diff --git a/compiler/plugins/input/Torch/InputConversion/BitCastQuantTensor.cpp b/compiler/plugins/input/Torch/InputConversion/BitCastQuantTensor.cpp
index 6a3dccdb97f9..64a99f4782ad 100644
--- a/compiler/plugins/input/Torch/InputConversion/BitCastQuantTensor.cpp
+++ b/compiler/plugins/input/Torch/InputConversion/BitCastQuantTensor.cpp
@@ -6,6 +6,7 @@
 
 #include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
 #include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
@@ -20,6 +21,47 @@ namespace mlir::iree_compiler::TorchInput {
 
 namespace {
 
+class BitCastViewDtype
+    : public OpRewritePattern<torch::Torch::AtenViewDtypeOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(torch::Torch::AtenViewDtypeOp op,
+                                PatternRewriter &rewriter) const override {
+
+    Value in = op.getSelf();
+    auto loc = op.getLoc();
+    auto inType = cast<torch::Torch::ValueTensorType>(in.getType());
+    auto resultType = cast<torch::Torch::ValueTensorType>(op.getType());
+
+    auto bType = inType.toBuiltinTensor();
+
+    if (auto dtype = dyn_cast<IntegerType>(bType.getElementType())) {
+      bType = bType.clone(
+          rewriter.getType<IntegerType>(dtype.getIntOrFloatBitWidth()));
+    }
+
+    // Cast to the builtin tensor type.
+    Value builtinCast =
+        rewriter.create<torch::TorchConversion::ToBuiltinTensorOp>(loc, bType,
+                                                                   in);
+
+    auto rType = resultType.toBuiltinTensor();
+    if (auto dtype = dyn_cast<IntegerType>(rType.getElementType())) {
+      rType = rType.clone(
+          rewriter.getType<IntegerType>(dtype.getIntOrFloatBitWidth()));
+    }
+
+    Value flowBitcast = rewriter.create<IREE::Flow::TensorBitCastOp>(
+        loc, rType, builtinCast, ValueRange(), ValueRange());
+
+    auto torchCast =
+        rewriter.create<torch::TorchConversion::FromBuiltinTensorOp>(
+            loc, resultType, flowBitcast);
+    rewriter.replaceOp(op, torchCast);
+    return success();
+  }
+};
+
 class BitCastQuantizedMatmul
     : public OpRewritePattern<torch::Torch::OperatorOp> {
 public:
@@ -117,7 +159,7 @@ class BitCastQuantTensorPass final
   void runOnOperation() override {
     MLIRContext *context = &getContext();
     RewritePatternSet patterns(context);
-    patterns.add<BitCastQuantizedMatmul>(context);
+    patterns.add<BitCastQuantizedMatmul, BitCastViewDtype>(context);
     if (failed(
             applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
       signalPassFailure();
diff --git a/compiler/plugins/input/Torch/InputConversion/test/bitcast_quant_tensor.mlir b/compiler/plugins/input/Torch/InputConversion/test/bitcast_quant_tensor.mlir
index fad4e7c9194a..95465956a1d2 100644
--- a/compiler/plugins/input/Torch/InputConversion/test/bitcast_quant_tensor.mlir
+++ b/compiler/plugins/input/Torch/InputConversion/test/bitcast_quant_tensor.mlir
@@ -14,3 +14,13 @@ func.func @forward(%arg0: !torch.vtensor<[1,1,8],f16>) -> !torch.vtensor<[1,1,8]
   %output = torch.operator "quant.matmul_rhs_group_quant"(%arg0, %q_rhs, %scales, %zps, %bit_width, %group_size) : (!torch.vtensor<[1,1,8],f16>, !torch.vtensor<[8,4],ui8>, !torch.vtensor<[8,4,1],f16>, !torch.vtensor<[8,4,1],f16>, !torch.int, !torch.int) -> !torch.vtensor<[1,1,8],f16>
   return %output : !torch.vtensor<[1,1,8],f16>
 }
+
+// -----
+
+// CHECK-LABEL: @view_type
+func.func @view_type(%arg0 : !torch.vtensor<[295501824],ui8>) -> !torch.vtensor<[147750912],si16> {
+    %int4 = torch.constant.int 4
+    // CHECK: flow.tensor.bitcast %[[IN:.+]] : tensor<295501824xi8> -> tensor<147750912xi16>
+    %0 = torch.aten.view.dtype %arg0, %int4 : !torch.vtensor<[295501824],ui8>, !torch.int -> !torch.vtensor<[147750912],si16>
+    return %0 : !torch.vtensor<[147750912],si16>
+}