From 3cfd054696495317f1e38778db4773f0ad1d0f7d Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Wed, 8 May 2024 11:15:15 -0700 Subject: [PATCH] Fixing op issues and lit tests --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 28 ++----------- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 40 ++++++++++++------- 2 files changed, 29 insertions(+), 39 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index db83bf5fb2c4..0b830c86b903 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -851,9 +851,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.tensorResultType(resultType)) return failure(); - Torch::ValueTensorType inputType = - operand.getType().cast(); - Value vAlpha = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), alpha)); @@ -862,31 +859,12 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), gamma)); - Value cstOne = rewriter.create( + Value vInputScale = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), 1.0)); - Value cstNone = rewriter.create(binder.getLoc()); - Value zeroTensor = rewriter.create( - binder.getLoc(), resultType, operand, cstNone, cstNone, cstNone, - cstNone, cstNone); - Value exp = rewriter.create(binder.getLoc(), - resultType, operand); - Value expMulAlpha = rewriter.create( - binder.getLoc(), resultType, exp, vAlpha); - Value expMulAlphaSubAlpha = rewriter.create( - binder.getLoc(), resultType, expMulAlpha, vAlpha, cstOne); - Value neg = rewriter.create( - binder.getLoc(), resultType, expMulAlphaSubAlpha, vScale); - Value pos = rewriter.create( - binder.getLoc(), resultType, operand, vScale); - Type compareType = inputType.getWithSizesAndDtype( - inputType.getOptionalSizes(), rewriter.getI1Type()); - Value xLessThanZero = rewriter.create( - binder.getLoc(), compareType, operand, zeroTensor); - - rewriter.replaceOpWithNewOp( - binder.op, resultType, xLessThanZero, neg, pos); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, vAlpha, vScale, vInputScale); return success(); }); patterns.onOp("ReduceL1", 1, diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index b1b1943014c3..d1dd20d32ef1 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -582,18 +582,10 @@ func.func @test_softmax_negative_axis(%arg0: !torch.vtensor<[3,4,5],f32>) -> !to // CHECK-LABEL: func.func @test_selu func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 6 : si64} { - // CHECK: %[[F2:.+]] = torch.constant.float 2.000000e+00 - // CHECK: %[[F3:.+]] = torch.constant.float 3.000000e+00 - // CHECK: %[[F1:.+]] = torch.constant.float 1.000000e+00 - // CHECK: %[[NONE:.+]] = torch.constant.none - // CHECK: %[[ZEROS:.+]] = torch.aten.zeros_like %arg0, %none, %none, %none, %none, %none : !torch.vtensor<[3,4,5],f32>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[EXP:.+]] = torch.aten.exp %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[MUL:.+]] = torch.aten.mul.Scalar %[[EXP]], %[[F2]] : !torch.vtensor<[3,4,5],f32>, !torch.float -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[SUB:.+]] = torch.aten.sub.Scalar %[[MUL]], %[[F2]], %[[F1]] : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.float -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[MUL_1:.+]] = torch.aten.mul.Scalar %[[SUB]], %[[F3]] : !torch.vtensor<[3,4,5],f32>, !torch.float -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[MUL_2:.+]] = torch.aten.mul.Scalar %arg0, %[[F3]] : !torch.vtensor<[3,4,5],f32>, !torch.float -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[LT:.+]] = torch.aten.lt.Tensor %arg0, %[[ZEROS]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],i1> - // CHECK: torch.aten.where.self %[[LT]], %[[MUL_1]], %[[MUL_2]] : !torch.vtensor<[3,4,5],i1>, !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK-DAG: %[[F1:.+]] = torch.constant.float 1 + // CHECK-DAG: %[[F2:.+]] = torch.constant.float 2 + // CHECK-DAG: %[[F3:.+]] = torch.constant.float 3 + // CHECK: %[[ELU:.+]] = torch.aten.elu %arg0, %[[F2]], %[[F3]], %[[F1]] %0 = torch.operator "onnx.Selu"(%arg0) {torch.onnx.alpha = 2.000000e+00 : f32, torch.onnx.gamma = 3.000000e+00 : f32} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } @@ -921,6 +913,11 @@ func.func @test_reduce_log_sum_do_not_keepdims_example(%arg0:!torch.vtensor<[3,2 // CHECK-LABEL: func.func @test_reduce_log_sum_exp_default_axes_keepdims_example func.func @test_reduce_log_sum_exp_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT7:.+]] = torch.constant.int 7 + // CHECK: %[[NONE_0:.+]] = torch.constant.none + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[3,2,2],f64> -> !torch.vtensor<[3,2,2],f64> // CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[INT7:.+]] = torch.constant.int 7 @@ -945,6 +942,11 @@ func.func @test_reduce_log_sum_exp_default_axes_keepdims_example(%arg0: !torch.v // CHECK-LABEL: func.func @test_reduce_log_sum_exp_do_not_keepdims_example_expanded func.func @test_reduce_log_sum_exp_do_not_keepdims_example_expanded(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT7:.+]] = torch.constant.int 7 + // CHECK: %[[NONE_0:.+]] = torch.constant.none + // CHECK: %[[FALSE_0:.+]] = torch.constant.bool false + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[3,2,2],f64> -> !torch.vtensor<[3,2,2],f64> // CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> @@ -975,6 +977,11 @@ func.func @test_reduce_log_sum_exp_do_not_keepdims_example_expanded(%arg0: !torc // CHECK-LABEL: func.func @test_reduce_log_sum_exp_keep_dims_example func.func @test_reduce_log_sum_exp_keep_dims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT7:.+]] = torch.constant.int 7 + // CHECK: %[[NONE_0:.+]] = torch.constant.none + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[3,2,2],f64> -> !torch.vtensor<[3,2,2],f64> // CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> @@ -1005,9 +1012,14 @@ func.func @test_reduce_log_sum_exp_keep_dims_example(%arg0: !torch.vtensor<[3,2, // CHECK-LABEL: func.func @test_reduce_log_sum_exp_keep_dims_int_input_example func.func @test_reduce_log_sum_exp_keep_dims_int_input_example(%arg0: !torch.vtensor<[3,2,2],si64>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT7:.+]] = torch.constant.int 7 + // CHECK: %[[NONE_0:.+]] = torch.constant.none + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[3,2,2],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[3,2,2],f64> -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: %[[INT0_1:.+]] = torch.constant.int 0 - // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0_0]], %[[INT0_1]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list // CHECK: %[[INT7:.+]] = torch.constant.int 7