From 586b9afa4787953242638601174dacd870d3a91b Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Wed, 9 Oct 2024 09:39:45 -0400 Subject: [PATCH] Enforce static dimensions in generation of flow.tensor.transfer (#205) This solves the problem in https://github.com/iree-org/iree/issues/18283 The issue is that we generate cast to/from dynamic tensors that later lowering in IREE chokes on it. My assumption is that it should be able to digest this IR since it is of the form. ```mlir %2 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[2,3,11,13],f32> -> tensor<2x3x11x13xf32> %cast = tensor.cast %2 : tensor<2x3x11x13xf32> to tensor %c0 = arith.constant 0 : index %dim = tensor.dim %cast, %c0 : tensor %c1 = arith.constant 1 : index %dim_0 = tensor.dim %cast, %c1 : tensor %c2 = arith.constant 2 : index %dim_1 = tensor.dim %cast, %c2 : tensor %c3 = arith.constant 3 : index %dim_2 = tensor.dim %cast, %c3 : tensor %3 = flow.tensor.transfer %cast : tensor{%dim, %dim_0, %dim_1, %dim_2} to #hal.device.promise<@__device_0> %cast_3 = tensor.cast %3 : tensor to tensor<2x3x11x13xf32> %4 = torch_c.from_builtin_tensor %cast_3 : tensor<2x3x11x13xf32> -> !torch.vtensor<[2,3,11,13],f32> ``` It essentially casts to a dynamic `tensor<...>` for the purpose of performing `flow.tensor.transfer` and then casts back to a static `torch.vtensor`. So it should be fine. With this change we get ```mlir %2 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[2,3,11,13],f32> -> tensor<2x3x11x13xf32> %3 = flow.tensor.transfer %2 : tensor<2x3x11x13xf32> to #hal.device.promise<@__device_0> %4 = torch_c.from_builtin_tensor %3 : tensor<2x3x11x13xf32> -> !torch.vtensor<[2,3,11,13],f32> ``` Signed-off-by: Boian Petkantchin --- iree/turbine/ops/iree.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/iree/turbine/ops/iree.py b/iree/turbine/ops/iree.py index 1609db2b..b4d79aee 100644 --- a/iree/turbine/ops/iree.py +++ b/iree/turbine/ops/iree.py @@ -83,7 +83,8 @@ class transfer_to_logical_device(CustomOp): def select(self, ksel: KernelSelection): ksel.attr_str(0) ta = ksel.arg_tensor(1) - ksel.return_tensor(ta.t) + ta.specialize_all_dims() + ksel.return_tensor(ta.t).specialize_all_dims() def eager_execute(self, device_moniker, tensor): return tensor