From f4562a8eaa3d13e17768f50164ede38383f7983e Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Mon, 5 Feb 2024 23:46:58 +0530 Subject: [PATCH] [ONNX] Fix the lowering of onnx.expand op (#2861) Signed-off-by: Gaurav Shukla --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 1 - .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 36 +++++++++---------- 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 1161b981c09f..05a1e5fcb15d 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -1387,7 +1387,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( Torch::BaseTensorType shapeType = shape.getType().cast(); SmallVector selectSizes; - selectSizes.push_back(1); Type selectResultType = shapeType.getWithSizesAndDtype( llvm::ArrayRef(selectSizes), shapeType.getOptionalDtype()); // Variable to store 1-D onnx shape tensor, shapeSizes[0] has the diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 797f9b6c2054..e757e3776d1b 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -1040,11 +1040,11 @@ func.func @test_expand_dim2_shape2(%arg0: !torch.vtensor<[1,4],f32>, %arg1: !tor -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[1],si32> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si32> -> !torch.int + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32> + // CHECK: torch.aten.item %0 : !torch.vtensor<[],si32> -> !torch.int // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[1],si32> - // CHECK: torch.aten.item %2 : !torch.vtensor<[1],si32> -> !torch.int + // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32> + // CHECK: torch.aten.item %2 : !torch.vtensor<[],si32> -> !torch.int // CHECK: torch.prim.ListConstruct %1, %3 : (!torch.int, !torch.int) -> !torch.list // CHECK: torch.aten.broadcast_to %arg0, %4 : !torch.vtensor<[1,4],f32>, !torch.list -> !torch.vtensor<[3,4],f32> %0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[1,4],f32>, !torch.vtensor<[2],si32>) -> !torch.vtensor<[3,4],f32> @@ -1057,14 +1057,14 @@ func.func @test_expand_dim2_shape2(%arg0: !torch.vtensor<[1,4],f32>, %arg1: !tor func.func @test_expand_dim2_shape3(%arg0: !torch.vtensor<[3,1],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,6],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[],si64> -> !torch.int // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %2 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: torch.aten.item %2 : !torch.vtensor<[],si64> -> !torch.int // CHECK: %[[INT2:.+]] = torch.constant.int 2 - // CHECK: torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %4 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: torch.aten.item %4 : !torch.vtensor<[],si64> -> !torch.int // CHECK: torch.prim.ListConstruct %1, %3, %5 : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: torch.aten.broadcast_to %arg0, %6 : !torch.vtensor<[3,1],f32>, !torch.list -> !torch.vtensor<[2,3,6],f32> %0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[3,1],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,6],f32> @@ -1077,17 +1077,17 @@ func.func @test_expand_dim2_shape3(%arg0: !torch.vtensor<[3,1],f32>, %arg1: !tor func.func @test_expand_dim3_shape4(%arg0: !torch.vtensor<[1,3,1],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[3,3,3,3],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[],si64> -> !torch.int // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %2 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: torch.aten.item %2 : !torch.vtensor<[],si64> -> !torch.int // CHECK: %[[INT2:.+]] = torch.constant.int 2 - // CHECK: torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %4 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: torch.aten.item %4 : !torch.vtensor<[],si64> -> !torch.int // CHECK: %[[INT3:.+]] = torch.constant.int 3 - // CHECK: torch.aten.select.int %arg1, %int0, %int3 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.select.int %arg1, %int0, %int3 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: torch.aten.item %6 : !torch.vtensor<[],si64> -> !torch.int // CHECK: torch.prim.ListConstruct %1, %3, %5, %7 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %9 = torch.aten.broadcast_to %arg0, %8 : !torch.vtensor<[1,3,1],f32>, !torch.list -> !torch.vtensor<[3,3,3,3],f32> %0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[1,3,1],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[3,3,3,3],f32>