Skip to content

Commit

Permalink
support lowering for aten.log1p (#93)
Browse files Browse the repository at this point in the history
Lower aten.log1p op to  tcp.log(tcp.add(input, 1.0))

To test:

`bazel test //...` (in docker)
  • Loading branch information
zezhang authored Sep 3, 2024
1 parent 57d5e00 commit 29062f8
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 0 deletions.
32 changes: 32 additions & 0 deletions lib/Conversion/TorchToTcp/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,37 @@ class ConvertAtenSqrtOp : public OpConversionPattern<AtenSqrtOp> {
}
};

class ConvertAtenLog1pOp : public OpConversionPattern<AtenLog1pOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(AtenLog1pOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value input = adaptor.getSelf();
RankedTensorType inputType = input.getType().dyn_cast<RankedTensorType>();

if (!inputType)
return rewriter.notifyMatchFailure(
op, "Only Ranked Tensor types are supported in TCP");

auto elementType = inputType.getElementType();
if (!isa<mlir::FloatType>(elementType))
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype is supported");

auto constOp = torch_to_tcp::getConstTensor<float>(
rewriter, op, llvm::ArrayRef((float)1.0), {})
.value();
constOp = torch_to_tcp::broadcast0DOr1DToNDAndMatchShape(
rewriter, constOp, input, elementType);
auto addOp =
rewriter.create<tcp::AddOp>(op.getLoc(), inputType, input, constOp);
rewriter.replaceOpWithNewOp<tcp::LogOp>(op, inputType, addOp);
return success();
}
};

template <typename AtenOpT, typename TcpOpT>
class ConvertAtenUnaryIntOrFpOp : public OpConversionPattern<AtenOpT> {
public:
Expand Down Expand Up @@ -694,6 +725,7 @@ void torch_to_tcp::populateElementwisePatternsAndLegality(
INSERT_ATEN_ELEMENTWISE_OP_PATTERN(AtenBatchNormOp);
INSERT_ATEN_ELEMENTWISE_OP_PATTERN(AtenAtan2Op);
INSERT_ATEN_ELEMENTWISE_OP_PATTERN(AtenSqrtOp);
INSERT_ATEN_ELEMENTWISE_OP_PATTERN(AtenLog1pOp);
#undef INSERT_ATEN_ELEMENTWISE_OP_PATTERN

#define INSERT_ATEN_ELEMENTWISE_ADD_SUB_PATTERN(AtenOp, TcpOp) \
Expand Down
1 change: 1 addition & 0 deletions test/AotCompile/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ AOT_TEST_SUITE = [
("tanh", False),
("clamp", False),
("relu", False),
("log1p", False),
("round_even", False),
("sqrt_float", False),
("sqrt_int", False),
Expand Down
18 changes: 18 additions & 0 deletions test/AotCompile/model_loader_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return TorchLoaderOutput(model=Relu(), inputs=(x,), dynamic_shapes=dynamic_shapes)


def log1p_loader() -> TorchLoaderOutput:
class Log1p(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.log1p(x)

# Sample inputs
x = torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])

# Dynamic dim constraints
batch = Dim("batch")
dynamic_shapes = {"x": {0: batch}}

return TorchLoaderOutput(model=Log1p(), inputs=(x,), dynamic_shapes=dynamic_shapes)


def round_even_loader() -> TorchLoaderOutput:
class RoundEven(torch.nn.Module):
def __init__(self):
Expand Down
19 changes: 19 additions & 0 deletions test/Conversion/TorchToTcp/elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -762,3 +762,22 @@ func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[?,?],i1>) -> !torch.vtenso
%0 = torch.aten.to.dtype %arg0, %int11, %false, %false, %none : !torch.vtensor<[?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],ui8>
return %0 : !torch.vtensor<[?,?],ui8>
}

// -----

// CHECK-LABEL: func.func @torch.aten.log1p(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,4,19,2],f32>) -> !torch.vtensor<[?,4,19,2],f32> {
// CHECK-DAG: %[[TO_BUILTIN0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,4,19,2],f32> -> tensor<?x4x19x2xf32>
// CHECK: %[[CONST:.*]] = tcp.const {value = dense<1.000000e+00> : tensor<f32>} : tensor<f32>
// CHECK: %[[EXPAND_SHAPE:.*]] = tensor.expand_shape %[[CONST]] [] output_shape [1, 1, 1, 1] : tensor<f32> into tensor<1x1x1x1xf32>
// CHECK: %[[CONST0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM0:.*]] = tensor.dim %[[TO_BUILTIN0]], %[[CONST0]] : tensor<?x4x19x2xf32>
// CHECK: %[[BROADCAST:.*]] = tcp.broadcast %[[EXPAND_SHAPE]], %[[DIM0]]
// CHECK: %[[ADD:.*]] = tcp.add %[[TO_BUILTIN0]], %[[BROADCAST]] : tensor<?x4x19x2xf32>, tensor<?x4x19x2xf32> -> tensor<?x4x19x2xf32>
// CHECK: %[[LOG:.*]] = tcp.log %[[ADD]] : tensor<?x4x19x2xf32> -> tensor<?x4x19x2xf32>
// CHECK: %[[FROM_BUILTIN:.*]] = torch_c.from_builtin_tensor %[[LOG]] : tensor<?x4x19x2xf32> -> !torch.vtensor<[?,4,19,2],f32>
// CHECK: return %[[FROM_BUILTIN]] : !torch.vtensor<[?,4,19,2],f32>
func.func @torch.aten.log1p(%arg0: !torch.vtensor<[?,4,19,2],f32>) -> !torch.vtensor<[?,4,19,2],f32> {
%1 = torch.aten.log1p %arg0 : !torch.vtensor<[?,4,19,2],f32> -> !torch.vtensor<[?,4,19,2],f32>
return %1 : !torch.vtensor<[?,4,19,2],f32>
}

0 comments on commit 29062f8

Please sign in to comment.