-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds layer metadata APIs #191
Conversation
8ba77d2
to
2b91dd4
Compare
mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h
Outdated
Show resolved
Hide resolved
mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Utils/Types.h
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here's how to fix this up:
-
Remove the
Types.h
header you created. That is not necessary. -
Refer to this basic outline:
In NetworkEncoder.h, you should just use std::function
directly. Notice that we use Operation*
here, not MlirOperation
.
std::function<std::string(Operation*)> layerMetadataCallback;
NetworkEncoder.cpp
and StableHloToExecutable.h
would be changed accordingly to reflect the use of Operation*
.
Now the CAPI can be reworked as follows. First, you can't use std::function
in the C API since it is plain C. You must use function pointers. The idea here is that the
user must specify a callback that communicates a string across the C API boundary. The caller (say in the C++ PyBind binding code) will construct a C++ function and pass it as a raw function pointer. There are two considerations:
- To avoid exposing a whole string interface to the C API, we can just pass an "append" function for the caller to invoke in order to construct the string (again using function pointers).
- When we construct a function that is passed-by-pointer across the C API boundary, we can use C++ lambda's to specify this function, but we are not allowed to have captured variables. Therefore, if we require capturing some object, we must 'lift' that state out into a separate "user data" field which can be treated opaquely by the API. There are two instances of this at work in the below example. We use one raw
void*
to represent thestd::string
being constructed by the CAPI implementation code. Then we also use onevoid*
to represent the "context" of the callback passed from the CAPI client/caller code. This allows for wrapping things like thepy::object
combing from Python into an effectively static C function handle.
To do this, first in mlir-tensorrt-c/Compiler/Compiler.h
, change MTRT_MetadataCallback
to be just a typedef of a function pointer type:
typedef void (*MTRT_MetadataCallback)(MlirOperation op, MlirStringCallback append, void *appendCtx, void *userData);
Now in CAPI/Compiler/Compiler.cpp
we can do this:
MTRT_Status
mtrtStableHloToExecutableOptionsSetTensorRTTranslationMetadataCallback(
MTRT_StableHLOToExecutableOptions options, MTRT_MetadataCallback callback, void * userData) {
StableHLOToExecutableOptions *cppOpts = unwrap(options);
// Note: this is a pure C function pointer. Cannot capture, hence the need
// to lift the accumulation state into `appendCtx`.
auto appendCallback = [](MlirStringRef str, void *appendCtx) {
std::string &accum = *reinterpret_cast<std::string*>(appendCtx);
accum += std::string(str.data, str.length);
};
// Note: capture here is OK since we are now constructing a C++ std::function.
// It's important that we just pass the user data as given.
cppOpts->layerMetadataCallback = [&](Operation *op) -> std::string {
std::string accum;
void *appendCtx = reinterpret_cast<void*>(&accum);
callback(wrap(op), &appendFunc, appendCtx, userData);
return accum;
};
return mtrtStatusGetOk();
}
Now the tricky part here is how to call it from the client side (PyBind).
Below I assume py::object
is a callable object passed from Python. Exact syntax may vary, though.
.def("set....",
[](PyStableHLOToExecutableOptions &self, py::object callback) {
// We are constructing a C function, so we cannot capture.
// We must lift the function state into a separate struct.
struct MetadataCallbackUserData {
py::object pyCallback;
};
MTRT_MetadataCallback callback = [](MlirOperation op, MlirStringCallback append,
void* appendCtx, void *userDataVoid) {
auto *userData = static_cast<MetadataCallbackUserData*>(userDataVoid);
// Call the actual Python callback.
// Wrap this in a try...catch.
std::string result = userData->pyCallback(op);
// Send the string data across the CAPI boundary.
append(MlirStringRef{result.data(),result.size()}, appendCtx)
};
// This is stack data, but it only ever gets accessed inside the call below.
MetadataCallbackUserData userData {
std::move(callback);
};
// omitted: handle errors.
mtrtStableHloToExecutableOptionsSetTensorRTTranslationMetadataCallback(
self, callback, reinterpret_cast<void*>(&userData));
});
mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Utils/Types.h
Outdated
Show resolved
Hide resolved
mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp
Outdated
Show resolved
Hide resolved
mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp
Outdated
Show resolved
Hide resolved
50c6388
to
b0bf532
Compare
3db8e37
to
3b82968
Compare
3b82968
to
c4ef5c9
Compare
mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Utils/Options.h
Outdated
Show resolved
Hide resolved
- Adds a new API which allows for setting a layer metadata callback which will be invoked for each MLIR operation in order to set metadata for the corresponding TensorRT network layers.
280ecaa
to
821f9ec
Compare
821f9ec
to
67e72cf
Compare
No description provided.