Skip to content

Commit

Permalink
Separate conversion pass for torch -> tcp.custom_op to support both…
Browse files Browse the repository at this point in the history
… kernel and codegen approaches (#19)

This allows switching between codegen vs kernels approaches, for torch
ops that can lower to either tcp ops (codegen) or tcp.custom_op
(kernel). Addresses
#17 (comment)
using [Option
2](#17 (comment)).
  • Loading branch information
sjain-stanford authored Nov 5, 2023
1 parent 6070960 commit 32c2532
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 5 deletions.
6 changes: 5 additions & 1 deletion BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,14 @@ cc_library(
"lib/Conversion/TorchToTcp/PopulatePatterns.h",
"lib/Conversion/TorchToTcp/TcpCustomOp.cpp",
"lib/Conversion/TorchToTcp/TorchToTcp.cpp",
"lib/Conversion/TorchToTcp/TorchToTcpCustomOp.cpp",
"lib/Conversion/TorchToTcp/Utils.cpp",
"lib/Conversion/TorchToTcp/Utils.h",
],
hdrs = ["include/mlir-tcp/Conversion/TorchToTcp/TorchToTcp.h"],
hdrs = [
"include/mlir-tcp/Conversion/TorchToTcp/TorchToTcp.h",
"include/mlir-tcp/Conversion/TorchToTcp/TorchToTcpCustomOp.h",
],
strip_include_prefix = "include",
deps = [
":TcpConversionPassesIncGen",
Expand Down
17 changes: 17 additions & 0 deletions include/mlir-tcp/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,23 @@ def ConvertTorchToTcp : Pass<"convert-torch-to-tcp", "func::FuncOp"> {
];
}

//===----------------------------------------------------------------------===//
// TorchToTcpCustomOp
//===----------------------------------------------------------------------===//

def ConvertTorchToTcpCustomOp : Pass<"convert-torch-to-tcp-custom-op", "func::FuncOp"> {
let summary = "Convert Torch ops to Tcp custom ops";
let description = [{
Convert Torch ops to Tcp custom ops.
}];
let constructor = "mlir::tcp::createConvertTorchToTcpCustomOpPass()";
let options = [
ListOption<"convertTorchOps", "convert-torch-ops", "std::string",
"List of Torch operation names that should be converted to Tcp custom op",
"llvm::cl::ZeroOrMore">,
];
}

//===----------------------------------------------------------------------===//
// StablehloToTcp
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions include/mlir-tcp/Conversion/TorchToTcp/TorchToTcp.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ namespace mlir {
namespace tcp {

std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTcpPass();

std::unique_ptr<OperationPass<func::FuncOp>>
createConvertTorchToTcpPass(llvm::ArrayRef<std::string> convertTorchOps);

Expand Down
30 changes: 30 additions & 0 deletions include/mlir-tcp/Conversion/TorchToTcp/TorchToTcpCustomOp.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#pragma once

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"

namespace mlir {

#define GEN_PASS_DECL_CONVERTTORCHTOTCPCUSTOMOP
#include "mlir-tcp/Conversion/Passes.h.inc"

namespace tcp {

std::unique_ptr<OperationPass<func::FuncOp>>
createConvertTorchToTcpCustomOpPass();

std::unique_ptr<OperationPass<func::FuncOp>>
createConvertTorchToTcpCustomOpPass(
llvm::ArrayRef<std::string> convertTorchOps);

} // namespace tcp
} // namespace mlir
1 change: 1 addition & 0 deletions lib/Conversion/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "mlir-tcp/Conversion/TcpToArith/TcpToArith.h"
#include "mlir-tcp/Conversion/TcpToLinalg/TcpToLinalg.h"
#include "mlir-tcp/Conversion/TorchToTcp/TorchToTcp.h"
#include "mlir-tcp/Conversion/TorchToTcp/TorchToTcpCustomOp.h"

//===----------------------------------------------------------------------===//
// Pass registration
Expand Down
3 changes: 0 additions & 3 deletions lib/Conversion/TorchToTcp/TorchToTcp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,6 @@ class ConvertTorchToTcp : public ConvertTorchToTcpBase<ConvertTorchToTcp> {
torch_to_tcp::populateDataMovementPatternsAndLegality(
typeConverter, patterns, target, convertTorchOpsSet);

torch_to_tcp::populateTcpCustomOpPatternsAndLegality(
typeConverter, patterns, target, convertTorchOpsSet);

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
return signalPassFailure();
Expand Down
106 changes: 106 additions & 0 deletions lib/Conversion/TorchToTcp/TorchToTcpCustomOp.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#include "mlir-tcp/Conversion/TorchToTcp/TorchToTcpCustomOp.h"

#include "mlir-tcp/Dialect/IR/TcpDialect.h"
#include "mlir-tcp/Dialect/IR/TcpOps.h"

#include "../PassDetail.h"
#include "PopulatePatterns.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"

#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringSet.h"

using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;

namespace mlir {

#define GEN_PASS_DEF_CONVERTTORCHTOTCPCUSTOMOP
#include "mlir-tcp/Conversion/Passes.h.inc"

namespace tcp {

namespace {

class ConvertTorchToTcpCustomOp
: public ConvertTorchToTcpCustomOpBase<ConvertTorchToTcpCustomOp> {
private:
llvm::StringSet<> convertTorchOpsSet;

public:
ConvertTorchToTcpCustomOp() = default;
ConvertTorchToTcpCustomOp(ArrayRef<std::string> convertTorchOps) {
this->convertTorchOps = convertTorchOps;
}

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<tcp::TcpDialect>();
registry.insert<tensor::TensorDialect>();
TorchConversion::getBackendTypeConversionDependentDialects(registry);
}

void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);

// Usually the default constructor is called which means `convertTorchOps`
// is usually unset. Doing this here allows the initialization of
// `convertTorchOpsSet` to be be delayed to when `runOnOperation` is called.
convertTorchOpsSet.clear();
convertTorchOpsSet.insert(convertTorchOps.begin(), convertTorchOps.end());

ConversionTarget target(*context);
target.addLegalDialect<tcp::TcpDialect, tensor::TensorDialect,
arith::ArithDialect>();

TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversion(target, typeConverter);

torch_to_tcp::populateTcpCustomOpPatternsAndLegality(
typeConverter, patterns, target, convertTorchOpsSet);

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
return signalPassFailure();
}
}
};

} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
createConvertTorchToTcpCustomOpPass() {
llvm::ArrayRef<std::string> emptyArrayRef;
return std::make_unique<ConvertTorchToTcpCustomOp>(
/*convertTorchOps=*/emptyArrayRef);
}

std::unique_ptr<OperationPass<func::FuncOp>>
createConvertTorchToTcpCustomOpPass(
llvm::ArrayRef<std::string> convertTorchOps) {
return std::make_unique<ConvertTorchToTcpCustomOp>(convertTorchOps);
}

} // namespace tcp
} // namespace mlir
2 changes: 1 addition & 1 deletion test/Conversion/TorchToTcp/tcp_custom_ops.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: tcp-opt <%s -convert-torch-to-tcp -canonicalize -split-input-file -verify-diagnostics | FileCheck %s
// RUN: tcp-opt <%s -convert-torch-to-tcp-custom-op -canonicalize -split-input-file -verify-diagnostics | FileCheck %s

// CHECK-LABEL: func.func @torch.aten.gather_op(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,2],si64>
Expand Down

0 comments on commit 32c2532

Please sign in to comment.