Skip to content

Commit

Permalink
Overhaul backend function execution for improved performance and flex…
Browse files Browse the repository at this point in the history
…ibility

This PR replaces the DPS-style calling convention with a non-DPS approach, eliminating the requirement for call sites to preallocate output buffers. This change enables us to bypass the computation of output shapes and advance allocation of output buffers, laying the groundwork for supporting data-dependent shapes where network outputs can have dynamic dimensions.

The underlying compiler stack has been enhanced to avoid allocating oversized buffers and eliminate an extra device-to-device copy operation from TensorRT-allocated memory to MLIR-TRT managed memory.

Additionally, we've improved the copy operation to support copying to host memory. This enhancement removes the need to track output device allocations for device-to-host copies. Previously, copy outputs were restricted to device allocations; now they can be allocated on both device and host.

Tests have been updated to align with the new calling convention, ensuring compatibility and correctness.
  • Loading branch information
jhalakpatel committed Nov 4, 2024
1 parent 3ac751b commit 780e18b
Show file tree
Hide file tree
Showing 10 changed files with 61 additions and 150 deletions.
2 changes: 1 addition & 1 deletion tripy/tests/backend/api/test_executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_signature(self, single_return_executable):
assert param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
assert param.annotation == tp.Tensor

assert signature.return_annotation == tp.Tensor
assert signature.return_annotation == Sequence[tp.Tensor]

def test_signature_multiple_return_values(self, multiple_return_executable):
signature = inspect.signature(multiple_return_executable)
Expand Down
3 changes: 1 addition & 2 deletions tripy/tests/frontend/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,7 @@ def test_no_explicit_cast(self):
"devices",
[
("cpu", "gpu"),
# TODO(#155)
# ("gpu", "cpu"),
("gpu", "cpu"),
],
)
def test_explicit_copy(self, devices):
Expand Down
15 changes: 8 additions & 7 deletions tripy/tests/integration/test_iota.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,17 @@ def test_iota_like(self, dtype, shape, dim):

@pytest.mark.parametrize("dtype", DATA_TYPES.values())
def test_negative_no_casting(self, dtype):
from tripy.frontend.trace.ops.iota import Iota
with tp.logger.use_verbosity("ir"):
from tripy.frontend.trace.ops.iota import Iota

if dtype in [tp.float32, tp.int32, tp.int64]:
pytest.skip("tp.iota() supports float32, int32, and int64 without cast")
if dtype in [tp.float32, tp.int32, tp.int64]:
pytest.skip("tp.iota() supports float32, int32, and int64 without cast")

# TODO: update the 'match' error msg when MLIR-TRT fixes dtype constraint
a = tp.ones((2, 2))
out = Iota.build([frontend_utils.tensor_from_shape_like(a.shape)], dim=0, output_rank=2, dtype=dtype)
# TODO: update the 'match' error msg when MLIR-TRT fixes dtype constraint
a = tp.ones((2, 2))
out = Iota.build([frontend_utils.tensor_from_shape_like(a.shape)], dim=0, output_rank=2, dtype=dtype)

exception_str = "error: 'tensorrt.linspace' op result #0 must be 0D/1D/2D/3D/4D/5D/6D/7D/8D tensor of 32-bit float or 32-bit signless integer values"
exception_str = "InternalError: failed to run compilation on module with symbol name."
if dtype == tp.bool:
exception_str = "InternalError: failed to run compilation"
with helper.raises(
Expand Down
3 changes: 2 additions & 1 deletion tripy/tests/integration/test_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,5 +117,6 @@ def test_non_constant_scale(self):
input = tp.ones((4, 4))
scale = tp.ones((4,))
quantized = tp.quantize(input, scale, tp.int8, dim=0)
quantized_int32 = tp.cast(quantized, tp.int32)

assert bool(tp.all(quantized == tp.ones((4, 4), dtype=tp.int8)))
assert bool(tp.all(quantized_int32 == tp.ones((4, 4), dtype=tp.int32)))
1 change: 0 additions & 1 deletion tripy/tripy/backend/api/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,5 +196,4 @@ def process_arg(name, arg):
return Executable(
executable,
compiled_arg_names,
output_devices=[out.device for out in trace.outputs],
)
50 changes: 29 additions & 21 deletions tripy/tripy/backend/api/executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
import base64
import inspect
from typing import Sequence, Union
from typing import Sequence, Union, Tuple, Callable

import mlir_tensorrt.runtime.api as runtime

Expand All @@ -37,13 +37,11 @@ class Executable:
"""

# The constructor is intentionally undocumented because it is not meant to be called by users.
# TODO(#155): output_devices is not needed after they can be queried from executable
def __init__(self, executable, arg_names, output_devices):
def __init__(self, executable, arg_names):
self._executable = executable
self._executor = Executor(self._executable)
self._arg_names = arg_names
self._num_expected_args = len(arg_names)
self._output_devices = output_devices
self._executable_signature = self._executable.get_signature("main")

# Build a signature so the executable works with `inspect.signature`
Expand Down Expand Up @@ -128,7 +126,7 @@ def add(a, b):
tensor.eval()

try:
executor_outputs = self._executor.execute(self._output_devices, input_tensors)
executor_outputs = self._executor.execute(input_tensors)
except runtime.MTRTException as err:
# TODO: Evaluate whether this should be moved into the executor
if "function expects a memref type with element type" in str(err):
Expand Down Expand Up @@ -170,15 +168,22 @@ def add(a, b):
output_tensors = output_tensors[0]
return output_tensors

def _get_arg_info(self, idx):
arg = self._executable_signature.get_arg(idx)
arg = runtime.MemRefType(arg)
arg_bound = self._executable_signature.get_arg_bound(idx)
shape_bounds = tuple(zip(arg_bound.min(), arg_bound.max()))
if len(shape_bounds) == 0:
# For static shape arguments, get_arg_bound returns an empty list and we fallback to arg.shape
shape_bounds = tuple((x, x) for x in arg.shape)
return ArgInfo(shape_bounds, mlir_utils.convert_runtime_dtype_to_tripy_dtype(arg.dtype))
def _get_info(self, idx: int, get_item: Callable, get_bound: Callable) -> ArgInfo:
item = runtime.MemRefType(get_item(idx))
bound = get_bound(idx)
shape_bounds = tuple(zip(bound.min(), bound.max()))

if not shape_bounds:
# For static shape, fallback to item.shape
shape_bounds = tuple((x, x) for x in item.shape)

return ArgInfo(shape_bounds, mlir_utils.convert_runtime_dtype_to_tripy_dtype(item.dtype))

def _get_arg_info(self, idx: int) -> ArgInfo:
return self._get_info(idx, self._executable_signature.get_arg, self._executable_signature.get_arg_bound)

def _get_result_info(self, idx: int) -> ArgInfo:
return self._get_info(idx, self._executable_signature.get_result, self._executable_signature.get_res_bound)

def get_input_info(self) -> Sequence[ArgInfo]:
"""
Expand Down Expand Up @@ -221,11 +226,16 @@ def add(a, b):
compiled_add = tp.compile(add, args=[tp.InputInfo(([1, 2, 3],), dtype=tp.float32), tp.InputInfo(([1, 2, 3],), dtype=tp.float32)])
print(compiled_add.get_output_info())
"""
output_info = []
offset = self._executable_signature.get_num_input_args()
for idx in range(self._executable_signature.get_num_output_args()):
output_info.append(self._get_arg_info(idx + offset))
return output_info
num_input_args = self._executable_signature.get_num_input_args()
num_output_args = self._executable_signature.get_num_output_args()
num_results = self._executable_signature.get_num_results()

assert not (num_output_args and num_results), "Cannot have both output arguments and results"

if num_output_args:
return [self._get_arg_info(idx + num_input_args) for idx in range(num_output_args)]
else:
return [self._get_result_info(idx) for idx in range(num_results)]

def save(self, path: str) -> None:
"""
Expand Down Expand Up @@ -289,7 +299,6 @@ def add(a, b):
def encode_executable(executable):
return {
"arg_names": executable._arg_names,
"output_devices": executable._output_devices,
"executable": base64.b64encode(executable._executable.serialize()).decode(),
}

Expand All @@ -300,5 +309,4 @@ def decode_executable(executable_dict):
return Executable(
runtime.Executable(executable_bytes),
executable_dict["arg_names"],
executable_dict["output_devices"],
)
1 change: 1 addition & 0 deletions tripy/tripy/backend/mlir/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def _make_mlir_opts(self, trt_builder_opt_level):
f"--tensorrt-timing-cache-path={G_TIMING_CACHE_FILE}",
f"--tensorrt-builder-opt-level={trt_builder_opt_level}",
"--tensorrt-strongly-typed=True",
"--enable-non-dps-returns",
]
if config.enable_mlir_debug or config.enable_tensorrt_debug:
opts.append("--debug=true")
Expand Down
121 changes: 7 additions & 114 deletions tripy/tripy/backend/mlir/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,89 +31,17 @@

class Executor:
def __init__(self, executable: runtime.Executable) -> None:

runtime.GlobalDebug.flag = True
debug_types = ["allocator", "runtime"]
runtime.GlobalDebug.set_types(debug_types)
self.runtime_client = MLIRRuntimeClient()
session_options = runtime.RuntimeSessionOptions(num_devices=1, device_id=0)
self.session = runtime.RuntimeSession(session_options, executable)
self.device = self.runtime_client.get_devices()[0] # Assume a single device is available.
self.signature = executable.get_signature("main")
self.stream = default_stream()
self.num_input_args = self.signature.get_num_input_args()
self.num_output_args = self.signature.get_num_output_args()
self.output_args = [
self.signature.get_arg(index + self.num_input_args) for index in range(self.num_output_args)
]
self.output_memrefs = [runtime.MemRefType(out) for out in self.output_args]

def _create_shape_memref(self, shape):
shape = make_tuple(shape)
if len(shape) == 0:
return create_memref(
shape=(0,),
dtype=datatype.int64,
device=device("cpu"),
)
return create_memref(
array=convert_list_to_array(shape, datatype.int64),
shape=(len(shape),),
dtype=datatype.int64,
device=device("cpu"),
)

def _get_outputs_shape(self):
outputs_shape = []
all_outputs_known = True
for memref in self.output_memrefs:
outputs_shape.append(memref.shape)
all_outputs_known &= all(dim >= 0 for dim in memref.shape)
return outputs_shape, all_outputs_known

def _get_inputs_runtime_shape(self, inputs):
inputs_shape = []
for input in inputs:
inputs_shape.append(input.trace_tensor.producer.data.shape)
return inputs_shape

def _execute_shape_inference(self, inputs_shape, outputs_shape):
inputs_shape_memref = [self._create_shape_memref(inp_shape) for inp_shape in inputs_shape]
outputs_shape_memref = [self._create_shape_memref(out_shape) for out_shape in outputs_shape]
self.session.execute_function(
name=self.signature.get_shape_func_name(), in_args=inputs_shape_memref, out_args=outputs_shape_memref
)

outputs_runtime_shape = [memoryview(s).tolist() for s in outputs_shape_memref]
return outputs_runtime_shape

def _get_output_tensor_info(self, outputs_runtime_shape, output_devices):
outputs_tensor_info = []
for index in range(self.num_output_args):
memref = self.output_memrefs[index]
dtype = convert_runtime_dtype_to_tripy_dtype(memref.dtype)

output_device = output_devices[index]
if not output_device:
output_device = device(("gpu" if memref.address_space == runtime.PointerType.device else "cpu", 0))

runtime_shape = [rs if dim < 0 else dim for dim, rs in zip(memref.shape, outputs_runtime_shape[index])]
outputs_tensor_info.append(
TensorInfo(
len(runtime_shape),
tuple(runtime_shape),
dtype,
output_device,
)
)
return outputs_tensor_info

def get_output_tensor_runtime_info(self, inputs, output_devices=List[device]):
outputs_shape, all_outputs_known = self._get_outputs_shape()
if not all_outputs_known:
inputs_shape = self._get_inputs_runtime_shape(inputs)
outputs_shape = self._execute_shape_inference(inputs_shape, outputs_shape)
output_tensor_info = self._get_output_tensor_info(outputs_shape, output_devices)
return output_tensor_info

def execute(self, output_devices: List[device], inputs: List["Tensor"] = []) -> List[runtime.MemRefValue]:
def execute(self, inputs: List["Tensor"] = []) -> List[runtime.MemRefValue]:
in_args = []
for inp in inputs:
memref = inp.trace_tensor.producer.data
Expand All @@ -131,45 +59,10 @@ def execute(self, output_devices: List[device], inputs: List["Tensor"] = []) ->
)
in_args.append(memref)

# HACK (#155): Remove `get_devices` once executable output tensor location matches Trace IR.
out_tensor_info = self.get_output_tensor_runtime_info(inputs, output_devices)

# Allocate output memory and store buffer pointers.
outputs = [
create_memref(
shape=info.shape, dtype=info.dtype, device=info.device, stream=self.stream._active_cuda_stream
)
for info in out_tensor_info
]

out_args = []
for out in outputs:
memref = out
# HACK (#155): MLIR-TensorRT requires inputs to be on device.
# Remove explicit copy to device once #155 is addressed.
if memref.address_space != runtime.PointerType.device:
memref = self.runtime_client.copy_to_device(
host_memref=memref,
device=self.runtime_client.get_devices()[0],
stream=self.stream._active_cuda_stream,
)
if not memref:
raise_error("Could not allocate output memref", details=memref.error_details)
out_args.append(memref)

# Execute and populate device pointers.
self.session.execute_function(
"main", in_args=in_args, out_args=out_args, stream=self.stream._active_cuda_stream
outputs = self.session.execute_function(
"main", in_args=in_args, stream=self.stream._active_cuda_stream, client=self.runtime_client
)

# For outputs that were on the host, do the copy back
# TODO(#155): MLIR-TensorRT should allow output tensor placements on host.
for idx, out_info in enumerate(out_tensor_info):
if out_info.device.kind != "gpu":
self.runtime_client.copy_to_host(
device_memref=out_args[idx],
existing_host_memref=outputs[idx],
stream=self.stream._active_cuda_stream,
)

# For now return results on GPU.
return outputs
9 changes: 9 additions & 0 deletions tripy/tripy/flat_ir/ops/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ class CopyOp(BaseFlatIROp):

target: tripy.common.device

def set_memory_space_attr(self, tensor, mem_space_attr):
current_type = tensor.type
# Set the encoding attribute on the operation's result
new_type = ir.RankedTensorType.get(current_type.shape, current_type.element_type, encoding=mem_space_attr)
tensor.set_type(new_type)

def to_mlir(self, operands):
from mlir_tensorrt.compiler.dialects import bufferization, tensor, arith

Expand All @@ -46,7 +52,10 @@ def to_mlir(self, operands):
sliced_dims.append(dim)

alloc_tensor = bufferization.alloc_tensor(inp_type, sliced_dims, memory_space=mem_space_attr)
self.set_memory_space_attr(alloc_tensor, mem_space_attr)
result_tensor = bufferization.materialize_in_destination(inp_type, operands[0], alloc_tensor)
self.set_memory_space_attr(result_tensor, mem_space_attr)
cast_tensor = tensor.cast(self.outputs[0].to_mlir(), result_tensor)
self.set_memory_space_attr(cast_tensor, mem_space_attr)

return [cast_tensor]
6 changes: 3 additions & 3 deletions tripy/tripy/frontend/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,11 @@ def eval(self) -> runtime.MemRefValue:

compiler = Compiler(trt_builder_opt_level=0)
executable = compiler.compile(mlir, flat_ir=flat_ir)
executor = Executor(executable)
self.executor = Executor(executable)
# Upon computing the value of this tensor, we switch it to have a `Storage`
# parameter so that it does not need to be computed again.
data = executor.execute([out.device for out in flat_ir.outputs])
executor.stream.synchronize()
data = self.executor.execute()
self.executor.stream.synchronize()
assert len(data) == 1, "Expects only one output from mlir_tensorrt.compiler executor"
data = data[0]

Expand Down

0 comments on commit 780e18b

Please sign in to comment.