Skip to content

Commit

Permalink
Adds a layer metadata callback API
Browse files Browse the repository at this point in the history
- Adds a new API which allows for setting a layer metadata callback which
    will be invoked for each MLIR operation in order to set metadata for the
    corresponding TensorRT network layers.
  • Loading branch information
pranavm-nvidia committed Sep 24, 2024
1 parent 8434208 commit c4ef5c9
Show file tree
Hide file tree
Showing 15 changed files with 317 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ typedef struct MTRT_StableHLOToExecutableOptions {
void *ptr;
} MTRT_StableHLOToExecutableOptions;

typedef void (*MTRT_MetadataCallback)(MlirOperation op,
MlirStringCallback append,
void *appendCtx, void *userData);

MLIR_CAPI_EXPORTED MTRT_Status mtrtStableHloToExecutableOptionsCreate(
MTRT_CompilerClient client, MTRT_StableHLOToExecutableOptions *options,
int32_t tensorRTBuilderOptLevel, bool tensorRTStronglyTyped);
Expand All @@ -77,6 +81,11 @@ MLIR_CAPI_EXPORTED MTRT_Status mtrtStableHloToExecutableOptionsSetDebugOptions(
const char **debugTypes, size_t debugTypeSizes,
const char *dumpIrTreeDir = nullptr, const char *dumpTensorRTDir = nullptr);

MLIR_CAPI_EXPORTED MTRT_Status
mtrtStableHloToExecutableOptionsSetTensorRTTranslationMetadataCallback(
MTRT_StableHLOToExecutableOptions options, MTRT_MetadataCallback callback,
void *userData);

MLIR_CAPI_EXPORTED MTRT_Status mtrtStableHloToExecutableOptionsDestroy(
MTRT_StableHLOToExecutableOptions options);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,11 @@ class CompilerClient {
auto key = std::make_pair(mlir::TypeID::get<CompilationTaskType>(),
options.getHash());
auto it = cachedPassManagers.find(key);
if (it == cachedPassManagers.end()) {
if (it == cachedPassManagers.end() || options.shouldInvalidateCache()) {
auto pm = std::make_unique<CompilationTaskType>(context, options);
setupPassManagerLogging(*pm, options.debugOptions);
auto *ptr = pm.get();
cachedPassManagers.insert(std::make_pair(key, std::move(pm)));
cachedPassManagers[key] = std::move(pm);
return *ptr;
}
return *it->second;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@

#include "mlir-executor/Runtime/API/API.h"
#include "mlir-executor/Support/Status.h"
#include "mlir-tensorrt-dialect/Target/TranslateToTensorRT.h"
#include "mlir-tensorrt/Compiler/Client.h"
#include "mlir-tensorrt/Compiler/Extension.h"
#include "mlir-tensorrt/Compiler/Options.h"
Expand Down Expand Up @@ -106,6 +105,15 @@ struct StableHLOToExecutableOptions : public mlir::OptionsContext {
/// Get the mutable DebugOptions.
DebugOptions &getDebugOptions() { return debugOptions; }

llvm::hash_code getHash() const override;

bool shouldInvalidateCache() const override {
// If a callback is provided, we have no way of verifying whether it is
// equivalent to a callback from another set of options. Therefore, we are
// forced to invalidate the cache entry if it is present at all.
return static_cast<bool>(layerMetadataCallback);
}

/// The host index bit-width.
int64_t executorIndexBitwidth{64};

Expand All @@ -125,11 +133,13 @@ struct StableHLOToExecutableOptions : public mlir::OptionsContext {
/// Whether to disallow host tensors in TensorRT clusters.
bool disallowHostTensorsInTensorRTClusters = false;

/// Entrypiont function name.
/// Entrypoint function name.
std::string entrypoint = "main";

DebugOptions debugOptions;

std::function<std::string(mlir::Operation *)> layerMetadataCallback{nullptr};

/// Base class for extensions associated with StableHloToExecutableTask.
class ExtensionBase : public TaskExtensionBase {
public:
Expand Down
27 changes: 26 additions & 1 deletion mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
#include "mlir-tensorrt/Dialect/Plan/IR/Plan.h"
#include "mlir/CAPI/IR.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/raw_ostream.h"

using namespace mlirtrt;
using namespace mlirtrt::compiler;
Expand Down Expand Up @@ -199,6 +198,32 @@ MTRT_Status mtrtStableHloToExecutableOptionsSetDebugOptions(
return mtrtStatusGetOk();
}

MTRT_Status
mtrtStableHloToExecutableOptionsSetTensorRTTranslationMetadataCallback(
MTRT_StableHLOToExecutableOptions options, MTRT_MetadataCallback callback,
void *userData) {
StableHLOToExecutableOptions *cppOpts = unwrap(options);

// Construct the append callback which we will pass to the callback provided
// by the user. We do it this way to avoid needing a string construct in the C
// API.
auto appendFunc = [](MlirStringRef str, void *appendCtx) {
std::string &accum = *reinterpret_cast<std::string *>(appendCtx);
accum += std::string(str.data, str.length);
};

// Capturing by reference here will cause `callback` to point to the wrong
// place at the time this callback is invoked.
cppOpts->layerMetadataCallback = [=](Operation *op) {
std::string accum;
void *appendCtx = reinterpret_cast<void *>(&accum);
callback(wrap(op), appendFunc, appendCtx, userData);
return accum;
};

return mtrtStatusGetOk();
}

MTRT_Status mtrtStableHloToExecutableOptionsDestroy(
MTRT_StableHLOToExecutableOptions options) {
delete reinterpret_cast<StableHLOToExecutableOptions *>(options.ptr);
Expand Down
7 changes: 7 additions & 0 deletions mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,13 @@ Status StableHLOToExecutableOptions::inferDeviceOptionsFromHost() {
return Status::getOk();
}

llvm::hash_code StableHLOToExecutableOptions::getHash() const {
llvm::hash_code hash = OptionsContext::getHash();
if (layerMetadataCallback)
return llvm::hash_combine(hash, &layerMetadataCallback);
return hash;
}

//===----------------------------------------------------------------------===//
// StableHloToExecutableTask
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ void StableHLOToExecutableTensorRTExtension::populatePasses(
auto &trtPM = pm.nest<tensorrt::TensorRTModuleOp>();
tensorrt::buildTensorRTModuleTransformationPipeline(
trtPM, translationOptions.enableStronglyTyped);
trtPM.addPass(
tensorrt::createTranslateTensorRTPass(nullptr, translationOptions));
trtPM.addPass(tensorrt::createTranslateTensorRTPass(
nullptr, options.layerMetadataCallback, translationOptions));
return;
}

Expand Down
47 changes: 45 additions & 2 deletions mlir-tensorrt/python/bindings/Compiler/CompilerPyBind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "pybind11/pybind11.h"
#include "llvm/Support/DynamicLibrary.h"
#include "llvm/Support/raw_ostream.h"
#include <pybind11/attr.h>
#include <pybind11/functional.h>

#ifdef MLIR_TRT_TARGET_TENSORRT
#include "mlir-tensorrt-dialect/Utils/NvInferAdaptor.h"
Expand Down Expand Up @@ -66,6 +68,11 @@ class PyStableHLOToExecutableOptions
mtrtStableHloToExecutableOptionsDestroy,
mtrtPythonCapsuleToStableHLOToExecutableOptions,
mtrtPythonStableHLOToExecutableOptionsToCapsule};

// We need this member so we can keep the Python callback alive long enough.
std::function<std::string(MlirOperation)> callback;

~PyStableHLOToExecutableOptions() { callback = nullptr; }
};
} // namespace

Expand Down Expand Up @@ -270,7 +277,43 @@ PYBIND11_MODULE(_api, m) {
py::arg("enabled"),
py::arg("debug_types") = std::vector<std::string>{},
py::arg("dump_ir_tree_dir") = py::none(),
py::arg("dump_tensorrt_dir") = py::none());
py::arg("dump_tensorrt_dir") = py::none())

#ifdef MLIR_TRT_TARGET_TENSORRT
.def(
"set_tensorrt_translation_metadata_callback",
[](PyStableHLOToExecutableOptions &self,
std::function<std::string(MlirOperation)> pyCallback) {
// Since we're constructing a C callback, our closures must not
// capture. We can pass in the Python callback via the userData
// argument.
auto callback = [](MlirOperation op, MlirStringCallback append,
void *appendCtx, void *userDataVoid) {
auto &pyCallback =
*static_cast<std::function<std::string(MlirOperation)> *>(
userDataVoid);

if (!pyCallback)
return;

std::string result;
try {
result = pyCallback(op);
} catch (const std::exception &e) {
llvm::errs() << e.what() << '\n';
}

append(MlirStringRef{result.data(), result.size()}, appendCtx);
};

self.callback = pyCallback;
THROW_IF_MTRT_ERROR(
mtrtStableHloToExecutableOptionsSetTensorRTTranslationMetadataCallback(
self, callback, reinterpret_cast<void *>(&self.callback)));
},
py::arg("callback"), py::keep_alive<1, 2>{})
#endif
;

m.def(
"compiler_stablehlo_to_executable",
Expand Down Expand Up @@ -308,4 +351,4 @@ PYBIND11_MODULE(_api, m) {
bindTensorRTPluginAdaptorObjects(m);
#endif
#endif
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,25 +74,31 @@ static constexpr nvinfer1::Weights kNullWeights =

class NvInferNetworkEncoder {
public:
NvInferNetworkEncoder(nvinfer1::INetworkDefinition *network,
nvinfer1::IOptimizationProfile *profile,
TensorRTVersion version, bool usesStronglyTyped)
NvInferNetworkEncoder(
nvinfer1::INetworkDefinition *network,
nvinfer1::IOptimizationProfile *profile, TensorRTVersion version,
bool usesStronglyTyped,
std::function<std::string(Operation *)> metadataCallback)
: network(network), profile(profile), version(std::move(version)),
usesStronglyTyped(usesStronglyTyped) {}
usesStronglyTyped(usesStronglyTyped),
layerMetadataCallback(std::move(metadataCallback)) {}

/// Lookup the TRT ITensor* equivalent of a Value.
nvinfer1::ITensor *lookup(Value v) const;

/// Lookup the TRT ITensor* equivalents of a ValueRange.
SmallVector<nvinfer1::ITensor *> lookupValues(ValueRange values);

/// Add a map from a Value to a TRT ITEnsor*.
/// Add a map from a Value to a TRT ITensor*.
void map(Value from, nvinfer1::ITensor *to);

/// Remap values in `from` to each layer in `to` using the output at index 0
/// for each layer.
void map(ValueRange from, ArrayRef<nvinfer1::ILayer *> to);

// Add a map from an Operation to a TRT ILayer*
void map(Operation *op, nvinfer1::ILayer *layer);

/// Check whether the value map contains `v`.
size_t contains(Value v) { return valueMap.count(v); }

Expand Down Expand Up @@ -132,6 +138,10 @@ class NvInferNetworkEncoder {
/// and other temporary buffers.
using WeightsMap = llvm::DenseMap<mlir::Attribute, std::vector<int8_t>>;

// Tracks the mapping of mlir::Operations to layers. Note that one operation
// may map to multiple layers.
using LayerMap = llvm::DenseMap<Operation *, std::vector<nvinfer1::ILayer *>>;

using NamesSet = llvm::StringSet<>;

TensorMap &getTensorMap() { return valueMap; }
Expand All @@ -141,7 +151,7 @@ class NvInferNetworkEncoder {

/// Set the name of the `trtLayer` to a unique string that contains the op
/// name and location information from `sourceOp`.
void setName(nvinfer1::ILayer *layer, Operation *sourceOp);
void setMetadata(nvinfer1::ILayer *layer, Operation *sourceOp);

// Check if network uses fp16 types.
bool hasFp16Usage() const { return usesFp16; }
Expand Down Expand Up @@ -207,6 +217,9 @@ class NvInferNetworkEncoder {
// build ends.
SmallVector<NvInferPluginPtr> pluginReferences;

// Tracks the mapping between mlir::Operations and TensorRT ILayers.
LayerMap layerMap;

/// Holds the set of strings currently assigned as names to TensorRT ILayers.
/// This is required because we must make new names unique. The TensorRT API
/// does not have a set object to query names.
Expand Down Expand Up @@ -238,6 +251,8 @@ class NvInferNetworkEncoder {
bool hasQDQOps{false};

PluginManager pluginMgr;

std::function<std::string(Operation *)> layerMetadataCallback;
};

//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

#ifdef MLIR_TRT_TARGET_TENSORRT
#include "mlir-tensorrt-dialect/Target/TensorRTEncodingOpInterface/NetworkEncoder.h"
#include "mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.h"
#include "mlir-tensorrt-dialect/Utils/Options.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/Support/raw_ostream.h"
Expand Down Expand Up @@ -208,17 +207,18 @@ class TensorRTSerializedTimingCache {
/// `tensorrt.shape_profile` arguments have been populated for each argument
/// that has unknown dimensions.
/// TODO(cbate): add additional options here for builder configuration.
FailureOr<TensorRTEngineResult>
buildFunction(mlir::FunctionOpInterface op,
TensorRTBuilderContext &builderContext,
TensorRTSerializedTimingCache &serializedTimingCache,
const TensorRTTranslationOptions &options =
TensorRTTranslationOptions::fromCLFlags());
FailureOr<TensorRTEngineResult> buildFunction(
mlir::FunctionOpInterface op, TensorRTBuilderContext &builderContext,
TensorRTSerializedTimingCache &serializedTimingCache,
const TensorRTTranslationOptions &options =
TensorRTTranslationOptions::fromCLFlags(),
std::function<std::string(Operation *)> layerMetadataCallback = nullptr);

/// Create an instance of a translate-to-tensorrt pass using an existing
/// TensorRTBuilderContext.
std::unique_ptr<mlir::Pass> createTranslateTensorRTPass(
std::shared_ptr<tensorrt::TensorRTBuilderContext> context,
std::function<std::string(Operation *)> layerMetadataCallback,
TensorRTTranslationOptions options =
TensorRTTranslationOptions::fromCLFlags());

Expand Down
Loading

0 comments on commit c4ef5c9

Please sign in to comment.