Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Compiler] Split compiler pipeline into pass pipeline, compilation, and translation steps #284

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#define MLIR_TENSORRT_C_COMPILER_COMPILER

#include "mlir-c/IR.h"
#include "mlir-c/Pass.h"
#include "mlir-c/Support.h"
#include "mlir-executor-c/Common/Common.h"
#include "mlir-executor-c/Support/Status.h"
Expand All @@ -47,8 +48,8 @@ mtrtCompilerClientCreate(MlirContext context, MTRT_CompilerClient *client);
MLIR_CAPI_EXPORTED MTRT_Status
mtrtCompilerClientDestroy(MTRT_CompilerClient client);

static inline bool mtrtCompilerClientIsNull(MTRT_CompilerClient options) {
return !options.ptr;
static inline bool mtrtCompilerClientIsNull(MTRT_CompilerClient client) {
return !client.ptr;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -108,6 +109,14 @@ MLIR_CAPI_EXPORTED MTRT_Status mtrtCompilerStableHLOToExecutable(
MTRT_CompilerClient client, MlirOperation module,
MTRT_StableHLOToExecutableOptions options, MTRT_Executable *result);

MLIR_CAPI_EXPORTED MTRT_Status mtrtCompilerPopulatePassManager(
MTRT_CompilerClient compilerClient,
MTRT_StableHLOToExecutableOptions stableHloToExecutableOptions,
MlirPassManager *passManager);

MLIR_CAPI_EXPORTED MTRT_Status mtrtTranslateRuntimeToExecutable(
MlirOperation moduleOp, MTRT_Executable *result);

//===----------------------------------------------------------------------===//
// MTRT_StableHLOProgramSignatureRefinementOptions
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,13 +202,6 @@ class StableHloToExecutableTask
static void populatePassManager(mlir::PassManager &pm,
const StableHLOToExecutableOptions &options);

/// Compile a StableHLO module into a MLIR-TensorRT Runtime executable.
/// This is the "functional" entrypoint that will allocate a new PassManager
/// for a single run.
static mlirtrt::StatusOr<std::unique_ptr<runtime::Executable>>
compileStableHLOToExecutable(mlir::ModuleOp module,
const StableHLOToExecutableOptions &options);

/// Compile a StableHLO module into a MLIR-TensorRT Runtime executable.
/// This is the "functional" entrypoint that will allocate a new PassManager
/// for a single run.
Expand Down
72 changes: 71 additions & 1 deletion mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,19 @@
//===----------------------------------------------------------------------===//
#include "mlir-tensorrt-c/Compiler/Compiler.h"
#include "mlir-c/IR.h"
#include "mlir-c/Pass.h"
#include "mlir-c/Support.h"
#include "mlir-executor-c/Support/Status.h"
#include "mlir-tensorrt-dialect/Target/TranslateToTensorRT.h"
#include "mlir-tensorrt/Compiler/Extension.h"
#include "mlir-tensorrt/Compiler/StableHloToExecutable.h"
#include "mlir-tensorrt/Compiler/TensorRTExtension/TensorRTExtension.h"
#include "mlir-executor/Target/Lua/TranslateToRuntimeExecutable.h"
#include "mlir-tensorrt/Dialect/Plan/IR/Plan.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Pass.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Debug.h"

using namespace mlirtrt;
using namespace mlirtrt::compiler;
Expand All @@ -46,10 +50,14 @@ DEFINE_C_API_PTR_METHODS(MTRT_StableHLOToExecutableOptions,
StableHLOToExecutableOptions)
DEFINE_C_API_PTR_METHODS(MTRT_StableHLOProgramSignatureRefinementOptions,
StableHLOProgramSignatureRefinementOptions)

#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic pop
#endif

#define DEBUG_TYPE "compiler-api"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]")

/// Return the MTRT_StatusCode. These are auto-generated from the same schema as
/// the `mlirtrt::StatusCode`.
static MTRT_StatusCode
Expand Down Expand Up @@ -97,7 +105,7 @@ MTRT_Status mtrtCompilerClientCreate(MlirContext context,
}

MTRT_Status mtrtCompilerClientDestroy(MTRT_CompilerClient client) {
delete reinterpret_cast<MTRT_CompilerClient *>(client.ptr);
delete reinterpret_cast<CompilerClient *>(client.ptr);
return mtrtStatusGetOk();
}

Expand Down Expand Up @@ -256,6 +264,68 @@ MTRT_Status mtrtCompilerStableHLOToExecutable(
return mtrtStatusGetOk();
}

MTRT_Status mtrtCompilerPopulatePassManager(
MTRT_CompilerClient compilerClient,
MTRT_StableHLOToExecutableOptions stableHloToExecutableOptions,
MlirPassManager *pm) {

PassManager *passManager = llvm::dyn_cast<PassManager>(unwrap(*pm));

std::unique_ptr<StableHloToExecutableTask> runner{};

CompilerClient &client = *unwrap(compilerClient);
const StableHLOToExecutableOptions &options =
*unwrap(stableHloToExecutableOptions);

LLVM_DEBUG({
DBGS() << "compiling with options:\n";
options.print(llvm::dbgs());
llvm::dbgs() << "\n";
});

#ifndef NDEBUG
if (options.debugOptions.enableLLVMDebugFlag) {
SmallVector<const char *> debugTypeLiterals =
llvm::map_to_vector(options.debugOptions.llvmDebugTypes,
[](const std::string &x) { return x.c_str(); });
llvm::setCurrentDebugTypes(debugTypeLiterals.data(),
debugTypeLiterals.size());
llvm::DebugFlag = true;
}
#endif

if (options.getHash())
passManager =
&client.getOrCreatePassManager<StableHloToExecutableTask>(options);
else {
runner.reset(new StableHloToExecutableTask(client.getContext(), options));
CompilerClient::setupPassManagerLogging(*passManager, options.debugOptions);
passManager = runner.get();
}

return mtrtStatusGetOk();
}

MTRT_Status mtrtTranslateRuntimeToExecutable(MlirOperation moduleOp,
MTRT_Executable *result) {
ModuleOp module = llvm::dyn_cast<ModuleOp>(unwrap(moduleOp));

// Translate to Runtime Executable
FailureOr<std::unique_ptr<runtime::ExecutableStorage>> exeStorage =
mlir::translateToRuntimeExecutable(module);

if (failed(exeStorage))
return mtrtStatusCreate(MTRT_StatusCode::MTRT_StatusCode_InternalError,
"failed to translate compiled MLIR module to a "
"MLIR-TensorRT runtime Executable");

auto exe = std::make_unique<runtime::Executable>(std::move(*exeStorage));

result->ptr = exe.release();

return mtrtStatusGetOk();
}

//===----------------------------------------------------------------------===//
// Main StableHLO Program Signature Refinement Functions
//===----------------------------------------------------------------------===//
Expand Down
60 changes: 0 additions & 60 deletions mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,66 +388,6 @@ void StableHloToExecutableTask::populatePassManager(
mlir::executor::buildExecutorLoweringPipeline(pm, stdToExecOpts);
}

StatusOr<std::unique_ptr<runtime::Executable>>
StableHloToExecutableTask::compileStableHLOToExecutable(
mlir::ModuleOp module, const StableHLOToExecutableOptions &options) {
LLVM_DEBUG({
DBGS() << "compiling with options:\n";
options.print(llvm::dbgs());
llvm::dbgs() << "\n";
});

#ifndef NDEBUG
//===----------------------------------------------------------------------===//
// Set debug options.
//===----------------------------------------------------------------------===//
if (options.debugOptions.enableLLVMDebugFlag) {
SmallVector<const char *> debugTypeLiterals =
llvm::map_to_vector(options.debugOptions.llvmDebugTypes,
[](const std::string &x) { return x.c_str(); });
llvm::setCurrentDebugTypes(debugTypeLiterals.data(),
debugTypeLiterals.size());
llvm::DebugFlag = true;
}
#endif

//===----------------------------------------------------------------------===//
// Setup pass manager
//===----------------------------------------------------------------------===//

StableHloToExecutableTask runner(module->getContext(), options);
if (failed(setupPassManager(runner, options.debugOptions))) {
/// TODO: Ignored. This can fail if pass manager static CL options were not
/// registered/initialized. This happens through invocation of e.g. this
/// function in e.g. Python bindings or standalone calls to C++ or C API
/// without doing all the typical static CL setup. We should instead be
/// accepting a PassManager here that has already been setup to the caller's
/// specifications.
}
if (failed(runner.run(module)))
return getInternalErrorStatus(
"failed to run compilation on module with symbol name: {0}",
module.getName() ? *module.getName() : "no-symbol-name");

//===----------------------------------------------------------------------===//
// Translate to Runtime Executable
//===----------------------------------------------------------------------===//

FailureOr<std::unique_ptr<runtime::ExecutableStorage>> exeStorage =
mlir::translateToRuntimeExecutable(module);
if (failed(exeStorage))
return getStatusWithMsg(StatusCode::InternalError,
"failed to translate compiled MLIR module to a "
"MLIR-TensorRT runtime Executable");

#ifndef NDEBUG
// Turn debugging back off if we turned it on.
if (options.debugOptions.enableLLVMDebugFlag)
llvm::DebugFlag = false;
#endif

return std::make_unique<runtime::Executable>(std::move(*exeStorage));
}

mlirtrt::StatusOr<std::unique_ptr<runtime::Executable>>
StableHloToExecutableTask::compileStableHLOToExecutable(
Expand Down
42 changes: 42 additions & 0 deletions mlir-tensorrt/python/bindings/Compiler/CompilerPyBind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "../Utils.h"
#include "NvInferRuntime.h"
#include "mlir-c/IR.h"
#include "mlir-c/Pass.h"
#include "mlir-c/Support.h"
#include "mlir-executor-c/Common/Common.h"
#include "mlir-executor-c/Support/Status.h"
Expand All @@ -37,6 +38,26 @@ MTRT_DEFINE_COMPILER_INLINE_PY_CAPSULE_CASTER_FUNCS(

namespace {

// Define a type caster for MlirPassManager
namespace pybind11 { namespace detail {
template <> struct type_caster<MlirPassManager> {
PYBIND11_TYPE_CASTER(MlirPassManager, _("MlirPassManager"));

// Conversion from Python to C++
bool load(py::handle src, bool) {
py::object capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToPassManager(capsule.ptr());
return !mlirPassManagerIsNull(value);
}

// Conversion from C++ to Python
static py::handle cast(MlirPassManager pm, py::return_value_policy, py::handle) {
if (mlirPassManagerIsNull(pm)) return py::none();
return py::reinterpret_steal<py::object>(mlirPythonPassManagerToCapsule(pm));
}
};
}}

//===----------------------------------------------------------------------===//
// Python Wrapper Classes
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -325,6 +346,27 @@ PYBIND11_MODULE(_api, m) {
},
py::arg("client"), py::arg("module"), py::arg("options"));

m.def(
"compiler_populate_pass_manager",
[](PyCompilerClient &client, PyStableHLOToExecutableOptions &options) {
MlirPassManager pm{nullptr};
MTRT_Status status =
mtrtCompilerPopulatePassManager(client, options, &pm);
THROW_IF_MTRT_ERROR(status);
return py::reinterpret_steal<py::object>(mlirPythonPassManagerToCapsule(pm));
},
py::arg("client"), py::arg("options"));

m.def(
"compiler_translate_to_executable",
[](MlirOperation module) {
MTRT_Executable exe{nullptr};
MTRT_Status status = mtrtTranslateRuntimeToExecutable(module, &exe);
THROW_IF_MTRT_ERROR(status);
return new PyExecutable(exe);
},
py::arg("module"));

m.def(
"get_stablehlo_program_refined_signature",
[](PyCompilerClient &client, MlirOperation module, std::string funcName) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def flush():
sys.stderr.flush()


def compile_asm(ASM):
def compile_asm(ASM, use_pass_manager_api=False):
with Context() as context:
m = Module.parse(ASM)
client = api.CompilerClient(context)
Expand All @@ -93,7 +93,18 @@ def compile_asm(ASM):

print("running compilation (1)")
flush()
exe = api.compiler_stablehlo_to_executable(client, m.operation.clone(), opts)
if use_pass_manager_api:
pm = api.compiler_populate_pass_manager(client, opts)
import pdb

pdb.set_trace()
compiled_module = pm.run(m.operation.clone())
exe = api.compiler_translate_to_executable(compiled_module)
else:
exe = api.compiler_stablehlo_to_executable(
client, m.operation.clone(), opts
)

# Options don't change, so the cached pipeline should be re-used.
print("running compilation (2)")
flush()
Expand Down Expand Up @@ -126,7 +137,7 @@ def compile_asm(ASM):


print("Compiling static asm")
compile_asm(STATIC_ASM)
compile_asm(STATIC_ASM, use_pass_manager_api=True)
# CHECK-LABEL: Compiling static asm
# CHECK-LABEL: running compilation (1)
# CHECK: [translate-to-tensorrt] TranslateToTensorRTEnginePass is generating a new TensorRT builder
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# RUN: %PYTHON %s 2>&1
# RUN: %PYTHON %s 2>&1 | FileCheck %s
# REQUIRES: host-has-at-least-1-gpus
import os
import tempfile
Expand Down Expand Up @@ -50,3 +50,10 @@ def compile_asm(ASM):


compile_asm(ASM)

# CHECK: builtin.module
# CHECK: [translate-to-tensorrt] TranslateToTensorRTEnginePass is generating a new TensorRT builder
# CHECK: [translate-to-tensorrt] timing cache path was not specified, creating a fresh timing cache
# CHECK: [translate-to-tensorrt] deserializing TensorRT builder timing cache (0 bytes)
# CHECK: [translate-to-tensorrt] Setting builder optimization level to 3
# CHECK: [translate-to-tensorrt] replacing cache with updated data