Skip to content

Commit

Permalink
[Pipeline] Add TorchToTcpCustomOpPass to frontend pipeline (#20)
Browse files Browse the repository at this point in the history
This was missed from the earlier PR.
  • Loading branch information
sjain-stanford authored Nov 6, 2023
1 parent 32c2532 commit eaaa7cf
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
2 changes: 2 additions & 0 deletions lib/Pipeline/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir-tcp/Conversion/TcpToArith/TcpToArith.h"
#include "mlir-tcp/Conversion/TcpToLinalg/TcpToLinalg.h"
#include "mlir-tcp/Conversion/TorchToTcp/TorchToTcp.h"
#include "mlir-tcp/Conversion/TorchToTcp/TorchToTcpCustomOp.h"

#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
Expand All @@ -36,6 +37,7 @@ using namespace mlir;

static void createTorchBackendToTcpBackendPipeline(OpPassManager &pm) {
pm.addNestedPass<func::FuncOp>(tcp::createConvertTorchToTcpPass());
pm.addNestedPass<func::FuncOp>(tcp::createConvertTorchToTcpCustomOpPass());

// Clean up any non-canonical code introduced above.
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
Expand Down
14 changes: 14 additions & 0 deletions test/Pipeline/torch_to_tcp_pipeline.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,17 @@ func.func @torch.aten.div.Tensor$mixed_type_int(%arg0: !torch.vtensor<[?, ?],si1
%0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],si16>, !torch.vtensor<[?, ?],si32> -> !torch.vtensor<[?, ?],si32>
return %0 : !torch.vtensor<[?, ?],si32>
}

// -----

// CHECK-LABEL: torch.aten.gather_op
// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x2xi64>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<2x2xf32>
// CHECK: %[[VAL_2:.*]] = tcp.custom_op("torch.aten.gather") %[[VAL_1]], %[[VAL_0]] {axis = 1 : i64} : tensor<2x2xf32>, tensor<2x2xi64> -> tensor<2x2xf32>
// CHECK: return %[[VAL_2]] : tensor<2x2xf32>
func.func @torch.aten.gather_op(%arg0: !torch.vtensor<[2,2],si64>, %arg1: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,2],f32> {
%false = torch.constant.bool false
%int1 = torch.constant.int 1
%0 = torch.aten.gather %arg1, %int1, %arg0, %false : !torch.vtensor<[2,2],f32>, !torch.int, !torch.vtensor<[2,2],si64>, !torch.bool -> !torch.vtensor<[2,2],f32>
return %0 : !torch.vtensor<[2,2],f32>
}

0 comments on commit eaaa7cf

Please sign in to comment.