Skip to content

Commit

Permalink
Adds a mechanism to invalidate cache entries when a callback is provided
Browse files Browse the repository at this point in the history
  • Loading branch information
pranavm-nvidia committed Sep 23, 2024
1 parent 1765e21 commit 3b82968
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,11 @@ class CompilerClient {
auto key = std::make_pair(mlir::TypeID::get<CompilationTaskType>(),
options.getHash());
auto it = cachedPassManagers.find(key);
if (it == cachedPassManagers.end()) {
if (it == cachedPassManagers.end() || options.shouldInvalidateCache()) {
auto pm = std::make_unique<CompilationTaskType>(context, options);
setupPassManagerLogging(*pm, options.debugOptions);
auto *ptr = pm.get();
cachedPassManagers.insert(std::make_pair(key, std::move(pm)));
cachedPassManagers[key] = std::move(pm);
return *ptr;
}
return *it->second;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,15 @@ struct StableHLOToExecutableOptions : public mlir::OptionsContext {
/// Get the mutable DebugOptions.
DebugOptions &getDebugOptions() { return debugOptions; }

llvm::hash_code getHash() const override;

bool shouldInvalidateCache() const override {
// If a callback is provided, we have no way of verifying whether it is
// equivalent to a callback from another set of options. Therefore, we are
// forced to invalidate the cache entry if it is present at all.
return static_cast<bool>(layerMetadataCallback);
}

/// The host index bit-width.
int64_t executorIndexBitwidth{64};

Expand All @@ -129,8 +138,7 @@ struct StableHLOToExecutableOptions : public mlir::OptionsContext {

DebugOptions debugOptions;

std::function<std::string(mlir::Operation *)> layerMetadataCallback =
[](mlir::Operation *) { return ""; };
std::function<std::string(mlir::Operation *)> layerMetadataCallback{nullptr};

/// Base class for extensions associated with StableHloToExecutableTask.
class ExtensionBase : public TaskExtensionBase {
Expand Down
7 changes: 7 additions & 0 deletions mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,13 @@ Status StableHLOToExecutableOptions::inferDeviceOptionsFromHost() {
return Status::getOk();
}

llvm::hash_code StableHLOToExecutableOptions::getHash() const {
llvm::hash_code hash = OptionsContext::getHash();
if (layerMetadataCallback)
return llvm::hash_combine(hash, &layerMetadataCallback);
return hash;
}

//===----------------------------------------------------------------------===//
// StableHloToExecutableTask
//===----------------------------------------------------------------------===//
Expand Down
5 changes: 3 additions & 2 deletions mlir-tensorrt/python/bindings/Compiler/CompilerPyBind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class PyStableHLOToExecutableOptions

// We need this member so we can keep the Python callback alive long enough.
std::function<std::string(MlirOperation)> callback;

~PyStableHLOToExecutableOptions() { callback = nullptr; }
};
} // namespace

Expand Down Expand Up @@ -323,8 +325,7 @@ PYBIND11_MODULE(_api, m) {
THROW_IF_MTRT_ERROR(status);
return new PyExecutable(exe);
},
py::arg("client"), py::arg("module"), py::arg("options"),
py::keep_alive<1, 3>());
py::arg("client"), py::arg("module"), py::arg("options"));

m.def(
"get_stablehlo_program_refined_signature",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,7 @@ FailureOr<TensorRTEngineResult> buildFunction(
TensorRTSerializedTimingCache &serializedTimingCache,
const TensorRTTranslationOptions &options =
TensorRTTranslationOptions::fromCLFlags(),
std::function<std::string(Operation *)> layerMetadataCallback =
[](Operation *op) { return ""; });
std::function<std::string(Operation *)> layerMetadataCallback = nullptr);

/// Create an instance of a translate-to-tensorrt pass using an existing
/// TensorRTBuilderContext.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ namespace mlir {
/// ```
class OptionsContext : public llvm::cl::SubCommand {
public:
OptionsContext() = default;
OptionsContext(const OptionsContext &) = delete;
OptionsContext(OptionsContext &&) = default;
virtual ~OptionsContext() = default;

/// Add an option to this context. The storage `value` must outlive the
/// OptionsContext.
template <typename DataType, typename... Mods>
Expand Down Expand Up @@ -124,7 +129,9 @@ class OptionsContext : public llvm::cl::SubCommand {
void print(llvm::raw_ostream &os) const;

/// Get a hash derived from the string representation of the options.
llvm::hash_code getHash() const;
virtual llvm::hash_code getHash() const;

virtual bool shouldInvalidateCache() const { return false; }

private:
struct OptionInfo {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,8 @@ void NvInferNetworkEncoder::setMetadata(nvinfer1::ILayer *layer,
Operation *sourceOp) {
std::string name = createName(namesSet, sourceOp);
layer->setName(name.c_str());
layer->setMetadata(layerMetadataCallback(sourceOp).c_str());
if (layerMetadataCallback)
layer->setMetadata(layerMetadataCallback(sourceOp).c_str());
}

nvinfer1::ITensor *NvInferNetworkEncoder::lookup(Value v) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import glob
import os
import json
import gc

STATIC_ASM = """
func.func @main(%arg0: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
Expand All @@ -19,17 +20,14 @@


def layer_metadata_callback(op) -> str:
print("layer_metadata_callback CALLED")
return "TEST_CUSTOM_METADATA"


def compile_asm():
with Context() as context:
m = Module.parse(STATIC_ASM)
client = api.CompilerClient(context)
opts = api.StableHLOToExecutableOptions(
client,
["--tensorrt-builder-opt-level=3", "--tensorrt-strongly-typed=false"],
)

with tempfile.TemporaryDirectory() as tmp:
opts = api.StableHLOToExecutableOptions(
Expand Down Expand Up @@ -58,7 +56,45 @@ def compile_asm():
# CHECK-LABEL: Compiling ASM
# CHECK: [translate-to-tensorrt] TranslateToTensorRTEnginePass is generating a new TensorRT builder
# CHECK: [translate-to-tensorrt] timing cache path was not specified, creating a fresh timing cache
# CHECK: layer_metadata_callback CALLED
# CHECK: [translate-to-tensorrt] deserializing TensorRT builder timing cache (0 bytes)
# CHECK: [translate-to-tensorrt] Setting builder optimization level to 3
# CHECK: [translate-to-tensorrt] replacing cache with updated data (0 -> 2057 bytes)
# CHECK: TEST_CUSTOM_METADATA


def layer_metadata_callback2(op) -> str:
print("layer_metadata_callback2 CALLED")
return "TEST_CUSTOM_METADATA2"


def compile_multiple():
# Compile multiple times with different callbacks to ensure pass manager caching doesn't
# cause issues.
with Context() as context:
m = Module.parse(STATIC_ASM)
client = api.CompilerClient(context)
opts0 = api.StableHLOToExecutableOptions(
client,
["--tensorrt-builder-opt-level=3", "--tensorrt-strongly-typed=false"],
)
opts0.set_tensorrt_translation_metadata_callback(layer_metadata_callback)
api.compiler_stablehlo_to_executable(client, m.operation.clone(), opts0)

del opts0
gc.collect()

opts1 = api.StableHLOToExecutableOptions(
client,
["--tensorrt-builder-opt-level=3", "--tensorrt-strongly-typed=false"],
)
opts1.set_tensorrt_translation_metadata_callback(layer_metadata_callback2)
api.compiler_stablehlo_to_executable(client, m.operation.clone(), opts1)


print("Checking multiple compile calls")
compile_multiple()

# CHECK-LABEL: Checking multiple compile calls
# CHECK: layer_metadata_callback CALLED
# CHECK: layer_metadata_callback2 CALLED

0 comments on commit 3b82968

Please sign in to comment.