Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support lowering for aten.log1p #93

Merged
merged 3 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions lib/Conversion/TorchToTcp/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,37 @@ class ConvertAtenSqrtOp : public OpConversionPattern<AtenSqrtOp> {
}
};

class ConvertAtenLog1pOp : public OpConversionPattern<AtenLog1pOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(AtenLog1pOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value input = adaptor.getSelf();
RankedTensorType inputType = input.getType().dyn_cast<RankedTensorType>();

if (!inputType)
return rewriter.notifyMatchFailure(
op, "Only Ranked Tensor types are supported in TCP");

auto elementType = inputType.getElementType();
if (!isa<mlir::FloatType>(elementType))
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype is supported");

auto constOp = torch_to_tcp::getConstTensor<float>(
rewriter, op, llvm::ArrayRef((float)1.0), {})
.value();
constOp = torch_to_tcp::broadcast0DOr1DToNDAndMatchShape(
rewriter, constOp, input, elementType);
auto addOp =
rewriter.create<tcp::AddOp>(op.getLoc(), inputType, input, constOp);
rewriter.replaceOpWithNewOp<tcp::LogOp>(op, inputType, addOp);
return success();
}
};

template <typename AtenOpT, typename TcpOpT>
class ConvertAtenUnaryIntOrFpOp : public OpConversionPattern<AtenOpT> {
public:
Expand Down Expand Up @@ -694,6 +725,7 @@ void torch_to_tcp::populateElementwisePatternsAndLegality(
INSERT_ATEN_ELEMENTWISE_OP_PATTERN(AtenBatchNormOp);
INSERT_ATEN_ELEMENTWISE_OP_PATTERN(AtenAtan2Op);
INSERT_ATEN_ELEMENTWISE_OP_PATTERN(AtenSqrtOp);
INSERT_ATEN_ELEMENTWISE_OP_PATTERN(AtenLog1pOp);
#undef INSERT_ATEN_ELEMENTWISE_OP_PATTERN

#define INSERT_ATEN_ELEMENTWISE_ADD_SUB_PATTERN(AtenOp, TcpOp) \
Expand Down
1 change: 1 addition & 0 deletions test/AotCompile/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ AOT_TEST_SUITE = [
("tanh", False),
("clamp", False),
("relu", False),
("log1p", False),
("round_even", False),
("sqrt_float", False),
("sqrt_int", False),
Expand Down
18 changes: 18 additions & 0 deletions test/AotCompile/model_loader_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return TorchLoaderOutput(model=Relu(), inputs=(x,), dynamic_shapes=dynamic_shapes)


def log1p_loader() -> TorchLoaderOutput:
class Log1p(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.log1p(x)

# Sample inputs
x = torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])

# Dynamic dim constraints
batch = Dim("batch")
dynamic_shapes = {"x": {0: batch}}

return TorchLoaderOutput(model=Log1p(), inputs=(x,), dynamic_shapes=dynamic_shapes)


def round_even_loader() -> TorchLoaderOutput:
class RoundEven(torch.nn.Module):
def __init__(self):
Expand Down
19 changes: 19 additions & 0 deletions test/Conversion/TorchToTcp/elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -762,3 +762,22 @@ func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[?,?],i1>) -> !torch.vtenso
%0 = torch.aten.to.dtype %arg0, %int11, %false, %false, %none : !torch.vtensor<[?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],ui8>
return %0 : !torch.vtensor<[?,?],ui8>
}

// -----

// CHECK-LABEL: func.func @torch.aten.log1p(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,4,19,2],f32>) -> !torch.vtensor<[?,4,19,2],f32> {
// CHECK-DAG: %[[TO_BUILTIN0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,4,19,2],f32> -> tensor<?x4x19x2xf32>
// CHECK: %[[CONST:.*]] = tcp.const {value = dense<1.000000e+00> : tensor<f32>} : tensor<f32>
// CHECK: %[[EXPAND_SHAPE:.*]] = tensor.expand_shape %[[CONST]] [] output_shape [1, 1, 1, 1] : tensor<f32> into tensor<1x1x1x1xf32>
// CHECK: %[[CONST0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM0:.*]] = tensor.dim %[[TO_BUILTIN0]], %[[CONST0]] : tensor<?x4x19x2xf32>
// CHECK: %[[BROADCAST:.*]] = tcp.broadcast %[[EXPAND_SHAPE]], %[[DIM0]]
// CHECK: %[[ADD:.*]] = tcp.add %[[TO_BUILTIN0]], %[[BROADCAST]] : tensor<?x4x19x2xf32>, tensor<?x4x19x2xf32> -> tensor<?x4x19x2xf32>
// CHECK: %[[LOG:.*]] = tcp.log %[[ADD]] : tensor<?x4x19x2xf32> -> tensor<?x4x19x2xf32>
// CHECK: %[[FROM_BUILTIN:.*]] = torch_c.from_builtin_tensor %[[LOG]] : tensor<?x4x19x2xf32> -> !torch.vtensor<[?,4,19,2],f32>
// CHECK: return %[[FROM_BUILTIN]] : !torch.vtensor<[?,4,19,2],f32>
func.func @torch.aten.log1p(%arg0: !torch.vtensor<[?,4,19,2],f32>) -> !torch.vtensor<[?,4,19,2],f32> {
%1 = torch.aten.log1p %arg0 : !torch.vtensor<[?,4,19,2],f32> -> !torch.vtensor<[?,4,19,2],f32>
return %1 : !torch.vtensor<[?,4,19,2],f32>
}