diff --git a/lib/Conversion/TorchToTcp/Elementwise.cpp b/lib/Conversion/TorchToTcp/Elementwise.cpp index 98025bd..d28f28d 100644 --- a/lib/Conversion/TorchToTcp/Elementwise.cpp +++ b/lib/Conversion/TorchToTcp/Elementwise.cpp @@ -478,6 +478,37 @@ class ConvertAtenSqrtOp : public OpConversionPattern { } }; +class ConvertAtenLog1pOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AtenLog1pOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value input = adaptor.getSelf(); + RankedTensorType inputType = input.getType().dyn_cast(); + + if (!inputType) + return rewriter.notifyMatchFailure( + op, "Only Ranked Tensor types are supported in TCP"); + + auto elementType = inputType.getElementType(); + if (!isa(elementType)) + return rewriter.notifyMatchFailure( + op, "Only floating-point datatype is supported"); + + auto constOp = torch_to_tcp::getConstTensor( + rewriter, op, llvm::ArrayRef((float)1.0), {}) + .value(); + constOp = torch_to_tcp::broadcast0DOr1DToNDAndMatchShape( + rewriter, constOp, input, elementType); + auto addOp = + rewriter.create(op.getLoc(), inputType, input, constOp); + rewriter.replaceOpWithNewOp(op, inputType, addOp); + return success(); + } +}; + template class ConvertAtenUnaryIntOrFpOp : public OpConversionPattern { public: @@ -694,6 +725,7 @@ void torch_to_tcp::populateElementwisePatternsAndLegality( INSERT_ATEN_ELEMENTWISE_OP_PATTERN(AtenBatchNormOp); INSERT_ATEN_ELEMENTWISE_OP_PATTERN(AtenAtan2Op); INSERT_ATEN_ELEMENTWISE_OP_PATTERN(AtenSqrtOp); + INSERT_ATEN_ELEMENTWISE_OP_PATTERN(AtenLog1pOp); #undef INSERT_ATEN_ELEMENTWISE_OP_PATTERN #define INSERT_ATEN_ELEMENTWISE_ADD_SUB_PATTERN(AtenOp, TcpOp) \ diff --git a/test/AotCompile/BUILD b/test/AotCompile/BUILD index 6de2f31..cd79257 100644 --- a/test/AotCompile/BUILD +++ b/test/AotCompile/BUILD @@ -21,6 +21,7 @@ AOT_TEST_SUITE = [ ("tanh", False), ("clamp", False), ("relu", False), + ("log1p", False), ("round_even", False), ("sqrt_float", False), ("sqrt_int", False), diff --git a/test/AotCompile/model_loader_lib.py b/test/AotCompile/model_loader_lib.py index ddccad5..1b3e88c 100644 --- a/test/AotCompile/model_loader_lib.py +++ b/test/AotCompile/model_loader_lib.py @@ -263,6 +263,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return TorchLoaderOutput(model=Relu(), inputs=(x,), dynamic_shapes=dynamic_shapes) +def log1p_loader() -> TorchLoaderOutput: + class Log1p(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.log1p(x) + + # Sample inputs + x = torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) + + # Dynamic dim constraints + batch = Dim("batch") + dynamic_shapes = {"x": {0: batch}} + + return TorchLoaderOutput(model=Log1p(), inputs=(x,), dynamic_shapes=dynamic_shapes) + + def round_even_loader() -> TorchLoaderOutput: class RoundEven(torch.nn.Module): def __init__(self): diff --git a/test/Conversion/TorchToTcp/elementwise.mlir b/test/Conversion/TorchToTcp/elementwise.mlir index 8554f42..f9a3b57 100644 --- a/test/Conversion/TorchToTcp/elementwise.mlir +++ b/test/Conversion/TorchToTcp/elementwise.mlir @@ -762,3 +762,22 @@ func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[?,?],i1>) -> !torch.vtenso %0 = torch.aten.to.dtype %arg0, %int11, %false, %false, %none : !torch.vtensor<[?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],ui8> return %0 : !torch.vtensor<[?,?],ui8> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.log1p( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,4,19,2],f32>) -> !torch.vtensor<[?,4,19,2],f32> { +// CHECK-DAG: %[[TO_BUILTIN0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,4,19,2],f32> -> tensor +// CHECK: %[[CONST:.*]] = tcp.const {value = dense<1.000000e+00> : tensor} : tensor +// CHECK: %[[EXPAND_SHAPE:.*]] = tensor.expand_shape %[[CONST]] [] output_shape [1, 1, 1, 1] : tensor into tensor<1x1x1x1xf32> +// CHECK: %[[CONST0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM0:.*]] = tensor.dim %[[TO_BUILTIN0]], %[[CONST0]] : tensor +// CHECK: %[[BROADCAST:.*]] = tcp.broadcast %[[EXPAND_SHAPE]], %[[DIM0]] +// CHECK: %[[ADD:.*]] = tcp.add %[[TO_BUILTIN0]], %[[BROADCAST]] : tensor, tensor -> tensor +// CHECK: %[[LOG:.*]] = tcp.log %[[ADD]] : tensor -> tensor +// CHECK: %[[FROM_BUILTIN:.*]] = torch_c.from_builtin_tensor %[[LOG]] : tensor -> !torch.vtensor<[?,4,19,2],f32> +// CHECK: return %[[FROM_BUILTIN]] : !torch.vtensor<[?,4,19,2],f32> +func.func @torch.aten.log1p(%arg0: !torch.vtensor<[?,4,19,2],f32>) -> !torch.vtensor<[?,4,19,2],f32> { + %1 = torch.aten.log1p %arg0 : !torch.vtensor<[?,4,19,2],f32> -> !torch.vtensor<[?,4,19,2],f32> + return %1 : !torch.vtensor<[?,4,19,2],f32> +}