From 32c25322513ffe02b869ef245606c0f9eae65395 Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Sun, 5 Nov 2023 06:39:46 -0800 Subject: [PATCH] Separate conversion pass for `torch -> tcp.custom_op` to support both kernel and codegen approaches (#19) This allows switching between codegen vs kernels approaches, for torch ops that can lower to either tcp ops (codegen) or tcp.custom_op (kernel). Addresses https://github.com/cruise-automation/mlir-tcp/pull/17#discussion_r1382303711 using [Option 2](https://github.com/cruise-automation/mlir-tcp/pull/17#discussion_r1382436616). --- BUILD | 6 +- include/mlir-tcp/Conversion/Passes.td | 17 +++ .../Conversion/TorchToTcp/TorchToTcp.h | 1 + .../TorchToTcp/TorchToTcpCustomOp.h | 30 +++++ lib/Conversion/Passes.cpp | 1 + lib/Conversion/TorchToTcp/TorchToTcp.cpp | 3 - .../TorchToTcp/TorchToTcpCustomOp.cpp | 106 ++++++++++++++++++ .../Conversion/TorchToTcp/tcp_custom_ops.mlir | 2 +- 8 files changed, 161 insertions(+), 5 deletions(-) create mode 100644 include/mlir-tcp/Conversion/TorchToTcp/TorchToTcpCustomOp.h create mode 100644 lib/Conversion/TorchToTcp/TorchToTcpCustomOp.cpp diff --git a/BUILD b/BUILD index 9b6a4ffe..bfa973da 100644 --- a/BUILD +++ b/BUILD @@ -192,10 +192,14 @@ cc_library( "lib/Conversion/TorchToTcp/PopulatePatterns.h", "lib/Conversion/TorchToTcp/TcpCustomOp.cpp", "lib/Conversion/TorchToTcp/TorchToTcp.cpp", + "lib/Conversion/TorchToTcp/TorchToTcpCustomOp.cpp", "lib/Conversion/TorchToTcp/Utils.cpp", "lib/Conversion/TorchToTcp/Utils.h", ], - hdrs = ["include/mlir-tcp/Conversion/TorchToTcp/TorchToTcp.h"], + hdrs = [ + "include/mlir-tcp/Conversion/TorchToTcp/TorchToTcp.h", + "include/mlir-tcp/Conversion/TorchToTcp/TorchToTcpCustomOp.h", + ], strip_include_prefix = "include", deps = [ ":TcpConversionPassesIncGen", diff --git a/include/mlir-tcp/Conversion/Passes.td b/include/mlir-tcp/Conversion/Passes.td index f242763d..f25954e1 100644 --- a/include/mlir-tcp/Conversion/Passes.td +++ b/include/mlir-tcp/Conversion/Passes.td @@ -29,6 +29,23 @@ def ConvertTorchToTcp : Pass<"convert-torch-to-tcp", "func::FuncOp"> { ]; } +//===----------------------------------------------------------------------===// +// TorchToTcpCustomOp +//===----------------------------------------------------------------------===// + +def ConvertTorchToTcpCustomOp : Pass<"convert-torch-to-tcp-custom-op", "func::FuncOp"> { + let summary = "Convert Torch ops to Tcp custom ops"; + let description = [{ + Convert Torch ops to Tcp custom ops. + }]; + let constructor = "mlir::tcp::createConvertTorchToTcpCustomOpPass()"; + let options = [ + ListOption<"convertTorchOps", "convert-torch-ops", "std::string", + "List of Torch operation names that should be converted to Tcp custom op", + "llvm::cl::ZeroOrMore">, + ]; +} + //===----------------------------------------------------------------------===// // StablehloToTcp //===----------------------------------------------------------------------===// diff --git a/include/mlir-tcp/Conversion/TorchToTcp/TorchToTcp.h b/include/mlir-tcp/Conversion/TorchToTcp/TorchToTcp.h index d7dd5c24..16400407 100644 --- a/include/mlir-tcp/Conversion/TorchToTcp/TorchToTcp.h +++ b/include/mlir-tcp/Conversion/TorchToTcp/TorchToTcp.h @@ -20,6 +20,7 @@ namespace mlir { namespace tcp { std::unique_ptr> createConvertTorchToTcpPass(); + std::unique_ptr> createConvertTorchToTcpPass(llvm::ArrayRef convertTorchOps); diff --git a/include/mlir-tcp/Conversion/TorchToTcp/TorchToTcpCustomOp.h b/include/mlir-tcp/Conversion/TorchToTcp/TorchToTcpCustomOp.h new file mode 100644 index 00000000..89810e18 --- /dev/null +++ b/include/mlir-tcp/Conversion/TorchToTcp/TorchToTcpCustomOp.h @@ -0,0 +1,30 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { + +#define GEN_PASS_DECL_CONVERTTORCHTOTCPCUSTOMOP +#include "mlir-tcp/Conversion/Passes.h.inc" + +namespace tcp { + +std::unique_ptr> +createConvertTorchToTcpCustomOpPass(); + +std::unique_ptr> +createConvertTorchToTcpCustomOpPass( + llvm::ArrayRef convertTorchOps); + +} // namespace tcp +} // namespace mlir diff --git a/lib/Conversion/Passes.cpp b/lib/Conversion/Passes.cpp index 446ef4e2..d2328232 100644 --- a/lib/Conversion/Passes.cpp +++ b/lib/Conversion/Passes.cpp @@ -13,6 +13,7 @@ #include "mlir-tcp/Conversion/TcpToArith/TcpToArith.h" #include "mlir-tcp/Conversion/TcpToLinalg/TcpToLinalg.h" #include "mlir-tcp/Conversion/TorchToTcp/TorchToTcp.h" +#include "mlir-tcp/Conversion/TorchToTcp/TorchToTcpCustomOp.h" //===----------------------------------------------------------------------===// // Pass registration diff --git a/lib/Conversion/TorchToTcp/TorchToTcp.cpp b/lib/Conversion/TorchToTcp/TorchToTcp.cpp index 83608d50..28d2b294 100644 --- a/lib/Conversion/TorchToTcp/TorchToTcp.cpp +++ b/lib/Conversion/TorchToTcp/TorchToTcp.cpp @@ -85,9 +85,6 @@ class ConvertTorchToTcp : public ConvertTorchToTcpBase { torch_to_tcp::populateDataMovementPatternsAndLegality( typeConverter, patterns, target, convertTorchOpsSet); - torch_to_tcp::populateTcpCustomOpPatternsAndLegality( - typeConverter, patterns, target, convertTorchOpsSet); - if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { return signalPassFailure(); diff --git a/lib/Conversion/TorchToTcp/TorchToTcpCustomOp.cpp b/lib/Conversion/TorchToTcp/TorchToTcpCustomOp.cpp new file mode 100644 index 00000000..a9e57f8e --- /dev/null +++ b/lib/Conversion/TorchToTcp/TorchToTcpCustomOp.cpp @@ -0,0 +1,106 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "mlir-tcp/Conversion/TorchToTcp/TorchToTcpCustomOp.h" + +#include "mlir-tcp/Dialect/IR/TcpDialect.h" +#include "mlir-tcp/Dialect/IR/TcpOps.h" + +#include "../PassDetail.h" +#include "PopulatePatterns.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Traits.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Transforms/DialectConversion.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" +#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringSet.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +namespace mlir { + +#define GEN_PASS_DEF_CONVERTTORCHTOTCPCUSTOMOP +#include "mlir-tcp/Conversion/Passes.h.inc" + +namespace tcp { + +namespace { + +class ConvertTorchToTcpCustomOp + : public ConvertTorchToTcpCustomOpBase { +private: + llvm::StringSet<> convertTorchOpsSet; + +public: + ConvertTorchToTcpCustomOp() = default; + ConvertTorchToTcpCustomOp(ArrayRef convertTorchOps) { + this->convertTorchOps = convertTorchOps; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + TorchConversion::getBackendTypeConversionDependentDialects(registry); + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + + // Usually the default constructor is called which means `convertTorchOps` + // is usually unset. Doing this here allows the initialization of + // `convertTorchOpsSet` to be be delayed to when `runOnOperation` is called. + convertTorchOpsSet.clear(); + convertTorchOpsSet.insert(convertTorchOps.begin(), convertTorchOps.end()); + + ConversionTarget target(*context); + target.addLegalDialect(); + + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + TorchConversion::setupBackendTypeConversion(target, typeConverter); + + torch_to_tcp::populateTcpCustomOpPatternsAndLegality( + typeConverter, patterns, target, convertTorchOpsSet); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr> +createConvertTorchToTcpCustomOpPass() { + llvm::ArrayRef emptyArrayRef; + return std::make_unique( + /*convertTorchOps=*/emptyArrayRef); +} + +std::unique_ptr> +createConvertTorchToTcpCustomOpPass( + llvm::ArrayRef convertTorchOps) { + return std::make_unique(convertTorchOps); +} + +} // namespace tcp +} // namespace mlir diff --git a/test/Conversion/TorchToTcp/tcp_custom_ops.mlir b/test/Conversion/TorchToTcp/tcp_custom_ops.mlir index 1f6593c0..a201589a 100644 --- a/test/Conversion/TorchToTcp/tcp_custom_ops.mlir +++ b/test/Conversion/TorchToTcp/tcp_custom_ops.mlir @@ -1,4 +1,4 @@ -// RUN: tcp-opt <%s -convert-torch-to-tcp -canonicalize -split-input-file -verify-diagnostics | FileCheck %s +// RUN: tcp-opt <%s -convert-torch-to-tcp-custom-op -canonicalize -split-input-file -verify-diagnostics | FileCheck %s // CHECK-LABEL: func.func @torch.aten.gather_op( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,2],si64>