diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h b/mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h index 4d3a06610..56c053973 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h @@ -25,6 +25,7 @@ #define MLIR_TENSORRT_C_COMPILER_COMPILER #include "mlir-c/IR.h" +#include "mlir-c/Pass.h" #include "mlir-c/Support.h" #include "mlir-executor-c/Common/Common.h" #include "mlir-executor-c/Support/Status.h" @@ -47,8 +48,8 @@ mtrtCompilerClientCreate(MlirContext context, MTRT_CompilerClient *client); MLIR_CAPI_EXPORTED MTRT_Status mtrtCompilerClientDestroy(MTRT_CompilerClient client); -static inline bool mtrtCompilerClientIsNull(MTRT_CompilerClient options) { - return !options.ptr; +static inline bool mtrtCompilerClientIsNull(MTRT_CompilerClient client) { + return !client.ptr; } //===----------------------------------------------------------------------===// @@ -108,6 +109,14 @@ MLIR_CAPI_EXPORTED MTRT_Status mtrtCompilerStableHLOToExecutable( MTRT_CompilerClient client, MlirOperation module, MTRT_StableHLOToExecutableOptions options, MTRT_Executable *result); +MLIR_CAPI_EXPORTED MTRT_Status mtrtCompilerPopulatePassManager( + MTRT_CompilerClient compilerClient, + MTRT_StableHLOToExecutableOptions stableHloToExecutableOptions, + MlirPassManager *passManager); + +MLIR_CAPI_EXPORTED MTRT_Status mtrtTranslateRuntimeToExecutable( + MlirOperation moduleOp, MTRT_Executable *result); + //===----------------------------------------------------------------------===// // MTRT_StableHLOProgramSignatureRefinementOptions //===----------------------------------------------------------------------===// diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StableHloToExecutable.h b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StableHloToExecutable.h index 8358ff0dd..c71af66cb 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StableHloToExecutable.h +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StableHloToExecutable.h @@ -202,13 +202,6 @@ class StableHloToExecutableTask static void populatePassManager(mlir::PassManager &pm, const StableHLOToExecutableOptions &options); - /// Compile a StableHLO module into a MLIR-TensorRT Runtime executable. - /// This is the "functional" entrypoint that will allocate a new PassManager - /// for a single run. - static mlirtrt::StatusOr> - compileStableHLOToExecutable(mlir::ModuleOp module, - const StableHLOToExecutableOptions &options); - /// Compile a StableHLO module into a MLIR-TensorRT Runtime executable. /// This is the "functional" entrypoint that will allocate a new PassManager /// for a single run. diff --git a/mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp b/mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp index 9be562712..f6859bfbd 100644 --- a/mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp +++ b/mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp @@ -23,15 +23,19 @@ //===----------------------------------------------------------------------===// #include "mlir-tensorrt-c/Compiler/Compiler.h" #include "mlir-c/IR.h" +#include "mlir-c/Pass.h" #include "mlir-c/Support.h" #include "mlir-executor-c/Support/Status.h" #include "mlir-tensorrt-dialect/Target/TranslateToTensorRT.h" #include "mlir-tensorrt/Compiler/Extension.h" #include "mlir-tensorrt/Compiler/StableHloToExecutable.h" #include "mlir-tensorrt/Compiler/TensorRTExtension/TensorRTExtension.h" +#include "mlir-executor/Target/Lua/TranslateToRuntimeExecutable.h" #include "mlir-tensorrt/Dialect/Plan/IR/Plan.h" #include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Pass.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/Support/Debug.h" using namespace mlirtrt; using namespace mlirtrt::compiler; @@ -46,10 +50,14 @@ DEFINE_C_API_PTR_METHODS(MTRT_StableHLOToExecutableOptions, StableHLOToExecutableOptions) DEFINE_C_API_PTR_METHODS(MTRT_StableHLOProgramSignatureRefinementOptions, StableHLOProgramSignatureRefinementOptions) + #if defined(__GNUC__) || defined(__clang__) #pragma GCC diagnostic pop #endif +#define DEBUG_TYPE "compiler-api" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]") + /// Return the MTRT_StatusCode. These are auto-generated from the same schema as /// the `mlirtrt::StatusCode`. static MTRT_StatusCode @@ -97,7 +105,7 @@ MTRT_Status mtrtCompilerClientCreate(MlirContext context, } MTRT_Status mtrtCompilerClientDestroy(MTRT_CompilerClient client) { - delete reinterpret_cast(client.ptr); + delete reinterpret_cast(client.ptr); return mtrtStatusGetOk(); } @@ -256,6 +264,68 @@ MTRT_Status mtrtCompilerStableHLOToExecutable( return mtrtStatusGetOk(); } +MTRT_Status mtrtCompilerPopulatePassManager( + MTRT_CompilerClient compilerClient, + MTRT_StableHLOToExecutableOptions stableHloToExecutableOptions, + MlirPassManager *pm) { + + PassManager *passManager = llvm::dyn_cast(unwrap(*pm)); + + std::unique_ptr runner{}; + + CompilerClient &client = *unwrap(compilerClient); + const StableHLOToExecutableOptions &options = + *unwrap(stableHloToExecutableOptions); + + LLVM_DEBUG({ + DBGS() << "compiling with options:\n"; + options.print(llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + +#ifndef NDEBUG + if (options.debugOptions.enableLLVMDebugFlag) { + SmallVector debugTypeLiterals = + llvm::map_to_vector(options.debugOptions.llvmDebugTypes, + [](const std::string &x) { return x.c_str(); }); + llvm::setCurrentDebugTypes(debugTypeLiterals.data(), + debugTypeLiterals.size()); + llvm::DebugFlag = true; + } +#endif + + if (options.getHash()) + passManager = + &client.getOrCreatePassManager(options); + else { + runner.reset(new StableHloToExecutableTask(client.getContext(), options)); + CompilerClient::setupPassManagerLogging(*passManager, options.debugOptions); + passManager = runner.get(); + } + + return mtrtStatusGetOk(); +} + +MTRT_Status mtrtTranslateRuntimeToExecutable(MlirOperation moduleOp, + MTRT_Executable *result) { + ModuleOp module = llvm::dyn_cast(unwrap(moduleOp)); + + // Translate to Runtime Executable + FailureOr> exeStorage = + mlir::translateToRuntimeExecutable(module); + + if (failed(exeStorage)) + return mtrtStatusCreate(MTRT_StatusCode::MTRT_StatusCode_InternalError, + "failed to translate compiled MLIR module to a " + "MLIR-TensorRT runtime Executable"); + + auto exe = std::make_unique(std::move(*exeStorage)); + + result->ptr = exe.release(); + + return mtrtStatusGetOk(); +} + //===----------------------------------------------------------------------===// // Main StableHLO Program Signature Refinement Functions //===----------------------------------------------------------------------===// diff --git a/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp b/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp index ba79bcb6d..9cf52c658 100644 --- a/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp +++ b/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp @@ -388,66 +388,6 @@ void StableHloToExecutableTask::populatePassManager( mlir::executor::buildExecutorLoweringPipeline(pm, stdToExecOpts); } -StatusOr> -StableHloToExecutableTask::compileStableHLOToExecutable( - mlir::ModuleOp module, const StableHLOToExecutableOptions &options) { - LLVM_DEBUG({ - DBGS() << "compiling with options:\n"; - options.print(llvm::dbgs()); - llvm::dbgs() << "\n"; - }); - -#ifndef NDEBUG - //===----------------------------------------------------------------------===// - // Set debug options. - //===----------------------------------------------------------------------===// - if (options.debugOptions.enableLLVMDebugFlag) { - SmallVector debugTypeLiterals = - llvm::map_to_vector(options.debugOptions.llvmDebugTypes, - [](const std::string &x) { return x.c_str(); }); - llvm::setCurrentDebugTypes(debugTypeLiterals.data(), - debugTypeLiterals.size()); - llvm::DebugFlag = true; - } -#endif - - //===----------------------------------------------------------------------===// - // Setup pass manager - //===----------------------------------------------------------------------===// - - StableHloToExecutableTask runner(module->getContext(), options); - if (failed(setupPassManager(runner, options.debugOptions))) { - /// TODO: Ignored. This can fail if pass manager static CL options were not - /// registered/initialized. This happens through invocation of e.g. this - /// function in e.g. Python bindings or standalone calls to C++ or C API - /// without doing all the typical static CL setup. We should instead be - /// accepting a PassManager here that has already been setup to the caller's - /// specifications. - } - if (failed(runner.run(module))) - return getInternalErrorStatus( - "failed to run compilation on module with symbol name: {0}", - module.getName() ? *module.getName() : "no-symbol-name"); - - //===----------------------------------------------------------------------===// - // Translate to Runtime Executable - //===----------------------------------------------------------------------===// - - FailureOr> exeStorage = - mlir::translateToRuntimeExecutable(module); - if (failed(exeStorage)) - return getStatusWithMsg(StatusCode::InternalError, - "failed to translate compiled MLIR module to a " - "MLIR-TensorRT runtime Executable"); - -#ifndef NDEBUG - // Turn debugging back off if we turned it on. - if (options.debugOptions.enableLLVMDebugFlag) - llvm::DebugFlag = false; -#endif - - return std::make_unique(std::move(*exeStorage)); -} mlirtrt::StatusOr> StableHloToExecutableTask::compileStableHLOToExecutable( diff --git a/mlir-tensorrt/python/bindings/Compiler/CompilerPyBind.cpp b/mlir-tensorrt/python/bindings/Compiler/CompilerPyBind.cpp index 284f12f40..b9f670657 100644 --- a/mlir-tensorrt/python/bindings/Compiler/CompilerPyBind.cpp +++ b/mlir-tensorrt/python/bindings/Compiler/CompilerPyBind.cpp @@ -11,6 +11,7 @@ #include "../Utils.h" #include "NvInferRuntime.h" #include "mlir-c/IR.h" +#include "mlir-c/Pass.h" #include "mlir-c/Support.h" #include "mlir-executor-c/Common/Common.h" #include "mlir-executor-c/Support/Status.h" @@ -37,6 +38,26 @@ MTRT_DEFINE_COMPILER_INLINE_PY_CAPSULE_CASTER_FUNCS( namespace { +// Define a type caster for MlirPassManager +namespace pybind11 { namespace detail { + template <> struct type_caster { + PYBIND11_TYPE_CASTER(MlirPassManager, _("MlirPassManager")); + + // Conversion from Python to C++ + bool load(py::handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToPassManager(capsule.ptr()); + return !mlirPassManagerIsNull(value); + } + + // Conversion from C++ to Python + static py::handle cast(MlirPassManager pm, py::return_value_policy, py::handle) { + if (mlirPassManagerIsNull(pm)) return py::none(); + return py::reinterpret_steal(mlirPythonPassManagerToCapsule(pm)); + } + }; +}} + //===----------------------------------------------------------------------===// // Python Wrapper Classes //===----------------------------------------------------------------------===// @@ -325,6 +346,27 @@ PYBIND11_MODULE(_api, m) { }, py::arg("client"), py::arg("module"), py::arg("options")); + m.def( + "compiler_populate_pass_manager", + [](PyCompilerClient &client, PyStableHLOToExecutableOptions &options) { + MlirPassManager pm{nullptr}; + MTRT_Status status = + mtrtCompilerPopulatePassManager(client, options, &pm); + THROW_IF_MTRT_ERROR(status); + return py::reinterpret_steal(mlirPythonPassManagerToCapsule(pm)); + }, + py::arg("client"), py::arg("options")); + + m.def( + "compiler_translate_to_executable", + [](MlirOperation module) { + MTRT_Executable exe{nullptr}; + MTRT_Status status = mtrtTranslateRuntimeToExecutable(module, &exe); + THROW_IF_MTRT_ERROR(status); + return new PyExecutable(exe); + }, + py::arg("module")); + m.def( "get_stablehlo_program_refined_signature", [](PyCompilerClient &client, MlirOperation module, std::string funcName) { diff --git a/mlir-tensorrt/test/python/mlir_tensorrt_compiler/compiler_api/test_compiler_api.py b/mlir-tensorrt/test/python/mlir_tensorrt_compiler/compiler_api/test_compiler_api.py index edc51ada6..7ad7e0dfe 100644 --- a/mlir-tensorrt/test/python/mlir_tensorrt_compiler/compiler_api/test_compiler_api.py +++ b/mlir-tensorrt/test/python/mlir_tensorrt_compiler/compiler_api/test_compiler_api.py @@ -71,7 +71,7 @@ def flush(): sys.stderr.flush() -def compile_asm(ASM): +def compile_asm(ASM, use_pass_manager_api=False): with Context() as context: m = Module.parse(ASM) client = api.CompilerClient(context) @@ -93,7 +93,18 @@ def compile_asm(ASM): print("running compilation (1)") flush() - exe = api.compiler_stablehlo_to_executable(client, m.operation.clone(), opts) + if use_pass_manager_api: + pm = api.compiler_populate_pass_manager(client, opts) + import pdb + + pdb.set_trace() + compiled_module = pm.run(m.operation.clone()) + exe = api.compiler_translate_to_executable(compiled_module) + else: + exe = api.compiler_stablehlo_to_executable( + client, m.operation.clone(), opts + ) + # Options don't change, so the cached pipeline should be re-used. print("running compilation (2)") flush() @@ -126,7 +137,7 @@ def compile_asm(ASM): print("Compiling static asm") -compile_asm(STATIC_ASM) +compile_asm(STATIC_ASM, use_pass_manager_api=True) # CHECK-LABEL: Compiling static asm # CHECK-LABEL: running compilation (1) # CHECK: [translate-to-tensorrt] TranslateToTensorRTEnginePass is generating a new TensorRT builder diff --git a/mlir-tensorrt/test/python/mlir_tensorrt_compiler/compiler_api/test_compiler_debug_dump.py b/mlir-tensorrt/test/python/mlir_tensorrt_compiler/compiler_api/test_compiler_debug_dump.py index 49ad84ffd..f4b98a6fc 100644 --- a/mlir-tensorrt/test/python/mlir_tensorrt_compiler/compiler_api/test_compiler_debug_dump.py +++ b/mlir-tensorrt/test/python/mlir_tensorrt_compiler/compiler_api/test_compiler_debug_dump.py @@ -1,4 +1,4 @@ -# RUN: %PYTHON %s 2>&1 +# RUN: %PYTHON %s 2>&1 | FileCheck %s # REQUIRES: host-has-at-least-1-gpus import os import tempfile @@ -50,3 +50,10 @@ def compile_asm(ASM): compile_asm(ASM) + +# CHECK: builtin.module +# CHECK: [translate-to-tensorrt] TranslateToTensorRTEnginePass is generating a new TensorRT builder +# CHECK: [translate-to-tensorrt] timing cache path was not specified, creating a fresh timing cache +# CHECK: [translate-to-tensorrt] deserializing TensorRT builder timing cache (0 bytes) +# CHECK: [translate-to-tensorrt] Setting builder optimization level to 3 +# CHECK: [translate-to-tensorrt] replacing cache with updated data