diff --git a/include/torch-mlir-dialects/Dialect/TMTensor/Transforms/PassDetail.h b/include/torch-mlir-dialects/Dialect/TMTensor/Transforms/PassDetail.h deleted file mode 100644 index 391280bbc1a0..000000000000 --- a/include/torch-mlir-dialects/Dialect/TMTensor/Transforms/PassDetail.h +++ /dev/null @@ -1,27 +0,0 @@ -//===- PassDetail.h - TMTensor Pass class details -------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Also available under a BSD-style license. See LICENSE. -// -//===----------------------------------------------------------------------===// - -#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASS_DETAIL_H_ -#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASS_DETAIL_H_ - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Pass/Pass.h" - -namespace mlir { -namespace torch { -namespace TMTensor { - -#define GEN_PASS_CLASSES -#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h.inc" // IWYU pragma: keep - -} // namespace TMTensor -} // namespace torch -} // namespace mlir - -#endif // TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASS_DETAIL_H_ diff --git a/include/torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h b/include/torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h index c1926015989e..36f171a4b52b 100644 --- a/include/torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h +++ b/include/torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h @@ -16,10 +16,18 @@ namespace mlir { namespace torch { + +#define GEN_PASS_DECL_CONVERTTORCHTOSTABLEHLO +#include "torch-mlir/Conversion/Passes.h.inc" + std::unique_ptr> createConvertTorchToStablehloPass(); + +// Convenience wrapper for users who want to pass options as individual +// parameters std::unique_ptr> createConvertTorchToStablehloPass(bool enableStaticShape, bool enableI32Index); + } // namespace torch } // namespace mlir diff --git a/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h b/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h index 8ee6fecaa015..c9d9688e04fd 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h @@ -19,6 +19,9 @@ namespace mlir { namespace torch { +#define GEN_PASS_DECL_CONVERTTORCHTOTOSA +#include "torch-mlir/Conversion/Passes.h.inc" + /// Collect a set of legal/illegal ops for converting Torch operations to Tosa /// dialect. void populateTorchToTosaConversionLegalOps(ConversionTarget &target); @@ -30,8 +33,12 @@ populateTorchToTosaConversionPatternsAndIllegalOps(TypeConverter &typeConverter, RewritePatternSet &patterns); std::unique_ptr> createConvertTorchToTosaPass(); + +// Convenience wrapper for users who want to pass options as individual +// parameters std::unique_ptr> createConvertTorchToTosaPass(bool requireFullTosaConversion); + } // namespace torch } // namespace mlir diff --git a/include/torch-mlir/RefBackend/Passes.h b/include/torch-mlir/RefBackend/Passes.h index be5e43a1e63c..19879ace8194 100644 --- a/include/torch-mlir/RefBackend/Passes.h +++ b/include/torch-mlir/RefBackend/Passes.h @@ -15,25 +15,12 @@ #include "mlir/Pass/PassManager.h" namespace mlir { -class ModuleOp; - namespace torch { namespace RefBackend { /// Registers all RefBackend passes. void registerRefBackendPasses(); -std::unique_ptr> createMungeCallingConventionsPass(); - -std::unique_ptr> createExpandOpsForLLVMPass(); - -std::unique_ptr> createMLProgramBufferizePass(); - -std::unique_ptr> createMungeMemrefCopyPass(); - -std::unique_ptr> createGeneralizeTensorConcatPass(); - -std::unique_ptr> createGeneralizeTensorPadPass(); } // namespace RefBackend } // namespace torch } // namespace mlir diff --git a/include/torch-mlir/RefBackend/Passes.td b/include/torch-mlir/RefBackend/Passes.td index 3d8b7fd41b1b..2f08518f92c2 100644 --- a/include/torch-mlir/RefBackend/Passes.td +++ b/include/torch-mlir/RefBackend/Passes.td @@ -14,35 +14,29 @@ include "mlir/Pass/PassBase.td" def MungeCallingConventions : Pass<"refback-munge-calling-conventions", "ModuleOp"> { let summary = "Munge calling conventions for calling via ExecutionEngine"; - let constructor = "mlir::torch::RefBackend::createMungeCallingConventionsPass();"; let dependentDialects = ["memref::MemRefDialect"]; } def MLProgramBufferize: Pass<"refback-mlprogram-bufferize", "ModuleOp"> { let summary = "Bufferize the MLProgram dialect ops"; - let constructor = "mlir::torch::RefBackend::createMLProgramBufferizePass();"; let dependentDialects = ["memref::MemRefDialect"]; } def ExpandOpsForLLVM : Pass<"refback-expand-ops-for-llvm", "func::FuncOp"> { let summary = "Expand ops into more primitive ops before LLVM lowering."; - let constructor = "mlir::torch::RefBackend::createExpandOpsForLLVMPass();"; } def MungeMemrefCopy : Pass<"refback-munge-memref-copy", "func::FuncOp"> { let summary = "Munge memref.copy to linalg.copy"; - let constructor = "mlir::torch::RefBackend::createMungeMemrefCopyPass();"; let dependentDialects = ["memref::MemRefDialect"]; } def GeneralizeTensorConcat : Pass<"refback-generalize-tensor-concat", "func::FuncOp"> { let summary = "Convert tensor.concat to other tensor ops"; - let constructor = "mlir::torch::RefBackend::createGeneralizeTensorConcatPass()"; } def GeneralizeTensorPad : Pass<"refback-generalize-tensor-pad", "func::FuncOp"> { let summary = "Convert tensor.pad to linalg ops"; - let constructor = "mlir::torch::RefBackend::createGeneralizeTensorPadPass()"; } #endif // TORCHMLIR_REFBACKEND_PASSES diff --git a/lib/Conversion/PassDetail.h b/lib/Conversion/PassDetail.h deleted file mode 100644 index aa832141f1de..000000000000 --- a/lib/Conversion/PassDetail.h +++ /dev/null @@ -1,26 +0,0 @@ -//===- PassDetail.h - Conversion Pass class details -------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Also available under a BSD-style license. See LICENSE. -// -//===----------------------------------------------------------------------===// - -#ifndef TORCHMLIR_CONVERSION_PASSDETAIL_H -#define TORCHMLIR_CONVERSION_PASSDETAIL_H - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/Pass/Pass.h" - -namespace mlir { -namespace torch { - -#define GEN_PASS_CLASSES -#include "torch-mlir/Conversion/Passes.h.inc" - -} // namespace torch -} // end namespace mlir - -#endif // TORCHMLIR_CONVERSION_PASSDETAIL_H diff --git a/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp b/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp index e5dee5b4c3bb..d45506a16088 100644 --- a/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp +++ b/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp @@ -8,8 +8,10 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" +#include "torch-mlir/Conversion/Passes.h" -#include "../PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -20,6 +22,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; using namespace mlir::torch::TorchConversion; +namespace mlir::torch { + +#define GEN_PASS_DEF_CONVERTTORCHCONVERSIONTOMLPROGRAM +#include "torch-mlir/Conversion/Passes.h.inc" static constexpr StringRef getSeedGobalVarName() { return "global_seed"; } @@ -102,7 +108,7 @@ class ConvertGetNextSeedOp : public OpConversionPattern { namespace { class ConvertTorchConversionToMLProgram - : public ConvertTorchConversionToMLProgramBase< + : public impl::ConvertTorchConversionToMLProgramBase< ConvertTorchConversionToMLProgram> { public: void getDependentDialects(DialectRegistry ®istry) const override { @@ -138,6 +144,8 @@ class ConvertTorchConversionToMLProgram } // namespace std::unique_ptr> -mlir::torch::createConvertTorchConversionToMLProgramPass() { +createConvertTorchConversionToMLProgramPass() { return std::make_unique(); } + +} // namespace mlir::torch diff --git a/lib/Conversion/TorchOnnxToTorch/PassDetail.h b/lib/Conversion/TorchOnnxToTorch/PassDetail.h deleted file mode 100644 index bbcd3413c59c..000000000000 --- a/lib/Conversion/TorchOnnxToTorch/PassDetail.h +++ /dev/null @@ -1,24 +0,0 @@ -//===------------------------------------------------------------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Also available under a BSD-style license. See LICENSE. -// -//===----------------------------------------------------------------------===// - -#ifndef TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_PASSDETAIL_H -#define TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_PASSDETAIL_H - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/Pass/Pass.h" - -namespace mlir::torch::onnx_c { - -#define GEN_PASS_CLASSES -#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h.inc" - -} // namespace mlir::torch::onnx_c - -#endif // TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_PASSDETAIL_H diff --git a/lib/Conversion/TorchOnnxToTorch/TorchOnnxToTorch.cpp b/lib/Conversion/TorchOnnxToTorch/TorchOnnxToTorch.cpp index d2f7517376d8..cc42d947175e 100644 --- a/lib/Conversion/TorchOnnxToTorch/TorchOnnxToTorch.cpp +++ b/lib/Conversion/TorchOnnxToTorch/TorchOnnxToTorch.cpp @@ -7,7 +7,8 @@ // //===----------------------------------------------------------------------===// -#include "./PassDetail.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h" #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" @@ -19,6 +20,10 @@ using llvm::dbgs; using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::onnx_c; +namespace mlir::torch::onnx_c { + +#define GEN_PASS_DEF_CONVERTTORCHONNXTOTORCH +#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h.inc" #define DEBUG_TYPE "torch-onnx" @@ -37,7 +42,7 @@ int64_t getDefaultOpsetVersion(Operation *containerOp) { } class ConvertTorchOnnxToTorch - : public ConvertTorchOnnxToTorchBase { + : public impl::ConvertTorchOnnxToTorchBase { public: ConvertTorchOnnxToTorch() = default; void runOnOperation() override { @@ -82,7 +87,8 @@ class ConvertTorchOnnxToTorch } // namespace -std::unique_ptr> -mlir::torch::onnx_c::createTorchOnnxToTorchPass() { +std::unique_ptr> createTorchOnnxToTorchPass() { return std::make_unique(); } + +} // namespace mlir::torch::onnx_c diff --git a/lib/Conversion/TorchToArith/TorchToArith.cpp b/lib/Conversion/TorchToArith/TorchToArith.cpp index 17614f95ea16..2dd15ef2c651 100644 --- a/lib/Conversion/TorchToArith/TorchToArith.cpp +++ b/lib/Conversion/TorchToArith/TorchToArith.cpp @@ -8,8 +8,10 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToArith/TorchToArith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" +#include "torch-mlir/Conversion/Passes.h" -#include "../PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Math/IR/Math.h" @@ -25,6 +27,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch { + +#define GEN_PASS_DEF_CONVERTTORCHTOARITH +#include "torch-mlir/Conversion/Passes.h.inc" // ----------------------------------------------------------------------------- // Patterns (as this grows, it should be organized into multiple files) @@ -407,7 +413,7 @@ class ConvertAtenBoolLikeOp : public OpConversionPattern { namespace { class ConvertTorchToArith - : public ConvertTorchToArithBase { + : public impl::ConvertTorchToArithBase { public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -565,7 +571,8 @@ class ConvertTorchToArith }; } // namespace -std::unique_ptr> -mlir::torch::createConvertTorchToArithPass() { +std::unique_ptr> createConvertTorchToArithPass() { return std::make_unique(); } + +} // namespace mlir::torch diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 01b1d4b973b6..8b0c6ab8ad19 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -8,8 +8,10 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" +#include "torch-mlir/Conversion/Passes.h" -#include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" @@ -24,6 +26,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch { + +#define GEN_PASS_DEF_CONVERTTORCHTOLINALG +#include "torch-mlir/Conversion/Passes.h.inc" // ----------------------------------------------------------------------------- // The pass @@ -34,7 +40,7 @@ using namespace mlir::torch::Torch; namespace { class ConvertTorchToLinalg - : public ConvertTorchToLinalgBase { + : public impl::ConvertTorchToLinalgBase { public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -89,7 +95,8 @@ class ConvertTorchToLinalg }; } // namespace -std::unique_ptr> -mlir::torch::createConvertTorchToLinalgPass() { +std::unique_ptr> createConvertTorchToLinalgPass() { return std::make_unique(); } + +} // namespace mlir::torch diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index c2b584efdecc..8630e6a7ac1a 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -8,7 +8,6 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tensor/Utils/Utils.h" -#include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" diff --git a/lib/Conversion/TorchToSCF/TorchToSCF.cpp b/lib/Conversion/TorchToSCF/TorchToSCF.cpp index 8978a75c01a4..57ee17700f53 100644 --- a/lib/Conversion/TorchToSCF/TorchToSCF.cpp +++ b/lib/Conversion/TorchToSCF/TorchToSCF.cpp @@ -8,8 +8,10 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" +#include "torch-mlir/Conversion/Passes.h" -#include "../PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Transforms/DialectConversion.h" @@ -21,6 +23,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch { + +#define GEN_PASS_DEF_CONVERTTORCHTOSCF +#include "torch-mlir/Conversion/Passes.h.inc" namespace { class ConvertTorchPrimIfYieldOp : public OpConversionPattern { @@ -312,7 +318,8 @@ class ConvertTorchPrimLoopForLikeOp : public OpConversionPattern { } // namespace namespace { -class ConvertTorchToSCF : public ConvertTorchToSCFBase { +class ConvertTorchToSCF + : public impl::ConvertTorchToSCFBase { public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -345,7 +352,8 @@ class ConvertTorchToSCF : public ConvertTorchToSCFBase { }; } // namespace -std::unique_ptr> -mlir::torch::createConvertTorchToSCFPass() { +std::unique_ptr> createConvertTorchToSCFPass() { return std::make_unique(); } + +} // namespace mlir::torch diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index a22e6658a2ac..2b6f4a90ef7e 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -9,7 +9,6 @@ #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" -#include "../PassDetail.h" #include "PopulatePatterns.h" #include "Utils.h" diff --git a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp index 8ebb7050b124..95435dd5805b 100644 --- a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp +++ b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp @@ -9,7 +9,6 @@ #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" -#include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/lib/Conversion/TorchToStablehlo/Linear.cpp b/lib/Conversion/TorchToStablehlo/Linear.cpp index 56094c8d0f52..892a0158667b 100644 --- a/lib/Conversion/TorchToStablehlo/Linear.cpp +++ b/lib/Conversion/TorchToStablehlo/Linear.cpp @@ -9,7 +9,6 @@ #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" -#include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/lib/Conversion/TorchToStablehlo/Pooling.cpp b/lib/Conversion/TorchToStablehlo/Pooling.cpp index 5c0ecb19c5a4..45982e108d00 100644 --- a/lib/Conversion/TorchToStablehlo/Pooling.cpp +++ b/lib/Conversion/TorchToStablehlo/Pooling.cpp @@ -9,7 +9,6 @@ #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" -#include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp index f66d9e040951..ec078080708a 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -9,7 +9,6 @@ #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" -#include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/lib/Conversion/TorchToStablehlo/Rng.cpp b/lib/Conversion/TorchToStablehlo/Rng.cpp index b71af126c69e..c7627431cf56 100644 --- a/lib/Conversion/TorchToStablehlo/Rng.cpp +++ b/lib/Conversion/TorchToStablehlo/Rng.cpp @@ -9,7 +9,6 @@ #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" -#include "../PassDetail.h" #include "./PopulatePatterns.h" #include "stablehlo/dialect/StablehloOps.h" diff --git a/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp b/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp index 5a7eb398dc9b..03d36e0ec91b 100644 --- a/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp +++ b/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp @@ -8,8 +8,10 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" +#include "torch-mlir/Conversion/Passes.h" -#include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -23,17 +25,18 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch { + +#define GEN_PASS_DEF_CONVERTTORCHTOSTABLEHLO +#include "torch-mlir/Conversion/Passes.h.inc" namespace { class ConvertTorchToStablehlo - : public ConvertTorchToStablehloBase { + : public impl::ConvertTorchToStablehloBase { public: - ConvertTorchToStablehlo() = default; - ConvertTorchToStablehlo(bool enableStaticShape, bool enableI32Index) { - this->enableStaticShape = enableStaticShape; - this->enableI32Index = enableI32Index; - } + using impl::ConvertTorchToStablehloBase< + ConvertTorchToStablehlo>::ConvertTorchToStablehloBase; void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -86,14 +89,20 @@ class ConvertTorchToStablehlo } // namespace +// Default pass creation function (required by tablegen) std::unique_ptr> -mlir::torch::createConvertTorchToStablehloPass() { - return std::make_unique(false, false); +createConvertTorchToStablehloPass() { + return std::make_unique(); } +// Convenience wrapper for users who want to pass options as individual +// parameters std::unique_ptr> -mlir::torch::createConvertTorchToStablehloPass(bool enableStaticShape, - bool enableI32Index) { - return std::make_unique(enableStaticShape, - enableI32Index); +createConvertTorchToStablehloPass(bool enableStaticShape, bool enableI32Index) { + ConvertTorchToStablehloOptions options; + options.enableStaticShape = enableStaticShape; + options.enableI32Index = enableI32Index; + return std::make_unique(options); } + +} // namespace mlir::torch diff --git a/lib/Conversion/TorchToStablehlo/Uncategorized.cpp b/lib/Conversion/TorchToStablehlo/Uncategorized.cpp index f8af1529ff15..5026e8dd09ee 100644 --- a/lib/Conversion/TorchToStablehlo/Uncategorized.cpp +++ b/lib/Conversion/TorchToStablehlo/Uncategorized.cpp @@ -10,7 +10,6 @@ #include "mlir/IR/BuiltinAttributes.h" #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" -#include "../PassDetail.h" #include "PopulatePatterns.h" #include "Utils.h" diff --git a/lib/Conversion/TorchToStablehlo/ViewLike.cpp b/lib/Conversion/TorchToStablehlo/ViewLike.cpp index af48f84fc357..632d64a3eae1 100644 --- a/lib/Conversion/TorchToStablehlo/ViewLike.cpp +++ b/lib/Conversion/TorchToStablehlo/ViewLike.cpp @@ -9,7 +9,6 @@ #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" -#include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index 60a04bbd7e55..cb5baab07b67 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -8,8 +8,10 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" +#include "torch-mlir/Conversion/Passes.h" -#include "../PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" @@ -33,6 +35,10 @@ using namespace mlir::torch; using namespace mlir::torch::Torch; using namespace mlir::torch::TorchConversion; using namespace mlir::torch::TMTensor; +namespace mlir::torch { + +#define GEN_PASS_DEF_CONVERTTORCHTOTMTENSOR +#include "torch-mlir/Conversion/Passes.h.inc" // ----------------------------------------------------------------------------- // Patterns (as this grows, it should be organized into multiple files) @@ -2459,7 +2465,7 @@ class ConvertAtenKthvalueOp : public OpConversionPattern { namespace { class ConvertTorchToTMTensor - : public ConvertTorchToTMTensorBase { + : public impl::ConvertTorchToTMTensorBase { public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -2519,6 +2525,8 @@ class ConvertTorchToTMTensor } // namespace std::unique_ptr> -mlir::torch::createConvertTorchToTMTensorPass() { +createConvertTorchToTMTensorPass() { return std::make_unique(); } + +} // namespace mlir::torch diff --git a/lib/Conversion/TorchToTensor/TorchToTensor.cpp b/lib/Conversion/TorchToTensor/TorchToTensor.cpp index 10fd2a160d0d..890ca4ec1860 100644 --- a/lib/Conversion/TorchToTensor/TorchToTensor.cpp +++ b/lib/Conversion/TorchToTensor/TorchToTensor.cpp @@ -8,8 +8,9 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToTensor/TorchToTensor.h" - -#include "../PassDetail.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" +#include "torch-mlir/Conversion/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -21,6 +22,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch { + +#define GEN_PASS_DEF_CONVERTTORCHTOTENSOR +#include "torch-mlir/Conversion/Passes.h.inc" namespace { @@ -139,7 +144,7 @@ class ConvertAtenTensorOpPattern : public OpConversionPattern { }; class ConvertTorchToTensor - : public ConvertTorchToTensorBase { + : public impl::ConvertTorchToTensorBase { public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -170,7 +175,8 @@ class ConvertTorchToTensor } // namespace -std::unique_ptr> -mlir::torch::createConvertTorchToTensorPass() { +std::unique_ptr> createConvertTorchToTensorPass() { return std::make_unique(); } + +} // namespace mlir::torch diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index c959f06c6a66..850ca3f3cfb9 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -8,13 +8,15 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" -#include "../PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/IR/Matchers.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" +#include "torch-mlir/Conversion/Passes.h" #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h" #include "torch-mlir/Conversion/Utils/Utils.h" @@ -34,6 +36,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch { + +#define GEN_PASS_DEF_CONVERTTORCHTOTOSA +#include "torch-mlir/Conversion/Passes.h.inc" namespace { @@ -9033,12 +9039,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // ----------------------------------------------------------------------------- namespace { -class ConvertTorchToTosa : public ConvertTorchToTosaBase { +class ConvertTorchToTosa + : public impl::ConvertTorchToTosaBase { public: - ConvertTorchToTosa() = default; - ConvertTorchToTosa(bool requireFullTosaConversion) { - this->requireFullTosaConversion = requireFullTosaConversion; - } + using impl::ConvertTorchToTosaBase< + ConvertTorchToTosa>::ConvertTorchToTosaBase; void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -9081,7 +9086,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { }; } // namespace -void torch::populateTorchToTosaConversionLegalOps(ConversionTarget &target) { +void populateTorchToTosaConversionLegalOps(ConversionTarget &target) { // The following ops are never the primary reason why lowering fails. // The backend contract only allows functions to return tensors thus there // is always another op using them. @@ -9098,7 +9103,7 @@ void torch::populateTorchToTosaConversionLegalOps(ConversionTarget &target) { target.addLegalOp(); } -std::set torch::populateTorchToTosaConversionPatternsAndIllegalOps( +std::set populateTorchToTosaConversionPatternsAndIllegalOps( TypeConverter &typeConverter, RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); @@ -9411,12 +9416,18 @@ std::set torch::populateTorchToTosaConversionPatternsAndIllegalOps( return illegalOps; } -std::unique_ptr> -mlir::torch::createConvertTorchToTosaPass() { - return std::make_unique(true); +// Default pass creation function (required by tablegen) +std::unique_ptr> createConvertTorchToTosaPass() { + return std::make_unique(); } +// Convenience wrapper for users who want to pass options as individual +// parameters std::unique_ptr> -mlir::torch::createConvertTorchToTosaPass(bool requireFullTosaConversion) { - return std::make_unique(requireFullTosaConversion); +createConvertTorchToTosaPass(bool requireFullTosaConversion) { + ConvertTorchToTosaOptions options; + options.requireFullTosaConversion = requireFullTosaConversion; + return std::make_unique(options); } + +} // namespace mlir::torch diff --git a/lib/Dialect/TMTensor/Transforms/Bufferize.cpp b/lib/Dialect/TMTensor/Transforms/Bufferize.cpp index ca47cdd6033a..f70099e0a478 100644 --- a/lib/Dialect/TMTensor/Transforms/Bufferize.cpp +++ b/lib/Dialect/TMTensor/Transforms/Bufferize.cpp @@ -23,11 +23,14 @@ #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h" -#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/PassDetail.h" #include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h" using namespace ::mlir; using namespace ::mlir::torch::TMTensor; +namespace mlir::torch::TMTensor { + +#define GEN_PASS_DEF_TMTENSORBUFFERIZE +#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h.inc" static Value cloneMemref(Location loc, Value memref, OpBuilder &b) { auto memrefType = cast(memref.getType()); @@ -134,7 +137,7 @@ static Value materializeToTensor(OpBuilder &builder, TensorType type, /// Converts TMTensor operations that work on tensor-type operands or results to /// work on buffers. struct TMTensorBufferizePass - : public TMTensorBufferizeBase { + : public impl::TMTensorBufferizeBase { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -202,7 +205,8 @@ struct TMTensorBufferizePass }; } // namespace -std::unique_ptr> -torch::TMTensor::createTMTensorBufferizePass() { +std::unique_ptr> createTMTensorBufferizePass() { return std::make_unique(); } + +} // namespace mlir::torch::TMTensor diff --git a/lib/Dialect/TMTensor/Transforms/ConvertToLoops.cpp b/lib/Dialect/TMTensor/Transforms/ConvertToLoops.cpp index 74d539ab6d8a..5c7755d210fa 100644 --- a/lib/Dialect/TMTensor/Transforms/ConvertToLoops.cpp +++ b/lib/Dialect/TMTensor/Transforms/ConvertToLoops.cpp @@ -20,7 +20,6 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h" -#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/PassDetail.h" #include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" @@ -28,6 +27,10 @@ using namespace mlir; using namespace mlir::torch::TMTensor; +namespace mlir::torch::TMTensor { + +#define GEN_PASS_DEF_TMTENSORTOLOOPS +#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h.inc" /// Recursive method that lowers one dimension of the `ScalarLoopOpInterface` to /// scalar loops at a time. @@ -98,7 +101,8 @@ struct ScalarLoopOpInterfaceLowerToLoopsPattern : public RewritePattern { //===----------------------------------------------------------------------===// namespace { -struct TMTensorToLoopsPass : public TMTensorToLoopsBase { +struct TMTensorToLoopsPass + : public impl::TMTensorToLoopsBase { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert { }; } // namespace -std::unique_ptr> -torch::TMTensor::createTMTensorToLoopsPass() { +std::unique_ptr> createTMTensorToLoopsPass() { return std::make_unique(); } + +} // namespace mlir::torch::TMTensor diff --git a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp index 25a38c83627c..37c54b51d874 100644 --- a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp +++ b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" @@ -17,6 +17,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DEF_ADJUSTCALLINGCONVENTIONS +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" // Map from func name and arg index to the type bound for that arg. // This is needed because to rewrite calls, we need the non-local information @@ -285,7 +289,7 @@ static LogicalResult adjustCallingConventions(func::FuncOp func, namespace { class AdjustCallingConventionsPass - : public AdjustCallingConventionsBase { + : public impl::AdjustCallingConventionsBase { void runOnOperation() override { auto module = getOperation(); TypeBoundMap typeBoundMap; @@ -306,7 +310,8 @@ class AdjustCallingConventionsPass }; } // namespace -std::unique_ptr> -mlir::torch::Torch::createAdjustCallingConventionsPass() { +std::unique_ptr> createAdjustCallingConventionsPass() { return std::make_unique(); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index c9c42b43c463..08b25c9b6f60 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -7,11 +7,11 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" @@ -27,6 +27,11 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DECL_DECOMPOSECOMPLEXOPS +#define GEN_PASS_DEF_DECOMPOSECOMPLEXOPS +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" // Helper function to check whether the `dtype` is None or Float type. static bool isNoneOrFloatDtype(MLIRContext *context, Value dtype) { @@ -13047,7 +13052,7 @@ class DecomposeAtenAsStridedOp : public OpRewritePattern { namespace { class DecomposeComplexOpsPass - : public DecomposeComplexOpsBase { + : public impl::DecomposeComplexOpsBase { private: llvm::StringSet<> legalOpsSet; @@ -13068,10 +13073,8 @@ class DecomposeComplexOpsPass } public: - DecomposeComplexOpsPass() = default; - DecomposeComplexOpsPass(ArrayRef legalOps) { - this->legalOps = legalOps; - } + using impl::DecomposeComplexOpsBase< + DecomposeComplexOpsPass>::DecomposeComplexOpsBase; void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); @@ -13392,7 +13395,10 @@ class DecomposeComplexOpsPass } // namespace std::unique_ptr> -mlir::torch::Torch::createDecomposeComplexOpsPass( - ArrayRef legalOps) { - return std::make_unique(legalOps); +createDecomposeComplexOpsPass(ArrayRef legalOps) { + DecomposeComplexOpsOptions options; + options.legalOps.append(legalOps.begin(), legalOps.end()); + return std::make_unique(options); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/DropAbstractInterpCalculations.cpp b/lib/Dialect/Torch/Transforms/DropAbstractInterpCalculations.cpp index c3236c0324d1..f7cf7e5f4384 100644 --- a/lib/Dialect/Torch/Transforms/DropAbstractInterpCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/DropAbstractInterpCalculations.cpp @@ -7,8 +7,8 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -17,6 +17,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DEF_DROPABSTRACTINTERPCALCULATIONS +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" namespace { template @@ -39,7 +43,7 @@ class DropCalculateOp : public OpConversionPattern { namespace { class DropAbstractInterpCalculationsPass - : public DropAbstractInterpCalculationsBase< + : public impl::DropAbstractInterpCalculationsBase< DropAbstractInterpCalculationsPass> { void runOnOperation() override { MLIRContext *context = &getContext(); @@ -61,6 +65,8 @@ class DropAbstractInterpCalculationsPass } // namespace std::unique_ptr> -mlir::torch::Torch::createDropAbstractInterpCalculationsPass() { +createDropAbstractInterpCalculationsPass() { return std::make_unique(); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/EraseModuleInitializer.cpp b/lib/Dialect/Torch/Transforms/EraseModuleInitializer.cpp index db80714127e1..7602169f74ab 100644 --- a/lib/Dialect/Torch/Transforms/EraseModuleInitializer.cpp +++ b/lib/Dialect/Torch/Transforms/EraseModuleInitializer.cpp @@ -7,10 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IRMapping.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" @@ -18,10 +17,14 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DEF_ERASEMODULEINITIALIZER +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" namespace { class EraseModuleInitializerPass - : public EraseModuleInitializerBase { + : public impl::EraseModuleInitializerBase { void runOnOperation() override { for (auto initializer : getOperation().getOps()) { @@ -37,7 +40,8 @@ class EraseModuleInitializerPass }; } // namespace -std::unique_ptr> -mlir::torch::Torch::createEraseModuleInitializerPass() { +std::unique_ptr> createEraseModuleInitializerPass() { return std::make_unique(); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp index e418a4b08ec0..4b04151f68cc 100644 --- a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp +++ b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp @@ -7,8 +7,8 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" @@ -18,6 +18,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DEF_FUSEQUANTIZEDOPS +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" namespace { @@ -438,7 +442,8 @@ template class RemoveUnused : public OpRewritePattern { } }; -class FuseQuantizedOpsPass : public FuseQuantizedOpsBase { +class FuseQuantizedOpsPass + : public impl::FuseQuantizedOpsBase { public: void runOnOperation() override { MLIRContext *context = &getContext(); @@ -470,7 +475,8 @@ class FuseQuantizedOpsPass : public FuseQuantizedOpsBase { } // namespace -std::unique_ptr> -mlir::torch::Torch::createFuseQuantizedOpsPass() { +std::unique_ptr> createFuseQuantizedOpsPass() { return std::make_unique(); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp index d47220e348ea..a3f15c07183e 100644 --- a/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp +++ b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp @@ -7,10 +7,10 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IRMapping.h" +#include "mlir/Pass/Pass.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "llvm/ADT/STLExtras.h" @@ -21,6 +21,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DEF_GLOBALIZEOBJECTGRAPH +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" static FailureOr findRootNnModule(ModuleOp module) { NnModuleOp rootNnModule; @@ -299,6 +303,8 @@ struct Monomorphization { }; } // namespace +} // namespace mlir::torch::Torch + template <> struct llvm::DenseMapInfo { static Monomorphization getEmptyKey() { return Monomorphization{nullptr, {ArgInstance{-1, nullptr}}}; @@ -318,6 +324,8 @@ template <> struct llvm::DenseMapInfo { } }; +namespace mlir::torch::Torch { + // Populate `mapping` such that values of NnModuleType in the function are // mapped to appropriate global objects of NnModuleType. // @@ -696,7 +704,7 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) { namespace { class GlobalizeObjectGraphPass - : public GlobalizeObjectGraphBase { + : public impl::GlobalizeObjectGraphBase { void runOnOperation() override { if (failed(globalizeObjectGraph(getOperation()))) return signalPassFailure(); @@ -704,7 +712,7 @@ class GlobalizeObjectGraphPass }; } // namespace -std::unique_ptr> -mlir::torch::Torch::createGlobalizeObjectGraphPass() { +std::unique_ptr> createGlobalizeObjectGraphPass() { return std::make_unique(); } +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp index 12660cfee47c..d6308f11b8d8 100644 --- a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp +++ b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp @@ -23,12 +23,11 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - #include "mlir/Analysis/DataFlowFramework.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IRMapping.h" +#include "mlir/Pass/Pass.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "llvm/Support/Debug.h" @@ -38,6 +37,11 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DEF_FLATSYMBOLREFLATTICEANCHOR +#define GEN_PASS_DEF_INLINEGLOBALSLOTS +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" /// A program point representing a symbol. /// @@ -276,7 +280,7 @@ static bool isInitialValueTransitivelySafeToInline(Value initialValue, namespace { class InlineGlobalSlotsPass - : public InlineGlobalSlotsBase { + : public impl::InlineGlobalSlotsBase { void runOnOperation() override { ModuleOp module = getOperation(); @@ -417,7 +421,8 @@ class InlineGlobalSlotsPass }; } // namespace -std::unique_ptr> -mlir::torch::Torch::createInlineGlobalSlotsPass() { +std::unique_ptr> createInlineGlobalSlotsPass() { return std::make_unique(); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index cfc8bb96118b..b149d172496c 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" @@ -24,6 +24,12 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DECL_LOWERTOBACKENDCONTRACT +#define GEN_PASS_DEF_LOWERTOBACKENDCONTRACT +#define GEN_PASS_DEF_VERIFYBACKENDCONTRACTNODECOMPOSITIONS +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" //===----------------------------------------------------------------------===// // Checking the backend contract. @@ -258,19 +264,10 @@ getBackendContractTarget(MLIRContext *context, bool decompose, namespace { class LowerToBackendContractPass - : public LowerToBackendContractBase { + : public impl::LowerToBackendContractBase { public: - LowerToBackendContractPass() = default; - LowerToBackendContractPass(int maxIterations, bool decompose, - bool shapeDtypeRefine, - ArrayRef backendLegalOps, - StringRef extraLibrary) { - this->maxIterations = maxIterations; - this->decompose = decompose; - this->shapeDtypeRefine = shapeDtypeRefine; - this->backendLegalOps = backendLegalOps; - this->extraLibrary = extraLibrary.str(); - } + using impl::LowerToBackendContractBase< + LowerToBackendContractPass>::LowerToBackendContractBase; void runOnOperation() override { ModuleOp module = getOperation(); MLIRContext *context = &getContext(); @@ -317,7 +314,7 @@ class LowerToBackendContractPass }; class VerifyBackendContractNoDecompositionsPass - : public VerifyBackendContractNoDecompositionsBase< + : public impl::VerifyBackendContractNoDecompositionsBase< VerifyBackendContractNoDecompositionsPass> { public: VerifyBackendContractNoDecompositionsPass() = default; @@ -336,17 +333,21 @@ class VerifyBackendContractNoDecompositionsPass }; } // namespace -std::unique_ptr> -mlir::torch::Torch::createLowerToBackendContractPass( +std::unique_ptr> createLowerToBackendContractPass( int maxIterations, bool decompose, bool shapeDtypeRefine, ArrayRef backendLegalOps, StringRef extraLibrary) { - return std::make_unique( - maxIterations, decompose, shapeDtypeRefine, backendLegalOps, - extraLibrary); + LowerToBackendContractOptions options; + options.maxIterations = maxIterations; + options.decompose = decompose; + options.shapeDtypeRefine = shapeDtypeRefine; + options.backendLegalOps.append(backendLegalOps.begin(), + backendLegalOps.end()); + options.extraLibrary = extraLibrary.str(); + return std::make_unique(options); } std::unique_ptr> -mlir::torch::Torch::createVerifyBackendContractNoDecompositionsPass() { +createVerifyBackendContractNoDecompositionsPass() { return std::make_unique(); } @@ -606,3 +607,5 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, return backendLegalOpsSet.contains(opName); }); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp b/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp index e5c279415c7f..d873892ba61e 100644 --- a/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp +++ b/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp @@ -7,8 +7,8 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" @@ -17,6 +17,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DEF_MATCHQUANTIZEDCUSTOMOPS +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" namespace { @@ -115,7 +119,7 @@ class MatchQuantizeOperator : public OpRewritePattern { }; class MatchQuantizedCustomOpsPass - : public MatchQuantizedCustomOpsBase { + : public impl::MatchQuantizedCustomOpsBase { public: void runOnOperation() override { MLIRContext *context = &getContext(); @@ -132,6 +136,8 @@ class MatchQuantizedCustomOpsPass } // namespace std::unique_ptr> -mlir::torch::Torch::createMatchQuantizedCustomOpsPass() { +createMatchQuantizedCustomOpsPass() { return std::make_unique(); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp index 10580b81876b..aec53aa6535a 100644 --- a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp +++ b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp @@ -7,10 +7,10 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" @@ -19,6 +19,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DEF_MAXIMIZEVALUESEMANTICS +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" static Value assertNonValueTensor(Value tensor) { assert(isa(tensor.getType()) && @@ -364,7 +368,7 @@ class RewriteViewLikeSubgraph namespace { class MaximizeValueSemanticsPass - : public MaximizeValueSemanticsBase { + : public impl::MaximizeValueSemanticsBase { void runOnOperation() override { MLIRContext *context = &getContext(); auto func = getOperation(); @@ -379,6 +383,8 @@ class MaximizeValueSemanticsPass } // namespace std::unique_ptr> -mlir::torch::Torch::createMaximizeValueSemanticsPass() { +createMaximizeValueSemanticsPass() { return std::make_unique(); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/PassDetail.h b/lib/Dialect/Torch/Transforms/PassDetail.h deleted file mode 100644 index 85fc116fe5ae..000000000000 --- a/lib/Dialect/Torch/Transforms/PassDetail.h +++ /dev/null @@ -1,28 +0,0 @@ -//===- PassDetail.h - Pass details ------------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Also available under a BSD-style license. See LICENSE. -// -//===----------------------------------------------------------------------===// - -#ifndef TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H -#define TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Pass/Pass.h" - -namespace mlir { -class ModuleOp; -namespace torch { -namespace Torch { - -#define GEN_PASS_CLASSES -#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" - -} // namespace Torch -} // namespace torch -} // end namespace mlir - -#endif // TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H diff --git a/lib/Dialect/Torch/Transforms/PrepareForGlobalizeObjectGraph.cpp b/lib/Dialect/Torch/Transforms/PrepareForGlobalizeObjectGraph.cpp index c7ff95270d98..a53c24954627 100644 --- a/lib/Dialect/Torch/Transforms/PrepareForGlobalizeObjectGraph.cpp +++ b/lib/Dialect/Torch/Transforms/PrepareForGlobalizeObjectGraph.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -18,6 +18,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DEF_PREPAREFORGLOBALIZEOBJECTGRAPH +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" namespace { class ConvertPrimCallMethodToCall : public OpRewritePattern { @@ -63,7 +67,7 @@ class EraseUnusedConstantOp : public OpRewritePattern { namespace { class PrepareForGlobalizeObjectGraphPass - : public PrepareForGlobalizeObjectGraphBase< + : public impl::PrepareForGlobalizeObjectGraphBase< PrepareForGlobalizeObjectGraphPass> { void runOnOperation() override { @@ -105,6 +109,8 @@ class PrepareForGlobalizeObjectGraphPass } // namespace std::unique_ptr> -mlir::torch::Torch::createPrepareForGlobalizeObjectGraphPass() { +createPrepareForGlobalizeObjectGraphPass() { return std::make_unique(); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp index 1d7c926473c2..dc533427d773 100644 --- a/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp @@ -7,8 +7,8 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" @@ -18,6 +18,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DEF_RECOMPOSECOMPLEXOPS +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" namespace { @@ -806,7 +810,7 @@ class RecomposeMeshgridIndexingListUnpack namespace { class RecomposeComplexOpsPass - : public RecomposeComplexOpsBase { + : public impl::RecomposeComplexOpsBase { public: void runOnOperation() override { MLIRContext *context = &getContext(); @@ -841,7 +845,8 @@ class RecomposeComplexOpsPass }; } // namespace -std::unique_ptr> -mlir::torch::Torch::createRecomposeComplexOpsPass() { +std::unique_ptr> createRecomposeComplexOpsPass() { return std::make_unique(); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp index b84b4465eab5..187d234183a3 100644 --- a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp +++ b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - #include "ReifyAbstractInterpCalculationsUtils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" @@ -18,6 +18,11 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DECL_REDUCEOPVARIANTS +#define GEN_PASS_DEF_REDUCEOPVARIANTS +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" // Create an overwrite in a manner that preserves the // `OverwriteTensorContentsOp` invariant that both arguments @@ -403,11 +408,8 @@ reduceNonValueTensorLiteralOpToValueTensorLiteralOp(NonValueTensorLiteralOp op, namespace { struct ReduceOpVariantsPass - : public ReduceOpVariantsBase { - ReduceOpVariantsPass() = default; - ReduceOpVariantsPass(StringRef extraLibrary) { - this->extraLibrary = extraLibrary.str(); - } + : public impl::ReduceOpVariantsBase { + using impl::ReduceOpVariantsBase::ReduceOpVariantsBase; void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); @@ -481,6 +483,10 @@ struct ReduceOpVariantsPass } // namespace std::unique_ptr> -mlir::torch::Torch::createReduceOpVariantsPass(StringRef extraLibrary) { - return std::make_unique(extraLibrary); +createReduceOpVariantsPass(StringRef extraLibrary) { + ReduceOpVariantsOptions options; + options.extraLibrary = extraLibrary.str(); + return std::make_unique(options); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp b/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp index 6f45e8876ee1..1040b8e9976d 100644 --- a/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp +++ b/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp @@ -7,20 +7,24 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DEF_REFINEPUBLICRETURN +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" namespace { class RefinePublicReturnPass - : public RefinePublicReturnBase { + : public impl::RefinePublicReturnBase { void runOnOperation() override { auto module = getOperation(); module.walk([&](func::FuncOp func) { @@ -101,7 +105,8 @@ class RefinePublicReturnPass } // namespace -std::unique_ptr> -mlir::torch::Torch::createRefinePublicReturnPass() { +std::unique_ptr> createRefinePublicReturnPass() { return std::make_unique(); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp b/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp index 790fd80a2f71..e1b781957f0e 100644 --- a/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp @@ -7,10 +7,10 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - #include "ReifyAbstractInterpCalculationsUtils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Parser/Parser.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" @@ -18,6 +18,11 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DECL_REIFYDTYPECALCULATIONS +#define GEN_PASS_DEF_REIFYDTYPECALCULATIONS +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" // Massage the op operands to match the dtype function signature. // The dtype function generally takes the same operands as the op, with a few @@ -62,11 +67,10 @@ dtypeFunctionArgsBuilder(OpBuilder &b, Location loc, namespace { struct ReifyDtypeCalculationsPass - : public ReifyDtypeCalculationsBase { - ReifyDtypeCalculationsPass() = default; - ReifyDtypeCalculationsPass(StringRef extraLibrary) { - this->extraLibrary = extraLibrary.str(); - } + : public impl::ReifyDtypeCalculationsBase { + using impl::ReifyDtypeCalculationsBase< + ReifyDtypeCalculationsPass>::ReifyDtypeCalculationsBase; + void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp module = getOperation(); @@ -96,6 +100,10 @@ struct ReifyDtypeCalculationsPass } // namespace std::unique_ptr> -Torch::createReifyDtypeCalculationsPass(StringRef extraLibrary) { - return std::make_unique(extraLibrary); +createReifyDtypeCalculationsPass(StringRef extraLibrary) { + ReifyDtypeCalculationsOptions options; + options.extraLibrary = extraLibrary.str(); + return std::make_unique(options); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp b/lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp index 4b81970909d2..81bd5e45a30a 100644 --- a/lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp @@ -7,10 +7,10 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - #include "ReifyAbstractInterpCalculationsUtils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Parser/Parser.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" @@ -19,6 +19,11 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DECL_REIFYSHAPECALCULATIONS +#define GEN_PASS_DEF_REIFYSHAPECALCULATIONS +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" static FailureOr> shapeFunctionArgsBuilder(OpBuilder &b, Location loc, @@ -57,11 +62,9 @@ shapeFunctionArgsBuilder(OpBuilder &b, Location loc, namespace { struct ReifyShapeCalculationsPass - : public ReifyShapeCalculationsBase { - ReifyShapeCalculationsPass() = default; - ReifyShapeCalculationsPass(StringRef extraLibrary) { - this->extraLibrary = extraLibrary.str(); - } + : public impl::ReifyShapeCalculationsBase { + using impl::ReifyShapeCalculationsBase< + ReifyShapeCalculationsPass>::ReifyShapeCalculationsBase; void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp module = getOperation(); @@ -95,6 +98,10 @@ struct ReifyShapeCalculationsPass } // namespace std::unique_ptr> -mlir::torch::Torch::createReifyShapeCalculationsPass(StringRef extraLibrary) { - return std::make_unique(extraLibrary); +createReifyShapeCalculationsPass(StringRef extraLibrary) { + ReifyShapeCalculationsOptions options; + options.extraLibrary = extraLibrary.str(); + return std::make_unique(options); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/RestructureNonConstantAxes.cpp b/lib/Dialect/Torch/Transforms/RestructureNonConstantAxes.cpp index 0ea79a02a799..7a5d9270bf77 100644 --- a/lib/Dialect/Torch/Transforms/RestructureNonConstantAxes.cpp +++ b/lib/Dialect/Torch/Transforms/RestructureNonConstantAxes.cpp @@ -8,9 +8,9 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -26,6 +26,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DEF_RESTRUCTURENONCONSTANTAXES +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" namespace { @@ -251,7 +255,8 @@ void populateRestructureNonConstantAxesPattern(RewritePatternSet &patterns, } class RestructureNonConstantAxesPass - : public RestructureNonConstantAxesBase { + : public impl::RestructureNonConstantAxesBase< + RestructureNonConstantAxesPass> { public: RestructureNonConstantAxesPass() = default; @@ -276,6 +281,8 @@ class RestructureNonConstantAxesPass } // namespace std::unique_ptr> -mlir::torch::Torch::createRestructureNonConstantAxesPass() { +createRestructureNonConstantAxesPass() { return std::make_unique(); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp index d6db40d0c182..5c990b720d51 100644 --- a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp +++ b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp @@ -7,12 +7,13 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Iterators.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" @@ -23,6 +24,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DEF_SCALARIZESHAPES +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" namespace { @@ -1534,7 +1539,8 @@ void populateScalarizationRemovePatterns(RewritePatternSet &patterns) { } // namespace namespace { -class ScalarizeShapesPass : public ScalarizeShapesBase { +class ScalarizeShapesPass + : public impl::ScalarizeShapesBase { public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -1615,7 +1621,8 @@ class ScalarizeShapesPass : public ScalarizeShapesBase { }; } // namespace -std::unique_ptr> -mlir::torch::Torch::createScalarizeShapesPass() { +std::unique_ptr> createScalarizeShapesPass() { return std::make_unique(); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp index 2432a3b4686d..25a3cbabb5b9 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - #include "SimplifyAbstractInterpCalculationsUtils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" @@ -18,6 +18,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DEF_SIMPLIFYDTYPECALCULATIONS +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" static LogicalResult refineDtypeCalculateResult(DtypeCalculateOp op, int resultNum, @@ -192,7 +196,8 @@ class RefineNumToTensorScalarOpType namespace { class SimplifyDtypeCalculationsPass - : public SimplifyDtypeCalculationsBase { + : public impl::SimplifyDtypeCalculationsBase< + SimplifyDtypeCalculationsPass> { void runOnOperation() override { MLIRContext *context = &getContext(); @@ -222,6 +227,8 @@ class SimplifyDtypeCalculationsPass } // namespace std::unique_ptr> -mlir::torch::Torch::createSimplifyDtypeCalculationsPass() { +createSimplifyDtypeCalculationsPass() { return std::make_unique(); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp index 54a9fb07d72b..0f78e668310f 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - #include "SimplifyAbstractInterpCalculationsUtils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" @@ -17,6 +17,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DEF_SIMPLIFYSHAPECALCULATIONS +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" namespace { class DecomposeAtenSizeOp : public OpRewritePattern { @@ -186,7 +190,8 @@ class RefineShapeCalculateOp : public OpRewritePattern { namespace { class SimplifyShapeCalculationsPass - : public SimplifyShapeCalculationsBase { + : public impl::SimplifyShapeCalculationsBase< + SimplifyShapeCalculationsPass> { void runOnOperation() override { MLIRContext *context = &getContext(); @@ -219,6 +224,8 @@ class SimplifyShapeCalculationsPass } // namespace std::unique_ptr> -mlir::torch::Torch::createSimplifyShapeCalculationsPass() { +createSimplifyShapeCalculationsPass() { return std::make_unique(); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp index dadd865a54a7..8625a55205d3 100644 --- a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp +++ b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp @@ -7,12 +7,11 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" @@ -22,6 +21,13 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::TorchConversion; +namespace mlir::torch::TorchConversion { + +#define GEN_PASS_DEF_FUNCBACKENDTYPECONVERSION +#define GEN_PASS_DEF_FUNCBACKENDTYPECONVERSIONFORSTABLEHLO +#define GEN_PASS_DEF_FINALIZINGBACKENDTYPECONVERSION +#define GEN_PASS_DEF_FINALIZINGBACKENDTYPECONVERSIONFORSTABLEHLO +#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h.inc" //===----------------------------------------------------------------------===// // FuncBackendTypeConversionPass @@ -74,7 +80,8 @@ void populateFuncBackendTypeConversionPatterns(TypeConverter &typeConverter, } struct FuncBackendTypeConversionPass - : public FuncBackendTypeConversionBase { + : public impl::FuncBackendTypeConversionBase< + FuncBackendTypeConversionPass> { using FuncBackendTypeConversionBase< FuncBackendTypeConversionPass>::FuncBackendTypeConversionBase; void getDependentDialects(DialectRegistry ®istry) const override { @@ -99,7 +106,7 @@ struct FuncBackendTypeConversionPass #ifdef TORCH_MLIR_ENABLE_STABLEHLO struct FuncBackendTypeConversionForStablehloPass - : public FuncBackendTypeConversionForStablehloBase< + : public impl::FuncBackendTypeConversionForStablehloBase< FuncBackendTypeConversionForStablehloPass> { using FuncBackendTypeConversionForStablehloBase< FuncBackendTypeConversionForStablehloPass>:: @@ -127,14 +134,14 @@ struct FuncBackendTypeConversionForStablehloPass #endif // TORCH_MLIR_ENABLE_STABLEHLO } // namespace -std::unique_ptr> -mlir::torch::TorchConversion::createFuncBackendTypeConversionPass() { +// Create functions for passes +std::unique_ptr> createFuncBackendTypeConversionPass() { return std::make_unique(); } #ifdef TORCH_MLIR_ENABLE_STABLEHLO -std::unique_ptr> mlir::torch::TorchConversion:: - createFuncBackendTypeConversionForStablehloPass() { +std::unique_ptr> +createFuncBackendTypeConversionForStablehloPass() { return std::make_unique(); } #endif // TORCH_MLIR_ENABLE_STABLEHLO @@ -195,7 +202,7 @@ static void stripTorchAttrs(FunctionOpInterface func) { namespace { struct FinalizingBackendTypeConversionPass - : public FinalizingBackendTypeConversionBase< + : public impl::FinalizingBackendTypeConversionBase< FinalizingBackendTypeConversionPass> { using FinalizingBackendTypeConversionBase< FinalizingBackendTypeConversionPass>::FinalizingBackendTypeConversionBase; @@ -242,7 +249,7 @@ struct FinalizingBackendTypeConversionPass #ifdef TORCH_MLIR_ENABLE_STABLEHLO struct FinalizingBackendTypeConversionForStablehloPass - : public FinalizingBackendTypeConversionForStablehloBase< + : public impl::FinalizingBackendTypeConversionForStablehloBase< FinalizingBackendTypeConversionForStablehloPass> { using FinalizingBackendTypeConversionForStablehloBase< FinalizingBackendTypeConversionForStablehloPass>:: @@ -287,13 +294,15 @@ struct FinalizingBackendTypeConversionForStablehloPass } // namespace std::unique_ptr> -mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass() { +createFinalizingBackendTypeConversionPass() { return std::make_unique(); } #ifdef TORCH_MLIR_ENABLE_STABLEHLO -std::unique_ptr> mlir::torch:: - TorchConversion::createFinalizingBackendTypeConversionForStablehloPass() { +std::unique_ptr> +createFinalizingBackendTypeConversionForStablehloPass() { return std::make_unique(); } #endif // TORCH_MLIR_ENABLE_STABLEHLO + +} // namespace mlir::torch::TorchConversion diff --git a/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp b/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp index a4c28c2c3160..8b55b664a578 100644 --- a/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp +++ b/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp @@ -7,11 +7,11 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" @@ -22,6 +22,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::TorchConversion { + +#define GEN_PASS_DEF_CONVERTCUSTOMQUANTOP +#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h.inc" namespace { class ConvertCustomQuantizedMatmulOp : public OpConversionPattern { @@ -191,8 +195,7 @@ class ConvertCustomQuantizedMatmulOp : public OpConversionPattern { namespace { class ConvertCustomQuantOpPass - : public TorchConversion::ConvertCustomQuantOpBase< - ConvertCustomQuantOpPass> { + : public impl::ConvertCustomQuantOpBase { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); registry.insert(); @@ -225,6 +228,8 @@ class ConvertCustomQuantOpPass } // namespace std::unique_ptr> -mlir::torch::TorchConversion::createConvertCustomQuantOpPass() { +createConvertCustomQuantOpPass() { return std::make_unique(); } + +} // namespace mlir::torch::TorchConversion diff --git a/lib/Dialect/TorchConversion/Transforms/PassDetail.h b/lib/Dialect/TorchConversion/Transforms/PassDetail.h deleted file mode 100644 index cb80ebd89a3c..000000000000 --- a/lib/Dialect/TorchConversion/Transforms/PassDetail.h +++ /dev/null @@ -1,29 +0,0 @@ -//===- PassDetail.h - Pass details ------------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Also available under a BSD-style license. See LICENSE. -// -//===----------------------------------------------------------------------===// - -#ifndef TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H -#define TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H - -#include "mlir/Interfaces/FunctionInterfaces.h" -#include "mlir/Pass/Pass.h" - -namespace mlir { -class ModuleOp; - -namespace torch { -namespace TorchConversion { - -#define GEN_PASS_CLASSES -#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h.inc" - -} // namespace TorchConversion -} // namespace torch -} // end namespace mlir - -#endif // TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H diff --git a/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp b/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp index fcc0beb8d0c3..b621eea1fcd7 100644 --- a/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp +++ b/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp @@ -7,8 +7,8 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" @@ -19,6 +19,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::TorchConversion { + +#define GEN_PASS_DEF_UNPACKQUANTTENSOR +#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h.inc" namespace { class UnpackQuantizedMatmulWeights @@ -119,7 +123,7 @@ class UnpackQuantizedMatmulWeights namespace { class UnpackQuantTensorPass - : public TorchConversion::UnpackQuantTensorBase { + : public impl::UnpackQuantTensorBase { using UnpackQuantTensorBase::UnpackQuantTensorBase; void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -138,6 +142,8 @@ class UnpackQuantTensorPass } // namespace std::unique_ptr> -mlir::torch::TorchConversion::createUnpackQuantTensorPass() { +createUnpackQuantTensorPass() { return std::make_unique(); } + +} // namespace mlir::torch::TorchConversion diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp b/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp index 5189a17fc942..f08a1f389f81 100644 --- a/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp +++ b/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp @@ -7,8 +7,6 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" @@ -19,6 +17,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" @@ -30,10 +29,14 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::TorchConversion; using namespace TMTensor; +namespace mlir::torch::TorchConversion { + +#define GEN_PASS_DEF_VERIFYLINALGONTENSORSBACKENDCONTRACT +#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h.inc" namespace { class VerifyLinalgOnTensorsBackendContractPass - : public VerifyLinalgOnTensorsBackendContractBase< + : public impl::VerifyLinalgOnTensorsBackendContractBase< VerifyLinalgOnTensorsBackendContractPass> { void runOnOperation() override { MLIRContext *context = &getContext(); @@ -105,6 +108,8 @@ class VerifyLinalgOnTensorsBackendContractPass } // namespace std::unique_ptr> -mlir::torch::TorchConversion::createVerifyLinalgOnTensorsBackendContractPass() { +createVerifyLinalgOnTensorsBackendContractPass() { return std::make_unique(); } + +} // namespace mlir::torch::TorchConversion diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp b/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp index 3ff6e4732db2..9b0b8986bf8e 100644 --- a/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp +++ b/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp @@ -7,13 +7,12 @@ // //===----------------------------------------------------------------------===// #ifdef TORCH_MLIR_ENABLE_STABLEHLO -#include "PassDetail.h" - #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" @@ -22,10 +21,14 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::TorchConversion; +namespace mlir::torch::TorchConversion { + +#define GEN_PASS_DEF_VERIFYSTABLEHLOBACKENDCONTRACT +#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h.inc" namespace { class VerifyStablehloBackendContractPass - : public VerifyStablehloBackendContractBase< + : public impl::VerifyStablehloBackendContractBase< VerifyStablehloBackendContractPass> { void runOnOperation() override { TypeConverter converter; @@ -66,7 +69,10 @@ class VerifyStablehloBackendContractPass } // namespace std::unique_ptr> -mlir::torch::TorchConversion::createVerifyStablehloBackendContractPass() { +createVerifyStablehloBackendContractPass() { return std::make_unique(); } + +} // namespace mlir::torch::TorchConversion + #endif // TORCH_MLIR_ENABLE_STABLEHLO diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp b/lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp index efa40a02aeb0..3d48a4b8ef81 100644 --- a/lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp +++ b/lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp @@ -7,22 +7,26 @@ // //===----------------------------------------------------------------------===// #ifdef TORCH_MLIR_ENABLE_TOSA -#include "PassDetail.h" - #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::TorchConversion; +namespace mlir::torch::TorchConversion { + +#define GEN_PASS_DEF_VERIFYTOSABACKENDCONTRACT +#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h.inc" namespace { class VerifyTosaBackendContractPass - : public VerifyTosaBackendContractBase { + : public impl::VerifyTosaBackendContractBase< + VerifyTosaBackendContractPass> { void runOnOperation() override { MLIRContext *context = &getContext(); auto module = getOperation(); @@ -59,8 +63,10 @@ class VerifyTosaBackendContractPass }; } // namespace -std::unique_ptr> -mlir::torch::TorchConversion::createVerifyTosaBackendContractPass() { +std::unique_ptr> createVerifyTosaBackendContractPass() { return std::make_unique(); } + +} // namespace mlir::torch::TorchConversion + #endif // TORCH_MLIR_ENABLE_TOSA diff --git a/lib/RefBackend/PassDetail.h b/lib/RefBackend/PassDetail.h deleted file mode 100644 index aad2c369168b..000000000000 --- a/lib/RefBackend/PassDetail.h +++ /dev/null @@ -1,26 +0,0 @@ -//===- PassDetail.h - RefBackend Pass class details -------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Also available under a BSD-style license. See LICENSE. -// -//===----------------------------------------------------------------------===// - -#ifndef REFBACKEND_PASSDETAIL_H -#define REFBACKEND_PASSDETAIL_H - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Pass/Pass.h" - -namespace mlir { -namespace torch { - -#define GEN_PASS_CLASSES -#include "torch-mlir/RefBackend/Passes.h.inc" - -} // namespace torch -} // end namespace mlir - -#endif // REFBACKEND_PASSDETAIL_H diff --git a/lib/RefBackend/RefBackend.cpp b/lib/RefBackend/RefBackend.cpp index 89c7fb5df21a..f5e005994432 100644 --- a/lib/RefBackend/RefBackend.cpp +++ b/lib/RefBackend/RefBackend.cpp @@ -14,7 +14,6 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -42,6 +41,26 @@ using namespace mlir::torch::RefBackend; // Pass registration //===----------------------------------------------------------------------===// +namespace mlir::torch::RefBackend { + +#define GEN_PASS_DEF_MUNGECALLINGCONVENTIONS +#define GEN_PASS_DEF_MLPROGRAMBUFFERIZE +#define GEN_PASS_DEF_EXPANDOPSFORLLVM +#define GEN_PASS_DEF_MUNGEMEMREFCOPY +#define GEN_PASS_DEF_GENERALIZETENSORCONCAT +#define GEN_PASS_DEF_GENERALIZETENSORPAD +#include "torch-mlir/RefBackend/Passes.h.inc" + +} // namespace mlir::torch::RefBackend + +// Bring Base classes into scope for anonymous namespace passes +using mlir::torch::RefBackend::impl::ExpandOpsForLLVMBase; +using mlir::torch::RefBackend::impl::GeneralizeTensorConcatBase; +using mlir::torch::RefBackend::impl::GeneralizeTensorPadBase; +using mlir::torch::RefBackend::impl::MLProgramBufferizeBase; +using mlir::torch::RefBackend::impl::MungeCallingConventionsBase; +using mlir::torch::RefBackend::impl::MungeMemrefCopyBase; + namespace { #define GEN_PASS_REGISTRATION #include "torch-mlir/RefBackend/Passes.h.inc" @@ -220,11 +239,6 @@ class MungeCallingConventions }; } // namespace -std::unique_ptr> -mlir::torch::RefBackend::createMungeCallingConventionsPass() { - return std::make_unique(); -} - //===----------------------------------------------------------------------===// // MLProgramBufferize //===----------------------------------------------------------------------===// @@ -346,11 +360,6 @@ class MLProgramBufferize : public MLProgramBufferizeBase { }; } // namespace -std::unique_ptr> -mlir::torch::RefBackend::createMLProgramBufferizePass() { - return std::make_unique(); -} - //===----------------------------------------------------------------------===// // ExpandOpsForLLVM //===----------------------------------------------------------------------===// @@ -376,11 +385,6 @@ class ExpandOpsForLLVM : public ExpandOpsForLLVMBase { }; } // namespace -std::unique_ptr> -mlir::torch::RefBackend::createExpandOpsForLLVMPass() { - return std::make_unique(); -} - //===----------------------------------------------------------------------===// // MungeMemrefCopy //===----------------------------------------------------------------------===// @@ -432,11 +436,6 @@ class MungeMemrefCopy : public MungeMemrefCopyBase { }; } // namespace -std::unique_ptr> -mlir::torch::RefBackend::createMungeMemrefCopyPass() { - return std::make_unique(); -} - namespace { class GeneralizeTensorConcat : public GeneralizeTensorConcatBase { @@ -454,11 +453,6 @@ class GeneralizeTensorConcat }; } // namespace -std::unique_ptr> -mlir::torch::RefBackend::createGeneralizeTensorConcatPass() { - return std::make_unique(); -} - namespace { class GeneralizeTensorPad : public GeneralizeTensorPadBase { @@ -476,8 +470,3 @@ class GeneralizeTensorPad } }; } // namespace - -std::unique_ptr> -mlir::torch::RefBackend::createGeneralizeTensorPadPass() { - return std::make_unique(); -}