Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

int8 input not supported for average pooling in MLIR_TRT #457

Open
farazkh80 opened this issue Dec 18, 2024 · 7 comments
Open

int8 input not supported for average pooling in MLIR_TRT #457

farazkh80 opened this issue Dec 18, 2024 · 7 comments
Assignees
Labels
mlir-tensorrt Pull request for the mlir-tensorrt project

Comments

@farazkh80
Copy link
Collaborator

Happened when running test_dtype_constraints[avgpool-valid:T1-int8]

summary = 'MTRTException: InternalError: failed to run compilation on module with symbol name: ins_t9521_outs_t9522_988\n\nAddit...s._api.MTRTException: InternalError: failed to run compilation on module with symbol name: ins_t9521_outs_t9522_988\n.'
details = ["IBuilder::buildSerializedNetwork: Error Code 1: Internal Error (Node [tensorrt.pooling] (t9522)cannot be quantized b...s=[1, 1, 1, 1], padding=[(0, 0), (0, 0), (0, 0), (0, 0)])\n      | ", '\n', '\nThis operation was introduced to ', ...]

    def raise_error(summary: str, details: List[Any] = []):
        """
        Raises a Tripy exception with a formatted message.
    
        Args:
            summary: A summary of the error message. This will be displayed before any other details.
            details: Details on the error. This function handles objects in this list as follows:
                - If they include a `stack_info` member, then information on the first user frame is displayed,
                    including file/line information as well as the line of code.
    
                    IMPORTANT: Any stack frames from the function registry are not displayed since
                    the function registry is an implementation detail used to dispatch to the real functions
                    we care about. Additionally, any code defined in the functions listed in ``EXCLUDE_FUNCTIONS``
                    is omitted.
    
                - In all other cases, the object is just converted to a string.
    
        Raises:
            TripyException
        """
    
        pre_summary = ""
        stack_info = utils.get_stack_info()
        user_frame_index = stack_info.get_first_user_frame_index()
        if user_frame_index is not None:
            stack_info.fetch_source_code()
            pre_summary = str_from_source_info(stack_info[user_frame_index])
    
        detail_msg = ""
        for detail in details:
            stack_info_message = None
            if hasattr(detail, "stack_info"):
                stack_info_message = str_from_stack_info(detail.stack_info)
            elif isinstance(detail, utils.StackInfo):
                stack_info_message = str_from_stack_info(detail)
    
            if stack_info_message is not None:
                detail_msg += stack_info_message
            else:
                detail_msg += str(detail)
    
        msg = f"{pre_summary}{summary}\n" + indent(detail_msg, " " * 4)
        # We use `from None` to suppress output from previous exceptions, since we want to handle them internally.
>       raise TripyException(msg) from None
E       tripy.common.exception.TripyException: 
E       
E       --> /tripy/tests/wrappers/test_interface.py:221 in _run_dtype_constraints_subtest()
E             |
E         221 |     ret_val.eval()
E             | 
E       
E       MTRTException: InternalError: failed to run compilation on module with symbol name: ins_t9521_outs_t9522_988
E       
E       Additional context:
E       Traceback (most recent call last):
E         File "/tripy/tripy/backend/mlir/compiler.py", line 86, in compile
E           executable = compiler.compiler_stablehlo_to_executable(
E       mlir_tensorrt.runtime._mlir_libs._api.MTRTException: InternalError: failed to run compilation on module with symbol name: ins_t9521_outs_t9522_988
E       .
E           IBuilder::buildSerializedNetwork: Error Code 1: Internal Error (Node [tensorrt.pooling] (t9522)cannot be quantized by arg0. You might want to add a DQ node before [tensorrt.pooling] (t9522).
E           )
E           (t9522)error: failed to translate function 'tensorrt_cluster' to a TensorRT engine
E       
E           This error occured while trying to compile the following FlatIR expression:
E                 |
E                 | t_inter2: [rank=(4), shape=((-1, -1, -1, -1)), dtype=(int8), loc=(gpu:0)] = ReduceWindowOp(t9521, t_inter3, reduce_mode='avg', window_dims=[1, 1, 2, 2], window_strides=[1, 1, 1, 1], padding=[(0, 0), (0, 0), (0, 0), (0, 0)])
E                 | 
E       
E           This operation was introduced to create the output of reduce `avg` operation..
E       
E           Note: This originated from the following expression:
E       
E           --> <string>:7 in <module>()
E       
E           Input 0:
E       
E           --> /tripy/tests/wrappers/object_builders.py:35 in tensor_builder()
E                 |
E              35 |         out = tp.cast(out, dtype=namespace[dtype])
E                 |               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

@pranavm-nvidia
Copy link
Collaborator

@farazkh80 could you post the trace if you still have it?

@yizhuoz004
Copy link
Collaborator

How to reproduce this error? This test passes locally.

@pranavm-nvidia
Copy link
Collaborator

@yizhuoz004 the error only happens when we use an input to the pooling layer. My suspicion is that it's being constant folded in other cases. It will probably repro if you tp.compile the tp.avgpool and use an int8 input.

@farazkh80
Copy link
Collaborator Author

here is the trace

inputs:
    t35: [shape=([1, 1, 8, 8]), dtype=(int8), loc=(gpu:0)]
t36 = pooling(t35, kind=Kind.AVG, kernel_dims=[2, 2], stride=[1, 1], padding=[(0, 0), (0, 0)])
outputs:
    t36: [shape=([-1, -1, -1, -1]), dtype=(int8), loc=(gpu:0)]

@yizhuoz004
Copy link
Collaborator

Can reproduce by making the int8 tensor as an input. This is most likely a TRT constraint, will file a bug. We can waive it for now. Also torch avg pooling does not support int8, this should be a rare use case.

@yizhuoz004
Copy link
Collaborator

There are issues in both TRT and MLIR-TRT.
TRT: this error is not expected when there is no explict Q/DQ node.
MLIR-TRT: stablehlo -> tensorrt translation contains unnecessary elementwise layers, for a single avg pooling layer:

  tensorrt.module @trt_engines {
    func.func @tensorrt_cluster(%arg0: tensor<1x1x8x8xi8>) -> (tensor<1x1x7x7xi8> {tensorrt.shape_profile = #profile}) attributes {cluster.tensorrt} {
      %cst_i8 = tensorrt.constant dense<4> : tensor<1x1x1x1xi8>
      %cst_i8_0 = tensorrt.constant dense<4> : tensor<1x1x7x7xi8>
      %0 = tensorrt.pooling {averageCountExcludesPadding = true, poolingType = #tensorrt.pooling_type<kAVERAGE>, postPadding = array<i64: 0, 0>, prePadding = array<i64: 0, 0>, stride = array<i64: 1, 1>, windowSize = array<i64: 2, 2>} ins(%arg0 : tensor<1x1x8x8xi8>) -> tensor<1x1x7x7xi8>
      %1 = tensorrt.element_wise <kPROD>(%0, %cst_i8 : tensor<1x1x7x7xi8>, tensor<1x1x1x1xi8>) -> tensor<1x1x7x7xi8>
      %2 = tensorrt.element_wise <kDIV>(%1, %cst_i8_0 : tensor<1x1x7x7xi8>, tensor<1x1x7x7xi8>) -> tensor<1x1x7x7xi8>
      return %2 : tensor<1x1x7x7xi8>
    }
  }

@shelkesagar29 shelkesagar29 self-assigned this Jan 15, 2025
@shelkesagar29 shelkesagar29 added the mlir-tensorrt Pull request for the mlir-tensorrt project label Jan 15, 2025
@shelkesagar29
Copy link
Collaborator

Fixed a push internally. Should be available to OSS soon.

christopherbate added a commit that referenced this issue Feb 4, 2025
26523df3e94cc4de47a744e4e48621b74743dd00 by Sagar Shelke <[email protected]>:

[compiler/lib/Conversion] Update `stablehlo.reduce_window` conversion pattern

Previously, if `stablehlo.reduce_window` op body has `stablehlo.add` (i.e.
reduce_window<add>), we always converted such reduce_window op to
`tensorrt.pool` op with average pooling. We also inserted a multiplication
operation to balance the fact that `reduce_window<add>` is being replaced
with average pooling.

This pattern did not consider a case where, if `stablehlo.divide` op
is present after `reduce_window<add>` and its the only user of `reduce_window<add>`,
it is truly average pooling. In this case, instead of inserting multiplication,
we can consume next `stablehlo.divide` op into TensorRT average pooling.

This MR updates conversion pattern to accommodate this and adds positive
, negative MLIR tests.

--
d6631d8be74994ae26c58f4afe9d087329a20985 by Christopher Bate <[email protected]>:

Revert "NFC: fix some unused value warnings in Release build"

This reverts commit 0a9df7a05ace412a423d72076ccde1380782852f.

--
0a9df7a05ace412a423d72076ccde1380782852f by Christopher Bate <[email protected]>:

NFC: fix some unused value warnings in Release build

--
842b0171738aaf14327b6db74ace88733f3f62fe by Chris Bate <[email protected]>:

[compiler] Add ability to export large binary data to external file resource

This change updates the `cuda-to-llvm` and `tensorrt-runtime-to-llvm`
passes in order to add an `artifacts-dir` option. When non-empty, this
directory is used to offload binary blobs data otherwise would be
encoded into the compiled program.

Often the large encoded constants typically are passed directly to an
externally defined function (e.g  `mtrt_cuda_module_load_from_ptx`)
using `llvm.addressof`, so in relevant conversions just add another
variant `*_file` variant of the runtime call
(e.g  `mtrt_cuda_module_load_from_ptx_file`). This then lets the runtime
figure out how to load the file data.

The one other op where large data is often encoded is `memref.global`, which will be
handled in a follow on change.

--
c270c740d1032a40e7b34aeaac36c42b8b4298c0 by Chris Bate <[email protected]>:

[compiler] Update cuda-to-llvm to create separate dtor/dtor for each global

Create separate ctor/dtor for each global. This avoids issues if the pass
is run multiple times and prevents empty functions from being created when
no globals are created.

--
7d3c310f64b7a5c513c0435a75675003e32f0c58 by Chris Bate <[email protected]>:

[compiler] Add promise of ConvertToLLVMPatternInterface to CUDADialect

CUDA dialect was not promising `ConvertToLLVMPatternInterface`, which
meant that `cuda-to-llvm` patterns were not being populated by
`host-to-llvm`. Currently we explicitly run `cuda-to-llvm` prior to
`host-to-llvm`, but in the future that will change.

--
7e2e7aaaace73842efdeb6d5e8002e3f7c76b4af by Chris Bate <[email protected]>:

[cmake] Improve CMake package/install organization

- Creates dedicated CMake install components for easier packaging
- Disable upstream MLIR CMake install logic when invoking upstream
  MLIR cmake commands.

--
d786bcdc4bc1c683eabc6a37a902ac7a75568bfe by Chris Bate <[email protected]>:

[compiler] Enable end-to-end host LLVM JIT and C++ gen compile & execute test

This change helps enables the first integration test which can be compiled and
executed in three different ways:

1. compile to Executor IR -> Lua based interpreter
2. compile to LLVM IR -> LLVM JIT runner

- We will need to expand the C support libraries for more CUDA, TensorRT,
  CuBlas, and NCCL module coverage equivalent to what we have for Lua.

- In the Lua-based backend, we had the convenience of some auto-magical
  error reporting mechanisms that get added to the runtime support
  library functions (e.g. using `lua::this_state` argument). We need to
  craft something that is compatible with LLVM and EmitC pipelines because
  right now the C runtime support functions just abort on error, which
  isn't OK outside of simple integration tests.

--
8936f5b68e4d3efa646884c8ffee2ce6940c76b4 by Christopher Bate <[email protected]>:

Cherry-pick upstream EmitC fixes

--
fb672f3da68dbdcec4617bddbbaadc0d40b7f48d by Chris Bate <[email protected]>:

NFC: [compiler] Untangle conversion passes header inclusion of NvInfer.h

--
2528dd25f80cb838d82181f6b9bc386d0e1bda6d by Chris Bate <[email protected]>:

[python] Improve performance of compiler API tests

- Migrate some compiler python API tests to the newer APIs.
- Reuse existing compiler clients where possible
- Don't use "tensorrt-builder-opt-level=3" unnecessarily.
- Compiler python test performance is now dramatically improved.

--
e1ebdced50ac27c808c4bd4488a3f2d7ee645983 by Chris Bate <[email protected]>:

[compiler] Fix two issues in LLVM conversion utilities

Fixes an issue where a utility may create a `SymbolTable` in the middle
of a dialect conversion. This could cause an assertion to trigger since
symbols may not be unique in the middle of a conversion.

Additionally, `llvm.global` ops representing string literals now use
StringAttr to hold their value.

--
fd65b941673b49b6ff9a8b053b845dd1e051eb8c by Chris Bate <[email protected]>:

[compiler] Fix some issues with CUDA lowerings

- Adds fields to memory alloc/free to communicate whether a buffer is
  device/host_pinned/managed.

--
9a88722b7ca54fe59c280b70068c98a5a07784ce by Chris Bate <[email protected]>:

[compiler] Add additional support for `cuda-to-llvm`

Adds some of the major missing ops dealing with stream, device, and
memory management to the `convert-cuda-to-llvm` pass.
After this change, we have enough support to support runtime testing.

--
0df274addcd00aab6501ecd3c72d1e72008cbe3b by Christopher Bate <[email protected]>:

[compiler] Add 'convert-tensorrt-runtime-to-llvm' pass

Adds an pass that convers TensorRTRuntime dialect operations/types to LLVM
dialect operations/types.

--
ca94e04dfc5f0134a25b218ef6de15503a0636cf by Christopher Bate <[email protected]>:

[compiler] Simplify workflow for lowering TensorRT engines

Adds a global symbol op and corresponding load op to the TensorRTRuntime
dialect. These operations represent the TensorRT engine binary and loading
of the TensorRT engine into a runtime execution context. The conversions
`tensorrt-to-tensorrt-runtime` and `tensorrt-runtime-to-executor` are
updated/simplified, and this change helps to reduce complexity for the
`tensorrt-runtime-to-llvm` change to be added. Two ops and a type can
be dropped from the TensorRTRuntime dialect.

--
5db84d6a239a275b7c29f8c40635622deff29bb8 by Christopher Bate <[email protected]>:

[compiler] Fix some issues in cuda-to-llvm conversion

Fixes a couple issues in the CUDA-to-LLVM conversion. To make it easier
to write similar patterns for other dialects, this change also introduces
some utilities that are used to cleanup the code. They help to make it easier
to create LLVM globals for string literals and other objects as well as
help to ensure there are no symbol name conflicts with rewriting globals
from one dialect to another.

--
28ad3e14eed4eb7112876fdf8f7a892532566dae by Chris Bate <[email protected]>:

[compiler] Add aggregate "convert-host-to-llvm" conversion pass

Adds an aggregate pass "convert-host-to-llvm" which convers the host
program IR to LLVM. The purpose is to aggregate patterns and type
conversions for CUDA, TensorRTRuntime, and various other upstream
dialects in order to enable lowering to LLVM IR in a single pass.

Additionally, this change provides ConvertToLLVMPatternInterface
for the CUDA and Plan dialects so that they may hook into
"convert-host-to-llvm". The CUDA dialect to LLVM conversion is further
updated to correct the type of the stream (to pointer, not i8).

--
e98df022081e3e648a67c160d329f668c99bdcac by Christopher Bate <[email protected]>:

nfc: remove whitespace from cuda-to-llvm test

--
bbefdbf03fcbd8a60c1e1c942faef3b67817d57a by Chris Bate <[email protected]>:

[cmake] NFC: Remove 'MLIRTensorRTRegistration' library

Previously we used a catch-all "registration" library to collect all
dependencies of the "registerAllDialects|Passes" functions. However,
this actually caused a subtle circular dependency with the
StablehloToExecutable library which could manifest as a random build-time
error due to headers not being generated in the correct order. This
change removes that library and instead declares dependencies in a more
fine-grained manner where required.

--
cf3a76338ebe3e2c6770947bc8944dd4b6bb4e59 by Sagar Shelke <[email protected]>:

Add integer support to `RemoveProdDivPair` canonicalize pattern

This MR adds integer support to `RemoveProdDivPair`
canonicalize pattern that removes pair of `kPROD` and
`kDIV` ops if constant RHS in both multiply and division
is 1.

Positive and negative MLIR test is added.
This fixes OSS issue #457

--
df381b5e747bb79842afdd62aed3cd2dff7fe564 by Christopher Bate <[email protected]>:

NFC: fix test file location and pass dependencies

Fixes a couple minor issues from f4959954b4daccd270323fa47867cbd12a62f97d.

--
f4959954b4daccd270323fa47867cbd12a62f97d by Zixin Huang <[email protected]>:

[compiler] New cuda-to-llvm pass

This MR converts cuda dialect ops into llvm ops. This will allow us to
generate LLVM IR for host code.

--
4343a6c60331de176eafab5fe4c91374e7d62a2e by Chris Bate <[email protected]>:

[executor] Fix conversion of function signature metadata

When a function lacks a 'executor.function_metadata' attribute, we should
create a function signature just using the function's MLIR type. It will
lack certain information (e.g. bounds information), but that is better
than not serializing a signature at all. Certain methods like the associated
runtime API's 'print' method for function flatbuffer objects were not
handling the case where the signature could be null.

--
2d4cc749800887f7cf5580ab664a9a18a29f4d2f by Chris Bate <[email protected]>:

[compiler] NFC: change 'enable-non-dps-returns' to 'force-entrypoints-return-allocs' in 'plan-alloc-tensors'

In the 'plan-alloc-tensors' pass and related pipelines, we had an option
previously named 'enable-non-dps-returns'. However, this doesn't accurately
reflect the desired effect -- even if this option is off, some tensor results
of entrypoint functions may be lowered into a returned allocations.
If the shape is not computable from the input parameters, then the user
cannot pre-allocate a result buffer, and therefore the tensor must be lowered
into a returned allocation.

--
54319fa2fd1093bab1f4a16d85af5a19d6d9a6d3 by Chris Bate <[email protected]>:

NFC: fix typo in RuntimeSession member function name

--
d7a6e722c4e23e239ea32861ac542d4e943d40d3 by Chris Bate <[email protected]>:

NFC: [executor] Fix typo in CAPI type name

--
497647f15587fa927748a66721c77d1df6f6089c by Chris Bate <[email protected]>:

[compiler] Add support for non-DPS TensorRT call variants in `plan-outline-clusters`

Adds support for outlining `plan.closed_alloc_group` regions targeting
TensorRT in the `plan-outline-clusters` pass.

Co-authored-by: Jhalak Patel <[email protected]>

--
247de07ad07a017c8ff9408083439488d4f0220d by Chris Bate <[email protected]>:

[compiler] Add support for `tensorrt.call_alloc` in `plan-eliminate-shape-ops`

Add support for shape-op and argument cleanup in `plan-eliminate-shape-ops`
for the non-DPS TensorRT call variant `tensorrt.call_alloc`.

Co-authored-by: Jhalak Patel <[email protected]>

--
11621ff176f2503d04a7eba5e59e79ada3560310 by Chris Bate <[email protected]>:

[compiler] Fix incorrect conversion in `tensorrt-runtime-to-executor`

Fix miscellaneous issues in the conversion of `trtrt.enqueue_alloc`
to Executor IR. Previously, the offsets into the output descriptor were
not being correctly calculated.

Co-authored-by: Jhalak Patel <[email protected]>

GitOrigin-RevId: 1151c1999d0aa77991637455673a8f4ba5dd8cf3
christopherbate added a commit that referenced this issue Feb 4, 2025
26523df3e94cc4de47a744e4e48621b74743dd00 by Sagar Shelke
<[email protected]>:

[compiler/lib/Conversion] Update `stablehlo.reduce_window` conversion
pattern

Previously, if `stablehlo.reduce_window` op body has `stablehlo.add`
(i.e.
reduce_window<add>), we always converted such reduce_window op to
`tensorrt.pool` op with average pooling. We also inserted a
multiplication
operation to balance the fact that `reduce_window<add>` is being
replaced
with average pooling.

This pattern did not consider a case where, if `stablehlo.divide` op
is present after `reduce_window<add>` and its the only user of
`reduce_window<add>`,
it is truly average pooling. In this case, instead of inserting
multiplication,
we can consume next `stablehlo.divide` op into TensorRT average pooling.

This MR updates conversion pattern to accommodate this and adds positive
, negative MLIR tests.

--
d6631d8be74994ae26c58f4afe9d087329a20985 by Christopher Bate
<[email protected]>:

Revert "NFC: fix some unused value warnings in Release build"

This reverts commit 0a9df7a05ace412a423d72076ccde1380782852f.

--
0a9df7a05ace412a423d72076ccde1380782852f by Christopher Bate
<[email protected]>:

NFC: fix some unused value warnings in Release build

--
842b0171738aaf14327b6db74ace88733f3f62fe by Chris Bate
<[email protected]>:

[compiler] Add ability to export large binary data to external file
resource

This change updates the `cuda-to-llvm` and `tensorrt-runtime-to-llvm`
passes in order to add an `artifacts-dir` option. When non-empty, this
directory is used to offload binary blobs data otherwise would be
encoded into the compiled program.

Often the large encoded constants typically are passed directly to an
externally defined function (e.g  `mtrt_cuda_module_load_from_ptx`)
using `llvm.addressof`, so in relevant conversions just add another
variant `*_file` variant of the runtime call
(e.g  `mtrt_cuda_module_load_from_ptx_file`). This then lets the runtime
figure out how to load the file data.

The one other op where large data is often encoded is `memref.global`,
which will be
handled in a follow on change.

--
c270c740d1032a40e7b34aeaac36c42b8b4298c0 by Chris Bate
<[email protected]>:

[compiler] Update cuda-to-llvm to create separate dtor/dtor for each
global

Create separate ctor/dtor for each global. This avoids issues if the
pass
is run multiple times and prevents empty functions from being created
when
no globals are created.

--
7d3c310f64b7a5c513c0435a75675003e32f0c58 by Chris Bate
<[email protected]>:

[compiler] Add promise of ConvertToLLVMPatternInterface to CUDADialect

CUDA dialect was not promising `ConvertToLLVMPatternInterface`, which
meant that `cuda-to-llvm` patterns were not being populated by
`host-to-llvm`. Currently we explicitly run `cuda-to-llvm` prior to
`host-to-llvm`, but in the future that will change.

--
7e2e7aaaace73842efdeb6d5e8002e3f7c76b4af by Chris Bate
<[email protected]>:

[cmake] Improve CMake package/install organization

- Creates dedicated CMake install components for easier packaging
- Disable upstream MLIR CMake install logic when invoking upstream
  MLIR cmake commands.

--
d786bcdc4bc1c683eabc6a37a902ac7a75568bfe by Chris Bate
<[email protected]>:

[compiler] Enable end-to-end host LLVM JIT and C++ gen compile & execute
test

This change helps enables the first integration test which can be
compiled and
executed in three different ways:

1. compile to Executor IR -> Lua based interpreter
2. compile to LLVM IR -> LLVM JIT runner

- We will need to expand the C support libraries for more CUDA,
TensorRT,
  CuBlas, and NCCL module coverage equivalent to what we have for Lua.

- In the Lua-based backend, we had the convenience of some auto-magical
  error reporting mechanisms that get added to the runtime support
  library functions (e.g. using `lua::this_state` argument). We need to
craft something that is compatible with LLVM and EmitC pipelines because
  right now the C runtime support functions just abort on error, which
  isn't OK outside of simple integration tests.

--
8936f5b68e4d3efa646884c8ffee2ce6940c76b4 by Christopher Bate
<[email protected]>:

Cherry-pick upstream EmitC fixes

--
fb672f3da68dbdcec4617bddbbaadc0d40b7f48d by Chris Bate
<[email protected]>:

NFC: [compiler] Untangle conversion passes header inclusion of NvInfer.h

--
2528dd25f80cb838d82181f6b9bc386d0e1bda6d by Chris Bate
<[email protected]>:

[python] Improve performance of compiler API tests

- Migrate some compiler python API tests to the newer APIs.
- Reuse existing compiler clients where possible
- Don't use "tensorrt-builder-opt-level=3" unnecessarily.
- Compiler python test performance is now dramatically improved.

--
e1ebdced50ac27c808c4bd4488a3f2d7ee645983 by Chris Bate
<[email protected]>:

[compiler] Fix two issues in LLVM conversion utilities

Fixes an issue where a utility may create a `SymbolTable` in the middle
of a dialect conversion. This could cause an assertion to trigger since
symbols may not be unique in the middle of a conversion.

Additionally, `llvm.global` ops representing string literals now use
StringAttr to hold their value.

--
fd65b941673b49b6ff9a8b053b845dd1e051eb8c by Chris Bate
<[email protected]>:

[compiler] Fix some issues with CUDA lowerings

- Adds fields to memory alloc/free to communicate whether a buffer is
  device/host_pinned/managed.

--
9a88722b7ca54fe59c280b70068c98a5a07784ce by Chris Bate
<[email protected]>:

[compiler] Add additional support for `cuda-to-llvm`

Adds some of the major missing ops dealing with stream, device, and
memory management to the `convert-cuda-to-llvm` pass.
After this change, we have enough support to support runtime testing.

--
0df274addcd00aab6501ecd3c72d1e72008cbe3b by Christopher Bate
<[email protected]>:

[compiler] Add 'convert-tensorrt-runtime-to-llvm' pass

Adds an pass that convers TensorRTRuntime dialect operations/types to
LLVM
dialect operations/types.

--
ca94e04dfc5f0134a25b218ef6de15503a0636cf by Christopher Bate
<[email protected]>:

[compiler] Simplify workflow for lowering TensorRT engines

Adds a global symbol op and corresponding load op to the TensorRTRuntime
dialect. These operations represent the TensorRT engine binary and
loading
of the TensorRT engine into a runtime execution context. The conversions
`tensorrt-to-tensorrt-runtime` and `tensorrt-runtime-to-executor` are
updated/simplified, and this change helps to reduce complexity for the
`tensorrt-runtime-to-llvm` change to be added. Two ops and a type can
be dropped from the TensorRTRuntime dialect.

--
5db84d6a239a275b7c29f8c40635622deff29bb8 by Christopher Bate
<[email protected]>:

[compiler] Fix some issues in cuda-to-llvm conversion

Fixes a couple issues in the CUDA-to-LLVM conversion. To make it easier
to write similar patterns for other dialects, this change also
introduces
some utilities that are used to cleanup the code. They help to make it
easier
to create LLVM globals for string literals and other objects as well as
help to ensure there are no symbol name conflicts with rewriting globals
from one dialect to another.

--
28ad3e14eed4eb7112876fdf8f7a892532566dae by Chris Bate
<[email protected]>:

[compiler] Add aggregate "convert-host-to-llvm" conversion pass

Adds an aggregate pass "convert-host-to-llvm" which convers the host
program IR to LLVM. The purpose is to aggregate patterns and type
conversions for CUDA, TensorRTRuntime, and various other upstream
dialects in order to enable lowering to LLVM IR in a single pass.

Additionally, this change provides ConvertToLLVMPatternInterface
for the CUDA and Plan dialects so that they may hook into
"convert-host-to-llvm". The CUDA dialect to LLVM conversion is further
updated to correct the type of the stream (to pointer, not i8).

--
e98df022081e3e648a67c160d329f668c99bdcac by Christopher Bate
<[email protected]>:

nfc: remove whitespace from cuda-to-llvm test

--
bbefdbf03fcbd8a60c1e1c942faef3b67817d57a by Chris Bate
<[email protected]>:

[cmake] NFC: Remove 'MLIRTensorRTRegistration' library

Previously we used a catch-all "registration" library to collect all
dependencies of the "registerAllDialects|Passes" functions. However,
this actually caused a subtle circular dependency with the
StablehloToExecutable library which could manifest as a random
build-time
error due to headers not being generated in the correct order. This
change removes that library and instead declares dependencies in a more
fine-grained manner where required.

--
cf3a76338ebe3e2c6770947bc8944dd4b6bb4e59 by Sagar Shelke
<[email protected]>:

Add integer support to `RemoveProdDivPair` canonicalize pattern

This MR adds integer support to `RemoveProdDivPair`
canonicalize pattern that removes pair of `kPROD` and
`kDIV` ops if constant RHS in both multiply and division
is 1.

Positive and negative MLIR test is added.
This fixes OSS issue
#457

--
df381b5e747bb79842afdd62aed3cd2dff7fe564 by Christopher Bate
<[email protected]>:

NFC: fix test file location and pass dependencies

Fixes a couple minor issues from
f4959954b4daccd270323fa47867cbd12a62f97d.

--
f4959954b4daccd270323fa47867cbd12a62f97d by Zixin Huang
<[email protected]>:

[compiler] New cuda-to-llvm pass

This MR converts cuda dialect ops into llvm ops. This will allow us to
generate LLVM IR for host code.

--
4343a6c60331de176eafab5fe4c91374e7d62a2e by Chris Bate
<[email protected]>:

[executor] Fix conversion of function signature metadata

When a function lacks a 'executor.function_metadata' attribute, we
should
create a function signature just using the function's MLIR type. It will
lack certain information (e.g. bounds information), but that is better
than not serializing a signature at all. Certain methods like the
associated
runtime API's 'print' method for function flatbuffer objects were not
handling the case where the signature could be null.

--
2d4cc749800887f7cf5580ab664a9a18a29f4d2f by Chris Bate
<[email protected]>:

[compiler] NFC: change 'enable-non-dps-returns' to
'force-entrypoints-return-allocs' in 'plan-alloc-tensors'

In the 'plan-alloc-tensors' pass and related pipelines, we had an option
previously named 'enable-non-dps-returns'. However, this doesn't
accurately
reflect the desired effect -- even if this option is off, some tensor
results
of entrypoint functions may be lowered into a returned allocations.
If the shape is not computable from the input parameters, then the user
cannot pre-allocate a result buffer, and therefore the tensor must be
lowered
into a returned allocation.

--
54319fa2fd1093bab1f4a16d85af5a19d6d9a6d3 by Chris Bate
<[email protected]>:

NFC: fix typo in RuntimeSession member function name

--
d7a6e722c4e23e239ea32861ac542d4e943d40d3 by Chris Bate
<[email protected]>:

NFC: [executor] Fix typo in CAPI type name

--
497647f15587fa927748a66721c77d1df6f6089c by Chris Bate
<[email protected]>:

[compiler] Add support for non-DPS TensorRT call variants in
`plan-outline-clusters`

Adds support for outlining `plan.closed_alloc_group` regions targeting
TensorRT in the `plan-outline-clusters` pass.

Co-authored-by: Jhalak Patel <[email protected]>

--
247de07ad07a017c8ff9408083439488d4f0220d by Chris Bate
<[email protected]>:

[compiler] Add support for `tensorrt.call_alloc` in
`plan-eliminate-shape-ops`

Add support for shape-op and argument cleanup in
`plan-eliminate-shape-ops`
for the non-DPS TensorRT call variant `tensorrt.call_alloc`.

Co-authored-by: Jhalak Patel <[email protected]>

--
11621ff176f2503d04a7eba5e59e79ada3560310 by Chris Bate
<[email protected]>:

[compiler] Fix incorrect conversion in `tensorrt-runtime-to-executor`

Fix miscellaneous issues in the conversion of `trtrt.enqueue_alloc`
to Executor IR. Previously, the offsets into the output descriptor were
not being correctly calculated.

Co-authored-by: Jhalak Patel <[email protected]>

GitOrigin-RevId: 1151c1999d0aa77991637455673a8f4ba5dd8cf3
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir-tensorrt Pull request for the mlir-tensorrt project
Projects
None yet
Development

No branches or pull requests

4 participants