Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds layer metadata APIs #191

Merged
merged 2 commits into from
Oct 1, 2024
Merged

Adds layer metadata APIs #191

merged 2 commits into from
Oct 1, 2024

Conversation

pranavm-nvidia
Copy link
Collaborator

No description provided.

Copy link
Collaborator

@christopherbate christopherbate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's how to fix this up:

  1. Remove the Types.h header you created. That is not necessary.

  2. Refer to this basic outline:

In NetworkEncoder.h, you should just use std::function directly. Notice that we use Operation* here, not MlirOperation.

  std::function<std::string(Operation*)> layerMetadataCallback;

NetworkEncoder.cpp and StableHloToExecutable.h would be changed accordingly to reflect the use of Operation*.

Now the CAPI can be reworked as follows. First, you can't use std::function in the C API since it is plain C. You must use function pointers. The idea here is that the
user must specify a callback that communicates a string across the C API boundary. The caller (say in the C++ PyBind binding code) will construct a C++ function and pass it as a raw function pointer. There are two considerations:

  1. To avoid exposing a whole string interface to the C API, we can just pass an "append" function for the caller to invoke in order to construct the string (again using function pointers).
  2. When we construct a function that is passed-by-pointer across the C API boundary, we can use C++ lambda's to specify this function, but we are not allowed to have captured variables. Therefore, if we require capturing some object, we must 'lift' that state out into a separate "user data" field which can be treated opaquely by the API. There are two instances of this at work in the below example. We use one raw void* to represent the std::string being constructed by the CAPI implementation code. Then we also use one void* to represent the "context" of the callback passed from the CAPI client/caller code. This allows for wrapping things like the py::object combing from Python into an effectively static C function handle.

To do this, first in mlir-tensorrt-c/Compiler/Compiler.h, change MTRT_MetadataCallback to be just a typedef of a function pointer type:

typedef void (*MTRT_MetadataCallback)(MlirOperation op, MlirStringCallback append, void *appendCtx, void *userData);

Now in CAPI/Compiler/Compiler.cpp we can do this:


MTRT_Status
mtrtStableHloToExecutableOptionsSetTensorRTTranslationMetadataCallback(
    MTRT_StableHLOToExecutableOptions options, MTRT_MetadataCallback callback, void * userData) {
  StableHLOToExecutableOptions *cppOpts = unwrap(options);

  // Note: this is a pure C function pointer. Cannot capture, hence the need
  // to lift the accumulation state into `appendCtx`.
  auto appendCallback = [](MlirStringRef str, void *appendCtx) {
     std::string &accum = *reinterpret_cast<std::string*>(appendCtx);
     accum += std::string(str.data, str.length);
   };

  // Note: capture here is OK since we are now constructing a C++ std::function.
  // It's important that we just pass the user data as given.
  cppOpts->layerMetadataCallback = [&](Operation *op) -> std::string {
    std::string accum;
    void *appendCtx = reinterpret_cast<void*>(&accum);
    callback(wrap(op), &appendFunc, appendCtx, userData);
    return accum;
  };
  return mtrtStatusGetOk();
}

Now the tricky part here is how to call it from the client side (PyBind).

Below I assume py::object is a callable object passed from Python. Exact syntax may vary, though.

.def("set....", 
[](PyStableHLOToExecutableOptions &self, py::object callback) {
   
   // We are constructing a C function, so we cannot capture. 
   // We must lift the function state into a separate struct.
   struct MetadataCallbackUserData { 
      py::object pyCallback;
   };


   MTRT_MetadataCallback callback = [](MlirOperation op, MlirStringCallback append,
        void* appendCtx, void *userDataVoid) {

       auto *userData = static_cast<MetadataCallbackUserData*>(userDataVoid);
       
       // Call the actual Python callback.
       // Wrap this in a try...catch.
       std::string result = userData->pyCallback(op);

       // Send the string data across the CAPI boundary.
       append(MlirStringRef{result.data(),result.size()}, appendCtx) 
   };
   
   // This is stack data, but it only ever gets accessed inside the call below. 
   MetadataCallbackUserData userData {
     std::move(callback);
   };
      
   // omitted: handle errors. 
   mtrtStableHloToExecutableOptionsSetTensorRTTranslationMetadataCallback(
     self, callback, reinterpret_cast<void*>(&userData));
});

@pranavm-nvidia pranavm-nvidia force-pushed the layer-metadata branch 3 times, most recently from 50c6388 to b0bf532 Compare September 19, 2024 20:48
@pranavm-nvidia pranavm-nvidia marked this pull request as ready for review September 19, 2024 20:49
@pranavm-nvidia pranavm-nvidia force-pushed the layer-metadata branch 2 times, most recently from 3db8e37 to 3b82968 Compare September 23, 2024 17:19
- Adds a new API which allows for setting a layer metadata callback which
    will be invoked for each MLIR operation in order to set metadata for the
    corresponding TensorRT network layers.
@pranavm-nvidia pranavm-nvidia force-pushed the layer-metadata branch 2 times, most recently from 280ecaa to 821f9ec Compare September 30, 2024 18:11
@pranavm-nvidia pranavm-nvidia merged commit d9339f5 into main Oct 1, 2024
1 check passed
@pranavm-nvidia pranavm-nvidia deleted the layer-metadata branch October 1, 2024 16:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants