From 959e069eea3be1bdc469235b1a07d0f36ec9695d Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Tue, 12 Nov 2024 14:14:14 -0500 Subject: [PATCH 01/29] [Tripy] Check type annotations for variadic arguments in the function registry (#345) Addresses issue #341. --- tripy/tests/test_function_registry.py | 25 ++++++++++++++++++++++++- tripy/tripy/function_registry.py | 26 ++++++++++++++++++++++---- 2 files changed, 46 insertions(+), 5 deletions(-) diff --git a/tripy/tests/test_function_registry.py b/tripy/tests/test_function_registry.py index 20dcadf57..620d39c02 100644 --- a/tripy/tests/test_function_registry.py +++ b/tripy/tests/test_function_registry.py @@ -255,10 +255,22 @@ def func(a: float): def test_variadic_positional_args(self, registry): @registry("test") - def func(*args: List[Any]): + def func(*args: Any): return sum(args) + assert registry["test"](1.0) == 1.0 assert registry["test"](1.0, 2.0, 3.0) == 6.0 + assert registry["test"]() == 0 + + def test_variadic_positional_and_keyword_args(self, registry): + # ensure the interaction succeeds + @registry("test") + def func(a: int, *args: int, b: float, c: str): + return a + sum(args) + int(b) + len(c) + + assert registry["test"](3, b=1.0, c="ab") == 6 + assert registry["test"](3, 4, b=1.0, c="ab") == 10 + assert registry["test"](3, 4, 5, b=1.0, c="ab") == 15 def test_variadic_keyword_args(self, registry): @registry("test") @@ -469,6 +481,17 @@ def func(n: Sequence[Union[int, float]]) -> int: ): registry["test"](["a", "b", "c"]) + def test_error_variadic_positional_arg_mismatch(self, registry): + @registry("test") + def func(a: int, *args: int) -> int: + return a + sum(args) + + with helper.raises( + TripyException, + match="Not a valid overload because: For parameter: 'args', expected an instance of type: 'int' but got argument of type: 'str'", + ): + registry["test"](1, 2, 3, 4, "hi") + @pytest.mark.parametrize( "typ, expected", diff --git a/tripy/tripy/function_registry.py b/tripy/tripy/function_registry.py index 4f353bd43..4481fac6f 100644 --- a/tripy/tripy/function_registry.py +++ b/tripy/tripy/function_registry.py @@ -147,6 +147,8 @@ def _get_annotations(self): return self.annotations def matches_arg_types(self, args, kwargs) -> "Result": + from itertools import chain + from tripy.utils.result import Result def matches_type(name: str, annotation: type, arg: Any) -> bool: @@ -197,14 +199,30 @@ def matches_type(name: str, annotation: type, arg: Any) -> bool: annotations = self._get_annotations() # Check if we have too many positional arguments. We can only do this if there isn't a variadic positional argument. - if not any(annotation.kind == inspect.Parameter.VAR_POSITIONAL for annotation in annotations.values()) and len( - args - ) > len(annotations): + annotation_items = list(annotations.items()) + variadic_idx = None + for idx, (_, annotation) in enumerate(annotation_items): + # there can only be at most one variadic arg and it must come after all positional args and before keyword-only args + if annotation.kind == inspect.Parameter.VAR_POSITIONAL: + variadic_idx = idx + break + + if variadic_idx is None and len(args) > len(annotations): return Result.err( [f"Function expects {len(annotations)} parameters, but {len(args)} arguments were provided."], ) - for (name, annotation), arg in zip(annotations.items(), args): + # If there is a variadic positional arg, we can copy the final annotation for the remaining args. + # Keyword-only args (only possible with a variadic arg) will appear in kwargs and don't need to be checked here. + if variadic_idx is not None: + positional_args_to_check = chain( + zip(annotation_items[:variadic_idx], args), + map(lambda arg: (annotation_items[variadic_idx], arg), args[len(annotations) - 1 :]), + ) + else: + positional_args_to_check = zip(annotation_items, args) + + for (name, annotation), arg in positional_args_to_check: if not matches_type(name, annotation.type_info, arg): return Result.err( [ From bbef6e175a039d5108214f48e62dfc513aa532d9 Mon Sep 17 00:00:00 2001 From: Faraz <58580514+farazkh80@users.noreply.github.com> Date: Tue, 12 Nov 2024 16:33:36 -0500 Subject: [PATCH 02/29] [tripy] Sequential module feature branch (#321) Sequential similar to [torch.nn.sequential ](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html) for addressing https://github.com/NVIDIA/TensorRT-Incubator/issues/295 Things done - [x] new frontend module called `tp.Sequential` - [x] Unit tests under `tests/frontend/module/test_sequential.py` - [x] integration tests and comparison with torch under Supports nested tp.Sequential and most list operations other than modifications. `tp.Sequential` can not be modified after creation. --------- Signed-off-by: Faraz Khoubsirat Signed-off-by: Faraz <58580514+farazkh80@users.noreply.github.com> Co-authored-by: Faraz Khoubsirat Co-authored-by: pranavm-nvidia <49246958+pranavm-nvidia@users.noreply.github.com> --- .../tests/frontend/module/test_sequential.py | 210 ++++++++++++++++++ tripy/tests/integration/test_sequential.py | 176 +++++++++++++++ tripy/tripy/frontend/module/sequential.py | 181 +++++++++++++++ 3 files changed, 567 insertions(+) create mode 100644 tripy/tests/frontend/module/test_sequential.py create mode 100644 tripy/tests/integration/test_sequential.py create mode 100644 tripy/tripy/frontend/module/sequential.py diff --git a/tripy/tests/frontend/module/test_sequential.py b/tripy/tests/frontend/module/test_sequential.py new file mode 100644 index 000000000..3a751f99f --- /dev/null +++ b/tripy/tests/frontend/module/test_sequential.py @@ -0,0 +1,210 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# SPDX-LicenseCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import cupy as cp +import numpy as np +import pytest + +import tripy as tp +from tests import helper +from textwrap import dedent + + +@pytest.fixture +def sequential_network(): + yield tp.Sequential(tp.Linear(1, 3), tp.Linear(3, 2)) + + +@pytest.fixture +def dict_sequential_network(): + yield tp.Sequential({"layer1": tp.Linear(1, 3), "layer2": tp.Linear(3, 2)}) + + +@pytest.fixture +def nested_sequential_network(): + yield tp.Sequential(tp.Linear(2, 4), tp.Sequential(tp.Linear(4, 3), tp.Linear(3, 1))) + + +class TestSequential: + def test_basic_structure(self, sequential_network): + assert len(sequential_network) == 2 + + assert isinstance(sequential_network[0], tp.Linear) + assert np.array_equal( + cp.from_dlpack(sequential_network[0].weight), cp.from_dlpack(sequential_network[0].weight) + ) + assert np.array_equal(cp.from_dlpack(sequential_network[0].bias), cp.from_dlpack(sequential_network[0].bias)) + + def test_named_children(self, sequential_network): + expected_names = [("0", sequential_network[0]), ("1", sequential_network[1])] + assert list(sequential_network.named_children()) == expected_names + + def test_forward_pass(self, sequential_network): + input_data = tp.Tensor([1.0]) + output = sequential_network(input_data) + assert output.shape == [1, 2] + + def test_state_dict(self, sequential_network): + state_dict = sequential_network.state_dict() + param_count = sum(len(dict(m.named_parameters())) for m in sequential_network) + assert len(state_dict) == param_count + + expected_state_dict_keys = ["0.weight", "0.bias", "1.weight", "1.bias"] + assert list(state_dict.keys()) == expected_state_dict_keys + + def test_load_state_dict(self, sequential_network): + new_state_dict = {"0.weight": tp.Parameter(tp.ones((3, 1)))} + sequential_network.load_state_dict(new_state_dict, strict=False) + assert np.array_equal(cp.from_dlpack(sequential_network[0].weight), cp.from_dlpack(new_state_dict["0.weight"])) + + def test_modify_parameters(self, sequential_network): + new_param = tp.Parameter(tp.ones((2, 3))) + sequential_network[1].weight = new_param + assert sequential_network[1].weight is new_param + + def test_invalid_index_access(self, sequential_network): + with helper.raises(tp.TripyException, match="Key: '2' not found in modules"): + _ = sequential_network[2] + + def test_str_representation(self, sequential_network): + expected_str = dedent( + """\ + Sequential( + 0= + Linear( + weight=[3, 1], + bias=[3], + ), + 1= + Linear( + weight=[2, 3], + bias=[2], + ), + )""" + ) + assert str(sequential_network) == expected_str + + +class TestDictSequential: + def test_basic_structure(self, dict_sequential_network): + assert len(dict_sequential_network) == 2 + assert isinstance(dict_sequential_network["layer1"], tp.Linear) + assert isinstance(dict_sequential_network["layer2"], tp.Linear) + + def test_named_children(self, dict_sequential_network): + expected_names = [("layer1", dict_sequential_network["layer1"]), ("layer2", dict_sequential_network["layer2"])] + assert list(dict_sequential_network.named_children()) == expected_names + + def test_forward_pass(self, dict_sequential_network): + input_data = tp.Tensor([[1.0]]) + output = dict_sequential_network(input_data) + assert output.shape == [1, 2] + + def test_state_dict(self, dict_sequential_network): + state_dict = dict_sequential_network.state_dict() + expected_keys = ["layer1.weight", "layer1.bias", "layer2.weight", "layer2.bias"] + assert list(state_dict.keys()) == expected_keys + + def test_load_state_dict(self, dict_sequential_network): + new_state_dict = {"layer1.weight": tp.Parameter(tp.ones((3, 1)))} + dict_sequential_network.load_state_dict(new_state_dict, strict=False) + assert np.array_equal( + cp.from_dlpack(dict_sequential_network["layer1"].weight), cp.from_dlpack(new_state_dict["layer1.weight"]) + ) + + def test_modify_parameters(self, dict_sequential_network): + new_weight = tp.Parameter(tp.ones((2, 3))) + dict_sequential_network["layer2"].weight = new_weight + assert dict_sequential_network["layer2"].weight is new_weight + + def test_str_representation(self, dict_sequential_network): + expected_str = dedent( + """\ + Sequential( + layer1= + Linear( + weight=[3, 1], + bias=[3], + ), + layer2= + Linear( + weight=[2, 3], + bias=[2], + ), + )""" + ) + assert str(dict_sequential_network) == expected_str + + +class TestNestedSequential: + def test_basic_structure(self, nested_sequential_network): + # Check that the top-level Sequential has two layers and that one of them is a nested Sequential + assert len(nested_sequential_network) == 2 + assert isinstance(nested_sequential_network[1], tp.Sequential) + + def test_named_children_top_level(self, nested_sequential_network): + expected_names = [ + ("0", nested_sequential_network[0]), + ("1", nested_sequential_network[1]), + ] + assert list(nested_sequential_network.named_children()) == expected_names + + def test_named_children_nested(self, nested_sequential_network): + expected_names = [ + ("0", nested_sequential_network[1][0]), + ("1", nested_sequential_network[1][1]), + ] + assert list(nested_sequential_network[1].named_children()) == expected_names + + def test_load_state_dict_nested(self, nested_sequential_network): + # Loading state dict with parameters for both top-level and nested modules + new_state_dict = { + "1.1.weight": tp.Parameter(tp.ones((1, 3))), + } + nested_sequential_network.load_state_dict(new_state_dict, strict=False) + assert np.array_equal( + cp.from_dlpack(nested_sequential_network[1][1].weight), cp.from_dlpack(new_state_dict["1.1.weight"]) + ) + + def test_str_representation(self, nested_sequential_network): + expected_str = dedent( + """\ + Sequential( + 0= + Linear( + weight=[4, 2], + bias=[4], + ), + 1= + Sequential( + 0= + Linear( + weight=[3, 4], + bias=[3], + ), + 1= + Linear( + weight=[1, 3], + bias=[1], + ), + ), + )""" + ) + assert str(nested_sequential_network) == expected_str diff --git a/tripy/tests/integration/test_sequential.py b/tripy/tests/integration/test_sequential.py new file mode 100644 index 000000000..b6ef3e260 --- /dev/null +++ b/tripy/tests/integration/test_sequential.py @@ -0,0 +1,176 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict + +import torch +import tripy as tp + + +class TestSequential: + def test_basic_forward_pass_accuracy(self): + torch_model = torch.nn.Sequential( + torch.nn.Linear(1, 3, dtype=torch.float32, device="cuda"), + torch.nn.Linear(3, 2, dtype=torch.float32, device="cuda"), + ) + tp_model = tp.Sequential(tp.Linear(1, 3, dtype=tp.float32), tp.Linear(3, 2, dtype=tp.float32)) + + tp_model[0].weight = tp.Parameter(torch_model[0].weight.detach()) + tp_model[0].bias = tp.Parameter(torch_model[0].bias.detach()) + tp_model[1].weight = tp.Parameter(torch_model[1].weight.detach()) + tp_model[1].bias = tp.Parameter(torch_model[1].bias.detach()) + + input_tensor = torch.tensor([[1.0]], dtype=torch.float32, device="cuda") + tp_input = tp.Tensor(input_tensor, dtype=tp.float32) + + tp_output = tp_model(tp_input) + + torch_model.eval() + with torch.no_grad(): + torch_output = torch_model(input_tensor) + + rtol_ = 2e-6 + assert torch.allclose(torch.from_dlpack(tp_output), torch_output, rtol=rtol_) + + def test_dict_forward_pass_accuracy(self): + torch_model = torch.nn.Sequential( + torch.nn.Linear(1, 3, dtype=torch.float32, device="cuda"), + torch.nn.Linear(3, 2, dtype=torch.float32, device="cuda"), + ) + + tp_model = tp.Sequential( + {"layer1": tp.Linear(1, 3, dtype=tp.float32), "layer2": tp.Linear(3, 2, dtype=tp.float32)} + ) + + tp_model["layer1"].weight = tp.Parameter(torch_model[0].weight.detach()) + tp_model["layer1"].bias = tp.Parameter(torch_model[0].bias.detach()) + tp_model["layer2"].weight = tp.Parameter(torch_model[1].weight.detach()) + tp_model["layer2"].bias = tp.Parameter(torch_model[1].bias.detach()) + + input_tensor = torch.tensor([[1.0]], dtype=torch.float32, device="cuda") + tp_input = tp.Tensor(input_tensor, dtype=tp.float32) + + tp_output = tp_model(tp_input) + + torch_model.eval() + with torch.no_grad(): + torch_output = torch_model(input_tensor) + + rtol_ = 2e-6 + assert torch.allclose( + torch.from_dlpack(tp_output), torch_output, rtol=rtol_ + ), "Forward pass outputs do not match." + + def test_nested_forward_pass_accuracy(self): + torch_model = torch.nn.Sequential( + torch.nn.Linear(1, 3, dtype=torch.float32, device="cuda"), + torch.nn.Sequential( + torch.nn.Linear(3, 4, dtype=torch.float32, device="cuda"), + torch.nn.Linear(4, 2, dtype=torch.float32, device="cuda"), + ), + ) + tp_model = tp.Sequential( + tp.Linear(1, 3, dtype=tp.float32), + tp.Sequential(tp.Linear(3, 4, dtype=tp.float32), tp.Linear(4, 2, dtype=tp.float32)), + ) + + tp_model[0].weight = tp.Parameter(torch_model[0].weight.detach()) + tp_model[0].bias = tp.Parameter(torch_model[0].bias.detach()) + tp_model[1][0].weight = tp.Parameter(torch_model[1][0].weight.detach()) + tp_model[1][0].bias = tp.Parameter(torch_model[1][0].bias.detach()) + tp_model[1][1].weight = tp.Parameter(torch_model[1][1].weight.detach()) + tp_model[1][1].bias = tp.Parameter(torch_model[1][1].bias.detach()) + + input_tensor = torch.tensor([[1.0]], dtype=torch.float32, device="cuda") + tp_input = tp.Tensor(input_tensor, dtype=tp.float32) + + tp_output = tp_model(tp_input) + + torch_model.eval() + with torch.no_grad(): + torch_output = torch_model(input_tensor) + + rtol_ = 2e-6 + assert torch.allclose(torch.from_dlpack(tp_output), torch_output, rtol=rtol_) + + def test_basic_state_dict_comparison(self): + torch_model = torch.nn.Sequential( + torch.nn.Linear(1, 3, dtype=torch.float32), torch.nn.Linear(3, 2, dtype=torch.float32) + ) + tp_model = tp.Sequential(tp.Linear(1, 3, dtype=tp.float32), tp.Linear(3, 2, dtype=tp.float32)) + + tp_model[0].weight = tp.Parameter(torch_model[0].weight.detach()) + tp_model[0].bias = tp.Parameter(torch_model[0].bias.detach()) + tp_model[1].weight = tp.Parameter(torch_model[1].weight.detach()) + tp_model[1].bias = tp.Parameter(torch_model[1].bias.detach()) + + torch_state_dict = torch_model.state_dict() + tp_state_dict = tp_model.state_dict() + + for name, torch_param in torch_state_dict.items(): + tp_param = tp_state_dict[name] + assert torch.allclose(torch_param, torch.from_dlpack(tp_param), rtol=1e-5), f"Mismatch in {name}" + + def test_dict_sequential_state_dict_comparison(self): + torch_model = torch.nn.Sequential( + OrderedDict( + [ + ("layer1", torch.nn.Linear(1, 3, dtype=torch.float32)), + ("layer2", torch.nn.Linear(3, 2, dtype=torch.float32)), + ] + ) + ) + + tp_model = tp.Sequential( + {"layer1": tp.Linear(1, 3, dtype=tp.float32), "layer2": tp.Linear(3, 2, dtype=tp.float32)} + ) + + tp_model["layer1"].weight = tp.Parameter(torch_model[0].weight.detach()) + tp_model["layer1"].bias = tp.Parameter(torch_model[0].bias.detach()) + tp_model["layer2"].weight = tp.Parameter(torch_model[1].weight.detach()) + tp_model["layer2"].bias = tp.Parameter(torch_model[1].bias.detach()) + + torch_state_dict = torch_model.state_dict() + tp_state_dict = tp_model.state_dict() + + for name, torch_param in torch_state_dict.items(): + tp_param = tp_state_dict[name] + assert torch.allclose(torch_param, torch.from_dlpack(tp_param), rtol=1e-5), f"Mismatch in {name}" + + def test_nested_sequential_state_dict_comparison(self): + torch_model = torch.nn.Sequential( + torch.nn.Linear(1, 3, dtype=torch.float32), + torch.nn.Sequential(torch.nn.Linear(3, 4, dtype=torch.float32), torch.nn.Linear(4, 2, dtype=torch.float32)), + ) + + tp_model = tp.Sequential( + tp.Linear(1, 3, dtype=tp.float32), + tp.Sequential(tp.Linear(3, 4, dtype=tp.float32), tp.Linear(4, 2, dtype=tp.float32)), + ) + + tp_model[0].weight = tp.Parameter(torch_model[0].weight.detach()) + tp_model[0].bias = tp.Parameter(torch_model[0].bias.detach()) + tp_model[1][0].weight = tp.Parameter(torch_model[1][0].weight.detach()) + tp_model[1][0].bias = tp.Parameter(torch_model[1][0].bias.detach()) + tp_model[1][1].weight = tp.Parameter(torch_model[1][1].weight.detach()) + tp_model[1][1].bias = tp.Parameter(torch_model[1][1].bias.detach()) + + torch_state_dict = torch_model.state_dict() + tp_state_dict = tp_model.state_dict() + + for name, torch_param in torch_state_dict.items(): + tp_param = tp_state_dict[name] + assert torch.allclose(torch_param, torch.from_dlpack(tp_param), rtol=1e-5), f"Mismatch in {name}" diff --git a/tripy/tripy/frontend/module/sequential.py b/tripy/tripy/frontend/module/sequential.py new file mode 100644 index 000000000..d30982be9 --- /dev/null +++ b/tripy/tripy/frontend/module/sequential.py @@ -0,0 +1,181 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Union, Tuple, Dict, Iterator, Any +from dataclasses import dataclass +import copy + +from tripy import export +from tripy.common.exception import raise_error +from tripy.frontend.module import Module + + +@export.public_api( + document_under="modules/sequential.rst", +) +@dataclass +class Sequential(Module): + r""" + A module to stack multiple layers or modules in a sequential order. The `Sequential` + container can accept either a list of modules or a dictionary of named modules. Modules are + added in the order they are passed, and each is called sequentially during the forward pass. + """ + + def __init__(self, *modules: Union[Module, Dict[str, Module]]) -> None: + r""" + Args: + *modules: The modules to include in the sequence. + Can be passed as individual positional arguments or as a single dictionary of named modules. + + .. code-block:: python + :linenos: + :caption: Sequential with Positional Arguments + + model = tp.Sequential(tp.Linear(1, 3), tp.Linear(3, 2)) + + input = tp.Tensor([1.0]) + output = model(input) + + .. code-block:: python + :linenos: + :caption: Sequential with a Dictionary + + model = tp.Sequential({'layer1': tp.Linear(1, 3), 'layer2': tp.Linear(3, 2)}) + + input = tp.Tensor([1.0]) + output = model(input) + """ + super().__init__() + self.modules = {} + + if len(modules) == 1 and isinstance(modules[0], dict): + self.modules = copy.copy(modules[0]) + else: + for idx, module in enumerate(modules): + self.modules[str(idx)] = module + + def __call__(self, input: "tripy.Tensor") -> "tripy.Tensor": + r""" + Defines the forward pass by applying each module in the container sequentially to `input` + + Args: + input: The input tensor to pass through the sequence of modules. + + Returns: + The output tensor after passing through each module in sequence. + """ + for module in self.modules.values(): + input = module(input) + return input + + def __getattr__(self, name: str) -> Any: + """ + Custom __getattr__ to search both in `modules` dictionary and in other attributes. This is for handling + `module = operator.attrgetter(child_name)(module)` calls in tripy/frontend/module/module.py:load_state_dict + """ + if name in self.modules: + return self.modules[name] + + # Fallback to regular attribute access if not found in modules + return super().__getattr__(name) + + def __len__(self) -> int: + r""" + Returns the total number of modules in the sequence. + + Returns: + The number of modules in the sequence. + + .. code-block:: python + :linenos: + :caption: Example + + # doc: print-locals model length + + model = tp.Sequential(tp.Linear(1, 64), tp.Linear(64, 128)) + length = len(model) + assert length == 2 + """ + return len(self.modules) + + def __iter__(self) -> Iterator[Module]: + r""" + Returns an iterator over the modules in the sequence. + + Returns: + An iterator over the modules. + + .. code-block:: python + :linenos: + :caption: Example + + model = tp.Sequential(tp.Linear(1, 3), tp.Linear(3, 2)) + for layer in model: + print(layer) + """ + return iter(self.modules.values()) + + def __getitem__(self, idx: Union[int, str]) -> Module: + r""" + Accesses a module by index (int) or name (str). + + Args: + idx: The index or name of the module to retrieve. + + Returns: + The module at the specified index or name. + + Raises: + TypeError: If `idx` is not an int or str. + + .. code-block:: python + :linenos: + :caption: Example + + model = tp.Sequential(tp.Linear(1, 3), tp.Linear(3, 2)) + print(model[1]) + """ + key = str(idx) if isinstance(idx, int) else idx + + if key not in self.modules: + raise_error( + f"Key: '{key}' not found in modules.", [f"Note: Available keys were: {list(self.modules.keys())}"] + ) + + return self.modules[key] + + def named_children(self) -> Iterator[Tuple[str, "Module"]]: + r""" + Returns an iterator over all the first-order modules in this `Sequential` container. + Each child module is represented by its name and the module object itself. + + Returns: + An iterator over tuples containing + the name and module of each child. + + .. code-block:: python + :linenos: + :caption: Example + + model = tp.Sequential(tp.Linear(1, 3), tp.Linear(3, 2)) + + for name, child in model.named_children(): + print(f"{name}: {type(child).__name__}") + + """ + # Overriding the base implementation to prevent displaying every child module + # with the 'modules' prefix in the state_dict. This change ensures compatibility + # with PyTorch's naming conventions. + for name, module in self.modules.items(): + yield name, module From 16303566158ccb70ee44460d0d3bb082965692c8 Mon Sep 17 00:00:00 2001 From: Christopher Bate Date: Wed, 13 Nov 2024 09:39:21 -0700 Subject: [PATCH 03/29] Integrate changes from internal (#362) ## [Dialect/TensorRT,StablehloToTensorRT] Add better support for dynamic n-d iota This change adds better support for converting `stablehlo.dynamic_iota` to TensorRT operations. It also expands testing for dynamic `tensorrt.linspace` and corrects some documentation issues for `tensorrt.linspace`. Finally, it plugs a gap in serialization of i64 tensor constants in the TensorRT translation utilities. ## [compiler] Better organize bounds attribute conversions and verifications Prior to this change, we sporadically used attribute names like "tensorrt.value_bounds" and "tensorrt.shape_profile" throughout the compilation process despite using different attributes under those names at different stages. This change ensures that the attribute names referring to a) shape bounds and b) value bounds are properly converted at each stage along with the actual attribute types. Each of the tensorrt, plan, and executor dialects have their own bounds attribute types, and they now have different key names (e.g. `(tensorrt|plan|executor).shape_profile`) to be used with argument/result attributes. Region argument attribute verifiers are added to the Plan and Executor dialects. Additional tests are added. Additionally, we add a op ASM aliasing callbacks to the Plan dialect since the bounds attributes can be very verbose after closed cluster formation. ## [Dialect/TensorRT] Fix verification issues in `tensorrt.dequantize` operation Previously, the `tensorrt.dequantize` operation was missing some restrictions: 1. The input rank must be non-zero. 2. The input shape final dimension must be even if the input element type is a sub-byte type. We weren't catching point (2) earlier because our main CI workflow was using a version of TensorRT that was missing explicit validation of this constraint (which was a bug). ## [Dialect/Plan] Fix clustering of stablehlo ops into host/scalarized clusters This change fixes the clustering of StableHLO ops when `plan.with_value` operations are interleaved. It is fine to add `plan.with_values` operations to host clusters, just like TRT clusters when the clustering options allow shape tensor calculations to be offloaded to TRT. ## [Plan/RefineTypes] Improve logic for refining types in 'plan-refine-types' pass Updates the rewrite that uses `plan.with_shape`` to refine the types of the stablehlo/tensorrt operation producers. Previously, had an unnecessary "single use" restriction that blocked some type refinements from occurring. ## [tensorrt] Fix bug in `tensorrt.slice` canonicalizer The canonicalizer for `tensorrt.slice` was incorrectly allowing the builder to set the type, which could incorrectly replace the original result type without using a cast. Fix and add tests. ## [tensorrt] Fix >2 operand shape broadcasting validity check Fixes an issue where a TensorRT operation that allows implicit broadcasting of > 2 operands (e.g. `tensorrt.select`) would incorrectly complain about some valid shape configurations due. Adds additional regression test. ## [StableHloToTensorRT] Add support for converting some specific kinds of `stablehlo.dynamic_gather` to TensorRT This commit adds a straightforward extension to the "stablehlo-to-tensorrt" conversions for converting `stablehlo.dynamic_gather` to `tensorrt.gather`. The specific case that this change adds support for is when `stablehlo.dynamic_gather` represents a "simple, single dimension gather" with implicit dimension index (for complete definition and examples of this term, see the doc comments in 'compiler/include/mlir-tensorrt/Dialect/StableHloExt/Utils/GatherScatterUtils.h'). ## NFC: [tensorrt] add translation test for `tensorrt.gather` with dynamic dimensions ## [StableHloExt] Migrate last remaining patterns in the 'tensorrt-stablehlo-preprocessing' passs This change eliminates the `tensorrt-stablehlo-input-preprocessing` pass. The only remaining patterns in that pass after the previous changes were a pattern to lower `chlo.erf` to stablehlo ops (duplicated from upstream). We can now isolate that requirement to selectively convert CHLO ops into a new pass, `convert-chlo-to-stablehlo-ext`. The new pass uses the upstream patterns and the MLIR dialect conversion infra to choose which CHLO operations to preserve. We can then eliminate the duplicated code which was being used in the old pass. This also provides an opportunity to cleanup the various CHLO options present in the preprocessing pass pipeline. ## [StableHloExt] Migrate `stablehlo.logical_shift_right` simplification pattern to `stablehlo-ext-constant-folding` This change moves a pattern that folds trivial `stablehlo.logical_shift_right` from the `tensorrt-stablehlo-input-preprocessing` pass to `stablehlo-ext-constant-folding`. The pattern is improved to solve two latent bugs which were never encountered previously. Additional tests are added to get better coverage of dynamic shape cases where the original pattern would have crashed. ## [StableHloExt] Improve organization of StableHlo pre-processing passes This change improves the organization of transforms under StableHloExt in the following ways: - Adds an additional pass 'stablehlo-ext-canonicalize-convolution' specifically for testing the convolution canonicalization patterns. - Folds the patterns from `StableHloPrepareScatter.cpp` into `CanonicalizeScatter.cpp`, allowing the former to be eliminated. ## NFC: Move "stablehlo-input-preprocessing" and "stablehlo-raise-qdq" passes under the StableHloExt folder ## NFC: Don't unnecessarily restrict StableHLO preprocessing passes to 'func' ops ## NFC: [StableHloExt] Consolidate scatter/gather utilities into common translation unit and under `stablehlo_ext` namespace This change consolidates utility functions related to scatter/gather ops into the "ScatterGatherUtils.(h|cpp)" translation unit and changes all the functions in those files to be under the 'stablehlo_ext' namespace. This is part one of a series of changes that will attempt to consolidate the stablehlo simplification passes under 'compiler/lib/Transforms' with the transformations under 'compiler/lib/Dialect/StableHloExt/Transforms'. The fact that the simplification transforms are currently split into these two different sets is a product of the last major restructuring of the project. Co-authored-by: Copybara Bot --- .../mlir-tensorrt/Conversion/Passes.td | 24 ++ .../StablehloToTensorRT/StablehloToTensorRT.h | 4 +- .../Dialect/Plan/Analysis/BoundsAnalysis.h | 2 + .../Dialect/Plan/IR/PlanAttributes.td | 9 +- .../Dialect/Plan/IR/PlanDialect.td | 14 + .../Dialect/Plan/IR/PlanInterfaces.td | 28 ++ .../Dialect/StableHloExt/Transforms/Passes.td | 27 +- .../StableHloExt/Utils/GatherScatterUtils.h | 31 +- .../Pipelines/StableHloInputPipelines.h | 15 +- .../include/mlir-tensorrt/Utils/ShapeInfo.h | 96 ++++++ .../lib/Compiler/StableHloToExecutable.cpp | 6 +- .../compiler/lib/Conversion/CMakeLists.txt | 1 + .../ChloToStablehloExt/CMakeLists.txt | 17 + .../ChloToStablehloExt/ChloToStablehloExt.cpp | 79 +++++ .../PlanToExecutor/PlanToExecutor.cpp | 61 ++-- .../StablehloToTensorRT/CMakeLists.txt | 1 + .../StablehloToTensorRT.cpp | 156 ++++++++- .../Dialect/Plan/Analysis/BoundsAnalysis.cpp | 101 +++--- .../Dialect/Plan/IR/BuiltinClusterKinds.cpp | 79 ++++- .../lib/Dialect/Plan/IR/CMakeLists.txt | 1 + .../compiler/lib/Dialect/Plan/IR/PlanOps.cpp | 110 +++++- .../Dialect/Plan/Transforms/AllocTensors.cpp | 16 +- .../MaterializeShapeCalculations.cpp | 81 ++++- .../PopulateFunctionBoundsAttributes.cpp | 17 +- .../Dialect/Plan/Transforms/RefineTypes.cpp | 48 +-- .../StableHloExt/Transforms/CMakeLists.txt | 4 +- ...lution.cpp => CanonicalizeConvolution.cpp} | 28 ++ .../Transforms/CanonicalizeGather.cpp | 15 +- .../Transforms/CanonicalizeScatter.cpp | 127 ++++++- .../Transforms/ConstantFolding.cpp | 36 +- .../StablehloInputPreprocessing.cpp | 200 ----------- .../Transforms/StablehloPrepareScatter.cpp | 199 ----------- .../Dialect/StableHloExt/Utils/CMakeLists.txt | 1 + .../StableHloExt/Utils/GatherScatterUtils.cpp | 262 ++++++++++++-- .../compiler/lib/Pipelines/CMakeLists.txt | 1 + .../lib/Pipelines/StableHloInputPipelines.cpp | 50 +-- .../compiler/lib/Utils/CMakeLists.txt | 7 + .../compiler/lib/Utils/ShapeInfo.cpp | 41 +++ .../Executor/IR/ExecutorDialect.td | 6 +- .../executor/lib/Executor/IR/Executor.cpp | 83 ++++- .../Lua/TranslateToRuntimeExecutable.cpp | 12 +- .../executor/test/Executor/invalid.mlir | 67 ++++ .../executor/test/Executor/roundtrip.mlir | 13 +- .../mlir_tensorrt/tools/gpu_tools.py | 30 +- .../NetworkEncoder.h | 11 +- .../TensorRT/IR/TensorRTDialect.td | 4 + .../TensorRT/IR/TensorRTOps.td | 19 +- .../NetworkEncoder.cpp | 30 +- .../TensorRT/IR/TensorKindOpInterfaceImpl.cpp | 5 + .../tensorrt/lib/TensorRT/IR/TensorRT.cpp | 24 +- .../tensorrt/lib/TensorRT/IR/Verification.cpp | 43 ++- .../tensorrt/lib/Utils/ShapeUtils.cpp | 5 +- .../test/Dialect/TensorRT/canonicalize.mlir | 41 +++ .../test/Dialect/TensorRT/invalid.mlir | 32 +- .../test/Dialect/TensorRT/roundtrip.mlir | 12 + .../Target/TensorRT/TRT10/convolution.mlir | 10 +- .../tensorrt/test/Target/TensorRT/gather.mlir | 14 + .../test/Target/TensorRT/linspace.mlir | 47 ++- .../chlo-to-stablehlo-ext.mlir | 42 +++ .../ChloToStablehloExt/lit.local.cfg | 2 + .../PlanToExecutor/plan-to-executor.mlir | 16 +- .../Conversion/StablehloToScf/lit.local.cfg | 2 + .../StablehloToTensorRT/stablehlo-conv.mlir | 2 +- .../StablehloToTensorRT/stablehlo-gather.mlir | 73 ++++ .../stablehlo-to-tensorrt.mlir | 32 +- .../test/Dialect/Plan/bounds-analysis.mlir | 36 +- .../Dialect/Plan/create-closed-regions.mlir | 320 +++++++++++------- mlir-tensorrt/test/Dialect/Plan/invalid.mlir | 69 ++++ .../Plan/materialize-shape-calculations.mlir | 4 +- .../Plan/populate-func-bounds-attrs.mlir | 86 +++-- .../test/Dialect/Plan/refine-types.mlir | 30 ++ .../test/Dialect/Plan/roundtrip.mlir | 2 +- .../Dialect/Plan/segmentation-pipeline.mlir | 213 ++++++++++++ .../Dialect/Plan/stablehlo-clustering.mlir | 210 +++++------- ...ion.mlir => canonicalize-convolution.mlir} | 48 +-- .../StableHloExt/canonicalize-gather.mlir | 26 +- ...tter.mlir => canonicalize-scatter-nd.mlir} | 10 +- .../StableHloExt/canonicalize-scatter.mlir | 147 +++++++- .../constant-folding-bitwise.mlir | 63 ++++ .../stablehlo-input-preprocessing.mlir | 123 ------- .../test_stablehlo_dynamic_iota.py | 105 ++++++ 81 files changed, 2930 insertions(+), 1234 deletions(-) create mode 100644 mlir-tensorrt/compiler/include/mlir-tensorrt/Utils/ShapeInfo.h create mode 100644 mlir-tensorrt/compiler/lib/Conversion/ChloToStablehloExt/CMakeLists.txt create mode 100644 mlir-tensorrt/compiler/lib/Conversion/ChloToStablehloExt/ChloToStablehloExt.cpp rename mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/{StablehloPrepareConvolution.cpp => CanonicalizeConvolution.cpp} (90%) delete mode 100644 mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/StablehloInputPreprocessing.cpp delete mode 100644 mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/StablehloPrepareScatter.cpp create mode 100644 mlir-tensorrt/compiler/lib/Utils/ShapeInfo.cpp create mode 100644 mlir-tensorrt/test/Conversion/ChloToStablehloExt/chlo-to-stablehlo-ext.mlir create mode 100644 mlir-tensorrt/test/Conversion/ChloToStablehloExt/lit.local.cfg create mode 100644 mlir-tensorrt/test/Conversion/StablehloToScf/lit.local.cfg create mode 100644 mlir-tensorrt/test/Dialect/Plan/segmentation-pipeline.mlir rename mlir-tensorrt/test/Dialect/StableHloExt/{stablehlo-prepare-convolution.mlir => canonicalize-convolution.mlir} (98%) rename mlir-tensorrt/test/Dialect/StableHloExt/{stablehlo-prepare-scatter.mlir => canonicalize-scatter-nd.mlir} (92%) create mode 100644 mlir-tensorrt/test/Dialect/StableHloExt/constant-folding-bitwise.mlir delete mode 100644 mlir-tensorrt/test/Dialect/StableHloExt/stablehlo-input-preprocessing.mlir create mode 100644 mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_dynamic_iota.py diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Conversion/Passes.td b/mlir-tensorrt/compiler/include/mlir-tensorrt/Conversion/Passes.td index e4eb9f427..57cc84db9 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Conversion/Passes.td +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Conversion/Passes.td @@ -25,6 +25,7 @@ include "mlir/Pass/PassBase.td" //===----------------------------------------------------------------------===// // StablehloToTensorRT //===----------------------------------------------------------------------===// + #ifdef MLIR_TENSORRT_ENABLE_HLO def ConvertStablehloToTensorRTPass : Pass<"convert-stablehlo-to-tensorrt"> { let summary = "Convert Stable HLO dialect to TensorRT dialect"; @@ -44,7 +45,30 @@ def ConvertStablehloToTensorRTPass : Pass<"convert-stablehlo-to-tensorrt"> { "target TensorRT version for conversion"> ]; } +#endif // MLIR_TENSORRT_ENABLE_HLO + +//===----------------------------------------------------------------------===// +// ChloToStableHloExt +//===----------------------------------------------------------------------===// + +#ifdef MLIR_TENSORRT_ENABLE_HLO +def ConvertChloToStableHloExtPass : Pass<"convert-chlo-to-stablehlo-ext"> { + let summary = "Convert specific CHLO operations to stablehlo"; + let description = [{ + This pass converts a CHLO operations to StableHlo while also allowing + for some CHLO operations to be preserved (see options). + }]; + let dependentDialects = [ + "::mlir::stablehlo::StablehloDialect" + ]; + let options = [ + Option<"preserveErf", "preserve-erf", "bool", "true", + "do not convert chlo.erf ops">, + Option<"preserveTopK", "preserve-topk", "bool", "true", + "do not convert chlo.topk ops">, + ]; +} #endif // MLIR_TENSORRT_ENABLE_HLO //===----------------------------------------------------------------------===// diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Conversion/StablehloToTensorRT/StablehloToTensorRT.h b/mlir-tensorrt/compiler/include/mlir-tensorrt/Conversion/StablehloToTensorRT/StablehloToTensorRT.h index 5d0e46793..01acc1a22 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Conversion/StablehloToTensorRT/StablehloToTensorRT.h +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Conversion/StablehloToTensorRT/StablehloToTensorRT.h @@ -25,6 +25,7 @@ #define MLIR_TENSORRT_CONVERSION_HLOTOTENSORRT_HLOTOTENSORRT_H #include "mlir-tensorrt/Conversion/TensorRTCommon/ConvertToTensorRTCommon.h" +#include "mlir-tensorrt/Dialect/StableHloExt/Utils/GatherScatterUtils.h" #include "mlir/IR/PatternMatch.h" namespace mlir { @@ -33,7 +34,8 @@ class ConversionTarget; // Collection of rewrite patterns for lowering of Stable HLO to TensorRT // dialect. void populateStablehloToTensorRtConversionPattern( - TensorRTTypeConverter &typeConverter, RewritePatternSet &patterns); + TensorRTTypeConverter &typeConverter, RewritePatternSet &patterns, + ShapeInfoCallbacks shapeInfoCallbacks = {}); /// Populate patterns for convert Chlo ops to TensorRT ops. void populateChloToTensorRtLegalityAndPatterns( diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/Analysis/BoundsAnalysis.h b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/Analysis/BoundsAnalysis.h index 97e3e98e0..5f5da0a73 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/Analysis/BoundsAnalysis.h +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/Analysis/BoundsAnalysis.h @@ -74,6 +74,8 @@ class BoundsArray { static BoundsArray fromIntegerValueBounds(unsigned bitwidth, ArrayRef min, ArrayRef max); + static BoundsArray fromIntegerValueBounds(ArrayRef min, + ArrayRef max); /// For the given tensor-typed value, return the most conservative bounds for /// the shape of `v`. For each unknown dimension of the shape of `v` the diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanAttributes.td b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanAttributes.td index 6cb8489d0..7b69eb98f 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanAttributes.td +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanAttributes.td @@ -24,7 +24,8 @@ def Plan_HostClusterKindAttr : Plan_Attr<"HostClusterKind", "host_cluster", } -def Plan_BoundsAttr : Plan_Attr<"Bounds", "bounds">{ +def Plan_BoundsAttr : Plan_Attr<"Bounds", "bounds", [ + DeclareAttrInterfaceMethods]>{ let parameters = (ins EnumParameter:$kind, OptionalParameter<"DenseI64ArrayAttr">:$min_shape, @@ -46,17 +47,17 @@ def Plan_BoundsAttr : Plan_Attr<"Bounds", "bounds">{ let extraClassDeclaration = [{ /// Returns true if this bounds is for shape dimension extents. - bool isShapeBound() { + bool isShapeBound() const { return getKind() == BoundsKind::Shape; } /// Returns true if this bounds is a 'none' bounds kind. - bool isNone() { + bool isNone() const { return getKind() == BoundsKind::None; } /// Returns true if this bounds is for values of a tensor. - bool isValueBound() { + bool isValueBound() const { return getKind() == BoundsKind::Value; } diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanDialect.td b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanDialect.td index 9b63534d7..ec0146a3b 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanDialect.td +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanDialect.td @@ -15,6 +15,8 @@ def Plan_Dialect : Dialect { }]; let cppNamespace = "::mlir::plan"; + + let hasRegionArgAttrVerify = 1; let extraClassDeclaration = [{ @@ -72,6 +74,18 @@ def Plan_Dialect : Dialect { (addExtensionOperation(), ...); } + /// Return the name of the function arg/result attributes that encode + /// host tensor value bounds. It should have a type `plan::BoundsAttr`. + static StringRef getValueBoundsAttrName() { + return "plan.value_bounds"; + } + + /// Return the name of the function arg/result attributes that encode + /// the shape bounds. It should have a type `plan::BoundsAttr`. + static StringRef getShapeBoundsAttrName() { + return "plan.shape_profile"; + } + private: ::llvm::StringMap attrParsingHooks; ::llvm::DenseMap<::mlir::TypeID, AttrPrintingHook> attrPrintingHooks; diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanInterfaces.td b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanInterfaces.td index 10393038a..91c2262cb 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanInterfaces.td +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanInterfaces.td @@ -3,6 +3,34 @@ include "mlir/IR/OpBase.td" +//===----------------------------------------------------------------------===// +// TensorBoundsAttrInterface +//===----------------------------------------------------------------------===// + +def TensorBoundsAttrInterface : AttrInterface<"TensorBoundsAttrInterface"> { + let cppNamespace = "::mlir::plan"; + let methods = [ + InterfaceMethod< + /*desc=*/"Return the shape bounds associated with the attribute", + /*retTy=*/"LogicalResult", + "getShapeBounds", + (ins "llvm::SmallVectorImpl &":$min, + "llvm::SmallVectorImpl &":$max), + /*body=*/"", + "" + >, + InterfaceMethod< + /*desc=*/"Return the integer value bounds associated with the attribute", + /*retTy=*/"LogicalResult", + "getIntegerValueBounds", + (ins "llvm::SmallVectorImpl &":$min, + "llvm::SmallVectorImpl &":$max), + /*body=*/"", + "" + > + ]; +} + //===----------------------------------------------------------------------===// // ClusterKindInterface //===----------------------------------------------------------------------===// diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/StableHloExt/Transforms/Passes.td b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/StableHloExt/Transforms/Passes.td index 09da5db62..398f70aa1 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/StableHloExt/Transforms/Passes.td +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/StableHloExt/Transforms/Passes.td @@ -202,32 +202,11 @@ def LowerSpecialCustomCalls : Pass<"stablehlo-ext-lower-special-custom-calls"> { } //===----------------------------------------------------------------------===// -// StablehloInputPreprocessingPass +// CanonicalizeConvolutionPass //===----------------------------------------------------------------------===// -def StablehloInputPreprocessingPass : Pass<"tensorrt-stablehlo-input-preprocessing"> { - let summary = "Prepares Stable HLO dialect operations for conversion to TensorRT"; - - let description = [{ - This pass contains a set of patterns for simplifying or transforming Stable HLO - input IR so that conversion to the TensorRT dialect is more straightforward. - - In particular: - - - Simplify certain patterns commonly found in IR emitted for JAX programs - but not covered by existing Stable HLO canonicalizations/transforms. - - - Prepare convolutions to be NCHW/FCRS format and have at least two - "spatial" dimensions. - - - Canonicalize `stablehlo.scatter` operations so that they can be converted to - `tensorrt.scatter` in a straightforward manner. - }]; - - let dependentDialects = [ - "::mlir::tensor::TensorDialect", - "::mlir::stablehlo::StablehloDialect" - ]; +def CanonicalizeConvolutionPass : Pass<"stablehlo-ext-canonicalize-convolution"> { + let summary = "Canonicalizes stablehlo convolution operations"; } diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/StableHloExt/Utils/GatherScatterUtils.h b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/StableHloExt/Utils/GatherScatterUtils.h index cffb9470d..2b1c941cb 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/StableHloExt/Utils/GatherScatterUtils.h +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/StableHloExt/Utils/GatherScatterUtils.h @@ -26,6 +26,8 @@ #ifndef MLIR_TENSORRT_DIALECT_STABLEHLOEXT_UTILS_GATHERSCATTERUTILS_H #define MLIR_TENSORRT_DIALECT_STABLEHLOEXT_UTILS_GATHERSCATTERUTILS_H +#include "mlir-tensorrt/Utils/ShapeInfo.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/IR/Value.h" #include @@ -34,6 +36,7 @@ namespace mlir { class OpBuilder; namespace stablehlo { +class DynamicGatherOp; class GatherOp; class ScatterOp; @@ -124,6 +127,15 @@ namespace stablehlo_ext { std::optional isSingleDimSimpleGatherWithImplicitIndexDim(stablehlo::GatherOp op); +/// Returns the "gather dimension" if `op` is a 'simple, single dimension' +/// gather op with implicit index vector dimension (see above for definition). +/// This version works for `stablehlo.dynamic_gather` using pattern matching +/// against the expected canonical form when the operand shape along some +/// "offset dimensions" is dynamic. +std::optional isSingleDimSimpleGatherWithImplicitIndexDim( + stablehlo::DynamicGatherOp op, + const ShapeInfoCallbacks &shapeInfoCallbacks); + /// Returns the "gather dimension" if `op` is a 'simple, single dimension' /// gather op with explicit size-1 index vector dimension (see above for /// definition). @@ -138,6 +150,21 @@ bool isSimpleLeadingMultiDimGather(stablehlo::GatherOp op); /// gather' (see definition above). bool isSimpleLeadingMultiDimGatherWithDegenerateDims(stablehlo::GatherOp op); +/// Attempts to construct a `stablehlo.reshape` if result type is statically +/// shaped, otherwise creates `stablehlo.dynamic_reshape`. +Value createCollapsingReshape(OpBuilder &b, Location loc, Value input, + ArrayRef reassociation); + +/// Attempts to construct a `stablehlo.reshape` if `resultType` is statically +/// shaped, otherwise creates a `stablehlo.dynamic_reshape`. +Value createExpandingReshape(OpBuilder &b, Location loc, + RankedTensorType resultType, Value input, + ArrayRef reassociation); + +/// Returns true if the `scatterOp` has a configuration that corresponds to the +/// ONNX ScatterNd operation semantic. +bool isCanonicalScatterNd(stablehlo::ScatterOp scatterOp); + //===----------------------------------------------------------------------===// // Code below this point was adapted from the MLIR-HLO project (part of OpenXLA // project) `xla/mlir_hlo/mhlo/utils/mhlo_scatter_gather_utils.h` and has the @@ -155,10 +182,6 @@ bool isSimpleLeadingMultiDimGatherWithDegenerateDims(stablehlo::GatherOp op); // - scatter_dims_to_operand_dims is [0, 1, ...] bool isCanonicalScatter(stablehlo::ScatterOp scatterOp); -/// Returns true if the `scatterOp` has a configuration that corresponds to the -/// ONNX ScatterNd operation semantic. -bool isCanonicalScatterNd(stablehlo::ScatterOp scatterOp); - // Checks if the gather has the following characteristics: // - start_indices is a two-dimensional tensor // - index_vector_dim is 1 diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Pipelines/StableHloInputPipelines.h b/mlir-tensorrt/compiler/include/mlir-tensorrt/Pipelines/StableHloInputPipelines.h index 9b96ecc31..506c72281 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Pipelines/StableHloInputPipelines.h +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Pipelines/StableHloInputPipelines.h @@ -32,12 +32,19 @@ class OpPassManager; struct StableHloInputOptions { /// Whether to lower Stablehlo control flow ops to SCF dialect ops. bool legalizeControlFlowToSCF = false; - /// Whether to lower chlo.erf into primitive stablehlo operations. - bool legalizeChloErfToStablehlo = false; + + /// Whether to preserve 'chlo.erf' ops or lower them to 'stablehlo' ops. + /// By default, we preserve since it has a 1-1 correspondence with a TensorRT + /// op. + bool preserveChloErf = true; + + /// Whether to preserve 'chlo.top_k' ops or lower them to 'stablehlo' ops. + /// By default, we preserve since it has a 1-1 correspondence with a TensorRT + /// op. + bool preserveChloTopK = true; + /// Whether to disable running the inliner. bool disableInliner = false; - /// Whether to lower chlo to stablehlo. - bool convertChloToStablehlo = false; }; /// Construct a pipeline for preprocessing StableHLO IR to convert it into the diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Utils/ShapeInfo.h b/mlir-tensorrt/compiler/include/mlir-tensorrt/Utils/ShapeInfo.h new file mode 100644 index 000000000..e0b160304 --- /dev/null +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Utils/ShapeInfo.h @@ -0,0 +1,96 @@ +//===- ShapeInfo.h ---------------------------------------------*- C++ -*-===// +// +// SPDX-FileCopyrightText: Copyright 2024 NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// Declarations for callback types are used to abstract away how to infer +/// shape knowledge from a pass or transformation. For example, a pass operating +/// on StableHlo IR may need to check whether the *values* of tensor A represent +/// the actual *shape* of tensor B, whose shape may not be known statically at +/// compile time. +/// +/// The specific mechanism that one may use to determine the validity of a +/// specific proposition like the example above (which must be reported as +/// "unknown", "true", or "false") may depend on the context. In the case +/// of the StableHlo example above, we could try to naively pattern match +/// whether tensor A is the result of `stablehlo.concat` of appropriate +/// `stablehlo.get_dimensions_size %A, dim = ...` results. In other cases, +/// we may have access to an analysis that assists with more robustly +/// checking the proposition. +/// +/// This file just contains callback types that a Pass or rewrite/transform can +/// accept as a parameter, allowing the creator or caller to hand in a +/// particular implementation. +/// +//===----------------------------------------------------------------------===// +#ifndef MLIR_TENSORRT_UTILS_SHAPEINFO +#define MLIR_TENSORRT_UTILS_SHAPEINFO + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Value.h" + +namespace mlir { + +/// TensorElementValue identifies a particular scalar element value of a +/// statically-shaped tensor. +struct TensorElementValue { + TensorElementValue(Value value, ArrayRef coord); + + TypedValue getTensor() const { return tensor; } + int64_t getLinearIndex() const { return linearIndex; } + + /// A value of type (must be statically-shaped) RankedTensorType. + TypedValue tensor; + + /// The linear coordinate of the value. + int64_t linearIndex; +}; + +/// TensorShapeDimExtent identifies a (potentially dynamically shaped) size +/// of a particular dimension of a tensor's shape. +struct TensorShapeDimExtent { + TensorShapeDimExtent(Value value, int64_t dim); + + std::optional getConstantSize() const; + + /// A value of type RankedTensorType. + TypedValue tensor; + + /// The dimension. + int64_t dim; +}; + +struct ShapeInfoCallbacks { + // Check whether 'tensorElementValue' is provably equivalent to + // `tensorShapeDimExtent`. Returning 'nullopt' means "unknown", true means + // "equal", false means "not equal". + std::function(TensorElementValue tensorElementValue, + TensorShapeDimExtent tensorShapeDimExtent)> + isElementValueEqualToShapeDimExtent; + + // Check whether 'tensorElementValue' is provably equivalent to the given + // static value. Returning 'nullopt' means "unknown", true means "equal", + // false means "not equal". + std::function(TensorElementValue tensorElementValue, + Attribute constantValue)> + isElementValueEqualToConstant; +}; + +} // namespace mlir + +#endif // MLIR_TENSORRT_UTILS_SHAPEINFO diff --git a/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp b/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp index 3b8308449..50f146645 100644 --- a/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp +++ b/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp @@ -185,7 +185,8 @@ compiler::getStableHLOProgramRefinedSignature( { mlir::StableHloInputOptions opts{}; opts.legalizeControlFlowToSCF = false; - opts.convertChloToStablehlo = false; + opts.preserveChloErf = true; + opts.preserveChloTopK = true; mlir::buildStablehloPreProcessingPipeline(pm, opts); } @@ -382,7 +383,8 @@ void StableHloToExecutableTask::populatePassManager( // StableHLO Preprocessing mlir::StableHloInputOptions opts{}; opts.legalizeControlFlowToSCF = false; - opts.convertChloToStablehlo = false; + opts.preserveChloErf = true; + opts.preserveChloTopK = true; mlir::buildStablehloPreProcessingPipeline(pm, opts); buildStablehloClusteringPipeline(pm, options); diff --git a/mlir-tensorrt/compiler/lib/Conversion/CMakeLists.txt b/mlir-tensorrt/compiler/lib/Conversion/CMakeLists.txt index dbf98cbf6..804fb3dfc 100644 --- a/mlir-tensorrt/compiler/lib/Conversion/CMakeLists.txt +++ b/mlir-tensorrt/compiler/lib/Conversion/CMakeLists.txt @@ -4,6 +4,7 @@ if(MLIR_TRT_ENABLE_HLO) add_subdirectory(StablehloToTensorRT) add_subdirectory(StablehloScalarToArith) add_subdirectory(StablehloToScf) + add_subdirectory(ChloToStablehloExt) endif() if(MLIR_TRT_TARGET_CPP) diff --git a/mlir-tensorrt/compiler/lib/Conversion/ChloToStablehloExt/CMakeLists.txt b/mlir-tensorrt/compiler/lib/Conversion/ChloToStablehloExt/CMakeLists.txt new file mode 100644 index 000000000..94feea8b8 --- /dev/null +++ b/mlir-tensorrt/compiler/lib/Conversion/ChloToStablehloExt/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_tensorrt_library(MLIRTensorRTChloToStablehloExt + ChloToStablehloExt.cpp + + DEPENDS + MLIRTensorRTConversionPassIncGen + + LINK_LIBS PUBLIC + ChloOps + MLIRDialectUtils + MLIRIR + MLIRPass + MLIRRewrite + MLIRTensorRTDialect + MLIRTransformUtils + StablehloOps + StablehloPasses + ) \ No newline at end of file diff --git a/mlir-tensorrt/compiler/lib/Conversion/ChloToStablehloExt/ChloToStablehloExt.cpp b/mlir-tensorrt/compiler/lib/Conversion/ChloToStablehloExt/ChloToStablehloExt.cpp new file mode 100644 index 000000000..edc68dea3 --- /dev/null +++ b/mlir-tensorrt/compiler/lib/Conversion/ChloToStablehloExt/ChloToStablehloExt.cpp @@ -0,0 +1,79 @@ +//===- ChloToStablehloExt.cpp ---------------------------------------------===// +// +// SPDX-FileCopyrightText: Copyright 2024 NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// Convert certain CHLO ops to stablehlo ops. We only need this instantiation +/// of the upstream pass since we need to selectively preserve certain CHLO ops +/// like 'top k'. +/// +//===----------------------------------------------------------------------===// +#include "mlir-tensorrt/Conversion/Passes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" +#include "stablehlo/dialect/ChloOps.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/transforms/Passes.h" + +namespace mlir { +#define GEN_PASS_DEF_CONVERTCHLOTOSTABLEHLOEXTPASS +#include "mlir-tensorrt/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { + +struct ChloToStablehloExtPass + : public impl::ConvertChloToStableHloExtPassBase { +public: + using Base::Base; + + LogicalResult initialize(MLIRContext *context) override { + target = std::make_shared(*context); + + target->addDynamicallyLegalDialect([&](Operation *op) { + if (isa(op)) + return preserveErf.getValue(); + if (isa(op)) + return preserveTopK.getValue(); + return false; + }); + + target->markUnknownOpDynamicallyLegal([](Operation *op) { return true; }); + + RewritePatternSet patterns_(context); + stablehlo::populateChloToStablehloPatterns(context, &patterns_); + patterns = std::move(patterns_); + + return success(); + } + + void runOnOperation() override { + if (failed(applyPartialConversion(getOperation(), *target, patterns))) { + emitError(getOperation()->getLoc()) + << "failed to apply patterns in " << getArgument(); + signalPassFailure(); + } + } + +private: + std::shared_ptr target; + FrozenRewritePatternSet patterns; +}; +} // namespace \ No newline at end of file diff --git a/mlir-tensorrt/compiler/lib/Conversion/PlanToExecutor/PlanToExecutor.cpp b/mlir-tensorrt/compiler/lib/Conversion/PlanToExecutor/PlanToExecutor.cpp index badf6ff42..017c9e9ad 100644 --- a/mlir-tensorrt/compiler/lib/Conversion/PlanToExecutor/PlanToExecutor.cpp +++ b/mlir-tensorrt/compiler/lib/Conversion/PlanToExecutor/PlanToExecutor.cpp @@ -41,10 +41,6 @@ namespace mlir { using namespace mlir; -static constexpr llvm::StringRef kShapeBoundsAttrName = - "tensorrt.shape_profile"; -static constexpr llvm::StringRef kValueBoundsAttrName = "tensorrt.value_bounds"; - namespace { class GenericStructuralConverter : public ConversionPattern { @@ -130,8 +126,7 @@ struct ConstantOpConverter : public OpConversionPattern { /// Convert 'plan' dialect or 'tensorrt' dialect bounds into 'executor' bounds /// attributes. -static Attribute convertArgOrResultAttr(OpBuilder &b, Attribute attr, - llvm::StringRef name) { +static Attribute convertArgOrResultAttr(OpBuilder &b, Attribute attr) { MLIRContext *ctx = attr.getContext(); if (auto planAttr = dyn_cast(attr)) { if (planAttr.isShapeBound()) @@ -141,29 +136,33 @@ static Attribute convertArgOrResultAttr(OpBuilder &b, Attribute attr, return executor::ValueBoundsAttr::get(ctx, planAttr.getMinValues(), planAttr.getMaxValues()); } - if (auto trtAttr = dyn_cast(attr)) { - if (name == kValueBoundsAttrName) - return executor::ValueBoundsAttr::get( - ctx, b.getI64TensorAttr(trtAttr.getMin()), - b.getI64TensorAttr(trtAttr.getMax())); - if (name == kShapeBoundsAttrName) - return executor::DimensionBoundsAttr::get( - ctx, b.getDenseI64ArrayAttr(trtAttr.getMin()), - b.getDenseI64ArrayAttr(trtAttr.getMax())); - } return attr; } -/// Convert 'plan' dialect or 'tensorrt' dialect bounds into 'executor' bounds +/// Convert 'plan' dialect arg|result attributes into 'executor' dialect /// attributes for all function arg attrs and res attrs. static void convertArgAndResultAttrs(OpBuilder &b, func::FuncOp op) { + StringRef executorShapeBoundsAttrName = + mlir::executor::ExecutorDialect::getShapeBoundsAttrName(); + StringRef executorValueBoundsAttrName = + mlir::executor::ExecutorDialect::getValueBoundsAttrName(); + + StringRef planShapeBoundsAttrName = + mlir::plan::PlanDialect::getShapeBoundsAttrName(); + StringRef planValueBoundsAttrName = + mlir::plan::PlanDialect::getValueBoundsAttrName(); + for (unsigned idx = 0; idx < op.getNumArguments(); idx++) { - if (auto attr = op.getArgAttr(idx, kShapeBoundsAttrName)) - op.setArgAttr(idx, kShapeBoundsAttrName, - convertArgOrResultAttr(b, attr, kShapeBoundsAttrName)); - if (auto attr = op.getArgAttr(idx, kValueBoundsAttrName)) - op.setArgAttr(idx, kValueBoundsAttrName, - convertArgOrResultAttr(b, attr, kValueBoundsAttrName)); + if (auto attr = op.getArgAttr(idx, planShapeBoundsAttrName)) { + op.removeArgAttr(idx, planShapeBoundsAttrName); + op.setArgAttr(idx, executorShapeBoundsAttrName, + convertArgOrResultAttr(b, attr)); + } + if (auto attr = op.getArgAttr(idx, planValueBoundsAttrName)) { + op.removeArgAttr(idx, planValueBoundsAttrName); + op.setArgAttr(idx, executorValueBoundsAttrName, + convertArgOrResultAttr(b, attr)); + } if (auto attr = op.getArgAttr(idx, plan::PlanDialect::kResultArgAttrName)) { op.removeArgAttr(idx, plan::PlanDialect::kResultArgAttrName); @@ -171,12 +170,16 @@ static void convertArgAndResultAttrs(OpBuilder &b, func::FuncOp op) { } } for (unsigned idx = 0; idx < op.getNumResults(); idx++) { - if (auto attr = op.getResultAttr(idx, kShapeBoundsAttrName)) - op.setResultAttr(idx, kShapeBoundsAttrName, - convertArgOrResultAttr(b, attr, kShapeBoundsAttrName)); - if (auto attr = op.getResultAttr(idx, kValueBoundsAttrName)) - op.setResultAttr(idx, kValueBoundsAttrName, - convertArgOrResultAttr(b, attr, kValueBoundsAttrName)); + if (auto attr = op.getResultAttr(idx, planShapeBoundsAttrName)) { + op.removeResultAttr(idx, b.getStringAttr(planShapeBoundsAttrName)); + op.setResultAttr(idx, executorShapeBoundsAttrName, + convertArgOrResultAttr(b, attr)); + } + if (auto attr = op.getResultAttr(idx, planValueBoundsAttrName)) { + op.removeResultAttr(idx, b.getStringAttr(planValueBoundsAttrName)); + op.setResultAttr(idx, executorValueBoundsAttrName, + convertArgOrResultAttr(b, attr)); + } } } diff --git a/mlir-tensorrt/compiler/lib/Conversion/StablehloToTensorRT/CMakeLists.txt b/mlir-tensorrt/compiler/lib/Conversion/StablehloToTensorRT/CMakeLists.txt index a4367fb04..c41f91c0e 100644 --- a/mlir-tensorrt/compiler/lib/Conversion/StablehloToTensorRT/CMakeLists.txt +++ b/mlir-tensorrt/compiler/lib/Conversion/StablehloToTensorRT/CMakeLists.txt @@ -18,6 +18,7 @@ add_mlir_tensorrt_library(MLIRTensorRTStablehloToTensorRT MLIRTensorRTStableHloExtUtils MLIRTensorRTStablehloMatchers MLIRTensorRTTensorRTUtils + MLIRTensorRTUtilsShapeInfo MLIRTransforms MLIRTransformUtils StablehloOps diff --git a/mlir-tensorrt/compiler/lib/Conversion/StablehloToTensorRT/StablehloToTensorRT.cpp b/mlir-tensorrt/compiler/lib/Conversion/StablehloToTensorRT/StablehloToTensorRT.cpp index b07eec882..ac178f515 100644 --- a/mlir-tensorrt/compiler/lib/Conversion/StablehloToTensorRT/StablehloToTensorRT.cpp +++ b/mlir-tensorrt/compiler/lib/Conversion/StablehloToTensorRT/StablehloToTensorRT.cpp @@ -2404,14 +2404,53 @@ struct ConvertDynamicIota auto resultType = dyn_cast_or_null( this->getTypeConverter()->convertType(op.getType())); - if (!resultType || resultType.getRank() != 1) + if (!resultType) return failure(); + + // For rank-1 iota, we don't need to worry about creating a dynamic + // "step" tensor. + if (resultType.getRank() == 1) + return trtRewriter.checkAndReplaceOpWithNewOp( + op, targetTrtMajorVersion, resultType, + /*shape=*/adaptor.getOutputShape(), /*start=*/Value(), + /*step=*/Value(), + /*static_start=*/rewriter.getF64FloatAttr(0.0), + /*static_step=*/rewriter.getF64FloatAttr(1.0)) + ? success() + : failure(); + + // For greater thank rank-1 iota, we generate use a 1-d constant + // "step tensor". `tensorrt.linspace` has the following semantic: + // `result[coord...] = start + dot(step, [coord...])`. + const uint64_t dim = op.getIotaDimension(); + SmallVector stepValues( + resultType.getRank(), + rewriter.getZeroAttr(resultType.getElementType())); + stepValues[dim] = rewriter.getOneAttr(resultType.getElementType()); + + RankedTensorType stepTensorType = resultType.clone({resultType.getRank()}); + RankedTensorType startTensorType = resultType.clone(ArrayRef{}); + + auto constStart = trtRewriter.checkAndCreate( + op.getLoc(), targetTrtMajorVersion, startTensorType, + DenseElementsAttr::get( + startTensorType, + rewriter.getZeroAttr(resultType.getElementType()))); + if (!constStart) + return failure(); + + auto constStep = trtRewriter.checkAndCreate( + op.getLoc(), targetTrtMajorVersion, stepTensorType, + DenseElementsAttr::get(stepTensorType, stepValues)); + if (!constStep) + return failure(); + return trtRewriter.checkAndReplaceOpWithNewOp( op, targetTrtMajorVersion, resultType, - /*shape=*/adaptor.getOutputShape(), /*start=*/Value(), - /*step=*/Value(), - /*static_start=*/rewriter.getF64FloatAttr(0.0), - /*static_step=*/rewriter.getF64FloatAttr(1.0)) + /*shape=*/adaptor.getOutputShape(), /*start=*/constStart, + /*step=*/constStep, + /*static_start=*/FloatAttr{}, + /*static_step=*/FloatAttr{}) ? success() : failure(); } @@ -3933,6 +3972,48 @@ struct SingleDimSimpleGatherToTensorRTGatherPattern } }; +/// Convert `stablehlo.dynamic_gather` that represents a 'Simple, Single +/// Dimension Gather' with implicit index dimension to `tensorrt.gather`. +struct SingleDimSimpleDynamicGatherToTensorRTGatherPattern + : public ConvertHloOpToTensorRTPattern { + + SingleDimSimpleDynamicGatherToTensorRTGatherPattern( + const ShapeInfoCallbacks &shapeInfoCallbacks, + TensorRTTypeConverter &typeConverter, MLIRContext *ctx, + PatternBenefit benefit = 1) + : ConvertHloOpToTensorRTPattern(typeConverter, ctx, benefit), + shapeInfoCallbacks(shapeInfoCallbacks) {} + + using ConvertHloOpToTensorRTPattern::ConvertHloOpToTensorRTPattern; + LogicalResult + matchAndRewrite(stablehlo::DynamicGatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + TensorRTConversionPatternRewriter trtRewriter(rewriter); + int64_t targetTrtMajorVersion = + this->getTypeConverter()->getOptions().getTensorRTVersion(); + RankedTensorType resultType = dyn_cast_or_null( + typeConverter->convertType(op.getType())); + if (!resultType) + return failure(); + + std::optional gatherDim = + stablehlo_ext::isSingleDimSimpleGatherWithImplicitIndexDim( + op, shapeInfoCallbacks); + if (!gatherDim) + return rewriter.notifyMatchFailure(op, "not a correct gather op"); + + return trtRewriter.checkAndReplaceOpWithNewOp( + op, targetTrtMajorVersion, resultType, adaptor.getOperand(), + adaptor.getStartIndices(), *gatherDim) + ? success() + : rewriter.notifyMatchFailure(op, + "could not create a valid TRT op"); + } + +protected: + ShapeInfoCallbacks shapeInfoCallbacks; +}; + struct ConvertGatherToTensorRT : public ConvertHloOpToTensorRTPattern { using ConvertHloOpToTensorRTPattern::ConvertHloOpToTensorRTPattern; @@ -4453,7 +4534,8 @@ class ConvertStablehloToTensorRtPass } // namespace void mlir::populateStablehloToTensorRtConversionPattern( - TensorRTTypeConverter &typeConverter, RewritePatternSet &patterns) { + TensorRTTypeConverter &typeConverter, RewritePatternSet &patterns, + ShapeInfoCallbacks shapeInfoCallbacks) { // Add larger patterns with a higher // benefit so that they run first. patterns.add( @@ -4527,4 +4609,66 @@ void mlir::populateStablehloToTensorRtConversionPattern( CompositeToQDQConverter >(typeConverter, patterns.getContext(), PatternBenefit(1)); // clang-format on + + if (!shapeInfoCallbacks.isElementValueEqualToConstant) + shapeInfoCallbacks.isElementValueEqualToConstant = + [](TensorElementValue elementValue, + Attribute constValue) -> std::optional { + RankedTensorType shapeTensorType = elementValue.getTensor().getType(); + if (shapeTensorType.getRank() != 1 || + !shapeTensorType.getElementType().isIntOrIndex()) + return {}; + + auto concatOp = + elementValue.getTensor().getDefiningOp(); + if (!concatOp || static_cast(concatOp.getOperands().size()) != + shapeTensorType.getDimSize(0)) + return {}; + + Value element = concatOp.getOperands()[elementValue.getLinearIndex()]; + SplatElementsAttr splat = {}; + if (!matchPattern(element, m_Constant(&splat))) + return {}; + return splat.getSplatValue() == constValue; + }; + + if (!shapeInfoCallbacks.isElementValueEqualToShapeDimExtent) + shapeInfoCallbacks.isElementValueEqualToShapeDimExtent = + [](TensorElementValue elementValue, + TensorShapeDimExtent dimExtent) -> std::optional { + RankedTensorType shapeTensorType = elementValue.getTensor().getType(); + RankedTensorType valueTensorType = dimExtent.tensor.getType(); + + if (shapeTensorType.getRank() != 1 || + !shapeTensorType.getElementType().isIntOrIndex()) + return {}; + + auto concatOp = + elementValue.getTensor().getDefiningOp(); + if (!concatOp || static_cast(concatOp.getOperands().size()) != + shapeTensorType.getDimSize(0)) + return {}; + + Value element = concatOp.getOperands()[elementValue.getLinearIndex()]; + + DenseElementsAttr splat = {}; + if (!valueTensorType.isDynamicDim(dimExtent.dim) && + matchPattern(element, m_Constant(&splat))) + return splat.getSplatValue().getSExtValue() == + valueTensorType.getDimSize(dimExtent.dim); + + if (!matchPattern(element, m_Op( + m_Op()))) + return {}; + + auto dimSizeOp = element.getDefiningOp() + .getOperand() + .getDefiningOp(); + return dimSizeOp.getOperand() == dimExtent.tensor && + dimSizeOp.getDimension() == static_cast(dimExtent.dim); + }; + + patterns.add( + shapeInfoCallbacks, typeConverter, patterns.getContext(), + PatternBenefit(1)); } diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/Analysis/BoundsAnalysis.cpp b/mlir-tensorrt/compiler/lib/Dialect/Plan/Analysis/BoundsAnalysis.cpp index 03c8cecc6..2236c21fc 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/Plan/Analysis/BoundsAnalysis.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/Analysis/BoundsAnalysis.cpp @@ -24,6 +24,7 @@ #include "mlir-tensorrt/Dialect/Plan/Analysis/BoundsAnalysis.h" #include "mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.h" #include "mlir-tensorrt/Dialect/Plan/IR/Plan.h" +#include "mlir-tensorrt/Dialect/Plan/IR/PlanInterfaces.h" #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -42,16 +43,15 @@ using namespace mlir::plan; #define DEBUG_TYPE "plan-bounds-analysis" #define DBGS(x) llvm::dbgs() << " [" DEBUG_TYPE "][" x "] " -static std::optional -maybeGetFunctionArgBound(Value value, StringRef attrName) { +template +std::optional maybeGetFunctionArgBound(Value value, StringRef attrName) { BlockArgument source = dyn_cast(value); if (!source) return {}; func::FuncOp func = dyn_cast(source.getOwner()->getParentOp()); if (!func) return {}; - auto shapeProfile = func.getArgAttrOfType( - source.getArgNumber(), attrName); + auto shapeProfile = func.getArgAttrOfType(source.getArgNumber(), attrName); if (!shapeProfile) return {}; return shapeProfile; @@ -137,6 +137,14 @@ BoundsArray BoundsArray::fromIntegerValueBounds(unsigned bitWidth, return BoundsArray(std::move(res)); } +BoundsArray BoundsArray::fromIntegerValueBounds(ArrayRef min, + ArrayRef max) { + SmallVector res; + for (auto [l, r] : llvm::zip_equal(min, max)) + res.push_back(ConstantIntRanges::fromSigned(l, r)); + return BoundsArray(std::move(res)); +} + BoundsArray BoundsArray::join(const BoundsArray &lhs, const BoundsArray &rhs) { if (lhs.isUninitialized()) return rhs; @@ -274,16 +282,17 @@ void ShapeBoundsForwardAnalysis::setToEntryState(ShapeBoundsLattice *lattice) { lattice, lattice->join(BoundsArray::getMaxRangeForShapeBounds( lattice->getAnchor()))); - std::optional shapeProfile = - maybeGetFunctionArgBound( - lattice->getAnchor(), - tensorrt::TensorRTDialect::getShapeProfileArgAttrName()); + std::optional shapeProfile = + maybeGetFunctionArgBound( + lattice->getAnchor(), plan::PlanDialect::getShapeBoundsAttrName()); + if (!shapeProfile) return propagateIfChanged(lattice, lattice->join(BoundsArray())); - + SmallVector minBound, maxBound; + if (failed(shapeProfile->getShapeBounds(minBound, maxBound))) + return; return propagateIfChanged( - lattice, lattice->join(BoundsArray::fromShapeBounds( - shapeProfile->getMin(), shapeProfile->getMax()))); + lattice, lattice->join(BoundsArray::fromShapeBounds(minBound, maxBound))); } LogicalResult ShapeBoundsForwardAnalysis::visitOperation( @@ -458,20 +467,23 @@ static std::optional maybeGetValueBounds(Value value, std::optional linearIndex) { Type elType = cast(value.getType()).getElementType(); assert(elType.isSignlessIntOrIndex() && "expected integer or index type"); - unsigned bitWidth = elType.isIndex() ? IndexType::kInternalStorageBitWidth - : elType.getIntOrFloatBitWidth(); - std::optional bound = maybeGetFunctionArgBound( - value, tensorrt::TensorRTDialect::getShapeTensorValueBoundsArgAttrName()); + std::optional bound = + maybeGetFunctionArgBound( + value, plan::PlanDialect::getValueBoundsAttrName()); if (!bound) return {}; - ArrayRef min = bound->getMin(); - ArrayRef max = bound->getMax(); + + SmallVector mins, maxs; + if (failed(bound->getIntegerValueBounds(mins, maxs))) + return {}; + auto comp = [](const llvm::APInt &lhs, const llvm::APInt &rhs) { + return lhs.sle(rhs); + }; if (!linearIndex) return ConstantIntRanges::fromSigned( - APInt(bitWidth, *std::min_element(min.begin(), min.end())), - APInt(bitWidth, *std::max_element(max.begin(), max.end()))); - return ConstantIntRanges::fromSigned(APInt(bitWidth, min[*linearIndex]), - APInt(bitWidth, max[*linearIndex])); + *std::min_element(mins.begin(), mins.end(), comp), + *std::max_element(maxs.begin(), maxs.end(), comp)); + return ConstantIntRanges::fromSigned(mins[*linearIndex], maxs[*linearIndex]); } static ConstantIntRanges @@ -540,19 +552,18 @@ static void inferResultRanges(tensor::DimOp dimOp, ConstantIntRanges::fromSigned(intStatic, intStatic)); } - std::optional shapeProfile = - maybeGetFunctionArgBound( - dimOp.getSource(), - tensorrt::TensorRTDialect::getShapeProfileArgAttrName()); + std::optional shapeProfile = + maybeGetFunctionArgBound( + dimOp.getSource(), plan::PlanDialect::getShapeBoundsAttrName()); if (!shapeProfile) return setResultRanges(dimOp.getResult(), getMaxDimRange()); setResultRanges(dimOp.getResult(), ConstantIntRanges::fromSigned( APInt(IndexType::kInternalStorageBitWidth, - shapeProfile->getMin()[*staticDimNum]), + shapeProfile->getMinShape()[*staticDimNum]), APInt(IndexType::kInternalStorageBitWidth, - shapeProfile->getMax()[*staticDimNum]))); + shapeProfile->getMaxShape()[*staticDimNum]))); } static void inferResultRanges(tensor::ExtractOp extractOp, @@ -603,10 +614,9 @@ void ShapeIntegerRangeAnalysis::setToEntryState( if (!lattice->getAnchor().getType().isIntOrIndex()) return propagateIfChanged(lattice, lattice->join(IntegerValueRange())); - std::optional shapeProfile = - maybeGetFunctionArgBound( - lattice->getAnchor(), - tensorrt::TensorRTDialect::getShapeTensorValueBoundsArgAttrName()); + std::optional shapeProfile = + maybeGetFunctionArgBound( + lattice->getAnchor(), plan::PlanDialect::getValueBoundsAttrName()); if (!shapeProfile) { IntegerValueRange range = IntegerValueRange::getMaxRange(lattice->getAnchor()); @@ -614,15 +624,13 @@ void ShapeIntegerRangeAnalysis::setToEntryState( range = IntegerValueRange(truncateToNonNegative(range.getValue())); return propagateIfChanged(lattice, lattice->join(range)); } - assert(shapeProfile->getMax().size() == 1 && + assert(shapeProfile->getMaxValues().getNumElements() == 1 && "expected one element for scalar value bounds"); - Type intType = lattice->getAnchor().getType(); - unsigned bitWidth = intType.isIndex() ? IndexType::kInternalStorageBitWidth - : intType.getIntOrFloatBitWidth(); + return propagateIfChanged( lattice, lattice->join(IntegerValueRange(ConstantIntRanges::fromSigned( - APInt(bitWidth, shapeProfile->getMin()[0]), - APInt(bitWidth, shapeProfile->getMax()[0]))))); + shapeProfile->getMinValues().getSplatValue(), + shapeProfile->getMaxValues().getSplatValue())))); } /// Visit an operation. Invoke the transfer function on each operation that @@ -716,18 +724,19 @@ void TensorValueBoundsAnalysis::setToEntryState( if (!shouldAnalyzeValueBounds(point)) return propagateIfChanged(lattice, lattice->join(BoundsArray())); - std::optional shapeProfile = - maybeGetFunctionArgBound( - point, - tensorrt::TensorRTDialect::getShapeTensorValueBoundsArgAttrName()); - if (!shapeProfile) + std::optional shapeProfile = + maybeGetFunctionArgBound( + point, plan::PlanDialect::getValueBoundsAttrName()); + if (!shapeProfile || !shapeProfile->isValueBound()) return propagateIfChanged(lattice, lattice->join(BoundsArray())); return propagateIfChanged( - lattice, lattice->join(BoundsArray::fromIntegerValueBounds( - ConstantIntRanges::getStorageBitwidth( - mlir::getElementTypeOrSelf(point.getType())), - shapeProfile->getMin(), shapeProfile->getMax()))); + lattice, + lattice->join(BoundsArray::fromIntegerValueBounds( + llvm::to_vector( + shapeProfile->getMinValues().getValues()), + llvm::to_vector( + shapeProfile->getMaxValues().getValues())))); } static void maybePopulateConstantValueBounds( diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/IR/BuiltinClusterKinds.cpp b/mlir-tensorrt/compiler/lib/Dialect/Plan/IR/BuiltinClusterKinds.cpp index 2e732e9a6..3bdd596d0 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/Plan/IR/BuiltinClusterKinds.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/IR/BuiltinClusterKinds.cpp @@ -27,8 +27,12 @@ #include "mlir-tensorrt/Conversion/StablehloToTensorRT/StablehloToTensorRT.h" #include "mlir-tensorrt/Conversion/TensorRTCommon/ConvertToTensorRTCommon.h" #include "mlir-tensorrt/Dialect/Plan/IR/Plan.h" +#include "mlir-tensorrt/Utils/ShapeInfo.h" #include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" #include "llvm/Support/Debug.h" @@ -60,6 +64,8 @@ ClusteringOpts HostClusterKindAttr::getClusterKindOptions( ClusterRange) { return true; }; opts.clusterTarget = *this; opts.isClusterableOp = [&solver](Operation *op) { + if (llvm::isa(op)) + return true; return plan::detail::shouldRunOnHost(op, solver); }; return opts; @@ -106,7 +112,12 @@ HostClusterKindAttr::getClusterOutliningOptions( std::function HostClusterKindAttr::getClusterFilter() const { - return [](const Cluster &cluster) { return true; }; + return [](const Cluster &cluster) { + return !llvm::all_of(cluster, [](Operation *op) { + return op->hasTrait() || + llvm::isa(op); + }); + }; } //===----------------------------------------------------------------------===// @@ -117,6 +128,69 @@ std::string TensorRTClusterKindAttr::getClusterKindName() const { return "tensorrt"; } +static ShapeInfoCallbacks getShapeInfoCallbacks() { + ShapeInfoCallbacks shapeInfoCallbacks{}; + shapeInfoCallbacks.isElementValueEqualToConstant = + [](TensorElementValue elementValue, + Attribute constValue) -> std::optional { + auto withValuesOp = + elementValue.getTensor().getDefiningOp(); + if (!withValuesOp) + return {}; + Value element = withValuesOp.getElements()[elementValue.getLinearIndex()]; + + Attribute intAttr = {}; + if (!matchPattern(element, m_Constant(&intAttr))) + return {}; + return intAttr == constValue; + }; + shapeInfoCallbacks.isElementValueEqualToShapeDimExtent = + [](TensorElementValue elementValue, + TensorShapeDimExtent dimExtent) -> std::optional { + assert(elementValue.getTensor().getType().getElementType().isIntOrIndex() && + "expected int or integer tensor"); + auto withValuesOp = + elementValue.getTensor().getDefiningOp(); + if (!withValuesOp) + return {}; + + /// Scalar value will be of type equivalent to `elementValue.tensor` element + /// type. + Value scalarValue = + withValuesOp.getElements()[elementValue.getLinearIndex()]; + + /// Check if it is statically known to be equal to the `dimExtent`. + IntegerAttr constInt = {}; + if (std::optional staticSize = dimExtent.getConstantSize()) { + if (matchPattern(scalarValue, m_Constant(&constInt))) + return constInt.getValue().getSExtValue() == *staticSize; + } + + /// Otherwise, we need to check equivalence of the dynamic values. + /// There are two cases to consider: either both have the same type, or + /// `plan.with_shape` may have index type scalars and `plan.with_values` + /// will have a more specific integer type that matches the shape tensor. + /// We can try to handle the later case where the conversion is done by + /// `arith.index_cast`. + /// TODO: we should change the shape materialization pass so that we infer + /// the desired shape tensor element type and have all `plan.with_shape` + /// materialize with that scalar type using casts. + if (auto withShape = dimExtent.tensor.getDefiningOp()) { + Value dimExtentValue = withShape.getShape()[dimExtent.dim]; + if (dimExtentValue == scalarValue) + return true; + if (auto indexCastOp = + dyn_cast(scalarValue.getDefiningOp())) { + if (indexCastOp.getOperand() == dimExtentValue) + return true; + } + } + + return {}; + }; + return shapeInfoCallbacks; +} + /// ClusteringOpts that identifies groups of TensorRT operations and will be /// clustered into one TensorRT function (which is eventually translated to a /// engine). @@ -151,7 +225,8 @@ ClusteringOpts TensorRTClusterKindAttr::getClusterKindOptions( loweringOptions.setTensorRTVersion(*trtMajorVersion); TensorRTTypeConverter typeConverter(ctx, loweringOptions); TensorRTConversionTarget target(*ctx, typeConverter); - populateStablehloToTensorRtConversionPattern(typeConverter, patterns); + populateStablehloToTensorRtConversionPattern(typeConverter, patterns, + getShapeInfoCallbacks()); populateChloToTensorRtLegalityAndPatterns(typeConverter, target, patterns); // Analyze the convertible operations. diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/IR/CMakeLists.txt b/mlir-tensorrt/compiler/lib/Dialect/Plan/IR/CMakeLists.txt index 3bb5f00d7..4a1ed1fa6 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/Plan/IR/CMakeLists.txt +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/IR/CMakeLists.txt @@ -23,5 +23,6 @@ add_mlir_tensorrt_dialect_library(MLIRTensorRTPlanDialect MLIRTensorRTStablehloScalarToArith MLIRTensorRTStablehloToTensorRT MLIRTensorRTSupportStatus + MLIRTensorRTUtilsShapeInfo StablehloOps ) diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/IR/PlanOps.cpp b/mlir-tensorrt/compiler/lib/Dialect/Plan/IR/PlanOps.cpp index f75a8b2db..33d805d6e 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/Plan/IR/PlanOps.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/IR/PlanOps.cpp @@ -25,6 +25,7 @@ #include "mlir-tensorrt/Dialect/Plan/IR/Plan.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" @@ -240,6 +241,7 @@ BoundsAttr BoundsAttr::get(MLIRContext *ctx, BoundsKind kind, DenseIntElementsAttr::get(type, min), DenseIntElementsAttr::get(type, max)); } + BoundsAttr BoundsAttr::getChecked(llvm::function_ref emitError, MLIRContext *ctx, BoundsKind kind, ArrayRef min, @@ -262,6 +264,28 @@ BoundsAttr::getChecked(llvm::function_ref emitError, return BoundsAttr::get(ctx); } +LogicalResult BoundsAttr::getShapeBounds(SmallVectorImpl &min, + SmallVectorImpl &max) const { + ArrayRef minShape = getMinShape(); + min.assign(minShape.begin(), minShape.end()); + ArrayRef maxShape = getMaxShape(); + max.assign(maxShape.begin(), maxShape.end()); + return success(); +} + +LogicalResult +BoundsAttr::getIntegerValueBounds(SmallVectorImpl &min, + SmallVectorImpl &max) const { + if (!isValueBound() || !getMinValues().getElementType().isIntOrIndex()) + return failure(); + auto mins = getMinValues().getValues(); + min.assign(mins.begin(), mins.end()); + + auto maxs = getMaxValues().getValues(); + max.assign(maxs.begin(), maxs.end()); + return success(); +} + //===----------------------------------------------------------------------===// // InlineGroupOp //===----------------------------------------------------------------------===// @@ -306,9 +330,8 @@ verifyBoundsAttr(StringRef argOrResult, unsigned idx, Type type, if (boundsAttr.isNone()) return success(); if (boundsAttr.isShapeBound()) { - int64_t boundsLength = boundsAttr.getMinShape().size(); - if (std::max(shapedType.getRank(), 1) != - std::max(boundsLength, 1)) + const int64_t boundsLength = boundsAttr.getMinShape().size(); + if (shapedType.getRank() != boundsLength) return emitOpError() << argOrResult << " #" << idx << " has type " << type << ", whose rank is not equal to the rank of the " @@ -324,18 +347,17 @@ verifyBoundsAttr(StringRef argOrResult, unsigned idx, Type type, "only allowed for staticly shaped operands"; if (boundsAttr.isValueBound()) { - Type elType = boundsAttr.getMinValues().getElementType(); + Type elType = boundsAttr.getValuesType().getElementType(); if (elType != shapedType.getElementType()) return emitOpError() << argOrResult << " #" << idx << " expected element type of value bounds elements (" << elType << ") to be compatible with the type (" << type << ")"; - if (boundsAttr.getMinValues().getType().getShape() != - shapedType.getShape()) + if (boundsAttr.getValuesType().getShape() != shapedType.getShape()) return emitOpError() << argOrResult << " #" << idx << " expected type of values bounds elements (" - << boundsAttr.getMinValues().getType() + << boundsAttr.getValuesType() << ") to be compatible with the type (" << type << ")"; } @@ -351,21 +373,23 @@ verifyBoundsAttr(StringRef argOrResult, unsigned idx, Type type, << argOrResult << " #" << idx << " of type " << type << ", but got " << boundsAttr; if (boundsAttr.isValueBound()) { - int64_t numEls = boundsAttr.getMinValues().getNumElements(); - if (numEls != 1) + + if (boundsAttr.getValuesType().getRank() != 0) return emitOpError() << argOrResult << " #" << idx - << " expected number of values bounds elements (" << numEls - << ") to equal number of elements of the type (1)"; - Type elType = boundsAttr.getMinValues().getElementType(); + << " type expects rank-0 value bounds type, but got " + << boundsAttr.getValuesType(); + + Type elType = boundsAttr.getValuesType().getElementType(); if (elType != type) return emitOpError() << argOrResult << " #" << idx << " expected element type of value bounds elements (" << elType << ") to be compatible with the type (" << type << ")"; } + return success(); } - // For all other types, the bounds kind must be none. + if (!boundsAttr.isNone()) return emitOpError() << "expected only 'none' bounds for type " << type; return success(); @@ -723,9 +747,24 @@ struct PlanInlinerInterface : public DialectInlinerInterface { #include "mlir-tensorrt/Dialect/Plan/IR/PlanOps.cpp.inc" //===----------------------------------------------------------------------===// -// Dialect initialization +// PlanDialect Definitions //===----------------------------------------------------------------------===// +namespace { +class PlanDialectOpAsmInterface : public OpAsmDialectInterface { + using OpAsmDialectInterface::OpAsmDialectInterface; + /// Tells MLIR assembly printer/parser that the BoundsAttr can be + /// aliased using #bounds[num]. This make the IR more readable. + AliasResult getAlias(Attribute attr, raw_ostream &os) const override { + if (isa(attr)) { + os << "bounds"; + return AliasResult::OverridableAlias; + } + return AliasResult::NoAlias; + } +}; +} // namespace + void PlanDialect::initialize() { addOperations< #define GET_OP_LIST @@ -746,7 +785,7 @@ void PlanDialect::initialize() { (void)&generatedAttributePrinter; (void)&generatedAttributeParser; - addInterfaces(); + addInterfaces(); } Attribute PlanDialect::parseAttribute(DialectAsmParser &parser, @@ -771,3 +810,44 @@ void PlanDialect::printAttribute(Attribute attr, assert(it != attrPrintingHooks.end() && "printing unknown type"); it->getSecond()(attr, printer); } + +static LogicalResult verifyBoundsAttribute(Operation *op, unsigned argIndex, + plan::BoundsAttr attr, + StringRef attrName) { + auto func = dyn_cast(op); + if (!func) + return success(); + + Type argType = func.getArgument(argIndex).getType(); + return verifyBoundsAttr( + "arg", argIndex, argType, attr, + [&]() -> InFlightDiagnostic { return op->emitOpError(); }); + + return success(); +} + +LogicalResult PlanDialect::verifyRegionArgAttribute(Operation *op, + unsigned regionIndex, + unsigned argIndex, + NamedAttribute attribute) { + if (attribute.getName() == getValueBoundsAttrName()) { + auto boundsAttr = dyn_cast(attribute.getValue()); + if (!boundsAttr || !boundsAttr.isValueBound()) + return op->emitError() + << "expected named attribute \"" << getValueBoundsAttrName() + << "\" to be a \"#plan.bounds\" attribute containing value bounds"; + + return verifyBoundsAttribute(op, argIndex, boundsAttr, attribute.getName()); + } + + if (attribute.getName() == getShapeBoundsAttrName()) { + auto boundsAttr = dyn_cast(attribute.getValue()); + if (!boundsAttr || !boundsAttr.isShapeBound()) + return op->emitError() + << "expected named attribute \"" << getShapeBoundsAttrName() + << "\" to be a \"#plan.bounds\" attribute containing shape bounds"; + return verifyBoundsAttribute(op, argIndex, boundsAttr, attribute.getName()); + } + + return success(); +} diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/AllocTensors.cpp b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/AllocTensors.cpp index d2d9785ef..b3e133420 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/AllocTensors.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/AllocTensors.cpp @@ -23,6 +23,7 @@ //===----------------------------------------------------------------------===// #include "mlir-tensorrt-dialect/Analysis/TensorKindAnalysis.h" #include "mlir-tensorrt-dialect/Interface/TensorKindOpInterface.h" +#include "mlir-tensorrt/Dialect/Plan/IR/Plan.h" #include "mlir-tensorrt/Dialect/Plan/Transforms/Passes.h" #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" @@ -521,13 +522,14 @@ updateFunctionWithNewDpsArg(func::FuncOp func, Location loc, Type argType, UnitAttr::get(ctx))}); func.insertArgument(func.getNumArguments(), argType, argAttrs, loc); - if (auto boundsAttr = - func.getResultAttr(tiedResult, "tensorrt.shape_profile")) - func.setArgAttr(func.getNumArguments() - 1, "tensorrt.shape_profile", - boundsAttr); - if (auto boundsAttr = func.getResultAttr(tiedResult, "tensorrt.value_bounds")) - func.setArgAttr(func.getNumArguments() - 1, "tensorrt.value_bounds", - boundsAttr); + if (auto boundsAttr = func.getResultAttr( + tiedResult, plan::PlanDialect::getShapeBoundsAttrName())) + func.setArgAttr(func.getNumArguments() - 1, + plan::PlanDialect::getShapeBoundsAttrName(), boundsAttr); + if (auto boundsAttr = func.getResultAttr( + tiedResult, plan::PlanDialect::getValueBoundsAttrName())) + func.setArgAttr(func.getNumArguments() - 1, + plan::PlanDialect::getValueBoundsAttrName(), boundsAttr); return func.getArguments().back(); } diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/MaterializeShapeCalculations.cpp b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/MaterializeShapeCalculations.cpp index 6b52810bc..490b60aab 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/MaterializeShapeCalculations.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/MaterializeShapeCalculations.cpp @@ -22,6 +22,7 @@ /// //===----------------------------------------------------------------------===// #include "mlir-tensorrt-dialect/Analysis/TensorKindAnalysis.h" +#include "mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.h" #include "mlir-tensorrt/Dialect/Plan/Analysis/BoundsAnalysis.h" #include "mlir-tensorrt/Dialect/Plan/IR/Plan.h" #include "mlir-tensorrt/Dialect/StableHloExt/Transforms/Patterns.h" @@ -731,6 +732,78 @@ static void addCanonicalizationPatterns(RewritePatternSet &patterns) { (Ops::getCanonicalizationPatterns(patterns, patterns.getContext()), ...); } +/// Convert 'tensorrt' dialect arg/result bounds attribute into 'plan' bounds +/// attribute. +static Attribute convertArgOrResultAttr(OpBuilder &b, Type type, + tensorrt::ShapeProfileAttr trtAttr, + bool isValueBounds) { + MLIRContext *ctx = b.getContext(); + if (isValueBounds) { + Type elementType = mlir::getElementTypeOrSelf(type); + assert(elementType.isIntOrIndex() && "expected int or index element type"); + SmallVector boundsShape; + if (auto shapedType = dyn_cast(type)) + boundsShape = llvm::to_vector(shapedType.getShape()); + auto boundsValueType = RankedTensorType::get(boundsShape, elementType); + auto convertI64ArrayToDenseElements = [&](ArrayRef i64Vals) { + return DenseElementsAttr::get( + boundsValueType, + llvm::map_to_vector(i64Vals, [&](int64_t i64Val) -> Attribute { + return b.getIntegerAttr(elementType, i64Val); + })); + }; + return plan::BoundsAttr::get( + ctx, BoundsKind::Value, DenseI64ArrayAttr{}, DenseI64ArrayAttr{}, + convertI64ArrayToDenseElements(trtAttr.getMin()), + convertI64ArrayToDenseElements(trtAttr.getMax())); + } + return plan::BoundsAttr::get(ctx, plan::BoundsKind::Shape, trtAttr.getMin(), + trtAttr.getMax()); +} + +static void convertArgAndResultAttrs(OpBuilder &b, func::FuncOp op) { + StringRef tensorrtShapeBoundsAttrName = + mlir::tensorrt::TensorRTDialect::getShapeProfileArgAttrName(); + StringRef tensorrtValueBoundsAttrName = + mlir::tensorrt::TensorRTDialect::getShapeTensorValueBoundsArgAttrName(); + + StringRef planShapeBoundsAttrName = + mlir::plan::PlanDialect::getShapeBoundsAttrName(); + StringRef planValueBoundsAttrName = + mlir::plan::PlanDialect::getValueBoundsAttrName(); + + for (unsigned idx = 0; idx < op.getNumArguments(); idx++) { + Type type = op.getArgumentTypes()[idx]; + if (auto attr = op.getArgAttrOfType( + idx, tensorrtShapeBoundsAttrName)) { + op.removeArgAttr(idx, tensorrtShapeBoundsAttrName); + op.setArgAttr(idx, planShapeBoundsAttrName, + convertArgOrResultAttr(b, type, attr, false)); + } + if (auto attr = op.getArgAttrOfType( + idx, tensorrtValueBoundsAttrName)) { + op.removeArgAttr(idx, tensorrtValueBoundsAttrName); + op.setArgAttr(idx, planValueBoundsAttrName, + convertArgOrResultAttr(b, type, attr, true)); + } + } + for (unsigned idx = 0; idx < op.getNumResults(); idx++) { + Type type = op.getResultTypes()[idx]; + if (auto attr = op.getResultAttrOfType( + idx, tensorrtShapeBoundsAttrName)) { + op.removeArgAttr(idx, tensorrtShapeBoundsAttrName); + op.setResultAttr(idx, planShapeBoundsAttrName, + convertArgOrResultAttr(b, type, attr, false)); + } + if (auto attr = op.getResultAttrOfType( + idx, tensorrtValueBoundsAttrName)) { + op.removeArgAttr(idx, tensorrtValueBoundsAttrName); + op.setResultAttr(idx, planValueBoundsAttrName, + convertArgOrResultAttr(b, type, attr, true)); + } + } +} + //===----------------------------------------------------------------------===// // MaterializeShapeCalculationsPass //===----------------------------------------------------------------------===// @@ -745,9 +818,15 @@ class MaterializeShapeCalculationsPass Operation *op = getOperation(); MLIRContext *ctx = &getContext(); IRRewriter rewriter(ctx); - SymbolTableCollection symbolTable; + /// Convert `tensorrt` dialect bounds func arg/result attributes. + /// TODO: should this be moved to a dedicated pass? + op->walk([&](func::FuncOp func) { + convertArgAndResultAttrs(rewriter, func); + return WalkResult::skip(); + }); + // Run TensorKindAnalysis and populate the `plan.with_shape|with_values` // operations. { diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/PopulateFunctionBoundsAttributes.cpp b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/PopulateFunctionBoundsAttributes.cpp index 8a2c595e7..adc8bb992 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/PopulateFunctionBoundsAttributes.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/PopulateFunctionBoundsAttributes.cpp @@ -91,10 +91,8 @@ class PlanPopulateFunctionBoundsAttributesPass llvm::none_of( func.getArgAttrs()->getAsRange(), [&](DictionaryAttr dict) { - return dict.getNamed(tensorrt::TensorRTDialect:: - getShapeProfileArgAttrName()) || - dict.getNamed(tensorrt::TensorRTDialect:: - getShapeTensorValueBoundsArgAttrName()); + return dict.getNamed(PlanDialect::getShapeBoundsAttrName()) || + dict.getNamed(PlanDialect::getValueBoundsAttrName()); })) return; @@ -137,9 +135,8 @@ class PlanPopulateFunctionBoundsAttributesPass << "failed to compute lower/upper shape bounds attribute"; return signalPassFailure(); } - func.setResultAttr( - idx, tensorrt::TensorRTDialect::getShapeProfileArgAttrName(), - boundsAttr); + func.setResultAttr(idx, plan::PlanDialect::getShapeBoundsAttrName(), + boundsAttr); continue; } @@ -167,10 +164,8 @@ class PlanPopulateFunctionBoundsAttributesPass return signalPassFailure(); } - func.setResultAttr( - idx, - tensorrt::TensorRTDialect::getShapeTensorValueBoundsArgAttrName(), - boundsAttr); + func.setResultAttr(idx, plan::PlanDialect::getValueBoundsAttrName(), + boundsAttr); } } }; diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/RefineTypes.cpp b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/RefineTypes.cpp index 4142f4ace..065916f96 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/RefineTypes.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/RefineTypes.cpp @@ -235,7 +235,8 @@ struct AbsorbCastsIntoFuncReturnPattern } }; -struct WithShapeAbsorbCastPattern : public OpRewritePattern { +/// Push generalizing `tensor.cast` down below the `plan.with_shape`. +struct WithShapeCastPushDownPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(WithShapeOp op, PatternRewriter &rewriter) const override { @@ -267,46 +268,14 @@ struct WithShapeAbsorbCastPattern : public OpRewritePattern { /// `dims` yields an opportunity to refine the type of `with_shape`, then /// `stablehlo_op` can also be refined. The refinements are made (and casts are /// inserted if required). -struct StableHloRefineTypeFromWithShapeGeneric - : public OpRewritePattern { +struct RefineTypeFromWithShapeGeneric : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(WithShapeOp withOp, PatternRewriter &rewriter) const override { auto producer = withOp.getOperand().getDefiningOp(); - if (!producer || !producer->hasOneUse() || - !isa(producer->getDialect())) - return failure(); - - // Create a new shape and try to refine it. - std::optional> newShape = - getRefinedShape(withOp.getShape(), withOp.getOperand().getType()); - if (!newShape) - return failure(); - - // Update type of the producer. - updateTypeInPlaceAndMaybeInsertCast( - rewriter, withOp.getOperand(), - withOp.getOperand().getType().clone(*newShape)); - - // Update type of the WithShapeOp. - updateTypeInPlaceAndMaybeInsertCast(rewriter, withOp.getResult(), - withOp.getType().clone(*newShape)); - return success(); - } -}; - -/// Given a pattern `plan.with_shape(tensorrt_op, dims...)`, if inspection of -/// `dims` yields an opportunity to refine the type of `with_shape`, then -/// `tensorrt_op` can also be refined. The refinements are made (and casts are -/// inserted if required). -struct TensorRTRefineTypeFromWithShapeGeneric - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(WithShapeOp withOp, - PatternRewriter &rewriter) const override { - auto producer = withOp.getOperand().getDefiningOp(); - if (!producer || !producer->hasOneUse() || - !isa(producer->getDialect())) + if (!producer || + !isa_and_present(producer->getDialect())) return failure(); // Create a new shape and try to refine it. @@ -347,9 +316,8 @@ class PlanRefineTypesPass RefineDynamicBroadcast, RefineDynamicIota, SimplifyIdentityDynamicBroadcast, - StableHloRefineTypeFromWithShapeGeneric, - WithShapeAbsorbCastPattern, - TensorRTRefineTypeFromWithShapeGeneric + RefineTypeFromWithShapeGeneric, + WithShapeCastPushDownPattern >(ctx); // clang-format on tensor::CastOp::getCanonicalizationPatterns(patterns, ctx); diff --git a/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/CMakeLists.txt b/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/CMakeLists.txt index 583d389cc..0196dd845 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/CMakeLists.txt +++ b/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_tensorrt_library(MLIRTensorRTStableHloExtTransforms + CanonicalizeConvolution.cpp CanonicalizeDotGeneral.cpp CanonicalizeGather.cpp CanonicalizeScatter.cpp @@ -7,9 +8,6 @@ add_mlir_tensorrt_library(MLIRTensorRTStableHloExtTransforms ExpandTuples.cpp GatherToSlice.cpp LowerSpecialCustomCalls.cpp - StablehloInputPreprocessing.cpp - StablehloPrepareConvolution.cpp - StablehloPrepareScatter.cpp StablehloRaiseQDQ.cpp LINK_LIBS PUBLIC diff --git a/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/StablehloPrepareConvolution.cpp b/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/CanonicalizeConvolution.cpp similarity index 90% rename from mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/StablehloPrepareConvolution.cpp rename to mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/CanonicalizeConvolution.cpp index e61d472b2..131ff87c2 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/StablehloPrepareConvolution.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/CanonicalizeConvolution.cpp @@ -22,12 +22,20 @@ /// tensorrt dialect. /// //===----------------------------------------------------------------------===// +#include "mlir-tensorrt/Dialect/StableHloExt/Transforms/Passes.h" #include "mlir-tensorrt/Dialect/StableHloExt/Transforms/Patterns.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "stablehlo/dialect/StablehloOps.h" +namespace mlir::stablehlo_ext { +#define GEN_PASS_DEF_CANONICALIZECONVOLUTIONPASS +#include "mlir-tensorrt/Dialect/StableHloExt/Transforms/Passes.h.inc" +} // namespace mlir::stablehlo_ext + using namespace mlir; +using namespace mlir::stablehlo; /// Expand (n, c, h) input to (n, c, 1, h) input. static Value stablehloExpandSpatialDims(OpBuilder &b, Location loc, @@ -228,3 +236,23 @@ void stablehlo_ext::populateCanonicalizeStablehloConvolutionPatterns( RewritePatternSet &patterns) { patterns.insert(patterns.getContext()); } + +namespace { +class CanonicalizeConvolutionPass + : public stablehlo_ext::impl::CanonicalizeConvolutionPassBase< + CanonicalizeConvolutionPass> { +public: + using Base::Base; + void runOnOperation() override { + Operation *op = getOperation(); + MLIRContext *ctx = &getContext(); + RewritePatternSet patterns(ctx); + stablehlo_ext::populateCanonicalizeStablehloConvolutionPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) { + emitError(op->getLoc()) + << "failed to apply rewrite patterns in " << getArgument(); + return signalPassFailure(); + } + } +}; +} // namespace diff --git a/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/CanonicalizeGather.cpp b/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/CanonicalizeGather.cpp index abfb273c5..b3916bae0 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/CanonicalizeGather.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/CanonicalizeGather.cpp @@ -52,8 +52,8 @@ collapseSliceDims(RewriterBase &rewriter, Location loc, return {}; return cast>( - rewriter.create(loc, input, *reassociations) - .getResult()); + stablehlo_ext::createCollapsingReshape(rewriter, loc, input, + *reassociations)); } // Expands the first dimension of `input` into the shape of `startIndices`, @@ -86,13 +86,12 @@ expandBatchDimension(RewriterBase &rewriter, Location loc, return {}; if (static_cast(newShape.size()) > input.getType().getRank()) return cast>( - rewriter - .create(loc, newType, input, *reassociations) - .getResult()); + stablehlo_ext::createExpandingReshape(rewriter, loc, newType, input, + *reassociations)); + return cast>( - rewriter - .create(loc, newType, input, *reassociations) - .getResult()); + stablehlo_ext::createCollapsingReshape(rewriter, loc, input, + *reassociations)); } static TypedValue diff --git a/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/CanonicalizeScatter.cpp b/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/CanonicalizeScatter.cpp index 2056f9001..c422723e3 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/CanonicalizeScatter.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/CanonicalizeScatter.cpp @@ -14,6 +14,7 @@ /// //===----------------------------------------------------------------------===// #include "mlir-tensorrt/Dialect/StableHloExt/Transforms/Passes.h" +#include "mlir-tensorrt/Dialect/StableHloExt/Transforms/Patterns.h" #include "mlir-tensorrt/Dialect/StableHloExt/Utils/GatherScatterUtils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -101,7 +102,7 @@ static SmallVector reshapeUpdatesToEnsureSingleScatterDimension( reassociation.push_back({i}); return llvm::map_to_vector(updates, [&](Value update) -> Value { - return b.create(loc, update, reassociation); + return createCollapsingReshape(b, loc, update, reassociation); }); } if (numScatterDims == 0) { @@ -169,7 +170,7 @@ struct CanonicalizeScatterPattern : public OpRewritePattern { LogicalResult matchAndRewrite(ScatterOp scatterOp, PatternRewriter &rewriter) const override { - if (isCanonicalScatter(scatterOp)) + if (isCanonicalScatter(scatterOp) || isCanonicalScatterNd(scatterOp)) return failure(); Location loc = scatterOp.getLoc(); @@ -220,6 +221,126 @@ struct CanonicalizeScatterPattern : public OpRewritePattern { return success(); } }; +} // namespace + +// Ensure that there are enough "inserted window dimensions" so that the window +// update and the batch dims access disjoint areas of the result index space. +// This ensures we can map to `tensorr.scatter` or `onnx.scatter_nd`. This must +// be called after ensuring that there is a single scatter batch dimension. +static FailureOr> +stablehloReshapeScatterUpdatesToAddInsertedDims(OpBuilder &b, Location loc, + ValueRange updates, + int64_t indexDepth, + int64_t inputRank) { + assert(indexDepth >= 1 && "expected non-zero index depth"); + const size_t numScatterBatchDims = 1; + RankedTensorType updateType = + cast(updates.front().getType()); + const int64_t currUpdateSliceRank = + updateType.getRank() - numScatterBatchDims; + const int64_t expectedUpdateSliceRank = inputRank - indexDepth; + const int64_t expectedInsertWindowDims = indexDepth; + assert(expectedInsertWindowDims >= 1 && + "expected positive number of window insert dims"); + + // No need to do anything. + if (expectedUpdateSliceRank == currUpdateSliceRank) + return llvm::to_vector(updates); + + // Otherwise, we need to drop leading dimensions (hopefully). If leading + // dimensions are not unit dims, then we can't proceed. + if (currUpdateSliceRank > expectedUpdateSliceRank) { + const int64_t dimToDrop = currUpdateSliceRank - expectedUpdateSliceRank; + + if (!llvm::all_equal( + updateType.getShape().drop_front(1).take_front(dimToDrop)) || + updateType.getDimSize(1) != 1) + return failure(); + + RankedTensorType newShape = + RankedTensorType::Builder(updateType) + .setShape(updateType.getShape().drop_front(1 + dimToDrop)) + .insertDim(updateType.getDimSize(0), 0); + return llvm::to_vector(llvm::map_range(updates, [&](Value update) -> Value { + return b.create(loc, newShape, update); + })); + } + + return failure(); +} + +namespace { + +/// Simplify `stablehlo.scatter` to conform with `tensorrt.scatter`. +struct StablehloCanonicalizeScatterToTensorRtScatterNdFormat + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(stablehlo::ScatterOp op, + PatternRewriter &rewriter) const override { + + // Only proceed if we are in the StableHLO canonical form. This is covered + // by the "stablehlo-ext-canonicalize-scatter" pass that runs before this + // pass. All scatter ops should be in stablehlo canonical form at this + // point. + if (!stablehlo_ext::isCanonicalScatter(op)) + return failure(); + + RankedTensorType canonicalizedInputType = + cast(op.getInputs().front().getType()); + RankedTensorType canonicalizedIndexType = + cast(op.getScatterIndices().getType()); + + // Reshape the updates if possible. + int64_t inputRank = canonicalizedInputType.getRank(); + int64_t indexDepth = canonicalizedIndexType.getDimSize(1); + FailureOr> canonicalizedUpdates = + stablehloReshapeScatterUpdatesToAddInsertedDims( + rewriter, op.getLoc(), op.getUpdates(), indexDepth, inputRank); + if (failed(canonicalizedUpdates)) + return rewriter.notifyMatchFailure(op, "failed to canonicalize updates"); + + // Create the new scatter op. + auto canonicalizedUpdatesType = + cast(canonicalizedUpdates->front().getType()); + assert(((canonicalizedInputType.getRank() - indexDepth) + + (canonicalizedIndexType.getRank() - 1)) == + canonicalizedUpdatesType.getRank() && + "expected slice size to equal inputRank - index_depth"); + auto newConfig = stablehlo::ScatterDimensionNumbersAttr::get( + getContext(), + /*updateWindowDims=*/ + llvm::to_vector( + llvm::seq(1, canonicalizedUpdatesType.getRank())), + /*insertedWindowDims=*/ + llvm::to_vector(llvm::seq(0, indexDepth)), + /*inputBatchingDims=*/{}, + /*scatterIndicesBatchingDims=*/{}, + /*scatterDimsToOperandDims=*/ + llvm::to_vector(llvm::seq(0, indexDepth)), 1); + auto scatterOp = rewriter.create( + op.getLoc(), TypeRange(ValueRange(op.getInputs())), op.getInputs(), + op.getScatterIndices(), *canonicalizedUpdates, newConfig); + Region ®ion = scatterOp.getUpdateComputation(); + rewriter.inlineRegionBefore(op.getUpdateComputation(), region, + region.end()); + rewriter.replaceOp(op, scatterOp.getResults()); + + scatterOp->setAttr("tensorrt.canonicalized_scatter", + rewriter.getUnitAttr()); + + return success(); + } +}; +} // namespace + +void stablehlo_ext::populateCanonicalizeStablehloScatterPatterns( + RewritePatternSet &patterns) { + patterns.insert( + patterns.getContext()); +} + +namespace { struct CanonicalizeScatterPass : stablehlo_ext::impl::CanonicalizeScatterPassBase< @@ -227,7 +348,7 @@ struct CanonicalizeScatterPass void runOnOperation() override { MLIRContext *ctx = &getContext(); RewritePatternSet patterns(ctx); - patterns.add(ctx); + populateCanonicalizeStablehloScatterPatterns(patterns); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { emitError(getOperation()->getLoc()) diff --git a/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/ConstantFolding.cpp b/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/ConstantFolding.cpp index 89038a540..a36004bfb 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/ConstantFolding.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/ConstantFolding.cpp @@ -933,6 +933,39 @@ struct ConstFoldGatherOnSplat : public OpRewritePattern { } }; +//===----------------------------------------------------------------------===// +// LogicalRightShiftOp +//===----------------------------------------------------------------------===// + +/// Fold trivial `stablehlo.logical_shift_right` when the shift has a greater +/// width than the element type. +struct RewriteTrivialLogicalRightShiftPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(stablehlo::ShiftRightLogicalOp op, + PatternRewriter &rewriter) const override { + TensorType resultType = op.getType(); + + // Make sure we rule out index type, since 'getElementTypeBitWidth' will + // fail in that case. + if (resultType.isIndex() || !resultType.hasStaticShape()) + return failure(); + + int64_t bitWidth = resultType.getElementTypeBitWidth(); + ElementsAttr attr; + + // Try to match a constant shift amount. + if (!matchPattern(op.getRhs(), m_Constant(&attr)) || !attr.isSplat()) + return failure(); + + int64_t shiftAmount = attr.getSplatValue().getSExtValue(); + if (shiftAmount < bitWidth) + return failure(); + return replaceOpWithNewOpAndMaybeCast( + rewriter, op, rewriter.getZeroAttr(resultType)); + } +}; + //===----------------------------------------------------------------------===// // Misc Patterns //===----------------------------------------------------------------------===// @@ -991,7 +1024,7 @@ struct AbsorbTensorCastProducer : public RewritePattern { LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - if (!isa(op->getDialect()) || + if (!isa_and_present(op->getDialect()) || // Composite op types cannot be refined in-place. isa(op)) return failure(); @@ -1066,6 +1099,7 @@ class ConstantFoldingPass FixInvalidReturnWorkaround, FoldAndOp, FoldOrOp, + RewriteTrivialLogicalRightShiftPattern, RsqrtFolder, SimplifyReshapeBroadcastInDimReshape, SimplifyTrivialMinOrTrivalMax, diff --git a/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/StablehloInputPreprocessing.cpp b/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/StablehloInputPreprocessing.cpp deleted file mode 100644 index cd7248b1f..000000000 --- a/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/StablehloInputPreprocessing.cpp +++ /dev/null @@ -1,200 +0,0 @@ -//===- StablehloInputPreprocessing.cpp -----------------------------------===// -// -// SPDX-FileCopyrightText: Copyright 2024 NVIDIA CORPORATION & AFFILIATES. -// All rights reserved. -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// -/// -/// Implements a pass that applies various patterns to StableHLO IR to prepare -/// it for conversion to the TensorRT dialect. -/// -//===----------------------------------------------------------------------===// -#include "mlir-tensorrt/Dialect/StableHloExt/Transforms/Passes.h" -#include "mlir-tensorrt/Dialect/StableHloExt/Transforms/Patterns.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Func/Transforms/FuncConversions.h" -#include "mlir/Dialect/SCF/Transforms/Transforms.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/OperationSupport.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "stablehlo/dialect/ChloOps.h" -#include "stablehlo/dialect/StablehloOps.h" - -namespace mlir::stablehlo_ext { -#define GEN_PASS_DEF_STABLEHLOINPUTPREPROCESSINGPASS -#include "mlir-tensorrt/Dialect/StableHloExt/Transforms/Passes.h.inc" -} // namespace mlir::stablehlo_ext - -using namespace mlir; -using namespace mlir::stablehlo; - -namespace { -/// Fold trivial `stablehlo.logical_shift_right` when the shift has a greater -/// width than the element type. -struct StablehloRewriteTrivialLogicalRightShift - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(stablehlo::ShiftRightLogicalOp op, - PatternRewriter &rewriter) const override { - TensorType resultType = op.getType(); - int64_t bitWidth = resultType.getElementTypeBitWidth(); - ElementsAttr attr; - // Try to match a constant shift amount. - if (matchPattern(op.getRhs(), m_Constant(&attr))) { - if (!attr.isSplat()) - return failure(); - int64_t shiftAmount = attr.getSplatValue().getSExtValue(); - if (shiftAmount < bitWidth) - return failure(); - rewriter.replaceOpWithNewOp( - op, rewriter.getZeroAttr(resultType)); - return success(); - } - return failure(); - } -}; -} // namespace - -static Value makeSplatF32TensorConstantLike(OpBuilder &b, Location loc, - float constant, Value val) { - auto rtt = cast(val.getType()); - return b.create(loc, - DenseElementsAttr::get(rtt, constant)); -} - -static Value makeSplatTensorInfConstantLike(OpBuilder &b, Location loc, - Value val, bool isNegInf) { - auto ty = cast(getElementTypeOrSelf(val.getType())); - return makeSplatF32TensorConstantLike( - b, loc, - APFloat::getInf(ty.getFloatSemantics(), isNegInf).convertToFloat(), val); -} - -/// Reproduced from `chlo-legalize-to-stablehlo` from here: -/// https://github.com/openxla/stablehlo/blob/c3f456500f0f2e96fdb4a98fde4fbe48b4d624b8/stablehlo/transforms/ChloLegalizeToStablehlo.cpp -/// (Apache 2.0 License: -/// https://github.com/openxla/stablehlo/blob/c3f456500f0f2e96fdb4a98fde4fbe48b4d624b8/LICENSE) -/// TODO: remove this once we modify upstream to expose this lowering pattern in -/// a public header so we don't have to copy-paste like this. -static Value erfInv32Stablehlo(RewriterBase &b, Location loc, ValueRange args) { - constexpr int kDegree = 9; - constexpr std::array wLessThan5Constants = { - 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, - -4.39150654e-06f, 0.00021858087f, -0.00125372503f, - -0.00417768164f, 0.246640727f, 1.50140941f}; - constexpr std::array wGreaterThan5Constants = { - -0.000200214257f, 0.000100950558f, 0.00134934322f, - -0.00367342844f, 0.00573950773f, -0.0076224613f, - 0.00943887047f, 1.00167406f, 2.83297682f}; - - Value x = args[0]; - // Compute logarithm of (1+arg) using log1p(arg) which is more precise than - // log(1+arg) when arg is close to zero. For more details, see - // https://en.cppreference.com/w/cpp/numeric/math/log1p - Value minusXSquared = - b.create(loc, x, b.create(loc, x)); - Value w = b.create( - loc, b.create(loc, minusXSquared)); - - Value lt = b.create( - loc, w, makeSplatF32TensorConstantLike(b, loc, 5.0, x), - stablehlo::ComparisonDirection::LT); - auto coefficient = [&](int i) { - return b.create( - loc, lt, - makeSplatF32TensorConstantLike(b, loc, wLessThan5Constants[i], x), - makeSplatF32TensorConstantLike(b, loc, wGreaterThan5Constants[i], x)); - }; - w = b.create( - loc, lt, - b.create( - loc, w, makeSplatF32TensorConstantLike(b, loc, 2.5, x)), - b.create( - loc, b.create(loc, w), - makeSplatF32TensorConstantLike(b, loc, 3.0, x))); - Value p = coefficient(0); - for (int i = 1; i < kDegree; ++i) { - p = b.create(loc, coefficient(i), - b.create(loc, p, w)); - } - - // Result modulo edge cases. - Value result = b.create(loc, p, x); - - // Handle edge cases, namely erfinv(+/-1) = +/-inf. (The above computation is - // indeterminate, and can give nan or -/+inf.) - return b.create( - loc, - b.create( - loc, b.create(loc, x), - makeSplatF32TensorConstantLike(b, loc, 1, x), - stablehlo::ComparisonDirection::EQ), - b.create( - loc, x, makeSplatTensorInfConstantLike(b, loc, x, false)), - result); -} - -namespace { - -struct ConvertErfInvOpToStablehlo final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(chlo::ErfInvOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Type elementType = op.getResult().getType().getElementType(); - if (!isa(elementType) || - elementType.getIntOrFloatBitWidth() > 32) - return failure(); - - Value operand = op.getOperand(); - if (!elementType.isF32()) - operand = rewriter.create(loc, operand, - rewriter.getF32Type()); - - Value result = erfInv32Stablehlo(rewriter, loc, operand); - if (result.getType() != op.getResult().getType()) - result = rewriter.create( - loc, result, op.getResult().getType().getElementType()); - rewriter.replaceOp(op, result); - return success(); - } -}; - -class StablehloInputPreprocessing - : public mlir::stablehlo_ext::impl::StablehloInputPreprocessingPassBase< - StablehloInputPreprocessing> { - using Base::Base; - - void runOnOperation() override { - Operation *op = getOperation(); - MLIRContext *ctx = &getContext(); - RewritePatternSet patterns(ctx); - patterns.insert(ctx); - stablehlo_ext::populateCanonicalizeStablehloConvolutionPatterns(patterns); - stablehlo_ext::populateCanonicalizeStablehloScatterPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) { - emitError(op->getLoc()) << "failed to run patterns in " << getArgument(); - return signalPassFailure(); - } - } -}; - -} // namespace diff --git a/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/StablehloPrepareScatter.cpp b/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/StablehloPrepareScatter.cpp deleted file mode 100644 index cdaf60af0..000000000 --- a/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/StablehloPrepareScatter.cpp +++ /dev/null @@ -1,199 +0,0 @@ -//===- StablehloPrepareScatter.cpp ---------------------------------------===// -// -// SPDX-FileCopyrightText: Copyright 2024 NVIDIA CORPORATION & AFFILIATES. -// All rights reserved. -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// -/// -/// Prepare `stablehlo.scatter` for conversion to TensorRT dialect. This pass -/// canonicalizes the scatter operations so that they have a form compatible -/// with the "onnx.ScatterNd" semantic, which is the same as the -/// `tensorrt.scatter` operation semantic. -/// -//===----------------------------------------------------------------------===// -#include "mlir-tensorrt/Dialect/StableHloExt/Transforms/Patterns.h" -#include "mlir-tensorrt/Dialect/StableHloExt/Utils/GatherScatterUtils.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "stablehlo/dialect/StablehloOps.h" -#include "llvm/Support/Debug.h" - -#define DEBUG_TYPE "tensorrt-stablehlo-prepare-scatter" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") - -using namespace mlir; -using namespace mlir::stablehlo; - -// Ensure that there are enough "inserted window dimensions" so that the window -// update and the batch dims access disjoint areas of the result index space. -// This ensures we can map to `tensorr.scatter` or `onnx.scatter_nd`. This must -// be called after ensuring that there is a single scatter batch dimension. -static FailureOr> -stablehloReshapeScatterUpdatesToAddInsertedDims(OpBuilder &b, Location loc, - ValueRange updates, - int64_t indexDepth, - int64_t inputRank) { - assert(indexDepth >= 1 && "expected non-zero index depth"); - const size_t numScatterBatchDims = 1; - RankedTensorType updateType = - cast(updates.front().getType()); - const int64_t currUpdateSliceRank = - updateType.getRank() - numScatterBatchDims; - const int64_t expectedUpdateSliceRank = inputRank - indexDepth; - const int64_t expectedInsertWindowDims = indexDepth; - assert(expectedInsertWindowDims >= 1 && - "expected positive number of window insert dims"); - - LLVM_DEBUG(DBGS() << "update slice rank expected = " - << expectedUpdateSliceRank << ", current = " - << currUpdateSliceRank << " = " << updateType << "\n"); - - // No need to do anything. - if (expectedUpdateSliceRank == currUpdateSliceRank) - return llvm::to_vector(updates); - - // Otherwise, we need to drop leading dimensions (hopefully). If leading - // dimensions are not unit dims, then we can't proceed. - if (currUpdateSliceRank > expectedUpdateSliceRank) { - const int64_t dimToDrop = currUpdateSliceRank - expectedUpdateSliceRank; - LLVM_DEBUG(DBGS() << "need to drop " << dimToDrop - << " from the updates tensor (after index dim)\n"); - if (!llvm::all_equal( - updateType.getShape().drop_front(1).take_front(dimToDrop)) || - updateType.getDimSize(1) != 1) - return failure(); - - RankedTensorType newShape = - RankedTensorType::Builder(updateType) - .setShape(updateType.getShape().drop_front(1 + dimToDrop)) - .insertDim(updateType.getDimSize(0), 0); - return llvm::to_vector(llvm::map_range(updates, [&](Value update) -> Value { - return b.create(loc, newShape, update); - })); - } - - return failure(); -} - -namespace { -/// Simplify `stablehlo.scatter` to conform with `tensorrt.scatter`. -struct StablehloCanonicalizeScatterToTensorRtScatterNdFormat - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(stablehlo::ScatterOp op, - PatternRewriter &rewriter) const override { - // If we are already in canonical form, then there is nothing to do. - if (stablehlo_ext::isCanonicalScatterNd(op)) - return failure(); - - // Only proceed if we are in the StableHLO canonical form. This is covered - // by the "stablehlo-ext-canonicalize-scatter" pass that runs before this - // pass. All scatter ops should be in stablehlo canonical form at this - // point. - if (!stablehlo_ext::isCanonicalScatter(op)) - return failure(); - - RankedTensorType canonicalizedInputType = - cast(op.getInputs().front().getType()); - RankedTensorType canonicalizedIndexType = - cast(op.getScatterIndices().getType()); - - // Reshape the updates if possible. - int64_t inputRank = canonicalizedInputType.getRank(); - int64_t indexDepth = canonicalizedIndexType.getDimSize(1); - FailureOr> canonicalizedUpdates = - stablehloReshapeScatterUpdatesToAddInsertedDims( - rewriter, op.getLoc(), op.getUpdates(), indexDepth, inputRank); - if (failed(canonicalizedUpdates)) - return rewriter.notifyMatchFailure(op, "failed to canonicalize updates"); - - // Create the new scatter op. - auto canonicalizedUpdatesType = - cast(canonicalizedUpdates->front().getType()); - assert(((canonicalizedInputType.getRank() - indexDepth) + - (canonicalizedIndexType.getRank() - 1)) == - canonicalizedUpdatesType.getRank() && - "expected slice size to equal inputRank - index_depth"); - auto newConfig = stablehlo::ScatterDimensionNumbersAttr::get( - getContext(), - /*updateWindowDims=*/ - llvm::to_vector( - llvm::seq(1, canonicalizedUpdatesType.getRank())), - /*insertedWindowDims=*/ - llvm::to_vector(llvm::seq(0, indexDepth)), - /*inputBatchingDims=*/{}, - /*scatterIndicesBatchingDims=*/{}, - /*scatterDimsToOperandDims=*/ - llvm::to_vector(llvm::seq(0, indexDepth)), 1); - auto scatterOp = rewriter.create( - op.getLoc(), TypeRange(ValueRange(op.getInputs())), op.getInputs(), - op.getScatterIndices(), *canonicalizedUpdates, newConfig); - Region ®ion = scatterOp.getUpdateComputation(); - rewriter.inlineRegionBefore(op.getUpdateComputation(), region, - region.end()); - rewriter.replaceOp(op, scatterOp.getResults()); - - scatterOp->setAttr("tensorrt.canonicalized_scatter", - rewriter.getUnitAttr()); - - return success(); - } -}; - -/// Rewrite `arith.constant` to `stablehlo.constant`. Arith constant can be -/// created by `tensor` dialect canonicalizers. Some `arith` constants may be -/// created by `stablehlo-canonicalize-scatter` pass. -struct RewriteArithConstToStablehlo - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(arith::ConstantOp op, - PatternRewriter &rewriter) const override { - if (!isa(op.getType())) - return failure(); - rewriter.replaceOpWithNewOp( - op, op.getType(), cast(op.getValueAttr())); - return success(); - } -}; -} // namespace - -/// Rewrite `tensor.expand_shape`/`tensor.collapse_shape` into a -/// `stablehlo.reshape` operation. -template -static LogicalResult -stablehloRewriteTensorExpandCollapseShape(OpType op, - PatternRewriter &rewriter) { - if (!op.getType().hasStaticShape()) - return failure(); - rewriter.replaceOpWithNewOp(op, op.getType(), - op->getOperand(0)); - return success(); -} - -void stablehlo_ext::populateCanonicalizeStablehloScatterPatterns( - RewritePatternSet &patterns) { - patterns.add( - stablehloRewriteTensorExpandCollapseShape, - PatternBenefit(1), {"tensorCollapseShapeToStablehloReshape"}); - patterns.add(stablehloRewriteTensorExpandCollapseShape, - PatternBenefit(1), {"tensorExpandShapeToStablehloReshape"}); - patterns.insert(patterns.getContext()); - tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, - patterns.getContext()); - tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, - patterns.getContext()); -} diff --git a/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Utils/CMakeLists.txt b/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Utils/CMakeLists.txt index 7aeeb80e9..cab45247e 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Utils/CMakeLists.txt +++ b/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Utils/CMakeLists.txt @@ -5,4 +5,5 @@ add_mlir_tensorrt_library(MLIRTensorRTStableHloExtUtils LINK_LIBS PUBLIC MLIRTensorRTStableHloExtIR + MLIRTensorRTUtilsShapeInfo ) diff --git a/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Utils/GatherScatterUtils.cpp b/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Utils/GatherScatterUtils.cpp index d0a4823b2..6d0fdc3f8 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Utils/GatherScatterUtils.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Utils/GatherScatterUtils.cpp @@ -10,6 +10,9 @@ #include "mlir-tensorrt/Dialect/StableHloExt/Utils/GatherScatterUtils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "stablehlo/dialect/StablehloOps.h" using namespace mlir; @@ -22,8 +25,10 @@ static bool isSeq(R &&range, int64_t start, int64_t size) { llvm::seq(start, start + size)); } -std::optional -stablehlo_ext::isSingleDimSimpleGatherWithImplicitIndexDim(GatherOp op) { +/// Common conditions shared by the static and dynamic versions of +/// "isSingleDimSimpleGatherWithImplicitIndexDim". +template +static bool isSingleDimSimpleGatherWithImplicitIndexDimImpl(OpType op) { RankedTensorType operandType = op.getOperand().getType(); RankedTensorType startIndicesType = op.getStartIndices().getType(); RankedTensorType resultType = op.getType(); @@ -31,31 +36,86 @@ stablehlo_ext::isSingleDimSimpleGatherWithImplicitIndexDim(GatherOp op) { /// Sanity check the expected rank of the result. if (resultType.getRank() != operandType.getRank() + startIndicesType.getRank() - 1) - return {}; + return false; const auto &dims = op.getDimensionNumbers(); // (C3) Check for implicit size-1 index vector. if (dims.getIndexVectorDim() != startIndicesType.getRank()) - return {}; + return false; // (C0) The dimension being gathered is the one that should be collapsed. if (dims.getStartIndexMap().size() != 1 || dims.getStartIndexMap() != dims.getCollapsedSliceDims()) + return false; + + // (C2) The offset dims of the result are the trailing dimensions after the + // start index result dimensions. + if (!isSeq(dims.getOffsetDims(), startIndicesType.getRank(), + resultType.getRank() - startIndicesType.getRank())) + return false; + + return true; +} + +std::optional +stablehlo_ext::isSingleDimSimpleGatherWithImplicitIndexDim(GatherOp op) { + if (!isSingleDimSimpleGatherWithImplicitIndexDimImpl(op)) return {}; // (C1) The `slice_sizes` should equal the shape of the operand except // along the gather dimension, which is size 1. + const auto &dims = op.getDimensionNumbers(); SmallVector expectedSliceSizes(op.getOperand().getType().getShape()); expectedSliceSizes[dims.getStartIndexMap()[0]] = 1; if (!llvm::equal(expectedSliceSizes, op.getSliceSizes())) return {}; - // (C2) The offset dims of the result are the trailing dimensions after the - // start index result dimensions. - if (!isSeq(dims.getOffsetDims(), startIndicesType.getRank(), - resultType.getRank() - startIndicesType.getRank())) + return dims.getStartIndexMap().front(); +} + +std::optional +stablehlo_ext::isSingleDimSimpleGatherWithImplicitIndexDim( + DynamicGatherOp op, const ShapeInfoCallbacks &shapeInfoCallbacks) { + + // The dynamic gather 3rd parameter is the "slice sizes". We want to + // ensure that "slice sizes" matches the operand shape in all dimensions + // except for those dropped using "collapsed_dims". + TypedValue sliceSizes = op.getSliceSizes(); + RankedTensorType sliceSizesType = sliceSizes.getType(); + + if (!isSingleDimSimpleGatherWithImplicitIndexDimImpl(op)) + return {}; + + // (C1) The `slice_sizes` should equal the shape of the operand except + // along the gather dimension, which is size 1. + const auto &dims = op.getDimensionNumbers(); + for (int64_t i = 0; i < sliceSizesType.getDimSize(0); i++) { + if (i == dims.getStartIndexMap()[0]) { + auto one = + IntegerAttr::get(op.getSliceSizes().getType().getElementType(), 1); + if (std::optional isEqualToOne = + shapeInfoCallbacks.isElementValueEqualToConstant( + TensorElementValue(op.getSliceSizes(), i), one)) { + if (!*isEqualToOne) + return {}; + continue; + } + return {}; + } + + if (std::optional isEquivalent = + shapeInfoCallbacks.isElementValueEqualToShapeDimExtent( + TensorElementValue(op.getSliceSizes(), i), + TensorShapeDimExtent(op.getOperand(), i))) { + if (!*isEquivalent) + return {}; + continue; + } + return {}; + } + return dims.getStartIndexMap().front(); } @@ -217,6 +277,125 @@ bool stablehlo_ext::isCanonicalScatterNd(stablehlo::ScatterOp scatterOp) { updateType.getRank(); } +Value stablehlo_ext::createCollapsingReshape( + OpBuilder &b, Location loc, Value input, + ArrayRef reassociation) { + RankedTensorType inputType = cast(input.getType()); + + std::vector newShape(reassociation.size()); + for (auto [idx, re] : llvm::enumerate(reassociation)) { + int64_t dim = 1; + for (int64_t i : re) { + if (inputType.isDynamicDim(i)) { + dim = ShapedType::kDynamic; + break; + } + dim *= inputType.getDimSize(i); + } + newShape[idx] = dim; + } + auto resultType = inputType.clone(newShape); + + assert(inputType.getRank() > resultType.getRank() && + "input rank should be > result rank"); + assert(static_cast(reassociation.size()) == resultType.getRank() && + "invalid reassociation indices"); + + if (resultType.hasStaticShape()) + return b.create(loc, resultType, input); + + // Calculate the shape. + Type i32Type = b.getI32Type(); + RankedTensorType i32ScalarTensorType = RankedTensorType::get({}, i32Type); + SmallVector components; + for (const ReassociationIndices &indices : reassociation) { + Value vol = b.create( + loc, + DenseElementsAttr::get(i32ScalarTensorType, static_cast(1))); + for (int64_t index : indices) { + Value extent = b.create( + loc, i32ScalarTensorType, input, index); + vol = b.create(loc, vol, extent); + } + components.push_back(b.create( + loc, i32ScalarTensorType.clone({1}), vol)); + } + Value shape = b.create( + loc, i32ScalarTensorType.clone({resultType.getRank()}), components, + /*dimension=*/0); + return b.create(loc, resultType, input, shape); +} + +Value stablehlo_ext::createExpandingReshape( + OpBuilder &b, Location loc, RankedTensorType resultType, Value input, + ArrayRef reassociation) { + RankedTensorType inputType = cast(input.getType()); + + assert(inputType.getRank() < resultType.getRank() && + "input rank should be > result rank"); + assert(static_cast(reassociation.size()) == inputType.getRank() && + "invalid reassociation indices"); + + if (resultType.hasStaticShape()) + return b.create(loc, resultType, input); + + // Calculate the shape. + Type i32Type = b.getI32Type(); + RankedTensorType i32ScalarTensorType = RankedTensorType::get({}, i32Type); + SmallVector components; + for (auto [inputDim, resultIndices] : llvm::enumerate(reassociation)) { + + // Calculate how many dynamic dimensions are in the group. This function + // only supports up to 1 dynamic dimension in each group, otherwise we can't + // calculate the shape. + int64_t numDynamicInGroup = 0; + int64_t divisor = 1; + for (int64_t resultDim : resultIndices) { + if (resultType.isDynamicDim(resultDim)) { + numDynamicInGroup += 1; + continue; + } + divisor *= resultType.getDimSize(resultDim); + } + assert(numDynamicInGroup <= 1 && "invalid reshape configuration requested"); + + for (int64_t resultDim : resultIndices) { + if (!resultType.isDynamicDim(resultDim)) { + components.push_back(b.create( + loc, DenseElementsAttr::get( + i32ScalarTensorType.clone({1}), + static_cast(resultType.getDimSize(resultDim))))); + continue; + } + + Value extent = b.create( + loc, i32ScalarTensorType, input, inputDim); + + if (resultIndices.size() == 1 || divisor == 1) { + components.push_back(b.create( + loc, i32ScalarTensorType.clone({1}), extent)); + continue; + } + + // In the case where we are factoring out multiple constant dimensions, we + // divide by the product of the other dimensions to get the expected + // extent. + extent = b.create( + loc, extent, + b.create( + loc, DenseElementsAttr::get(i32ScalarTensorType, + static_cast(divisor)))); + components.push_back(b.create( + loc, i32ScalarTensorType.clone({1}), extent)); + } + } + + Value shape = b.create( + loc, i32ScalarTensorType.clone({resultType.getRank()}), components, + /*dimension=*/0); + return b.create(loc, resultType, input, shape); +} + //===----------------------------------------------------------------------===// // Code below this point was adapted from the MLIR-HLO project (part of OpenXLA // project) `xla/mlir_hlo/mhlo/utils/mhlo_scatter_gather_utils.cc` and has the @@ -278,17 +457,62 @@ Value stablehlo_ext::insertDegenerateDimensions( assert(llvm::is_sorted(dimsToInsert) && "dimsToInsert must be sorted"); if (dimsToInsert.empty()) return tensor; - TensorType type = mlir::cast(tensor.getType()); + auto type = mlir::cast(tensor.getType()); SmallVector newShape{type.getShape()}; - for (int64_t dim : dimsToInsert) + + // Create an initial identity reassociation. We will update this as we insert + // the degenerate dimensions. + SmallVector reassociation; + for (unsigned i = 0; i < newShape.size(); i++) { + ReassociationIndices &back = reassociation.emplace_back(); + back.push_back(i); + } + + for (int64_t dim : dimsToInsert) { newShape.insert(newShape.begin() + dim, 1); + + if (type.getRank() == 0) + continue; + + /// Calculate which reassociation group this new degenerate dimension + /// belongs to and where the degenerate dimension should be inserted. + unsigned reassociationGroupIdx = 0; + unsigned insertionPositionWithinGroup = 0; + for (auto [idx, re] : llvm::enumerate(reassociation)) { + if (reassociationGroupIdx + re.size() > static_cast(dim)) { + insertionPositionWithinGroup = dim - reassociationGroupIdx; + reassociationGroupIdx = idx; + break; + } + reassociationGroupIdx += re.size(); + } + reassociationGroupIdx = + std::min(reassociationGroupIdx, reassociation.size() - 1); + + assert(reassociationGroupIdx < reassociation.size() && + "invalid reassociation group"); + + reassociation[reassociationGroupIdx].insert( + reassociation[reassociationGroupIdx].begin() + + insertionPositionWithinGroup, + reassociation[reassociationGroupIdx][insertionPositionWithinGroup]); + // Update all indices to the right of where we inserted, for all groups. + for (int64_t &d : + llvm::MutableArrayRef(reassociation[reassociationGroupIdx]) + .drop_front(insertionPositionWithinGroup + 1)) + d += 1; + + for (ReassociationIndices &other : + llvm::MutableArrayRef(reassociation) + .drop_front(reassociationGroupIdx + 1)) { + for (int64_t &d : other) + d += 1; + } + } + auto newType = RankedTensorType::get(newShape, type.getElementType()); - return b - .create( - loc, newType, tensor, - *getReassociationIndicesForReshape(type, newType)) - .getResult(); + return createExpandingReshape(b, loc, newType, tensor, reassociation); } // Checks if the indexVectorDim is equal to the rank of `indices`. In that @@ -315,8 +539,8 @@ Value stablehlo_ext::canonicalizeStartIndices(OpBuilder &b, Location loc, Value indices, int64_t indexVectorDim) { indices = ensureIndexVectorDimPosition(b, loc, indices, indexVectorDim); - - int64_t indicesRank = mlir::cast(indices.getType()).getRank(); + auto indicesType = cast(indices.getType()); + int64_t indicesRank = indicesType.getRank(); if (indicesRank == 2) return indices; @@ -327,6 +551,6 @@ Value stablehlo_ext::canonicalizeStartIndices(OpBuilder &b, Location loc, SmallVector reassociation{ llvm::to_vector<2>(llvm::seq(0, indicesRank - 1)), {indicesRank - 1}}; - return b.create(loc, indices, reassociation) - .getResult(); + + return createCollapsingReshape(b, loc, indices, reassociation); } diff --git a/mlir-tensorrt/compiler/lib/Pipelines/CMakeLists.txt b/mlir-tensorrt/compiler/lib/Pipelines/CMakeLists.txt index 0e7f92df6..b0450f933 100644 --- a/mlir-tensorrt/compiler/lib/Pipelines/CMakeLists.txt +++ b/mlir-tensorrt/compiler/lib/Pipelines/CMakeLists.txt @@ -14,6 +14,7 @@ if(MLIR_TRT_ENABLE_HLO) StableHloInputPipelines.cpp ) list(APPEND pipeline_deps_ + MLIRTensorRTChloToStablehloExt MLIRTensorRTStableHloExtTransforms MLIRTensorRTStablehloToSCF MLIRTensorRTStablehloToTensorRT diff --git a/mlir-tensorrt/compiler/lib/Pipelines/StableHloInputPipelines.cpp b/mlir-tensorrt/compiler/lib/Pipelines/StableHloInputPipelines.cpp index 6aa92a2cd..b05988d84 100644 --- a/mlir-tensorrt/compiler/lib/Pipelines/StableHloInputPipelines.cpp +++ b/mlir-tensorrt/compiler/lib/Pipelines/StableHloInputPipelines.cpp @@ -29,8 +29,9 @@ using namespace mlir; -static void buildStableHloSimplificationPipeline(OpPassManager &pm, - bool legalizeChlo) { +static void buildStableHloSimplificationPipeline( + OpPassManager &pm, + const mlir::ConvertChloToStableHloExtPassOptions &chloToStablehloOptions) { // Some match-and-raise patterns should be performed before canonicalization, // since the pattern is based on specific frontend patterns (e.g. JAX). pm.addPass(stablehlo_ext::createExpandTuplesPass()); @@ -43,9 +44,7 @@ static void buildStableHloSimplificationPipeline(OpPassManager &pm, // We don't do the CHLO legalization until this point since we want to wait // until after `canonicalize-shapes` has run at least once. This reduces the // likelihood of generating `shape` dialect ops. - if (legalizeChlo) - pm.addNestedPass( - stablehlo::createChloLegalizeToStablehloPass()); + pm.addPass(mlir::createConvertChloToStableHloExtPass(chloToStablehloOptions)); pm.addPass(stablehlo_ext::createCanonicalizeDotGeneralPass()); pm.addPass(stablehlo_ext::createConstantFoldingPass()); @@ -64,10 +63,13 @@ void mlir::buildStablehloPreProcessingPipeline( pm.addPass(stablehlo_ext::createLowerSpecialCustomCalls()); // Simplify StableHLO graph - buildStableHloSimplificationPipeline(pm, opts.convertChloToStablehlo); + buildStableHloSimplificationPipeline( + pm, ConvertChloToStableHloExtPassOptions{ + /*preserveErf=*/opts.preserveChloErf, + /*preserveTopK=*/opts.preserveChloTopK, + }); pm.addPass(createCSEPass()); - pm.addNestedPass( - stablehlo_ext::createStablehloInputPreprocessingPass()); + pm.addPass(stablehlo_ext::createCanonicalizeConvolutionPass()); if (opts.legalizeControlFlowToSCF) pm.addPass(mlir::createConvertStablehloToScfPass()); pm.addPass(createCSEPass()); @@ -83,20 +85,20 @@ struct StableHloInputPipelineOptions *this, "legalize-control-flow-to-scf", llvm::cl::desc("lower StableHLO control flow ops to SCF"), llvm::cl::init(false)}; - Option legalizeChloErfToStablehlo{ - *this, "legalize-chlo-erf-to-stablehlo", - llvm::cl::desc( - "Whether to lower chlo.erf into primitive stablehlo operations"), + + Option preserveChloErf{ + *this, "preserve-chlo-erf", + llvm::cl::desc("don't lower chlo.erf to stablehlo"), + llvm::cl::init(true)}; + Option preserveChloTopK{ + *this, "preserve-chlo-topk", + llvm::cl::desc("don't lower chlo.top_k to stablehlo"), llvm::cl::init(true)}; Option disableInliner{ *this, "disable-inliner", llvm::cl::desc( "Whether to disable running the inliner as part of the pipeline"), llvm::cl::init(false)}; - Option convertChloToStablehlo{ - *this, "convert-chlo-to-stablehlo", - llvm::cl::desc("Whether to lower chlo to stablehlo"), - llvm::cl::init(false)}; }; } // namespace @@ -108,15 +110,17 @@ void mlir::registerStableHloInputPipelines() { [](OpPassManager &pm, const StableHloInputPipelineOptions &opts) { StableHloInputOptions inputOpts; inputOpts.legalizeControlFlowToSCF = opts.legalizeControlFlowToSCF; - inputOpts.legalizeChloErfToStablehlo = opts.legalizeChloErfToStablehlo; + inputOpts.preserveChloErf = opts.preserveChloErf; + inputOpts.preserveChloTopK = opts.preserveChloTopK; inputOpts.disableInliner = opts.disableInliner; - inputOpts.convertChloToStablehlo = opts.convertChloToStablehlo; buildStablehloPreProcessingPipeline(pm, inputOpts); }); - PassPipelineRegistration<>("stablehlo-simplification-pipeline", - "Apply StableHLO simplification passes", - [](OpPassManager &pm) { - buildStableHloSimplificationPipeline(pm, false); - }); + PassPipelineRegistration( + "stablehlo-simplification-pipeline", + "Apply StableHLO simplification passes", + [](OpPassManager &pm, const StableHloInputPipelineOptions &opts) { + buildStableHloSimplificationPipeline( + pm, {opts.preserveChloErf, opts.preserveChloTopK}); + }); } diff --git a/mlir-tensorrt/compiler/lib/Utils/CMakeLists.txt b/mlir-tensorrt/compiler/lib/Utils/CMakeLists.txt index 8b1378917..a6b68e3bc 100644 --- a/mlir-tensorrt/compiler/lib/Utils/CMakeLists.txt +++ b/mlir-tensorrt/compiler/lib/Utils/CMakeLists.txt @@ -1 +1,8 @@ +add_mlir_tensorrt_library(MLIRTensorRTUtilsShapeInfo + ShapeInfo.cpp + + LINK_LIBS PUBLIC + MLIRIR + MLIRDialectUtils + ) diff --git a/mlir-tensorrt/compiler/lib/Utils/ShapeInfo.cpp b/mlir-tensorrt/compiler/lib/Utils/ShapeInfo.cpp new file mode 100644 index 000000000..3b469a33d --- /dev/null +++ b/mlir-tensorrt/compiler/lib/Utils/ShapeInfo.cpp @@ -0,0 +1,41 @@ +//===- ShapeInfo.h ---------------------------------------------*- C++ -*-===// +// +// SPDX-FileCopyrightText: Copyright 2024 NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +#include "mlir-tensorrt/Utils/ShapeInfo.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/IR/BuiltinTypes.h" + +using namespace mlir; + +TensorElementValue::TensorElementValue(Value value, ArrayRef coord) + : tensor(cast>(value)), + linearIndex(mlir::linearize( + coord, mlir::computeSuffixProduct(tensor.getType().getShape()))) {} + +TensorShapeDimExtent::TensorShapeDimExtent(Value value, int64_t dim) + : tensor(cast>(value)), dim(dim) { + assert(dim > 0 && dim < tensor.getType().getRank() && + "dim must be > 0 and < tensor rank"); +} + +std::optional TensorShapeDimExtent::getConstantSize() const { + if (tensor.getType().isDynamicDim(dim)) + return {}; + return tensor.getType().getDimSize(dim); +} diff --git a/mlir-tensorrt/executor/include/mlir-executor/Executor/IR/ExecutorDialect.td b/mlir-tensorrt/executor/include/mlir-executor/Executor/IR/ExecutorDialect.td index 308790202..1de15dd7e 100644 --- a/mlir-tensorrt/executor/include/mlir-executor/Executor/IR/ExecutorDialect.td +++ b/mlir-tensorrt/executor/include/mlir-executor/Executor/IR/ExecutorDialect.td @@ -58,6 +58,8 @@ def Executor_Dialect : Dialect { let useDefaultTypePrinterParser = 1; let useDefaultAttributePrinterParser = 1; + let hasRegionArgAttrVerify = 1; + let extraClassDeclaration = [{ Operation *materializeConstant(::mlir::OpBuilder &builder, ::mlir::Attribute value, @@ -85,13 +87,13 @@ def Executor_Dialect : Dialect { /// Return the name of the function arg attr that encodes /// host tensor value bounds. It should have a type `executor::ValueBoundsAttr`. static StringRef getValueBoundsAttrName() { - return "tensorrt.value_bounds"; + return "executor.value_bounds"; } /// Return the name of the function arg attr that encodes /// the shape bounds. It should have a type `executor::DimensionBoundsAttr`. static StringRef getShapeBoundsAttrName() { - return "tensorrt.shape_profile"; + return "executor.shape_profile"; } }]; // Temporary until CUDA ops are completely removed out of Executor dialect diff --git a/mlir-tensorrt/executor/lib/Executor/IR/Executor.cpp b/mlir-tensorrt/executor/lib/Executor/IR/Executor.cpp index 68d567383..ed812361b 100644 --- a/mlir-tensorrt/executor/lib/Executor/IR/Executor.cpp +++ b/mlir-tensorrt/executor/lib/Executor/IR/Executor.cpp @@ -86,23 +86,21 @@ LogicalResult executor::ValueBoundsAttr::verify( "matching types; found min type: " << min.getType() << ", max type: " << max.getType(); - if (!min.getType().getElementType().isIntOrIndex()) + if (!min.getType().getElementType().isSignlessIntOrIndex()) return emitError() - << "ValueBoundsAttr 'min' and 'max' value bounds type must " - "be either i64 or " - "an index"; + << "ValueBoundsAttr 'min' and 'max' value bounds element type must " + "be a signless integer or " + "an index type"; // Compare underlying values. - auto minV = min.getValues(); - auto maxV = max.getValues(); - for (unsigned i = 0; i < minV.size(); ++i) { - if (minV[i] < 0) - return emitError() << "ValueBoundsAttr min[" << i << "] : " << minV[i] - << " must be greater than or equal to 0"; - if (minV[i] > maxV[i]) - return emitError() << "ValueBoundsAttr min[" << i << "] : " << minV[i] + auto mins = min.getValues(); + auto maxs = max.getValues(); + for (auto [i, minV, maxV] : llvm::enumerate(mins, maxs)) { + if (minV.sgt(maxV)) + return emitError() << "ValueBoundsAttr min[" << i + << "] : " << minV.getSExtValue() << " must be less than equal to " - << "max[" << i << "] : " << maxV[i]; + << "max[" << i << "] : " << maxV.getSExtValue(); } return success(); } @@ -1616,6 +1614,65 @@ Operation *ExecutorDialect::materializeConstant(OpBuilder &builder, return builder.create(loc, type, typedAttr); } +static LogicalResult verifyValueBoundsAttribute(Operation *op, + unsigned argIndex, + executor::ValueBoundsAttr attr, + StringRef attrName) { + auto func = dyn_cast(op); + if (!func) + return op->emitError() + << attrName + << " should only be used for FunctionOpInterface argument " + "and result attributes"; + + ShapedType valuesType = attr.getMin().getType(); + + Type argType = func.getArgument(argIndex).getType(); + if (auto shapedType = dyn_cast(argType)) { + if (valuesType.getShape() != shapedType.getShape() || + (valuesType.getElementType().isIndex() && + !shapedType.getElementType().isIntOrIndex()) || + (!valuesType.getElementType().isIndex() && + shapedType.getElementType() != shapedType.getElementType())) + return op->emitError() + << attrName << " value bounds type " << valuesType + << " is not compatible with the argument type " << argType; + + return success(); + } + + if (argType.isIntOrIndexOrFloat()) { + if (attr.getMin().getType().getRank() != 0) + return op->emitError() + << attrName << " bounds of type " << valuesType + << " must be a 0-rank shaped type for scalar argument type " + << argType; + } + + // If the type is not a shaped type or scalar, then we don't do any + // validation. It may could correspond to whatever type that the memref was + // lowered into (e.g. pointer or table), so there's not much validation that + // is possible. + return success(); +} + +LogicalResult +ExecutorDialect::verifyRegionArgAttribute(Operation *op, unsigned regionIndex, + unsigned argIndex, + NamedAttribute attribute) { + if (attribute.getName() == getValueBoundsAttrName()) { + auto boundsAttr = dyn_cast(attribute.getValue()); + if (!boundsAttr) + return op->emitError() + << "expected named attribute \"" << getValueBoundsAttrName() + << "\" to be a \"#executor.value_bounds\" attribute containing " + "value bounds"; + return verifyValueBoundsAttribute(op, argIndex, boundsAttr, + attribute.getName()); + } + + return success(); +} //===----------------------------------------------------------------------===// // TableGen'd dialect definition. //===----------------------------------------------------------------------===// diff --git a/mlir-tensorrt/executor/lib/Target/Lua/TranslateToRuntimeExecutable.cpp b/mlir-tensorrt/executor/lib/Target/Lua/TranslateToRuntimeExecutable.cpp index 7e88f9d39..189f8c46a 100644 --- a/mlir-tensorrt/executor/lib/Target/Lua/TranslateToRuntimeExecutable.cpp +++ b/mlir-tensorrt/executor/lib/Target/Lua/TranslateToRuntimeExecutable.cpp @@ -288,17 +288,11 @@ translateAttribute(FBBuilder &fb, Attribute attr) { } if (auto vals = llvm::dyn_cast(attr)) { - auto elemT = vals.getMin().getElementType(); - assert(elemT == vals.getMax().getElementType()); - if (!elemT.isInteger(64)) - return emitError(UnknownLoc::get(attr.getContext())) - << "Unsupported element type " << elemT << " for attribute (" - << attr << ") in function metadata"; - + auto toI64 = [](const llvm::APInt &v) { return v.getSExtValue(); }; auto min = fb.serialize( - llvm::to_vector(vals.getMin().getValues())); + llvm::map_to_vector(vals.getMin().getValues(), toI64)); auto max = fb.serialize( - llvm::to_vector(vals.getMax().getValues())); + llvm::map_to_vector(vals.getMax().getValues(), toI64)); return std::make_pair(rt::impl::Bounds::ValueBounds, rt::impl::CreateValueBounds(fb, min, max).Union()); } diff --git a/mlir-tensorrt/executor/test/Executor/invalid.mlir b/mlir-tensorrt/executor/test/Executor/invalid.mlir index 748b844fc..f7c325358 100644 --- a/mlir-tensorrt/executor/test/Executor/invalid.mlir +++ b/mlir-tensorrt/executor/test/Executor/invalid.mlir @@ -308,3 +308,70 @@ func.func @coro_await() -> (i32) { %0:2 = executor.coro_await %coro (%c0, %c0_f32 : i32, f32) : (f32, i32) -> i32 return %0#1 : i32 } + +// ----- + +#bounds = #executor.value_bounds : tensor<1x10xi32>, max = dense<20> : tensor<1x10xi32>> + + +// expected-error @below {{executor.value_bounds value bounds type 'tensor<1x10xi32>' is not compatible with the argument type 'tensor<1x11xi32>'}} +func.func @value_bounds_shape_mismatch(%arg0: tensor<1x11xi32> {executor.value_bounds = #bounds}) { + return +} + +// ----- + +#bounds = #executor.value_bounds : tensor<1xi32>, max = dense<20> : tensor<1xi32>> + + +// expected-error @below {{executor.value_bounds value bounds type 'tensor<1xi32>' is not compatible with the argument type 'tensor'}} +func.func @value_bounds_0rank_shape_mismatch(%arg0: tensor {executor.value_bounds = #bounds}) { + return +} + +// ----- + +#bounds = #executor.value_bounds : tensor<1x10xi32>, max = dense<20> : tensor<1x10xi32>> + + +// expected-error @below {{executor.value_bounds value bounds type 'tensor<1x10xi32>' is not compatible with the argument type 'tensor<1x11xi64>'}} +func.func @value_bounds_element_type_mismatch(%arg0: tensor<1x11xi64> {executor.value_bounds = #bounds}) { + return +} + +// ----- + +#bounds = #executor.value_bounds : tensor, max = dense<20> : tensor> + +// We allow 'index' value type to be compatible with other integer types. +func.func @value_bounds_element_type_index_compat( + %arg0: tensor {executor.value_bounds = #bounds}, + %arg1: tensor {executor.value_bounds = #bounds}, + %arg2: tensor {executor.value_bounds = #bounds}) { + return +} + +// ----- + +#bounds = #executor.value_bounds : tensor<1xi32>, max = dense<20> : tensor<1xi32>> + +// expected-error @below {{executor.value_bounds bounds of type 'tensor<1xi32>' must be a 0-rank shaped type for scalar argument type 'i32'}} +func.func @value_bounds_scalar_shape_mismatch(%arg0: i32 {executor.value_bounds = #bounds}) { + return +} + +// ----- + +#bounds = #executor.value_bounds : tensor, max = dense<20> : tensor> + +func.func @value_bounds_scalar_shape_ok(%arg0: i32 {executor.value_bounds = #bounds}) { + return +} + +// ----- + +#bounds = #executor.value_bounds : tensor, max = dense<20> : tensor> + +func.func @dont_validate_bounds_to_non_shaped_or_scalar_type(%arg0: !executor.table {executor.value_bounds = #bounds}) { + return +} diff --git a/mlir-tensorrt/executor/test/Executor/roundtrip.mlir b/mlir-tensorrt/executor/test/Executor/roundtrip.mlir index c28677f8f..fedd57ac1 100644 --- a/mlir-tensorrt/executor/test/Executor/roundtrip.mlir +++ b/mlir-tensorrt/executor/test/Executor/roundtrip.mlir @@ -681,4 +681,15 @@ func.func @coro_await() -> (i32) { // CHECK: %[[v0:.+]] = executor.coro_create @coro : (f32, i32) -> i32 // CHECK: %[[status:status.*]], %[[results:.+]] = executor.coro_await %[[v0]](%[[cst]], %[[c0_i32]] : f32, i32) : (f32, i32) -> i32 // CHECK: %[[status_0:.+]], %[[results_1:.+]] = executor.coro_await %[[v0]]() : (f32, i32) -> i32 -// CHECK: return %[[results_1]] : i32 \ No newline at end of file +// CHECK: return %[[results_1]] : i32 + +// ----- + +#bounds = #executor.value_bounds : tensor<1x10xi32>, max = dense<20> : tensor<1x10xi32>> + +func.func @value_bounds(%arg0: tensor<1x10xi32> {executor.value_bounds = #bounds}) { + return +} + +// CHECK-LABEL: func.func @value_bounds +// CHECK-SAME: executor.value_bounds = #executor.value_bounds : tensor<1x10xi32>, max = dense<20> : tensor<1x10xi32>> diff --git a/mlir-tensorrt/python/mlir_tensorrt_tools/mlir_tensorrt/tools/gpu_tools.py b/mlir-tensorrt/python/mlir_tensorrt_tools/mlir_tensorrt/tools/gpu_tools.py index f578ec960..100b877be 100644 --- a/mlir-tensorrt/python/mlir_tensorrt_tools/mlir_tensorrt/tools/gpu_tools.py +++ b/mlir-tensorrt/python/mlir_tensorrt_tools/mlir_tensorrt/tools/gpu_tools.py @@ -6,7 +6,8 @@ """ from contextlib import contextmanager -from typing import List, Tuple +import time +from typing import List, Optional, Tuple import click import numpy as np @@ -73,16 +74,25 @@ def get_stats(devices: List[int]) -> Tuple[List[float], List[float], List[float] return avail_mem_gb, gpu_rates, mem_rates -def select_device(devices: List[int]) -> int: +def select_device(devices: List[int], required_memory: Optional[float] = None) -> int: """Selects the device (that is among those with the highest SM version if SM versions are not uniform) that has the most available GPU memory. """ assert len(devices) > 0 - avail_mem_gb, _, _ = get_stats(devices) - # All devices have same SM version. - # Check utilization rates. - max_mem = int(np.argmax(avail_mem_gb)) + while True: + avail_mem_gb, _, _ = get_stats(devices) + avail_mem_gb = np.asarray(avail_mem_gb) + + if required_memory and avail_mem_gb.max() * 1024.0 < required_memory: + time.sleep(1.0) + continue + + # All devices have same SM version. + # Check utilization rates. + max_mem = int(np.argmax(avail_mem_gb)) + break + return max_mem @@ -118,7 +128,13 @@ def cli(): @cli.command("pick-device") -def pick_device(): +@click.option( + "--required-memory", + help="causes the command to block until the specified amount of memory (in gigabytes) is available on some visible device", + required=False, + type=click.FLOAT, +) +def pick_device(required_memory: Optional[float]): with nvml_context() as devices: if len(devices) == 0: return diff --git a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Target/TensorRTEncodingOpInterface/NetworkEncoder.h b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Target/TensorRTEncodingOpInterface/NetworkEncoder.h index e351db1e9..777677d4a 100644 --- a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Target/TensorRTEncodingOpInterface/NetworkEncoder.h +++ b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Target/TensorRTEncodingOpInterface/NetworkEncoder.h @@ -194,11 +194,12 @@ class NvInferNetworkEncoder { /// Adds IFillLayer to the network. This switches between different /// APIs depending on the compile-time TensorRT version and whether or not /// the strongly-typed flags is enabled. - nvinfer1::ILayer *addFillLayer( - nvinfer1::DataType elementType, const nvinfer1::Dims &staticShape, - nvinfer1::ITensor *dynamicShape, nvinfer1::FillOperation fillOperation, - std::optional alpha, std::optional beta, - nvinfer1::ITensor *dynamicAlpha, nvinfer1::ITensor *dynamicBeta); + nvinfer1::ILayer * + addFillLayer(nvinfer1::DataType elementType, nvinfer1::Dims staticShape, + nvinfer1::ITensor *dynamicShape, + nvinfer1::FillOperation fillOperation, + std::optional alpha, std::optional beta, + nvinfer1::ITensor *dynamicAlpha, nvinfer1::ITensor *dynamicBeta); /// Adds a TensorRT plugin. FailureOr diff --git a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.td b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.td index 99bbad84d..2adfc1e04 100644 --- a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.td +++ b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.td @@ -143,6 +143,10 @@ def TensorRT_I4 : I<4>; class TensorRT_RankedTensorOf allowedTypes> : TensorRankOf; +// Describes a TensorRT Tensor whose rank must be greater than 0. +class TensorRT_Non0RankedTensorOf allowedTypes> + : TensorRankOf; + def TensorRT_Tensor : TensorRT_RankedTensorOf< [I1, TensorRT_I8, I32, I64, F16, BF16, F32]>; diff --git a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td index dad8043ba..10cb584e5 100644 --- a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td +++ b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td @@ -144,7 +144,7 @@ def TensorRT_CallAllocOp : TensorRT_Op<"call_alloc", [ `enqueue_alloc` operation in a later pass. }]; - let arguments = (ins + let arguments = (ins Variadic>:$inputs, SymbolRefAttr:$callee ); @@ -1037,8 +1037,13 @@ def TensorRT_LinspaceOp : TensorRT_Op<"linspace", [ slice of the result along dimension `i`. Then the contents of `s` are given by: - `s[j] = start + step[i] * j` if `step` is a tensor or - `s[j] = start + step * j` if `step` is a scalar. + - If `step` is a tensor: + + `s[coord...] = start + dot(step, [coord...]) + + - If `step` is a static scalar: + + `s[j] = start + step * j` The inputs `start` and `step` must either both be given by tensor-typed SSA values or both be given by static attributes. One @@ -1052,12 +1057,12 @@ def TensorRT_LinspaceOp : TensorRT_Op<"linspace", [ let arguments = (ins Optional<1DTensorOf<[I32]>>:$shape, - Optional<0DTensorOf<[I32, F32]>>:$start, - Optional<1DTensorOf<[I32, F32]>>:$step, + Optional<0DTensorOf<[I32, I64, F32]>>:$start, + Optional<1DTensorOf<[I32, I64, F32]>>:$step, OptionalAttr:$static_start, OptionalAttr:$static_step ); - let results = (outs TensorRT_RankedTensorOf<[I32, I64, F16, F32]>:$result); + let results = (outs TensorRT_RankedTensorOf<[I32, I64, F32]>:$result); let assemblyFormat = [{ attr-dict `[` ($start^ `:` type($start)) : ($static_start)? `]` @@ -3238,7 +3243,7 @@ def TensorRT_DequantizeOp : TensorRT_Op<"dequantize", }]; let arguments = (ins - TensorRT_RankedTensorOf<[TensorRT_I8, TensorRT_F8, TensorRT_I4]>:$input, + TensorRT_Non0RankedTensorOf<[TensorRT_I8, TensorRT_F8, TensorRT_I4]>:$input, TensorRankOf<[F32, F16, BF16], [0, 1, 2]>:$scale, OptionalAttr:$axis ); diff --git a/mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp b/mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp index 6ae989205..da5f73408 100644 --- a/mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp +++ b/mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp @@ -341,6 +341,7 @@ nvinfer1::IFillLayer *populateFillLayerParameters( nvinfer1::ITensor *dynamicShape, std::optional alpha, std::optional beta, nvinfer1::ITensor *dynamicAlpha, nvinfer1::ITensor *dynamicBeta) { + assert(layer != nullptr && "expected valid layer"); if (dynamicShape != nullptr) layer->setInput(0, *dynamicShape); else @@ -360,18 +361,25 @@ nvinfer1::IFillLayer *populateFillLayerParameters( } nvinfer1::ILayer *NvInferNetworkEncoder::addFillLayer( - nvinfer1::DataType elementType, const nvinfer1::Dims &staticShape, + nvinfer1::DataType elementType, nvinfer1::Dims staticShape, nvinfer1::ITensor *dynamicShape, nvinfer1::FillOperation fillOperation, std::optional alpha, std::optional beta, nvinfer1::ITensor *dynamicAlpha, nvinfer1::ITensor *dynamicBeta) { #if MLIR_TRT_COMPILE_TIME_TENSORRT_VERSION_GTE(9, 1, 0) - nvinfer1::IFillLayer *layer{nullptr}; - if (dynamicShape != nullptr) { - nvinfer1::Dims emptyDims{}; - layer = network->addFill(emptyDims, fillOperation, elementType); - } else { - layer = network->addFill(staticShape, fillOperation, elementType); + if (dynamicShape) { + // Starting in TensorRT 10.5, TensorRT will give an error if we don't give a + // fully valid static result shape, even if we are about to override it with + // a dynamic shape. + nvinfer1::Dims shapeDims = dynamicShape->getDimensions(); + assert(shapeDims.nbDims == 1 && shapeDims.d[0] > 0 && + "invalid shape tensor shape"); + staticShape.nbDims = shapeDims.d[0]; + for (int32_t i = 0; i < shapeDims.nbDims; i++) + staticShape.d[i] = 1; } + nvinfer1::IFillLayer *layer = + network->addFill(staticShape, fillOperation, elementType); + assert(layer != nullptr && "expected valid layer"); return populateFillLayerParameters(layer, staticShape, dynamicShape, alpha, beta, dynamicAlpha, dynamicBeta); #else @@ -530,6 +538,11 @@ static void serializeSplatElements(DenseIntOrFPElementsAttr values, values.getNumElements(), values.getSplatValue()); return; } + if (rtt.getElementType().isInteger(64)) { + std::fill_n(reinterpret_cast(data.data()), + values.getNumElements(), values.getSplatValue()); + return; + } if (rtt.getElementType().isInteger(8)) { std::fill_n(reinterpret_cast(data.data()), values.getNumElements(), values.getSplatValue()); @@ -624,6 +637,9 @@ NvInferNetworkEncoder::getNvInferWeights(ElementsAttr values) { // We also handle elided attributes by generating weights filled with zeros. if (mlir::getElidedResourceElementsAttr(values)) { std::memset(reinterpret_cast(data.data()), 0, data.size()); + } else if (rtt.getElementType().isInteger(64)) { + llvm::copy(values.getValues(), + reinterpret_cast(data.data())); } else if (rtt.getElementType().isInteger(32)) { llvm::copy(values.getValues(), reinterpret_cast(data.data())); diff --git a/mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorKindOpInterfaceImpl.cpp b/mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorKindOpInterfaceImpl.cpp index 79425f781..9ebbabe57 100644 --- a/mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorKindOpInterfaceImpl.cpp +++ b/mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorKindOpInterfaceImpl.cpp @@ -124,9 +124,14 @@ struct LinspaceTensorKindOpInterfaceImpl Operation *op, ArrayRef operands, ArrayRef results, llvm::function_ref setOperandKind) const { + assert(results.size() == 1 && "expected single result lattice"); auto linspaceOp = cast(op); if (linspaceOp.getShape()) setOperandKind(linspaceOp.getShapeMutable()[0], TensorKind::Host); + + if (!results[0] || results[0]->getValue().isUninitialized()) + return; + if (linspaceOp.getStart()) setOperandKind(linspaceOp.getStartMutable()[0], results[0]->getValue().getKind()); diff --git a/mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorRT.cpp b/mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorRT.cpp index 37bfd76eb..bbcb86330 100644 --- a/mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorRT.cpp +++ b/mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorRT.cpp @@ -1147,22 +1147,28 @@ void tensorrt::SliceOp::build(OpBuilder &odsBuilder, OperationState &odsState, toArrayAttr(stride), sliceMode, fill); } -void tensorrt::SliceOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(+[](SliceOp op, PatternRewriter &rewriter) { +namespace { +/// Move size|start dynamic arguments to static attributes if possible. +struct SliceDynamicParameterToStaticPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SliceOp op, + PatternRewriter &rewriter) const override { + // If the dynamic size parameter is foldable, fold to static parameter. DenseIntElementsAttr value; if (op.getSize() && matchPattern(op.getSize(), m_Constant(&value))) { rewriter.replaceOpWithNewOp( - op, op.getInput(), op.getStartAsOpFoldResult(), + op, op.getType(), op.getInput(), op.getStartAsOpFoldResult(), rewriter.getDenseI32ArrayAttr( llvm::to_vector(value.getValues())), op.getStrideAsOpFoldResult(), op.getMode(), op.getFill()); return success(); } + // If the dynamic start parameter is foldable, fold to static parameter. if (op.getStart() && matchPattern(op.getStart(), m_Constant(&value))) { rewriter.replaceOpWithNewOp( - op, op.getInput(), + op, op.getType(), op.getInput(), rewriter.getDenseI32ArrayAttr( llvm::to_vector(value.getValues())), op.getSizeAsOpFoldResult(), op.getStrideAsOpFoldResult(), @@ -1171,7 +1177,13 @@ void tensorrt::SliceOp::getCanonicalizationPatterns(RewritePatternSet &results, } return failure(); - }); + } +}; +} // namespace + +void tensorrt::SliceOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir-tensorrt/tensorrt/lib/TensorRT/IR/Verification.cpp b/mlir-tensorrt/tensorrt/lib/TensorRT/IR/Verification.cpp index 9e34afd2e..48e93cd22 100644 --- a/mlir-tensorrt/tensorrt/lib/TensorRT/IR/Verification.cpp +++ b/mlir-tensorrt/tensorrt/lib/TensorRT/IR/Verification.cpp @@ -24,6 +24,7 @@ #include "mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.h" #include "mlir-tensorrt-dialect/Utils/ShapeUtils.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Quant/QuantTypes.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/IR/BuiltinTypes.h" #include "llvm/ADT/STLExtras.h" @@ -1161,22 +1162,23 @@ LogicalResult OpaquePluginOp::verifyRegions() { // SelectOp //===----------------------------------------------------------------------===// -LogicalResult tensorrt::SelectOp::verify() { - // Select impl start +LogicalResult tensorrt::SelectOp::verify() { return success(); } - // Select impl end - return success(); -} // LogicalResult tensorrt::SelectOp::verify() +//===----------------------------------------------------------------------===// +// AssertOp +//===----------------------------------------------------------------------===// LogicalResult tensorrt::AssertionOp::verify() { - // Assertion impl start const int64_t conditionRank = getCondition().getType().getRank(); if (conditionRank > 1) return emitOpError("expected condition to be of rank 0 or 1"); - // Assertion impl end return success(); -} // LogicalResult tensorrt::AssertionOp::verify() +} + +//===----------------------------------------------------------------------===// +// DequantizeOp +//===----------------------------------------------------------------------===// LogicalResult tensorrt::DequantizeOp::verify() { auto inputType = getInput().getType(); @@ -1219,8 +1221,31 @@ LogicalResult tensorrt::DequantizeOp::verify() { "tensor."); } } + + // Sub-byte input types must have even final dimension. As of TensorRT 10.5, + // the only sub-byte element type supported is i4, and the typical use case is + // to dequantize i4 constants (weights). I4 weights must be packed, and while + // this could allow a range of valid constants shapes like '4x3xi4', TensorRT + // effectively makes an additional requirement that the last dimension must be + // vectorizable to `vector<2xi4>`. + Type quantizedElementType = inputType.getElementType(); + if (auto quantType = dyn_cast(quantizedElementType)) + quantizedElementType = quantType.getStorageType(); + if (!quantizedElementType.isIntOrFloat()) + return emitOpError("expected element type to be int or float type"); + if (quantizedElementType.getIntOrFloatBitWidth() < 8 && + inputType.getDimSize(inputRank - 1) % 2 == 1) + return emitOpError( + "input tensor with sub-byte element type must have even final " + "dimension, but input tensor has final dimension of size ") + << inputType.getDimSize(inputRank - 1); + return success(); -} // LogicalResult tensorrt::DequantizeOp::verify() +} + +//===----------------------------------------------------------------------===// +// ScatterOp +//===----------------------------------------------------------------------===// LogicalResult tensorrt::ScatterOp::verify() { auto inputDataType = getData().getType(); diff --git a/mlir-tensorrt/tensorrt/lib/Utils/ShapeUtils.cpp b/mlir-tensorrt/tensorrt/lib/Utils/ShapeUtils.cpp index 64822cbac..2f800f95d 100644 --- a/mlir-tensorrt/tensorrt/lib/Utils/ShapeUtils.cpp +++ b/mlir-tensorrt/tensorrt/lib/Utils/ShapeUtils.cpp @@ -116,11 +116,14 @@ tensorrt::getBroadcastedShape(ArrayRef> shapes) { if (allEqual) return dimSizes.front(); - // Some dims are '1', all other dims are equal to another fixed number. + // Some dims are '1', all other dims are equal to another fixed number or + // dynamic. std::optional nonUnitSize{}; for (int64_t dimSize : dimSizes) { if (dimSize == 1) continue; + if (ShapedType::isDynamic(dimSize)) + continue; if (nonUnitSize && dimSize == *nonUnitSize) continue; if (nonUnitSize && dimSize != *nonUnitSize) diff --git a/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/canonicalize.mlir b/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/canonicalize.mlir index ac629a05f..16b608ba2 100644 --- a/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/canonicalize.mlir +++ b/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/canonicalize.mlir @@ -1560,3 +1560,44 @@ func.func @resize_no_absorb_invalid_cast(%arg0: tensor<1x144x?x?xf32>) -> tensor // CHECK: %[[cast:.+]] = tensor.cast %[[arg0]] : tensor<1x144x?x?xf32> to tensor<1x144x7x7xf32> // CHECK: %[[v0:.+]] = tensorrt.resize_cubic {coordinateTransformation = #tensorrt.resize_coordinate_transformation, cubicCoeff = -7.500000e-01 : f32, selectorForSinglePixel = #tensorrt.resize_selector} %[[cast]], %[[cst_i32]] : (tensor<1x144x7x7xf32>, tensor<4xi32>) -> tensor // CHECK-NEXT: return %[[v0]] + +// ----- + +func.func @slice_canon_dynamic_size(%arg0: tensor) -> tensor { + %size = tensorrt.constant dense<[2, 2]> : tensor<2xi32> + %0 = tensorrt.slice %arg0 [0, 0][%size: tensor<2xi32>][1, 1] : tensor to tensor + return %0 : tensor +} + +// CHECK-LABEL: func.func @slice_canon_dynamic_size +// CHECK-SAME: (%[[arg0:.+]]: tensor) +// CHECK-NEXT: %[[v0:.+]] = tensorrt.slice %[[arg0]][0, 0][2, 2][1, 1] : tensor to tensor +// CHECK-NEXT: return %[[v0]] : tensor + +// ----- + +func.func @slice_canon_static_type_dynamic_size(%arg0: tensor) -> tensor<2x2xf32> { + %size = tensorrt.constant dense<[2, 2]> : tensor<2xi32> + %0 = tensorrt.slice %arg0 [0, 0][%size: tensor<2xi32>][1, 1] : tensor to tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// CHECK-LABEL: func.func @slice_canon_static_type_dynamic_size +// CHECK-SAME: (%[[arg0:.+]]: tensor) -> tensor<2x2xf32> { +// CHECK: %[[v0:.+]] = tensorrt.slice %[[arg0]][0, 0][2, 2][1, 1] : tensor to tensor<2x2xf32> +// CHECK: return %[[v0]] : tensor<2x2xf32> + +// ----- + +func.func @slice_canon_dynamic_start(%arg0: tensor, %arg1: tensor<2xi32>) -> tensor { + %start = tensorrt.constant dense<[0, 1]> : tensor<2xi32> + %0 = tensorrt.slice %arg0 [%start: tensor<2xi32>][%arg1: tensor<2xi32>][1, 1] : tensor to tensor + return %0 : tensor +} + +// CHECK-LABEL: func.func @slice_canon_dynamic_start +// CHECK-SAME: (%[[arg0:.+]]: tensor, %[[arg1:.+]]: tensor<2xi32>) +// CHECK: %[[v0:.+]] = tensorrt.slice %[[arg0]][0, 1][%[[arg1:.+]]: tensor<2xi32>][1, 1] : tensor to tensor +// CHECK: return %[[v0]] : tensor + + diff --git a/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/invalid.mlir b/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/invalid.mlir index beb8c6461..333a26b08 100644 --- a/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/invalid.mlir +++ b/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/invalid.mlir @@ -1713,6 +1713,26 @@ func.func @trt_dequantize(%arg0: tensor<10x10xi8>, %arg1: tensor<2xf32>) -> tens // ----- +func.func @trt_dequanize_non_zero_input_rank(%arg0: tensor) -> tensor { + %k = tensorrt.constant dense<2> : tensor + %scale = tensorrt.constant dense<1.0> : tensor + // expected-error @below {{'tensorrt.dequantize' op operand #0 must be 1D/2D/3D/4D/5D/6D/7D/8D tensor of allowed TensorRT tensor i8 element types or f8E4M3FN type or 4-bit signless integer values, but got 'tensor'}} + %dq_k = tensorrt.dequantize in (%k: tensor) scale (%scale: tensor) -> tensor + return %dq_k : tensor +} + +// ----- + +func.func @trt_subbyte_dequantize_even_final_dim(%arg0: tensor<4x3xf16>) -> tensor<4x3xf16> { + %k = tensorrt.constant dense<2> : tensor<4x3xi4> + %scale = tensorrt.constant dense<1.0> : tensor + // expected-error @below {{'tensorrt.dequantize' op input tensor with sub-byte element type must have even final dimension, but input tensor has final dimension of size 3}} + %dq_k = tensorrt.dequantize in (%k: tensor<4x3xi4>) scale (%scale: tensor) -> tensor<4x3xf16> + return %dq_k : tensor<4x3xf16> +} + +// ----- + func.func @trt_matrix_multiply_trans_vec(%arg0: tensor<1x1x1x50x10xf32>, %arg1: tensor<1x4x240x50xf32>) -> tensor<1x1x240x10xf32> { // expected-error @below {{'tensorrt.matrix_multiply' op inferred type(s) 'tensor<1x4x240x10xf32>' are incompatible with return type(s) of operation 'tensor<1x1x240x10xf32>'}} @@ -2460,13 +2480,13 @@ func.func @trt_fill_linspace_i32(%arg0: tensor, %arg1: tensor<4xi32>) -> te // ----- -func.func @trt_fill_linspace_dynamic_f16() -> tensor<1024x1024xf16> { +func.func @trt_linspace_mismatched_types() -> tensor<1024x1024xf32> { %shape = tensorrt.constant dense<[1024, 1024]>:tensor<2xi32> - %start = tensorrt.constant dense<0.0>:tensor - %step = tensorrt.constant dense<[1.0,1.0]>:tensor<2xf16> - // expected-error @below {{'tensorrt.linspace' op operand #1 must be 0D tensor of 32-bit signless integer or 32-bit float values, but got 'tensor'}} - %0 = tensorrt.linspace [%start:tensor][%shape:tensor<2xi32>][%step:tensor<2xf16>] : tensor<1024x1024xf16> - return %0 : tensor<1024x1024xf16> + %start = tensorrt.constant dense<0> : tensor + %step = tensorrt.constant dense<[1.0,1.0]>:tensor<2xf32> + // expected-error @below {{'tensorrt.linspace' op start and step tensor types must have the same element type}} + %0 = tensorrt.linspace [%start: tensor][%shape: tensor<2xi32>][%step: tensor<2xf32>] : tensor<1024x1024xf32> + return %0 : tensor<1024x1024xf32> } // ----- diff --git a/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/roundtrip.mlir b/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/roundtrip.mlir index f02c3ad96..90a742e1c 100644 --- a/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/roundtrip.mlir +++ b/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/roundtrip.mlir @@ -763,6 +763,7 @@ func.func @trt_concatenation_dynamic_2(%arg0: tensor, %arg1: tensor // CHECK-LABEL: @trt_concatenation_dynamic_2 // CHECK: %[[v0:.+]] = tensorrt.concatenation {axis = 1 : i32} ins(%{{.+}}, %{{.+}} : tensor, tensor) -> tensor // CHECK: return %[[v0]] : tensor + // ----- func.func @trt_select(%arg0: tensor<10x10xi1>, %arg1: tensor<1x10xf32>, %arg2: tensor<10x1xf32>) -> tensor<10x10xf32> { @@ -787,6 +788,17 @@ func.func @trt_select2(%arg0: tensor<10x10xi1>, %arg1: tensor<1x1xf32>, %arg2: t // ----- +func.func @trt_select3(%arg0: tensor<1x10xi1>, %arg1: tensor<10x10xf32>, %arg2: tensor) -> tensor { + %0 = tensorrt.select ins(%arg0, %arg1, %arg2: tensor<1x10xi1>, tensor<10x10xf32>, tensor) + -> tensor + return %0 : tensor +} + +// CHECK-LABEL: @trt_select3 +// CHECK-NEXT: tensorrt.select + +// ----- + func.func @trt_softmax(%arg0 : tensor<10x10x10xf32>) -> tensor<10x10x10xf32> { %0 = tensorrt.softmax {axis = 2 : i64} %arg0 : tensor<10x10x10xf32> return %0 : tensor<10x10x10xf32> diff --git a/mlir-tensorrt/tensorrt/test/Target/TensorRT/TRT10/convolution.mlir b/mlir-tensorrt/tensorrt/test/Target/TensorRT/TRT10/convolution.mlir index 574e76def..2bf4bb94b 100644 --- a/mlir-tensorrt/tensorrt/test/Target/TensorRT/TRT10/convolution.mlir +++ b/mlir-tensorrt/tensorrt/test/Target/TensorRT/TRT10/convolution.mlir @@ -32,16 +32,14 @@ func.func @trt_2d_bf16_convolution(%arg0: tensor<1x32x128x128xbf16>) -> tensor<1 // CHECK-SAME: tensorrt.engine func.func @trt_2d_int4_convolution(%arg0: tensor<1x32x128x128xf16>) -> tensor<1x64x128x128xf16> { - %k = tensorrt.constant dense<2> : tensor<64x32x3x3xi4> + %k = tensorrt.constant dense<2> : tensor<64x32x4x4xi4> %scale = tensorrt.constant dense<1.0> : tensor - %dq_k = tensorrt.dequantize in (%k: tensor<64x32x3x3xi4>) scale (%scale: tensor) -> tensor<64x32x3x3xf16> + %dq_k = tensorrt.dequantize in (%k: tensor<64x32x4x4xi4>) scale (%scale: tensor) -> tensor<64x32x4x4xf16> %0 = tensorrt.convolution { pre_padding = array, - post_padding = array, + post_padding = array, stride = array - } in (%arg0 : tensor<1x32x128x128xf16>) kernel(%dq_k: tensor<64x32x3x3xf16>) -> tensor<1x64x128x128xf16> + } in (%arg0 : tensor<1x32x128x128xf16>) kernel(%dq_k: tensor<64x32x4x4xf16>) -> tensor<1x64x128x128xf16> return %0 : tensor<1x64x128x128xf16> } -// CHECK-LABEL: @trt_2d_int4_convolution -// CHECK-SAME: tensorrt.engine diff --git a/mlir-tensorrt/tensorrt/test/Target/TensorRT/gather.mlir b/mlir-tensorrt/tensorrt/test/Target/TensorRT/gather.mlir index 4a4fb2d49..e4d098ad7 100644 --- a/mlir-tensorrt/tensorrt/test/Target/TensorRT/gather.mlir +++ b/mlir-tensorrt/tensorrt/test/Target/TensorRT/gather.mlir @@ -29,6 +29,20 @@ func.func @trt_gather_default1(%arg0: tensor<10x20x30xf32>, %arg1: tensor<2x5xi3 // CHECK-LABEL: @trt_gather_default1 // CHECK-SAME: tensorrt.engine +func.func @trt_gather_default_dynamic(%arg0: tensor<10x20x30xf32>, + %arg1: tensor {tensorrt.shape_profile = #tensorrt.shape_profile}, + %arg2: tensor<10x?x5x30xf32> {tensorrt.shape_profile = #tensorrt.shape_profile}) + -> tensor<10x?x5x30xf32> { + %0 = tensorrt.gather { + axis = 1 : i64 + } ins(%arg0, %arg1 : tensor<10x20x30xf32>, tensor) -> tensor<10x?x5x30xf32> + %1 = tensorrt.element_wise (%0, %arg2 : tensor<10x?x5x30xf32>, tensor<10x?x5x30xf32>) -> tensor<10x?x5x30xf32> + return %1 : tensor<10x?x5x30xf32> +} + +// CHECK-LABEL: @trt_gather_default_dynamic +// CHECK-SAME: tensorrt.engine + func.func @trt_gather_default_i32(%arg0: tensor<10x20x30xi32>, %arg1: tensor<2x5xi32>, %arg2: tensor<10x2x5x30xi32>) -> tensor<10x2x5x30xi32> { %0 = tensorrt.gather { diff --git a/mlir-tensorrt/tensorrt/test/Target/TensorRT/linspace.mlir b/mlir-tensorrt/tensorrt/test/Target/TensorRT/linspace.mlir index 0da60bec1..f4abdb76c 100644 --- a/mlir-tensorrt/tensorrt/test/Target/TensorRT/linspace.mlir +++ b/mlir-tensorrt/tensorrt/test/Target/TensorRT/linspace.mlir @@ -8,6 +8,8 @@ func.func @trt_fill_linspace() -> tensor<1024xf32> { return %0 : tensor<1024xf32> } + + // CHECK-LABEL: @trt_fill_linspace_i32 // CHECK-SAME: tensorrt.engine func.func @trt_fill_linspace_i32() -> tensor<1024xi32> { @@ -15,6 +17,7 @@ func.func @trt_fill_linspace_i32() -> tensor<1024xi32> { return %0 : tensor<1024xi32> } + // CHECK-LABEL: @trt_fill_linspace_dynamic // CHECK-SAME: tensorrt.engine func.func @trt_fill_linspace_dynamic() -> tensor<1024x1024xf32> { @@ -25,13 +28,39 @@ func.func @trt_fill_linspace_dynamic() -> tensor<1024x1024xf32> { return %0 : tensor<1024x1024xf32> } -func.func @trt_fill_linspace_dynamic_dim(%arg0: tensor {tensorrt.shape_profile = #tensorrt.shape_profile}) -> tensor { - %cst_i32 = tensorrt.constant dense<0> : tensor<1xi32> - %0 = tensorrt.shape %arg0 : tensor -> tensor<2xi32> - %1 = tensorrt.gather {axis = 0 : i64} ins(%0, %cst_i32 : tensor<2xi32>, tensor<1xi32>) -> tensor<1xi32> - %2 = tensorrt.linspace[0.000000e+00] [%1 : tensor<1xi32>] [1.000000e+00] : tensor - %3 = tensorrt.cast %2 : tensor to tensor - return %3 : tensor +// CHECK-LABEL: @dynamic_nd_iota_1 +// CHECK-SAME: tensorrt.engine +func.func @dynamic_nd_iota_1(%arg0: tensor<2xi32> { + tensorrt.value_bounds = #tensorrt.shape_profile, + tensorrt.host_tensor +}) -> tensor { + %cst_i32 = tensorrt.constant dense<0> : tensor + %cst_i32_0 = tensorrt.constant dense<[0, 1]> : tensor<2xi32> + %0 = tensorrt.linspace[%cst_i32 : tensor] [%arg0 : tensor<2xi32>] [%cst_i32_0 : tensor<2xi32>] : tensor + return %0 : tensor } -// CHECK-LABEL: @trt_fill_linspace_dynamic_dim -// CHECK-SAME: tensorrt.engine \ No newline at end of file + +// CHECK-LABEL: @dynamic_nd_iota_2 +// CHECK-SAME: tensorrt.engine +func.func @dynamic_nd_iota_2(%arg0: tensor<2xi32> { + tensorrt.value_bounds = #tensorrt.shape_profile, + tensorrt.host_tensor +}) -> tensor { + %cst_i32 = tensorrt.constant dense<0> : tensor + %cst_i32_0 = tensorrt.constant dense<[1, 0]> : tensor<2xi32> + %0 = tensorrt.linspace[%cst_i32 : tensor] [%arg0 : tensor<2xi32>] [%cst_i32_0 : tensor<2xi32>] : tensor + return %0 : tensor +} + +// CHECK-LABEL: @dynamic_nd_iota_3 +// CHECK-SAME: tensorrt.engine +func.func @dynamic_nd_iota_3(%arg0: tensor<2xi32> { + tensorrt.value_bounds = #tensorrt.shape_profile, + tensorrt.host_tensor +}) -> tensor { + %cst_f16 = tensorrt.constant dense<0> : tensor + %cst_f16_0 = tensorrt.constant dense<[0, 1]> : tensor<2xi64> + %0 = tensorrt.linspace[%cst_f16 : tensor] [%arg0 : tensor<2xi32>] [%cst_f16_0 : tensor<2xi64>] : tensor + return %0 : tensor +} + diff --git a/mlir-tensorrt/test/Conversion/ChloToStablehloExt/chlo-to-stablehlo-ext.mlir b/mlir-tensorrt/test/Conversion/ChloToStablehloExt/chlo-to-stablehlo-ext.mlir new file mode 100644 index 000000000..b8773d028 --- /dev/null +++ b/mlir-tensorrt/test/Conversion/ChloToStablehloExt/chlo-to-stablehlo-ext.mlir @@ -0,0 +1,42 @@ +// RUN: mlir-tensorrt-opt %s -convert-chlo-to-stablehlo-ext="preserve-erf=true preserve-topk=true" -split-input-file | FileCheck %s --check-prefix=CHECK +// RUN: mlir-tensorrt-opt %s -convert-chlo-to-stablehlo-ext="preserve-erf=false preserve-topk=false" -split-input-file | FileCheck %s --check-prefix=LOWERALL + +func.func @erf_inv(%arg0 : tensor<4xf32>) -> tensor<4xf32> { + %0 = chlo.erf_inv %arg0 : tensor<4xf32> -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// We don't need to check the full lowering since that is tested upstream. +// We just test some basic logic and options. + +// CHECK-LABEL: func.func @erf_inv +// CHECK-NOT: chlo.erf_inv + +// LOWERALL-LABEL: func.func @erf_inv +// LOWERALL-NOT: chlo.erf_inv + +// ----- + +func.func @erf(%arg0 : tensor<4xf32>) -> tensor<4xf32> { + %0 = chlo.erf %arg0 : tensor<4xf32> -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: func.func @erf +// CHECK-NEXT: chlo.erf + +// LOWERALL-LABEL: func.func @erf +// LOWERALL-NOT: chlo.erf + +// ----- + +func.func @top_k(%arg0: tensor<1x50257xf32>) -> (tensor<1x50xf32>, tensor<1x50xi32>) { + %values, %indices = chlo.top_k(%arg0, k = 50) : tensor<1x50257xf32> -> (tensor<1x50xf32>, tensor<1x50xi32>) + return %values, %indices : tensor<1x50xf32>, tensor<1x50xi32> +} + +// CHECK-LABEL: func.func @top_k +// CHECK-NEXT: chlo.top_k + +// LOWERALL-LABEL: func.func @top_k +// LOWERALL-NOT: chlo.top_k diff --git a/mlir-tensorrt/test/Conversion/ChloToStablehloExt/lit.local.cfg b/mlir-tensorrt/test/Conversion/ChloToStablehloExt/lit.local.cfg new file mode 100644 index 000000000..1d783b418 --- /dev/null +++ b/mlir-tensorrt/test/Conversion/ChloToStablehloExt/lit.local.cfg @@ -0,0 +1,2 @@ +if not config.enable_hlo: + config.unsupported = True diff --git a/mlir-tensorrt/test/Conversion/PlanToExecutor/plan-to-executor.mlir b/mlir-tensorrt/test/Conversion/PlanToExecutor/plan-to-executor.mlir index df0b291ad..1bd7ac281 100644 --- a/mlir-tensorrt/test/Conversion/PlanToExecutor/plan-to-executor.mlir +++ b/mlir-tensorrt/test/Conversion/PlanToExecutor/plan-to-executor.mlir @@ -82,18 +82,18 @@ func.func @convert_extract(%arg0: tensor<2xf32, #plan.memory_space>) -> f3 // ----- -func.func @bounds_attr_conversion(%arg0: tensor {tensorrt.shape_profile = #tensorrt.shape_profile}, - %arg1: tensor<2xi32> {tensorrt.value_bounds = #tensorrt.shape_profile}) - -> (tensor {tensorrt.shape_profile = #plan.bounds}, - tensor<2xi32> {tensorrt.value_bounds = #plan.bounds : tensor<1xi64>, dense<10> : tensor<1xi64>>}) { +func.func @bounds_attr_conversion(%arg0: tensor {plan.shape_profile = #plan.bounds}, + %arg1: tensor<2xi32> {plan.value_bounds = #plan.bounds:tensor<2xi32>,dense<[10, 10]>:tensor<2xi32>>}) + -> (tensor {plan.shape_profile = #plan.bounds}, + tensor<2xi32> {plan.value_bounds = #plan.bounds : tensor<1xi32>, dense<10> : tensor<1xi32>>}) { return %arg0, %arg1 : tensor, tensor<2xi32> } // CHECK-LABEL: func.func @bounds_attr_conversion -// CHECK-SAME: tensor {tensorrt.shape_profile = #executor.dim_bounds} -// CHECK-SAME: tensor<2xi32> {tensorrt.value_bounds = #executor.value_bounds : tensor<2xi64>, max = dense<10> : tensor<2xi64>>}) -// CHECK-SAME: (tensor {tensorrt.shape_profile = #executor.dim_bounds} -// CHECK-SAME: tensor<2xi32> {tensorrt.value_bounds = #executor.value_bounds : tensor<1xi64>, max = dense<10> : tensor<1xi64>>}) +// CHECK-SAME: tensor {executor.shape_profile = #executor.dim_bounds} +// CHECK-SAME: tensor<2xi32> {executor.value_bounds = #executor.value_bounds : tensor<2xi32>, max = dense<10> : tensor<2xi32>>}) +// CHECK-SAME: (tensor {executor.shape_profile = #executor.dim_bounds} +// CHECK-SAME: tensor<2xi32> {executor.value_bounds = #executor.value_bounds : tensor<1xi32>, max = dense<10> : tensor<1xi32>>}) // ----- diff --git a/mlir-tensorrt/test/Conversion/StablehloToScf/lit.local.cfg b/mlir-tensorrt/test/Conversion/StablehloToScf/lit.local.cfg new file mode 100644 index 000000000..1d783b418 --- /dev/null +++ b/mlir-tensorrt/test/Conversion/StablehloToScf/lit.local.cfg @@ -0,0 +1,2 @@ +if not config.enable_hlo: + config.unsupported = True diff --git a/mlir-tensorrt/test/Conversion/StablehloToTensorRT/stablehlo-conv.mlir b/mlir-tensorrt/test/Conversion/StablehloToTensorRT/stablehlo-conv.mlir index 9bb876798..73dffa4a8 100644 --- a/mlir-tensorrt/test/Conversion/StablehloToTensorRT/stablehlo-conv.mlir +++ b/mlir-tensorrt/test/Conversion/StablehloToTensorRT/stablehlo-conv.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-tensorrt-opt -split-input-file -tensorrt-stablehlo-input-preprocessing -convert-stablehlo-to-tensorrt %s | FileCheck %s +// RUN: mlir-tensorrt-opt -split-input-file -stablehlo-ext-canonicalize-convolution -convert-stablehlo-to-tensorrt %s | FileCheck %s func.func @conv2d_nhwc_rsck_no_padding_dilated( %arg0: tensor<1x32x64x2xf32>, %arg1: tensor<3x3x2x128xf32>) diff --git a/mlir-tensorrt/test/Conversion/StablehloToTensorRT/stablehlo-gather.mlir b/mlir-tensorrt/test/Conversion/StablehloToTensorRT/stablehlo-gather.mlir index 2ee8101fb..15572c187 100644 --- a/mlir-tensorrt/test/Conversion/StablehloToTensorRT/stablehlo-gather.mlir +++ b/mlir-tensorrt/test/Conversion/StablehloToTensorRT/stablehlo-gather.mlir @@ -159,3 +159,76 @@ func.func @gather_negative(%arg0: tensor<1x1x64x12xf16> , %arg1: tensor<1xi32>) } : (tensor<1x1x64x12xf16>, tensor<1xi32>) -> tensor<1x1x64x4xf16> return %2 : tensor<1x1x64x4xf16> } + +// ----- + +func.func @simple_gather_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { + %c1 = stablehlo.constant dense<1> : tensor<1xi32> + %c256 = stablehlo.constant dense<256> : tensor<1xi32> + %dim = stablehlo.get_dimension_size %arg0, dim = 1 : (tensor) -> tensor + %dim.1 = stablehlo.reshape %dim : (tensor) -> tensor<1xi32> + %shape = stablehlo.concatenate %c1, %dim.1, %c256, %c256, dim = 0 : + (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + %0 = "stablehlo.dynamic_gather"(%arg0, %arg1, %shape) { + dimension_numbers = #stablehlo.gather< + offset_dims = [1, 2, 3], + collapsed_slice_dims = [0], + start_index_map = [0], + index_vector_dim = 1>, + indices_are_sorted = false, slice_sizes = array + } : (tensor, tensor, tensor<4xi32>) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func.func @simple_gather_dynamic +// CHECK-SAME: (%[[arg0:.+]]: tensor, %[[arg1:.+]]: tensor) +// CHECK-DAG: %[[v5:.+]] = tensorrt.gather {axis = 0 : i64} ins(%[[arg0]], %[[arg1]] : tensor, tensor) -> tensor +// CHECK-DAG: return %[[v5]] : tensor + +// ----- + +func.func @negative_gather_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { + %c1 = stablehlo.constant dense<1> : tensor<1xi32> + %c256 = stablehlo.constant dense<256> : tensor<1xi32> + // Wrong dimension index, should be dim = 1. + %dim = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor) -> tensor + %dim.1 = stablehlo.reshape %dim : (tensor) -> tensor<1xi32> + %shape = stablehlo.concatenate %c1, %dim.1, %c256, %c256, dim = 0 : + (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + %0 = "stablehlo.dynamic_gather"(%arg0, %arg1, %shape) { + dimension_numbers = #stablehlo.gather< + offset_dims = [1, 2, 3], + collapsed_slice_dims = [0], + start_index_map = [0], + index_vector_dim = 1>, + indices_are_sorted = false, slice_sizes = array + } : (tensor, tensor, tensor<4xi32>) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func.func @negative_gather_dynamic( +// CHECK-NOT: tensorrt.gather + +// ----- + +func.func @negative_gather_dynamic2(%arg0: tensor, %arg1: tensor) -> tensor { + %c1 = stablehlo.constant dense<1> : tensor<1xi32> + %c256 = stablehlo.constant dense<256> : tensor<1xi32> + // Dimension size should be arg0, not arg1 + %dim = stablehlo.get_dimension_size %arg1, dim = 0 : (tensor) -> tensor + %dim.1 = stablehlo.reshape %dim : (tensor) -> tensor<1xi32> + %shape = stablehlo.concatenate %c1, %dim.1, %c256, %c256, dim = 0 : + (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + %0 = "stablehlo.dynamic_gather"(%arg0, %arg1, %shape) { + dimension_numbers = #stablehlo.gather< + offset_dims = [1, 2, 3], + collapsed_slice_dims = [0], + start_index_map = [0], + index_vector_dim = 1>, + indices_are_sorted = false, slice_sizes = array + } : (tensor, tensor, tensor<4xi32>) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func.func @negative_gather_dynamic2( +// CHECK-NOT: tensorrt.gather diff --git a/mlir-tensorrt/test/Conversion/StablehloToTensorRT/stablehlo-to-tensorrt.mlir b/mlir-tensorrt/test/Conversion/StablehloToTensorRT/stablehlo-to-tensorrt.mlir index 3663969ce..1b2daee0e 100644 --- a/mlir-tensorrt/test/Conversion/StablehloToTensorRT/stablehlo-to-tensorrt.mlir +++ b/mlir-tensorrt/test/Conversion/StablehloToTensorRT/stablehlo-to-tensorrt.mlir @@ -1481,17 +1481,45 @@ func.func @hlo_round_nearest_even(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // ----- -func.func @hlo_dynamic_iota(%arg0 : tensor<1xi32>) -> tensor { +func.func @hlo_dynamic_iota_0(%arg0 : tensor<1xi32>) -> tensor { %0 = "stablehlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<1xi32>) -> tensor return %0 : tensor } -// CHECK-LABEL: @hlo_dynamic_iota +// CHECK-LABEL: @hlo_dynamic_iota_0 // CHECK: tensorrt.linspace // CHECK-SAME: [ 0.00{{.+}}] [%arg0 : tensor<1xi32>] [ 1.000{{.+}}] : tensor // ----- +func.func @dynamic_nd_iota_1(%arg0 : tensor<2xi32>) -> tensor { + %0 = "stablehlo.dynamic_iota"(%arg0) {iota_dimension = 1 : i64} : (tensor<2xi32>) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func.func @dynamic_nd_iota_1 +// CHECK-SAME: (%[[arg0:.+]]: tensor<2xi32>) -> tensor { +// CHECK: %[[cst_i32:.+]] = tensorrt.constant dense<0> : tensor +// CHECK: %[[cst_i32_0:.+]] = tensorrt.constant dense<[0, 1]> : tensor<2xi32> +// CHECK: %[[v0:.+]] = tensorrt.linspace[%[[cst_i32]] : tensor] [%[[arg0]] : tensor<2xi32>] [%[[cst_i32_0]] : tensor<2xi32>] : tensor +// CHECK: return %[[v0]] : tensor + +// ----- + +func.func @dynamic_nd_iota_2(%arg0 : tensor<2xi32>) -> tensor { + %0 = "stablehlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<2xi32>) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func.func @dynamic_nd_iota_2 +// CHECK-SAME: (%[[arg0:.+]]: tensor<2xi32>) -> tensor { +// CHECK: %[[cst_i32:.+]] = tensorrt.constant dense<0> : tensor +// CHECK: %[[cst_i32_0:.+]] = tensorrt.constant dense<[1, 0]> : tensor<2xi32> +// CHECK: %[[v0:.+]] = tensorrt.linspace[%[[cst_i32]] : tensor] [%[[arg0]] : tensor<2xi32>] [%[[cst_i32_0]] : tensor<2xi32>] : tensor +// CHECK: return %[[v0]] : tensor + +// ----- + func.func @stablehlo_broadcast(%arg0: tensor<8xf32>) -> tensor<4x8xf32> { %0 = "stablehlo.broadcast"(%arg0) { broadcast_sizes = array diff --git a/mlir-tensorrt/test/Dialect/Plan/bounds-analysis.mlir b/mlir-tensorrt/test/Dialect/Plan/bounds-analysis.mlir index e62b35372..a1ff75d06 100644 --- a/mlir-tensorrt/test/Dialect/Plan/bounds-analysis.mlir +++ b/mlir-tensorrt/test/Dialect/Plan/bounds-analysis.mlir @@ -16,9 +16,9 @@ func.func @test_simple_static(%arg0: tensor<10xf32>, %arg1: tensor<10xf32>) -> t // ----- -#profile0 = #tensorrt.shape_profile +#profile0 = #plan.bounds : tensor, dense<20> : tensor> -func.func @test_forward_backward(%arg0: tensor, %arg1: index {tensorrt.value_bounds = #profile0}) -> tensor { +func.func @test_forward_backward(%arg0: tensor, %arg1: index {plan.value_bounds = #profile0}) -> tensor { %0 = plan.with_shape {tag = "with_shape0"} %arg0(%arg1) : (tensor, index) -> tensor %1 = stablehlo.exponential %0 : tensor %2 = plan.with_shape {tag = "with_shape1"} %1(%arg1) : (tensor, index) -> tensor @@ -41,11 +41,11 @@ func.func @test_forward_backward(%arg0: tensor, %arg1: index {tensorrt.va // ----- -#profile0 = #tensorrt.shape_profile -#profile1 = #tensorrt.shape_profile +#profile0 = #plan.bounds +#profile1 = #plan.bounds -func.func @dot_general_c12(%arg0: tensor {tensorrt.shape_profile = #profile0}, - %arg1: tensor {tensorrt.shape_profile = #profile1}) +func.func @dot_general_c12(%arg0: tensor {plan.shape_profile = #profile0}, + %arg1: tensor {plan.shape_profile = #profile1}) -> tensor { %c2 = arith.constant 2 : index %c0 = arith.constant 0 : index @@ -77,9 +77,9 @@ func.func @dot_general_c12(%arg0: tensor {tensorrt.shape_profile = #p // ----- -#profile0 = #tensorrt.shape_profile +#profile0 = #plan.bounds -func.func @test_unneeded_dynamism(%arg0: tensor {tensorrt.shape_profile = #profile0}) -> tensor { +func.func @test_unneeded_dynamism(%arg0: tensor {plan.shape_profile = #profile0}) -> tensor { %0 = stablehlo.constant dense<[1]> : tensor<1xi32> %c1 = arith.constant 1 : index %1 = plan.inline_group target(#plan.tensorrt_cluster) -> tensor { @@ -106,9 +106,9 @@ func.func @test_unneeded_dynamism(%arg0: tensor {tensorrt.shape_profile = func.func @test_loop_concat( %arg0: tensor<1xf32>, %arg1: tensor<1xi32> - {tensorrt.value_bounds = #tensorrt.shape_profile}, + {plan.value_bounds = #plan.bounds : tensor<1xi32>, dense<[4]> : tensor<1xi32>>}, %arg2: tensor - {tensorrt.shape_profile = #tensorrt.shape_profile}) + {plan.shape_profile = #plan.bounds}) -> tensor { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -148,11 +148,11 @@ func.func @test_loop_concat( // ----- -#profile0 = #tensorrt.shape_profile -#profile1 = #tensorrt.shape_profile +#profile0 = #plan.bounds +#profile1 = #plan.bounds : tensor, dense<6> : tensor> -func.func @test_separated(%arg0: tensor {tensorrt.shape_profile = #profile0}, - %arg1: index {tensorrt.value_bounds = #profile1}) +func.func @test_separated(%arg0: tensor {plan.shape_profile = #profile0}, + %arg1: index {plan.value_bounds = #profile1}) -> tensor { %c0 = arith.constant 0 : index %dim = tensor.dim %arg0, %c0 : tensor @@ -189,11 +189,11 @@ func.func @test_separated(%arg0: tensor {tensorrt.shape_profile = #profil // ----- -#profile0 = #tensorrt.shape_profile -#profile1 = #tensorrt.shape_profile +#profile0 = #plan.bounds +#profile1 = #plan.bounds : tensor<2xi32>, dense<[40, 40]> : tensor<2xi32>> -func.func @test_reshape(%arg0: tensor {tensorrt.shape_profile = #profile0}, - %arg1: tensor<2xi32> {tensorrt.value_bounds = #profile1}) -> tensor { +func.func @test_reshape(%arg0: tensor {plan.shape_profile = #profile0}, + %arg1: tensor<2xi32> {plan.value_bounds = #profile1}) -> tensor { %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index %extracted = tensor.extract %arg1[%c0] : tensor<2xi32> diff --git a/mlir-tensorrt/test/Dialect/Plan/create-closed-regions.mlir b/mlir-tensorrt/test/Dialect/Plan/create-closed-regions.mlir index b530e7b9a..a7aafc285 100644 --- a/mlir-tensorrt/test/Dialect/Plan/create-closed-regions.mlir +++ b/mlir-tensorrt/test/Dialect/Plan/create-closed-regions.mlir @@ -10,31 +10,34 @@ func.func @test_simple_static(%arg0: tensor<10xf32>, %arg1: tensor<10xf32>) -> t return %0 : tensor<10xf32> } +// CHECK: #[[$nobounds:.+]] = #plan.bounds +// CHECK: #[[$bounds1:.+]] = #plan.bounds // CHECK-LABEL: @test_simple_static // CHECK-SAME: (%[[arg0:.+]]: tensor<10xf32>, %[[arg1:.+]]: tensor<10xf32>) -> tensor<10xf32> // CHECK: %[[v0:.+]] = tensor.empty() : tensor<10xf32> // CHECK: %[[v1:.+]] = plan.inline_closed_group target(#plan.tensorrt_cluster) // CHECK: inputs(%[[arg0]], %[[arg1]] : tensor<10xf32>, tensor<10xf32>) // CHECK: outs(%[[v0]] : tensor<10xf32>) -// CHECK: in_attrs [#plan.bounds, #plan.bounds] -// CHECK: res_attrs [#plan.bounds] -> tensor<10xf32> { +// CHECK: in_attrs [#[[$nobounds]], #[[$nobounds]]] +// CHECK: res_attrs [#[[$bounds1]]] -> tensor<10xf32> { // CHECK: ^bb0(%[[in:.+]]: tensor<10xf32>, %[[in_0:.+]]: tensor<10xf32>, %[[out:.+]]: tensor<10xf32>): // CHECK: return %[[v1]] : tensor<10xf32> +// CHECK-ALLOC: #[[$nobounds:.+]] = #plan.bounds // CHECK-ALLOC-LABEL: @test_simple_static // CHECK-ALLOC-SAME: (%[[arg0:.+]]: tensor<10xf32>, %[[arg1:.+]]: tensor<10xf32>) -> tensor<10xf32> // CHECK-ALLOC: %[[v1:.+]] = plan.inline_closed_alloc_group target(#plan.tensorrt_cluster) // CHECK-ALLOC: inputs(%[[arg0]], %[[arg1]] : tensor<10xf32>, tensor<10xf32>) -// CHECK-ALLOC: in_attrs [#plan.bounds, #plan.bounds] +// CHECK-ALLOC: in_attrs [#[[$nobounds]], #[[$nobounds]]] // CHECK-ALLOC: -> tensor<10xf32> { // CHECK-ALLOC: ^bb0(%[[in:.+]]: tensor<10xf32>, %[[in_0:.+]]: tensor<10xf32>): // CHECK-ALLOC: return %[[v1]] : tensor<10xf32> // ----- -#profile0 = #tensorrt.shape_profile +#profile0 = #plan.bounds -func.func @test_simple_shape_bound(%arg0: tensor {tensorrt.shape_profile=#profile0}) -> tensor { +func.func @test_simple_shape_bound(%arg0: tensor {plan.shape_profile=#profile0}) -> tensor { %c10 = arith.constant 10 : index %c0 = arith.constant 0 : index %dim = tensor.dim %arg0, %c0 : tensor @@ -46,7 +49,9 @@ func.func @test_simple_shape_bound(%arg0: tensor {tensorrt.shape_profi return %0 : tensor } -// CHECK: #[[$map:.+]] = affine_map<()[s0] -> (s0 * 10)> +// CHECK-DAG: #[[$nobounds:.+]] = #plan.bounds +// CHECK-DAG: #[[$arg0bounds:.+]] = #plan.bounds +// CHECK-DAG: #[[$map:.+]] = affine_map<()[s0] -> (s0 * 10)> // CHECK-LABEL: @test_simple_shape_bound // CHECK-SAME: (%[[arg0:.+]]: tensor {{.*}}) -> tensor { // CHECK: %[[c10:.+]] = arith.constant 10 : index @@ -61,14 +66,16 @@ func.func @test_simple_shape_bound(%arg0: tensor {tensorrt.shape_profi // CHECK: %[[v2:.+]] = plan.inline_closed_group target(#plan.tensorrt_cluster) // CHECK: inputs(%[[arg0]], %[[dim]], %[[c10]] : tensor, index, index) // CHECK: outs(%[[reshape]] : tensor) -// CHECK: in_attrs [#plan.bounds, #plan.bounds, #plan.bounds] -// CHECK: res_attrs [#plan.bounds] -> tensor { +// CHECK: in_attrs [#[[$arg0bounds]], #[[$nobounds]], #[[$nobounds]]] +// CHECK: res_attrs [#[[$arg0bounds]]] -> tensor { // CHECK: ^bb0(%[[in:.+]]: tensor, %[[in_1:.+]]: index, %[[in_2:.+]]: index, %[[out:.+]]: tensor): // CHECK: %[[v3:.+]] = stablehlo.exponential %[[in]] : tensor // CHECK: %[[v4:.+]] = with_shape %[[v3]](%[[in_1]], %[[in_2]]) : // CHECK: yield %[[v4]] : tensor // CHECK: return %[[v2]] : tensor +// CHECK-ALLOC-DAG: #[[$nobounds:.+]] = #plan.bounds +// CHECK-ALLOC-DAG: #[[$arg0bounds:.+]] = #plan.bounds // CHECK-ALLOC-LABEL: @test_simple_shape_bound // CHECK-ALLOC-SAME: (%[[arg0:.+]]: tensor {{.*}}) -> tensor { // CHECK-ALLOC: %[[c10:.+]] = arith.constant 10 : index @@ -76,7 +83,7 @@ func.func @test_simple_shape_bound(%arg0: tensor {tensorrt.shape_profi // CHECK-ALLOC: %[[dim:.+]] = tensor.dim %[[arg0]], %[[c0]] : tensor // CHECK-ALLOC: %[[v2:.+]] = plan.inline_closed_alloc_group target(#plan.tensorrt_cluster) // CHECK-ALLOC: inputs(%[[arg0]], %[[dim]], %[[c10]] : tensor, index, index) -// CHECK-ALLOC: in_attrs [#plan.bounds, #plan.bounds, #plan.bounds] +// CHECK-ALLOC: in_attrs [#[[$arg0bounds]], #[[$nobounds]], #[[$nobounds]]] // CHECK-ALLOC: -> tensor { // CHECK-ALLOC: ^bb0(%[[in:.+]]: tensor, %[[in_1:.+]]: index, %[[in_2:.+]]: index): // CHECK-ALLOC: %[[v3:.+]] = stablehlo.exponential %[[in]] : tensor @@ -86,11 +93,11 @@ func.func @test_simple_shape_bound(%arg0: tensor {tensorrt.shape_profi // ----- -#profile0 = #tensorrt.shape_profile -#profile1 = #tensorrt.shape_profile +#profile0 = #plan.bounds +#profile1 = #plan.bounds : tensor<2xi32>, dense<[40, 40]> : tensor<2xi32>> -func.func @test_dynamic_reshape(%arg0: tensor {tensorrt.shape_profile = #profile0}, - %arg1: tensor<2xi32> {tensorrt.value_bounds = #profile1}) -> tensor { +func.func @test_dynamic_reshape(%arg0: tensor {plan.shape_profile = #profile0}, + %arg1: tensor<2xi32> {plan.value_bounds = #profile1}) -> tensor { %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index %extracted = tensor.extract %arg1[%c0] : tensor<2xi32> @@ -105,7 +112,11 @@ func.func @test_dynamic_reshape(%arg0: tensor {tensorrt.shape_profile = # return %2 : tensor } -// CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 * s1)> +// CHECK-DAG: #[[$nobounds:.+]] = #plan.bounds +// CHECK-DAG: #[[$bounds0:.+]] = #plan.bounds +// CHECK-DAG: #[[$bounds1:.+]] = #plan.bounds : tensor<2xi32>, dense<40> : tensor<2xi32>> +// CHECK-DAG: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 * s1)> +// CHECK-DAG: #[[$bounds2:.+]] = #plan.bounds // CHECK-LABEL: @test_dynamic_reshape // CHECK-SAME: (%[[arg0:.+]]: tensor {{.*}}, %[[arg1:.+]]: tensor<2xi32> {{.*}}) -> tensor { // CHECK: %[[c1:.+]] = arith.constant 1 : index @@ -123,16 +134,20 @@ func.func @test_dynamic_reshape(%arg0: tensor {tensorrt.shape_profile = # // CHECK: inputs(%[[arg0]], %[[arg1]], %[[v0]], %[[v1]] : tensor, tensor<2xi32>, index, index) // CHECK: outs(%[[reshape]] : tensor) // CHECK-NEXT: in_attrs [ -// CHECK-SAME: #plan.bounds, -// CHECK-SAME: #plan.bounds : tensor<2xi32>, dense<40> : tensor<2xi32>>, -// CHECK-SAME: #plan.bounds, #plan.bounds] -// CHECK-NEXT: res_attrs [#plan.bounds] -> tensor +// CHECK-SAME: #[[$bounds0]], +// CHECK-SAME: #[[$bounds1]], +// CHECK-SAME: #[[$nobounds]], #[[$nobounds]]] +// CHECK-NEXT: res_attrs [#[[$bounds2]]] -> tensor // CHECK: ^bb0(%[[in:.+]]: tensor, %[[in_1:.+]]: tensor<2xi32>, %[[in_2:.+]]: index, %[[in_3:.+]]: index, %[[out:.+]]: tensor): // CHECK: %[[v5:.+]] = stablehlo.dynamic_reshape %[[in]], %[[in_1]] // CHECK: %[[v6:.+]] = with_shape %[[v5]](%[[in_2]], %[[in_3]]) // CHECK: yield %[[v6]] : tensor // CHECK: return %[[v4]] : tensor + +// CHECK-ALLOC-DAG: #[[$nobounds:.+]] = #plan.bounds +// CHECK-ALLOC-DAG: #[[$bounds0:.+]] = #plan.bounds +// CHECK-ALLOC-DAG: #[[$bounds1:.+]] = #plan.bounds : tensor<2xi32>, dense<40> : tensor<2xi32>> // CHECK-ALLOC-LABEL: @test_dynamic_reshape // CHECK-ALLOC-SAME: (%[[arg0:.+]]: tensor {{.*}}, %[[arg1:.+]]: tensor<2xi32> {{.*}}) -> tensor { // CHECK-ALLOC: %[[c1:.+]] = arith.constant 1 : index @@ -144,9 +159,9 @@ func.func @test_dynamic_reshape(%arg0: tensor {tensorrt.shape_profile = # // CHECK-ALLOC: %[[v4:.+]] = plan.inline_closed_alloc_group target(#plan.tensorrt_cluster) // CHECK-ALLOC: inputs(%[[arg0]], %[[arg1]], %[[v0]], %[[v1]] : tensor, tensor<2xi32>, index, index) // CHECK-ALLOC-NEXT: in_attrs [ -// CHECK-ALLOC-SAME: #plan.bounds, -// CHECK-ALLOC-SAME: #plan.bounds : tensor<2xi32>, dense<40> : tensor<2xi32>>, -// CHECK-ALLOC-SAME: #plan.bounds, #plan.bounds] +// CHECK-ALLOC-SAME: #[[$bounds0]], +// CHECK-ALLOC-SAME: #[[$bounds1]], +// CHECK-ALLOC-SAME: #[[$nobounds]], #[[$nobounds]]] // CHECK-ALLOC-NEXT: -> tensor // CHECK-ALLOC: ^bb0(%[[in:.+]]: tensor, %[[in_1:.+]]: tensor<2xi32>, %[[in_2:.+]]: index, %[[in_3:.+]]: index): // CHECK-ALLOC: %[[v5:.+]] = stablehlo.dynamic_reshape %[[in]], %[[in_1]] @@ -156,10 +171,10 @@ func.func @test_dynamic_reshape(%arg0: tensor {tensorrt.shape_profile = # // ----- -#profile0 = #tensorrt.shape_profile -#profile1 = #tensorrt.shape_profile +#profile0 = #plan.bounds +#profile1 = #plan.bounds -func.func @test_get_dim_size_max(%arg0: tensor {tensorrt.shape_profile=#profile0}, %arg1: tensor {tensorrt.shape_profile=#profile1}) -> tensor { +func.func @test_get_dim_size_max(%arg0: tensor {plan.shape_profile=#profile0}, %arg1: tensor {plan.shape_profile=#profile1}) -> tensor { %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index %0 = stablehlo.constant dense<0.000000e+00> : tensor<1x1xf32> @@ -188,7 +203,10 @@ func.func @test_get_dim_size_max(%arg0: tensor {tensorrt.shape_profile= return %3 : tensor } -// CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 * s1)> +// CHECK-DAG: #[[$nobounds:.+]] = #plan.bounds +// CHECK-DAG: #[[$bounds0:.+]] = #plan.bounds +// CHECK-DAG: #[[$bounds1:.+]] = #plan.bounds +// CHECK-DAG: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 * s1)> // CHECK-LABEL: @test_get_dim_size_max // CHECK-SAME: (%[[arg0:.+]]: tensor {{.+}}, %[[arg1:.+]]: tensor {{.+}}) // CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index @@ -208,8 +226,11 @@ func.func @test_get_dim_size_max(%arg0: tensor {tensorrt.shape_profile= // CHECK: %[[v5:.+]] = plan.inline_closed_group target(#plan.tensorrt_cluster) // CHECK: inputs(%[[arg0]], %[[arg1]], %[[v1]], %[[v2]] : tensor, tensor, index, index) // CHECK: outs(%[[reshape]] : tensor) -// CHECK: in_attrs [#plan.bounds, #plan.bounds, #plan.bounds, #plan.bounds] -// CHECK: res_attrs [#plan.bounds] -> tensor +// CHECK: in_attrs [ +// CHECK-SAME: #[[$bounds0]], +// CHECK-SAME: #[[$bounds1]], +// CHECK-SAME: #[[$nobounds]], #[[$nobounds]]] +// CHECK: res_attrs [#[[$bounds1]]] -> tensor // CHECK: ^bb0(%[[in:.+]]: tensor, %[[in_3:.+]]: tensor, %[[in_4:.+]]: index, %[[in_5:.+]]: index, %[[out:.+]]: tensor): // CHECK: %[[v18:.+]] = stablehlo.dynamic_broadcast_in_dim // CHECK: %[[v19:.+]] = with_shape %[[v18]](%[[in_4]], %[[in_5]]) : @@ -220,14 +241,14 @@ func.func @test_get_dim_size_max(%arg0: tensor {tensorrt.shape_profile= // ----- -#profile0 = #tensorrt.shape_profile -#profile1 = #tensorrt.shape_profile -#profile2 = #tensorrt.shape_profile +#profile0 = #plan.bounds +#profile1 = #plan.bounds : tensor<1xindex>, dense<[100]> : tensor<1xindex>> +#profile2 = #plan.bounds : tensor<1xindex>, dense<[2]> : tensor<1xindex>> -func.func @real_dynamic_slice(%arg0: tensor {tensorrt.shape_profile = #profile0}, - %arg1: tensor<1xindex> {tensorrt.value_bounds = #profile1}, - %arg2: tensor<1xindex> {tensorrt.value_bounds = #profile1}, - %arg3: tensor<1xindex> {tensorrt.value_bounds = #profile2}) -> tensor { +func.func @real_dynamic_slice(%arg0: tensor {plan.shape_profile = #profile0}, + %arg1: tensor<1xindex> {plan.value_bounds = #profile1}, + %arg2: tensor<1xindex> {plan.value_bounds = #profile1}, + %arg3: tensor<1xindex> {plan.value_bounds = #profile2}) -> tensor { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %extracted = tensor.extract %arg1[%c0] : tensor<1xindex> @@ -245,6 +266,11 @@ func.func @real_dynamic_slice(%arg0: tensor {tensorrt.shape_profile = #pr return %4 : tensor } +// CHECK-DAG: #[[$bounds0:.+]] = #plan.bounds +// CHECK-DAG: #[[$bounds1:.+]] = #plan.bounds : tensor<1xindex>, dense<100> : tensor<1xindex>> +// CHECK-DAG: #[[$bounds2:.+]] = #plan.bounds : tensor<1xindex>, dense<2> : tensor<1xindex>> +// CHECK-DAG: #[[$bounds3:.+]] = #plan.bounds +// CHECK-DAG: #[[$nobounds:.+]] = #plan.bounds // CHECK-LABEL: @real_dynamic_slice // CHECK-SAME: (%[[arg0:.+]]: tensor {{.+}}, %[[arg1:.+]]: tensor<1xindex> {{.+}}, %[[arg2:.+]]: tensor<1xindex> {{.+}}, %[[arg3:.+]]: tensor<1xindex> {{.+}}) -> tensor { // CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index @@ -262,12 +288,12 @@ func.func @real_dynamic_slice(%arg0: tensor {tensorrt.shape_profile = #pr // CHECK: inputs(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[v3]] : tensor, tensor<1xindex>, tensor<1xindex>, tensor<1xindex>, index) // CHECK: outs(%[[extracted_slice]] : tensor) // CHECK: in_attrs [ -// CHECK-SAME: #plan.bounds, -// CHECK-SAME: #plan.bounds : tensor<1xindex>, dense<100> : tensor<1xindex>>, -// CHECK-SAME: #plan.bounds : tensor<1xindex>, dense<100> : tensor<1xindex>>, -// CHECK-SAME: #plan.bounds : tensor<1xindex>, dense<2> : tensor<1xindex>>, -// CHECK-SAME: #plan.bounds] -// CHECK-NEXT: res_attrs [#plan.bounds] -> tensor { +// CHECK-SAME: #[[$bounds0]], +// CHECK-SAME: #[[$bounds1]], +// CHECK-SAME: #[[$bounds1]], +// CHECK-SAME: #[[$bounds2]], +// CHECK-SAME: #[[$nobounds]]] +// CHECK-NEXT: res_attrs [#[[$bounds3]]] -> tensor { // CHECK: ^bb0(%[[in:.+]]: tensor, %[[in_2:.+]]: tensor<1xindex>, %[[in_3:.+]]: tensor<1xindex>, %[[in_4:.+]]: tensor<1xindex>, %[[in_5:.+]]: index, %[[out:.+]]: tensor): // CHECK: %[[v6:.+]] = stablehlo.real_dynamic_slice %[[in]], %[[in_2]], %[[in_3]], %[[in_4]] // CHECK: %[[v7:.+]] = with_shape %[[v6]](%[[in_5]]) : @@ -276,11 +302,11 @@ func.func @real_dynamic_slice(%arg0: tensor {tensorrt.shape_profile = #pr // ----- -#profile0 = #tensorrt.shape_profile -#profile1 = #tensorrt.shape_profile +#profile0 = #plan.bounds +#profile1 = #plan.bounds -func.func @dot_general_c12(%arg0: tensor {tensorrt.shape_profile = #profile0}, - %arg1: tensor {tensorrt.shape_profile = #profile1}) +func.func @dot_general_c12(%arg0: tensor {plan.shape_profile = #profile0}, + %arg1: tensor {plan.shape_profile = #profile1}) -> tensor { %c2 = arith.constant 2 : index %c0 = arith.constant 0 : index @@ -295,6 +321,8 @@ func.func @dot_general_c12(%arg0: tensor {tensorrt.shape_profile = #p return %0 : tensor } +// CHECK-DAG: #[[$bounds0:.+]] = #plan.bounds +// CHECK: #[[$nobounds:.+]] = #plan.bounds // CHECK: #[[$map:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * s2)> // CHECK-LABEL: @dot_general_c12 // CHECK-SAME: (%[[arg0:.+]]: tensor {{.*}}, %[[arg1:.+]]: tensor {{.*}}) -> tensor { @@ -311,8 +339,8 @@ func.func @dot_general_c12(%arg0: tensor {tensorrt.shape_profile = #p // CHECK: %[[v2:.+]] = plan.inline_closed_group target(#plan.tensorrt_cluster) // CHECK: inputs(%[[arg0]], %[[arg1]], %[[dim]], %[[dim_0]], %[[dim_1]] : tensor, tensor, index, index, index) // CHECK: outs(%[[reshape]] : tensor) -// CHECK: in_attrs [#plan.bounds, #plan.bounds, #plan.bounds, #plan.bounds, #plan.bounds] -// CHECK: res_attrs [#plan.bounds] -> tensor { +// CHECK: in_attrs [#[[$bounds0]], #[[$bounds0]], #[[$nobounds]], #[[$nobounds]], #[[$nobounds]]] +// CHECK: res_attrs [#[[$bounds0]]] -> tensor { // CHECK: ^bb0(%[[in:.+]]: tensor, %[[in_2:.+]]: tensor, %[[in_3:.+]]: index, %[[in_4:.+]]: index, %[[in_5:.+]]: index, %[[out]]: tensor): // CHECK: %[[v3:.+]] = stablehlo.dot_general %[[in]], %[[in_2]] // CHECK: %[[v4:.+]] = with_shape %[[v3]](%[[in_3]], %[[in_4]], %[[in_5]]) : @@ -321,14 +349,14 @@ func.func @dot_general_c12(%arg0: tensor {tensorrt.shape_profile = #p // ----- -#profile0 = #tensorrt.shape_profile -#profile1 = #tensorrt.shape_profile +#profile0 = #plan.bounds +#profile1 = #plan.bounds : tensor<1xindex>, dense<2> : tensor<1xindex>> -func.func @dynamic_pad(%arg0: tensor {tensorrt.shape_profile = #profile0}, +func.func @dynamic_pad(%arg0: tensor {plan.shape_profile = #profile0}, %arg1: tensor, - %arg2: tensor<1xindex> {tensorrt.value_bounds = #profile1}, - %arg3: tensor<1xindex> {tensorrt.value_bounds = #profile1}, - %arg4: tensor<1xindex> {tensorrt.value_bounds = #profile1}) -> tensor { + %arg2: tensor<1xindex> {plan.value_bounds = #profile1}, + %arg3: tensor<1xindex> {plan.value_bounds = #profile1}, + %arg4: tensor<1xindex> {plan.value_bounds = #profile1}) -> tensor { %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index %dim = tensor.dim %arg0, %c0 : tensor @@ -350,6 +378,10 @@ func.func @dynamic_pad(%arg0: tensor {tensorrt.shape_profile = #profile0} return %8 : tensor } +// CHECK-DAG: #[[$nobounds:.+]] = #plan.bounds +// CHECK-DAG: #[[$bounds0:.+]] = #plan.bounds : tensor<1xindex>, dense<2> : tensor<1xindex>> +// CHECK-DAG: #[[$bounds1:.+]] = #plan.bounds +// CHECK-DAG: #[[$bounds2:.+]] = #plan.bounds // CHECK-LABEL: @dynamic_pad // CHECK-SAME: (%[[arg0:.+]]: tensor {{.*}}, %[[arg1:.+]]: tensor, %[[arg2:.+]]: tensor<1xindex> {{.*}}, %[[arg3:.+]]: tensor<1xindex> {{.*}}, %[[arg4:.+]]: tensor<1xindex> {{.*}}) -> tensor { // CHECK: %[[c1:.+]] = arith.constant 1 : index @@ -371,13 +403,13 @@ func.func @dynamic_pad(%arg0: tensor {tensorrt.shape_profile = #profile0} // CHECK: inputs(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]], %[[v6]] : tensor, tensor, tensor<1xindex>, tensor<1xindex>, tensor<1xindex>, index) // CHECK: outs(%[[extracted_slice]] : tensor) // CHECK: in_attrs [ -// CHECK-SAME: #plan.bounds, -// CHECK-SAME: #plan.bounds, -// CHECK-SAME: #plan.bounds : tensor<1xindex>, dense<2> : tensor<1xindex>>, -// CHECK-SAME: #plan.bounds : tensor<1xindex>, dense<2> : tensor<1xindex>>, -// CHECK-SAME: #plan.bounds : tensor<1xindex>, dense<2> : tensor<1xindex>>, -// CHECK-SAME: #plan.bounds] -// CHECK: res_attrs [#plan.bounds] -> tensor { +// CHECK-SAME: #[[$bounds2]], +// CHECK-SAME: #[[$nobounds]], +// CHECK-SAME: #[[$bounds0]], +// CHECK-SAME: #[[$bounds0]], +// CHECK-SAME: #[[$bounds0]], +// CHECK-SAME: #[[$nobounds]]] +// CHECK: res_attrs [#[[$bounds1]]] -> tensor { // CHECK: ^bb0(%[[in:.+]]: tensor, %[[in_2:.+]]: tensor, %[[in_3:.+]]: tensor<1xindex>, %[[in_4:.+]]: tensor<1xindex>, %[[in_5:.+]]: tensor<1xindex>, %[[in_6:.+]]: index, %[[out]]: tensor): // CHECK: %[[v9:.+]] = stablehlo.dynamic_pad %[[in]], %[[in_2]], %[[in_3]], %[[in_4]], %[[in_5]] // CHECK: %[[v10:.+]] = with_shape %[[v9]](%[[in_6]]) : @@ -386,9 +418,9 @@ func.func @dynamic_pad(%arg0: tensor {tensorrt.shape_profile = #profile0} // ----- -#profile0 = #tensorrt.shape_profile +#profile0 = #plan.bounds -func.func @broadcast(%arg0: tensor {tensorrt.shape_profile = #profile0}) -> tensor<1x2x?xi32> { +func.func @broadcast(%arg0: tensor {plan.shape_profile = #profile0}) -> tensor<1x2x?xi32> { %c0 = arith.constant 0 : index %c2 = arith.constant 2 : index %c1 = arith.constant 1 : index @@ -403,9 +435,9 @@ func.func @broadcast(%arg0: tensor {tensorrt.shape_profile = #profile0}) // ----- -#profile0 = #tensorrt.shape_profile +#profile0 = #plan.bounds -func.func @transpose(%arg0: tensor {tensorrt.shape_profile = #profile0}) -> tensor { +func.func @transpose(%arg0: tensor {plan.shape_profile = #profile0}) -> tensor { %c2 = arith.constant 2 : index %c3 = arith.constant 3 : index %c0 = arith.constant 0 : index @@ -422,7 +454,10 @@ func.func @transpose(%arg0: tensor {tensorrt.shape_profile = #profi return %0 : tensor } -// CHECK: #[[$map:.+]] = affine_map<()[s0, s1, s2, s3] -> (((s0 * s1) * s2) * s3)> +// CHECK-DAG: #[[$nobounds:.+]] = #plan.bounds +// CHECK-DAG: #[[$bounds0:.+]] = #plan.bounds +// CHECK-DAG: #[[$bounds1:.+]] = #plan.bounds +// CHECK-DAG: #[[$map:.+]] = affine_map<()[s0, s1, s2, s3] -> (((s0 * s1) * s2) * s3)> // CHECK-LABEL: @transpose // CHECK-SAME: (%[[arg0:.+]]: tensor {{.*}}) -> tensor { // CHECK: %[[c2:.+]] = arith.constant 2 : index @@ -441,8 +476,8 @@ func.func @transpose(%arg0: tensor {tensorrt.shape_profile = #profi // CHECK: %[[v2:.+]] = plan.inline_closed_group target(#plan.tensorrt_cluster) // CHECK: inputs(%[[arg0]], %[[dim]], %[[dim_0]], %[[dim_1]], %[[dim_2]] : tensor, index, index, index, index) // CHECK: outs(%[[reshape]] : tensor) -// CHECK: in_attrs [#plan.bounds, #plan.bounds, #plan.bounds, #plan.bounds, #plan.bounds] -// CHECK: res_attrs [#plan.bounds] -> tensor { +// CHECK: in_attrs [#[[$bounds0]], #[[$nobounds]], #[[$nobounds]], #[[$nobounds]], #[[$nobounds]]] +// CHECK: res_attrs [#[[$bounds1]]] // CHECK: ^bb0(%[[in:.+]]: tensor, %[[in_3:.+]]: index, %[[in_4:.+]]: index, %[[in_5:.+]]: index, %[[in_6:.+]]: index, %[[out]]: tensor): // CHECK: %[[v3:.+]] = stablehlo.transpose %[[in]] // CHECK: %[[v4:.+]] = with_shape %[[v3]](%[[in_3]], %[[in_4]], %[[in_5]], %[[in_6]]) : @@ -451,9 +486,9 @@ func.func @transpose(%arg0: tensor {tensorrt.shape_profile = #profi // ----- -#profile0 = #tensorrt.shape_profile +#profile0 = #plan.bounds : tensor<1xindex>, dense<6> : tensor<1xindex>> -func.func @dynamic_iota(%arg0: tensor<1xindex> {tensorrt.value_bounds = #profile0}) -> tensor { +func.func @dynamic_iota(%arg0: tensor<1xindex> {plan.value_bounds = #profile0}) -> tensor { %c0 = arith.constant 0 : index %extracted = tensor.extract %arg0[%c0] : tensor<1xindex> %0 = plan.inline_group target(#plan.tensorrt_cluster) attributes {__cluster_target__ = #plan.tensorrt_cluster} -> tensor { @@ -464,6 +499,10 @@ func.func @dynamic_iota(%arg0: tensor<1xindex> {tensorrt.value_bounds = #profile return %0 : tensor } +// CHECK-DAG: #[[$nobounds:.+]] = #plan.bounds +// CHECK-DAG: #[[$bounds0:.+]] = #plan.bounds +// CHECK-DAG: #[[$bounds1:.+]] = #plan.bounds : tensor<1xindex>, dense<6> : tensor<1xindex>> + // CHECK-LABEL: @dynamic_iota // CHECK-SAME: (%[[arg0:.+]]: tensor<1xindex> {{.*}}) -> tensor { // CHECK: %[[c0:.+]] = arith.constant 0 : index @@ -474,9 +513,9 @@ func.func @dynamic_iota(%arg0: tensor<1xindex> {tensorrt.value_bounds = #profile // CHECK: inputs(%[[arg0]], %[[extracted]] : tensor<1xindex>, index) // CHECK: outs(%[[extracted_slice]] : tensor) // CHECK: in_attrs [ -// CHECK-SAME: #plan.bounds : tensor<1xindex>, dense<6> : tensor<1xindex>>, -// CHECK-SAME: #plan.bounds] -// CHECK: res_attrs [#plan.bounds] -> tensor { +// CHECK-SAME: #[[$bounds1]], +// CHECK-SAME: #[[$nobounds]]] +// CHECK: res_attrs [#[[$bounds0]]] -> tensor { // CHECK: ^bb0(%[[in:.+]]: tensor<1xindex>, %[[in_0:.+]]: index, %[[out]]: tensor): // CHECK: %[[v2:.+]] = stablehlo.dynamic_iota %[[in]] // CHECK: %[[v3:.+]] = with_shape %[[v2]](%[[in_0]]) : @@ -485,13 +524,13 @@ func.func @dynamic_iota(%arg0: tensor<1xindex> {tensorrt.value_bounds = #profile // ----- -#profile0 = #tensorrt.shape_profile -#profile1 = #tensorrt.shape_profile -#profile2 = #tensorrt.shape_profile +#profile0 = #plan.bounds : tensor, dense<6> : tensor> +#profile1 = #plan.bounds +#profile2 = #plan.bounds -func.func @add_dynamic(%arg0: tensor {tensorrt.value_bounds = #profile0}, - %arg1: tensor {tensorrt.shape_profile = #profile1}, - %arg2: tensor<2x?x4xf32> {tensorrt.shape_profile = #profile2}) -> tensor<2x?x4xf32> { +func.func @add_dynamic(%arg0: tensor {plan.value_bounds = #profile0}, + %arg1: tensor {plan.shape_profile = #profile1}, + %arg2: tensor<2x?x4xf32> {plan.shape_profile = #profile2}) -> tensor<2x?x4xf32> { %c2 = arith.constant 2 : index %c4 = arith.constant 4 : index %c1 = arith.constant 1 : index @@ -522,6 +561,10 @@ func.func @add_dynamic(%arg0: tensor {tensorrt.value_bounds = #profile0}, return %5 : tensor<2x?x4xf32> } +// CHECK-DAG: #[[$nobounds:.+]] = #plan.bounds +// CHECK-DAG: #[[$bounds1:.+]] = #plan.bounds : tensor, dense<6> : tensor> +// CHECK-DAG: #[[$bounds2:.+]] = #plan.bounds +// CHECK-DAG: #[[$bounds3:.+]] = #plan.bounds // CHECK-LABEL: @add_dynamic // CHECK-SAME: (%[[arg0:.+]]: tensor {{.*}}, %[[arg1:.+]]: tensor {{.*}}, %[[arg2:.+]]: tensor<2x?x4xf32> {{.*}}) -> tensor<2x?x4xf32> { // CHECK: %[[c2:.+]] = arith.constant 2 : index @@ -544,20 +587,20 @@ func.func @add_dynamic(%arg0: tensor {tensorrt.value_bounds = #profile0}, // CHECK: %[[v7:.+]] = plan.inline_closed_group target(#plan.tensorrt_cluster) // CHECK-NEXT: inputs(%[[v3]], %[[arg1]], %[[c1]], %[[v4]], %[[c4]], %[[c2]], %[[arg2]] : // CHECK-NEXT: outs(%[[reshape]] : tensor<2x?x4xf32>) -// CHECK-NEXT: in_attrs [#plan.bounds : tensor, dense<6> : tensor>, #plan.bounds, #plan.bounds, #plan.bounds, #plan.bounds, #plan.bounds, #plan.bounds] -// CHECK-NEXT: res_attrs [#plan.bounds] -> tensor<2x?x4xf32> { +// CHECK-NEXT: in_attrs [#[[$bounds1]], #[[$bounds2]], #[[$nobounds]], #[[$nobounds]], #[[$nobounds]], #[[$nobounds]], #[[$bounds3]]] +// CHECK-NEXT: res_attrs [#[[$bounds3]]] -> tensor<2x?x4xf32> { // CHECK: return %[[v7]] : tensor<2x?x4xf32> // ----- -#profile0 = #tensorrt.shape_profile -#profile1 = #tensorrt.shape_profile +#profile0 = #plan.bounds : tensor, dense<6> : tensor> +#profile1 = #plan.bounds -func.func @collapse_dynamic(%arg0: tensor {tensorrt.value_bounds = #profile0}, - %arg1: tensor {tensorrt.value_bounds = #profile0}, - %arg2: tensor {tensorrt.value_bounds = #profile0}, - %arg3: tensor {tensorrt.shape_profile = #profile1}) -> tensor { +func.func @collapse_dynamic(%arg0: tensor {plan.value_bounds = #profile0}, + %arg1: tensor {plan.value_bounds = #profile0}, + %arg2: tensor {plan.value_bounds = #profile0}, + %arg3: tensor {plan.shape_profile = #profile1}) -> tensor { %c7 = arith.constant 7 : index %c5 = arith.constant 5 : index %0 = stablehlo.constant dense<7> : tensor<1xi32> @@ -599,7 +642,11 @@ func.func @collapse_dynamic(%arg0: tensor {tensorrt.value_bounds = #profile return %9 : tensor } - +// CHECK-DAG: #[[$nobounds:.+]] = #plan.bounds +// CHECK-DAG: #[[$bounds0:.+]] = #plan.bounds : tensor, dense<6> : tensor> +// CHECK-DAG: #[[$bounds1:.+]] = #plan.bounds +// CHECK-DAG: #[[$bounds2:.+]] = #plan.bounds : tensor, dense<36> : tensor> +// CHECK-DAG: #[[$bounds3:.+]] = #plan.bounds // CHECK-LABEL: @collapse_dynamic // CHECK-SAME: (%[[arg0:.+]]: tensor {{.+}}, %[[arg1:.+]]: tensor {{.*}}, %[[arg2:.+]]: tensor {{.*}}, %[[arg3:.+]]: tensor {{.*}}) -> tensor { // CHECK-DAG: %[[c7:.+]] = arith.constant 7 : index @@ -627,10 +674,10 @@ func.func @collapse_dynamic(%arg0: tensor {tensorrt.value_bounds = #profile // CHECK: %[[v11:.+]] = plan.inline_closed_group target(#plan.tensorrt_cluster) // CHECK: inputs(%[[v2]], %[[v3]], %[[arg3]], %[[v4]], %[[v8]], %[[c7]] : tensor, tensor, tensor, index, index, index) // CHECK: outs(%[[reshape]] : tensor) -// CHECK: in_attrs [#plan.bounds : tensor, dense<6> : tensor>, -// CHECK-SAME: #plan.bounds : tensor, dense<36> : tensor>, -// CHECK-SAME: #plan.bounds, #plan.bounds, #plan.bounds, #plan.bounds] -// CHECK: res_attrs [#plan.bounds] -> tensor { +// CHECK: in_attrs [#[[$bounds0]], +// CHECK-SAME: #[[$bounds2]], +// CHECK-SAME: #[[$bounds3]], #[[$nobounds]], #[[$nobounds]], #[[$nobounds]]] +// CHECK: res_attrs [#[[$bounds1]]] -> tensor { // CHECK: ^bb0(%[[in:.+]]: tensor, %[[in_3:.+]]: tensor, %[[in_4:.+]]: tensor, %[[in_5:.+]]: index, %[[in_6:.+]]: index, %[[in_7:.+]]: index, %[[out]]: tensor): // CHECK: %[[v12:.+]] = stablehlo.constant dense<7> : tensor<1xi32> // CHECK: %[[v13:.+]] = stablehlo.reshape %[[in]] : (tensor) -> tensor<1xi32> @@ -643,11 +690,11 @@ func.func @collapse_dynamic(%arg0: tensor {tensorrt.value_bounds = #profile // ----- -#profile0 = #tensorrt.shape_profile -#profile1 = #tensorrt.shape_profile +#profile0 = #plan.bounds +#profile1 = #plan.bounds : tensor, dense<6> : tensor> -func.func @test_separated(%arg0: tensor {tensorrt.shape_profile = #profile0}, - %arg1: index {tensorrt.value_bounds = #profile1}) +func.func @test_separated(%arg0: tensor {plan.shape_profile = #profile0}, + %arg1: index {plan.value_bounds = #profile1}) -> tensor { %c0 = arith.constant 0 : index %dim = tensor.dim %arg0, %c0 : tensor @@ -666,6 +713,9 @@ func.func @test_separated(%arg0: tensor {tensorrt.shape_profile = #profil return %1 : tensor } +// CHECK-DAG: #[[$nobounds:.+]] = #plan.bounds +// CHECK-DAG: #[[$bounds0:.+]] = #plan.bounds +// CHECK-DAG: #[[$bounds1:.+]] = #plan.bounds // CHECK-LABEL: @test_separated // CHECK-SAME: (%[[arg0:.+]]: tensor {{.*}}, %[[arg1:.+]]: index {{.*}}) -> tensor { // CHECK: %[[c0:.+]] = arith.constant 0 : index @@ -675,8 +725,8 @@ func.func @test_separated(%arg0: tensor {tensorrt.shape_profile = #profil // CHECK: %[[v1:.+]] = plan.inline_closed_group target(#plan.tensorrt_cluster) // CHECK: inputs(%[[arg0]], %[[dim]] : tensor, index) // CHECK: outs(%[[extracted_slice]] : tensor) -// CHECK: in_attrs [#plan.bounds, #plan.bounds] -// CHECK: res_attrs [#plan.bounds] -> tensor { +// CHECK: in_attrs [#[[$bounds0]], #[[$nobounds]]] +// CHECK: res_attrs [#[[$bounds0]]] -> tensor { // CHECK: ^bb0(%[[in:.+]]: tensor, %[[in_2:.+]]: index, %[[out]]: tensor): // CHECK: %[[v5:.+]] = stablehlo.exponential %[[in]] : tensor // CHECK: %[[v6:.+]] = with_shape {tag = "with_shape0"} %[[v5]](%[[in_2]]) : @@ -688,8 +738,8 @@ func.func @test_separated(%arg0: tensor {tensorrt.shape_profile = #profil // CHECK: %[[v4:.+]] = plan.inline_closed_group target(#plan.tensorrt_cluster) // CHECK: inputs(%[[extracted_slice_0]], %[[arg1]] : tensor, index) // CHECK: outs(%[[extracted_slice_1]] : tensor) -// CHECK: in_attrs [#plan.bounds, #plan.bounds] -// CHECK: res_attrs [#plan.bounds] -> tensor { +// CHECK: in_attrs [#[[$bounds1]], #[[$nobounds]]] +// CHECK: res_attrs [#[[$bounds1]]] -> tensor { // CHECK: ^bb0(%[[in:.+]]: tensor, %[[in_2:.+]]: index, %[[out]]: tensor): // CHECK: %[[v5:.+]] = stablehlo.exponential %[[in]] : tensor // CHECK: %[[v6:.+]] = with_shape {tag = "with_shape1"} %[[v5]](%[[in_2]]) : @@ -703,9 +753,9 @@ func.func @test_separated(%arg0: tensor {tensorrt.shape_profile = #profil // into its static version. Normally this would never occur, but we should still handle this // situation gracefully. -#profile0 = #tensorrt.shape_profile +#profile0 = #plan.bounds -func.func @test_unneeded_dynamism(%arg0: tensor {tensorrt.shape_profile = #profile0}) -> tensor { +func.func @test_unneeded_dynamism(%arg0: tensor {plan.shape_profile = #profile0}) -> tensor { %0 = stablehlo.constant dense<[1]> : tensor<1xi32> %c1 = arith.constant 1 : index %1 = plan.inline_group target(#plan.tensorrt_cluster) -> tensor { @@ -716,6 +766,8 @@ func.func @test_unneeded_dynamism(%arg0: tensor {tensorrt.shape_profile = return %1 : tensor } +// CHECK-DAG: #[[$nobounds:.+]] = #plan.bounds +// CHECK-DAG: #[[$bounds0:.+]] = #plan.bounds // CHECK-LABEL: @test_unneeded_dynamism // CHECK-SAME: (%[[arg0:.+]]: tensor {{.*}}) -> tensor { // CHECK: %[[v0:.+]] = stablehlo.constant dense<1> : tensor<1xi32> @@ -728,8 +780,8 @@ func.func @test_unneeded_dynamism(%arg0: tensor {tensorrt.shape_profile = // CHECK: %[[v2:.+]] = plan.inline_closed_group target(#plan.tensorrt_cluster) // CHECK: inputs(%[[arg0]], %[[c1]] : tensor, index) // CHECK: outs(%[[reshape]] : tensor) -// CHECK: in_attrs [#plan.bounds, #plan.bounds] -// CHECK: res_attrs [#plan.bounds] -> tensor { +// CHECK: in_attrs [#[[$bounds0]], #[[$nobounds]]] +// CHECK: res_attrs [#[[$bounds0]]] -> tensor { // CHECK: ^bb0(%[[in:.+]]: tensor, %[[in_1:.+]]: index, %[[out]]: tensor): // CHECK: %[[v3:.+]] = stablehlo.constant dense<1> : tensor<1xi32> // CHECK: %[[v4:.+]] = stablehlo.dynamic_broadcast_in_dim %[[in]], %[[v3]] @@ -739,14 +791,14 @@ func.func @test_unneeded_dynamism(%arg0: tensor {tensorrt.shape_profile = // ----- -#profile0 = #tensorrt.shape_profile -#profile1 = #tensorrt.shape_profile +#profile0 = #plan.bounds +#profile1 = #plan.bounds // Connected regions verifies that the bounds analysis result is correctly updated // as we rewrite the IR. -func.func @test_connected_regions(%arg0: tensor {tensorrt.shape_profile = #profile0}, - %arg1: tensor {tensorrt.shape_profile = #profile1}) +func.func @test_connected_regions(%arg0: tensor {plan.shape_profile = #profile0}, + %arg1: tensor {plan.shape_profile = #profile1}) -> tensor { %c0 = arith.constant 0 : index %dim = tensor.dim %arg0, %c0 : tensor @@ -764,6 +816,10 @@ func.func @test_connected_regions(%arg0: tensor {tensorrt.shape_profile = return %1 : tensor } +// CHECK-DAG: #[[$nobounds:.+]] = #plan.bounds +// CHECK-DAG: #[[$bounds0:.+]] = #plan.bounds +// CHECK-DAG: #[[$bounds1:.+]] = #plan.bounds + // CHECK-LABEL: func.func @test_connected_regions // CHECK-SAME: (%[[arg0:.+]]: tensor {{.*}}, %[[arg1:.+]]: tensor // CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index @@ -773,8 +829,8 @@ func.func @test_connected_regions(%arg0: tensor {tensorrt.shape_profile = // CHECK: %[[v1:.+]] = plan.inline_closed_group // CHECK-NEXT: inputs(%[[arg0]], %[[dim]] : // CHECK-NEXT: outs(%[[extracted_slice]] : -// CHECK-NEXT: in_attrs [#plan.bounds, #plan.bounds] -// CHECK-NEXT: res_attrs [#plan.bounds] -> tensor { +// CHECK-NEXT: in_attrs [#[[$bounds0]], #[[$nobounds]]] +// CHECK-NEXT: res_attrs [#[[$bounds0]]] -> tensor { // CHECK-NEXT: ^bb0(%[[in:.+]]: tensor, %[[in_2:.+]]: index, %[[out:.+]]: tensor): // CHECK-NEXT: %[[v4:.+]] = stablehlo.exponential %[[in]] : tensor // CHECK-NEXT: %[[v5:.+]] = with_shape %[[v4]](%[[in_2]]) : @@ -785,8 +841,8 @@ func.func @test_connected_regions(%arg0: tensor {tensorrt.shape_profile = // CHECK: %[[v3:.+]] = plan.inline_closed_group // CHECK-NEXT: inputs(%[[v1]], %[[arg1]], %[[dim_0]] : // CHECK-NEXT: outs(%[[extracted_slice_1]] : tensor) -// CHECK-NEXT: in_attrs [#plan.bounds, #plan.bounds, #plan.bounds] -// CHECK-NEXT: res_attrs [#plan.bounds] -> tensor { +// CHECK-NEXT: in_attrs [#[[$bounds0]], #[[$bounds1]], #[[$nobounds]]] +// CHECK-NEXT: res_attrs [#[[$bounds1]]] -> tensor { // CHECK-NEXT: ^bb0(%[[in:.+]]: tensor, %[[in_2:.+]]: tensor, %[[in_3:.+]]: index, %[[out:.+]]: tensor): // CHECK-NEXT: %[[v4:.+]] = stablehlo.add %[[in]], %[[in_2]] : tensor // CHECK-NEXT: %[[v5:.+]] = with_shape %[[v4]](%[[in_3]]) : @@ -798,13 +854,13 @@ func.func @test_connected_regions(%arg0: tensor {tensorrt.shape_profile = // Connected regions verifies that the bounds analysis result is correctly updated // as we rewrite the IR. -#profile0 = #tensorrt.shape_profile +#profile0 = #plan.bounds : tensor, dense<123> : tensor> func.func @test_connected_regions_host_values( %arg0: tensor<128xf32>, %arg1: tensor<4xf32>, - %arg2: tensor {tensorrt.host_tensor, tensorrt.value_bounds = #profile0}) + %arg2: tensor {tensorrt.host_tensor, plan.value_bounds = #profile0}) -> tensor<128xf32> { %0:2 = plan.inline_group target(#plan.tensorrt_cluster) attributes {__cluster_target__ = #plan.tensorrt_cluster} -> tensor<128xf32>, tensor { %1 = stablehlo.dynamic_update_slice %arg0, %arg1, %arg2 : (tensor<128xf32>, tensor<4xf32>, tensor) -> tensor<128xf32> @@ -817,23 +873,29 @@ func.func @test_connected_regions_host_values( return %1 : tensor<128xf32> } +// CHECK-DAG: #[[$nobounds:.+]] = #plan.bounds +// CHECK-DAG: #[[$bounds0:.+]] = #plan.bounds : tensor, dense<123> : tensor> +// CHECK-DAG: #[[$bounds1:.+]] = #plan.bounds +// CHECK-DAG: #[[$bounds2:.+]] = #plan.bounds // CHECK-LABEL: @test_connected_regions_host_values // CHECK: %[[v2:.+]]:2 = plan.inline_closed_group target(#plan.tensorrt_cluster) // CHECK-NEXT: inputs({{.+}} : tensor<128xf32>, tensor<4xf32>, tensor) // CHECK-NEXT: outs(%{{.+}} : tensor<128xf32>, tensor) -// CHECK-NEXT: in_attrs [#plan.bounds, #plan.bounds, #plan.bounds : tensor, dense<123> : tensor>] -// CHECK-NEXT: res_attrs [#plan.bounds, #plan.bounds] -> tensor<128xf32>, tensor { +// CHECK-NEXT: in_attrs [#[[$nobounds]], #[[$nobounds]], #[[$bounds0]]] +// CHECK-NEXT: res_attrs [#[[$bounds1]], #[[$bounds2]]] -> tensor<128xf32>, tensor { // CHECK: plan.inline_closed_group target(#plan.tensorrt_cluster) // CHECK-NEXT: inputs({{.*}}, %[[v2]]#1 : tensor<128xf32>, tensor<4xf32>, tensor) // CHECK-NEXT: outs(%[[v3]] : tensor<128xf32>) -// CHECK-NEXT: in_attrs [#plan.bounds, #plan.bounds, #plan.bounds : tensor, dense<123> : tensor>] -// CHECK-NEXT: res_attrs [#plan.bounds] -> tensor<128xf32> +// CHECK-NEXT: in_attrs [#[[$nobounds]], #[[$nobounds]], #[[$bounds0]]] +// CHECK-NEXT: res_attrs [#[[$bounds1]]] -> tensor<128xf32> // ----- -func.func @shape_calc(%arg0: tensor {tensorrt.shape_profile = #tensorrt.shape_profile}, - %arg1: tensor<2xi32> {tensorrt.value_bounds = #tensorrt.shape_profile}, - %arg2: tensor<2xi32> {tensorrt.value_bounds = #tensorrt.shape_profile}) + +func.func @shape_calc(%arg0: tensor {plan.shape_profile = #plan.bounds}, + %arg1: tensor<2xi32> {plan.value_bounds = #plan.bounds : tensor<2xi32>, dense<[2, 2]> : tensor<2xi32>>}, + %arg2: tensor<2xi32> {plan.value_bounds = #plan.bounds : tensor<2xi32>, dense<[2, 2]> : tensor<2xi32>>}) -> tensor { %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index @@ -867,9 +929,13 @@ func.func @shape_calc(%arg0: tensor {tensorrt.shape_profile = #tensorrt.s return %12 : tensor } -// CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 * s1)> +// CHECK-DAG: #[[$bounds0:.+]] = #plan.bounds : tensor<2xi32>, dense<2> : tensor<2xi32>> +// CHECK-DAG: #[[$bounds1:.+]] = #plan.bounds +// CHECK-DAG: #[[$bounds2:.+]] = #plan.bounds +// CHECK-DAG: #[[$nobounds:.+]] = #plan.bounds +// CHECK-DAG: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 * s1)> // CHECK-LABEL: func.func @shape_calc -// CHECK-SAME: (%[[arg0:.+]]: tensor {tensorrt.shape_profile = #tensorrt.shape_profile}, %[[arg1:.+]]: tensor<2xi32> {tensorrt.value_bounds = #tensorrt.shape_profile}, %[[arg2:.+]]: tensor<2xi32> {tensorrt.value_bounds = #tensorrt.shape_profile}) -> tensor { +// CHECK-SAME: (%[[arg0:.+]]: tensor{{.*}}, %[[arg1:.+]]: tensor<2xi32>{{.*}}, %[[arg2:.+]]: tensor<2xi32>{{.*}}) // CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[extracted:.+]] = tensor.extract %[[arg2]][%[[c0]]] : tensor<2xi32> @@ -896,11 +962,11 @@ func.func @shape_calc(%arg0: tensor {tensorrt.shape_profile = #tensorrt.s // CHECK: %[[v14:.+]] = plan.inline_closed_group target(#plan.tensorrt_cluster) // CHECK: inputs(%[[arg2]], %[[extracted]], %[[extracted_0]], %[[arg1]], %[[extracted_1]], %[[extracted_2]], %[[v0]], %[[v1]], %[[v2]], %[[v3]], %[[arg0]], %[[v7]], %[[v11]] : // CHECK: outs(%[[reshape]] : tensor) -// CHECK: in_attrs [#plan.bounds : tensor<2xi32>, dense<2> : tensor<2xi32>>, -// CHECK-SAME: #plan.bounds, #plan.bounds, -// CHECK-SAME: #plan.bounds : tensor<2xi32>, dense<2> : tensor<2xi32>>, -// CHECK-SAME: #plan.bounds, #plan.bounds, #plan.bounds, #plan.bounds, #plan.bounds, -// CHECK-SAME: #plan.bounds, #plan.bounds, #plan.bounds, #plan.bounds] -// CHECK: res_attrs [#plan.bounds] -> tensor { +// CHECK: in_attrs [#[[$bounds0]], +// CHECK-SAME: #[[$nobounds]], #[[$nobounds]], +// CHECK-SAME: #[[$bounds0]], +// CHECK-SAME: #[[$nobounds]], #[[$nobounds]], #[[$nobounds]], #[[$nobounds]], #[[$nobounds]], +// CHECK-SAME: #[[$nobounds]], #[[$bounds2]], #[[$nobounds]], #[[$nobounds]]] +// CHECK: res_attrs [#[[$bounds1]]] -> tensor { // CHECK: return %[[v14]] : tensor diff --git a/mlir-tensorrt/test/Dialect/Plan/invalid.mlir b/mlir-tensorrt/test/Dialect/Plan/invalid.mlir index 5822ed610..10cb35186 100644 --- a/mlir-tensorrt/test/Dialect/Plan/invalid.mlir +++ b/mlir-tensorrt/test/Dialect/Plan/invalid.mlir @@ -36,6 +36,75 @@ #plan_value_bounds = #plan.bounds : tensor<2x3xf32>, dense<[[7., 8., 9.], [10., 4., 7.]]> : tensor<2x3xf32>> +// ----- + +#bounds = #plan.bounds : tensor<1x10xi32>, dense<20> : tensor<1x10xi32>> +// expected-error @below {{'func.func' op arg #0 expected type of values bounds elements ('tensor<1x10xi32>') to be compatible with the type ('tensor<1x11xi32>')}} +func.func @value_bounds_shape_mismatch(%arg0: tensor<1x11xi32> {plan.value_bounds = #bounds}) { + return +} + + +// ----- + +#bounds = #plan.bounds + +// expected-error @below {{'func.func' op arg #0 has type 'tensor<1x?xi32>', whose rank is not equal to the rank of the corresponding shape bounds #plan.bounds}} +func.func @value_bounds_shape_mismatch(%arg0: tensor<1x?xi32> {plan.shape_profile = #bounds}) { + return +} + +// ----- + +#bounds = #plan.bounds +func.func @value_bounds_shape_0d_match(%arg0: tensor {plan.shape_profile = #bounds}) { + return +} + +// ----- + +#bounds = #plan.bounds +// expected-error @below {{'func.func' op expected only value bounds or none bounds for scalar arg #0 of type 'i32', but got #plan.bounds}} +func.func @value_bounds_shape_mismatch(%arg0: i32 {plan.shape_profile = #bounds}) { + return +} + +// ----- + +#bounds = #plan.bounds : tensor<1xi32>, dense<20> : tensor<1xi32>> + +// expected-error @below {{'func.func' op arg #0 expected type of values bounds elements ('tensor<1xi32>') to be compatible with the type ('tensor')}} +func.func @value_bounds_0rank_shape_mismatch(%arg0: tensor {plan.value_bounds = #bounds}) { + return +} + +// ----- + +#bounds = #plan.bounds : tensor<1x11xi32>, dense<20> : tensor<1x11xi32>> + +// expected-error @below {{'func.func' op arg #0 expected element type of value bounds elements ('i32') to be compatible with the type ('tensor<1x11xi64>')}} +func.func @value_bounds_element_type_mismatch(%arg0: tensor<1x11xi64> {plan.value_bounds = #bounds}) { + return +} + +// ----- + +#bounds = #plan.bounds : tensor<1xi32>,dense<20> : tensor<1xi32>> + +// expected-error @below {{'func.func' op arg #0 type expects rank-0 value bounds type, but got 'tensor<1xi32>'}} +func.func @value_bounds_scalar_shape_mismatch(%arg0: i32 {plan.value_bounds = #bounds}) { + return +} + +// ----- + +#bounds = #plan.bounds : tensor, dense<20> : tensor> + +func.func @value_bounds_scalar_shape_ok(%arg0: i32 {plan.value_bounds = #bounds}) { + return +} + + // ----- func.func @plan_inline_group_mismatched_result_types(%arg0: tensor<10xf32>, %arg1: index) { diff --git a/mlir-tensorrt/test/Dialect/Plan/materialize-shape-calculations.mlir b/mlir-tensorrt/test/Dialect/Plan/materialize-shape-calculations.mlir index e9b6990c1..06f03bd0a 100644 --- a/mlir-tensorrt/test/Dialect/Plan/materialize-shape-calculations.mlir +++ b/mlir-tensorrt/test/Dialect/Plan/materialize-shape-calculations.mlir @@ -698,7 +698,7 @@ func.func @test_loop_concat( } // CHECK-LABEL: @test_loop_concat -// CHECK-SAME: (%[[arg0:.+]]: tensor<1xf32>, %[[arg1:.+]]: tensor<1xi32>{{.*}}, %[[arg2:.+]]: tensor{{.*}}, %[[arg3:.+]]: tensor<1024xf32>) +// CHECK-SAME: (%[[arg0:.+]]: tensor<1xf32>, %[[arg1:[a-zA-Z0-9]+]]: tensor<1xi32>{{.*}}, %[[arg2:[a-zA-Z0-9]+]]: tensor{{.*}}, %[[arg3:[a-zA-Z0-9]+]]: tensor<1024xf32>) // CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index // CHECK: %[[dim:.+]] = tensor.dim %[[arg2]], %[[c0]] : tensor @@ -789,7 +789,7 @@ func.func @zero_slice_slice(%arg4: tensor<1xi32>, } // CHECK-LABEL: func.func @zero_slice_slice -// CHECK-SAME: (%[[arg0:.+]]: tensor<1xi32>, %[[arg1:.+]]: tensor<1xi32> {tensorrt.value_bounds = #tensorrt.shape_profile}, %[[arg2:.+]]: tensor<1xi32> {tensorrt.value_bounds = #tensorrt.shape_profile}, %[[arg3:.+]]: tensor<1xi32> {tensorrt.value_bounds = #tensorrt.shape_profile}, %[[arg4:.+]]: tensor<1xi32> {tensorrt.shape_profile = #tensorrt.shape_profile}) -> tensor { +// CHECK-SAME: (%[[arg0:.+]]: tensor<1xi32>, %[[arg1:.+]]: tensor<1xi32> {plan.value_bounds = #plan.bounds : tensor<1xi32>, dense<1> : tensor<1xi32>>}, %[[arg2:.+]]: tensor<1xi32> {plan.value_bounds = #plan.bounds : tensor<1xi32>, dense<1> : tensor<1xi32>>}, %[[arg3:.+]]: tensor<1xi32> {plan.value_bounds = #plan.bounds : tensor<1xi32>, dense<1> : tensor<1xi32>>}, %[[arg4:.+]]: tensor<1xi32> {plan.shape_profile = #plan.bounds}) // CHECK-DAG: %[[cst:.+]] = arith.constant dense<1> : tensor<1xi32> // CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[c1_i32:.+]] = arith.constant 1 : i32 diff --git a/mlir-tensorrt/test/Dialect/Plan/populate-func-bounds-attrs.mlir b/mlir-tensorrt/test/Dialect/Plan/populate-func-bounds-attrs.mlir index 393f37b8d..70416578b 100644 --- a/mlir-tensorrt/test/Dialect/Plan/populate-func-bounds-attrs.mlir +++ b/mlir-tensorrt/test/Dialect/Plan/populate-func-bounds-attrs.mlir @@ -1,7 +1,7 @@ // RUN: mlir-tensorrt-opt %s -split-input-file -plan-populate-func-bounds-attrs | FileCheck %s -func.func public @single_return(%arg0: tensor {tensorrt.shape_profile = #tensorrt.shape_profile}, - %arg1: tensor {tensorrt.shape_profile = #tensorrt.shape_profile}) -> tensor { +func.func public @single_return(%arg0: tensor {plan.shape_profile = #plan.bounds}, + %arg1: tensor {plan.shape_profile = #plan.bounds}) -> tensor { %c0 = arith.constant 0 : index %0 = stablehlo.add %arg0, %arg1 : tensor %dim = tensor.dim %arg0, %c0 : tensor @@ -9,12 +9,13 @@ func.func public @single_return(%arg0: tensor {tensorrt.shape_profile = # return %1 : tensor } +// CHECK: #[[$bounds1:.+]] = #plan.bounds // CHECK-LABEL: @single_return -// CHECK-SAME: -> (tensor {tensorrt.shape_profile = #plan.bounds}) +// CHECK-SAME: -> (tensor {plan.shape_profile = #[[$bounds1]]}) // ----- -func.func public @multiple_return(%arg0: tensor {tensorrt.shape_profile = #tensorrt.shape_profile}, %arg1: tensor {tensorrt.shape_profile = #tensorrt.shape_profile}) -> (tensor, tensor) { +func.func public @multiple_return(%arg0: tensor {plan.shape_profile = #plan.bounds}, %arg1: tensor {plan.shape_profile = #plan.bounds}) -> (tensor, tensor) { %c0 = arith.constant 0 : index %0 = stablehlo.add %arg0, %arg0 : tensor %1 = stablehlo.add %arg1, %arg1 : tensor @@ -25,12 +26,16 @@ func.func public @multiple_return(%arg0: tensor {tensorrt.shape_profile = return %2, %3 : tensor, tensor } +// CHECK: #[[$bounds1:.+]] = #plan.bounds +// CHECK: #[[$bounds2:.+]] = #plan.bounds + + // CHECK-LABEL: @multiple_return -// CHECK-SAME: -> (tensor {tensorrt.shape_profile = #plan.bounds}, tensor {tensorrt.shape_profile = #plan.bounds}) +// CHECK-SAME: -> (tensor {plan.shape_profile = #[[$bounds1]]}, tensor {plan.shape_profile = #[[$bounds2]]}) // ----- -func.func public @scalar_return(%arg0: tensor {tensorrt.shape_profile = #tensorrt.shape_profile}) -> i32 { +func.func public @scalar_return(%arg0: tensor {plan.shape_profile = #plan.bounds}) -> i32 { %c0 = arith.constant 0 : index %0 = stablehlo.add %arg0, %arg0 : tensor %1 = tensor.extract %0[%c0] : tensor @@ -42,7 +47,7 @@ func.func public @scalar_return(%arg0: tensor {tensorrt.shape_profile = # // ----- -func.func public @static_return(%arg0: tensor {tensorrt.shape_profile = #tensorrt.shape_profile}) -> tensor<1xi32> { +func.func public @static_return(%arg0: tensor {plan.shape_profile = #plan.bounds}) -> tensor<1xi32> { %c0 = arith.constant 0 : index %0 = stablehlo.add %arg0, %arg0 : tensor %1 = tensor.extract %0[%c0] : tensor @@ -55,7 +60,7 @@ func.func public @static_return(%arg0: tensor {tensorrt.shape_profile = # // ----- -func.func @mixed_dims(%arg0: tensor {tensorrt.shape_profile = #tensorrt.shape_profile}) -> tensor { +func.func @mixed_dims(%arg0: tensor {plan.shape_profile = #plan.bounds}) -> tensor { %c10 = arith.constant 10 : index %c0 = arith.constant 0 : index %0 = stablehlo.exponential %arg0 : tensor @@ -64,12 +69,13 @@ func.func @mixed_dims(%arg0: tensor {tensorrt.shape_profile = #tensorr return %1 : tensor } +// CHECK: #[[$bounds1:.+]] = #plan.bounds // CHECK-LABEL: @mixed_dims -// CHECK-SAME: -> (tensor {tensorrt.shape_profile = #plan.bounds}) +// CHECK-SAME: -> (tensor {plan.shape_profile = #[[$bounds1]]}) // ----- -func.func @transpose(%arg0: tensor {tensorrt.shape_profile = #tensorrt.shape_profile}) -> tensor { +func.func @transpose(%arg0: tensor {plan.shape_profile = #plan.bounds}) -> tensor { %c2 = arith.constant 2 : index %c3 = arith.constant 3 : index %c0 = arith.constant 0 : index @@ -83,12 +89,13 @@ func.func @transpose(%arg0: tensor {tensorrt.shape_profile = #tenso return %1 : tensor } +// CHECK: #[[$bounds1:.+]] = #plan.bounds // CHECK-LABEL: @transpose -// CHECK-SAME: -> (tensor {tensorrt.shape_profile = #plan.bounds}) +// CHECK-SAME: -> (tensor {plan.shape_profile = #[[$bounds1]]}) // ----- -func.func @reverse(%arg0: tensor {tensorrt.shape_profile = #tensorrt.shape_profile}) -> tensor { +func.func @reverse(%arg0: tensor {plan.shape_profile = #plan.bounds}) -> tensor { %c3 = arith.constant 3 : index %c2 = arith.constant 2 : index %c1 = arith.constant 1 : index @@ -102,12 +109,14 @@ func.func @reverse(%arg0: tensor {tensorrt.shape_profile = #tensorr return %1 : tensor } + +// CHECK: #[[$bounds1:.+]] = #plan.bounds // CHECK-LABEL: @reverse -// CHECK-SAME: -> (tensor {tensorrt.shape_profile = #plan.bounds}) +// CHECK-SAME: -> (tensor {plan.shape_profile = #[[$bounds1]]}) // ----- -func.func @broadcast(%arg0: tensor {tensorrt.shape_profile = #tensorrt.shape_profile}) -> tensor<1x2x?xi32> { +func.func @broadcast(%arg0: tensor {plan.shape_profile = #plan.bounds}) -> tensor<1x2x?xi32> { %c0 = arith.constant 0 : index %c2 = arith.constant 2 : index %c1 = arith.constant 1 : index @@ -117,12 +126,13 @@ func.func @broadcast(%arg0: tensor {tensorrt.shape_profile = #tensorrt.sh return %1 : tensor<1x2x?xi32> } +// CHECK: #[[$bounds1:.+]] = #plan.bounds // CHECK-LABEL: @broadcast -// CHECK-SAME: -> (tensor<1x2x?xi32> {tensorrt.shape_profile = #plan.bounds}) +// CHECK-SAME: -> (tensor<1x2x?xi32> {plan.shape_profile = #[[$bounds1]]}) // ----- -func.func @gather(%arg0: tensor<3x4x2xi32>, %arg1: tensor {tensorrt.shape_profile = #tensorrt.shape_profile}) -> tensor { +func.func @gather(%arg0: tensor<3x4x2xi32>, %arg1: tensor {plan.shape_profile = #plan.bounds}) -> tensor { %c2 = arith.constant 2 : index %c3 = arith.constant 3 : index %c0 = arith.constant 0 : index @@ -132,12 +142,14 @@ func.func @gather(%arg0: tensor<3x4x2xi32>, %arg1: tensor {tensorrt.s return %1 : tensor } + +// CHECK: #[[$bounds1:.+]] = #plan.bounds // CHECK-LABEL: @gather -// CHECK-SAME: -> (tensor {tensorrt.shape_profile = #plan.bounds}) +// CHECK-SAME: -> (tensor {plan.shape_profile = #[[$bounds1]]}) // ----- -func.func @test_dynamic_reshape(%arg0: tensor {tensorrt.shape_profile = #tensorrt.shape_profile}, %arg1: tensor<2xi32> {tensorrt.value_bounds = #tensorrt.shape_profile}) -> tensor { +func.func @test_dynamic_reshape(%arg0: tensor {plan.shape_profile = #plan.bounds}, %arg1: tensor<2xi32> {plan.value_bounds = #plan.bounds : tensor<2xi32>, dense<[40, 40]> : tensor<2xi32>>}) -> tensor { %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index %0 = stablehlo.dynamic_reshape %arg0, %arg1 : (tensor, tensor<2xi32>) -> tensor @@ -149,12 +161,13 @@ func.func @test_dynamic_reshape(%arg0: tensor {tensorrt.shape_profile = # return %3 : tensor } +// CHECK: #[[$bounds1:.+]] = #plan.bounds // CHECK-LABEL: @test_dynamic_reshape -// CHECK-SAME: -> (tensor {tensorrt.shape_profile = #plan.bounds}) +// CHECK-SAME: -> (tensor {plan.shape_profile = #[[$bounds1]]}) // ----- -func.func @test_get_dim_size_max(%arg0: tensor {tensorrt.shape_profile = #tensorrt.shape_profile}, %arg1: tensor {tensorrt.shape_profile = #tensorrt.shape_profile}) -> tensor { +func.func @test_get_dim_size_max(%arg0: tensor {plan.shape_profile = #plan.bounds}, %arg1: tensor {plan.shape_profile = #plan.bounds}) -> tensor { %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index %0 = stablehlo.constant dense<0.000000e+00> : tensor<1x1xf32> @@ -180,12 +193,13 @@ func.func @test_get_dim_size_max(%arg0: tensor {tensorrt.shape_profile return %15 : tensor } +// CHECK: #[[$bounds1:.+]] = #plan.bounds // CHECK-LABEL: @test_get_dim_size_max -// CHECK-SAME: -> (tensor {tensorrt.shape_profile = #plan.bounds}) +// CHECK-SAME: -> (tensor {plan.shape_profile = #[[$bounds1]]}) // ----- -func.func @dot_general(%arg0: tensor {tensorrt.shape_profile = #tensorrt.shape_profile}, %arg1: tensor {tensorrt.shape_profile = #tensorrt.shape_profile}) -> tensor { +func.func @dot_general(%arg0: tensor {plan.shape_profile = #plan.bounds}, %arg1: tensor {plan.shape_profile = #plan.bounds}) -> tensor { %c2 = arith.constant 2 : index %c0 = arith.constant 0 : index %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0] x [0], contracting_dims = [1] x [1] : (tensor, tensor) -> tensor @@ -196,12 +210,14 @@ func.func @dot_general(%arg0: tensor {tensorrt.shape_profile = #tenso return %1 : tensor } + +// CHECK: #[[$bounds1:.+]] = #plan.bounds // CHECK-LABEL: @dot_general -// CHECK-SAME: -> (tensor {tensorrt.shape_profile = #plan.bounds}) +// CHECK-SAME: -> (tensor {plan.shape_profile = #[[$bounds1]]}) // ----- -func.func @test_loop_concat(%arg0: tensor<1xf32>, %arg1: tensor<1xi32> {tensorrt.value_bounds = #tensorrt.shape_profile}, %arg2: tensor {tensorrt.shape_profile = #tensorrt.shape_profile}, %arg3: tensor<1024xf32>) -> tensor { +func.func @test_loop_concat(%arg0: tensor<1xf32>, %arg1: tensor<1xi32> {plan.value_bounds = #plan.bounds : tensor<1xi32>, dense<[4]> : tensor<1xi32>>}, %arg2: tensor {plan.shape_profile = #plan.bounds}, %arg3: tensor<1024xf32>) -> tensor { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %dim = tensor.dim %arg2, %c0 : tensor @@ -219,12 +235,13 @@ func.func @test_loop_concat(%arg0: tensor<1xf32>, %arg1: tensor<1xi32> {tensorrt return %2 : tensor } +// CHECK: #[[$bounds1:.+]] = #plan.bounds // CHECK-LABEL: @test_loop_concat -// CHECK-SAME: -> (tensor {tensorrt.shape_profile = #plan.bounds}) +// CHECK-SAME: -> (tensor {plan.shape_profile = #[[$bounds1]]}) // ----- -func.func @real_dynamic_slice(%arg0: tensor {tensorrt.shape_profile = #tensorrt.shape_profile}, %arg1: tensor<1xindex> { tensorrt.value_bounds = #tensorrt.shape_profile}, %arg2: tensor<1xindex> { tensorrt.value_bounds = #tensorrt.shape_profile}, %arg3: tensor<1xindex> { tensorrt.value_bounds = #tensorrt.shape_profile}) -> tensor { +func.func @real_dynamic_slice(%arg0: tensor {plan.shape_profile = #plan.bounds}, %arg1: tensor<1xindex> { plan.value_bounds = #plan.bounds : tensor<1xindex>, dense<[0]> : tensor<1xindex>>}, %arg2: tensor<1xindex> { plan.value_bounds = #plan.bounds : tensor<1xindex>, dense<[5]> : tensor<1xindex>>}, %arg3: tensor<1xindex> { plan.value_bounds = #plan.bounds : tensor<1xindex>, dense<[1]> : tensor<1xindex>>}) -> tensor { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %0 = stablehlo.real_dynamic_slice %arg0, %arg1, %arg2, %arg3 : (tensor, tensor<1xindex>, tensor<1xindex>, tensor<1xindex>) -> tensor @@ -239,15 +256,16 @@ func.func @real_dynamic_slice(%arg0: tensor {tensorrt.shape_profile = #te return %5 : tensor } +// CHECK: #[[$bounds1:.+]] = #plan.bounds // CHECK-LABEL: @real_dynamic_slice -// CHECK-SAME: -> (tensor {tensorrt.shape_profile = #plan.bounds}) +// CHECK-SAME: -> (tensor {plan.shape_profile = #[[$bounds1]]}) // ----- -#bounds0 = #tensorrt.shape_profile -#bounds1 = #tensorrt.shape_profile +#bounds0 = #plan.bounds +#bounds1 = #plan.bounds : tensor<2xi32>, dense<[10,10]> : tensor<2xi32>> -func.func @value_bounds(%arg0: tensor {tensorrt.shape_profile = #bounds0}, %arg1: tensor<2xi32> {tensorrt.value_bounds = #bounds1}) -> (tensor, tensor<2xi32>) { +func.func @value_bounds(%arg0: tensor {plan.shape_profile = #bounds0}, %arg1: tensor<2xi32> {plan.value_bounds = #bounds1}) -> (tensor, tensor<2xi32>) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %d0 = tensor.extract %arg1[%c0] : tensor<2xi32> @@ -258,8 +276,8 @@ func.func @value_bounds(%arg0: tensor {tensorrt.shape_profile = #bounds0} return {tag="return"} %0, %with_bounds : tensor, tensor<2xi32> } +// CHECK-DAG: #[[$bounds2:.+]] = #plan.bounds +// CHECK-DAG: #[[$bounds4:.+]] = #plan.bounds : tensor<2xi64>, dense<10> : tensor<2xi64>> // CHECK-LABEL: func.func @value_bounds -// CHECK-SAME: tensor {tensorrt.shape_profile = #tensorrt.shape_profile} -// CHECK-SAME: tensor<2xi32> {tensorrt.value_bounds = #tensorrt.shape_profile}) -// CHECK-SAME: -> (tensor {tensorrt.shape_profile = #plan.bounds}, -// CHECK-SAME: tensor<2xi32> {tensorrt.value_bounds = #plan.bounds : tensor<2xi64>, dense<10> : tensor<2xi64>>} +// CHECK-SAME: -> (tensor {plan.shape_profile = #[[$bounds2]]}, +// CHECK-SAME: tensor<2xi32> {plan.value_bounds = #[[$bounds4]]} diff --git a/mlir-tensorrt/test/Dialect/Plan/refine-types.mlir b/mlir-tensorrt/test/Dialect/Plan/refine-types.mlir index aebfdf0e5..6c171a05d 100644 --- a/mlir-tensorrt/test/Dialect/Plan/refine-types.mlir +++ b/mlir-tensorrt/test/Dialect/Plan/refine-types.mlir @@ -84,6 +84,34 @@ func.func @refine_add_with_shape(%arg0: tensor, %arg1: tensor) -> // CHECK: %[[v0:.+]] = stablehlo.add %[[arg0]], %[[arg1]] : // CHECK: return %[[v0]] : tensor<1024xf32> +// ----- + +func.func @stablehlo_refine_with_shape_multi_user(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + %c1024 = arith.constant 1024 : i32 + %1 = stablehlo.add %arg0, %arg1 : tensor + %2 = plan.with_shape %1 (%c1024) : (tensor, i32) -> tensor + return %2, %1 : tensor, tensor +} + +// CHECK-LABEL: func.func @stablehlo_refine_with_shape_multi_user +// CHECK-SAME: (%[[arg0:.+]]: tensor, %[[arg1:.+]]: tensor) -> (tensor<1024xf32>, tensor<1024xf32>) { +// CHECK: %[[v0:.+]] = stablehlo.add %[[arg0]], %[[arg1]] : (tensor, tensor) -> tensor<1024xf32> +// CHECK: return %[[v0]], %[[v0]] : tensor<1024xf32>, tensor<1024xf32> + +// ----- + +func.func @tensorrt_refine_with_shape_multi_user(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + %c1024 = arith.constant 1024 : i32 + %1 = tensorrt.element_wise (%arg0, %arg1 : tensor, tensor) -> tensor + %2 = plan.with_shape %1 (%c1024) : (tensor, i32) -> tensor + return %2, %1 : tensor, tensor +} + +// CHECK-LABEL: func.func @tensorrt_refine_with_shape_multi_user +// CHECK-SAME: (%[[arg0:.+]]: tensor, %[[arg1:.+]]: tensor) -> (tensor<1024xf32>, tensor<1024xf32>) { +// CHECK: %[[v0:.+]] = tensorrt.element_wise (%[[arg0]], %[[arg1]] : tensor, tensor) -> tensor<1024xf32> +// CHECK: return %[[v0]], %[[v0]] : + // ----- @@ -159,3 +187,5 @@ func.func @refine_tensorrt_resize_with_shape() -> tensor { // CHECK-SAME: -> tensor<1x1x8x8xf32> // CHECK: %[[v0:.*]] = tensorrt.resize_linear {coordinateTransformation = #tensorrt.resize_coordinate_transformation, selectorForSinglePixel = #tensorrt.resize_selector} %cst, %c : (tensor<1x1x4x4xf32>, tensor<4xi32>) -> tensor<1x1x8x8xf32> // CHECK: return %[[v0]] : tensor<1x1x8x8xf32> + + diff --git a/mlir-tensorrt/test/Dialect/Plan/roundtrip.mlir b/mlir-tensorrt/test/Dialect/Plan/roundtrip.mlir index 8c065aaad..369cbf1a9 100644 --- a/mlir-tensorrt/test/Dialect/Plan/roundtrip.mlir +++ b/mlir-tensorrt/test/Dialect/Plan/roundtrip.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-tensorrt-opt %s -split-input-file | mlir-tensorrt-opt | FileCheck %s +// RUN: mlir-tensorrt-opt %s -split-input-file | mlir-tensorrt-opt -split-input-file -mlir-print-local-scope | FileCheck %s func.func @plan_attrs() attributes { diff --git a/mlir-tensorrt/test/Dialect/Plan/segmentation-pipeline.mlir b/mlir-tensorrt/test/Dialect/Plan/segmentation-pipeline.mlir new file mode 100644 index 000000000..216c20b65 --- /dev/null +++ b/mlir-tensorrt/test/Dialect/Plan/segmentation-pipeline.mlir @@ -0,0 +1,213 @@ +// RUN: mlir-tensorrt-opt -split-input-file \ +// RUN: -plan-segmentation-pipeline -cse -verify-diagnostics %s | FileCheck %s + +builtin.module attributes { + plan.cluster_kinds = [ + #plan.tensorrt_cluster, + #plan.host_cluster + ] +} { + func.func @chlo_erf_to_trt_cluster(%arg0: tensor<1x197x3072xf32>) -> tensor<1x197x3072xf32> { + %0 = chlo.erf %arg0 : tensor<1x197x3072xf32> -> tensor<1x197x3072xf32> + return %0 : tensor<1x197x3072xf32> + } +} + +// CHECK: #[[$profile:.+]] = #tensorrt.shape_profile +// CHECK-LABEL: @chlo_erf_to_trt_cluster +// CHECK-SAME: (%[[arg0:.+]]: tensor<1x197x3072xf32>) +// CHECK: %[[v0:.+]] = tensor.empty() : tensor<1x197x3072xf32> +// CHECK: %[[v1:.+]] = tensorrt.call @trt_engines::@tensorrt_cluster(%[[arg0]] : tensor<1x197x3072xf32>) outs(%[[v0]] : tensor<1x197x3072xf32>) -> tensor<1x197x3072xf32> +// CHECK: return %[[v1]] : tensor<1x197x3072xf32> + +// CHECK: tensorrt.module @trt_engines +// CHECK-LABEL: func.func @tensorrt_cluster +// CHECK-SAME: (%[[arg0:.+]]: tensor<1x197x3072xf32>) -> (tensor<1x197x3072xf32> {tensorrt.shape_profile = #[[$profile]]}) attributes {cluster.tensorrt} +// CHECK: %[[v0:.+]] = chlo.erf %[[arg0]] +// CHECK: return %[[v0]] + +// ----- + +builtin.module attributes { + plan.cluster_kinds = [ + #plan.tensorrt_cluster, + #plan.host_cluster + ] +} { + +func.func @reduce(%arg0: tensor<4xi32>, %arg1: tensor) -> (tensor, tensor) { + %0 = stablehlo.constant dense<0.000000e+00> : tensor + %1 = stablehlo.constant dense<1.000000e+00> : tensor + %2 = stablehlo.constant dense<0> : tensor + %3 = stablehlo.compare EQ, %2, %arg1 : (tensor, tensor) -> tensor + %4 = stablehlo.reduce(%arg0 init: %2) across dimensions = [0] : (tensor<4xi32>, tensor) -> tensor + reducer(%arg6: tensor, %arg7: tensor) { + %27 = stablehlo.add %arg6, %arg7 : tensor + stablehlo.return %27 : tensor + } + + return %4, %3 : tensor, tensor +} + +} + +// CHECK: #[[$profile:.+]] = #tensorrt.shape_profile +// CHECK-LABEL: func.func @reduce +// CHECK-SAME: (%[[arg0:.+]]: tensor<4xi32>, %[[arg1:.+]]: tensor) -> (tensor, tensor) { +// CHECK: %[[v0:.+]] = tensor.empty() : tensor +// CHECK: %[[v1:.+]] = tensor.empty() : tensor +// CHECK: %[[v2:.+]]:2 = tensorrt.call @trt_engines::@tensorrt_cluster(%[[arg1]], %[[arg0]] : tensor, tensor<4xi32>) outs(%[[v0]], %[[v1]] : +// CHECK: return %[[v2]]#1, %[[v2]]#0 : +// CHECK: tensorrt.module @trt_engines +// CHECK-LABEL: func.func @tensorrt_cluster +// CHECK-SAME: (%[[arg0:.+]]: tensor, %[[arg1:.+]]: tensor<4xi32>) -> (tensor {tensorrt.shape_profile = #[[$profile]]}, tensor {tensorrt.shape_profile = #[[$profile]]}) attributes {cluster.tensorrt} +// CHECK: stablehlo.constant +// CHECK: stablehlo.compare +// CHECK: stablehlo.reduce +// CHECK: return + +// ----- +builtin.module attributes { + plan.cluster_kinds = [ + #plan.tensorrt_cluster, + #plan.host_cluster + ] +} { + +func.func @small_reduce_host(%arg0: tensor<4xi32>, %arg1: tensor) + -> (tensor {tensorrt.host_tensor}, tensor {tensorrt.host_tensor}) { + %0 = stablehlo.constant dense<0.000000e+00> : tensor + %1 = stablehlo.constant dense<1.000000e+00> : tensor + %2 = stablehlo.constant dense<0> : tensor + %3 = stablehlo.compare EQ, %2, %arg1 : (tensor, tensor) -> tensor + %4 = stablehlo.reduce(%arg0 init: %2) across dimensions = [0] : (tensor<4xi32>, tensor) -> tensor + reducer(%arg6: tensor, %arg7: tensor) { + %27 = stablehlo.add %arg6, %arg7 : tensor + stablehlo.return %27 : tensor + } + return %4, %3 : tensor, tensor +} + +} + + +// CHECK-LABEL: func.func @small_reduce_host +// CHECK-SAME: (%[[arg0:.+]]: tensor<4xi32>, %[[arg1:.+]]: tensor) -> (tensor {tensorrt.host_tensor}, tensor {tensorrt.host_tensor}) + +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[extracted:.+]] = tensor.extract %[[arg1]][] : tensor +// CHECK-DAG: %[[extracted_0:.+]] = tensor.extract %[[arg0]][%[[c0]]] : tensor<4xi32> +// CHECK-DAG: %[[extracted_1:.+]] = tensor.extract %[[arg0]][%[[c1]]] : tensor<4xi32> +// CHECK-DAG: %[[extracted_2:.+]] = tensor.extract %[[arg0]][%[[c2]]] : tensor<4xi32> +// CHECK-DAG: %[[extracted_3:.+]] = tensor.extract %[[arg0]][%[[c3]]] : tensor<4xi32> +// CHECK: %[[v0:.+]]:2 = call @host_cluster(%[[extracted]], %[[extracted_0]], %[[extracted_1]], %[[extracted_2]], %[[extracted_3]]) : +// CHECK: %[[from_elements:.+]] = tensor.from_elements %[[v0]]#0 : tensor +// CHECK: %[[from_elements_4:.+]] = tensor.from_elements %[[v0]]#1 : tensor +// CHECK: return %[[from_elements_4]], %[[from_elements]] : tensor, tensor +// CHECK-LABEL: private @host_cluster +// CHECK-SAME: (%[[arg0:.+]]: i32, %[[arg1:.+]]: i32, %[[arg2:.+]]: i32, %[[arg3:.+]]: i32, %[[arg4:.+]]: i32) -> (i1, i32) attributes {cluster.host} +// CHECK-DAG: %[[v0:.+]] = stablehlo.constant dense<0> : tensor +// CHECK-DAG: %[[from_elements:.+]] = tensor.from_elements %[[arg0]] : tensor +// CHECK-DAG: %[[from_elements_0:.+]] = tensor.from_elements %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : tensor<4xi32> +// CHECK: %[[v1:.+]] = stablehlo.compare EQ, %[[v0]] +// CHECK: %[[v2:.+]] = stablehlo.reduce(%[[from_elements_0]] init: %[[v0]]) applies stablehlo.add across dimensions = [0] : +// CHECK: %[[extracted:.+]] = tensor.extract %[[v1]][] +// CHECK: %[[extracted_1:.+]] = tensor.extract %[[v2]][] +// CHECK: return %[[extracted]], %[[extracted_1]] + +// ----- + +// Quantize f32 -> int8 +func.func @main(%arg0: tensor<2x3x300x300xf32>) -> tensor<2x3x300x300xi8> { + %0 = stablehlo.composite "tensorrt.pt_q" %arg0 {composite_attributes = {axis = -1 : i32, scale = dense<8.000000e-01> : tensor}, decomposition = @pt_q} : (tensor<2x3x300x300xf32>) -> tensor<2x3x300x300xi8> + return %0 : tensor<2x3x300x300xi8> +} +func.func private @pt_q(%arg0: tensor<2x3x300x300xf32>) -> tensor<2x3x300x300xi8> attributes {plan.decomposition} { + %cst = stablehlo.constant dense<-1.280000e+02> : tensor + %cst_0 = stablehlo.constant dense<1.270000e+02> : tensor + %cst_1 = stablehlo.constant dense<8.000000e-01> : tensor + %0 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor) -> tensor<2x3x300x300xf32> + %1 = stablehlo.divide %arg0, %0 : tensor<2x3x300x300xf32> + %2 = stablehlo.round_nearest_even %1 : tensor<2x3x300x300xf32> + %3 = stablehlo.clamp %cst, %2, %cst_0 : (tensor, tensor<2x3x300x300xf32>, tensor) -> tensor<2x3x300x300xf32> + %4 = stablehlo.convert %3 : (tensor<2x3x300x300xf32>) -> tensor<2x3x300x300xi8> + return %4 : tensor<2x3x300x300xi8> +} + +// CHECK-LABEL: func.func @main +// CHECK-SAME: (%[[arg0:.+]]: tensor<2x3x300x300xf32>) -> tensor<2x3x300x300xi8> +// CHECK-NEXT: %[[v0:.+]] = tensor.empty() : tensor<2x3x300x300xi8> +// CHECK-NEXT: %[[v1:.+]] = tensorrt.call @trt_engines::@tensorrt_cluster(%[[arg0]] : tensor<2x3x300x300xf32>) outs(%[[v0]] : tensor<2x3x300x300xi8>) -> tensor<2x3x300x300xi8> +// CHECK-NEXT: return %[[v1]] : tensor<2x3x300x300xi8> + +// CHECK-LABEL: tensorrt.module @trt_engines +// CHECK-LABEL: func.func @tensorrt_cluster +// CHECK-SAME: (%[[arg0:.+]]: tensor<2x3x300x300xf32> +// CHECK-NEXT: %[[v0:.+]] = stablehlo.composite "tensorrt.pt_q" %[[arg0]] {composite_attributes = {axis = -1 : i32, scale = dense<8.000000e-01> : tensor}, decomposition = @pt_q} : (tensor<2x3x300x300xf32>) -> tensor<2x3x300x300xi8> +// CHECK-NEXT: return %[[v0]] : tensor<2x3x300x300xi8> + +// CHECK-LABEL: func.func private @pt_q +// CHECK-SAME: (%[[arg0:.+]]: tensor<2x3x300x300xf32>) +// CHECK-SAME: attributes {plan.decomposition} +// CHECK-NEXT: %[[v0:.+]] = stablehlo.constant dense<-1.280000e+02> : tensor +// CHECK-NEXT: %[[v1:.+]] = stablehlo.constant dense<1.270000e+02> : tensor +// CHECK-NEXT: %[[v2:.+]] = stablehlo.constant dense<8.000000e-01> : tensor +// CHECK-NEXT: %[[v3:.+]] = stablehlo.broadcast_in_dim %[[v2]], dims = [] : (tensor) -> tensor<2x3x300x300xf32> +// CHECK-NEXT: %[[v4:.+]] = stablehlo.divide %[[arg0]], %[[v3]] : tensor<2x3x300x300xf32> +// CHECK-NEXT: %[[v5:.+]] = stablehlo.round_nearest_even %[[v4]] : tensor<2x3x300x300xf32> +// CHECK-NEXT: %[[v6:.+]] = stablehlo.clamp %[[v0]], %[[v5]], %[[v1]] +// CHECK-NEXT: %[[v7:.+]] = stablehlo.convert %[[v6]] +// CHECK-NEXT: return %[[v7]] : tensor<2x3x300x300xi8> + +// ----- + +builtin.module @simple_gather_dynamic { +func.func @simple_gather_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { + %c1 = stablehlo.constant dense<1> : tensor<1xi32> + %c256 = stablehlo.constant dense<256> : tensor<1xi32> + %dim = stablehlo.get_dimension_size %arg0, dim = 1 : (tensor) -> tensor + %dim.1 = stablehlo.reshape %dim : (tensor) -> tensor<1xi32> + %shape = stablehlo.concatenate %c1, %dim.1, %c256, %c256, dim = 0 : + (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + %0 = "stablehlo.dynamic_gather"(%arg0, %arg1, %shape) { + dimension_numbers = #stablehlo.gather< + offset_dims = [1, 2, 3], + collapsed_slice_dims = [0], + start_index_map = [0], + index_vector_dim = 1>, + indices_are_sorted = false, slice_sizes = array + } : (tensor, tensor, tensor<4xi32>) -> tensor + return %0 : tensor +} +} + +// CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) * 65536)> +// CHECK-LABEL: func.func @simple_gather_dynamic +// CHECK-SAME: (%[[arg0:.+]]: tensor, %[[arg1:.+]]: tensor) -> tensor +// CHECK-DAG: %[[c256:.+]] = arith.constant 256 : index +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[dim:.+]] = tensor.dim %[[arg1]], %[[c0]] : tensor +// CHECK-DAG: %[[dim_0:.+]] = tensor.dim %[[arg0]], %[[c1]] : tensor +// CHECK-DAG: %[[v0:.+]] = arith.index_cast %[[dim_0]] : index to i32 +// CHECK-DAG: %[[v1:.+]] = arith.index_cast %[[v0]] : i32 to index +// CHECK-DAG: %[[v2:.+]] = tensor.empty() : tensor<65536xi32> +// CHECK-DAG: %[[v3:.+]] = affine.apply #[[$map]]()[%[[dim]], %[[v1]]] +// CHECK-DAG: %[[extracted_slice:.+]] = tensor.extract_slice %[[v2]][0] [%[[v3]]] [1] : tensor<65536xi32> to tensor +// CHECK-DAG: %[[from_elements:.+]] = tensor.from_elements %[[dim]], %[[v1]], %[[c256]], %[[c256]] : tensor<4xindex> +// CHECK-DAG: %[[reshape:.+]] = tensor.reshape %[[extracted_slice]](%[[from_elements]]) : (tensor, tensor<4xindex>) -> tensor +// CHECK-DAG: %[[v4:.+]] = tensorrt.call @trt_engines::@tensorrt_cluster(%[[arg1]], %[[arg0]] +// CHECK-DAG: return %[[v4]] : tensor + +// CHECK-LABEL: func.func @tensorrt_cluster +// CHECK-SAME: (%[[arg0:.+]]: tensor{{.*}}, %[[arg1:.+]]: tensor{{.*}}) +// CHECK-DAG: %[[c:.+]] = stablehlo.constant dense<1> : tensor<1xi32> +// CHECK-DAG: %[[c_0:.+]] = stablehlo.constant dense<256> : tensor<1xi32> +// CHECK-DAG: %[[v0:.+]] = stablehlo.get_dimension_size %[[arg1]], dim = 1 +// CHECK-DAG: %[[v1:.+]] = stablehlo.reshape %[[v0]] : (tensor) -> tensor<1xi32> +// CHECK-DAG: %[[v2:.+]] = stablehlo.concatenate %[[c]], %[[v1]], %[[c_0]], %[[c_0]] +// CHECK-DAG: %[[v3:.+]] = "stablehlo.dynamic_gather"(%[[arg1]], %[[arg0]], %[[v2]]) +// CHECK-DAG: return %[[v3]] : tensor \ No newline at end of file diff --git a/mlir-tensorrt/test/Dialect/Plan/stablehlo-clustering.mlir b/mlir-tensorrt/test/Dialect/Plan/stablehlo-clustering.mlir index b44382d16..d755eac01 100644 --- a/mlir-tensorrt/test/Dialect/Plan/stablehlo-clustering.mlir +++ b/mlir-tensorrt/test/Dialect/Plan/stablehlo-clustering.mlir @@ -1,163 +1,133 @@ // RUN: mlir-tensorrt-opt -split-input-file \ -// RUN: -plan-segmentation-pipeline -cse -verify-diagnostics %s | FileCheck %s - -builtin.module attributes { - plan.cluster_kinds = [ - #plan.tensorrt_cluster, - #plan.host_cluster - ] -} { - func.func @chlo_erf_to_trt_cluster(%arg0: tensor<1x197x3072xf32>) -> tensor<1x197x3072xf32> { - %0 = chlo.erf %arg0 : tensor<1x197x3072xf32> -> tensor<1x197x3072xf32> - return %0 : tensor<1x197x3072xf32> - } +// RUN: -stablehlo-clustering %s | FileCheck %s + +// Check that we can recognize `stablehlo.dynamic_gather` using `plan.with_shape|plan.with_values` to prove required shape/value equivalence +// propositions. + +func.func @simple_gather_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tensor<4xi32>) -> tensor { + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : i32 + %dim.0 = tensor.dim %arg0, %c0 : tensor + %dim = tensor.dim %arg0, %c1 : tensor + %dim.1 = arith.index_cast %dim : index to i32 + %c1_i32 = arith.constant 1 : i32 + %data = plan.with_shape %arg0(%dim.0, %dim.1, %c256, %c256) : (tensor, index, i32, i32, i32) -> tensor + %slice_sizes = plan.with_values %arg2(%c1_i32, %dim.1, %c256, %c256) : tensor<4xi32> + %0 = "stablehlo.dynamic_gather"(%data, %arg1, %slice_sizes) { + dimension_numbers = #stablehlo.gather< + offset_dims = [1, 2, 3], + collapsed_slice_dims = [0], + start_index_map = [0], + index_vector_dim = 1>, + indices_are_sorted = false, slice_sizes = array + } : (tensor, tensor, tensor<4xi32>) -> tensor + return %0 : tensor } -// CHECK: #[[$profile:.+]] = #tensorrt.shape_profile -// CHECK-LABEL: @chlo_erf_to_trt_cluster -// CHECK-SAME: (%[[arg0:.+]]: tensor<1x197x3072xf32>) -// CHECK: %[[v0:.+]] = tensor.empty() : tensor<1x197x3072xf32> -// CHECK: %[[v1:.+]] = tensorrt.call @trt_engines::@tensorrt_cluster(%[[arg0]] : tensor<1x197x3072xf32>) outs(%[[v0]] : tensor<1x197x3072xf32>) -> tensor<1x197x3072xf32> -// CHECK: return %[[v1]] : tensor<1x197x3072xf32> - -// CHECK: tensorrt.module @trt_engines -// CHECK-LABEL: func.func @tensorrt_cluster -// CHECK-SAME: (%[[arg0:.+]]: tensor<1x197x3072xf32>) -> (tensor<1x197x3072xf32> {tensorrt.shape_profile = #[[$profile]]}) attributes {cluster.tensorrt} -// CHECK: %[[v0:.+]] = chlo.erf %[[arg0]] -// CHECK: return %[[v0]] +// CHECK-LABEL: func.func @simple_gather_dynamic( +// CHECK: %[[v1:.+]] = plan.inline_group target(#plan.tensorrt_cluster< +// CHECK-NEXT: with_shape +// CHECK-NEXT: with_values +// CHECK-NEXT: stablehlo.dynamic_gather +// CHECK-NEXT: yield // ----- -builtin.module attributes { +// Test that interleaved `plan.with_values` and `arith` dialect operations don't disrupt +// the clustering of stablehlo ops that can be put into host clusters. + +builtin.module @host_clusters_with_values attributes { plan.cluster_kinds = [ #plan.tensorrt_cluster, #plan.host_cluster ] } { -func.func @reduce(%arg0: tensor<4xi32>, %arg1: tensor) -> (tensor, tensor) { +func.func @test(%arg0: tensor<4xi32>, %arg1: tensor) + -> (tensor {tensorrt.host_tensor}, tensor {tensorrt.host_tensor}) { %0 = stablehlo.constant dense<0.000000e+00> : tensor %1 = stablehlo.constant dense<1.000000e+00> : tensor %2 = stablehlo.constant dense<0> : tensor + + %c0_i32 = arith.constant 0 : i32 %3 = stablehlo.compare EQ, %2, %arg1 : (tensor, tensor) -> tensor + %extract = tensor.extract %arg1[] : tensor + %cmp = arith.cmpi eq, %c0_i32, %extract : i32 + %with_values = plan.with_values %3(%cmp) : tensor + %4 = stablehlo.reduce(%arg0 init: %2) across dimensions = [0] : (tensor<4xi32>, tensor) -> tensor reducer(%arg6: tensor, %arg7: tensor) { %27 = stablehlo.add %arg6, %arg7 : tensor stablehlo.return %27 : tensor } - - return %4, %3 : tensor, tensor + return %4, %with_values : tensor, tensor } } -// CHECK: #[[$profile:.+]] = #tensorrt.shape_profile -// CHECK-LABEL: func.func @reduce -// CHECK-SAME: (%[[arg0:.+]]: tensor<4xi32>, %[[arg1:.+]]: tensor) -> (tensor, tensor) { -// CHECK: %[[v0:.+]] = tensor.empty() : tensor -// CHECK: %[[v1:.+]] = tensor.empty() : tensor -// CHECK: %[[v2:.+]]:2 = tensorrt.call @trt_engines::@tensorrt_cluster(%[[arg1]], %[[arg0]] : tensor, tensor<4xi32>) outs(%[[v0]], %[[v1]] : -// CHECK: return %[[v2]]#1, %[[v2]]#0 : -// CHECK: tensorrt.module @trt_engines -// CHECK-LABEL: func.func @tensorrt_cluster -// CHECK-SAME: (%[[arg0:.+]]: tensor, %[[arg1:.+]]: tensor<4xi32>) -> (tensor {tensorrt.shape_profile = #[[$profile]]}, tensor {tensorrt.shape_profile = #[[$profile]]}) attributes {cluster.tensorrt} -// CHECK: stablehlo.constant -// CHECK: stablehlo.compare -// CHECK: stablehlo.reduce -// CHECK: return +// CHECK-LABEL: module @host_clusters_with_values attributes +// CHECK-LABEL: func.func @test +// CHECK-SAME: (%[[arg0:.+]]: tensor<4xi32>, %[[arg1:.+]]: tensor) +// CHECK-DAG: %[[cst:.+]] = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK-DAG: %[[cst_0:.+]] = stablehlo.constant dense<1.000000e+00> : tensor +// CHECK-DAG: %[[c:.+]] = stablehlo.constant dense<0> : tensor +// CHECK-DAG: %[[c0_i32:.+]] = arith.constant 0 : i32 +// CHECK-DAG: %[[c_1:.+]] = stablehlo.constant dense<0> : tensor +// CHECK-DAG: %[[extracted:.+]] = tensor.extract %[[arg1]][] : tensor +// CHECK-DAG: %[[v0:.+]] = arith.cmpi eq, %[[c0_i32]], %[[extracted]] : i32 +// CHECK-DAG: %[[v1:.+]]:2 = plan.inline_group target(#plan.host_cluster) +// CHECK-DAG: %[[v2:.+]] = stablehlo.compare EQ, %[[c_1]], %[[arg1]] : +// CHECK-DAG: %[[v3:.+]] = with_values %[[v2]](%[[v0]]) : tensor +// CHECK-DAG: %[[v4:.+]] = stablehlo.reduce(%[[arg0]] init: %[[c]]) +// CHECK-DAG: yield %[[v3]], %[[v4]] : tensor, tensor +// CHECK-DAG: return %[[v1]]#1, %[[v1]]#0 : tensor, tensor // ----- + +// Ensure that we don't create clusters containing 'plan.with_values' or +// 'plan.with_shape' operations. This can cause some un-intended side-effects +// if the compiler introduces extra ops to outline these clusters (e.g. scalar +// host clusters create extra `tensor.extract` and `tensor.from_elements` at +// the boundaries). These extra ops can block optimizations and +// reduce performance. + builtin.module attributes { plan.cluster_kinds = [ - #plan.tensorrt_cluster, #plan.host_cluster ] } { -func.func @small_reduce_host(%arg0: tensor<4xi32>, %arg1: tensor) - -> (tensor {tensorrt.host_tensor}, tensor {tensorrt.host_tensor}) { - %0 = stablehlo.constant dense<0.000000e+00> : tensor - %1 = stablehlo.constant dense<1.000000e+00> : tensor - %2 = stablehlo.constant dense<0> : tensor - %3 = stablehlo.compare EQ, %2, %arg1 : (tensor, tensor) -> tensor - %4 = stablehlo.reduce(%arg0 init: %2) across dimensions = [0] : (tensor<4xi32>, tensor) -> tensor - reducer(%arg6: tensor, %arg7: tensor) { - %27 = stablehlo.add %arg6, %arg7 : tensor - stablehlo.return %27 : tensor - } - return %4, %3 : tensor, tensor +func.func @host_cluster_filtering(%arg0: tensor, %arg1: i32) + -> (tensor {tensorrt.host_tensor}, tensor {tensorrt.host_tensora}) { + %0 = plan.with_values %arg0 (%arg1) : tensor + %1 = stablehlo.constant dense<1> : tensor + %c1_i32 = arith.constant 1 : i32 + %2 = plan.with_values %1(%c1_i32) : tensor + return %0, %2 : tensor, tensor } } - -// CHECK-LABEL: func.func @small_reduce_host -// CHECK-SAME: (%[[arg0:.+]]: tensor<4xi32>, %[[arg1:.+]]: tensor) -> (tensor {tensorrt.host_tensor}, tensor {tensorrt.host_tensor}) - -// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index -// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index -// CHECK-DAG: %[[extracted:.+]] = tensor.extract %[[arg1]][] : tensor -// CHECK-DAG: %[[extracted_0:.+]] = tensor.extract %[[arg0]][%[[c0]]] : tensor<4xi32> -// CHECK-DAG: %[[extracted_1:.+]] = tensor.extract %[[arg0]][%[[c1]]] : tensor<4xi32> -// CHECK-DAG: %[[extracted_2:.+]] = tensor.extract %[[arg0]][%[[c2]]] : tensor<4xi32> -// CHECK-DAG: %[[extracted_3:.+]] = tensor.extract %[[arg0]][%[[c3]]] : tensor<4xi32> -// CHECK: %[[v0:.+]]:2 = call @host_cluster(%[[extracted]], %[[extracted_0]], %[[extracted_1]], %[[extracted_2]], %[[extracted_3]]) : -// CHECK: %[[from_elements:.+]] = tensor.from_elements %[[v0]]#0 : tensor -// CHECK: %[[from_elements_4:.+]] = tensor.from_elements %[[v0]]#1 : tensor -// CHECK: return %[[from_elements_4]], %[[from_elements]] : tensor, tensor -// CHECK-LABEL: private @host_cluster -// CHECK-SAME: (%[[arg0:.+]]: i32, %[[arg1:.+]]: i32, %[[arg2:.+]]: i32, %[[arg3:.+]]: i32, %[[arg4:.+]]: i32) -> (i1, i32) attributes {cluster.host} -// CHECK-DAG: %[[v0:.+]] = stablehlo.constant dense<0> : tensor -// CHECK-DAG: %[[from_elements:.+]] = tensor.from_elements %[[arg0]] : tensor -// CHECK-DAG: %[[from_elements_0:.+]] = tensor.from_elements %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : tensor<4xi32> -// CHECK: %[[v1:.+]] = stablehlo.compare EQ, %[[v0]] -// CHECK: %[[v2:.+]] = stablehlo.reduce(%[[from_elements_0]] init: %[[v0]]) applies stablehlo.add across dimensions = [0] : -// CHECK: %[[extracted:.+]] = tensor.extract %[[v1]][] -// CHECK: %[[extracted_1:.+]] = tensor.extract %[[v2]][] -// CHECK: return %[[extracted]], %[[extracted_1]] +// CHECK-LABEL: func.func @host_cluster_filtering +// CHECK-NOT: plan.inline_group // ----- -// Quantize f32 -> int8 -func.func @main(%arg0: tensor<2x3x300x300xf32>) -> tensor<2x3x300x300xi8> { - %0 = stablehlo.composite "tensorrt.pt_q" %arg0 {composite_attributes = {axis = -1 : i32, scale = dense<8.000000e-01> : tensor}, decomposition = @pt_q} : (tensor<2x3x300x300xf32>) -> tensor<2x3x300x300xi8> - return %0 : tensor<2x3x300x300xi8> +builtin.module attributes { + plan.cluster_kinds = [ + #plan.tensorrt_cluster + ] +} { + +func.func @tensorrt_cluster_filtering(%arg0: tensor, %arg1: i32, %arg2: tensor) + -> (tensor, tensor {tensorrt.host_tensor}) { + %0 = plan.with_shape %arg0 (%arg1) : (tensor, i32) -> tensor + %1 = plan.with_values %arg2 (%arg1) : tensor + return %0, %1 : tensor, tensor } -func.func private @pt_q(%arg0: tensor<2x3x300x300xf32>) -> tensor<2x3x300x300xi8> attributes {plan.decomposition} { - %cst = stablehlo.constant dense<-1.280000e+02> : tensor - %cst_0 = stablehlo.constant dense<1.270000e+02> : tensor - %cst_1 = stablehlo.constant dense<8.000000e-01> : tensor - %0 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor) -> tensor<2x3x300x300xf32> - %1 = stablehlo.divide %arg0, %0 : tensor<2x3x300x300xf32> - %2 = stablehlo.round_nearest_even %1 : tensor<2x3x300x300xf32> - %3 = stablehlo.clamp %cst, %2, %cst_0 : (tensor, tensor<2x3x300x300xf32>, tensor) -> tensor<2x3x300x300xf32> - %4 = stablehlo.convert %3 : (tensor<2x3x300x300xf32>) -> tensor<2x3x300x300xi8> - return %4 : tensor<2x3x300x300xi8> + } -// CHECK-LABEL: func.func @main -// CHECK-SAME: (%[[arg0:.+]]: tensor<2x3x300x300xf32>) -> tensor<2x3x300x300xi8> -// CHECK-NEXT: %[[v0:.+]] = tensor.empty() : tensor<2x3x300x300xi8> -// CHECK-NEXT: %[[v1:.+]] = tensorrt.call @trt_engines::@tensorrt_cluster(%[[arg0]] : tensor<2x3x300x300xf32>) outs(%[[v0]] : tensor<2x3x300x300xi8>) -> tensor<2x3x300x300xi8> -// CHECK-NEXT: return %[[v1]] : tensor<2x3x300x300xi8> - -// CHECK-LABEL: tensorrt.module @trt_engines -// CHECK-LABEL: func.func @tensorrt_cluster -// CHECK-SAME: (%[[arg0:.+]]: tensor<2x3x300x300xf32> -// CHECK-NEXT: %[[v0:.+]] = stablehlo.composite "tensorrt.pt_q" %[[arg0]] {composite_attributes = {axis = -1 : i32, scale = dense<8.000000e-01> : tensor}, decomposition = @pt_q} : (tensor<2x3x300x300xf32>) -> tensor<2x3x300x300xi8> -// CHECK-NEXT: return %[[v0]] : tensor<2x3x300x300xi8> - -// CHECK-LABEL: func.func private @pt_q -// CHECK-SAME: (%[[arg0:.+]]: tensor<2x3x300x300xf32>) -// CHECK-SAME: attributes {plan.decomposition} -// CHECK-NEXT: %[[v0:.+]] = stablehlo.constant dense<-1.280000e+02> : tensor -// CHECK-NEXT: %[[v1:.+]] = stablehlo.constant dense<1.270000e+02> : tensor -// CHECK-NEXT: %[[v2:.+]] = stablehlo.constant dense<8.000000e-01> : tensor -// CHECK-NEXT: %[[v3:.+]] = stablehlo.broadcast_in_dim %[[v2]], dims = [] : (tensor) -> tensor<2x3x300x300xf32> -// CHECK-NEXT: %[[v4:.+]] = stablehlo.divide %[[arg0]], %[[v3]] : tensor<2x3x300x300xf32> -// CHECK-NEXT: %[[v5:.+]] = stablehlo.round_nearest_even %[[v4]] : tensor<2x3x300x300xf32> -// CHECK-NEXT: %[[v6:.+]] = stablehlo.clamp %[[v0]], %[[v5]], %[[v1]] -// CHECK-NEXT: %[[v7:.+]] = stablehlo.convert %[[v6]] -// CHECK-NEXT: return %[[v7]] : tensor<2x3x300x300xi8> \ No newline at end of file +// CHECK-LABEL: func.func @tensorrt_cluster_filtering +// CHECK-NOT: plan.inline_group diff --git a/mlir-tensorrt/test/Dialect/StableHloExt/stablehlo-prepare-convolution.mlir b/mlir-tensorrt/test/Dialect/StableHloExt/canonicalize-convolution.mlir similarity index 98% rename from mlir-tensorrt/test/Dialect/StableHloExt/stablehlo-prepare-convolution.mlir rename to mlir-tensorrt/test/Dialect/StableHloExt/canonicalize-convolution.mlir index b8e3551d0..3dfe6a6ce 100644 --- a/mlir-tensorrt/test/Dialect/StableHloExt/stablehlo-prepare-convolution.mlir +++ b/mlir-tensorrt/test/Dialect/StableHloExt/canonicalize-convolution.mlir @@ -1,27 +1,4 @@ -// RUN: mlir-tensorrt-opt %s -tensorrt-stablehlo-input-preprocessing -split-input-file | FileCheck %s - -func.func @conv2d_nchw_kcrs_padded( - %arg0: tensor<1x2x32x64xf32>, - %arg1: tensor<128x2x3x3xf32>) - -> tensor<1x128x28x62xf32> { - %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], - window = {stride = [1, 1], pad=[[0, 0], [0, 0]], rhs_dilate = [2, 1]} { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64 - } : (tensor<1x2x32x64xf32>, tensor<128x2x3x3xf32>) - -> tensor<1x128x28x62xf32> - func.return %0 : tensor<1x128x28x62xf32> -} - -// CHECK-LABEL: @conv2d_nchw_kcrs_padded -// CHECK-SAME: (%[[arg0:.+]]: tensor<1x2x32x64xf32>, %[[arg1:.+]]: tensor<128x2x3x3xf32>) -> tensor<1x128x28x62xf32> { -// CHECK: %[[v0:.+]] = stablehlo.convolution(%[[arg0]], %[[arg1]]) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], -// CHECK-SAME: window = {stride = [1, 1], pad = {{\[}}[0, 0], [0, 0]], rhs_dilate = [2, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} -// CHECK-SAME: : (tensor<1x2x32x64xf32>, tensor<128x2x3x3xf32>) -> tensor<1x128x28x62xf32> -// CHECK: return %[[v0]] : tensor<1x128x28x62xf32> - -// ----- +// RUN: mlir-tensorrt-opt %s -stablehlo-ext-canonicalize-convolution -split-input-file | FileCheck %s func.func @conv2d_nhwc_rsck_no_padding_dilated( %arg0: tensor<1x32x64x2xf32>, @@ -49,6 +26,29 @@ func.func @conv2d_nhwc_rsck_no_padding_dilated( // ----- +func.func @conv2d_nchw_kcrs_padded( + %arg0: tensor<1x2x32x64xf32>, + %arg1: tensor<128x2x3x3xf32>) + -> tensor<1x128x28x62xf32> { + %0 = stablehlo.convolution(%arg0, %arg1) + dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], + window = {stride = [1, 1], pad=[[0, 0], [0, 0]], rhs_dilate = [2, 1]} { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : (tensor<1x2x32x64xf32>, tensor<128x2x3x3xf32>) + -> tensor<1x128x28x62xf32> + func.return %0 : tensor<1x128x28x62xf32> +} + +// CHECK-LABEL: @conv2d_nchw_kcrs_padded +// CHECK-SAME: (%[[arg0:.+]]: tensor<1x2x32x64xf32>, %[[arg1:.+]]: tensor<128x2x3x3xf32>) -> tensor<1x128x28x62xf32> { +// CHECK: %[[v0:.+]] = stablehlo.convolution(%[[arg0]], %[[arg1]]) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], +// CHECK-SAME: window = {stride = [1, 1], pad = {{\[}}[0, 0], [0, 0]], rhs_dilate = [2, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} +// CHECK-SAME: : (tensor<1x2x32x64xf32>, tensor<128x2x3x3xf32>) -> tensor<1x128x28x62xf32> +// CHECK: return %[[v0]] : tensor<1x128x28x62xf32> + +// ----- + func.func @conv1d_nhc_rcf( %arg0: tensor<1x32x2xf32>, %arg1: tensor<3x2x128xf32>) diff --git a/mlir-tensorrt/test/Dialect/StableHloExt/canonicalize-gather.mlir b/mlir-tensorrt/test/Dialect/StableHloExt/canonicalize-gather.mlir index 7815efc65..55253f83e 100644 --- a/mlir-tensorrt/test/Dialect/StableHloExt/canonicalize-gather.mlir +++ b/mlir-tensorrt/test/Dialect/StableHloExt/canonicalize-gather.mlir @@ -16,15 +16,13 @@ func.func @transform_start_indices(%operand: tensor<33x34xf32>, // CHECK-LABEL: func @transform_start_indices // CHECK-SAME: %[[OPERAND:.*]]: tensor<33x34xf32> // CHECK-SAME: %[[INDICES:.*]]: tensor<42x43xi32> -// CHECK: %[[WITH_IVD:.*]] = tensor.expand_shape %[[INDICES]] -// CHECK-SAME: into tensor<42x43x1xi32> -// CHECK: %[[FLATTENED:.*]] = tensor.collapse_shape %[[WITH_IVD]] -// CHECK-SAME: into tensor<1806x1xi32> +// CHECK: %[[FLATTENED:.*]] = stablehlo.reshape %[[INDICES]] +// CHECK-SAME: -> tensor<1806x1xi32> // CHECK: %[[GATHER:.*]] = "stablehlo.gather"(%[[OPERAND]], %[[FLATTENED]]) // CHECK-SAME: offset_dims = [1, 2] // CHECK-SAME: index_vector_dim = 1 -// CHECK: %[[RESULT:.*]] = tensor.expand_shape %[[GATHER]] -// CHECK-SAME: into tensor<42x43x7x8xf32> +// CHECK: %[[RESULT:.*]] = stablehlo.reshape %[[GATHER]] +// CHECK-SAME: -> tensor<42x43x7x8xf32> // CHECK: return %[[RESULT]] // ----- @@ -49,8 +47,8 @@ func.func @remove_collapsed_slice_dims(%operand: tensor<33x34xf32>, // CHECK: %[[GATHER:.*]] = "stablehlo.gather"(%[[OPERAND]], %[[INDICES]]) // CHECK-SAME: offset_dims = [1, 2] // CHECK-NOT: collapsed_slice_dims -// CHECK: %[[RESULT:.*]] = tensor.collapse_shape %[[GATHER]] -// CHECK-SAME: into tensor<42xf32> +// CHECK: %[[RESULT:.*]] = stablehlo.reshape %[[GATHER]] +// CHECK-SAME: -> tensor<42xf32> // CHECK: return %[[RESULT]] // ----- @@ -101,8 +99,8 @@ func.func @collapse_some_dims(%operand: tensor<33x34x35xf32>, // CHECK-SAME: %[[INDICES:.*]]: tensor<42x1x // CHECK: %[[GATHER:.*]] = "stablehlo.gather"(%[[OPERAND]], %[[INDICES]]) // CHECK-SAME: -> tensor<42x1x7x1xf32> -// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[GATHER]] -// CHECK-SAME: into tensor<42x7xf32> +// CHECK: %[[COLLAPSED:.*]] = stablehlo.reshape %[[GATHER]] +// CHECK-SAME: -> tensor<42x7xf32> // CHECK: %[[RESULT:.*]] = stablehlo.transpose %[[COLLAPSED]] // CHECK-SAME: -> tensor<7x42xf32> // CHECK: return %[[RESULT]] @@ -125,13 +123,13 @@ func.func @no_batch_dims(%operand: tensor<8x16xf32>, %indices: tensor<2xi32>) // CHECK-LABEL: func @no_batch_dims // CHECK-SAME: %[[OPERAND:.*]]: tensor<8x16x // CHECK-SAME: %[[INDICES:.*]]: tensor<2x -// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[INDICES]] -// CHECK-SAME: into tensor<1x2xi32> +// CHECK: %[[EXPANDED:.*]] = stablehlo.reshape %[[INDICES]] +// CHECK-SAME: -> tensor<1x2xi32> // CHECK: %[[GATHER:.*]] = "stablehlo.gather"(%[[OPERAND]], %[[EXPANDED]]) // CHECK-SAME: offset_dims = [1, 2] // CHECK-SAME: index_vector_dim = 1 -// CHECK: %[[RESULT:.*]] = tensor.collapse_shape %[[GATHER]] -// CHECK-SAME: into tensor<8x16xf32> +// CHECK: %[[RESULT:.*]] = stablehlo.reshape %[[GATHER]] +// CHECK-SAME: -> tensor<8x16xf32> // CHECK: return %[[RESULT]] // ----- diff --git a/mlir-tensorrt/test/Dialect/StableHloExt/stablehlo-prepare-scatter.mlir b/mlir-tensorrt/test/Dialect/StableHloExt/canonicalize-scatter-nd.mlir similarity index 92% rename from mlir-tensorrt/test/Dialect/StableHloExt/stablehlo-prepare-scatter.mlir rename to mlir-tensorrt/test/Dialect/StableHloExt/canonicalize-scatter-nd.mlir index d8cdbe93b..ee57d711e 100644 --- a/mlir-tensorrt/test/Dialect/StableHloExt/stablehlo-prepare-scatter.mlir +++ b/mlir-tensorrt/test/Dialect/StableHloExt/canonicalize-scatter-nd.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-tensorrt-opt %s --tensorrt-stablehlo-input-preprocessing --stablehlo-aggressive-simplification -split-input-file | FileCheck %s +// RUN: mlir-tensorrt-opt %s --stablehlo-ext-canonicalize-scatter --stablehlo-aggressive-simplification -split-input-file | FileCheck %s func.func @whisper_jax_scatter(%arg0: tensor<1x51865xf32>) -> tensor<1x51865xf32> { @@ -25,7 +25,7 @@ func.func @whisper_jax_scatter(%arg0: tensor<1x51865xf32>) -> tensor<1x51865xf32 // CHECK-LABEL: @whisper_jax_scatter // CHECK-SAME: (%[[arg0:.+]]: tensor<1x51865xf32>) // CHECK-DAG: %[[v0:.+]] = stablehlo.constant dense<0xFF800000> : tensor<1x1xf32> -// CHECK-DAG: %[[cst:.+]] = stablehlo.constant dense<50257> : tensor<1x1xi32> +// CHECK-DAG: %[[cst:.+]] = arith.constant dense<50257> : tensor<1x1xi32> // CHECK: %[[v1:.+]] = stablehlo.reshape %[[arg0]] // CHECK: %[[v2:.+]] = "stablehlo.scatter"(%[[v1]], %[[cst]], %[[v0]]) // CHECK-SAME: indices_are_sorted = false @@ -77,9 +77,9 @@ func.func @whisper_jax_scatter2(%arg0: tensor<1x51865xf32>, %arg1: tensor<88x1xi // ----- func.func @stablehlo_scatter_canonicalize(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>, %arg2: tensor<2xi32>, %arg3: tensor<2x3xf32>, %arg4: tensor<2x3xf32>) -> tensor<3x3xf32> { - %expanded = tensor.expand_shape %arg2 [[0, 1]] output_shape [2, 1] : tensor<2xi32> into tensor<2x1xi32> - %expanded_0 = tensor.expand_shape %arg3 [[0], [1, 2]] output_shape [2, 1, 3] : tensor<2x3xf32> into tensor<2x1x3xf32> - %expanded_1 = tensor.expand_shape %arg4 [[0], [1, 2]] output_shape [2, 1, 3] : tensor<2x3xf32> into tensor<2x1x3xf32> + %expanded = stablehlo.reshape %arg2 : (tensor<2xi32>) -> tensor<2x1xi32> + %expanded_0 = stablehlo.reshape %arg3 : (tensor<2x3xf32>) -> tensor<2x1x3xf32> + %expanded_1 = stablehlo.reshape %arg4 : (tensor<2x3xf32>) -> tensor<2x1x3xf32> %0:2 = "stablehlo.scatter"(%arg0, %arg1, %expanded, %expanded_0, %expanded_1) ({ ^bb0(%arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor): stablehlo.return %arg5, %arg7 : tensor, tensor diff --git a/mlir-tensorrt/test/Dialect/StableHloExt/canonicalize-scatter.mlir b/mlir-tensorrt/test/Dialect/StableHloExt/canonicalize-scatter.mlir index f30094711..29e74b203 100644 --- a/mlir-tensorrt/test/Dialect/StableHloExt/canonicalize-scatter.mlir +++ b/mlir-tensorrt/test/Dialect/StableHloExt/canonicalize-scatter.mlir @@ -24,13 +24,12 @@ func.func @insert_index_vector_and_window_dims(%dst1: tensor<3x3xf32>, // CHECK-SAME: %[[IND:.*]]: tensor<2xi32>, // CHECK-SAME: %[[UPD1:.*]]: tensor<2x3xf32>, %[[UPD2:.*]]: tensor<2x3xf32>) -// CHECK: %[[IND_:.*]] = tensor.expand_shape %[[IND]] {{\[}}[0, 1]] -// CHECK: %[[UPD1_:.*]] = tensor.expand_shape %[[UPD1]] {{\[}}[0], [1, 2]] -// CHECK: %[[UPD2_:.*]] = tensor.expand_shape %[[UPD2]] {{\[}}[0], [1, 2]] +// CHECK: %[[IND_:.*]] = stablehlo.reshape %[[IND]] -// CHECK: "stablehlo.scatter"(%[[DST1]], %[[DST2]], %[[IND_]], %[[UPD1_]], %[[UPD2_]]) -// CHECK: update_window_dims = [1, 2], -// CHECK-SAME: scatter_dims_to_operand_dims = [0], +// CHECK: "stablehlo.scatter"(%[[DST1]], %[[DST2]], %[[IND_]], %[[UPD1]], %[[UPD2]]) +// CHECK: update_window_dims = [1], +// CHECK-SAME: inserted_window_dims = [0] +// CHECK-SAME: scatter_dims_to_operand_dims = [0] // CHECK-SAME: index_vector_dim = 1 // CHECK-SAME: unique_indices = false @@ -53,12 +52,13 @@ func.func @collapse_scatter_dims(%dst: tensor<3x3xf32>, } : (tensor<3x3xf32>, tensor<2x1x2xi32>, tensor<2x1x1x3xf32>) -> tensor<3x3xf32> func.return %0 : tensor<3x3xf32> } + // CHECK-LABEL: func.func @collapse_scatter_dims( // CHECK-SAME: %[[DST:.*]]: tensor<3x3xf32>, %[[IND:.*]]: tensor<2x1x2xi32>, // CHECK-SAME: %[[UPD:.*]]: tensor<2x1x1x3xf32>) -// CHECK: %[[IND_:.*]] = tensor.collapse_shape %[[IND]] {{\[\[}}0, 1], [2]] -// CHECK: %[[UPD_:.*]] = tensor.collapse_shape %[[UPD]] {{\[\[}}0, 1], [2], [3]] +// CHECK: %[[IND_:.*]] = stablehlo.reshape %[[IND]] : (tensor<2x1x2xi32>) -> tensor<2x2xi32> +// CHECK: %[[UPD_:.*]] = stablehlo.reshape %[[UPD]] : (tensor<2x1x1x3xf32>) -> tensor<2x1x3xf32> // CHECK: "stablehlo.scatter"(%[[DST]], %[[IND_]], %[[UPD_]]) // CHECK: #stablehlo.scatter< // CHECK-SAME: update_window_dims = [1, 2], @@ -132,6 +132,43 @@ func.func @transform_updates_and_operands_using_scatter_dims(%dst: tensor<3x4x5x // CHECK-SAME: index_vector_dim = 1 // CHECK: stablehlo.transpose %[[NEW_OP:.*]], dims = [1, 2, 0] : (tensor<5x3x4xf32>) -> tensor<3x4x5xf32> + +// ----- + +func.func @dynamic_transform_updates_and_operands(%dst: tensor<3x?x5xf32>, + %indices: tensor<2x2xi32>, %update: tensor<2x1x1x3xf32>) -> tensor<3x?x5xf32> { + %0 = "stablehlo.scatter"(%dst, %indices, %update) ({ + ^bb0(%u: tensor, %d: tensor): + "stablehlo.return"(%u) : (tensor) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1, 2, 3], + inserted_window_dims = [], + scatter_dims_to_operand_dims = [2, 0], + index_vector_dim = 1, + >, + unique_indices = false, + indices_are_sorted = false + } : (tensor<3x?x5xf32>, tensor<2x2xi32>, tensor<2x1x1x3xf32>) -> tensor<3x?x5xf32> + func.return %0 : tensor<3x?x5xf32> +} + +// CHECK-LABEL: func.func @dynamic_transform_updates_and_operands( +// CHECK-SAME: %[[DST:.*]]: tensor<3x?x5xf32>, +// CHECK-SAME: %[[IND:.*]]: tensor<2x2xi32>, +// CHECK-SAME: %[[UPD:.*]]: tensor<2x1x1x3xf32>) + +// CHECK: %[[DST_:.*]] = stablehlo.transpose %[[DST]], +// CHECK-SAME: dims = [2, 0, 1] : (tensor<3x?x5xf32>) -> tensor<5x3x?xf32> +// CHECK: %[[UPD_:.*]] = stablehlo.reshape %[[UPD]] +// CHECK-SAME: (tensor<2x1x1x3xf32>) -> tensor<2x3x1x1xf32> +// CHECK: %[[NEW_OP:.*]] = "stablehlo.scatter"(%[[DST_]], %[[IND]], %[[UPD_]]) +// CHECK: #stablehlo.scatter< +// CHECK-SAME: update_window_dims = [1, 2, 3], +// CHECK-SAME: scatter_dims_to_operand_dims = [0, 1], +// CHECK-SAME: index_vector_dim = 1 +// CHECK: stablehlo.transpose %[[NEW_OP:.*]], dims = [1, 2, 0] : (tensor<5x3x?xf32>) -> tensor<3x?x5xf32> + // ----- func.func @make_scatter_dims_leading_in_updates(%dst: tensor<3xf32>, @@ -188,8 +225,8 @@ func.func @zero_dim_scatter_indices(%dst: tensor<4x4xf32>, // CHECK-SAME: %[[DST:.*]]: tensor<4x4xf32>, %[[IND:.*]]: tensor<2xi32>, // CHECK-SAME: %[[UPD:.*]]: tensor<3x3xf32> -// CHECK: %[[IND_:.*]] = tensor.expand_shape %[[IND]] {{\[}}[0, 1]] -// CHECK: %[[UPD_:.*]] = tensor.expand_shape %[[UPD]] {{\[}}[0, 1], [2]] +// CHECK: %[[IND_:.*]] = stablehlo.reshape %[[IND]] : (tensor<2xi32>) -> tensor<1x2xi32> +// CHECK: %[[UPD_:.*]] = stablehlo.reshape %[[UPD]] : (tensor<3x3xf32>) -> tensor<1x3x3xf32> // CHECK: "stablehlo.scatter"(%[[DST]], %[[IND_]], %[[UPD_]]) // CHECK: #stablehlo.scatter< // CHECK-SAME: update_window_dims = [1, 2], @@ -198,6 +235,40 @@ func.func @zero_dim_scatter_indices(%dst: tensor<4x4xf32>, // ----- +func.func @dynamic_zero_dim_scatter_indices(%dst: tensor<4x4xf32>, + %indices: tensor<2xi32>, %update: tensor) -> tensor<4x4xf32> { + %0 = "stablehlo.scatter"(%dst, %indices, %update) ({ + ^bb0(%u: tensor, %d: tensor): + "stablehlo.return"(%u) : (tensor) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [0, 1], + inserted_window_dims = [], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 0, + >, + unique_indices = false, + indices_are_sorted = false + } : (tensor<4x4xf32>, tensor<2xi32>, tensor) -> tensor<4x4xf32> + func.return %0 : tensor<4x4xf32> +} + +// CHECK-LABEL: func.func @dynamic_zero_dim_scatter_indices +// CHECK-SAME: (%[[arg0:.+]]: tensor<4x4xf32>, %[[arg1:.+]]: tensor<2xi32>, %[[arg2:.+]]: tensor) -> tensor<4x4xf32> { +// CHECK-DAG: %[[c_0:.+]] = stablehlo.constant dense<1> : tensor<1xi32> +// CHECK-DAG: %[[v0:.+]] = stablehlo.reshape %[[arg1]] : (tensor<2xi32>) -> tensor<1x2xi32> +// CHECK-DAG: %[[v1:.+]] = stablehlo.get_dimension_size %[[arg2]], dim = 0 : (tensor) -> tensor +// CHECK-DAG: %[[v3:.+]] = stablehlo.reshape %[[v1]] : (tensor) -> tensor<1xi32> +// CHECK-DAG: %[[v4:.+]] = stablehlo.get_dimension_size %[[arg2]], dim = 1 : (tensor) -> tensor +// CHECK-DAG: %[[v5:.+]] = stablehlo.reshape %[[v4]] : (tensor) -> tensor<1xi32> +// CHECK-DAG: %[[v6:.+]] = stablehlo.concatenate %[[c_0]], %[[v3]], %[[v5]], dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> +// CHECK-DAG: %[[v7:.+]] = stablehlo.dynamic_reshape %[[arg2]], %[[v6]] : (tensor, tensor<3xi32>) -> tensor<1x?x?xf32> +// CHECK-DAG: %[[v8:.+]] = "stablehlo.scatter"(%[[arg0]], %[[v0]], %[[v7]]) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ +// CHECK: }) : (tensor<4x4xf32>, tensor<1x2xi32>, tensor<1x?x?xf32>) -> tensor<4x4xf32> +// CHECK-DAG: return %[[v8]] : tensor<4x4xf32> + +// ----- + func.func @multiple_window_and_scatter_dims( %dst: tensor<1x2x3x4x5xf32>, %indices: tensor<6x7x2xi32>, %updates: tensor<2x6x4x7xf32>) -> tensor<1x2x3x4x5xf32> { @@ -221,10 +292,56 @@ func.func @multiple_window_and_scatter_dims( // CHECK-SAME: %[[DST:.*]]: tensor<1x2x3x4x5xf32>, // CHECK-SAME: %[[IND:.*]]: tensor<6x7x2xi32>, // CHECK-SAME: %[[UPD:.*]]: tensor<2x6x4x7xf32> -// CHECK: %[[IND0:.*]] = tensor.collapse_shape %[[IND]] {{.*}} into tensor<42x2xi32> - +// CHECK: %[[IND0:.*]] = stablehlo.reshape %[[IND]] : (tensor<6x7x2xi32>) -> tensor<42x2xi32> // CHECK: %[[UPD0:.*]] = stablehlo.transpose %[[UPD]], dims = [1, 3, 0, 2] : (tensor<2x6x4x7xf32>) -> tensor<6x7x2x4xf32> +// CHECK: %[[UPD1:.*]] = stablehlo.reshape %[[UPD0]] : (tensor<6x7x2x4xf32>) -> tensor<42x1x2x1x4x1xf32> +// CHECK: "stablehlo.scatter"(%[[DST]], %[[IND0]], %[[UPD1]]) + +// ----- + +func.func @dynamic_window_size_multiple_window_and_scatter_dims( + %dst: tensor<1x2x3x4x5xf32>, %indices: tensor, + %updates: tensor<2x?x4x?xf32>) -> tensor<1x2x3x4x5xf32> { + %0 = "stablehlo.scatter"(%dst, %indices, %updates) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + stablehlo.return %arg3 : tensor + }) { + indices_are_sorted = false, + scatter_dimension_numbers = #stablehlo.scatter< + inserted_window_dims = [0, 2, 4], + update_window_dims = [0, 2], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 2 + >, unique_indices = false + } : (tensor<1x2x3x4x5xf32>, tensor, tensor<2x?x4x?xf32>) -> + tensor<1x2x3x4x5xf32> + return %0 : tensor<1x2x3x4x5xf32> +} -// CHECK: %[[UPD1:.*]] = tensor.collapse_shape %[[UPD0]] {{.*}} into tensor<42x2x4xf32> -// CHECK: %[[UPD2:.*]] = tensor.expand_shape %[[UPD1]] {{.*}} into tensor<42x1x2x1x4x1xf32> -// CHECK: "stablehlo.scatter"(%[[DST]], %[[IND0]], %[[UPD2]]) \ No newline at end of file +// CHECK-LABEL: func.func @dynamic_window_size_multiple_window_and_scatter_dims +// CHECK-SAME: (%[[arg0:.+]]: tensor<1x2x3x4x5xf32>, %[[arg1:.+]]: tensor, %[[arg2:.+]]: tensor<2x?x4x?xf32>) +// CHECK-DAG: %[[c:.+]] = stablehlo.constant dense<4> : tensor<1xi32> +// CHECK-DAG: %[[c_0:.+]] = stablehlo.constant dense<2> : tensor<1xi32> +// CHECK-DAG: %[[c_1:.+]] = stablehlo.constant dense<1> : tensor<1xi32> +// CHECK-DAG: %[[v0:.+]] = stablehlo.get_dimension_size %[[arg1]], dim = 0 : (tensor) -> tensor +// CHECK-DAG: %[[v1:.+]] = stablehlo.get_dimension_size %[[arg1]], dim = 1 : (tensor) -> tensor +// CHECK-DAG: %[[v2:.+]] = stablehlo.multiply %[[v0]], %[[v1]] : tensor +// CHECK-DAG: %[[v3:.+]] = stablehlo.reshape %[[v2]] : (tensor) -> tensor<1xi32> +// CHECK-DAG: %[[v4:.+]] = stablehlo.concatenate %[[v3]], %[[c_0]], dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> +// CHECK-DAG: %[[v5:.+]] = stablehlo.dynamic_reshape %[[arg1]], %[[v4]] : (tensor, tensor<2xi32>) -> tensor +// CHECK-DAG: %[[v6:.+]] = stablehlo.transpose %[[arg2]], dims = [1, 3, 0, 2] : (tensor<2x?x4x?xf32>) -> tensor +// CHECK-DAG: %[[v7:.+]] = stablehlo.get_dimension_size %[[v6]], dim = 0 : (tensor) -> tensor +// CHECK-DAG: %[[v8:.+]] = stablehlo.get_dimension_size %[[v6]], dim = 1 : (tensor) -> tensor +// CHECK-DAG: %[[v9:.+]] = stablehlo.multiply %[[v7]], %[[v8]] : tensor +// CHECK-DAG: %[[v10:.+]] = stablehlo.reshape %[[v9]] : (tensor) -> tensor<1xi32> +// CHECK-DAG: %[[v11:.+]] = stablehlo.concatenate %[[v10]], %[[c_0]], %[[c]], dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> +// CHECK-DAG: %[[v12:.+]] = stablehlo.dynamic_reshape %[[v6]], %[[v11]] : (tensor, tensor<3xi32>) -> tensor +// CHECK-DAG: %[[v13:.+]] = stablehlo.get_dimension_size %[[v12]], dim = 0 : (tensor) -> tensor +// CHECK-DAG: %[[v14:.+]] = stablehlo.reshape %[[v13]] : (tensor) -> tensor<1xi32> +// CHECK-DAG: %[[v15:.+]] = stablehlo.concatenate %[[v14]], %[[c_1]], %[[c_0]], %[[c_1]], %[[c]], %[[c_1]], dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<6xi32> +// CHECK-DAG: %[[v16:.+]] = stablehlo.dynamic_reshape %[[v12]], %[[v15]] : (tensor, tensor<6xi32>) -> tensor +// CHECK-DAG: %[[v17:.+]] = "stablehlo.scatter"(%[[arg0]], %[[v5]], %[[v16]]) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ +// CHECK-DAG: ^bb0(%[[arg3:.+]]: tensor, %[[arg4:.+]]: tensor): +// CHECK-DAG: stablehlo.return %[[arg3]] : tensor +// CHECK-DAG: }) : (tensor<1x2x3x4x5xf32>, tensor, tensor) -> tensor<1x2x3x4x5xf32> +// CHECK-DAG: return %[[v17]] : tensor<1x2x3x4x5xf32> diff --git a/mlir-tensorrt/test/Dialect/StableHloExt/constant-folding-bitwise.mlir b/mlir-tensorrt/test/Dialect/StableHloExt/constant-folding-bitwise.mlir new file mode 100644 index 000000000..5e97b8185 --- /dev/null +++ b/mlir-tensorrt/test/Dialect/StableHloExt/constant-folding-bitwise.mlir @@ -0,0 +1,63 @@ +// RUN: mlir-tensorrt-opt %s -split-input-file -stablehlo-ext-constant-folding | FileCheck %s + +func.func @trivial_right_shift(%arg0: tensor) -> tensor { + %c32 = stablehlo.constant dense<32> : tensor + %0 = stablehlo.shift_right_logical %arg0, %c32 : tensor + return %0 : tensor +} + +// CHECK-LABEL: @trivial_right_shift +// CHECK-SAME: (%[[arg0:.+]]: tensor) -> tensor { +// CHECK: %[[v0:.+]] = stablehlo.constant dense<0> : tensor +// CHECK: return %[[v0]] : tensor + +// ----- + +func.func @dynamic_trivial_right_shift(%arg0: tensor) -> tensor { + %c32 = stablehlo.constant dense<32> : tensor<1xi32> + %d = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor) -> tensor + %d.1 = stablehlo.reshape %d : (tensor) -> tensor<1xi32> + %c32_1 = stablehlo.dynamic_broadcast_in_dim %c32, %d.1, dims = [0] : (tensor<1xi32>, tensor<1xi32>) -> (tensor) + %0 = stablehlo.shift_right_logical %arg0, %c32_1 : tensor + return %0 : tensor +} + +// CHECK-LABEL: @dynamic_trivial_right_shift +// CHECK: stablehlo.shift_right_logical + +// ----- + +func.func @nontrivial_right_shift(%arg0: tensor) -> tensor { + %c16 = stablehlo.constant dense<16> : tensor + %0 = stablehlo.shift_right_logical %arg0, %c16 : tensor + return %0 : tensor +} + +// CHECK-LABEL: @nontrivial_right_shift +// CHECK-SAME: (%[[arg0:.+]]: tensor) -> tensor { +// CHECK: %[[v0:.+]] = stablehlo.constant dense<16> : tensor +// CHECK: %[[v1:.+]] = stablehlo.shift_right_logical %[[arg0]], %[[v0]] : tensor +// CHECK: return %[[v1]] : tensor + +// ----- + +func.func @jax_random_seed(%arg0: tensor) -> (tensor<2xi32>) { + %0 = stablehlo.constant dense<32> : tensor + %1 = stablehlo.shift_right_logical %arg0, %0 : tensor + %2 = stablehlo.convert %1 : (tensor) -> tensor + %3 = stablehlo.reshape %2 : (tensor) -> tensor<1xi32> + %4 = stablehlo.constant dense<4294967295> : tensor + %5 = stablehlo.convert %4 : (tensor) -> tensor + %6 = stablehlo.and %arg0, %5 : tensor + %7 = stablehlo.convert %6 : (tensor) -> tensor + %8 = stablehlo.reshape %7 : (tensor) -> tensor<1xi32> + %9 = "stablehlo.concatenate"(%3, %8) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + return %9 : tensor<2xi32> +} + +// CHECK-LABEL: @jax_random_seed +// CHECK-SAME: (%[[arg0:.+]]: tensor) -> tensor<2xi32> { +// CHECK-DAG: %[[v1:.+]] = stablehlo.constant dense<0> : tensor<1xi32> +// CHECK-DAG: %[[v3:.+]] = stablehlo.reshape %[[arg0]] : (tensor) -> tensor<1xi32> +// CHECK-DAG: %[[v4:.+]] = stablehlo.concatenate %[[v1]], %[[v3]], dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> +// CHECK-DAG: return %[[v4]] : tensor<2xi32> diff --git a/mlir-tensorrt/test/Dialect/StableHloExt/stablehlo-input-preprocessing.mlir b/mlir-tensorrt/test/Dialect/StableHloExt/stablehlo-input-preprocessing.mlir deleted file mode 100644 index 2772329e8..000000000 --- a/mlir-tensorrt/test/Dialect/StableHloExt/stablehlo-input-preprocessing.mlir +++ /dev/null @@ -1,123 +0,0 @@ -// RUN: mlir-tensorrt-opt %s -tensorrt-stablehlo-input-preprocessing -stablehlo-aggressive-simplification -split-input-file | FileCheck %s - -func.func @trivial_right_shift(%arg0: tensor) -> tensor { - %c32 = stablehlo.constant dense<32> : tensor - %0 = stablehlo.shift_right_logical %arg0, %c32 : tensor - return %0 : tensor -} - -// CHECK-LABEL: @trivial_right_shift -// CHECK-SAME: (%[[arg0:.+]]: tensor) -> tensor { -// CHECK: %[[v0:.+]] = stablehlo.constant dense<0> : tensor -// CHECK: return %[[v0]] : tensor - -// ----- - -func.func @nontrivial_right_shift(%arg0: tensor) -> tensor { - %c16 = stablehlo.constant dense<16> : tensor - %0 = stablehlo.shift_right_logical %arg0, %c16 : tensor - return %0 : tensor -} - -// CHECK-LABEL: @nontrivial_right_shift -// CHECK-SAME: (%[[arg0:.+]]: tensor) -> tensor { -// CHECK: %[[v0:.+]] = stablehlo.constant dense<16> : tensor -// CHECK: %[[v1:.+]] = stablehlo.shift_right_logical %[[arg0]], %[[v0]] : tensor -// CHECK: return %[[v1]] : tensor - -// ----- - -func.func @jax_random_seed(%arg0: tensor) -> (tensor<2xi32>) { - %0 = stablehlo.constant dense<32> : tensor - %1 = stablehlo.shift_right_logical %arg0, %0 : tensor - %2 = stablehlo.convert %1 : (tensor) -> tensor - %3 = stablehlo.reshape %2 : (tensor) -> tensor<1xi32> - %4 = stablehlo.constant dense<4294967295> : tensor - %5 = stablehlo.convert %4 : (tensor) -> tensor - %6 = stablehlo.and %arg0, %5 : tensor - %7 = stablehlo.convert %6 : (tensor) -> tensor - %8 = stablehlo.reshape %7 : (tensor) -> tensor<1xi32> - %9 = "stablehlo.concatenate"(%3, %8) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - return %9 : tensor<2xi32> -} - -// CHECK-LABEL: @jax_random_seed -// CHECK-SAME: (%[[arg0:.+]]: tensor) -> tensor<2xi32> { -// CHECK: %[[v0:.+]] = stablehlo.constant dense<-1> : tensor -// CHECK: %[[v1:.+]] = stablehlo.constant dense<0> : tensor<1xi32> -// CHECK: %[[v2:.+]] = stablehlo.and %[[arg0]], %[[v0]] : tensor -// CHECK: %[[v3:.+]] = stablehlo.reshape %[[v2]] : (tensor) -> tensor<1xi32> -// CHECK: %[[v4:.+]] = stablehlo.concatenate %[[v1]], %[[v3]], dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> -// CHECK: return %[[v4]] : tensor<2xi32> - -// ----- - -func.func @erf_inv(%arg0 : tensor<4xf32>) -> tensor<4xf32> { - %0 = chlo.erf_inv %arg0 : tensor<4xf32> -> tensor<4xf32> - return %0 : tensor<4xf32> -} -// CHECK-LABEL: func.func @erf_inv -// CHECK-SAME: (%[[arg0:.+]]: tensor<4xf32>) -> tensor<4xf32> { -// CHECK-DAG: %[[cst:.+]] = stablehlo.constant dense<0x7F800000> : tensor<4xf32> -// CHECK-DAG: %[[cst_0:.+]] = stablehlo.constant dense<1.000000e+00> : tensor<4xf32> -// CHECK-DAG: %[[cst_1:.+]] = stablehlo.constant dense<2.83297682> : tensor<4xf32> -// CHECK-DAG: %[[cst_2:.+]] = stablehlo.constant dense<1.50140941> : tensor<4xf32> -// CHECK-DAG: %[[cst_3:.+]] = stablehlo.constant dense<1.00167406> : tensor<4xf32> -// CHECK-DAG: %[[cst_4:.+]] = stablehlo.constant dense<0.246640727> : tensor<4xf32> -// CHECK-DAG: %[[cst_5:.+]] = stablehlo.constant dense<0.00943887047> : tensor<4xf32> -// CHECK-DAG: %[[cst_6:.+]] = stablehlo.constant dense<-0.00417768164> : tensor<4xf32> -// CHECK-DAG: %[[cst_7:.+]] = stablehlo.constant dense<-0.0076224613> : tensor<4xf32> -// CHECK-DAG: %[[cst_8:.+]] = stablehlo.constant dense<-0.00125372503> : tensor<4xf32> -// CHECK-DAG: %[[cst_9:.+]] = stablehlo.constant dense<0.00573950773> : tensor<4xf32> -// CHECK-DAG: %[[cst_10:.+]] = stablehlo.constant dense<2.1858087E-4> : tensor<4xf32> -// CHECK-DAG: %[[cst_11:.+]] = stablehlo.constant dense<-0.00367342844> : tensor<4xf32> -// CHECK-DAG: %[[cst_12:.+]] = stablehlo.constant dense<-4.39150654E-6> : tensor<4xf32> -// CHECK-DAG: %[[cst_13:.+]] = stablehlo.constant dense<0.00134934322> : tensor<4xf32> -// CHECK-DAG: %[[cst_14:.+]] = stablehlo.constant dense<-3.5233877E-6> : tensor<4xf32> -// CHECK-DAG: %[[cst_15:.+]] = stablehlo.constant dense<1.00950558E-4> : tensor<4xf32> -// CHECK-DAG: %[[cst_16:.+]] = stablehlo.constant dense<3.43273939E-7> : tensor<4xf32> -// CHECK-DAG: %[[cst_17:.+]] = stablehlo.constant dense<-2.00214257E-4> : tensor<4xf32> -// CHECK-DAG: %[[cst_18:.+]] = stablehlo.constant dense<2.81022636E-8> : tensor<4xf32> -// CHECK-DAG: %[[cst_19:.+]] = stablehlo.constant dense<3.000000e+00> : tensor<4xf32> -// CHECK-DAG: %[[cst_20:.+]] = stablehlo.constant dense<2.500000e+00> : tensor<4xf32> -// CHECK-DAG: %[[cst_21:.+]] = stablehlo.constant dense<5.000000e+00> : tensor<4xf32> -// CHECK-DAG: %[[v0:.+]] = stablehlo.negate %[[arg0]] : tensor<4xf32> -// CHECK-DAG: %[[v1:.+]] = stablehlo.multiply %[[arg0]], %[[v0]] : tensor<4xf32> -// CHECK-DAG: %[[v2:.+]] = stablehlo.log_plus_one %[[v1]] : tensor<4xf32> -// CHECK-DAG: %[[v3:.+]] = stablehlo.negate %[[v2]] : tensor<4xf32> -// CHECK-DAG: %[[v4:.+]] = stablehlo.compare LT, %[[v3]], %[[cst_21]] : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> -// CHECK-DAG: %[[v5:.+]] = stablehlo.subtract %[[v3]], %[[cst_20]] : tensor<4xf32> -// CHECK-DAG: %[[v6:.+]] = stablehlo.sqrt %[[v3]] : tensor<4xf32> -// CHECK-DAG: %[[v7:.+]] = stablehlo.subtract %[[v6]], %[[cst_19]] : tensor<4xf32> -// CHECK-DAG: %[[v8:.+]] = stablehlo.select %[[v4]], %[[v5]], %[[v7]] : tensor<4xi1>, tensor<4xf32> -// CHECK-DAG: %[[v9:.+]] = stablehlo.select %[[v4]], %[[cst_18]], %[[cst_17]] : tensor<4xi1>, tensor<4xf32> -// CHECK-DAG: %[[v10:.+]] = stablehlo.select %[[v4]], %[[cst_16]], %[[cst_15]] : tensor<4xi1>, tensor<4xf32> -// CHECK-DAG: %[[v11:.+]] = stablehlo.multiply %[[v9]], %[[v8]] : tensor<4xf32> -// CHECK-DAG: %[[v12:.+]] = stablehlo.add %[[v10]], %[[v11]] : tensor<4xf32> -// CHECK-DAG: %[[v13:.+]] = stablehlo.select %[[v4]], %[[cst_14]], %[[cst_13]] : tensor<4xi1>, tensor<4xf32> -// CHECK-DAG: %[[v14:.+]] = stablehlo.multiply %[[v12]], %[[v8]] : tensor<4xf32> -// CHECK-DAG: %[[v15:.+]] = stablehlo.add %[[v13]], %[[v14]] : tensor<4xf32> -// CHECK-DAG: %[[v16:.+]] = stablehlo.select %[[v4]], %[[cst_12]], %[[cst_11]] : tensor<4xi1>, tensor<4xf32> -// CHECK-DAG: %[[v17:.+]] = stablehlo.multiply %[[v15]], %[[v8]] : tensor<4xf32> -// CHECK-DAG: %[[v18:.+]] = stablehlo.add %[[v16]], %[[v17]] : tensor<4xf32> -// CHECK-DAG: %[[v19:.+]] = stablehlo.select %[[v4]], %[[cst_10]], %[[cst_9]] : tensor<4xi1>, tensor<4xf32> -// CHECK-DAG: %[[v20:.+]] = stablehlo.multiply %[[v18]], %[[v8]] : tensor<4xf32> -// CHECK-DAG: %[[v21:.+]] = stablehlo.add %[[v19]], %[[v20]] : tensor<4xf32> -// CHECK-DAG: %[[v22:.+]] = stablehlo.select %[[v4]], %[[cst_8]], %[[cst_7]] : tensor<4xi1>, tensor<4xf32> -// CHECK-DAG: %[[v23:.+]] = stablehlo.multiply %[[v21]], %[[v8]] : tensor<4xf32> -// CHECK-DAG: %[[v24:.+]] = stablehlo.add %[[v22]], %[[v23]] : tensor<4xf32> -// CHECK-DAG: %[[v25:.+]] = stablehlo.select %[[v4]], %[[cst_6]], %[[cst_5]] : tensor<4xi1>, tensor<4xf32> -// CHECK-DAG: %[[v26:.+]] = stablehlo.multiply %[[v24]], %[[v8]] : tensor<4xf32> -// CHECK-DAG: %[[v27:.+]] = stablehlo.add %[[v25]], %[[v26]] : tensor<4xf32> -// CHECK-DAG: %[[v28:.+]] = stablehlo.select %[[v4]], %[[cst_4]], %[[cst_3]] : tensor<4xi1>, tensor<4xf32> -// CHECK-DAG: %[[v29:.+]] = stablehlo.multiply %[[v27]], %[[v8]] : tensor<4xf32> -// CHECK-DAG: %[[v30:.+]] = stablehlo.add %[[v28]], %[[v29]] : tensor<4xf32> -// CHECK-DAG: %[[v31:.+]] = stablehlo.select %[[v4]], %[[cst_2]], %[[cst_1]] : tensor<4xi1>, tensor<4xf32> -// CHECK-DAG: %[[v32:.+]] = stablehlo.multiply %[[v30]], %[[v8]] : tensor<4xf32> -// CHECK-DAG: %[[v33:.+]] = stablehlo.add %[[v31]], %[[v32]] : tensor<4xf32> -// CHECK-DAG: %[[v34:.+]] = stablehlo.multiply %[[v33]], %[[arg0]] : tensor<4xf32> -// CHECK-DAG: %[[v35:.+]] = stablehlo.abs %[[arg0]] : tensor<4xf32> -// CHECK-DAG: %[[v36:.+]] = stablehlo.compare EQ, %[[v35]], %[[cst_0]] : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> -// CHECK-DAG: %[[v37:.+]] = stablehlo.multiply %[[arg0]], %[[cst]] : tensor<4xf32> -// CHECK-DAG: %[[v38:.+]] = stablehlo.select %[[v36]], %[[v37]], %[[v34]] : tensor<4xi1>, tensor<4xf32> -// CHECK-DAG: return %[[v38]] : tensor<4xf32> \ No newline at end of file diff --git a/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_dynamic_iota.py b/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_dynamic_iota.py new file mode 100644 index 000000000..ffa0d461e --- /dev/null +++ b/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_dynamic_iota.py @@ -0,0 +1,105 @@ +# RUN: %PYTHON %s +import mlir_tensorrt.compiler.api as compiler +import mlir_tensorrt.compiler.ir as ir +from mlir_tensorrt.compiler.dialects import builtin, stablehlo, func +import mlir_tensorrt.runtime.api as runtime +import numpy as np + + +def build_program(dtype, iota_dim): + @builtin.module(sym_name=f"dynamic_iota_test_{dtype}") + def mlir_module(): + DYNAMIC = ir.RankedTensorType.get_dynamic_size() + i32 = ir.IntegerType.get_signless(32) + shape_type = ir.RankedTensorType.get([2], i32) + result_type = ir.RankedTensorType.get([DYNAMIC, 3], dtype) + + @func.func(shape_type) + def main(shape): + return stablehlo.dynamic_iota(result_type, shape, iota_dimension=iota_dim) + + main.func_op.arg_attrs = [ + ir.DictAttr.get( + { + "tensorrt.value_bounds": ir.Attribute.parse( + "#tensorrt.shape_profile" + ) + } + ) + ] + + return mlir_module + + +def get_mlir_dtype(dtype): + if dtype == np.int32: + return ir.IntegerType.get_signless(32) + elif dtype == np.int64: + return ir.IntegerType.get_signless(64) + elif dtype == np.float32: + return ir.F32Type.get() + else: + raise Exception("unsupported dtype") + + +def build_exe(dtype, iota_dim): + with ir.Context() as context, ir.Location.unknown(): + module = build_program(dtype=get_mlir_dtype(dtype), iota_dim=iota_dim) + print(module.operation) + + # Use the compiler API to compile to executable. + client = compiler.CompilerClient(context) + opts = compiler.StableHLOToExecutableOptions( + client, + [ + "--tensorrt-builder-opt-level=3", + "--tensorrt-strongly-typed=false", + "--entrypoint=main", + "--mlir-print-ir-tree-dir=tmp", + ], + ) + return compiler.compiler_stablehlo_to_executable(client, module.operation, opts) + + +def run_test(exe, dtype, iota_dim): + client = runtime.RuntimeClient() + stream = client.create_stream() + devices = client.get_devices() + if len(devices) == 0: + return + + session_options = runtime.RuntimeSessionOptions(num_devices=1, device_id=0) + session = runtime.RuntimeSession(session_options, exe) + + dynamic_size = 128 + + arg0 = client.create_memref( + np.asarray([dynamic_size, 3], dtype=np.int32), + device=devices[0], + stream=stream, + ) + arg1 = client.create_memref( + np.zeros(shape=(dynamic_size, 3), dtype=dtype), + device=devices[0], + stream=stream, + ) + session.execute_function("main", in_args=[arg0], out_args=[arg1], stream=stream) + data = np.asarray(client.copy_to_host(arg1, stream=stream)) + stream.sync() + + broadcast_shape = [dynamic_size, 3] + iota_size = dynamic_size if iota_dim == 0 else 3 + iota_reshape = [1, 1] + iota_reshape[iota_dim] = -1 + + expected = np.linspace(0, iota_size - 1, num=iota_size, dtype=dtype).reshape( + *iota_reshape + ) * np.ones(broadcast_shape, dtype=dtype) + np.testing.assert_array_equal(data, expected) + + +if __name__ == "__main__": + for dtype in [np.int64, np.int32, np.float32]: + for iota_dim in [0, 1]: + exe = build_exe(dtype, iota_dim) + run_test(exe, dtype, iota_dim) From e1b1a0cdc63cf806fa07d33d2f3ec98e6fe53f17 Mon Sep 17 00:00:00 2001 From: Christopher Bate Date: Wed, 13 Nov 2024 09:56:29 -0700 Subject: [PATCH 04/29] [mlir-tensorrt][Dialect/Plan] Enable 'plan-materialize-shape-calculations' to use additional StableHLO simplification patterns(#367) Prior to this change, the 'plan-materialize-shape-calculations' pass could perform a number of simplifications based on bounds analysis, but it did not actually perform StableHLO simplifications. This change adds additional `stablehlo-canonicalize-dynamism` patterns so that the IR can be further simplified. Co-authored-by: Copybara Bot --- .../MaterializeShapeCalculations.cpp | 3 + .../executor/test/Unit/CMakeLists.txt | 17 ++++- .../Plan/materialize-shape-calculations.mlir | 76 ++++++++++++++----- 3 files changed, 74 insertions(+), 22 deletions(-) diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/MaterializeShapeCalculations.cpp b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/MaterializeShapeCalculations.cpp index 490b60aab..6ae987c71 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/MaterializeShapeCalculations.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/MaterializeShapeCalculations.cpp @@ -46,6 +46,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "stablehlo/conversions/linalg/transforms/MapStablehloToScalarOp.h" #include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/transforms/Passes.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" @@ -856,6 +857,8 @@ class MaterializeShapeCalculationsPass RewritePatternSet patterns_(ctx); memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns_); stablehlo_ext::populateStableHloAbsorbTensorCastPatterns(patterns_); + stablehlo::populateStablehloCanonicalizeDynamismPatterns(&patterns_, ctx); + // clang-format off addCanonicalizationPatterns< arith::AndIOp, diff --git a/mlir-tensorrt/executor/test/Unit/CMakeLists.txt b/mlir-tensorrt/executor/test/Unit/CMakeLists.txt index 4767c71d3..c7fe2824e 100644 --- a/mlir-tensorrt/executor/test/Unit/CMakeLists.txt +++ b/mlir-tensorrt/executor/test/Unit/CMakeLists.txt @@ -3,9 +3,20 @@ add_custom_target(MLIRTensorRTExecutorUnitTests) set_target_properties(MLIRTensorRTExecutorUnitTests PROPERTIES FOLDER "MLIR-TensorRT Executor Unit Tests") # Use this function for populating GTest-based unit tests. -function(add_mlir_executor_unittest name) - add_unittest(MLIRTensorRTExecutorUnitTests ${name} ${ARGN}) - llvm_update_compile_flags(${name}) +function(add_mlir_executor_unittest target) + set(LLVM_LINK_COMPONENTS Support) + add_llvm_executable(${target} IGNORE_EXTERNALIZE_DEBUGINFO NO_INSTALL_RPATH ${ARGN}) + add_dependencies(MLIRTensorRTExecutorUnitTests ${target}) + llvm_update_compile_flags(${target}) + if(TARGET gtest) + target_link_libraries(${target} PRIVATE + gtest gtest_main gmock) + elseif(TARGET llvm_gtest) + target_link_libraries(${target} PRIVATE + llvm_gtest llvm_gtest_main) + else() + message(FATAL_ERROR "No GTest library found") + endif() endfunction() if (EXISTS ${LLVM_THIRD_PARTY_DIR}/unittest/googletest/include/gtest/gtest.h) diff --git a/mlir-tensorrt/test/Dialect/Plan/materialize-shape-calculations.mlir b/mlir-tensorrt/test/Dialect/Plan/materialize-shape-calculations.mlir index 06f03bd0a..b3ab0521e 100644 --- a/mlir-tensorrt/test/Dialect/Plan/materialize-shape-calculations.mlir +++ b/mlir-tensorrt/test/Dialect/Plan/materialize-shape-calculations.mlir @@ -789,25 +789,28 @@ func.func @zero_slice_slice(%arg4: tensor<1xi32>, } // CHECK-LABEL: func.func @zero_slice_slice -// CHECK-SAME: (%[[arg0:.+]]: tensor<1xi32>, %[[arg1:.+]]: tensor<1xi32> {plan.value_bounds = #plan.bounds : tensor<1xi32>, dense<1> : tensor<1xi32>>}, %[[arg2:.+]]: tensor<1xi32> {plan.value_bounds = #plan.bounds : tensor<1xi32>, dense<1> : tensor<1xi32>>}, %[[arg3:.+]]: tensor<1xi32> {plan.value_bounds = #plan.bounds : tensor<1xi32>, dense<1> : tensor<1xi32>>}, %[[arg4:.+]]: tensor<1xi32> {plan.shape_profile = #plan.bounds}) -// CHECK-DAG: %[[cst:.+]] = arith.constant dense<1> : tensor<1xi32> -// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[c1_i32:.+]] = arith.constant 1 : i32 -// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[extracted:.+]] = tensor.extract %[[arg3]][%[[c0]]] : tensor<1xi32> -// CHECK-DAG: %[[extracted_0:.+]] = tensor.extract %[[arg2]][%[[c0]]] : tensor<1xi32> -// CHECK-DAG: %[[extracted_1:.+]] = tensor.extract %[[arg1]][%[[c0]]] : tensor<1xi32> -// CHECK-DAG: %[[v3:.+]] = stablehlo.real_dynamic_slice %[[arg4]], %[[cst]], %[[cst]], %[[cst]] : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor -// CHECK-DAG: %[[v4:.+]] = arith.subi %[[extracted_0]], %[[extracted_1]] : i32 -// CHECK-DAG: %[[v5:.+]] = arith.addi %[[extracted]], %[[v4]] : i32 -// CHECK-DAG: %[[v6:.+]] = arith.subi %[[v5]], %[[c1_i32]] : i32 -// CHECK-DAG: %[[v7:.+]] = arith.divsi %[[v6]], %[[extracted]] : i32 -// CHECK-DAG: %[[v8:.+]] = arith.index_cast %[[v7]] : i32 to index -// CHECK-DAG: %[[v9:.+]] = plan.with_shape %[[v3]](%[[v7]]) : (tensor, i32) -> tensor -// CHECK-DAG: %[[v10:.+]] = stablehlo.concatenate %[[arg0]], %[[v9]], dim = 0 : (tensor<1xi32>, tensor) -> tensor -// CHECK-DAG: %[[v11:.+]] = arith.addi %[[v8]], %[[c1]] : index -// CHECK-DAG: %[[v12:.+]] = plan.with_shape %[[v10]](%[[v11]]) : (tensor, index) -> tensor -// CHECK-DAG: return %[[v12]] : tensor +// CHECK-SAME: (%[[arg0:[a-zA-Z0-9]+]]: tensor<1xi32>, +// CHECK-SAME: %[[arg1:[a-zA-Z0-9]+]]: tensor<1xi32> {plan.value_bounds = #plan.bounds : tensor<1xi32>, dense<1> : tensor<1xi32>>}, +// CHECK-SAME: %[[arg2:[a-zA-Z0-9]+]]: tensor<1xi32> {plan.value_bounds = #plan.bounds : tensor<1xi32>, dense<1> : tensor<1xi32>>}, +// CHECK-SAME: %[[arg3:[a-zA-Z0-9]+]]: tensor<1xi32> {plan.value_bounds = #plan.bounds : tensor<1xi32>, dense<1> : tensor<1xi32>>}, +// CHECK-SAME: %[[arg4:[a-zA-Z0-9]+]]: tensor<1xi32> {plan.shape_profile = #plan.bounds}) +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[c1_i32:.+]] = arith.constant 1 : i32 +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[extracted:.+]] = tensor.extract %[[arg3]][%[[c0]]] +// CHECK-DAG: %[[extracted_0:.+]] = tensor.extract %[[arg2]][%[[c0]]] +// CHECK-DAG: %[[extracted_1:.+]] = tensor.extract %[[arg1]][%[[c0]]] +// CHECK-DAG: %[[v0:.+]] = stablehlo.slice %[[arg4]] [1:1] +// CHECK-DAG: %[[v1:.+]] = arith.subi %[[extracted_0]], %[[extracted_1]] : i32 +// CHECK-DAG: %[[v2:.+]] = arith.addi %[[extracted]], %[[v1]] : i32 +// CHECK-DAG: %[[v3:.+]] = arith.subi %[[v2]], %[[c1_i32]] : i32 +// CHECK-DAG: %[[v4:.+]] = arith.divsi %[[v3]], %[[extracted]] : i32 +// CHECK-DAG: %[[v5:.+]] = arith.index_cast %[[v4]] : i32 to index +// CHECK-DAG: %[[v6:.+]] = plan.with_shape %[[v0]](%[[v4]]) +// CHECK-DAG: %[[v7:.+]] = stablehlo.concatenate %[[arg0]], %[[v6]], dim = 0 : +// CHECK-DAG: %[[v8:.+]] = arith.addi %[[v5]], %[[c1]] : index +// CHECK-DAG: %[[v9:.+]] = plan.with_shape %[[v7]](%[[v8]]) : +// CHECK-DAG: return %[[v9]] : tensor // ----- @@ -1023,3 +1026,38 @@ func.func @refine_based_on_profile(%arg0: tensor {tensorrt.shape_profil // CHECK-DAG: %[[v1:.+]] = stablehlo.transpose %[[v0]], dims = [1, 0] : // CHECK-DAG: %[[v2:.+]] = plan.with_shape %[[v1]](%[[c128]], %[[dim]]) : (tensor, index, index) -> tensor // CHECK-DAG: return %[[v2]] : tensor + +// ----- + +#profile = #tensorrt.shape_profile + +func.func @dynamic_gather_simplify(%arg0: tensor, %arg1: tensor, %arg2: tensor { + tensorrt.shape_profile = #profile +}) -> tensor { + %4733 = stablehlo.get_dimension_size %arg2, dim = 0 : (tensor) -> tensor + %4734 = stablehlo.reshape %4733 : (tensor) -> tensor<1xi32> + %4735 = stablehlo.get_dimension_size %arg2, dim = 1 : (tensor) -> tensor + %4736 = stablehlo.reshape %4735 : (tensor) -> tensor<1xi32> + %4737 = stablehlo.concatenate %4734, %4736, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %4941 = stablehlo.slice %4737 [1:2] : (tensor<2xi32>) -> tensor<1xi32> + %c_73 = stablehlo.constant dense<1> : tensor<1xi32> + %4942 = stablehlo.concatenate %c_73, %4941, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %4943 = "stablehlo.dynamic_gather"(%arg0, %arg1, %4942) + <{dimension_numbers = #stablehlo.gather}> : (tensor, tensor, tensor<2xi32>) + -> tensor + return %4943 :tensor +} + +// CHECK-LABEL: func.func @dynamic_gather_simplify +// CHECK-SAME: (%[[arg0:.+]]: tensor, %[[arg1:.+]]: tensor, %[[arg2:.+]]: tensor {plan.shape_profile = #plan.bounds}) +// CHECK-DAG: %[[c3_i32:.+]] = arith.constant 3 : i32 +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[dim:.+]] = tensor.dim %[[arg1]], %[[c0]] : tensor +// CHECK-DAG: %[[v0:.+]] = plan.with_shape %[[arg1]](%[[dim]]) +// CHECK-DAG: %[[dim_0:.+]] = tensor.dim %[[arg0]], %[[c0]] : tensor +// CHECK-DAG: %[[dim_1:.+]] = tensor.dim %[[arg0]], %[[c1]] : tensor +// CHECK-DAG: %[[v1:.+]] = plan.with_shape %[[arg0]](%[[dim_0]], %[[dim_1]]) +// CHECK-DAG: %[[v2:.+]] = "stablehlo.gather"(%[[v1]], %[[v0]]) +// CHECK-DAG: %[[v3:.+]] = plan.with_shape %[[v2]](%[[dim]], %[[c3_i32]]) +// CHECK-DAG: return %[[v3]] : tensor From 2ad2c0a4c7cb863cc0208844d5f2987910d57430 Mon Sep 17 00:00:00 2001 From: pranavm Date: Tue, 12 Nov 2024 18:03:33 -0800 Subject: [PATCH 05/29] Updates README to include links to common guides --- tripy/README.md | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tripy/README.md b/tripy/README.md index 80abaeb08..b631d0db3 100644 --- a/tripy/README.md +++ b/tripy/README.md @@ -7,16 +7,19 @@ [![Tripy L1](https://github.com/NVIDIA/TensorRT-Incubator/actions/workflows/tripy-l1.yml/badge.svg)](https://github.com/NVIDIA/TensorRT-Incubator/actions/workflows/tripy-l1.yml) -Tripy is a Python programming model for [TensorRT](https://developer.nvidia.com/tensorrt) that aims to provide an excellent -user experience without compromising performance. Some of the features of Tripy include: +Tripy is a Python programming model for [TensorRT](https://developer.nvidia.com/tensorrt) that aims to provide +an excellent user experience without compromising performance. Some of the goals of Tripy are: - **Intuitive API**: Tripy doesn't reinvent the wheel: If you have used NumPy or PyTorch before, Tripy APIs should feel familiar. - **Excellent Error Messages**: When something goes wrong, Tripy tries to provide informative and actionable error messages. Even in cases where the error comes - from deep within the software stack, Tripy is able to map it back to the Python code - that caused it. + from deep within the software stack, Tripy is usually able to map it back to the + related Python code. + +- **Friendly Documentation**: The documentation is meant to be accessible and comprehensive, + with plenty of examples to illustrate important points. ## Installation @@ -73,6 +76,11 @@ We recommend starting with the [Introduction To Tripy](https://nvidia.github.io/TensorRT-Incubator/pre0_user_guides/00-introduction-to-tripy.html) guide. +Other features covered in our guides include: + +- [Compiling code (including dynamic shape support)](https://nvidia.github.io/TensorRT-Incubator/pre0_user_guides/02-compiler.html) +- [Quantization](https://nvidia.github.io/TensorRT-Incubator/pre0_user_guides/01-quantization.html) + To get an idea of the look and feel of Tripy, let's take a look at a short code example. All of the features used in this example are explained in more detail in the introduction guide mentioned above. From 087e0d52abc949803e5a639957c776271a272549 Mon Sep 17 00:00:00 2001 From: pranavm Date: Tue, 12 Nov 2024 18:03:49 -0800 Subject: [PATCH 06/29] Improves docstrings for overloaded functions Improves the docstrings for overloaded functions to be stylistically similar to non-overloaded functions. Also updates the helpers that generate strings from type annotations to be more consistent with the style the documentation uses. For example, `Union[int, float]` will now be rendered as `int | float`. --- tripy/docs/_static/style.css | 11 ++ tripy/tests/test_function_registry.py | 55 ++++--- tripy/tripy/backend/api/executable.py | 9 +- .../tripy/frontend/ops/tensor_initializers.py | 2 +- tripy/tripy/function_registry.py | 150 +++++++++++------- 5 files changed, 143 insertions(+), 84 deletions(-) diff --git a/tripy/docs/_static/style.css b/tripy/docs/_static/style.css index 40ccca4c7..efc9a968d 100644 --- a/tripy/docs/_static/style.css +++ b/tripy/docs/_static/style.css @@ -21,3 +21,14 @@ section { margin-top: 2rem; margin-bottom: 2rem; } + +.func-overload-sig { + padding-left: 3em !important; + color: var(--color-api-overall); + font-style: normal; +} + +.func-overload-sig p { + margin-bottom: 0 !important; + margin-top: 0 !important; +} diff --git a/tripy/tests/test_function_registry.py b/tripy/tests/test_function_registry.py index 620d39c02..1ad4c1ef0 100644 --- a/tripy/tests/test_function_registry.py +++ b/tripy/tests/test_function_registry.py @@ -25,7 +25,7 @@ import tripy as tp from tripy import TripyException -from tripy.function_registry import AnnotationInfo, FunctionRegistry, render_arg_type, sanitize_name +from tripy.function_registry import AnnotationInfo, FunctionRegistry, type_str_from_arg, str_from_type_annotation @pytest.fixture() @@ -199,10 +199,10 @@ def func(a: int): func_overload = registry.overloads["test"][0] - assert not func_overload.annotations + assert not func_overload._annotations assert registry["test"](0) == 1 - assert func_overload.annotations - assert func_overload.annotations["a"] == AnnotationInfo(int, False, inspect.Parameter.POSITIONAL_OR_KEYWORD) + assert func_overload._annotations + assert func_overload._annotations["a"] == AnnotationInfo(int, False, inspect.Parameter.POSITIONAL_OR_KEYWORD) def test_doc_of_non_overloaded_func(self, registry): # When there is no overload, the registry function should @@ -224,10 +224,11 @@ def func(a: int): """ pass + # Tripy types should turn into class links @registry("test") - def func(a: float): + def func(a: Union[int, "tripy.Tensor"]): """ - This func takes a float. + This func takes an int or a tensor. """ pass @@ -235,20 +236,34 @@ def func(a: float): assert ( registry["test"].__doc__ == dedent( - """ + r""" *This function has multiple overloads:* ---------- - > **test** (*a*: :class:`int`) -> None + .. role:: sig-prename + :class: sig-prename descclassname + .. role:: sig-name + :class: sig-name descname + + .. container:: func-overload-sig sig sig-object py + + :sig-prename:`tripy`\ .\ :sig-name:`test`\ (a: int) -> None This func takes an int. ---------- - > **test** (*a*: :class:`float`) -> None + .. role:: sig-prename + :class: sig-prename descclassname + .. role:: sig-name + :class: sig-name descname + + .. container:: func-overload-sig sig sig-object py + + :sig-prename:`tripy`\ .\ :sig-name:`test`\ (a: int | :class:`tripy.Tensor`) -> None - This func takes a float. + This func takes an int or a tensor. """ ).strip() ) @@ -379,7 +394,7 @@ def func(n: Union[int, float]) -> int: [0-9]+ \| \.\.\. \|\s - Not a valid overload because: For parameter: 'n', expected an instance of type: 'Union\[int, float\]' but got argument of type: 'List\[str\]'\. + Not a valid overload because: For parameter: 'n', expected an instance of type: 'int | float' but got argument of type: 'List\[str\]'\. """ ).strip(), ): @@ -403,7 +418,7 @@ def func(n: Sequence[int]) -> int: [0-9]+ \| \.\.\. \|\s - Not a valid overload because: For parameter: 'n', expected an instance of type: 'Sequence\[int\]' but got argument of type: 'List\[Union\[(int, str)|(str, int)\]\]'\. + Not a valid overload because: For parameter: 'n', expected an instance of type: 'Sequence\[int\]' but got argument of type: 'List\[(int \| str)|(str \| int)\]'\. """ ).strip(), ): @@ -475,7 +490,7 @@ def func(n: Sequence[Union[int, float]]) -> int: [0-9]+ \| \.\.\. \|\s - Not a valid overload because: For parameter: 'n', expected an instance of type: 'Sequence\[Union\[int, float\]\]' but got argument of type: 'List\[str\]'\. + Not a valid overload because: For parameter: 'n', expected an instance of type: 'Sequence\[int | float\]' but got argument of type: 'List\[str\]'\. """ ).strip(), ): @@ -496,16 +511,16 @@ def func(a: int, *args: int) -> int: @pytest.mark.parametrize( "typ, expected", [ - (tp.types.TensorLike, "Union[tripy.Tensor, numbers.Number]"), - (tp.types.ShapeLike, "Sequence[Union[int, tripy.DimensionSize]]"), + (tp.types.TensorLike, "tripy.Tensor | numbers.Number"), + (tp.types.ShapeLike, "Sequence[int | tripy.DimensionSize]"), (tp.Tensor, "Tensor"), (torch.Tensor, "torch.Tensor"), (int, "int"), - (Optional[int], "Optional[int]"), + (Optional[int], "int | None"), ], ) -def test_sanitize_name(typ, expected): - assert sanitize_name(typ) == expected +def test_str_from_type_annotation(typ, expected): + assert str_from_type_annotation(typ) == expected @pytest.mark.parametrize( @@ -517,5 +532,5 @@ def test_sanitize_name(typ, expected): ("hi", "str"), ], ) -def test_render_arg_type(typ, expected): - assert render_arg_type(typ) == expected +def test_type_str_from_arg(typ, expected): + assert type_str_from_arg(typ) == expected diff --git a/tripy/tripy/backend/api/executable.py b/tripy/tripy/backend/api/executable.py index d1f1736e3..7ededf394 100644 --- a/tripy/tripy/backend/api/executable.py +++ b/tripy/tripy/backend/api/executable.py @@ -23,7 +23,7 @@ from tripy.backend.mlir import utils as mlir_utils from tripy.common.exception import raise_error from tripy.frontend import Tensor -from tripy.function_registry import sanitize_name +from tripy.function_registry import str_from_type_annotation from tripy.utils import json as json_utils from dataclasses import dataclass @@ -73,8 +73,11 @@ def stream(self, stream): self._executor.stream = stream def __str__(self) -> str: - params = [f"{name}: {sanitize_name(param.annotation)}" for name, param in self.__signature__.parameters.items()] - return f"Executable({', '.join(params)}) -> {sanitize_name(self.__signature__.return_annotation)}" + params = [ + f"{name}: {str_from_type_annotation(param.annotation)}" + for name, param in self.__signature__.parameters.items() + ] + return f"Executable({', '.join(params)}) -> {str_from_type_annotation(self.__signature__.return_annotation)}" @staticmethod def load(path: str) -> "tripy.Executable": diff --git a/tripy/tripy/frontend/ops/tensor_initializers.py b/tripy/tripy/frontend/ops/tensor_initializers.py index faa395f7a..43142be2c 100644 --- a/tripy/tripy/frontend/ops/tensor_initializers.py +++ b/tripy/tripy/frontend/ops/tensor_initializers.py @@ -292,7 +292,7 @@ def arange( ) -> "tripy.Tensor": r""" Returns a 1D tensor containing a sequence of numbers in the half-open interval - :math:`[0, \text{stop})` incrementing by :math:`\text{step}`. + :math:`[\text{start}, \text{stop})` incrementing by :math:`\text{step}`. Args: start: The inclusive lower bound of the values to generate. If a tensor is provided, it must be a scalar tensor. diff --git a/tripy/tripy/function_registry.py b/tripy/tripy/function_registry.py index 4481fac6f..73f748235 100644 --- a/tripy/tripy/function_registry.py +++ b/tripy/tripy/function_registry.py @@ -20,8 +20,8 @@ from collections import OrderedDict, defaultdict from collections.abc import Sequence as ABCSequence from dataclasses import dataclass -from textwrap import dedent -from typing import Any, Callable, Dict, ForwardRef, List, Sequence, Union, get_args, get_origin +from textwrap import dedent, indent +from typing import Any, Callable, Dict, ForwardRef, List, Sequence, Tuple, Union, get_args, get_origin def get_type_name(typ): @@ -41,23 +41,31 @@ def get_type_name(typ): return module_name + typ.__qualname__ -def sanitize_name(annotation): - if get_origin(annotation) is Union and annotation._name == "Optional": - types = get_args(annotation) - return f"{annotation.__name__}[{sanitize_name(types[0])}]" +def str_from_type_annotation(annotation, postprocess_annotation=None): + postprocess_annotation = postprocess_annotation or (lambda x: x) + + if annotation is type(None): + return postprocess_annotation("None") + + if isinstance(annotation, str): + return postprocess_annotation(annotation) + + if get_origin(annotation) is Union: + types = list(get_args(annotation)) + return " | ".join(str_from_type_annotation(typ, postprocess_annotation) for typ in types) - if get_origin(annotation) in {Union, ABCSequence}: + if get_origin(annotation) in {ABCSequence, List, Tuple}: types = get_args(annotation) - return f"{annotation.__name__}[{', '.join(sanitize_name(typ) for typ in types)}]" + return f"{annotation.__name__}[{', '.join(str_from_type_annotation(typ, postprocess_annotation) for typ in types)}]" if isinstance(annotation, ForwardRef): - return annotation.__forward_arg__ + return postprocess_annotation(str(annotation.__forward_arg__)) # typing module annotations are likely to be better when pretty-printed due to including subscripts - return annotation if annotation.__module__ == "typing" else get_type_name(annotation) + return postprocess_annotation(str(annotation) if annotation.__module__ == "typing" else get_type_name(annotation)) -def render_arg_type(arg: Any) -> str: +def type_str_from_arg(arg: Any) -> str: # it is more useful to report more detailed types for sequences/tuples in error messages from typing import List, Tuple @@ -65,12 +73,12 @@ def render_arg_type(arg: Any) -> str: if len(arg) == 0: return "List" # catch inconsistencies this way - arg_types = {render_arg_type(member) for member in arg} + arg_types = {type_str_from_arg(member) for member in arg} if len(arg_types) == 1: return f"List[{list(arg_types)[0]}]" - return f"List[Union[{', '.join(arg_types)}]]" + return f"List[{' | '.join(arg_types)}]" if isinstance(arg, Tuple): - return f"Tuple[{', '.join(map(render_arg_type, arg))}]" + return f"Tuple[{', '.join(map(type_str_from_arg, arg))}]" return get_type_name(type(arg)) @@ -89,7 +97,7 @@ def __init__(self, func): # We *cannot* populate this at `__init__` time since that will be evaluated when the function # is first defined, at which point the required types in the annotations may not be accessible. # Instead, we populate this the first time the function is called. - self.annotations = None + self._annotations = None def __str__(self) -> str: from tripy.utils.utils import code_pretty_str @@ -112,39 +120,41 @@ def __str__(self) -> str: return pretty_code + "\n" def _get_annotations(self): - if self.annotations is None: - # Maps parameter names to their type annotations and a boolean indicating whether they are optional. - self.annotations: Dict[str, AnnotationInfo] = OrderedDict() - signature = inspect.signature(self.func) - for name, param in signature.parameters.items(): - if name == "self": - # Not likely to pass in the wrong `self` parameter, so we - # don't require an annotation for it. - annotation = Any - else: - assert (param.annotation and param.annotation is not signature.empty) or param.kind in { - inspect.Parameter.VAR_POSITIONAL, - inspect.Parameter.VAR_KEYWORD, - }, f"Non-variadic function parameters must have type annotations, but parameter: '{name}' of function: '{self.func.__name__}' has no type annotation!" - annotation = param.annotation - # In cases where a type is not available at the time of function definition, the type - # annotation may be provided as a string. Since we need the actual type, we just - # eval it here. - if isinstance(annotation, str): - try: - # Import tripy so we can evaluate types from within tripy. - import tripy - - annotation = eval(annotation) - except Exception as e: - raise NameError( - f"Error while evaluating type annotation: '{annotation}' for parameter: '{name}' of function: '{self.func.__name__}'." - f"\nNote: Error was: {e}" - ) + if self._annotations is not None: + return self._annotations + + # Maps parameter names to their type annotations and a boolean indicating whether they are optional. + self._annotations: Dict[str, AnnotationInfo] = OrderedDict() + signature = inspect.signature(self.func) + for name, param in signature.parameters.items(): + if name == "self": + # Not likely to pass in the wrong `self` parameter, so we + # don't require an annotation for it. + annotation = Any + else: + assert (param.annotation and param.annotation is not signature.empty) or param.kind in { + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD, + }, f"Non-variadic function parameters must have type annotations, but parameter: '{name}' of function: '{self.func.__name__}' has no type annotation!" + annotation = param.annotation + # In cases where a type is not available at the time of function definition, the type + # annotation may be provided as a string. Since we need the actual type, we just + # eval it here. + if isinstance(annotation, str): + try: + # Import tripy so we can evaluate types from within tripy. + import tripy + + annotation = eval(annotation) + except Exception as e: + raise NameError( + f"Error while evaluating type annotation: '{annotation}' for parameter: '{name}' of function: '{self.func.__name__}'." + f"\nNote: Error was: {e}" + ) - self.annotations[name] = AnnotationInfo(annotation, param.default is not signature.empty, param.kind) + self._annotations[name] = AnnotationInfo(annotation, param.default is not signature.empty, param.kind) - return self.annotations + return self._annotations def matches_arg_types(self, args, kwargs) -> "Result": from itertools import chain @@ -226,8 +236,8 @@ def matches_type(name: str, annotation: type, arg: Any) -> bool: if not matches_type(name, annotation.type_info, arg): return Result.err( [ - f"For parameter: '{name}', expected an instance of type: '{sanitize_name(annotation.type_info)}' " - f"but got argument of type: '{render_arg_type(arg)}'." + f"For parameter: '{name}', expected an instance of type: '{str_from_type_annotation(annotation.type_info)}' " + f"but got argument of type: '{type_str_from_arg(arg)}'." ] ) @@ -237,8 +247,8 @@ def matches_type(name: str, annotation: type, arg: Any) -> bool: if not matches_type(name, typ, arg): return Result.err( [ - f"For parameter: '{name}', expected an instance of type: '{sanitize_name(typ)}' " - f"but got argument of type: '{render_arg_type(arg)}'." + f"For parameter: '{name}', expected an instance of type: '{str_from_type_annotation(typ)}' " + f"but got argument of type: '{type_str_from_arg(arg)}'." ] ) elif not any(annotation.kind == inspect.Parameter.VAR_KEYWORD for annotation in annotations.values()): @@ -387,30 +397,50 @@ def prepend_signature_to_docstring(f): if not f.__doc__: return "" + roles = "" + + def add_role(name, *additional_classes): + nonlocal roles + + classes = [name] + list(additional_classes) + roles += f".. role:: {name}\n :class: {' '.join(classes)}\n" + + add_role("sig-prename", "descclassname") + add_role("sig-name", "descname") + + # We cannot use `FuncOverload._get_annotations()` here because it is too early to be able + # to import tripy to evaluate annotations. signature = inspect.signature(f) - def str_from_annotation(annotation): - if isinstance(annotation, str): - ret = annotation - else: - ret = annotation.__qualname__ - return f":class:`{ret}`" + postprocess_annotation = lambda annotation: ( + f":class:`{annotation}`" if annotation.startswith("tripy.") else annotation + ) def make_param_str(param): - param_str = f"*{param.name}*: {str_from_annotation(param.annotation)}" + param_str = ( + f"{param.name}: {str_from_type_annotation(param.annotation, postprocess_annotation)}" + ) if param.default != signature.empty: param_str += f" = {param.default}" return param_str - sig_str = f"> **{key}** ({', '.join(make_param_str(param) for param in signature.parameters.values() if param.name != 'self')}) -> " + sig_str = rf":sig-prename:`tripy`\ .\ :sig-name:`{key}`\ ({', '.join(make_param_str(param) for param in signature.parameters.values() if param.name != 'self')}) -> " if signature.return_annotation != signature.empty: - sig_str += f"{str_from_annotation(signature.return_annotation)}" + sig_str += ( + f"{str_from_type_annotation(signature.return_annotation, postprocess_annotation)}" + ) else: sig_str += "None" section_divider = "-" * 10 - return (f"""\n\n{section_divider}\n\n{sig_str}\n{dedent(f.__doc__)}""").strip() + indent_prefix = " " * 4 + # We add a special `func-overload-sig` class here so we can correct the documentation + # styling for signatures of overloaded functions. + overload_doc = ( + f"""\n\n{section_divider}\n\n{dedent(roles).strip()}\n\n.. container:: func-overload-sig sig sig-object py\n\n{indent(sig_str, indent_prefix)}\n{dedent(f.__doc__)}""" + ).strip() + return overload_doc # The first time we add an overload, we need to retroactively process the existing docstring # to add signature information. From ee30299a0399055bef96433fd9852cbd09247e24 Mon Sep 17 00:00:00 2001 From: yizhuoz004 Date: Wed, 13 Nov 2024 14:38:19 -0800 Subject: [PATCH 07/29] Change topK to 10 in nanoGPT sample (#363) Note: top K helps stablize the result a bit but not much, the result of `int8-weight-only` can vary among several outputs on CI machine, but cannot be reproduced locally. Worth looking into this issue later on. --------- Signed-off-by: yizhuoz004 --- tripy/examples/nanogpt/README.md | 11 ++++++++++- tripy/examples/nanogpt/example.py | 2 +- tripy/tests/test_examples.py | 7 ++++++- 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/tripy/examples/nanogpt/README.md b/tripy/examples/nanogpt/README.md index 179e3a9bd..a3dc4678f 100644 --- a/tripy/examples/nanogpt/README.md +++ b/tripy/examples/nanogpt/README.md @@ -38,7 +38,10 @@ for expected accuracy. @@ -62,7 +65,13 @@ To run with a quantization mode, pass `--quant-mode` to `example.py`. The suppor diff --git a/tripy/examples/nanogpt/example.py b/tripy/examples/nanogpt/example.py index 4b2ecebea..0041b6426 100644 --- a/tripy/examples/nanogpt/example.py +++ b/tripy/examples/nanogpt/example.py @@ -82,7 +82,7 @@ def main(): input_ids = encoder.encode(args.input_text, allowed_special={"<|endoftext|>"}) TEMPERATURE = 0.8 - TOP_K = 200 + TOP_K = 5 padded_seq_len = len(input_ids) + args.max_new_tokens diff --git a/tripy/tests/test_examples.py b/tripy/tests/test_examples.py index e24b906fd..3fe546268 100644 --- a/tripy/tests/test_examples.py +++ b/tripy/tests/test_examples.py @@ -103,7 +103,12 @@ def test_examples(example, sandboxed_install_run): if block.has_marker("test: expected_stdout"): print("Checking command output against expected output: ", end="") out = statuses[-1].stdout.strip() - matched = re.match(dedent(block_text).strip(), out) + matched = False + expected_outs = dedent(block_text).split("====") + for expected in expected_outs: + if re.match(expected.strip(), out): + matched = True + break print("matched!" if matched else "did not match!") print(f"==== STDOUT ====\n{out}") assert matched From a2542cf72a60e0d1454323e78b736ac6a3204c8c Mon Sep 17 00:00:00 2001 From: Christopher Bate Date: Thu, 14 Nov 2024 13:00:03 -0700 Subject: [PATCH 08/29] [tensorrt] Fix handling of dynamic shapes in `tensorrt-transpose-elimination` (#371) Fixes an issue where the `tensorrt-transpose-elimination` pass did not correctly calculate cost for tensors that are dynamically shaped. After this change, the pass now considers any dynamically shaped tensor to have maximum cost; better heuristics would require that we move transpose elimination higher in the pipeline (e.g. to Plan dialect where we might have better information about shapes). Fixes issue #368 GitOrigin-RevId: 99e6260ee0c2679e1e6539e60b7840d9a90845b9 --- .../Transforms/TransposeElimination.cpp | 5 ++++- .../TensorRT/transpose-elimination.mlir | 19 +++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeElimination.cpp b/mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeElimination.cpp index 3cd23c29b..37453885f 100644 --- a/mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeElimination.cpp +++ b/mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeElimination.cpp @@ -44,7 +44,10 @@ namespace tensorrt { using namespace mlir; using namespace mlir::tensorrt; -static int64_t memoryCost(TensorType type) { +static int64_t memoryCost(RankedTensorType type) { + // If the type is dynamic, then return max. + if (!type.hasStaticShape()) + return std::numeric_limits::max(); ArrayRef shape = type.getShape(); return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>()); } diff --git a/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/transpose-elimination.mlir b/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/transpose-elimination.mlir index e88cd63e9..66997936e 100644 --- a/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/transpose-elimination.mlir +++ b/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/transpose-elimination.mlir @@ -51,6 +51,25 @@ func.func @transpose_pushdown_noop(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32 // ----- +#map = affine_map<(d0, d1, d2) -> (d0, d2, d1)> + +func.func @tranpose_pushdown_dynamic(%arg0: tensor) -> tensor { + %cst_f32 = tensorrt.constant dense<1.000000e+00> : tensor<1x1x1xf32> + %1 = tensorrt.transpose {permutation = #map} %arg0 : tensor to tensor + %2 = tensorrt.element_wise (%cst_f32, %1 : tensor<1x1x1xf32>, tensor) -> tensor + return %2 : tensor +} + +// CHECK: #[[$map:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)> +// CHECK-LABEL: func.func @tranpose_pushdown_dynamic +// CHECK-SAME: (%[[arg0:.+]]: tensor) -> tensor : tensor<1x1x1xf32> +// CHECK-DAG: %[[v0:.+]] = tensorrt.element_wise (%[[cst_f32]], %[[arg0]] : tensor<1x1x1xf32>, tensor) -> tensor +// CHECK-DAG: %[[v1:.+]] = tensorrt.transpose {permutation = #[[$map]]} %[[v0]] : tensor to tensor +// CHECK-DAG: return %[[v1]] : tensor + +// ----- + func.func @transpose_pushdown_switch(%arg0: tensor<2x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<2x2xf32> { %1 = tensorrt.transpose {permutation = affine_map<(d0, d1)->(d1, d0)>} %arg0 : tensor<2x2xf32> to tensor<2x2xf32> %2 = tensorrt.element_wise (%1, %arg1: tensor<2x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> From 9dc1bd85094951dbafa0a20eaed27e6ab136bc12 Mon Sep 17 00:00:00 2001 From: Sagar Shelke Date: Thu, 14 Nov 2024 12:24:00 -0800 Subject: [PATCH 09/29] Integrate internal changes (#374) This is a combination of the following commits: ## [compiler] Drop the StableHLO "signature refinement" API Previously, we added a special API for TriPy to be able to perform only pre-processing (including entrypoint signature type refinement) on the input module. This path is no longer needed by TriPy, and since there are no other customers, it can be safely dropped from the C/C++/Python APIs. ## [executor/runtime] Update NCCL module to enable non-blocking communicator Updates the NCCL runtime module so that the communicators are non-blocking and so that more consistent logic is used for handling errors. This helps resolve issues where the test may deadlock (either because of an incorrect runtime implementation issue or because of system config issue) without errors being printed to stderr. Additional TODOs are noted where the implementation can be further improved. GitOrigin-RevId: 5321de2a3d779500436c7a62097e0fc219958caf Co-authored-by: Christopher Bate --- .../mlir-tensorrt-c/Compiler/Compiler.h | 32 --- .../Compiler/StableHloToExecutable.h | 35 --- .../compiler/lib/CAPI/Compiler/Compiler.cpp | 42 --- .../lib/Compiler/StableHloToExecutable.cpp | 62 ----- .../Runtime/Backend/Lua/LuaErrorHandling.h | 10 +- .../mlir-executor/Runtime/Support/Support.h | 7 +- .../include/mlir-executor/Support/Status.h | 4 +- .../Backend/Lua/Modules/NCCL/NCCLModule.cpp | 245 ++++++++++++------ .../bindings/Compiler/CompilerPyBind.cpp | 19 -- .../compiler_api/test_refine_signature.py | 28 -- 10 files changed, 186 insertions(+), 298 deletions(-) delete mode 100644 mlir-tensorrt/test/python/mlir_tensorrt_compiler/compiler_api/test_refine_signature.py diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h b/mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h index 4d3a06610..64cf76258 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h @@ -108,38 +108,6 @@ MLIR_CAPI_EXPORTED MTRT_Status mtrtCompilerStableHLOToExecutable( MTRT_CompilerClient client, MlirOperation module, MTRT_StableHLOToExecutableOptions options, MTRT_Executable *result); -//===----------------------------------------------------------------------===// -// MTRT_StableHLOProgramSignatureRefinementOptions -//===----------------------------------------------------------------------===// - -/// Options for compiling StableHLO MLIR to an Executable. -typedef struct MTRT_StableHLOProgramSignatureRefinementOptions { - void *ptr; -} MTRT_StableHLOProgramSignatureRefinementOptions; - -MLIR_CAPI_EXPORTED MTRT_Status -mtrtStableHloProgramSignatureRefinementOptionsCreate( - MTRT_StringView funcName, - MTRT_StableHLOProgramSignatureRefinementOptions *options); - -MLIR_CAPI_EXPORTED MTRT_Status -mtrtStableHloProgramSignatureRefinementOptionsDestroy( - MTRT_StableHLOProgramSignatureRefinementOptions options); - -static inline bool mtrtStableHloProgramSignatureRefinementOptionsIsNull( - MTRT_StableHLOProgramSignatureRefinementOptions options) { - return !options.ptr; -} - -//===----------------------------------------------------------------------===// -// Main StableHLO Program Signature Refinement API Functions -//===----------------------------------------------------------------------===// - -/// Compiler StableHLO to Executable. -MLIR_CAPI_EXPORTED MTRT_Status mtrtGetStableHloProgramRefinedSignature( - MTRT_CompilerClient client, MlirOperation module, - MTRT_StableHLOProgramSignatureRefinementOptions options, MlirType *result); - #ifdef __cplusplus } #endif diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StableHloToExecutable.h b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StableHloToExecutable.h index 9a1b40f7c..0e3b8e3ab 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StableHloToExecutable.h +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StableHloToExecutable.h @@ -44,41 +44,6 @@ namespace mlirtrt::compiler { -//===----------------------------------------------------------------------===// -// StableHLOProgramSignatureRefinementOptions -//===----------------------------------------------------------------------===// - -struct StableHLOProgramSignatureRefinementOptions - : public mlir::OptionsContext { - /// Creates default compilation options. - StableHLOProgramSignatureRefinementOptions() { - this->addOption("func-name", funcName, llvm::cl::init("main")); - debugOptions.addToOptions(*this); - } - - /// Set the entrypoint function name. - StableHLOProgramSignatureRefinementOptions & - setFuncName(const std::string &name) { - funcName = name; - return *this; - } - - std::string funcName = "main"; - - DebugOptions debugOptions; -}; - -//===----------------------------------------------------------------------===// -// StableHLO Signature Refinement Entrypoint -//===----------------------------------------------------------------------===// - -/// Attempt to refine the function signature of a StableHLO program through -/// canonicalization and constant folding. Returns the refined signature of the -/// specified function of the module. -mlirtrt::StatusOr getStableHLOProgramRefinedSignature( - CompilerClient &client, mlir::ModuleOp module, - const StableHLOProgramSignatureRefinementOptions &options); - //===----------------------------------------------------------------------===// // StableHLOToExecutableOptions //===----------------------------------------------------------------------===// diff --git a/mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp b/mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp index 9be562712..6636768b0 100644 --- a/mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp +++ b/mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp @@ -44,8 +44,6 @@ using namespace mlir; DEFINE_C_API_PTR_METHODS(MTRT_CompilerClient, CompilerClient) DEFINE_C_API_PTR_METHODS(MTRT_StableHLOToExecutableOptions, StableHLOToExecutableOptions) -DEFINE_C_API_PTR_METHODS(MTRT_StableHLOProgramSignatureRefinementOptions, - StableHLOProgramSignatureRefinementOptions) #if defined(__GNUC__) || defined(__clang__) #pragma GCC diagnostic pop #endif @@ -255,43 +253,3 @@ MTRT_Status mtrtCompilerStableHLOToExecutable( return mtrtStatusGetOk(); } - -//===----------------------------------------------------------------------===// -// Main StableHLO Program Signature Refinement Functions -//===----------------------------------------------------------------------===// - -MTRT_Status mtrtStableHloProgramSignatureRefinementOptionsCreate( - MTRT_StringView funcName, - MTRT_StableHLOProgramSignatureRefinementOptions *options) { - auto result = std::make_unique(); - result->setFuncName(std::string(funcName.data, funcName.length)); - *options = wrap(result.release()); - return mtrtStatusGetOk(); -} - -MTRT_Status mtrtStableHloProgramSignatureRefinementOptionsDestroy( - MTRT_StableHLOProgramSignatureRefinementOptions options) { - delete unwrap(options); - return mtrtStatusGetOk(); -} - -MTRT_Status mtrtGetStableHloProgramRefinedSignature( - MTRT_CompilerClient client, MlirOperation module, - MTRT_StableHLOProgramSignatureRefinementOptions options, MlirType *result) { - ModuleOp moduleOp = llvm::dyn_cast(unwrap(module)); - if (!moduleOp) - return mtrtStatusCreate( - MTRT_StatusCode::MTRT_StatusCode_InvalidArgument, - "StableHLO program signature refinement expects a ModuleOp"); - - StatusOr funcType = - compiler::getStableHLOProgramRefinedSignature(*unwrap(client), moduleOp, - *unwrap(options)); - if (!funcType.isOk()) - return mtrtStatusCreate(MTRT_StatusCode::MTRT_StatusCode_InvalidArgument, - funcType.getString().c_str()); - - *result = wrap(mlir::Type(*funcType)); - - return mtrtStatusGetOk(); -} diff --git a/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp b/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp index 50f146645..537e4567f 100644 --- a/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp +++ b/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp @@ -144,68 +144,6 @@ class HloToStdPass }; } // namespace -//===----------------------------------------------------------------------===// -// StableHLO Signature Refinement Entrypoint -//===----------------------------------------------------------------------===// - -mlirtrt::StatusOr -compiler::getStableHLOProgramRefinedSignature( - CompilerClient &client, mlir::ModuleOp module, - const StableHLOProgramSignatureRefinementOptions &options) { - -#ifndef NDEBUG - //===----------------------------------------------------------------------===// - // Set debug options. - //===----------------------------------------------------------------------===// - if (options.debugOptions.enableLLVMDebugFlag) { - SmallVector debugTypeLiterals = - llvm::map_to_vector(options.debugOptions.llvmDebugTypes, - [](const std::string &x) { return x.c_str(); }); - llvm::setCurrentDebugTypes(debugTypeLiterals.data(), - debugTypeLiterals.size()); - llvm::DebugFlag = true; - } -#endif - - //===----------------------------------------------------------------------===// - // Setup pass manager - //===----------------------------------------------------------------------===// - - mlir::PassManager pm(module->getContext()); - if (failed(setupPassManager(pm, options.debugOptions))) { - /// TODO: Ignored. This can fail if pass manager static CL options were not - /// registered/initialized. This happens through invocation of e.g. this - /// function in e.g. Python bindings or standalone calls to C++ or C API - /// without doing all the typical static CL setup. We should instead be - /// accepting a PassManager here that has already been setup to the caller's - /// specifications. - } - - // Add pre-processing passes. - { - mlir::StableHloInputOptions opts{}; - opts.legalizeControlFlowToSCF = false; - opts.preserveChloErf = true; - opts.preserveChloTopK = true; - mlir::buildStablehloPreProcessingPipeline(pm, opts); - } - - // Run pass pipeline. - if (mlir::failed(pm.run(module))) - return getStatusWithMsg(StatusCode::InternalError, - "failed to run compilation pipeline"); - - // Get the signature. - auto func = llvm::dyn_cast_or_null( - module.lookupSymbol(options.funcName)); - if (!func) - return getInvalidArgStatus( - "function with name {0} does not exist in the MLIR module", - options.funcName); - - return func.getFunctionType(); -} - //===----------------------------------------------------------------------===// // StableHLOToExecutableOptions //===----------------------------------------------------------------------===// diff --git a/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaErrorHandling.h b/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaErrorHandling.h index 9b4ab3325..a08264085 100644 --- a/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaErrorHandling.h +++ b/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaErrorHandling.h @@ -67,12 +67,16 @@ } \ } while (false) -#define SET_LUA_ERROR_IF_NCCL_ERROR(x, lstate) \ +#define SET_LUA_ERROR_IF_NCCL_ERROR(x, lstate, comm) \ do { \ ncclResult_t err = (x); \ - if (err != ncclSuccess) { \ + if (err != ncclSuccess && err != ncclInProgress) { \ lua_State *L = lstate; \ - luaL_error(L, ncclGetLastError(nullptr)); \ + std::string msg = llvm::formatv( \ + "{0}:{1} NCCL error [msg=\"{2}\" ncclGetLastError=\"{3}\"]", \ + __FILE__, __LINE__, ncclGetErrorString(err), \ + comm ? ncclGetLastError(comm) : ""); \ + luaL_error(L, msg.c_str()); \ } \ } while (false) diff --git a/mlir-tensorrt/executor/include/mlir-executor/Runtime/Support/Support.h b/mlir-tensorrt/executor/include/mlir-executor/Runtime/Support/Support.h index b277153f8..783c064c5 100644 --- a/mlir-tensorrt/executor/include/mlir-executor/Runtime/Support/Support.h +++ b/mlir-tensorrt/executor/include/mlir-executor/Runtime/Support/Support.h @@ -35,8 +35,13 @@ namespace mlirtrt::runtime { // Debugging and logging tools //===----------------------------------------------------------------------===// +/// Prints the given printf-style formatted data to stderr if the 'runtime' +/// debug module is enabled. Has no effect in non-assert builds. +/// Note that we prepend a space to assist with readability when the logs are +/// prefixed by other text when wrapped by another runtime system (e.g. +/// 'mpirun'). #define MTRT_DBGF(fmt, ...) \ - DEBUG_WITH_TYPE("runtime", fprintf(stderr, "%s:%d " fmt "\n", __FILE__, \ + DEBUG_WITH_TYPE("runtime", fprintf(stderr, " %s:%d " fmt "\n", __FILE__, \ __LINE__, __VA_ARGS__)) template diff --git a/mlir-tensorrt/executor/include/mlir-executor/Support/Status.h b/mlir-tensorrt/executor/include/mlir-executor/Support/Status.h index f21911530..206415c44 100644 --- a/mlir-tensorrt/executor/include/mlir-executor/Support/Status.h +++ b/mlir-tensorrt/executor/include/mlir-executor/Support/Status.h @@ -217,10 +217,12 @@ class StatusOr { } \ } while (false); +/// Causes returning an InternalError status from the current scope if the NCCL +/// result is not ncclSuccess or ncclInProgress. #define RETURN_ERROR_IF_NCCL_ERROR(x, comm) \ do { \ ncclResult_t err = (x); \ - if (err != ncclSuccess) { \ + if (err != ncclSuccess && err != ncclInProgress) { \ return getInternalErrorStatus( \ "{0}:{1} NCCL error [msg=\"{2}\" ncclGetLastError=\"{3}\"]", \ __FILE__, __LINE__, ncclGetErrorString(err), \ diff --git a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/NCCL/NCCLModule.cpp b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/NCCL/NCCLModule.cpp index c3fb40c0e..c78ec810c 100644 --- a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/NCCL/NCCLModule.cpp +++ b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/NCCL/NCCLModule.cpp @@ -26,6 +26,7 @@ #include "mlir-executor/Runtime/Backend/Common/CUDACommon.h" #include "mlir-executor/Runtime/Backend/Common/CommonRuntime.h" #include "mlir-executor/Runtime/Backend/Lua/LuaErrorHandling.h" +#include #define OMPI_SKIP_MPICXX #if defined(__clang__) || defined(__GNUC__) @@ -54,6 +55,82 @@ struct NcclCommunicator { int32_t numRanks; }; +/// A simple RAII class that checks whether a specified `limit` number of +/// milliseconds have elapsed since the objects creation. +class TimeoutChecker { +public: + explicit TimeoutChecker(std::chrono::milliseconds limit) + : start(std::chrono::steady_clock::now()), limit(limit) {} + + /// Check whether the timeout limit has been exceeded and update internal + /// flag. + bool operator()() { + auto now = std::chrono::steady_clock::now(); + auto elapsed = + std::chrono::duration_cast(now - start); + exceeded = elapsed > limit; + return exceeded; + } + + /// Return true if the timeout was exceeded in the last call. + bool exceededLimit() const { return exceeded; } + +private: + bool exceeded = false; + std::chrono::steady_clock::time_point start; + std::chrono::milliseconds limit; +}; +} // namespace + +// Wait for the status of `comm` to be `ncclSuccess`. If an async error is +// detected, then returns an error status. +static Status waitUntilNcclCommunicatorIsReady( + ncclComm_t comm, + std::chrono::milliseconds timeout = std::chrono::milliseconds(5000)) { + TimeoutChecker checkTimeout(timeout); + ncclResult_t state; + do { + ncclResult_t getLastErrorResult = ncclCommGetAsyncError(comm, &state); + if (getLastErrorResult != ncclSuccess) + return getInternalErrorStatus("NCCL failed on ncclCommGetAsyncError"); + } while (state == ncclInProgress && checkTimeout() != true); + if (checkTimeout.exceededLimit()) + return getInternalErrorStatus("timed out waiting for NCCL operations to " + "compelete"); + if (ncclSuccess != state) + return getInternalErrorStatus("NCCL experienced an async error: {0}", + ncclGetErrorString(state)); + return getOkStatus(); +} + +/// Cleans up a specified `NcclCommunicator`. +/// TODO: we should add an error flag to the `NcclCommunicator` struct. If the +/// communicator is in an error state, then we should invoke `ncclCommAbort` +/// instead of the finalization+destruction sequence. +static void destroyNcclCommunicator(uintptr_t ptr) { + auto *obj = reinterpret_cast(ptr); + + if (obj->comm) { + MTRT_DBGF("Destroying NCCL communicator: %lu", + reinterpret_cast(obj->comm)); + Status waitStatus = waitUntilNcclCommunicatorIsReady(obj->comm); + if (!waitStatus.isOk()) { + llvm::errs() << "Error while waiting for NCCL communicator to be ready " + "prior to finalizing: " + << waitStatus.getString() << "\n"; + } + ncclResult_t ncclErr; + ncclErr = ncclCommDestroy(obj->comm); + if (ncclErr != ncclSuccess && ncclErr != ncclInProgress) { + llvm::errs() << "ncclCommDestroy error: " << ncclGetErrorString(ncclErr) + << "\n"; + } + obj->comm = nullptr; + } + delete obj; +} + +namespace { /// RAII wrapper for NCCL communicator. struct NcclCommWrapper : public PointerWrapper { using PointerWrapper::PointerWrapper; @@ -62,55 +139,64 @@ struct NcclCommWrapper : public PointerWrapper { ncclUniqueId commId, int32_t rank) { ncclComm_t comm; - RETURN_ERROR_IF_NCCL_ERROR(ncclCommInitRank(&comm, numRanks, commId, rank), - comm); + ncclConfig_t config = NCCL_CONFIG_INITIALIZER; + config.blocking = 0; + RETURN_ERROR_IF_NCCL_ERROR( + ncclCommInitRankConfig(&comm, numRanks, commId, rank, &config), + nullptr); + + // We cannot split until the source communicator is ready. + RETURN_STATUS_IF_ERROR(waitUntilNcclCommunicatorIsReady(comm)); + + /// TODO: sync abortFlag among all healthy ranks and abort if there is an + /// error status. MTRT_DBGF("Created NCCL communicator: %lu", reinterpret_cast(comm)); NcclCommunicator *result = new NcclCommunicator{comm, rank, numRanks}; - tracker->track(reinterpret_cast(result), [](uintptr_t ptr) { - auto *obj = reinterpret_cast(ptr); - if (obj->comm) { - MTRT_DBGF("Destroying NCCL communicator: %lu", - reinterpret_cast(obj->comm)); - ncclCommFinalize(obj->comm); - ncclCommDestroy(obj->comm); - obj->comm = nullptr; - } - delete obj; - }); + tracker->track(reinterpret_cast(result), + destroyNcclCommunicator); return result; } - // Create a new communicator by splitting an existing one. + /// Create a new communicator by splitting an existing one. + /// TODO: refactoring thandling of `NcclCommunicator` to not require + /// packaging rank + num_ranks with the communicator would enable + /// interleaved work between `ncclCommSplit` and the first use of + /// the communicator (and where `waitUntilNcclCommunicator` is ready would be + /// called). static StatusOr create(ResourceTracker *tracker, - ncclComm_t comm, int32_t color, - int32_t key) { + NcclCommunicator *comm, + int32_t color, int32_t key) { #if NCCL_VERSION_CODE < NCCL_VERSION(2, 18, 1) return getStatusWithMsg( StatusCode::InternalError, "NCCL 2.18.1 or greater is required for ncclCommSplit."); #else + // We cannot split until the source communicator is ready. + RETURN_STATUS_IF_ERROR(waitUntilNcclCommunicatorIsReady(comm->comm)); + + // Create a non-blocking communicator with shared resources. ncclComm_t newComm = nullptr; + ncclConfig_t config = NCCL_CONFIG_INITIALIZER; + config.blocking = 0; + config.splitShare = 1; RETURN_ERROR_IF_NCCL_ERROR( - ncclCommSplit(comm, color, key, &newComm, /*config=*/nullptr), comm); + ncclCommSplit(comm->comm, color, key, &newComm, &config), comm->comm); MTRT_DBGF("Created NCCL communicator via split: %lu", reinterpret_cast(comm)); + + // The communicator is non-blocking, but we use it immediately below, so we + // must wait until it is ready. For split, this is done by waiting on the + // source communicator. + RETURN_STATUS_IF_ERROR(waitUntilNcclCommunicatorIsReady(comm->comm)); + int32_t rank, numRanks; - RETURN_ERROR_IF_NCCL_ERROR(ncclCommUserRank(newComm, &rank), comm); - RETURN_ERROR_IF_NCCL_ERROR(ncclCommCount(newComm, &numRanks), comm); + RETURN_ERROR_IF_NCCL_ERROR(ncclCommUserRank(newComm, &rank), comm->comm); + RETURN_ERROR_IF_NCCL_ERROR(ncclCommCount(newComm, &numRanks), comm->comm); NcclCommunicator *result = new NcclCommunicator{newComm, rank, numRanks}; - tracker->track(reinterpret_cast(result), [](uintptr_t ptr) { - auto *obj = reinterpret_cast(ptr); - if (obj->comm) { - MTRT_DBGF("Destroying NCCL communicator: %lu", - reinterpret_cast(obj->comm)); - ncclCommFinalize(obj->comm); - ncclCommDestroy(obj->comm); - obj->comm = nullptr; - } - delete obj; - }); + tracker->track(reinterpret_cast(result), + destroyNcclCommunicator); return result; #endif @@ -144,8 +230,9 @@ static void registerNcclOps(sol::state_view &lua, ResourceTracker *tracker) { lua["__nccl_comm_split"] = [tracker](sol::this_state state, uintptr_t comm, int32_t color, int32_t key) -> uintptr_t { + MTRT_DBGF("__nccl_comm_split comm=0x%lx color=%d key=%d", comm, color, key); StatusOr newComm = NcclCommWrapper::create( - tracker, reinterpret_cast(comm)->comm, color, key); + tracker, reinterpret_cast(comm), color, key); SET_LUA_ERROR_AND_RETURN_IF_ERROR(newComm, state, 0); return reinterpret_cast(*newComm); }; @@ -180,13 +267,15 @@ static void registerNcclOps(sol::state_view &lua, ResourceTracker *tracker) { #define DEFINE_NCCL_ALL_REDUCE_METHOD(opsuffix, op, typesuffix, type) \ lua["__nccl_all_reduce_" #opsuffix "_" #typesuffix] = \ [](sol::this_state state, ExecPtr sendbuff, ExecPtr recvbuff, \ - size_t count, uintptr_t comm, CudaStreamPtr stream) { \ + size_t count, uintptr_t communicator, CudaStreamPtr stream) { \ + auto comm = reinterpret_cast(communicator); \ SET_LUA_ERROR_IF_NCCL_ERROR( \ ncclAllReduce(reinterpret_cast(sendbuff), \ reinterpret_cast(recvbuff), count, type, op, \ - reinterpret_cast(comm)->comm, \ - stream), \ - state); \ + comm->comm, stream), \ + state, comm->comm); \ + SET_LUA_ERROR_AND_RETURN_IF_ERROR( \ + waitUntilNcclCommunicatorIsReady(comm->comm), state, ); \ }; CALL_FOR_ALL_REDOPS_AND_TYPES(DEFINE_NCCL_ALL_REDUCE_METHOD) @@ -195,13 +284,15 @@ static void registerNcclOps(sol::state_view &lua, ResourceTracker *tracker) { #define DEFINE_NCCL_REDUCE_SCATTER_METHOD(opsuffix, op, typesuffix, type) \ lua["__nccl_reduce_scatter_" #opsuffix "_" #typesuffix] = \ [](sol::this_state state, ExecPtr sendbuff, ExecPtr recvbuff, \ - size_t recvcount, uintptr_t comm, CudaStreamPtr stream) { \ + size_t recvcount, uintptr_t communicator, CudaStreamPtr stream) { \ + auto *comm = reinterpret_cast(communicator); \ SET_LUA_ERROR_IF_NCCL_ERROR( \ - ncclReduceScatter( \ - reinterpret_cast(sendbuff), \ - reinterpret_cast(recvbuff), recvcount, type, op, \ - reinterpret_cast(comm)->comm, stream), \ - state); \ + ncclReduceScatter(reinterpret_cast(sendbuff), \ + reinterpret_cast(recvbuff), recvcount, \ + type, op, comm->comm, stream), \ + state, comm->comm); \ + SET_LUA_ERROR_AND_RETURN_IF_ERROR( \ + waitUntilNcclCommunicatorIsReady(comm->comm), state, ); \ }; CALL_FOR_ALL_REDOPS_AND_TYPES(DEFINE_NCCL_REDUCE_SCATTER_METHOD) @@ -209,63 +300,67 @@ static void registerNcclOps(sol::state_view &lua, ResourceTracker *tracker) { lua["__nccl_all_gather"] = [](sol::this_state state, ExecPtr sendbuff, ExecPtr recvbuff, size_t sendNumBytes, - uintptr_t comm, CudaStreamPtr stream) { + uintptr_t communicator, CudaStreamPtr stream) { + auto *comm = reinterpret_cast(communicator); SET_LUA_ERROR_IF_NCCL_ERROR( ncclAllGather(reinterpret_cast(sendbuff), reinterpret_cast(recvbuff), sendNumBytes, - ncclInt8, - reinterpret_cast(comm)->comm, stream), - state); + ncclInt8, comm->comm, stream), + state, comm->comm); + SET_LUA_ERROR_AND_RETURN_IF_ERROR( + waitUntilNcclCommunicatorIsReady(comm->comm), state, ); }; lua["__nccl_all_to_all"] = [](sol::this_state state, ExecPtr sendbuff, ExecPtr recvbuff, size_t numBytes, - uintptr_t comm, CudaStreamPtr stream) { - size_t sendBytes = - numBytes / reinterpret_cast(comm)->numRanks; - SET_LUA_ERROR_IF_NCCL_ERROR(ncclGroupStart(), state); - for (int r = 0; r < reinterpret_cast(comm)->numRanks; - ++r) { + uintptr_t communicator, CudaStreamPtr stream) { + auto *comm = reinterpret_cast(communicator); + size_t sendBytes = numBytes / comm->numRanks; + SET_LUA_ERROR_IF_NCCL_ERROR(ncclGroupStart(), state, comm->comm); + for (int r = 0; r < comm->numRanks; ++r) { SET_LUA_ERROR_IF_NCCL_ERROR( ncclSend(reinterpret_cast(sendbuff + r * sendBytes), - sendBytes, ncclInt8, r, - reinterpret_cast(comm)->comm, stream), - state); + sendBytes, ncclInt8, r, comm->comm, stream), + state, comm->comm); SET_LUA_ERROR_IF_NCCL_ERROR( ncclRecv(reinterpret_cast(recvbuff + r * sendBytes), - sendBytes, ncclInt8, r, - reinterpret_cast(comm)->comm, stream), - state); + sendBytes, ncclInt8, r, comm->comm, stream), + state, comm->comm); } - SET_LUA_ERROR_IF_NCCL_ERROR(ncclGroupEnd(), state); + SET_LUA_ERROR_IF_NCCL_ERROR(ncclGroupEnd(), state, comm->comm); + + SET_LUA_ERROR_AND_RETURN_IF_ERROR( + waitUntilNcclCommunicatorIsReady(comm->comm), state, ); }; lua["__nccl_permute"] = [](sol::this_state state, ExecPtr sendbuff, ExecPtr recvbuff, int32_t sendId, int32_t recvId, - size_t numBytes, uintptr_t comm, + size_t numBytes, uintptr_t communicator, CudaStreamPtr stream) { - SET_LUA_ERROR_IF_NCCL_ERROR(ncclGroupStart(), state); - if (sendId != -1) { - SET_LUA_ERROR_IF_NCCL_ERROR( - ncclSend(reinterpret_cast(sendbuff), numBytes, ncclInt8, - sendId, reinterpret_cast(comm)->comm, - stream), - state); - } - if (recvId != -1) { - SET_LUA_ERROR_IF_NCCL_ERROR( - ncclRecv(reinterpret_cast(recvbuff), numBytes, ncclInt8, - recvId, reinterpret_cast(comm)->comm, - stream), - state); - } else { + auto *comm = reinterpret_cast(communicator); + if (recvId == -1) { // Zero out recvbuff if not receiving. SET_LUA_ERROR_IF_CUDA_ERROR( cuMemsetD8Async(static_cast(recvbuff), 0, numBytes, stream), state); } - SET_LUA_ERROR_IF_NCCL_ERROR(ncclGroupEnd(), state); + SET_LUA_ERROR_IF_NCCL_ERROR(ncclGroupStart(), state, comm->comm); + if (sendId != -1) { + SET_LUA_ERROR_IF_NCCL_ERROR(ncclSend(reinterpret_cast(sendbuff), + numBytes, ncclInt8, sendId, + comm->comm, stream), + state, comm->comm); + } + if (recvId != -1) { + SET_LUA_ERROR_IF_NCCL_ERROR(ncclRecv(reinterpret_cast(recvbuff), + numBytes, ncclInt8, recvId, + comm->comm, stream), + state, comm->comm); + } + SET_LUA_ERROR_IF_NCCL_ERROR(ncclGroupEnd(), state, comm->comm); + SET_LUA_ERROR_AND_RETURN_IF_ERROR( + waitUntilNcclCommunicatorIsReady(comm->comm), state, ); }; } diff --git a/mlir-tensorrt/python/bindings/Compiler/CompilerPyBind.cpp b/mlir-tensorrt/python/bindings/Compiler/CompilerPyBind.cpp index 284f12f40..912df1fcf 100644 --- a/mlir-tensorrt/python/bindings/Compiler/CompilerPyBind.cpp +++ b/mlir-tensorrt/python/bindings/Compiler/CompilerPyBind.cpp @@ -325,25 +325,6 @@ PYBIND11_MODULE(_api, m) { }, py::arg("client"), py::arg("module"), py::arg("options")); - m.def( - "get_stablehlo_program_refined_signature", - [](PyCompilerClient &client, MlirOperation module, std::string funcName) { - MlirType signature{nullptr}; - MTRT_StableHLOProgramSignatureRefinementOptions options{nullptr}; - MTRT_Status status = - mtrtStableHloProgramSignatureRefinementOptionsCreate( - mtrtStringViewCreate(funcName.c_str(), funcName.size()), - &options); - THROW_IF_MTRT_ERROR(status); - status = mtrtGetStableHloProgramRefinedSignature(client, module, - options, &signature); - THROW_IF_MTRT_ERROR(status); - status = mtrtStableHloProgramSignatureRefinementOptionsDestroy(options); - THROW_IF_MTRT_ERROR(status); - return signature; - }, - py::arg("client"), py::arg("module"), py::arg("func_name")); - #ifdef MLIR_TRT_TARGET_TENSORRT #if MLIR_TRT_COMPILE_TIME_TENSORRT_VERSION_GTE(10, 0, 0) bindTensorRTPluginAdaptorObjects(m); diff --git a/mlir-tensorrt/test/python/mlir_tensorrt_compiler/compiler_api/test_refine_signature.py b/mlir-tensorrt/test/python/mlir_tensorrt_compiler/compiler_api/test_refine_signature.py deleted file mode 100644 index ee82c638e..000000000 --- a/mlir-tensorrt/test/python/mlir_tensorrt_compiler/compiler_api/test_refine_signature.py +++ /dev/null @@ -1,28 +0,0 @@ -# RUN: %PYTHON %s | FileCheck %s -from pathlib import Path - -import mlir_tensorrt.compiler.api as api -import mlir_tensorrt.compiler.ir as ir - -CANONICALIZER_STRESS_TEST_ASM = ( - Path(__file__).parent.parent.parent.parent - / "Pipelines" - / "StableHloInputPipeline" - / "canonicalizer-stress-test.mlir" -).read_text() - - -def refine_signature(ASM): - with ir.Context() as context: - m = ir.Module.parse(ASM) - client = api.CompilerClient(context) - refined_func_type = api.get_stablehlo_program_refined_signature( - client, m.operation, "main" - ) - print(f"Refined func type: {refined_func_type}") - - -print("Testing StableHlo Program Signature Refinement") -refine_signature(CANONICALIZER_STRESS_TEST_ASM) -# CHECK-LABEL: Testing StableHlo Program Signature Refinement -# CHECK: Refined func type: () -> tensor<4xi32> From f3ec678cc24889e6b430b1e4250854d88102e9cc Mon Sep 17 00:00:00 2001 From: yizhuoz004 Date: Thu, 14 Nov 2024 12:29:02 -0800 Subject: [PATCH 10/29] [StableHloExt] Add ReifyRankedShapedTypeOpInterface for ReduceWindowOp (#370) This change adds a ReifyRankedShapedTypeOpInterface interface for ReduceWindowOp to enable dynamic shape inference. --- .../IR/StableHloReifyTypeInterfaceImpl.cpp | 66 ++++++++++++++ .../Plan/materialize-shape-calculations.mlir | 27 ++++++ .../reify-ranked-shaped-type.mlir | 88 +++++++++++++++++++ 3 files changed, 181 insertions(+) diff --git a/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/IR/StableHloReifyTypeInterfaceImpl.cpp b/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/IR/StableHloReifyTypeInterfaceImpl.cpp index 4c0978124..250fefdec 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/IR/StableHloReifyTypeInterfaceImpl.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/IR/StableHloReifyTypeInterfaceImpl.cpp @@ -279,6 +279,67 @@ class ConvolutionReifyRankedShapedTypeOpInterfaceImpl return success(); } }; + +class ReduceWindowReifyRankedShapedTypeOpInterfaceImpl + : public ReifyRankedShapedTypeOpInterface::ExternalModel< + ReduceWindowReifyRankedShapedTypeOpInterfaceImpl, + stablehlo::ReduceWindowOp> { +public: + LogicalResult + reifyResultShapes(Operation *op_, OpBuilder &builder, + ReifiedRankedShapedTypeDims &reifiedReturnShapes) const { + auto op = cast(op_); + Location loc = op.getLoc(); + + FailureOr>> padding = + convertPaddingAttribute(op.getPadding(), loc); + if (failed(padding)) + return failure(); + + // In ReduceWindowOp, size of window_dim, padding, stride and dilation all + // equal to input rank. So the output shape is inferred altogether. + SmallVector windowDims = llvm::to_vector(op.getWindowDimensions()); + SmallVector windowDimensionVals(windowDims.size()); + for (size_t i = 0; i < windowDims.size(); i++) + windowDimensionVals[i] = + builder.createOrFold(loc, windowDims[i]); + + FailureOr> windowOrErr = + getWindowDimensionInfo( + windowDimensionVals, + op.getWindowStrides().value_or(ArrayRef{}), *padding, + op.getBaseDilations().value_or(ArrayRef{}), + op.getWindowDilations().value_or(ArrayRef{}), + ArrayRef{}, loc); + if (failed(windowOrErr)) + return failure(); + + int64_t inputRank = static_cast(windowDims.size()); + SmallVector inputDimVals(inputRank); + for (int64_t i = 0; i < inputRank; ++i) + inputDimVals[i] = getDimExtent(builder, loc, op.getInputs().front(), i); + + SmallVector resultShape = + inferWindowOutputShape(builder, loc, inputDimVals, *windowOrErr); + + // Fixup the result to enforce the required convention for + // `reifyResultShapes` -- if the dimension is dynamic and we infer a static + // integer extent, we must still return a Value. Likewise, the above routine + // may produce a `Value` even though the result type already contains a + // known fixed extent. + RankedTensorType resultType = cast(op.getType(0)); + for (auto [idx, ofr] : llvm::enumerate(resultShape)) { + assert(ofr && "result shape is missing a value"); + if (resultType.isDynamicDim(idx) && !ofr.is()) + resultShape[idx] = getValueOrCreateConstantIndexOp(builder, loc, ofr); + if (!resultType.isDynamicDim(idx) && !ofr.is()) + resultShape[idx] = builder.getIndexAttr(resultType.getDimSize(idx)); + } + + reifiedReturnShapes.emplace_back(std::move(resultShape)); + return success(); + } +}; } // namespace void stablehlo::registerTypeInferenceExternalModels(DialectRegistry ®istry) { @@ -287,4 +348,9 @@ void stablehlo::registerTypeInferenceExternalModels(DialectRegistry ®istry) { stablehlo::ConvolutionOp::attachInterface< ConvolutionReifyRankedShapedTypeOpInterfaceImpl>(*ctx); }); + registry.addExtension( + +[](MLIRContext *ctx, stablehlo::StablehloDialect *dialect) { + stablehlo::ReduceWindowOp::attachInterface< + ReduceWindowReifyRankedShapedTypeOpInterfaceImpl>(*ctx); + }); } diff --git a/mlir-tensorrt/test/Dialect/Plan/materialize-shape-calculations.mlir b/mlir-tensorrt/test/Dialect/Plan/materialize-shape-calculations.mlir index b3ab0521e..fe419af1c 100644 --- a/mlir-tensorrt/test/Dialect/Plan/materialize-shape-calculations.mlir +++ b/mlir-tensorrt/test/Dialect/Plan/materialize-shape-calculations.mlir @@ -1061,3 +1061,30 @@ func.func @dynamic_gather_simplify(%arg0: tensor, %arg1: tensor, // CHECK-DAG: %[[v2:.+]] = "stablehlo.gather"(%[[v1]], %[[v0]]) // CHECK-DAG: %[[v3:.+]] = plan.with_shape %[[v2]](%[[dim]], %[[c3_i32]]) // CHECK-DAG: return %[[v3]] : tensor + +// ----- + +#profile = #tensorrt.shape_profile + +func.func @reduce_window_dynamic_input(%arg0: tensor {tensorrt.shape_profile = #profile}) -> tensor { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = "stablehlo.reduce_window"(%arg0, %cst) <{padding = dense<0> : tensor<4x2xi64>, window_dimensions = array, window_strides = array}> ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %1 : tensor + }) : (tensor, tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func.func @reduce_window_dynamic_input +// CHECK-SAME: (%[[arg0:.+]]: tensor +// CHECK-DAG: %[[c512:.+]] = arith.constant 512 : index +// CHECK-DAG: %[[c1024:.+]] = arith.constant 1024 : index +// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[dim:.+]] = tensor.dim %[[arg0]], %[[c0]] : tensor +// CHECK-DAG: %[[v0:.+]] = plan.with_shape %[[arg0]](%[[dim]], %[[c3]], %[[c1024]], %[[c1024]]) : +// CHECK-DAG: %[[v1:.+]] = "stablehlo.reduce_window" +// CHECK-DAG: %[[v2:.+]] = arith.maxsi %[[dim]], %[[c0]] : index +// CHECK-DAG: %[[v3:.+]] = plan.with_shape %[[v1]](%[[v2]], %[[c3]], %[[c512]], %[[c512]]) : +// CHECK-DAG: return %[[v3]] diff --git a/mlir-tensorrt/test/Dialect/StableHloExt/reify-ranked-shaped-type.mlir b/mlir-tensorrt/test/Dialect/StableHloExt/reify-ranked-shaped-type.mlir index beb7b5e78..3bd108216 100644 --- a/mlir-tensorrt/test/Dialect/StableHloExt/reify-ranked-shaped-type.mlir +++ b/mlir-tensorrt/test/Dialect/StableHloExt/reify-ranked-shaped-type.mlir @@ -96,3 +96,91 @@ func.func @dynamic_input(%arg0: tensor, %arg1: tensor<1x1x1024x1024 // CHECK-DAG: %[[v6:.+]] = arith.addi %[[v5]], %[[c1]] : index // CHECK-DAG: %[[v7:.+]] = arith.maxsi %[[v6]], %[[c0]] : index // CHECK-DAG: return %[[dim]], %[[c256]], %[[v3]], %[[v7]] : + +// ----- + +func.func @refine_reduce_window(%arg0 : tensor<4x3x1024x1024xf32>) + -> (index, index, index, index) { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %result = "stablehlo.reduce_window"(%arg0, %cst) <{ + padding = dense<1> : tensor<4x2xi64>, + window_dimensions = array, + window_strides = array, + base_dilations = array, + window_dilations = array}> ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %1 : tensor + }) : (tensor<4x3x1024x1024xf32>, tensor) -> tensor + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %d0 = tensor.dim %result, %c0 : tensor + %d1 = tensor.dim %result, %c1 : tensor + %d2 = tensor.dim %result, %c2 : tensor + %d3 = tensor.dim %result, %c3 : tensor + return %d0, %d1, %d2, %d3 : index, index, index, index +} + +// CHECK-LABEL: @refine_reduce_window +// CHECK-DAG: %[[c6:.+]] = arith.constant 6 : index +// CHECK-DAG: %[[c5:.+]] = arith.constant 5 : index +// CHECK-DAG: %[[c1024:.+]] = arith.constant 1024 : index +// CHECK-DAG: return %[[c6]], %[[c5]], %[[c1024]], %[[c1024]] : + +// ----- + +func.func @dynamic_reduce_window(%arg0 : tensor<4x3x?x?xf32>) + -> (index, index, index, index) { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %result = "stablehlo.reduce_window"(%arg0, %cst) <{padding = dense<0> : tensor<4x2xi64>, window_dimensions = array, window_strides = array}> ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %1 : tensor + }) : (tensor<4x3x?x?xf32>, tensor) -> tensor + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %d0 = tensor.dim %result, %c0 : tensor + %d1 = tensor.dim %result, %c1 : tensor + %d2 = tensor.dim %result, %c2 : tensor + %d3 = tensor.dim %result, %c3 : tensor + return %d0, %d1, %d2, %d3 : index, index, index, index +} + +// CHECK-LABEL: @dynamic_reduce_window +// CHECK-SAME: (%[[arg0:.+]]: tensor<4x3x?x?xf32>) +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[cn1:.+]] = arith.constant -1 : index +// CHECK-DAG: %[[cn2:.+]] = arith.constant -2 : index +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[c4:.+]] = arith.constant 4 : index +// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[dim:.+]] = tensor.dim %[[arg0]], %[[c2]] : tensor<4x3x?x?xf32> +// CHECK-DAG: %[[v0:.+]] = arith.maxsi %[[dim]], %[[c0]] +// CHECK-DAG: %[[v1:.+]] = arith.addi %[[v0]], %[[cn2]] : index +// CHECK-DAG: %[[v2:.+]] = arith.cmpi slt, %[[v1]], %[[c0]] : index +// CHECK-DAG: %[[v3:.+]] = arith.subi %[[cn1]], %[[v1]] : index +// CHECK-DAG: %[[v4:.+]] = arith.select %[[v2]], %[[v3]], %[[v1]] : index +// CHECK-DAG: %[[v5:.+]] = arith.divsi %[[v4]], %[[c2]] : index +// CHECK-DAG: %[[v6:.+]] = arith.subi %[[cn1]], %[[v5]] : index +// CHECK-DAG: %[[v7:.+]] = arith.select %[[v2]], %[[v6]], %[[v5]] : index +// CHECK-DAG: %[[v8:.+]] = arith.addi %[[v7]], %[[c1]] : index +// CHECK-DAG: %[[v9:.+]] = arith.maxsi %[[v8]], %[[c0]] +// CHECK-DAG: %[[dim_0:.+]] = tensor.dim %[[arg0]], %[[c3]] : tensor<4x3x?x?xf32> +// CHECK-DAG: %[[v10:.+]] = arith.maxsi %[[dim_0]], %[[c0]] +// CHECK-DAG: %[[v11:.+]] = arith.addi %[[v10]], %[[cn2]] : index +// CHECK-DAG: %[[v12:.+]] = arith.cmpi slt, %[[v11]], %[[c0]] : index +// CHECK-DAG: %[[v13:.+]] = arith.subi %[[cn1]], %[[v11]] : index +// CHECK-DAG: %[[v14:.+]] = arith.select %[[v12]], %[[v13]], %[[v11]] : index +// CHECK-DAG: %[[v15:.+]] = arith.divsi %[[v14]], %[[c2]] : index +// CHECK-DAG: %[[v16:.+]] = arith.subi %[[cn1]], %[[v15]] : index +// CHECK-DAG: %[[v17:.+]] = arith.select %[[v12]], %[[v16]], %[[v15]] : index +// CHECK-DAG: %[[v18:.+]] = arith.addi %[[v17]], %[[c1]] : index +// CHECK-DAG: %[[v19:.+]] = arith.maxsi %[[v18]], %[[c0]] +// CHECK-DAG: return %[[c4]], %[[c3]], %[[v9]], %[[v19]] : From fc8d1b606776937d7c76be68a01355d0c3d4c8b3 Mon Sep 17 00:00:00 2001 From: Sagar Shelke Date: Thu, 14 Nov 2024 14:14:12 -0800 Subject: [PATCH 11/29] [tensorrt] Fix broadcast shape calculation (take 2) (#376) The previous adjustment to `tensorrt::getBroadcastedShape` was not sufficient, and some correct configurations still caused the utility to return an error. This change fixes the issue and adds additional test cases. GitOrigin-RevId: 4e390a802db0a75083a83149415510cc1cb5487a --- .../tensorrt/lib/Utils/ShapeUtils.cpp | 19 +++++++++-------- .../test/Dialect/TensorRT/invalid.mlir | 21 +++++++++++++++++++ 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/mlir-tensorrt/tensorrt/lib/Utils/ShapeUtils.cpp b/mlir-tensorrt/tensorrt/lib/Utils/ShapeUtils.cpp index 2f800f95d..5c68f6b5b 100644 --- a/mlir-tensorrt/tensorrt/lib/Utils/ShapeUtils.cpp +++ b/mlir-tensorrt/tensorrt/lib/Utils/ShapeUtils.cpp @@ -109,32 +109,33 @@ tensorrt::getBroadcastedShape(ArrayRef> shapes) { // shapes are broadcastable. Don't fail because we can't say for sure it's // invalid. const bool allEqual = llvm::all_equal(dimSizes); - if (allEqual && dimSizes.front() == ShapedType::kDynamic) - return ShapedType::kDynamic; - // Dimensions are all equal to a static size. + // Dimensions are all equal to a fixed value or dynamic. if (allEqual) return dimSizes.front(); - // Some dims are '1', all other dims are equal to another fixed number or - // dynamic. + // Mixture of fixed or unkown extents. std::optional nonUnitSize{}; for (int64_t dimSize : dimSizes) { + // Extent of 1 is always valid. if (dimSize == 1) continue; + // Dynamic extent is always valid. if (ShapedType::isDynamic(dimSize)) continue; + // If a extent > 1 is present, check that it matches any previously seen + // static >1 extent. if (nonUnitSize && dimSize == *nonUnitSize) continue; if (nonUnitSize && dimSize != *nonUnitSize) return failure(); nonUnitSize = dimSize; } - if (nonUnitSize) - return *nonUnitSize; - // No other case is valid. - return failure(); + // Return the size >1 is seen, otherwise return dynamic indicator. An + // inferred size of 1 is only possible if all extents are 1; this case is + // captured by the check before the loop. + return nonUnitSize ? *nonUnitSize : ShapedType::kDynamic; }; for (auto dim : llvm::seq(0, rank)) { diff --git a/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/invalid.mlir b/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/invalid.mlir index 333a26b08..a0afa52e9 100644 --- a/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/invalid.mlir +++ b/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/invalid.mlir @@ -857,6 +857,27 @@ func.func @trt_select(%arg0: tensor<10x10xi1>, %arg1: tensor<1x10xf32>, %arg2: t // ----- +func.func @valid_select_ds_infer(%arg0: tensor, %arg1: tensor, %arg2: tensor<1x1xf16>) -> tensor { + %0 = tensorrt.select ins(%arg0, %arg1, %arg2 : tensor, tensor, tensor<1x1xf16>) -> tensor + return %0 : tensor +} + +// ----- + +func.func @valid_select_ds_infer2(%arg0: tensor<1x?xi1>, %arg1: tensor<1x?xf16>, %arg2: tensor<1x1xf16>) -> tensor { + %0 = tensorrt.select ins(%arg0, %arg1, %arg2 : tensor<1x?xi1>, tensor<1x?xf16>, tensor<1x1xf16>) -> tensor + return %0 : tensor +} + +// ----- + +func.func @valid_select_ds_infer3(%arg0: tensor<1x?xi1>, %arg1: tensor<1x?xf16>, %arg2: tensor<1x1xf16>) -> tensor<1x1xf16> { + %0 = tensorrt.select ins(%arg0, %arg1, %arg2 : tensor<1x?xi1>, tensor<1x?xf16>, tensor<1x1xf16>) -> tensor<1x1xf16> + return %0 : tensor<1x1xf16> +} + +// ----- + func.func @trt_softmax(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> { // expected-error @below {{'tensorrt.softmax' op expected axis to be non-negative and less than 2}} %0 = tensorrt.softmax {axis = 2 : i64} %arg0 : tensor<10x10xf32> From 87733ec9b684dc124d18e1e18e181f4daa73a352 Mon Sep 17 00:00:00 2001 From: Sagar Shelke Date: Thu, 14 Nov 2024 15:17:15 -0800 Subject: [PATCH 12/29] [mlir-tensorrt][CI] Fix CPM cache path (#373) This PR fixes CPM cache path in mlir-tensorrt CI. --- .github/workflows/mlir-tensorrt-ci.yml | 202 +++++++++++++++++++++++-- 1 file changed, 188 insertions(+), 14 deletions(-) diff --git a/.github/workflows/mlir-tensorrt-ci.yml b/.github/workflows/mlir-tensorrt-ci.yml index 495eb7ca8..ffb0ce007 100644 --- a/.github/workflows/mlir-tensorrt-ci.yml +++ b/.github/workflows/mlir-tensorrt-ci.yml @@ -6,14 +6,18 @@ on: - main types: [synchronize, opened, reopened, ready_for_review] paths: ["mlir-tensorrt/**"] + push: + branches: + - main + paths: ["mlir-tensorrt/**"] env: DEFAULT_IMAGE: ghcr.io/nvidia/tensorrt-incubator/mlir-tensorrt:cuda12.5-ubuntu-llvm17 REGISTRY: ghcr.io jobs: - mlir-tensorrt-tests: - if: github.event.pull_request.draft == false + mlir-tensorrt-test-pr: + if: github.event_name == 'pull_request' && github.event.pull_request.draft == false # `ubuntu-latest` is a CPU runner. # If selected, tests requiring GPU are not run. runs-on: ubuntu-latest @@ -128,9 +132,183 @@ jobs: username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} run: | - export CCACHE_BASEDIR="$PWD" - export CCACHE_DIR="$PWD/ccache" - export CCACHE_COMPILERCHECK=content + export CCACHE_DIR="/ccache" + export CCACHE_MAXSIZE=10G + ccache --zero-stats || true + ccache --show-stats || true + + cd mlir-tensorrt + cat > build_and_test.sh < build_and_test.sh < build_and_test.sh < Date: Thu, 14 Nov 2024 12:09:19 -0800 Subject: [PATCH 13/29] Adds a tp.equal API and updates tests to use it where applicable --- .../how-to-add-new-ops.md | 5 +- .../00-introduction-to-tripy.md | 2 +- tripy/tests/backend/api/test_compile.py | 7 +- tripy/tests/backend/api/test_executable.py | 3 +- tripy/tests/backend/api/test_stream.py | 4 +- tripy/tests/frontend/module/test_parameter.py | 2 +- .../tests/frontend/module/test_sequential.py | 14 +--- tripy/tests/frontend/ops/test_equal.py | 22 ++++++ tripy/tests/frontend/ops/test_outer.py | 5 +- tripy/tests/integration/test_allclose.py | 22 +++--- tripy/tests/integration/test_arange.py | 4 +- tripy/tests/integration/test_equal.py | 30 ++++++++ tripy/tests/integration/test_flip.py | 6 +- tripy/tripy/backend/api/stream.py | 2 +- tripy/tripy/frontend/ops/allclose.py | 22 +++--- tripy/tripy/frontend/ops/equal.py | 70 +++++++++++++++++++ 16 files changed, 163 insertions(+), 57 deletions(-) create mode 100644 tripy/tests/frontend/ops/test_equal.py create mode 100644 tripy/tests/integration/test_equal.py create mode 100644 tripy/tripy/frontend/ops/equal.py diff --git a/tripy/docs/post0_developer_guides/how-to-add-new-ops.md b/tripy/docs/post0_developer_guides/how-to-add-new-ops.md index 8c5b816f0..928c52e69 100644 --- a/tripy/docs/post0_developer_guides/how-to-add-new-ops.md +++ b/tripy/docs/post0_developer_guides/how-to-add-new-ops.md @@ -325,10 +325,9 @@ import tripy as tp def test_multi_dimensional(): output = tp.theta([2, 3], dim=1) - expected = np.broadcast_to(np.arange(0, 3, dtype=np.float32), (2, 3)) - - assert np.array_equal(cp.from_dlpack(output).get(), expected) + expected = tp.Tensor([[0.0, 1.0, 2.0], [0.0, 1.0, 2.0]], dtype=tp.float32) + assert tp.equal(output, expected) ``` ## Done! diff --git a/tripy/docs/pre0_user_guides/00-introduction-to-tripy.md b/tripy/docs/pre0_user_guides/00-introduction-to-tripy.md index dd714fb08..790a006b8 100644 --- a/tripy/docs/pre0_user_guides/00-introduction-to-tripy.md +++ b/tripy/docs/pre0_user_guides/00-introduction-to-tripy.md @@ -12,7 +12,7 @@ It aims to be fast, easy to debug, and provide an easy-to-use Pythonic interface a = tp.arange(5) c = a + 1.5 print(c) -assert np.array_equal(cp.from_dlpack(c).get(), np.arange(5, dtype=np.float32) + 1.5) # doc: omit +assert cp.array_equal(cp.from_dlpack(c), cp.arange(5, dtype=np.float32) + 1.5) # doc: omit ``` This should look familiar if you've used linear algebra or deep learning libraries like diff --git a/tripy/tests/backend/api/test_compile.py b/tripy/tests/backend/api/test_compile.py index 983708923..a4c241dc9 100644 --- a/tripy/tests/backend/api/test_compile.py +++ b/tripy/tests/backend/api/test_compile.py @@ -31,8 +31,7 @@ def test_function(self): inp = tp.ones((2, 2), dtype=tp.float32) out = compiled_gelu(inp) - # TODO (#225): Replace with tp.all - assert cp.array_equal(cp.from_dlpack(out), cp.from_dlpack(tp.relu(inp))) + assert tp.equal(out, tp.relu(inp)) def test_module(self): layernorm = tp.LayerNorm(2) @@ -41,7 +40,7 @@ def test_module(self): inp = tp.ones((2, 2), dtype=tp.float32) out = compiled_layernorm(inp) - assert cp.array_equal(cp.from_dlpack(out), cp.from_dlpack(layernorm(inp))) + assert tp.equal(out, layernorm(inp)) def test_compile_arg_order_irrelevant(self): # The order of arguments we specify to `compile` should not affect the order @@ -214,4 +213,4 @@ def test_linear(self): out = compiled_linear(a) - assert cp.array_equal(cp.from_dlpack(out), cp.from_dlpack(linear(a))) + assert tp.equal(out, linear(a)) diff --git a/tripy/tests/backend/api/test_executable.py b/tripy/tests/backend/api/test_executable.py index 0d64d9b6c..cc412b086 100644 --- a/tripy/tests/backend/api/test_executable.py +++ b/tripy/tests/backend/api/test_executable.py @@ -17,7 +17,6 @@ import tempfile from typing import Sequence -import cupy as cp import pytest from tests import helper from tests.backend.api.conftest import * @@ -117,4 +116,4 @@ def test_file_io(self, single_return_executable): inp = tp.iota((2, 2), dtype=tp.float32) out1 = single_return_executable(inp, inp) out2 = loaded_executable(inp, inp) - assert cp.array_equal(cp.from_dlpack(out1), cp.from_dlpack(out2)) + assert tp.equal(out1, out2) diff --git a/tripy/tests/backend/api/test_stream.py b/tripy/tests/backend/api/test_stream.py index 5d05ebea1..72c160278 100644 --- a/tripy/tests/backend/api/test_stream.py +++ b/tripy/tests/backend/api/test_stream.py @@ -39,11 +39,11 @@ def test_enqueue_work_on_stream(): out = compiled_linear(a) tp.default_stream().synchronize() - assert cp.array_equal(cp.from_dlpack(out), cp.from_dlpack(linear(a))) + assert tp.equal(out, linear(a)) stream = tp.Stream() compiled_linear.stream = stream out = compiled_linear(a) # stream sync below is not required since from_dlpack method will eval() the tensor which will call stream sync anyway. compiled_linear.stream.synchronize() - assert cp.array_equal(cp.from_dlpack(out), cp.from_dlpack(linear(a))) + assert tp.equal(out, linear(a)) diff --git a/tripy/tests/frontend/module/test_parameter.py b/tripy/tests/frontend/module/test_parameter.py index 09f2754e7..50ed6c084 100644 --- a/tripy/tests/frontend/module/test_parameter.py +++ b/tripy/tests/frontend/module/test_parameter.py @@ -36,7 +36,7 @@ def test_is_equivalent_to_tensor(self): tensor = tp.Tensor([1, 2, 3]) param = tp.Parameter(tensor) - assert np.array_equal(cp.from_dlpack(param).get(), cp.from_dlpack(tensor).get()) + assert tp.equal(param, tensor) def test_can_construct_from_non_tensor(self): param = tp.Parameter([1, 2, 3]) diff --git a/tripy/tests/frontend/module/test_sequential.py b/tripy/tests/frontend/module/test_sequential.py index 3a751f99f..7643a6fde 100644 --- a/tripy/tests/frontend/module/test_sequential.py +++ b/tripy/tests/frontend/module/test_sequential.py @@ -47,10 +47,6 @@ def test_basic_structure(self, sequential_network): assert len(sequential_network) == 2 assert isinstance(sequential_network[0], tp.Linear) - assert np.array_equal( - cp.from_dlpack(sequential_network[0].weight), cp.from_dlpack(sequential_network[0].weight) - ) - assert np.array_equal(cp.from_dlpack(sequential_network[0].bias), cp.from_dlpack(sequential_network[0].bias)) def test_named_children(self, sequential_network): expected_names = [("0", sequential_network[0]), ("1", sequential_network[1])] @@ -72,7 +68,7 @@ def test_state_dict(self, sequential_network): def test_load_state_dict(self, sequential_network): new_state_dict = {"0.weight": tp.Parameter(tp.ones((3, 1)))} sequential_network.load_state_dict(new_state_dict, strict=False) - assert np.array_equal(cp.from_dlpack(sequential_network[0].weight), cp.from_dlpack(new_state_dict["0.weight"])) + assert tp.equal(sequential_network[0].weight, new_state_dict["0.weight"]) def test_modify_parameters(self, sequential_network): new_param = tp.Parameter(tp.ones((2, 3))) @@ -125,9 +121,7 @@ def test_state_dict(self, dict_sequential_network): def test_load_state_dict(self, dict_sequential_network): new_state_dict = {"layer1.weight": tp.Parameter(tp.ones((3, 1)))} dict_sequential_network.load_state_dict(new_state_dict, strict=False) - assert np.array_equal( - cp.from_dlpack(dict_sequential_network["layer1"].weight), cp.from_dlpack(new_state_dict["layer1.weight"]) - ) + assert tp.equal(dict_sequential_network["layer1"].weight, new_state_dict["layer1.weight"]) def test_modify_parameters(self, dict_sequential_network): new_weight = tp.Parameter(tp.ones((2, 3))) @@ -179,9 +173,7 @@ def test_load_state_dict_nested(self, nested_sequential_network): "1.1.weight": tp.Parameter(tp.ones((1, 3))), } nested_sequential_network.load_state_dict(new_state_dict, strict=False) - assert np.array_equal( - cp.from_dlpack(nested_sequential_network[1][1].weight), cp.from_dlpack(new_state_dict["1.1.weight"]) - ) + assert tp.equal(nested_sequential_network[1][1].weight, new_state_dict["1.1.weight"]) def test_str_representation(self, nested_sequential_network): expected_str = dedent( diff --git a/tripy/tests/frontend/ops/test_equal.py b/tripy/tests/frontend/ops/test_equal.py new file mode 100644 index 000000000..bc278ac9e --- /dev/null +++ b/tripy/tests/frontend/ops/test_equal.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from tests import helper +import tripy as tp + + +class TestEqual: + def test_mismatched_dtypes_disallowed(self): + with helper.raises(tp.TripyException, match="Mismatched data types for 'equal'."): + tp.equal(tp.ones((2,), dtype=tp.float32), tp.ones((2,), dtype=tp.float16)) diff --git a/tripy/tests/frontend/ops/test_outer.py b/tripy/tests/frontend/ops/test_outer.py index 17d2fd7be..11732fe39 100644 --- a/tripy/tests/frontend/ops/test_outer.py +++ b/tripy/tests/frontend/ops/test_outer.py @@ -14,10 +14,13 @@ # limitations under the License. from tests import helper + import tripy as tp + + class TestOuter: def test_invalid_rank_fails(self): a = tp.ones((5, 1)) b = tp.ones((1, 4)) with helper.raises(tp.TripyException, "Expected input vectors to be 1-d."): - tp.outer(a, b) \ No newline at end of file + tp.outer(a, b) diff --git a/tripy/tests/integration/test_allclose.py b/tripy/tests/integration/test_allclose.py index cf83c4770..2458e5b64 100644 --- a/tripy/tests/integration/test_allclose.py +++ b/tripy/tests/integration/test_allclose.py @@ -16,27 +16,23 @@ # import pytest -import torch import tripy as tp class TestAllClose: @pytest.mark.parametrize( - "tensor_a, tensor_b, rtol, atol", + "tensor_a, tensor_b, rtol, atol, expected", [ - ([1e10, 1e-7], [1.00001e10, 1e-8], 1e-05, 1e-08), - # TODO (#232): Reenable when fixed - # ([1e10, 1e-8], [1.00001e10, 1e-9], 1e-05, 1e-08), - ([1e10, 1e-8], [1.0001e10, 1e-9], 1e-05, 1e-08), - ([1.0, 2.0, 3.0], [1.01, 2.01, 3.01], 0.0, 0.01), - ([1.0, 2.0, 3.0], [1.01, 2.01, 3.01], 0.01, 0.0), - ([1.0, 2.0, 3.0], [1.01, 2.01, 3.01], 0.01, 0.01), + ([1e10, 1e-7], [1e10, 1e-8], 1e-05, 1e-08, False), + ([1e10, 1e-8], [1.0001e10, 1e-9], 1e-05, 1e-08, False), + ([1.0, 2.0, 3.0], [1.01, 2.01, 3.01], 0.0, 0.01, True), + ([1.0, 2.0, 3.0], [1.01, 2.01, 3.01], 0.01, 0.0, True), + ([1.0, 2.0, 3.0], [1.01, 2.01, 3.01], 0.01, 0.01, True), ], ) - def test_all_close_float32(self, tensor_a, tensor_b, rtol, atol): - torch_result = torch.allclose(torch.FloatTensor(tensor_a), torch.FloatTensor(tensor_b), rtol=rtol, atol=atol) - tp_result = tp.allclose( + def test_all_close_float32(self, tensor_a, tensor_b, rtol, atol, expected): + out = tp.allclose( tp.Tensor(tensor_a, dtype=tp.float32), tp.Tensor(tensor_b, dtype=tp.float32), rtol=rtol, atol=atol ) - assert torch_result == tp_result + assert out == expected diff --git a/tripy/tests/integration/test_arange.py b/tripy/tests/integration/test_arange.py index 1a19e32ae..fd2282cd3 100644 --- a/tripy/tests/integration/test_arange.py +++ b/tripy/tests/integration/test_arange.py @@ -15,9 +15,9 @@ import cupy as cp import numpy as np -import tripy as tp from tests import helper -import pytest + +import tripy as tp class TestArange: diff --git a/tripy/tests/integration/test_equal.py b/tripy/tests/integration/test_equal.py new file mode 100644 index 000000000..3e56eb498 --- /dev/null +++ b/tripy/tests/integration/test_equal.py @@ -0,0 +1,30 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +import tripy as tp + + +@pytest.mark.parametrize( + "a, b, expected", + [ + (tp.Tensor([1, 2], dtype=tp.float32), tp.Tensor([1, 2], dtype=tp.float32), True), + (tp.ones((2, 2), dtype=tp.int32), tp.Tensor([[1, 1], [1, 1]], dtype=tp.int32), True), + (tp.ones((1, 4)), tp.ones((4, 1)), False), + ], +) +def test_equal(a, b, expected): + out = tp.equal(a, b) + assert out == expected diff --git a/tripy/tests/integration/test_flip.py b/tripy/tests/integration/test_flip.py index 20e7eea1f..8118716d5 100644 --- a/tripy/tests/integration/test_flip.py +++ b/tripy/tests/integration/test_flip.py @@ -40,12 +40,12 @@ def test_no_op(self): cp_a = cp.arange(16).reshape((4, 4)).astype(cp.float32) a = tp.Tensor(cp_a, device=tp.device("gpu")) f = tp.flip(a, dims=[]) - assert cp.array_equal(cp.from_dlpack(a), cp.from_dlpack(f)) + assert tp.equal(a, f) def test_zero_rank(self): t = tp.Tensor(1) f = tp.flip(t) - assert cp.array_equal(cp.from_dlpack(t), cp.from_dlpack(f)) + assert tp.equal(t, f) @pytest.mark.parametrize( "dims1, dims2", @@ -56,4 +56,4 @@ def test_equivalences(self, dims1, dims2): a = tp.Tensor(cp_a, device=tp.device("gpu")) f1 = tp.flip(a, dims=dims1) f2 = tp.flip(a, dims=dims2) - assert cp.array_equal(cp.from_dlpack(f1), cp.from_dlpack(f2)) + assert tp.equal(f1, f2) diff --git a/tripy/tripy/backend/api/stream.py b/tripy/tripy/backend/api/stream.py index 92099cf9d..0a4e73c2f 100644 --- a/tripy/tripy/backend/api/stream.py +++ b/tripy/tripy/backend/api/stream.py @@ -61,7 +61,7 @@ def __init__(self, priority: int = 0) -> None: input = tp.ones((2, 2), dtype=tp.float32) output = compiled_linear(input) - assert cp.array_equal(cp.from_dlpack(output), cp.from_dlpack(linear(input))) + assert tp.equal(output, linear(input)) """ if priority != 0: raise_error( diff --git a/tripy/tripy/frontend/ops/allclose.py b/tripy/tripy/frontend/ops/allclose.py index dbb811fa7..70e6b7a5c 100644 --- a/tripy/tripy/frontend/ops/allclose.py +++ b/tripy/tripy/frontend/ops/allclose.py @@ -15,24 +15,20 @@ # limitations under the License. # -from tripy import export, constraints -from tripy.common.exception import raise_error +from tripy import constraints, export @export.public_api(document_under="operations/functions") -@constraints.dtypes( - constraints={"a": "T1", "b": "T1"}, - variables={"T1": ["float32", "float16", "bfloat16"]}, -) -def allclose(a: "tripy.Tensor", b: "tripy.Tensor", rtol: float = 1e-05, atol: float = 1e-08) -> bool: +@constraints.dtypes(constraints={"input": "T1", "other": "T1"}, variables={"T1": ["float32", "float16", "bfloat16"]}) +def allclose(input: "tripy.Tensor", other: "tripy.Tensor", rtol: float = 1e-05, atol: float = 1e-08) -> bool: r""" - Returns true if the following equation is true for every element in ``a`` and ``b`` : + Returns ``True`` if the following equation is true for every element in ``input`` and ``other`` : - :math:`|a_i - b_i| <= (\text{atol} + \text{rtol} * |b_i|)` + :math:`|\text{input}_i - \text{other}_i| <= (\text{atol} + \text{rtol} * |\text{other}_i|)` Args: - a: First tensor to compare. - b: Second tensor to compare. + input: First tensor to compare. + other: Second tensor to compare. rtol: The relative tolerance. atol: The absolute tolerance. @@ -55,8 +51,8 @@ def allclose(a: "tripy.Tensor", b: "tripy.Tensor", rtol: float = 1e-05, atol: fl out = tp.allclose(tp.Tensor([1e-7]), tp.Tensor([1.2e-7])) assert not out """ - from tripy.frontend.trace.ops.unary_elementwise import abs from tripy.frontend.trace.ops.reduce import all + from tripy.frontend.trace.ops.unary_elementwise import abs - compare = abs(a - b) <= (atol + rtol * abs(b)) + compare = abs(input - other) <= (atol + rtol * abs(other)) return bool(all(compare)) diff --git a/tripy/tripy/frontend/ops/equal.py b/tripy/tripy/frontend/ops/equal.py new file mode 100644 index 000000000..357c82e54 --- /dev/null +++ b/tripy/tripy/frontend/ops/equal.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from tripy import constraints, export +from tripy.common.datatype import DATA_TYPES + + +@export.public_api(document_under="operations/functions") +@constraints.dtypes(constraints={"input": "T1", "other": "T1"}, variables={"T1": list(DATA_TYPES.keys())}) +def equal(input: "tripy.Tensor", other: "tripy.Tensor") -> bool: + r""" + Returns ``True`` if ``input`` and ``other`` have the same shape and elements. + + Args: + input: First tensor to compare. + other: Second tensor to compare. + + Returns: + ``True`` if the tensors have the same shape and elements and ``False`` otherwise. + + .. code-block:: python + :linenos: + :caption: Identical tensors + + # doc: print-locals a b is_equal + a = tp.ones((1, 2), dtype=tp.float32) + b = tp.ones((1, 2), dtype=tp.float32) + + is_equal = tp.equal(a, b) + assert is_equal + + .. code-block:: python + :linenos: + :caption: Different shapes + + # doc: print-locals a b is_equal + a = tp.ones((1, 2), dtype=tp.float32) + b = tp.ones((2, 2), dtype=tp.float32) + + is_equal = tp.equal(a, b) + assert not is_equal + + .. code-block:: python + :linenos: + :caption: Different elements + + # doc: print-locals a b is_equal + a = tp.ones((1, 2), dtype=tp.float32) + b = tp.zeros((1, 2), dtype=tp.float32) + + is_equal = tp.equal(a, b) + assert not is_equal + """ + from tripy.frontend.trace.ops.reduce import all + + if input.shape != other.shape: + return False + + return bool(all(input == other)) From bcfb737b92117a302182781c0518fbcdb0e8e351 Mon Sep 17 00:00:00 2001 From: pranavm Date: Thu, 14 Nov 2024 12:37:41 -0800 Subject: [PATCH 14/29] Cleans up printing of modules, adds automatic conversion to Parameter when setting attributes --- tripy/tests/frontend/module/test_module.py | 32 ++++---- .../tests/frontend/module/test_sequential.py | 74 +++++++++---------- tripy/tests/helper.py | 3 +- tripy/tripy/frontend/module/module.py | 15 ++-- 4 files changed, 63 insertions(+), 61 deletions(-) diff --git a/tripy/tests/frontend/module/test_module.py b/tripy/tests/frontend/module/test_module.py index a336c0687..245f0483b 100644 --- a/tripy/tests/frontend/module/test_module.py +++ b/tripy/tests/frontend/module/test_module.py @@ -44,6 +44,12 @@ def test_get_set_attr(self, network): assert cp.from_dlpack(dict(network.named_parameters())["param"]).get().tolist() == [0.0, 1.0] assert "dummy1" not in dict(network.named_children()) + def test_automatic_conversion_to_parameter_of_direct_attributes(self, network): + network.param = [0.0, 1.0] + assert isinstance(network.param, tp.Parameter) + + assert "param" in dict(network.named_parameters()) + def test_incompatible_parameter_cannot_be_set(self, network): with helper.raises( tp.TripyException, match=r"New parameter shape: \[2, 3\] is not compatible with current shape: \[2\]" @@ -130,25 +136,23 @@ def test_mixed_collections_not_registered(self, network): def test_module_print(self, network): expected_output = dedent( - """\ + """ Network( - dummy1= - DummyOp( - nested= - DummyNestedOp( - param=shape(2), + param: Parameter = (shape=[2], dtype=float32), + dummy1: Module = DummyOp( + nested: Module = DummyNestedOp( + param: Parameter = (shape=[2], dtype=float32), ), ), - dummy2= - DummyOp( - nested= - DummyNestedOp( - param=shape(2), + dummy2: Module = DummyOp( + nested: Module = DummyNestedOp( + param: Parameter = (shape=[2], dtype=float32), ), ), - param=shape(2), - )""" - ) + ) + """ + ).strip() + assert str(network) == expected_output class TestModuleWithList: diff --git a/tripy/tests/frontend/module/test_sequential.py b/tripy/tests/frontend/module/test_sequential.py index 7643a6fde..43ad48ed2 100644 --- a/tripy/tests/frontend/module/test_sequential.py +++ b/tripy/tests/frontend/module/test_sequential.py @@ -81,20 +81,20 @@ def test_invalid_index_access(self, sequential_network): def test_str_representation(self, sequential_network): expected_str = dedent( - """\ + """ Sequential( - 0= - Linear( - weight=[3, 1], - bias=[3], + 0: Module = Linear( + weight: Parameter = (shape=[3, 1], dtype=float32), + bias: Parameter = (shape=[3], dtype=float32), ), - 1= - Linear( - weight=[2, 3], - bias=[2], + 1: Module = Linear( + weight: Parameter = (shape=[2, 3], dtype=float32), + bias: Parameter = (shape=[2], dtype=float32), ), - )""" - ) + ) + """ + ).strip() + assert str(sequential_network) == expected_str @@ -130,20 +130,19 @@ def test_modify_parameters(self, dict_sequential_network): def test_str_representation(self, dict_sequential_network): expected_str = dedent( - """\ + """ Sequential( - layer1= - Linear( - weight=[3, 1], - bias=[3], + layer1: Module = Linear( + weight: Parameter = (shape=[3, 1], dtype=float32), + bias: Parameter = (shape=[3], dtype=float32), ), - layer2= - Linear( - weight=[2, 3], - bias=[2], + layer2: Module = Linear( + weight: Parameter = (shape=[2, 3], dtype=float32), + bias: Parameter = (shape=[2], dtype=float32), ), - )""" - ) + ) + """ + ).strip() assert str(dict_sequential_network) == expected_str @@ -177,26 +176,23 @@ def test_load_state_dict_nested(self, nested_sequential_network): def test_str_representation(self, nested_sequential_network): expected_str = dedent( - """\ + """ Sequential( - 0= - Linear( - weight=[4, 2], - bias=[4], + 0: Module = Linear( + weight: Parameter = (shape=[4, 2], dtype=float32), + bias: Parameter = (shape=[4], dtype=float32), ), - 1= - Sequential( - 0= - Linear( - weight=[3, 4], - bias=[3], + 1: Module = Sequential( + 0: Module = Linear( + weight: Parameter = (shape=[3, 4], dtype=float32), + bias: Parameter = (shape=[3], dtype=float32), ), - 1= - Linear( - weight=[1, 3], - bias=[1], + 1: Module = Linear( + weight: Parameter = (shape=[1, 3], dtype=float32), + bias: Parameter = (shape=[1], dtype=float32), ), ), - )""" - ) + ) + """ + ).strip() assert str(nested_sequential_network) == expected_str diff --git a/tripy/tests/helper.py b/tripy/tests/helper.py index db8dce5fe..ebefc94bb 100644 --- a/tripy/tests/helper.py +++ b/tripy/tests/helper.py @@ -573,7 +573,8 @@ def pretty_str_from_dict(dct): locals_str += f"\n>>> {name}" if isinstance(obj, tp.Module): - locals_str += f".state_dict()\n{pretty_str_from_dict(obj.state_dict())}" + locals_str += f"\n{obj}" + locals_str += f"\n>>> {name}.state_dict()\n{pretty_str_from_dict(obj.state_dict())}" elif isinstance(obj, dict): locals_str += f"\n{pretty_str_from_dict(obj)}" else: diff --git a/tripy/tripy/frontend/module/module.py b/tripy/tripy/frontend/module/module.py index 17bd28999..2cc41cc85 100644 --- a/tripy/tripy/frontend/module/module.py +++ b/tripy/tripy/frontend/module/module.py @@ -97,6 +97,8 @@ def __call__(self, x): def __setattr__(self, name: str, value: Any) -> None: if isinstance(value, Parameter) or name in dict(self.named_parameters()): + if not isinstance(value, Parameter): + value = Parameter(value) _check_param_compatible(getattr(self, name, None), value, name) super().__setattr__(name, value) @@ -326,14 +328,13 @@ def __str__(self): class_name = self.__class__.__name__ module_str = f"{class_name}(\n" - # Add children with hierarchical indentation - for name, child in self.named_children(): - c = indent(str(child), prefix=" ") - module_str += f" {name}=\n{c},\n" - - # Add parameters with hierarchical indentation + body_str = "" for name, param in self.named_parameters(): - module_str += f" {name}={param.shape},\n" + body_str += f"{name}: Parameter = (shape={param.shape}, dtype={param.dtype}),\n" + + for name, child in self.named_children(): + body_str += f"{name}: Module = {str(child).strip()},\n" + module_str += indent(body_str, " " * 4) module_str += f")" return module_str From 64c780f0440ee0dd135e0f3e3c71fbfb7e76f5c9 Mon Sep 17 00:00:00 2001 From: pranavm Date: Thu, 14 Nov 2024 12:48:16 -0800 Subject: [PATCH 15/29] Updates nanoGPT test to use standard regexs --- tripy/examples/nanogpt/README.md | 13 ++----------- tripy/tests/test_examples.py | 7 +------ 2 files changed, 3 insertions(+), 17 deletions(-) diff --git a/tripy/examples/nanogpt/README.md b/tripy/examples/nanogpt/README.md index a3dc4678f..11e9b29a0 100644 --- a/tripy/examples/nanogpt/README.md +++ b/tripy/examples/nanogpt/README.md @@ -38,10 +38,7 @@ for expected accuracy. @@ -65,13 +62,7 @@ To run with a quantization mode, pass `--quant-mode` to `example.py`. The suppor diff --git a/tripy/tests/test_examples.py b/tripy/tests/test_examples.py index 3fe546268..e24b906fd 100644 --- a/tripy/tests/test_examples.py +++ b/tripy/tests/test_examples.py @@ -103,12 +103,7 @@ def test_examples(example, sandboxed_install_run): if block.has_marker("test: expected_stdout"): print("Checking command output against expected output: ", end="") out = statuses[-1].stdout.strip() - matched = False - expected_outs = dedent(block_text).split("====") - for expected in expected_outs: - if re.match(expected.strip(), out): - matched = True - break + matched = re.match(dedent(block_text).strip(), out) print("matched!" if matched else "did not match!") print(f"==== STDOUT ====\n{out}") assert matched From d4a305809fd665ddf2d864635916d66f9d6882f5 Mon Sep 17 00:00:00 2001 From: Sagar Shelke Date: Fri, 15 Nov 2024 16:13:56 -0800 Subject: [PATCH 16/29] Move internal changes (#380) This PR moves the following internal changes to OSS repo. ## [plan][transforms] Fix an issue in shape materialization pass This MR fixes an issue in `SimplifyExtractOfReshape` pattern of shape materialization pass. This patten was being applied even when reshape op operand is dynamic. However, with dynamic operand, mapping extract index into reshape operand doesn't work. With this change, we return failure if reshape op operand is dynamic. MLIR test is added for scenario when pattern should return failure. ## [tensorrt] Fix incorrect handling of dynamic shape in `tensorrt-broadcast-elimination` Fixes an issue where a `tensorrt-broadcast-elimination` would improperly handle dynamically shaped tensors when attempting to reshape them. In certain cases (when more than 1 dynamic dimension is present), to perform a reshape, the target shape must be explicitly calculated in the IR and a dynamic reshape must be created. Co-authored-by: Copybara Bot --- .../MaterializeShapeCalculations.cpp | 3 + .../Transforms/BroadcastElimination.cpp | 66 ++++++++++++++++++- .../TensorRT/broadcast-elimination.mlir | 39 +++++++++++ .../Plan/materialize-shape-calculations.mlir | 24 +++++++ .../test/models/bert.stablehlo.elided.mlir | 2 +- .../models/gpt2.stablehlo.bs2.elided.mlir | 2 +- .../test/models/gpt2.stablehlo.elided.mlir | 2 +- .../models/llama-68m.stablehlo.elided.mlir | 2 +- .../models/llama-v2.stablehlo.elided.mlir | 2 +- .../models/resnet50.stablehlo.elided.mlir | 2 +- .../test/models/swin.stablehlo.elided.mlir | 2 +- .../models/whisper-jax.stablehlo.elided.mlir | 2 +- 12 files changed, 138 insertions(+), 10 deletions(-) diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/MaterializeShapeCalculations.cpp b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/MaterializeShapeCalculations.cpp index 6ae987c71..7e25472a3 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/MaterializeShapeCalculations.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/MaterializeShapeCalculations.cpp @@ -361,6 +361,9 @@ struct SimplifyExtractOfReshape : public OpRewritePattern { if (!reshapeOp) return failure(); + if (!reshapeOp.getOperand().getType().hasStaticShape()) + return failure(); + std::optional> coords = getConstantIntValues(getAsOpFoldResult(op.getIndices())); if (!coords) diff --git a/mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/BroadcastElimination.cpp b/mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/BroadcastElimination.cpp index b5ff60264..de5f1f099 100644 --- a/mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/BroadcastElimination.cpp +++ b/mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/BroadcastElimination.cpp @@ -120,6 +120,65 @@ struct PushDownBroadcastReduceRankOp : public OpRewritePattern { }; } // namespace +static Value expandRank(RewriterBase &rewriter, Location loc, + TypedValue input, + ArrayRef reorderedBroadcastDims, + RankedTensorType resultType) { + RankedTensorType inputType = input.getType(); + // For <= 1 dynamic dims, no need to do dynamic reshape. + if (input.getType().getNumDynamicDims() <= 1) { + SmallVector staticShape(resultType.getRank()); + + unsigned inputIdx = 0; + for (unsigned i = 0, e = staticShape.size(); i < e; i++) { + if (inputIdx < reorderedBroadcastDims.size() && + i == reorderedBroadcastDims[inputIdx]) { + staticShape[i] = inputType.getDimSize(inputIdx++); + continue; + } + staticShape[i] = 1; + } + return rewriter.create(loc, resultType.clone(staticShape), + input); + } + + // Otherwise, we need to do dynamic reshape. + auto shape = rewriter.create(loc, input); + SmallVector shapeComponents(resultType.getRank()); + SmallVector staticShape(resultType.getRank()); + unsigned inputIdx = 0; + for (unsigned i = 0, e = shapeComponents.size(); i < e; i++) { + if (inputIdx < reorderedBroadcastDims.size() && + i == reorderedBroadcastDims[inputIdx]) { + if (!inputType.isDynamicDim(inputIdx)) { + staticShape[i] = inputType.getDimSize(inputIdx); + shapeComponents[i] = rewriter.create( + loc, rewriter.getI32TensorAttr( + {static_cast(inputType.getDimSize(inputIdx++))})); + continue; + } + shapeComponents[i] = rewriter.create( + loc, shape, + /*offset=*/ArrayRef{static_cast(inputIdx++)}, + ArrayRef{1}, ArrayRef{1}); + staticShape[i] = ShapedType::kDynamic; + continue; + } + staticShape[i] = 1; + shapeComponents[i] = rewriter.create( + loc, rewriter.getI32TensorAttr( + {static_cast(inputType.getDimSize(1))})); + } + auto newShape = rewriter.create( + loc, + RankedTensorType::get(static_cast(shapeComponents.size()), + rewriter.getI32Type()), + shapeComponents, /*axis=*/0); + + return rewriter.create(loc, resultType.clone(staticShape), input, + newShape); +} + namespace { /// Create transpose + expand_rank on the input of a `tensorrt.broadcast` so /// that the result has the same rank as the `tensorrt.broadcast` result and the @@ -157,8 +216,9 @@ struct SimplifyBroadcast : public OpRewritePattern { } expandedShape[i] = 1; } - Value expanded = rewriter.create( - loc, resultType.clone(expandedShape), transposeOp); + + Value expanded = expandRank(rewriter, loc, transposeOp, + reorderedBroadcastDims, resultType); rewriter.replaceOpWithNewOp( op, op.getType(), expanded, op.getShape(), llvm::to_vector(llvm::seq(0, resultType.getRank()))); @@ -341,6 +401,8 @@ class BroadcastEliminationPass patterns.add(&getContext()); + tensorrt::ReshapeOp::getCanonicalizationPatterns(patterns, + patterns.getContext()); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { emitError(getOperation()->getLoc()) diff --git a/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/broadcast-elimination.mlir b/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/broadcast-elimination.mlir index 4cfba69db..bba942f64 100644 --- a/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/broadcast-elimination.mlir +++ b/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/broadcast-elimination.mlir @@ -236,3 +236,42 @@ func.func @broadcast_elim_matmul_vector(%arg0: tensor, %arg1: tenso // CHECK: return %[[v0]] : tensor +// ----- + +func.func @broadcast_dynamic_expand_shape_regression(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<4xi32>) -> tensor { + %0 = tensorrt.broadcast %arg0 broadcast_dims<0, 1, 2, 3> shape(%arg3 : tensor<4xi32>) : tensor to tensor + %1 = tensorrt.broadcast %arg1 broadcast_dims<2, 3> shape(%arg3 : tensor<4xi32>) : tensor to tensor + %2 = tensorrt.select ins(%0, %arg2, %1 : tensor, tensor, tensor) + -> tensor + return %2 : tensor +} + +// CHECK-LABEL: func.func @broadcast_dynamic_expand_shape_regression +// CHECK-SAME: (%[[arg0:.+]]: tensor, %[[arg1:.+]]: tensor, %[[arg2:.+]]: tensor, %[[arg3:.+]]: tensor<4xi32>) -> tensor { +// CHECK: %[[v0:.+]] = tensorrt.reshape %[[arg1]] : tensor to tensor<1x1x?x1xf16> +// CHECK: %[[v1:.+]] = tensorrt.select ins(%[[arg0]], %[[arg2]], %[[v0]] : tensor, tensor, tensor<1x1x?x1xf16>) -> tensor +// CHECK: return %[[v1]] : tensor + +// ----- + +func.func @broadcast_dynamic_expand_shape_regression(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<4xi32>) -> tensor { + %0 = tensorrt.broadcast %arg0 broadcast_dims<0, 1, 2, 3> shape(%arg3 : tensor<4xi32>) : tensor to tensor + %1 = tensorrt.broadcast %arg1 broadcast_dims<3, 2, 1> shape(%arg3 : tensor<4xi32>) : tensor to tensor + %2 = tensorrt.select ins(%0, %arg2, %1 : tensor, tensor, tensor) + -> tensor + return %2 : tensor +} + +// CHECK: #[[$map:.+]] = affine_map<(d0, d1, d2) -> (d2, d1, d0)> +// CHECK: module { +// CHECK-LABEL: func.func @broadcast_dynamic_expand_shape_regression +// CHECK-SAME: (%[[arg0:.+]]: tensor, %[[arg1:.+]]: tensor, %[[arg2:.+]]: tensor, %[[arg3:.+]]: tensor<4xi32>) -> tensor { +// CHECK: %[[cst_i32:.+]] = tensorrt.constant dense<1> : tensor<1xi32> +// CHECK: %[[v0:.+]] = tensorrt.transpose {permutation = #[[$map]]} %[[arg1]] : tensor to tensor +// CHECK: %[[v1:.+]] = tensorrt.shape %[[v0]] : tensor -> tensor<3xi32> +// CHECK: %[[v2:.+]] = tensorrt.slice %[[v1]][0][1][1] : tensor<3xi32> to tensor<1xi32> +// CHECK: %[[v3:.+]] = tensorrt.slice %[[v1]][2][1][1] : tensor<3xi32> to tensor<1xi32> +// CHECK: %[[v4:.+]] = tensorrt.concatenation {axis = 0 : i32} ins(%[[cst_i32]], %[[v2]], %[[cst_i32]], %[[v3]] : tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> +// CHECK: %[[v5:.+]] = tensorrt.reshape %[[v0]] shape(%[[v4]]: tensor<4xi32>) : tensor to tensor<1x?x1x?xf16> +// CHECK: %[[v6:.+]] = tensorrt.select ins(%[[arg0]], %[[arg2]], %[[v5]] : tensor, tensor, tensor<1x?x1x?xf16>) -> tensor +// CHECK: return %[[v6]] : tensor \ No newline at end of file diff --git a/mlir-tensorrt/test/Dialect/Plan/materialize-shape-calculations.mlir b/mlir-tensorrt/test/Dialect/Plan/materialize-shape-calculations.mlir index fe419af1c..935ab15a6 100644 --- a/mlir-tensorrt/test/Dialect/Plan/materialize-shape-calculations.mlir +++ b/mlir-tensorrt/test/Dialect/Plan/materialize-shape-calculations.mlir @@ -1088,3 +1088,27 @@ func.func @reduce_window_dynamic_input(%arg0: tensor {tensorrt.shap // CHECK-DAG: %[[v2:.+]] = arith.maxsi %[[dim]], %[[c0]] : index // CHECK-DAG: %[[v3:.+]] = plan.with_shape %[[v1]](%[[v2]], %[[c3]], %[[c512]], %[[c512]]) : // CHECK-DAG: return %[[v3]] + +// ----- + +func.func @simplify_extract_of_reshape_negative(%arg0: tensor<1x?x3x4xf32>) -> f32 { + %c0 = arith.constant 0: index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %1 = stablehlo.reshape %arg0 : (tensor<1x?x3x4xf32>) -> tensor<1x6x4xf32> + %2 = tensor.extract %1[%c0, %c1, %c2] : tensor<1x6x4xf32> + return %2 : f32 +} + +// CHECK-LABEL: simplify_extract_of_reshape_negative +// CHECK-SAME: (%[[arg0:.+]]: tensor<1x?x3x4xf32>) +// CHECK-NEXT: %[[c4:.+]] = arith.constant 4 : index +// CHECK-NEXT: %[[c3:.+]] = arith.constant 3 : index +// CHECK-NEXT: %[[c2:.+]] = arith.constant 2 : index +// CHECK-NEXT: %[[c1:.+]] = arith.constant 1 : index +// CHECK-NEXT: %[[c0:.+]] = arith.constant 0 : index +// CHECK-NEXT: %[[dim:.+]] = tensor.dim %[[arg0]], %[[c1]] : tensor<1x?x3x4xf32> +// CHECK-NEXT: %[[v0:.+]] = plan.with_shape %[[arg0]](%[[c1]], %[[dim]], %[[c3]], %[[c4]]) +// CHECK-NEXT: %[[v1:.+]] = stablehlo.reshape %[[v0]] +// CHECK-NEXT: %[[extracted:.+]] = tensor.extract %[[v1]][%[[c0]], %[[c1]], %[[c2]]] +// CHECK-NEXT: return %extracted \ No newline at end of file diff --git a/mlir-tensorrt/test/models/bert.stablehlo.elided.mlir b/mlir-tensorrt/test/models/bert.stablehlo.elided.mlir index fdea3997d..a08d9ed06 100644 --- a/mlir-tensorrt/test/models/bert.stablehlo.elided.mlir +++ b/mlir-tensorrt/test/models/bert.stablehlo.elided.mlir @@ -1,4 +1,4 @@ -module @jit__unnamed_wrapped_function_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { +module @bert attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<32x8xi32> {mhlo.layout_mode = "default"}) -> (tensor<32x8x768xf16> {mhlo.layout_mode = "default"}, tensor<32x768xf16> {mhlo.layout_mode = "default"}) { %0 = stablehlo.constant dense_resource<__elided__> : tensor<30522x768xf32> %1 = stablehlo.constant dense_resource<__elided__> : tensor<512x768xf32> diff --git a/mlir-tensorrt/test/models/gpt2.stablehlo.bs2.elided.mlir b/mlir-tensorrt/test/models/gpt2.stablehlo.bs2.elided.mlir index 4e8169f81..b08bbfff8 100644 --- a/mlir-tensorrt/test/models/gpt2.stablehlo.bs2.elided.mlir +++ b/mlir-tensorrt/test/models/gpt2.stablehlo.bs2.elided.mlir @@ -1,4 +1,4 @@ -module @jit_generate attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { +module @gpt2_bs2 attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<2x6xi32> {mhlo.sharding = "{replicated}"}, %arg1: tensor<2x6xi32> {mhlo.sharding = "{replicated}"}) -> (tensor<2x20xi32> {jax.result_info = ""}) { %0 = stablehlo.constant dense<0> : tensor<1xi32> %1 = stablehlo.constant dense<768> : tensor diff --git a/mlir-tensorrt/test/models/gpt2.stablehlo.elided.mlir b/mlir-tensorrt/test/models/gpt2.stablehlo.elided.mlir index 0be14b1d9..2cdb2e0aa 100644 --- a/mlir-tensorrt/test/models/gpt2.stablehlo.elided.mlir +++ b/mlir-tensorrt/test/models/gpt2.stablehlo.elided.mlir @@ -1,4 +1,4 @@ -module @jit_generate attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { +module @gpt_bs1 attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<1x7xi32> {jax.arg_info = "inputs['attention_mask']", mhlo.sharding = "{replicated}"}, %arg1: tensor<1x7xi32> {jax.arg_info = "inputs['input_ids']", mhlo.sharding = "{replicated}"}) -> (tensor<1x20xi32> {jax.result_info = ""}) { %0 = stablehlo.constant dense_resource<__elided__> : tensor<50257x768xf16> %1 = stablehlo.constant dense_resource<__elided__> : tensor<1024x768xf16> diff --git a/mlir-tensorrt/test/models/llama-68m.stablehlo.elided.mlir b/mlir-tensorrt/test/models/llama-68m.stablehlo.elided.mlir index cfd987ea9..698d515ec 100644 --- a/mlir-tensorrt/test/models/llama-68m.stablehlo.elided.mlir +++ b/mlir-tensorrt/test/models/llama-68m.stablehlo.elided.mlir @@ -1,4 +1,4 @@ -module @jit_generate attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32, mhlo.use_auto_spmd_partitioning = false} { +module @llama_68m attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32, mhlo.use_auto_spmd_partitioning = false} { func.func @main(%arg0: tensor<1x9xi32> {mhlo.sharding = "{replicated}"}, %arg1: tensor<1x9xi32> {mhlo.sharding = "{replicated}"}) -> tensor<1x20xi32> { %0 = stablehlo.constant dense<1.000000e+00> : tensor<1x1x3072xf32> %1 = stablehlo.constant dense<-3.40282347E+38> : tensor<1x1x1x20xf32> diff --git a/mlir-tensorrt/test/models/llama-v2.stablehlo.elided.mlir b/mlir-tensorrt/test/models/llama-v2.stablehlo.elided.mlir index 4fac66c90..41a59a058 100644 --- a/mlir-tensorrt/test/models/llama-v2.stablehlo.elided.mlir +++ b/mlir-tensorrt/test/models/llama-v2.stablehlo.elided.mlir @@ -1,4 +1,4 @@ -module @jit__unnamed_wrapped_function_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { +module @llama_v2 attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<1x27xf32> {mhlo.layout_mode = "default"}) -> (tensor<1x27x32000xf32> {mhlo.layout_mode = "default"}) { %0 = stablehlo.constant dense_resource<__elided__> : tensor<32000x4096xf16> %1 = stablehlo.constant dense_resource<__elided__> : tensor<4096xf16> diff --git a/mlir-tensorrt/test/models/resnet50.stablehlo.elided.mlir b/mlir-tensorrt/test/models/resnet50.stablehlo.elided.mlir index 781ce3e78..43cce81ab 100644 --- a/mlir-tensorrt/test/models/resnet50.stablehlo.elided.mlir +++ b/mlir-tensorrt/test/models/resnet50.stablehlo.elided.mlir @@ -1,4 +1,4 @@ -module @jit__unnamed_wrapped_function_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { +module @resnet50 attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<16x3x224x224xf16> {mhlo.layout_mode = "default"}) -> (tensor<16x1000xf16> {mhlo.layout_mode = "default"}) { %0 = stablehlo.constant dense_resource<__elided__> : tensor<7x7x3x64xf32> %1 = stablehlo.constant dense_resource<__elided__> : tensor<64xf32> diff --git a/mlir-tensorrt/test/models/swin.stablehlo.elided.mlir b/mlir-tensorrt/test/models/swin.stablehlo.elided.mlir index 91c58b585..732a39abc 100644 --- a/mlir-tensorrt/test/models/swin.stablehlo.elided.mlir +++ b/mlir-tensorrt/test/models/swin.stablehlo.elided.mlir @@ -1,4 +1,4 @@ -module @jit_run attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32, mhlo.use_auto_spmd_partitioning = false} { +module @swin attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32, mhlo.use_auto_spmd_partitioning = false} { func.func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1000xf32> { %cst = stablehlo.constant dense_resource<__elided__> : tensor<1x1000xf32> %cst_0 = stablehlo.constant dense_resource<__elided__> : tensor<1024x1000xf32> diff --git a/mlir-tensorrt/test/models/whisper-jax.stablehlo.elided.mlir b/mlir-tensorrt/test/models/whisper-jax.stablehlo.elided.mlir index fcdedf2b3..6c6d1448a 100644 --- a/mlir-tensorrt/test/models/whisper-jax.stablehlo.elided.mlir +++ b/mlir-tensorrt/test/models/whisper-jax.stablehlo.elided.mlir @@ -1,4 +1,4 @@ -module @jit_generate_fn attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { +module @whisper_jax attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<1x80x3000xf32> {jax.arg_info = "input_features", mhlo.sharding = "{replicated}"}) -> (tensor<1x448xi32> {jax.result_info = ""}) { %0 = stablehlo.constant dense_resource<__elided__> : tensor<3x80x384xf32> %1 = stablehlo.constant dense_resource<__elided__> : tensor<384xf32> From 49fede341016391c4022c27a60dbec31b75eb261 Mon Sep 17 00:00:00 2001 From: Christopher Bate Date: Sat, 16 Nov 2024 01:23:29 -0700 Subject: [PATCH 17/29] [compiler/StableHloExt] Incorporate some simplification patterns from upstream StableHLO (#384) Our next LLVM & StableHLO upgrade will incorporate some additional simplification patterns. However, the upgrade is large and will not fully land until mid next week. Until then, incorporate some critical concat and slice simplification patterns that simplify shape calculation IR. Solves https://github.com/NVIDIA/TensorRT-Incubator/issues/381. GitOrigin-RevId: 2a875a35547b89c96d530878907312a0f898e508 --- .../Transforms/ConstantFolding.cpp | 153 ++++++++++++++++++ .../StableHloExt/constant-folding.mlir | 16 ++ 2 files changed, 169 insertions(+) diff --git a/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/ConstantFolding.cpp b/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/ConstantFolding.cpp index a36004bfb..abd7b52ab 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/ConstantFolding.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/ConstantFolding.cpp @@ -29,6 +29,7 @@ #include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "stablehlo/dialect/StablehloOps.h" #include "stablehlo/dialect/TypeInference.h" @@ -1064,6 +1065,10 @@ struct AbsorbTensorCastProducer : public RewritePattern { }; } // namespace +/// Populates patterns that are temporarily reproduced here from upstream +/// commits we have not yet integrated. +static void populateFutureUpstreamPatterns(RewritePatternSet &patterns); + void stablehlo_ext::populateStableHloAbsorbTensorCastPatterns( RewritePatternSet &patterns) { patterns.add(patterns.getContext()); @@ -1108,6 +1113,7 @@ class ConstantFoldingPass SqrtOpFolder >(ctx); // clang-format on + populateFutureUpstreamPatterns(patterns); populateStableHloAbsorbTensorCastPatterns(patterns); stablehlo::populateStablehloCanonicalizationPatterns(ctx, &patterns); tensor::EmptyOp::getCanonicalizationPatterns(patterns, ctx); @@ -1124,3 +1130,150 @@ class ConstantFoldingPass } }; } // namespace + +//===----------------------------------------------------------------------===// +/// The patterns below this point are reproduced from +/// https://github.com/openxla/stablehlo/commit/5d15ab064f165cc6773ef4ba949ac083ae8e1fea, +/// which is in upstream, but our current pinned StableHlo commit is not there +/// yet. The patterns can be removed in the next StableHLO upgrade. +/// +//===----------------------------------------------------------------------===// + +/// +/// In cases where a concat is fed into a slice, it +/// is possible the concat can be simplified or bypassed. This checks which +/// inputs to the concat are used by the slice, either reducing the number of +/// concatenated values or entirely removes the concat. Pattern: +/// slice(concat(X,Y,Z,...),...) -> concat(slice(X),slice(Y),slice(Z)) +struct SimplifySliceOfConcat : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SliceOp slice, + PatternRewriter &rewriter) const override { + RankedTensorType resultTy = slice.getType(); + if (!resultTy.hasStaticShape()) + return rewriter.notifyMatchFailure(slice, "result shape not static"); + + auto concat = slice.getOperand().getDefiningOp(); + if (!concat) + return rewriter.notifyMatchFailure(slice, "slice input not concat"); + + RankedTensorType concatType = concat.getType(); + uint64_t dimension = concat.getDimension(); + + ArrayRef start = slice.getStartIndices(); + ArrayRef limit = slice.getLimitIndices(); + + int64_t sliceStart = start[dimension]; + int64_t sliceLimit = limit[dimension]; + + // We need to determine what inputs from the concat affect the slice, and + // how the bounds of the slice need to be updated for the minimally required + // inputs. + int64_t runningSize = 0; + int64_t frontOffset = concatType.getShape()[dimension]; + + auto subsetStart = concat.operand_end(); + auto subsetEnd = concat.operand_end(); + for (auto it = concat.operand_begin(); it < concat.operand_end(); ++it) { + Value input = *it; + auto inputTy = cast(input.getType()); + if (inputTy.isDynamicDim(dimension)) + return rewriter.notifyMatchFailure( + slice, "concat input has dynamic dimension"); + + int64_t dimSize = inputTy.getShape()[dimension]; + + // If this position is in the slice its the start of the subset and we + // need to update the start and limit values. + if (runningSize + dimSize > sliceStart && + subsetStart == concat.operand_end()) { + subsetStart = it; + frontOffset = runningSize; + } + + // Determine the last required offset. + if (runningSize < sliceLimit) { + subsetEnd = it + 1; + } + + runningSize += dimSize; + } + + auto subsetSize = subsetEnd - subsetStart; + // We need all inputs so no optimization. + if (subsetSize == concat.getNumOperands()) + return rewriter.notifyMatchFailure(slice, + "slice needs all concat inputs"); + + // If there's nothing to slice that means the output is an empty tensor and + // there is dead code. We do nothing here and rely on other passes to clean + // this up. + if (subsetSize == 0) + return rewriter.notifyMatchFailure(slice, "slice is empty"); + + if (subsetSize > 1 && !concat.getResult().hasOneUse()) + return rewriter.notifyMatchFailure(slice, + "slice is not the only concat user"); + + auto concatRange = OperandRange(subsetStart, subsetEnd); + auto newConcat = rewriter.create( + concat.getLoc(), concatRange, concat.getDimension()); + + SmallVector newStart(start); + SmallVector newLimit(limit); + newStart[dimension] -= frontOffset; + newLimit[dimension] -= frontOffset; + + rewriter.replaceOpWithNewOp( + slice, newConcat, rewriter.getDenseI64ArrayAttr(newStart), + rewriter.getDenseI64ArrayAttr(newLimit), slice.getStrides()); + return success(); + } +}; + +/// Flatten sequential concatenations as long as the parent concatenation either +/// has a single use or is <= 32 elements. +class SimplifyConcatOfConcatPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ConcatenateOp op, + PatternRewriter &rewriter) const override { + auto getFlattenedOperands = [&](const Value &val) -> ValueRange { + auto definingOp = dyn_cast_or_null(val.getDefiningOp()); + if (!definingOp || definingOp.getDimension() != op.getDimension()) + return val; + if (definingOp->hasOneUse()) + return definingOp.getInputs(); + if (!definingOp.getType().hasStaticShape()) + return val; + if (definingOp.getType().getNumElements() <= 32) + return definingOp.getInputs(); + return val; + }; + + bool needToFlatten = false; + int operandCount = 0; + for (Value val : op.getInputs()) { + ValueRange result = getFlattenedOperands(val); + if (result.size() != 1 || result[0] != val) + needToFlatten = true; + operandCount += result.size(); + } + if (!needToFlatten) + return rewriter.notifyMatchFailure(op, "no need to flatten"); + + llvm::SmallVector newOperands; + newOperands.reserve(operandCount); + for (Value operand : op.getInputs()) + llvm::append_range(newOperands, getFlattenedOperands(operand)); + + rewriter.modifyOpInPlace(op, [&] { op->setOperands(newOperands); }); + return success(); + } +}; + +void populateFutureUpstreamPatterns(RewritePatternSet &patterns) { + patterns.add( + patterns.getContext()); +} diff --git a/mlir-tensorrt/test/Dialect/StableHloExt/constant-folding.mlir b/mlir-tensorrt/test/Dialect/StableHloExt/constant-folding.mlir index 682875219..515c6160e 100644 --- a/mlir-tensorrt/test/Dialect/StableHloExt/constant-folding.mlir +++ b/mlir-tensorrt/test/Dialect/StableHloExt/constant-folding.mlir @@ -402,6 +402,22 @@ func.func @concat_simplify_single_operand_requires_cast(%arg0: tensor<4xi32>) -> // ----- +func.func @concat_slice_concat(%arg0: tensor<1xi32>, %arg1: tensor<3xi32>, %arg2: tensor<1xi32>) -> tensor<5xi32> { + %0 = stablehlo.concatenate %arg0, %arg1, %arg2, dim = 0 : (tensor<1xi32>, tensor<3xi32>, tensor<1xi32>) -> tensor<5xi32> + %1 = stablehlo.slice %0 [1:5] : (tensor<5xi32>) -> tensor<4xi32> + %2 = stablehlo.constant dense<1> : tensor<1xi32> + %3 = stablehlo.concatenate %2, %1, dim = 0 : (tensor<1xi32>, tensor<4xi32>) -> tensor<5xi32> + return %3 : tensor<5xi32> +} + +// CHECK-LABEL: func.func @concat_slice_concat +// CHECK-SAME: (%[[arg0:.+]]: tensor<1xi32>, %[[arg1:.+]]: tensor<3xi32>, %[[arg2:.+]]: tensor<1xi32>) -> tensor<5xi32> +// CHECK: %[[c:.+]] = stablehlo.constant dense<1> : tensor<1xi32> +// CHECK: %[[v0:.+]] = stablehlo.concatenate %[[c]], %[[arg1]], %[[arg2]], dim = 0 +// CHECK: return %[[v0]] : tensor<5xi32> + +// ----- + func.func @bitwise_or_fold_lhs(%arg0: tensor<5xi8>, %arg1: tensor<5xi1>, %arg2: tensor<5xi32>) -> (tensor<5xi8>, tensor<5xi1>, tensor<5xi32>, tensor<5xi32>){ %0 = stablehlo.constant dense<[255, 255, 255, 255, 255]> : tensor<5xi8> %1 = stablehlo.or %0, %arg0 : tensor<5xi8> From 38cfd580a1554a623259176f893e5a2d0e6a2854 Mon Sep 17 00:00:00 2001 From: Christopher Bate Date: Sat, 16 Nov 2024 15:10:37 -0700 Subject: [PATCH 18/29] [compiler/StableHloExt] Bump max iterations in `stablehlo-ext-refine-shapes` (#385) Increases the max number of iterations that the initial dynamic shape refinement pipeline will iterate in order to better reported use cases. --- .../Dialect/StableHloExt/Transforms/Passes.td | 2 +- .../Dialect/StableHloExt/refine-shapes.mlir | 86 +++++++++++++++++++ 2 files changed, 87 insertions(+), 1 deletion(-) create mode 100644 mlir-tensorrt/test/Dialect/StableHloExt/refine-shapes.mlir diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/StableHloExt/Transforms/Passes.td b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/StableHloExt/Transforms/Passes.td index 398f70aa1..76e7b1164 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/StableHloExt/Transforms/Passes.td +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/StableHloExt/Transforms/Passes.td @@ -95,7 +95,7 @@ def CanonicalizeShapesPass : Pass<"stablehlo-ext-canonicalize-shapes", "ModuleOp }]; let options = [ - Option<"maxIterations", "max-iterations", "int64_t", "4", + Option<"maxIterations", "max-iterations", "int64_t", "8", "the maximum number of iterations to run the dynamism simplification and " "shape refinement if a fixed-point is not reached"> ]; diff --git a/mlir-tensorrt/test/Dialect/StableHloExt/refine-shapes.mlir b/mlir-tensorrt/test/Dialect/StableHloExt/refine-shapes.mlir new file mode 100644 index 000000000..cbf0ca09e --- /dev/null +++ b/mlir-tensorrt/test/Dialect/StableHloExt/refine-shapes.mlir @@ -0,0 +1,86 @@ +// RUN: mlir-tensorrt-opt %s -split-input-file -stablehlo-ext-refine-shapes | FileCheck %s + +func.func @check_type_refinement() -> tensor { + %c = stablehlo.constant dense<[1, 2, 3]> : tensor<3xi32> + %c_0 = stablehlo.constant dense<3> : tensor + %c_1 = stablehlo.constant dense<1> : tensor<1xi32> + %c_2 = stablehlo.constant dense<3> : tensor<1xi32> + %c_3 = stablehlo.constant dense<1> : tensor + %c_4 = stablehlo.constant dense<1> : tensor<1xi32> + %c_5 = stablehlo.constant dense<0> : tensor + %c_6 = stablehlo.constant dense<1> : tensor + %c_7 = stablehlo.constant dense<0> : tensor<1xi32> + %c_8 = stablehlo.constant dense<1> : tensor<1xi32> + %0 = stablehlo.compare LE, %c_7, %c_8 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %1 = stablehlo.select %0, %c_7, %c_8 : tensor<1xi1>, tensor<1xi32> + %c_9 = stablehlo.constant dense<1> : tensor<1xi32> + %2 = stablehlo.real_dynamic_slice %c_4, %1, %c_8, %c_9 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %c_10 = stablehlo.constant dense<> : tensor<0xi32> + %3 = stablehlo.dynamic_reshape %2, %c_10 : (tensor, tensor<0xi32>) -> tensor + %c_11 = stablehlo.constant dense<-1> : tensor + %c_12 = stablehlo.constant dense<> : tensor<0xi32> + %4 = stablehlo.compare EQ, %c_12, %c_10 : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi1> + %5 = stablehlo.select %4, %c_12, %c_12 : tensor<0xi1>, tensor<0xi32> + %6 = stablehlo.dynamic_broadcast_in_dim %3, %5, dims = [] : (tensor, tensor<0xi32>) -> tensor + %7 = stablehlo.dynamic_broadcast_in_dim %c_11, %5, dims = [] : (tensor, tensor<0xi32>) -> tensor + %8 = stablehlo.add %6, %7 : tensor + %c_13 = stablehlo.constant dense<0> : tensor<1xi32> + %c_14 = stablehlo.constant dense<1> : tensor<1xi32> + %9 = stablehlo.compare LE, %c_13, %c_14 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %10 = stablehlo.select %9, %c_13, %c_14 : tensor<1xi1>, tensor<1xi32> + %c_15 = stablehlo.constant dense<1> : tensor<1xi32> + %11 = stablehlo.real_dynamic_slice %c_4, %10, %c_14, %c_15 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %12 = stablehlo.dynamic_reshape %11, %c_10 : (tensor, tensor<0xi32>) -> tensor + %13 = stablehlo.compare EQ, %c_12, %c_10 : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi1> + %14 = stablehlo.select %13, %c_12, %c_12 : tensor<0xi1>, tensor<0xi32> + %15 = stablehlo.dynamic_broadcast_in_dim %12, %14, dims = [] : (tensor, tensor<0xi32>) -> tensor + %16 = stablehlo.dynamic_broadcast_in_dim %c_11, %14, dims = [] : (tensor, tensor<0xi32>) -> tensor + %17 = stablehlo.add %15, %16 : tensor + %18 = stablehlo.compare EQ, %c_12, %c_10 : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi1> + %19 = stablehlo.select %18, %c_12, %c_12 : tensor<0xi1>, tensor<0xi32> + %20 = stablehlo.dynamic_broadcast_in_dim %17, %19, dims = [] : (tensor, tensor<0xi32>) -> tensor + %21 = stablehlo.dynamic_broadcast_in_dim %c_6, %19, dims = [] : (tensor, tensor<0xi32>) -> tensor + %22 = stablehlo.add %20, %21 : tensor + %23 = stablehlo.reshape %8 : (tensor) -> tensor<1xi32> + %24 = stablehlo.reshape %22 : (tensor) -> tensor<1xi32> + %25 = stablehlo.compare LE, %23, %24 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %26 = stablehlo.select %25, %23, %24 : tensor<1xi1>, tensor<1xi32> + %c_16 = stablehlo.constant dense<1> : tensor<1xi32> + %27 = stablehlo.real_dynamic_slice %c_2, %26, %24, %c_16 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %28 = stablehlo.dynamic_reshape %27, %c_10 : (tensor, tensor<0xi32>) -> tensor + %29 = stablehlo.dynamic_broadcast_in_dim %28, %c_1, dims = [] : (tensor, tensor<1xi32>) -> tensor<1xi32> + %cst = stablehlo.constant dense<1.000000e+00> : tensor + %30 = stablehlo.dynamic_broadcast_in_dim %cst, %29, dims = [] : (tensor, tensor<1xi32>) -> tensor + return %30 : tensor +} + +// CHECK-LABEL: func.func @check_type_refinement +// CHECK-DAG: %[[cst:.+]] = stablehlo.constant dense<1.000000e+00> : tensor +// CHECK-DAG: %[[c:.+]] = stablehlo.constant dense<-1> : tensor +// CHECK-DAG: %[[c_0:.+]] = stablehlo.constant dense<> : tensor<0xi32> +// CHECK-DAG: %[[c_1:.+]] = stablehlo.constant dense<1> : tensor<1xi32> +// CHECK-DAG: %[[c_2:.+]] = stablehlo.constant dense<3> : tensor<1xi32> +// CHECK-DAG: %[[c_3:.+]] = stablehlo.constant dense<1> : tensor +// CHECK-DAG: %[[c_4:.+]] = stablehlo.constant dense<0> : tensor<1xi32> +// CHECK-DAG: %[[v0:.+]] = stablehlo.real_dynamic_slice %[[c_1]], %[[c_4]], %[[c_1]], %[[c_1]] : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> +// CHECK-DAG: %[[v1:.+]] = stablehlo.dynamic_reshape %[[v0]], %[[c_0]] : (tensor<1xi32>, tensor<0xi32>) -> tensor +// CHECK-DAG: %[[v2:.+]] = stablehlo.dynamic_broadcast_in_dim %[[v1]], %[[c_0]], dims = [] : (tensor, tensor<0xi32>) -> tensor +// CHECK-DAG: %[[v3:.+]] = stablehlo.dynamic_broadcast_in_dim %[[c]], %[[c_0]], dims = [] : (tensor, tensor<0xi32>) -> tensor +// CHECK-DAG: %[[v4:.+]] = stablehlo.add %[[v2]], %[[v3]] : tensor +// CHECK-DAG: %[[v5:.+]] = stablehlo.real_dynamic_slice %[[c_1]], %[[c_4]], %[[c_1]], %[[c_1]] : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> +// CHECK-DAG: %[[v6:.+]] = stablehlo.dynamic_reshape %[[v5]], %[[c_0]] : (tensor<1xi32>, tensor<0xi32>) -> tensor +// CHECK-DAG: %[[v7:.+]] = stablehlo.dynamic_broadcast_in_dim %[[v6]], %[[c_0]], dims = [] : (tensor, tensor<0xi32>) -> tensor +// CHECK-DAG: %[[v8:.+]] = stablehlo.dynamic_broadcast_in_dim %[[c]], %[[c_0]], dims = [] : (tensor, tensor<0xi32>) -> tensor +// CHECK-DAG: %[[v9:.+]] = stablehlo.add %[[v7]], %[[v8]] : tensor +// CHECK-DAG: %[[v10:.+]] = stablehlo.dynamic_broadcast_in_dim %[[v9]], %[[c_0]], dims = [] : (tensor, tensor<0xi32>) -> tensor +// CHECK-DAG: %[[v11:.+]] = stablehlo.dynamic_broadcast_in_dim %[[c_3]], %[[c_0]], dims = [] : (tensor, tensor<0xi32>) -> tensor +// CHECK-DAG: %[[v12:.+]] = stablehlo.add %[[v10]], %[[v11]] : tensor +// CHECK-DAG: %[[v13:.+]] = stablehlo.reshape %[[v4]] : (tensor) -> tensor<1xi32> +// CHECK-DAG: %[[v14:.+]] = stablehlo.reshape %[[v12]] : (tensor) -> tensor<1xi32> +// CHECK-DAG: %[[v15:.+]] = stablehlo.compare LE, %[[v13]], %[[v14]] : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> +// CHECK-DAG: %[[v16:.+]] = stablehlo.select %[[v15]], %[[v13]], %[[v14]] : tensor<1xi1>, tensor<1xi32> +// CHECK-DAG: %[[v17:.+]] = stablehlo.real_dynamic_slice %[[c_2]], %[[v16]], %[[v14]], %[[c_1]] : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor +// CHECK-DAG: %[[v18:.+]] = stablehlo.dynamic_reshape %[[v17]], %[[c_0]] : (tensor, tensor<0xi32>) -> tensor +// CHECK-DAG: %[[v19:.+]] = stablehlo.dynamic_broadcast_in_dim %[[v18]], %[[c_1]], dims = [] : (tensor, tensor<1xi32>) -> tensor<1xi32> +// CHECK-DAG: %[[v20:.+]] = stablehlo.dynamic_broadcast_in_dim %[[cst]], %[[v19]], dims = [] : (tensor, tensor<1xi32>) -> tensor +// CHECK-DAG: return %[[v20]] : tensor \ No newline at end of file From d37e4f8a46e274a7d68977e386a8ed635425683e Mon Sep 17 00:00:00 2001 From: Faraz <58580514+farazkh80@users.noreply.github.com> Date: Mon, 18 Nov 2024 10:35:56 -0500 Subject: [PATCH 19/29] [Tripy] Mixed module list/dict for `tp.Module` (#365) PR for addressing issue https://github.com/NVIDIA/TensorRT-Incubator/issues/302. Supporting mixed types is important, since tripy does not have modules like `Identity` or `AveragePool`. With this, users can make sequential nets or module lists that include lambda functions or other callables like `tp.avgpool`. - [x] modified `_iterate_members_of_type` to have a more relaxed check on checking module/paramater types. - [x] Added mixed container tests to `tests/frontend/module/test_module.py` - [x] Added mixed container tests to `tests/frontend/module/test_sequential.py` - [x] Added example mixed container usage to `tp.Module` - [x] Added example mixed container usage to `tp.Sequential` --------- Signed-off-by: Faraz <58580514+farazkh80@users.noreply.github.com> Co-authored-by: pranavm-nvidia <49246958+pranavm-nvidia@users.noreply.github.com> --- tripy/tests/frontend/module/conftest.py | 33 +++++++++ tripy/tests/frontend/module/test_module.py | 74 +++++++++++++++++++ .../tests/frontend/module/test_sequential.py | 71 ++++++++++++++++++ tripy/tripy/frontend/module/module.py | 42 +++++------ tripy/tripy/frontend/module/sequential.py | 19 ++++- 5 files changed, 212 insertions(+), 27 deletions(-) diff --git a/tripy/tests/frontend/module/conftest.py b/tripy/tests/frontend/module/conftest.py index bbcca76e7..2cac5f87b 100644 --- a/tripy/tests/frontend/module/conftest.py +++ b/tripy/tests/frontend/module/conftest.py @@ -109,6 +109,34 @@ def __call__(self): return out1 + out2 +class MixedContainerNetwork(tp.Module): + def __init__(self): + super().__init__() + self.param = tp.Parameter(tp.ones((2,), dtype=tp.float32)) + + # Define a mixed list with both modules and lambda functions + self.mixed_list = [ + DummyOp(tp.zeros((2,), dtype=tp.float32)), + lambda: tp.ones((2,), dtype=tp.float32), + DummyOp(tp.zeros((2,), dtype=tp.float32)), + ] + + # Define a mixed dictionary with modules and lambda functions + self.mixed_dict = { + "scale": lambda: tp.ones((2,), dtype=tp.float32), + "dummy": DummyOp(tp.zeros((2,), dtype=tp.float32)), + } + + def __call__(self): + out = self.param + for op in self.mixed_list: + out = out + op() + for _, op in self.mixed_dict.items(): + out = out + op() + + return out + + @pytest.fixture(params=[(Network, ())]) def all_network_modes(request): call_args = request.param[1] @@ -139,3 +167,8 @@ def mixed_network(): @pytest.fixture def complex_network(): yield ComplexNetwork() + + +@pytest.fixture +def mixed_container_network(): + yield MixedContainerNetwork() diff --git a/tripy/tests/frontend/module/test_module.py b/tripy/tests/frontend/module/test_module.py index 245f0483b..00ceed886 100644 --- a/tripy/tests/frontend/module/test_module.py +++ b/tripy/tests/frontend/module/test_module.py @@ -308,3 +308,77 @@ def test_state_dict(self, complex_network): assert module.nets["dict_net"].dummy_dict["op1"].nested.param is tensor assert module.nets["list_net"].dummy_list[0].nested.param is tensor assert module.nets["list_net"].dummy_list[1].nested.param is tensor + + +class TestMixedContainerNetwork: + def test_basic_structure(self, mixed_container_network): + assert hasattr(mixed_container_network, "mixed_list") + assert hasattr(mixed_container_network, "mixed_dict") + + assert isinstance(mixed_container_network.mixed_list, list) + assert isinstance(mixed_container_network.mixed_dict, dict) + + assert any(isinstance(item, tp.Module) for item in mixed_container_network.mixed_list) + assert any(callable(item) and not isinstance(item, tp.Module) for item in mixed_container_network.mixed_list) + + assert any(isinstance(value, tp.Module) for value in mixed_container_network.mixed_dict.values()) + assert any( + callable(value) and not isinstance(value, tp.Module) + for value in mixed_container_network.mixed_dict.values() + ) + + def test_named_children(self, mixed_container_network): + children = list(mixed_container_network.named_children()) + + assert len(children) == 3 + for _, child in children: + assert isinstance(child, tp.Module) + + def test_state_dict(self, mixed_container_network): + expected_keys = set( + ["param", "mixed_list.0.nested.param", "mixed_list.2.nested.param", "mixed_dict.dummy.nested.param"] + ) + + state_dict = mixed_container_network.state_dict() + assert set(state_dict.keys()) == expected_keys + + def test_load_state_dict(self, mixed_container_network): + tensor = tp.Parameter(tp.ones((2,))) + state_dict = { + "param": tensor, + "mixed_list.0.nested.param": tensor, + "mixed_list.2.nested.param": tensor, + "mixed_dict.dummy.nested.param": tensor, + } + + missing_keys, unexpected_keys = mixed_container_network.load_state_dict(state_dict, strict=True) + + assert not missing_keys + assert not unexpected_keys + assert mixed_container_network.param is state_dict["param"] + assert mixed_container_network.mixed_list[0].nested.param is state_dict["mixed_list.0.nested.param"] + assert mixed_container_network.mixed_list[2].nested.param is state_dict["mixed_list.2.nested.param"] + assert mixed_container_network.mixed_dict["dummy"].nested.param is state_dict["mixed_dict.dummy.nested.param"] + + def test_print_structure(self, mixed_container_network): + assert str(mixed_container_network) == dedent( + """\ + MixedContainerNetwork( + param: Parameter = (shape=[2], dtype=float32), + mixed_list.0: Module = DummyOp( + nested: Module = DummyNestedOp( + param: Parameter = (shape=[2], dtype=float32), + ), + ), + mixed_list.2: Module = DummyOp( + nested: Module = DummyNestedOp( + param: Parameter = (shape=[2], dtype=float32), + ), + ), + mixed_dict.dummy: Module = DummyOp( + nested: Module = DummyNestedOp( + param: Parameter = (shape=[2], dtype=float32), + ), + ), + )""" + ) diff --git a/tripy/tests/frontend/module/test_sequential.py b/tripy/tests/frontend/module/test_sequential.py index 43ad48ed2..fd69f8325 100644 --- a/tripy/tests/frontend/module/test_sequential.py +++ b/tripy/tests/frontend/module/test_sequential.py @@ -37,6 +37,16 @@ def dict_sequential_network(): yield tp.Sequential({"layer1": tp.Linear(1, 3), "layer2": tp.Linear(3, 2)}) +@pytest.fixture +def mixed_container_sequential_network(): + yield tp.Sequential( + tp.Conv(in_channels=2, out_channels=2, kernel_dims=(1, 1), stride=(1, 1)), + lambda x: tp.avgpool(x, kernel_dims=(2, 2), stride=(1, 1)), + lambda x: tp.flatten(x, start_dim=1), + tp.Linear(2, 1), + ) + + @pytest.fixture def nested_sequential_network(): yield tp.Sequential(tp.Linear(2, 4), tp.Sequential(tp.Linear(4, 3), tp.Linear(3, 1))) @@ -146,6 +156,67 @@ def test_str_representation(self, dict_sequential_network): assert str(dict_sequential_network) == expected_str +class TestMixedContainerSequential: + def test_basic_structure(self, mixed_container_sequential_network): + assert len(mixed_container_sequential_network) == 4 + assert isinstance(mixed_container_sequential_network[0], tp.Module) + assert callable(mixed_container_sequential_network[1]) + assert callable(mixed_container_sequential_network[2]) + assert isinstance(mixed_container_sequential_network[3], tp.Module) + + def test_forward_pass(self, mixed_container_sequential_network): + input_data = tp.Tensor(tp.ones((1, 2, 2, 2), dtype=tp.float32)) + output = mixed_container_sequential_network(input_data) + assert output.shape == [1, 1] + + def test_named_children(self, mixed_container_sequential_network): + expected_names = [("0", mixed_container_sequential_network[0]), ("3", mixed_container_sequential_network[3])] + assert list(mixed_container_sequential_network.named_children()) == expected_names + + def test_state_dict(self, mixed_container_sequential_network): + state_dict = mixed_container_sequential_network.state_dict() + expected_keys = set(["0.bias", "0.weight", "3.weight", "3.bias"]) + assert set(state_dict.keys()) == expected_keys + + def test_load_state_dict(self, mixed_container_sequential_network): + new_state_dict = { + "0.weight": tp.Parameter(tp.ones((2, 2, 1, 1), dtype=tp.float32)), + "0.bias": tp.Parameter(tp.zeros((2,), dtype=tp.float32)), + "3.weight": tp.Parameter(tp.zeros((1, 2), dtype=tp.float32)), + "3.bias": tp.Parameter(tp.zeros((1,), dtype=tp.float32)), + } + mixed_container_sequential_network.load_state_dict(new_state_dict, strict=False) + + assert np.array_equal( + cp.from_dlpack(mixed_container_sequential_network[0].weight), cp.from_dlpack(new_state_dict["0.weight"]) + ) + assert np.array_equal( + cp.from_dlpack(mixed_container_sequential_network[0].bias), cp.from_dlpack(new_state_dict["0.bias"]) + ) + assert np.array_equal( + cp.from_dlpack(mixed_container_sequential_network[3].weight), cp.from_dlpack(new_state_dict["3.weight"]) + ) + assert np.array_equal( + cp.from_dlpack(mixed_container_sequential_network[3].bias), cp.from_dlpack(new_state_dict["3.bias"]) + ) + + def test_str_representation(self, mixed_container_sequential_network): + expected_str = dedent( + """\ + Sequential( + 0: Module = Conv( + bias: Parameter = (shape=[2], dtype=float32), + weight: Parameter = (shape=[2, 2, 1, 1], dtype=float32), + ), + 3: Module = Linear( + weight: Parameter = (shape=[1, 2], dtype=float32), + bias: Parameter = (shape=[1], dtype=float32), + ), + )""" + ) + assert str(mixed_container_sequential_network) == expected_str + + class TestNestedSequential: def test_basic_structure(self, nested_sequential_network): # Check that the top-level Sequential has two layers and that one of them is a nested Sequential diff --git a/tripy/tripy/frontend/module/module.py b/tripy/tripy/frontend/module/module.py index 2cc41cc85..ccf7fbc56 100644 --- a/tripy/tripy/frontend/module/module.py +++ b/tripy/tripy/frontend/module/module.py @@ -40,10 +40,6 @@ def _check_param_compatible(original_param, new_param, param_name): ) -def _is_homogeneous_container(container: Sequence, typ: T): - return all(isinstance(elem, typ) for elem in container) - - def _contains_types(container: Sequence, types: type): return any(any(isinstance(elem, typ) for typ in types) for elem in container) @@ -54,8 +50,10 @@ class Module: Base class used to define neural network modules. You can nest modules by assigning them as attributes of other modules. - Child modules or :class:`tripy.Parameter` s may be contained in Python ``list``\s or ``dict``\s. - If using ``dict``\s, the keys must be strings. + Child modules, :class:`tripy.Parameter` s, or other callables/lambda functions may be contained + in Python ``list``\ s or ``dict``\ s. + + If using ``dict``\ s, the keys must be strings. Nested data structures (for example, ``list``\s of ``list``\s) are not supported. Taking child modules as an example, this is allowed: :: @@ -67,6 +65,14 @@ class Module: "layernorm": tp.LayerNorm(2), } + This is another valid example with a wrapped :class:`tripy.avgpool` lambda function + :: + + self.dict_modules = { + "convolution": tp.Conv(in_channels=2, out_channels=2, kernel_dims=(1,1), stride=(1,1)), + "pool": lambda x: tp.avgpool(x, kernel_dims=(2,2), stride=(1,1)) + } + Whereas this is not supported: :: @@ -106,20 +112,6 @@ def __setattr__(self, name: str, value: Any) -> None: if value is None: return - if isinstance(value, List) or isinstance(value, Dict): - container = value if isinstance(value, List) else value.values() - if _contains_types(container, [Parameter, Module]) and ( - not _is_homogeneous_container(container, Parameter) and not _is_homogeneous_container(container, Module) - ): - stack_info = utils.get_stack_info() - stack_info.fetch_source_code() - stack_info_msg = str_from_stack_info(stack_info) - - logger.warning( - "A container of mixed types will not be registered with this module's state_dict()." - + (f"\nNote: container was set here: {stack_info_msg}" if stack_info_msg else "") - ) - def state_dict(self) -> Dict[str, Parameter]: r""" Returns a dictionary mapping names to parameters in the module. @@ -315,12 +307,14 @@ def _iterate_members_of_type(self, typ: T) -> Iterator[Tuple[str, T]]: for name, value in vars(self).items(): if isinstance(value, typ): yield name, value - elif isinstance(value, List) and _is_homogeneous_container(value, typ): + elif isinstance(value, List): for i, obj in enumerate(value): - yield f"{name}.{i}", obj - elif isinstance(value, Dict) and _is_homogeneous_container(value.values(), typ): + if isinstance(obj, typ): + yield f"{name}.{i}", obj + elif isinstance(value, Dict): for key, obj in value.items(): - yield f"{name}.{key}", obj + if isinstance(obj, typ): + yield f"{name}.{key}", obj def __str__(self): from textwrap import indent diff --git a/tripy/tripy/frontend/module/sequential.py b/tripy/tripy/frontend/module/sequential.py index d30982be9..cd825efb2 100644 --- a/tripy/tripy/frontend/module/sequential.py +++ b/tripy/tripy/frontend/module/sequential.py @@ -27,8 +27,8 @@ @dataclass class Sequential(Module): r""" - A module to stack multiple layers or modules in a sequential order. The `Sequential` - container can accept either a list of modules or a dictionary of named modules. Modules are + A module to stack multiple callable layers or modules in a sequential order. The `Sequential` + container can accept either a list of modules/callable objects or a dictionary of named modules/callable objects. Layers are added in the order they are passed, and each is called sequentially during the forward pass. """ @@ -55,6 +55,18 @@ def __init__(self, *modules: Union[Module, Dict[str, Module]]) -> None: input = tp.Tensor([1.0]) output = model(input) + + .. code-block:: python + :linenos: + :caption: Sequential with Callables + + model = tp.Sequential( + tp.Conv(in_channels=2, out_channels=2, kernel_dims=(1,1), stride=(1,1)), + lambda x: tp.avgpool(x, kernel_dims=(2,2), stride=(1,1)) + ) + + input = tp.ones((1,2,2,2), dtype=tp.float32) + output = model(input) """ super().__init__() self.modules = {} @@ -178,4 +190,5 @@ def named_children(self) -> Iterator[Tuple[str, "Module"]]: # with the 'modules' prefix in the state_dict. This change ensures compatibility # with PyTorch's naming conventions. for name, module in self.modules.items(): - yield name, module + if isinstance(module, Module): + yield name, module From 718ded641e79b2500faf830055124e95612a88ab Mon Sep 17 00:00:00 2001 From: Christopher Bate Date: Mon, 18 Nov 2024 15:03:28 -0800 Subject: [PATCH 20/29] [mlir-tensorrt] NFC: [executor] update NCCL debug logging in runtime Adds some additional debugging logging in the NCCL executor module. GitOrigin-RevId: b628f4613e87e148ed444342932dc1cf909851e3 --- .../Backend/Lua/Modules/CUDA/CUDAModule.cpp | 8 ++++++-- .../Backend/Lua/Modules/NCCL/NCCLModule.cpp | 15 +++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/CUDA/CUDAModule.cpp b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/CUDA/CUDAModule.cpp index 912873b93..6410a273c 100644 --- a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/CUDA/CUDAModule.cpp +++ b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/CUDA/CUDAModule.cpp @@ -227,6 +227,7 @@ registerCudaMemoryManagementOps(sol::state_view &lua, }; lua["__cuda_stream_sync"] = [](sol::this_state state, CudaStreamPtr stream) { + MTRT_DBG("__cuda_stream_sync @ {0}", reinterpret_cast(stream.ptr)); ADD_CUDA_MODULE_RANGE("cuda_stream_sync"); SET_LUA_ERROR_IF_CUDART_ERROR(cudaStreamSynchronize(stream), state); }; @@ -439,7 +440,8 @@ registerCudaMemoryManagementOps(sol::state_view &lua, size_t srcOffset, uintptr_t dest, size_t destOffset, size_t numBytes) { ADD_CUDA_MODULE_RANGE("cuda_memcpy_host_pinned2device"); - MTRT_DBGF("cuda_memcpy_h2d %lu bytes from 0x%lx + %lu to 0x%lx + %lu", + MTRT_DBGF("__cuda_memcpy_host_pinned2device: %lu bytes from 0x%lx + " + "%lu to 0x%lx + %lu", numBytes, src, srcOffset, dest, destOffset); void *srcPtr = reinterpret_cast(src + srcOffset); void *dstPtr = reinterpret_cast(dest + destOffset); @@ -475,7 +477,9 @@ registerCudaMemoryManagementOps(sol::state_view &lua, "expected src to be a device ptr and dest to be a host ptr"); } #endif - MTRT_DBGF("executor_memcpy device-host %lu bytes", numBytes); + MTRT_DBGF("__cuda_memcpy_device2host_pinned: %lu bytes from 0x%lx + " + "%lu to 0x%lx + %lu", + numBytes, src, srcOffset, dest, destOffset); SET_LUA_ERROR_IF_CUDART_ERROR(cudaMemcpyAsync(dstPtr, srcPtr, numBytes, cudaMemcpyDeviceToHost, stream), diff --git a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/NCCL/NCCLModule.cpp b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/NCCL/NCCLModule.cpp index c78ec810c..65e0e89c3 100644 --- a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/NCCL/NCCLModule.cpp +++ b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/NCCL/NCCLModule.cpp @@ -268,6 +268,10 @@ static void registerNcclOps(sol::state_view &lua, ResourceTracker *tracker) { lua["__nccl_all_reduce_" #opsuffix "_" #typesuffix] = \ [](sol::this_state state, ExecPtr sendbuff, ExecPtr recvbuff, \ size_t count, uintptr_t communicator, CudaStreamPtr stream) { \ + MTRT_DBG("__nccl_all_reduce_" #opsuffix "_" #typesuffix \ + ": count={0} send={1} recv={2}", \ + count, reinterpret_cast(sendbuff), \ + reinterpret_cast(recvbuff)); \ auto comm = reinterpret_cast(communicator); \ SET_LUA_ERROR_IF_NCCL_ERROR( \ ncclAllReduce(reinterpret_cast(sendbuff), \ @@ -285,6 +289,10 @@ static void registerNcclOps(sol::state_view &lua, ResourceTracker *tracker) { lua["__nccl_reduce_scatter_" #opsuffix "_" #typesuffix] = \ [](sol::this_state state, ExecPtr sendbuff, ExecPtr recvbuff, \ size_t recvcount, uintptr_t communicator, CudaStreamPtr stream) { \ + MTRT_DBG("__nccl_reduce_scatter_" #opsuffix "_" #typesuffix \ + ": count={0} sendbuff={1} recvbuff={2}", \ + recvcount, reinterpret_cast(sendbuff), \ + reinterpret_cast(recvbuff)); \ auto *comm = reinterpret_cast(communicator); \ SET_LUA_ERROR_IF_NCCL_ERROR( \ ncclReduceScatter(reinterpret_cast(sendbuff), \ @@ -338,6 +346,13 @@ static void registerNcclOps(sol::state_view &lua, ResourceTracker *tracker) { size_t numBytes, uintptr_t communicator, CudaStreamPtr stream) { auto *comm = reinterpret_cast(communicator); + MTRT_DBG("__nccl_permute[{6}/{7}]: send {0} bytes @ {1} to {2}, recv {0} " + "bytes @ " + "{3} from {4}, comm @{5}", + numBytes, reinterpret_cast(sendbuff), sendId, + reinterpret_cast(recvbuff), recvId, + reinterpret_cast(comm->comm), comm->rank, comm->numRanks); + if (recvId == -1) { // Zero out recvbuff if not receiving. SET_LUA_ERROR_IF_CUDA_ERROR( From 2e9cb365b44f2d74837493695015f6e8c2b0d1a3 Mon Sep 17 00:00:00 2001 From: Copybara Bot Date: Tue, 19 Nov 2024 10:06:13 -0800 Subject: [PATCH 21/29] [compiler] Fix CompilerClient CAPI destroy function Fixes an issue where the `mtrtCompilerClientDestroy` function was not correctly destroying the underlying C++ CompilerClient object. GitOrigin-RevId: 39362fcb2fc2bf8233b4695fed5e6a3944cdf606 --- mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp b/mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp index 6636768b0..185824226 100644 --- a/mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp +++ b/mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp @@ -90,12 +90,12 @@ MTRT_Status mtrtCompilerClientCreate(MlirContext context, if (!cppClient.isOk()) return wrap(cppClient.getStatus()); - *client = MTRT_CompilerClient{cppClient->release()}; + *client = wrap(cppClient->release()); return mtrtStatusGetOk(); } MTRT_Status mtrtCompilerClientDestroy(MTRT_CompilerClient client) { - delete reinterpret_cast(client.ptr); + delete unwrap(client); return mtrtStatusGetOk(); } From 3a8362c3a50d6092806b680087cd6a7bc4942b85 Mon Sep 17 00:00:00 2001 From: Parth Chadha Date: Wed, 20 Nov 2024 08:57:15 -0800 Subject: [PATCH 22/29] Increment mlir-tensorrt minor version to 0.1.37 (#395) --- mlir-tensorrt/Version.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir-tensorrt/Version.cmake b/mlir-tensorrt/Version.cmake index 05dedb434..969c4498c 100644 --- a/mlir-tensorrt/Version.cmake +++ b/mlir-tensorrt/Version.cmake @@ -1,6 +1,6 @@ set(MLIR_TENSORRT_VERSION_MAJOR "0") set(MLIR_TENSORRT_VERSION_MINOR "1") -set(MLIR_TENSORRT_VERSION_PATCH "36") +set(MLIR_TENSORRT_VERSION_PATCH "37") set(MLIR_TENSORRT_VERSION "${MLIR_TENSORRT_VERSION_MAJOR}.${MLIR_TENSORRT_VERSION_MINOR}.${MLIR_TENSORRT_VERSION_PATCH}") From 89e2090dc78bd0c340e7e09535be25023d089322 Mon Sep 17 00:00:00 2001 From: Sagar Shelke Date: Wed, 20 Nov 2024 10:21:17 -0800 Subject: [PATCH 23/29] [mlir-tensorrt] Add TensorRT 8.6 support (#391) This PR makes the following changes, - Make TensorRT 10.5 as default version - Add TensorRT 8.6 download support - Add TensorRT 8.6 to CI - TensorRT 9 checks from CI are removed to deal with device space error. - Fix tests to support above changes --- .github/workflows/mlir-tensorrt-ci.yml | 20 +- mlir-tensorrt/CMakeLists.txt | 2 +- .../build_tools/cmake/Dependencies.cmake | 38 +- .../StablehloToTensorRT.cpp | 38 +- .../stablehlo-control-flow.mlir | 2 +- .../stablehlo-to-tensorrt-invalid.mlir | 2 +- .../stablehlo-to-tensorrt-trt10.mlir | 345 +++++++++++++++++ .../stablehlo-to-tensorrt.mlir | 351 +----------------- .../Dialect/Plan/segmentation-pipeline.mlir | 2 +- .../test_stablehlo_dynamic_iota.py | 0 10 files changed, 422 insertions(+), 378 deletions(-) create mode 100644 mlir-tensorrt/test/Conversion/StablehloToTensorRT/stablehlo-to-tensorrt-trt10.mlir rename mlir-tensorrt/test/python/IntegrationTests/{ => TRT10}/test_stablehlo_dynamic_iota.py (100%) diff --git a/.github/workflows/mlir-tensorrt-ci.yml b/.github/workflows/mlir-tensorrt-ci.yml index ffb0ce007..e9ff1e611 100644 --- a/.github/workflows/mlir-tensorrt-ci.yml +++ b/.github/workflows/mlir-tensorrt-ci.yml @@ -148,7 +148,7 @@ jobs: -DCMAKE_BUILD_TYPE=RelWithDebInfo \ -DMLIR_TRT_PACKAGE_CACHE_DIR=/.cache.cpm \ -DMLIR_TRT_ENABLE_ASSERTIONS=ON \ - -DMLIR_TRT_DOWNLOAD_TENSORRT_VERSION=10.2 \ + -DMLIR_TRT_DOWNLOAD_TENSORRT_VERSION=10.5 \ -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ \ -DMLIR_TRT_USE_LINKER=lld \ -DMLIR_EXECUTOR_ENABLE_GPU_INTEGRATION_TESTS=OFF @@ -191,7 +191,7 @@ jobs: -DCMAKE_BUILD_TYPE=RelWithDebInfo \ -DMLIR_TRT_PACKAGE_CACHE_DIR=/.cache.cpm \ -DMLIR_TRT_ENABLE_ASSERTIONS=ON \ - -DMLIR_TRT_DOWNLOAD_TENSORRT_VERSION=10.2 \ + -DMLIR_TRT_DOWNLOAD_TENSORRT_VERSION=10.5 \ -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ \ -DMLIR_TRT_USE_LINKER=lld \ -DMLIR_EXECUTOR_ENABLE_GPU_INTEGRATION_TESTS=OFF \ @@ -209,8 +209,8 @@ jobs: bash build_and_test.sh - # Run LIT tests with TensorRT 9 - - name: Run MLIR-TensorRT lit tests with TensorRT 9 + # Run LIT tests with TensorRT 8 + - name: Run MLIR-TensorRT lit tests with TensorRT 8 uses: addnab/docker-run-action@v3 with: image: ${{ env.DEFAULT_IMAGE }} @@ -235,7 +235,7 @@ jobs: -DCMAKE_BUILD_TYPE=RelWithDebInfo \ -DMLIR_TRT_PACKAGE_CACHE_DIR=/.cache.cpm \ -DMLIR_TRT_ENABLE_ASSERTIONS=ON \ - -DMLIR_TRT_DOWNLOAD_TENSORRT_VERSION=9.2.0.5 \ + -DMLIR_TRT_DOWNLOAD_TENSORRT_VERSION=8.6.1.6 \ -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ \ -DMLIR_TRT_USE_LINKER=lld \ -DMLIR_EXECUTOR_ENABLE_GPU_INTEGRATION_TESTS=OFF @@ -324,7 +324,7 @@ jobs: -DCMAKE_BUILD_TYPE=RelWithDebInfo \ -DMLIR_TRT_PACKAGE_CACHE_DIR=/.cache.cpm \ -DMLIR_TRT_ENABLE_ASSERTIONS=ON \ - -DMLIR_TRT_DOWNLOAD_TENSORRT_VERSION=10.2 \ + -DMLIR_TRT_DOWNLOAD_TENSORRT_VERSION=10.5 \ -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ \ -DMLIR_TRT_USE_LINKER=lld \ -DMLIR_EXECUTOR_ENABLE_GPU_INTEGRATION_TESTS=OFF @@ -367,7 +367,7 @@ jobs: -DCMAKE_BUILD_TYPE=RelWithDebInfo \ -DMLIR_TRT_PACKAGE_CACHE_DIR=/.cache.cpm \ -DMLIR_TRT_ENABLE_ASSERTIONS=ON \ - -DMLIR_TRT_DOWNLOAD_TENSORRT_VERSION=10.2 \ + -DMLIR_TRT_DOWNLOAD_TENSORRT_VERSION=10.5 \ -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ \ -DMLIR_TRT_USE_LINKER=lld \ -DMLIR_EXECUTOR_ENABLE_GPU_INTEGRATION_TESTS=OFF \ @@ -385,8 +385,8 @@ jobs: bash build_and_test.sh - # Run LIT tests with TensorRT 9 - - name: Run MLIR-TensorRT lit tests with TensorRT 9 + # Run LIT tests with TensorRT 8 + - name: Run MLIR-TensorRT lit tests with TensorRT 8 uses: addnab/docker-run-action@v3 with: image: ${{ env.DEFAULT_IMAGE }} @@ -411,7 +411,7 @@ jobs: -DCMAKE_BUILD_TYPE=RelWithDebInfo \ -DMLIR_TRT_PACKAGE_CACHE_DIR=/.cache.cpm \ -DMLIR_TRT_ENABLE_ASSERTIONS=ON \ - -DMLIR_TRT_DOWNLOAD_TENSORRT_VERSION=9.2.0.5 \ + -DMLIR_TRT_DOWNLOAD_TENSORRT_VERSION=8.6.1.6 \ -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ \ -DMLIR_TRT_USE_LINKER=lld \ -DMLIR_EXECUTOR_ENABLE_GPU_INTEGRATION_TESTS=OFF diff --git a/mlir-tensorrt/CMakeLists.txt b/mlir-tensorrt/CMakeLists.txt index db6385850..41f81149b 100644 --- a/mlir-tensorrt/CMakeLists.txt +++ b/mlir-tensorrt/CMakeLists.txt @@ -55,7 +55,7 @@ mtrt_option(MLIR_TRT_ENABLE_EXECUTOR "Build the Executor dialect and MLIR-Tensor mtrt_option(MLIR_TRT_ENABLE_NCCL "Enable the NCCL runtime module" ON) set(MLIR_TRT_TENSORRT_DIR "" CACHE STRING "Path to TensorRT install directory") -set(MLIR_TRT_DOWNLOAD_TENSORRT_VERSION "10.2" CACHE STRING +set(MLIR_TRT_DOWNLOAD_TENSORRT_VERSION "10.5" CACHE STRING "Version of TensorRT to download and use. It overrides MLIR_TRT_TENSORRT_DIR.") set(MLIR_TRT_PACKAGE_CACHE_DIR "" CACHE STRING "Directory where to cache downloaded C++ packages") set(MLIR_TRT_USE_LINKER "" CACHE STRING "Specify a linker to use (e.g. LLD); this is just an alias for LLVM_USE_LINKER") diff --git a/mlir-tensorrt/build_tools/cmake/Dependencies.cmake b/mlir-tensorrt/build_tools/cmake/Dependencies.cmake index ef063d0e6..68a975a15 100644 --- a/mlir-tensorrt/build_tools/cmake/Dependencies.cmake +++ b/mlir-tensorrt/build_tools/cmake/Dependencies.cmake @@ -86,11 +86,15 @@ function(download_tensorrt) if(ARG_VERSION VERSION_EQUAL "10.2") set(ARG_VERSION "10.2.0.19") endif() + # Canonicalize "10.5" version by setting it to the latest public TRT 10.5 version. + if(ARG_VERSION VERSION_EQUAL "10.5") + set(ARG_VERSION "10.5.0.18") + endif() set(downloadable_versions - "9.0.1.4" "9.1.0.4" "9.2.0.5" + "8.6.1.6" "9.0.1.4" "9.1.0.4" "9.2.0.5" "10.0.0.6" "10.1.0.27" - "10.2.0.19" + "10.2.0.19" "10.5.0.18" ) if(NOT ARG_VERSION IN_LIST downloadable_versions) @@ -100,6 +104,28 @@ function(download_tensorrt) set(TRT_VERSION "${ARG_VERSION}") + # Handle TensorRT 8 versions. These are publicly accessible download links. + if(ARG_VERSION VERSION_LESS 9.0.0 AND ARG_VERSION VERSION_GREATER 8.0.0) + string(REGEX MATCH "[0-9]+\\.[0-9]+\\.[0-9]+" trt_short_version ${ARG_VERSION}) + set(CUDA_VERSION "12.0") + set(OS "linux") + EXECUTE_PROCESS(COMMAND uname -m + COMMAND tr -d '\n' + OUTPUT_VARIABLE ARCH) + if(ARCH STREQUAL "arm64") + set(ARCH "aarch64") + set(OS "Ubuntu-20.04") + elseif(ARCH STREQUAL "amd64") + set(ARCH "x86_64") + set(OS "Linux") + elseif(ARCH STREQUAL "aarch64") + set(OS "Ubuntu-20.04") + elseif(NOT (ARCH STREQUAL "x86_64")) + message(FATAL_ERROR "Direct download not available for architecture: ${ARCH}") + endif() + set(_url "https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/secure/${trt_short_version}/tars/TensorRT-${TRT_VERSION}.${OS}.${ARCH}-gnu.cuda-${CUDA_VERSION}.tar.gz") + endif() + # Handle TensorRT 9 versions. These are publicly accessible download links. if(ARG_VERSION VERSION_LESS 10.0.0 AND ARG_VERSION VERSION_GREATER 9.0.0) string(REGEX MATCH "[0-9]+\\.[0-9]+\\.[0-9]+" trt_short_version ${ARG_VERSION}) @@ -137,6 +163,10 @@ function(download_tensorrt) set(_url "https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.2.0/tars/TensorRT-10.2.0.19.Linux.x86_64-gnu.cuda-12.5.tar.gz") endif() + if(ARG_VERSION VERSION_EQUAL 10.5.0.18) + set(_url "https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.5.0/tars/TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-12.6.tar.gz") + endif() + if(NOT _url) message(FATAL_ERROR "Could not determine TensorRT download URL") endif() @@ -144,12 +174,12 @@ function(download_tensorrt) message(STATUS "TensorRT Download URL: ${_url}") CPMAddPackage( - NAME TensorRT9 + NAME TensorRT VERSION "${TRT_VERSION}" URL ${_url} DOWNLOAD_ONLY ) - set("${ARG_OUT_VAR}" "${TensorRT9_SOURCE_DIR}" PARENT_SCOPE) + set("${ARG_OUT_VAR}" "${TensorRT_SOURCE_DIR}" PARENT_SCOPE) endfunction() #------------------------------------------------------------------------------------- diff --git a/mlir-tensorrt/compiler/lib/Conversion/StablehloToTensorRT/StablehloToTensorRT.cpp b/mlir-tensorrt/compiler/lib/Conversion/StablehloToTensorRT/StablehloToTensorRT.cpp index ac178f515..34622ff91 100644 --- a/mlir-tensorrt/compiler/lib/Conversion/StablehloToTensorRT/StablehloToTensorRT.cpp +++ b/mlir-tensorrt/compiler/lib/Conversion/StablehloToTensorRT/StablehloToTensorRT.cpp @@ -2818,10 +2818,15 @@ struct PadConverter : public ConvertHloOpToTensorRTPattern { auto padLowHighSum = trtRewriter.checkAndCreate( loc, targetTrtMajorVersion, shapeTensorType, padLowConst, padHighConst, tensorrt::ElementWiseOperation::kSUM); + if (!padLowHighSum) + return failure(); Value size = padLowHighSum.getResult(); - size = trtRewriter.checkAndCreate( + auto sumWithResult = trtRewriter.checkAndCreate( loc, targetTrtMajorVersion, shapeTensorType, size, shape.getResult(), tensorrt::ElementWiseOperation::kSUM); + if (!sumWithResult) + return failure(); + size = sumWithResult.getResult(); SmallVector stride(inputType.getRank(), 1); return trtRewriter.checkAndReplaceOpWithNewOp( @@ -3858,7 +3863,7 @@ struct ConvertScatterToTensorRTScatterElements if (!constOneTuple) return failure(); - Value newIndices = trtRewriter.checkAndCreate( + auto newIndices = trtRewriter.checkAndCreate( op->getLoc(), targetTrtMajorVersion, newUpdateType.clone(rewriter.getI32Type()), Value(), startIndex, constOneTuple, FloatAttr(), FloatAttr()); @@ -3884,7 +3889,7 @@ struct ConvertScatterToTensorRTScatterElements auto newOp = trtRewriter.checkAndCreate( op->getLoc(), targetTrtMajorVersion, /*data*/ convertToI32(adaptor.getInputs().front()), - /*indices*/ newIndices, + /*indices*/ newIndices.getResult(), /*updates*/ convertToI32(newUpdates), /*axis*/ rewriter.getI64IntegerAttr(axis)); if (!newOp) @@ -3894,7 +3899,8 @@ struct ConvertScatterToTensorRTScatterElements auto newOp = trtRewriter.checkAndCreate( op->getLoc(), targetTrtMajorVersion, /*data*/ adaptor.getInputs().front(), - /*indices*/ newIndices, /*updates*/ newUpdates.getResult(), + /*indices*/ newIndices.getResult(), + /*updates*/ newUpdates.getResult(), /*axis*/ rewriter.getI64IntegerAttr(axis)); if (!newOp) return failure(); @@ -4327,24 +4333,32 @@ struct DynamicUpdateSliceToConcatConverter // start and shape to be the values appropriate for !hasNonZeroUpdateStart // (static case). We will update them in the condition block. // Calculate the slice start = update offset + update size. - TypedValue concatDimOffset = - trtRewriter.checkAndCreate( - loc, targetTrtMajorVersion, updateStartOffset, - tensorrt::createConstShapeTensor( - rewriter, loc, - {static_cast(updateType.getDimSize(*concatAxis))}), - tensorrt::ElementWiseOperation::kSUM); + auto sliceStart = trtRewriter.checkAndCreate( + loc, targetTrtMajorVersion, updateStartOffset, + tensorrt::createConstShapeTensor( + rewriter, loc, + {static_cast(updateType.getDimSize(*concatAxis))}), + tensorrt::ElementWiseOperation::kSUM); + if (!sliceStart) + return failure(); + TypedValue concatDimOffset = sliceStart.getResult(); + TypedValue endOffset = tensorrt::scatterShapeTensor( rewriter, loc, SmallVector(updateType.getRank(), 0), *concatAxis, concatDimOffset); // Calculate the slice size = result shape - update offset. - TypedValue finalPartDimSize = + auto finalPartDimSizeOp = trtRewriter.checkAndCreate( loc, targetTrtMajorVersion, tensorrt::createConstShapeTensor( rewriter, loc, {static_cast(resultType.getDimSize(*concatAxis))}), concatDimOffset, tensorrt::ElementWiseOperation::kSUB); + if (!finalPartDimSizeOp) + return failure(); + TypedValue finalPartDimSize = + finalPartDimSizeOp.getResult(); + TypedValue endShape = tensorrt::scatterShapeTensor( rewriter, loc, resultType.getShape(), *concatAxis, finalPartDimSize); diff --git a/mlir-tensorrt/test/Conversion/StablehloToTensorRT/stablehlo-control-flow.mlir b/mlir-tensorrt/test/Conversion/StablehloToTensorRT/stablehlo-control-flow.mlir index 9521cc241..c54f5e1a4 100644 --- a/mlir-tensorrt/test/Conversion/StablehloToTensorRT/stablehlo-control-flow.mlir +++ b/mlir-tensorrt/test/Conversion/StablehloToTensorRT/stablehlo-control-flow.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-tensorrt-opt -split-input-file %s --convert-stablehlo-to-tensorrt=convert-loops | FileCheck %s +// RUN: mlir-tensorrt-opt -split-input-file %s --convert-stablehlo-to-tensorrt="convert-loops=true trt-major-version=10" | FileCheck %s func.func @while() -> tensor { %arg0 = stablehlo.constant dense<0> : tensor diff --git a/mlir-tensorrt/test/Conversion/StablehloToTensorRT/stablehlo-to-tensorrt-invalid.mlir b/mlir-tensorrt/test/Conversion/StablehloToTensorRT/stablehlo-to-tensorrt-invalid.mlir index e4924f952..bd71c1cf1 100644 --- a/mlir-tensorrt/test/Conversion/StablehloToTensorRT/stablehlo-to-tensorrt-invalid.mlir +++ b/mlir-tensorrt/test/Conversion/StablehloToTensorRT/stablehlo-to-tensorrt-invalid.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-tensorrt-opt -split-input-file %s --convert-stablehlo-to-tensorrt -verify-diagnostics | FileCheck %s +// RUN: mlir-tensorrt-opt -split-input-file %s --convert-stablehlo-to-tensorrt="trt-major-version=10" -verify-diagnostics | FileCheck %s func.func @stablehlo_all_reduce_region(%arg0 : tensor) -> tensor { %0 = "stablehlo.all_reduce"(%arg0) ({ diff --git a/mlir-tensorrt/test/Conversion/StablehloToTensorRT/stablehlo-to-tensorrt-trt10.mlir b/mlir-tensorrt/test/Conversion/StablehloToTensorRT/stablehlo-to-tensorrt-trt10.mlir new file mode 100644 index 000000000..1af1ba7b1 --- /dev/null +++ b/mlir-tensorrt/test/Conversion/StablehloToTensorRT/stablehlo-to-tensorrt-trt10.mlir @@ -0,0 +1,345 @@ +// RUN: mlir-tensorrt-opt -split-input-file %s --convert-stablehlo-to-tensorrt="trt-major-version=10" -allow-unregistered-dialect | FileCheck %s + +func.func @hlo_iota() -> tensor<128xi32> { + %0 = "stablehlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<128xi32> + return %0 : tensor<128xi32> +} + +// CHECK-LABEL: @hlo_iota +// CHECK: tensorrt.linspace +// CHECK-SAME: [ 0.00{{.+}}] [ static] [ 1.000{{.+}}] : tensor<128xi32> + +// ----- + +func.func @hlo_rem(%arg0: tensor<128xi32>, %arg1: tensor<128xi32>) -> tensor<128xi32> { + %0 = "stablehlo.remainder"(%arg0, %arg1) {} : (tensor<128xi32>, tensor<128xi32>) -> tensor<128xi32> + return %0 : tensor<128xi32> +} + +// CHECK-LABEL: @hlo_rem +// CHECK-SAME: (%[[lhs:.+]]: tensor<128xi32>, %[[rhs:.+]]: tensor<128xi32>) +// CHECK: %[[div:.+]] = tensorrt.element_wise (%[[lhs]], %[[rhs]] : +// CHECK: %[[prod:.+]] = tensorrt.element_wise (%[[div]], %[[rhs]] : +// CHECK: tensorrt.element_wise (%[[lhs]], %[[prod]] : + +// ----- + +func.func @hlo_pad_dynamic_non_sliced_dim(%arg0: tensor) -> tensor { + %0 = "stablehlo.constant"() {value = dense<0.0> : tensor} : () -> tensor + %1 = "stablehlo.pad"(%arg0, %0) { + edge_padding_high = array, + edge_padding_low = array, + interior_padding = array + } : (tensor, tensor) -> tensor + func.return %1 : tensor +} + +// CHECK-LABEL: @hlo_pad_dynamic_non_sliced_dim +// CHECK-SAME: (%[[arg0:.+]]: tensor +// CHECK: %[[fill:.+]] = tensorrt.constant dense<0.000000e+00> : tensor +// CHECK: %[[shape:.+]] = tensorrt.shape %[[arg0]] : tensor -> tensor<4xi32> +// CHECK: %[[pad_high:.+]] = tensorrt.constant dense<[0, 0, 0, 16]> : tensor<4xi32> +// CHECK: %[[pad_low:.+]] = tensorrt.constant dense<0> : tensor<4xi32> +// CHECK: %[[sum0:.+]] = tensorrt.element_wise (%[[pad_low]], %[[pad_high]] : tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> +// CHECK: %[[sum1:.+]] = tensorrt.element_wise (%[[sum0]], %[[shape]] : tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> +// CHECK: tensorrt.slice %arg0[0, 0, 0, 0][%[[sum1]]: tensor<4xi32>][1, 1, 1, 1] fill(%[[fill]] : tensor) {mode = #tensorrt.slice_mode} : tensor to tensor + +// ----- + +func.func @hlo_dynamic_pad(%arg0: tensor, %arg1: tensor, %arg2: tensor<2xi32>, %arg3: tensor<2xi32>) -> tensor { + %cst_0 = stablehlo.constant dense<[0, 0]> : tensor<2xi32> + %0 = "stablehlo.dynamic_pad"(%arg0, %arg1, %arg2, %arg3, %cst_0) : (tensor, tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor + return %0: tensor +} + +// CHECK-LABEL: @hlo_dynamic_pad +// CHECK-SAME: (%[[arg0:.+]]: tensor, %[[arg1:.+]]: tensor, %[[arg2:.+]]: tensor<2xi32>, %[[arg3:.+]]: tensor<2xi32>) -> tensor { +// CHECK: %[[pad_interior:.+]] = tensorrt.constant dense<0> : tensor<2xi32> +// CHECK: %[[padding_low_f32:.+]] = tensorrt.identity %arg2 : tensor<2xi32> to tensor<2xf32> +// CHECK: %[[sliceOffset_f32:.+]] = tensorrt.unary {unaryOperation = #tensorrt.unary_operation} %[[padding_low_f32]] : tensor<2xf32> +// CHECK: %[[sliceOffset_i32:.+]] = tensorrt.identity %[[sliceOffset_f32]] : tensor<2xf32> to tensor<2xi32> +// CHECK: %[[shape:.+]] = tensorrt.shape %[[arg0]] : tensor -> tensor<2xi32> +// CHECK: %[[sum0:.+]] = tensorrt.element_wise (%[[arg2]], %[[arg3]] : tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: %[[sum1:.+]] = tensorrt.element_wise (%[[sum0]], %[[shape]] : tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: tensorrt.slice %arg0[%[[sliceOffset_i32]]: tensor<2xi32>][%[[sum1]]: tensor<2xi32>][1, 1] fill(%arg1 : tensor) {mode = #tensorrt.slice_mode} : tensor to tensor + +// ----- + +func.func @hlo_dynamic_iota_0(%arg0 : tensor<1xi32>) -> tensor { + %0 = "stablehlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<1xi32>) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: @hlo_dynamic_iota_0 +// CHECK: tensorrt.linspace +// CHECK-SAME: [ 0.00{{.+}}] [%arg0 : tensor<1xi32>] [ 1.000{{.+}}] : tensor + +// ----- + +func.func @dynamic_nd_iota_1(%arg0 : tensor<2xi32>) -> tensor { + %0 = "stablehlo.dynamic_iota"(%arg0) {iota_dimension = 1 : i64} : (tensor<2xi32>) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func.func @dynamic_nd_iota_1 +// CHECK-SAME: (%[[arg0:.+]]: tensor<2xi32>) -> tensor { +// CHECK: %[[cst_i32:.+]] = tensorrt.constant dense<0> : tensor +// CHECK: %[[cst_i32_0:.+]] = tensorrt.constant dense<[0, 1]> : tensor<2xi32> +// CHECK: %[[v0:.+]] = tensorrt.linspace[%[[cst_i32]] : tensor] [%[[arg0]] : tensor<2xi32>] [%[[cst_i32_0]] : tensor<2xi32>] : tensor +// CHECK: return %[[v0]] : tensor + +// ----- + +func.func @dynamic_nd_iota_2(%arg0 : tensor<2xi32>) -> tensor { + %0 = "stablehlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<2xi32>) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func.func @dynamic_nd_iota_2 +// CHECK-SAME: (%[[arg0:.+]]: tensor<2xi32>) -> tensor { +// CHECK: %[[cst_i32:.+]] = tensorrt.constant dense<0> : tensor +// CHECK: %[[cst_i32_0:.+]] = tensorrt.constant dense<[1, 0]> : tensor<2xi32> +// CHECK: %[[v0:.+]] = tensorrt.linspace[%[[cst_i32]] : tensor] [%[[arg0]] : tensor<2xi32>] [%[[cst_i32_0]] : tensor<2xi32>] : tensor +// CHECK: return %[[v0]] : tensor + +// ----- + +func.func @stablehlo_real_dynamic_slice( + %input: tensor, + %start_indices: tensor<2xindex>, + %limit_indices: tensor<2xindex>, + %strides: tensor<2xindex>) -> tensor { + %0 = "stablehlo.real_dynamic_slice"(%input, %start_indices, %limit_indices, %strides) : (tensor, tensor<2xindex>, tensor<2xindex>, tensor<2xindex>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: @stablehlo_real_dynamic_slice( +// CHECK-SAME: %[[arg0:.+]]: tensor, %[[arg1:.+]]: tensor<2xi32>, %[[arg2:.+]]: tensor<2xi32>, %[[arg3]]: tensor<2xi32>) -> tensor { +// CHECK: %[[num:.+]] = tensorrt.element_wise (%[[arg2]], %[[arg1]] +// CHECK: %[[cst_i32:.+]] = tensorrt.constant dense<1> : tensor<2xi32> +// CHECK: %[[bias:.+]] = tensorrt.element_wise (%[[arg3]], %[[cst_i32]] : tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: %[[num1:.+]] = tensorrt.element_wise (%[[num]], %[[bias]] : tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: %[[ceilDiv:.+]] = tensorrt.element_wise (%[[num1]], %[[arg3]] : tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: %[[result:.+]] = tensorrt.slice %[[arg0]][%[[arg1]]: tensor<2xi32>][%[[ceilDiv]]: tensor<2xi32>][%[[arg3]]: tensor<2xi32>] +// CHECK: return %[[result]] : tensor + +// ----- + +func.func @dynamic_update_slice_conversion_1(%arg0: tensor<1x6x12x64xf32>) -> tensor<1x20x12x64xf32> { + %0 = stablehlo.constant dense<0> : tensor + %1 = stablehlo.constant dense<0.000000e+00> : tensor<1x20x12x64xf32> + %2 = stablehlo.dynamic_update_slice %1, %arg0, %0, %0, %0, %0 : (tensor<1x20x12x64xf32>, tensor<1x6x12x64xf32>, tensor, tensor, tensor, tensor) -> tensor<1x20x12x64xf32> + return %2 : tensor<1x20x12x64xf32> +} + +// CHECK-LABEL: @dynamic_update_slice_conversion_1 +// CHECK-SAME: (%[[arg0:.+]]: tensor<1x6x12x64xf32>) +// CHECK: %[[cst_i32:.+]] = tensorrt.constant dense<0> +// CHECK: %[[cst_f32:.+]] = tensorrt.constant dense<0.000000e+00> +// CHECK: %[[v0:.+]] = tensorrt.expand_rank %[[cst_i32]] : tensor to tensor<1xi32> +// CHECK: %[[cst_i32_0:.+]] = tensorrt.constant dense<1> +// CHECK: %[[cst_i32_1:.+]] = tensorrt.constant dense<[12, 64]> +// CHECK: %[[v1:.+]] = tensorrt.concatenation {axis = 0 : i32} ins(%[[cst_i32_0]], %[[v0]], %[[cst_i32_1]] : +// CHECK: %[[v2:.+]] = tensorrt.slice %[[cst_f32]][0, 0, 0, 0][%[[v1]]: tensor<4xi32>][1, 1, 1, 1] +// CHECK: %[[cst_i32_2:.+]] = tensorrt.constant dense<6> +// CHECK: %[[v3:.+]] = tensorrt.element_wise (%[[v0]], %[[cst_i32_2]] : +// CHECK: %[[cst_i32_3:.+]] = tensorrt.constant dense<0> : tensor<1xi32> +// CHECK: %[[cst_i32_4:.+]] = tensorrt.constant dense<0> : tensor<2xi32> +// CHECK: %[[v4:.+]] = tensorrt.concatenation {axis = 0 : i32} ins(%[[cst_i32_3]], %[[v3]], %[[cst_i32_4]] +// CHECK: %[[cst_i32_5:.+]] = tensorrt.constant dense<20> +// CHECK: %[[v5:.+]] = tensorrt.element_wise (%[[cst_i32_5]], %[[v3]] : +// CHECK: %[[cst_i32_6:.+]] = tensorrt.constant dense<1> +// CHECK: %[[cst_i32_7:.+]] = tensorrt.constant dense<[12, 64]> +// CHECK: %[[v6:.+]] = tensorrt.concatenation {axis = 0 : i32} ins(%[[cst_i32_6]], %[[v5]], %[[cst_i32_7]] : +// CHECK: %[[v7:.+]] = tensorrt.slice %[[cst_f32]][%[[v4]]: tensor<4xi32>][%[[v6]]: tensor<4xi32>][1, 1, 1, 1] +// CHECK: %[[v8:.+]] = tensorrt.concatenation {axis = 1 : i32} ins(%[[v2]], %[[arg0]], %[[v7]] : +// CHECK: return %[[v8]] + +// ----- + +func.func @dynamic_update_slice_conversion_2(%arg0: tensor<1x20x12x64xf32>, %arg1: tensor<1x1x12x64xf32>, %arg2: tensor) -> tensor<1x20x12x64xf32> { + %0 = stablehlo.constant dense<0> : tensor + %1 = stablehlo.dynamic_update_slice %arg0, %arg1, %0, %arg2, %0, %0 : (tensor<1x20x12x64xf32>, tensor<1x1x12x64xf32>, tensor, tensor, tensor, tensor) -> tensor<1x20x12x64xf32> + return %1 : tensor<1x20x12x64xf32> +} + +// CHECK-LABEL: @dynamic_update_slice_conversion_2 +// CHECK-SAME: (%[[arg0:.+]]: tensor<1x20x12x64xf32>, %[[arg1:.+]]: tensor<1x1x12x64xf32>, %[[arg2:.+]]: tensor) +// CHECK: %[[cst_i32:.+]] = tensorrt.constant dense<0> : tensor +// CHECK: %[[v0:.+]] = tensorrt.expand_rank %[[arg2]] : tensor to tensor<1xi32> +// CHECK: %[[cst_i32_0:.+]] = tensorrt.constant dense<1> : tensor<1xi32> +// CHECK: %[[cst_i32_1:.+]] = tensorrt.constant dense<[12, 64]> : tensor<2xi32> +// CHECK: %[[v1:.+]] = tensorrt.concatenation +// CHECK-SAME: axis = 0 : i32 +// CHECK-SAME: ins(%[[cst_i32_0]], %[[v0]], %[[cst_i32_1]] : tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) +// CHECK: %[[v2:.+]] = tensorrt.slice %[[arg0]][0, 0, 0, 0][%[[v1]]: tensor<4xi32>][1, 1, 1, 1] : tensor<1x20x12x64xf32> to tensor<1x?x12x64xf32> +// CHECK: %[[cst_i32_2:.+]] = tensorrt.constant dense<1> : tensor<1xi32> +// CHECK: %[[v3:.+]] = tensorrt.element_wise (%[[v0]], %[[cst_i32_2]] : tensor<1xi32>, tensor<1xi32>) +// CHECK: %[[cst_i32_3:.+]] = tensorrt.constant dense<0> : tensor<1xi32> +// CHECK: %[[cst_i32_4:.+]] = tensorrt.constant dense<0> : tensor<2xi32> +// CHECK: %[[v4:.+]] = tensorrt.concatenation +// CHECK-SAME: axis = 0 : i32 +// CHECK-SAME: ins(%[[cst_i32_3]], %[[v3]], %[[cst_i32_4]] : tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) +// CHECK: %[[cst_i32_5:.+]] = tensorrt.constant dense<20> : tensor<1xi32> +// CHECK: %[[v5:.+]] = tensorrt.element_wise (%[[cst_i32_5]], %[[v3]] : tensor<1xi32>, tensor<1xi32>) +// CHECK: %[[cst_i32_6:.+]] = tensorrt.constant dense<1> : tensor<1xi32> +// CHECK: %[[cst_i32_7:.+]] = tensorrt.constant dense<[12, 64]> : tensor<2xi32> +// CHECK: %[[v6:.+]] = tensorrt.concatenation +// CHECK-SAME: axis = 0 : i32 +// CHECK-SAME: ins(%[[cst_i32_6]], %[[v5]], %[[cst_i32_7]] : tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) +// CHECK: %[[v7:.+]] = tensorrt.slice %[[arg0]][%[[v4]]: tensor<4xi32>][%[[v6]]: tensor<4xi32>][1, 1, 1, 1] : tensor<1x20x12x64xf32> to tensor<1x?x12x64xf32> +// CHECK: %[[v8:.+]] = tensorrt.concatenation +// CHECK-SAME: axis = 1 : i32 +// CHECK-SAME: ins(%[[v2]], %[[arg1]], %[[v7]] : tensor<1x?x12x64xf32>, tensor<1x1x12x64xf32>, tensor<1x?x12x64xf32>) +// CHECK: return %[[v8]] + +// ----- + +func.func @scatter_slice_update(%arg0: tensor<1x134xi32>, %arg1: tensor<1x2xi32>, %arg2: tensor<1x1x5xi32>) -> tensor<1x134xi32> { + %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + stablehlo.return %arg4 : tensor + }) {indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false} : (tensor<1x134xi32>, tensor<1x2xi32>, tensor<1x1x5xi32>) -> tensor<1x134xi32> + return %0 : tensor<1x134xi32> +} + +// CHECK-LABEL: @scatter_slice_update +// CHECK: %[[v0:.+]] = tensorrt.slice %arg1[0, 1][1, 1][1, 1] : tensor<1x2xi32> to tensor<1x1xi32> +// CHECK: %[[v1:.+]] = tensorrt.collapse_rank %[[v0]] : tensor<1x1xi32> to tensor +// CHECK: %[[v2:.+]] = tensorrt.collapse_rank %arg2 : tensor<1x1x5xi32> to tensor<1x5xi32> +// CHECK: %[[v3:.+]] = tensorrt.constant dense<1> : tensor<2xi32> +// CHECK: %[[v4:.+]] = tensorrt.linspace[%[[v1]] : tensor] [ static] [%[[v3]] : tensor<2xi32>] : tensor<1x5xi32> +// CHECK: %[[v5:.+]] = tensorrt.scatter_elements {axis = 1 : i64} data(%arg0 : tensor<1x134xi32>) indices(%[[v4]] : tensor<1x5xi32>) updates(%[[v2]] : tensor<1x5xi32>) +// CHECK: return %[[v5]] : tensor<1x134xi32> + +// ----- + +func.func @scatter_slice_update_f16_axis1(%arg0: tensor<1x134xf16>, %arg1: tensor<1x2xi32>, %arg2: tensor<1x1x5xf16>) -> tensor<1x134xf16> { + %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + stablehlo.return %arg4 : tensor + }) {indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false} : (tensor<1x134xf16>, tensor<1x2xi32>, tensor<1x1x5xf16>) -> tensor<1x134xf16> + return %0 : tensor<1x134xf16> +} + +// CHECK-LABEL: @scatter_slice_update_f16 +// CHECK: %[[v0:.+]] = tensorrt.slice %arg1[0, 1][1, 1][1, 1] : tensor<1x2xi32> to tensor<1x1xi32> +// CHECK: %[[v1:.+]] = tensorrt.collapse_rank %[[v0]] : tensor<1x1xi32> to tensor +// CHECK: %[[v2:.+]] = tensorrt.collapse_rank %arg2 : tensor<1x1x5xf16> to tensor<1x5xf16> +// CHECK: %[[v3:.+]] = tensorrt.constant dense<1> : tensor<2xi32> +// CHECK: %[[v4:.+]] = tensorrt.linspace[%[[v1]] : tensor] [ static] [%[[v3]] : tensor<2xi32>] : tensor<1x5xi32> +// CHECK: %[[v5:.+]] = tensorrt.scatter_elements {axis = 1 : i64} data(%arg0 : tensor<1x134xf16>) indices(%[[v4]] : tensor<1x5xi32>) updates(%[[v2]] : tensor<1x5xf16>) +// CHECK: return %[[v5]] : tensor<1x134xf16> + +// ----- + +func.func @scatter_slice_update_i1_axis1(%arg0: tensor<1x134xi1>, %arg1: tensor<1x2xi32>, %arg2: tensor<1x1x5xi1>) -> tensor<1x134xi1> { + %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + stablehlo.return %arg4 : tensor + }) {indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false} : (tensor<1x134xi1>, tensor<1x2xi32>, tensor<1x1x5xi1>) -> tensor<1x134xi1> + return %0 : tensor<1x134xi1> +} + +// CHECK-LABEL: @scatter_slice_update_i1 +// CHECK: %[[v0:.+]] = tensorrt.slice %arg1[0, 1][1, 1][1, 1] : tensor<1x2xi32> to tensor<1x1xi32> +// CHECK: %[[v1:.+]] = tensorrt.collapse_rank %[[v0]] : tensor<1x1xi32> to tensor +// CHECK: %[[v2:.+]] = tensorrt.collapse_rank %arg2 : tensor<1x1x5xi1> to tensor<1x5xi1> +// CHECK: %[[v3:.+]] = tensorrt.constant dense<1> : tensor<2xi32> +// CHECK: %[[v4:.+]] = tensorrt.linspace[%[[v1]] : tensor] [ static] [%[[v3]] : tensor<2xi32>] : tensor<1x5xi32> +// CHECK: %[[v5:.+]] = tensorrt.identity %arg0 : tensor<1x134xi1> to tensor<1x134xi32> +// CHECK: %[[v6:.+]] = tensorrt.identity %[[v2]] : tensor<1x5xi1> to tensor<1x5xi32> +// CHECK: %[[v7:.+]] = tensorrt.scatter_elements {axis = 1 : i64} data(%[[v5]] : tensor<1x134xi32>) indices(%[[v4]] : tensor<1x5xi32>) updates(%[[v6]] : tensor<1x5xi32>) +// CHECK: %[[v8:.+]] = tensorrt.identity %[[v7]] : tensor<1x134xi32> to tensor<1x134xi1> +// CHECK: return %[[v8]] : tensor<1x134xi1> + +// ----- + +func.func @scatter_slice_update_i1_axis0(%arg0: tensor<1024x1xi1>, %arg1: tensor<1x1xi32>, %arg2: tensor<1x134x1xi1>) -> tensor<1024x1xi1> { + %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg3: tensor, %arg4: tensor): + stablehlo.return %arg4 : tensor + }) : (tensor<1024x1xi1>, tensor<1x1xi32>, tensor<1x134x1xi1>) -> tensor<1024x1xi1> + return %0 : tensor<1024x1xi1> +} + +// CHECK-LABEL: @scatter_slice_update_i1_axis0 +// CHECK: %[[v0:.+]] = tensorrt.slice %arg1[0, 0][1, 1][1, 1] : tensor<1x1xi32> to tensor<1x1xi32> +// CHECK: %[[v1:.+]] = tensorrt.collapse_rank %[[v0]] : tensor<1x1xi32> to tensor +// CHECK: %[[v2:.+]] = tensorrt.collapse_rank %arg2 : tensor<1x134x1xi1> to tensor<134x1xi1> +// CHECK: %[[v3:.+]] = tensorrt.constant dense<1> : tensor<2xi32> +// CHECK: %[[v4:.+]] = tensorrt.linspace[%[[v1]] : tensor] [ static] [%[[v3]] : tensor<2xi32>] : tensor<134x1xi32> +// CHECK: %[[v5:.+]] = tensorrt.identity %arg0 : tensor<1024x1xi1> to tensor<1024x1xi32> +// CHECK: %[[v6:.+]] = tensorrt.identity %[[v2]] : tensor<134x1xi1> to tensor<134x1xi32> +// CHECK: %[[v7:.+]] = tensorrt.scatter_elements {axis = 0 : i64} data(%[[v5]] : tensor<1024x1xi32>) indices(%[[v4]] : tensor<134x1xi32>) updates(%[[v6]] : tensor<134x1xi32>) +// CHECK: %[[v8:.+]] = tensorrt.identity %[[v7]] : tensor<1024x1xi32> to tensor<1024x1xi1> +// CHECK: return %[[v8]] : tensor<1024x1xi1> + +// ----- + +func.func @large_weight() -> tensor<258x256xf32> { + %c = stablehlo.constant dense_resource<__elided__> : tensor<258x256xi4> + %0 = stablehlo.composite "tensorrt.block_dq" %c {composite_attributes = {axis = -1 : i32, scale = dense_resource<__elided__> : tensor<2x256xf32>}, decomposition = @block_dq} : (tensor<258x256xi4>) -> tensor<258x256xf32> + return %0 : tensor<258x256xf32> +} +func.func private @block_dq(%arg0: tensor<258x256xi4>) -> tensor<258x256xf32> attributes {plan.decomposition} { + %cst = stablehlo.constant dense_resource<__elided__> : tensor<2x256xf32> + %0 = stablehlo.broadcast_in_dim %cst, dims = [1, 2] : (tensor<2x256xf32>) -> tensor<129x2x256xf32> + %1 = stablehlo.reshape %0 : (tensor<129x2x256xf32>) -> tensor<258x256xf32> + %2 = stablehlo.convert %arg0 : (tensor<258x256xi4>) -> tensor<258x256xf32> + %3 = stablehlo.multiply %2, %1 : tensor<258x256xf32> + return %3 : tensor<258x256xf32> +} + +// CHECK-LABEL: large_weight +// CHECK-NEXT: %[[v0:.+]] = tensorrt.constant dense_resource<__elided__> : tensor<258x256xi4> +// CHECK-NEXT: %[[v1:.+]] = tensorrt.constant dense_resource<__elided__> : tensor<2x256xf32> +// CHECK-NEXT: %[[v2:.+]] = tensorrt.dequantize in(%[[v0]] : tensor<258x256xi4>) scale(%[[v1]] : tensor<2x256xf32>) -> tensor<258x256xf32> +// CHECK-NEXT: return %[[v2]] : tensor<258x256xf32> + +// ----- + +func.func @quantize_pt_bf16_to_fp8_static() -> tensor<2xf8E4M3FN> { + %cst = stablehlo.constant dense<[1.000000e+00, 2.000000e+00]> : tensor<2xbf16> + %0 = stablehlo.composite "tensorrt.pt_q" %cst {composite_attributes = {axis = -1 : i32, scale = dense<5.000000e-01> : tensor}, decomposition = @pt_q} : (tensor<2xbf16>) -> tensor<2xf8E4M3FN> + return %0 : tensor<2xf8E4M3FN> +} +func.func private @pt_q(%arg0: tensor<2xbf16>) -> tensor<2xf8E4M3FN> attributes {plan.decomposition} { + %cst = stablehlo.constant dense<-4.480000e+02> : tensor + %cst_0 = stablehlo.constant dense<4.480000e+02> : tensor + %cst_1 = stablehlo.constant dense<5.000000e-01> : tensor + %0 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor) -> tensor<2xbf16> + %1 = stablehlo.divide %arg0, %0 : tensor<2xbf16> + %2 = stablehlo.round_nearest_even %1 : tensor<2xbf16> + %3 = stablehlo.convert %cst_0 : (tensor) -> tensor + %4 = stablehlo.convert %cst : (tensor) -> tensor + %5 = stablehlo.clamp %4, %2, %3 : (tensor, tensor<2xbf16>, tensor) -> tensor<2xbf16> + %6 = stablehlo.convert %5 : (tensor<2xbf16>) -> tensor<2xf8E4M3FN> + return %6 : tensor<2xf8E4M3FN> +} + +// CHECK-LABEL: quantize_pt_bf16_to_fp8_static +// CHECK-NEXT: %[[v0:.+]] = tensorrt.constant dense<[1.000000e+00, 2.000000e+00]> : tensor<2xbf16> +// CHECK-NEXT: %[[v1:.+]] = tensorrt.constant dense<5.000000e-01> : tensor +// CHECK-NEXT: %[[v2:.+]] = tensorrt.quantize in(%[[v0]] : tensor<2xbf16>) scale(%[[v1]] : tensor) -> tensor<2xf8E4M3FN> +// CHECK-NEXT: return %[[v2]] : tensor<2xf8E4M3FN> + +// ----- + +func.func @compare_boolean_inputs(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = stablehlo.constant dense<1> : tensor + %1 = stablehlo.compare LT, %arg0, %0 : (tensor, tensor) -> tensor + %2 = stablehlo.compare LT, %arg1, %0 : (tensor, tensor) -> tensor + %3 = stablehlo.compare NE, %1, %2 : (tensor, tensor) -> tensor + return %3 : tensor +} + +// CHECK-LABEL: @compare_boolean_inputs +// CHECK: %[[v0:.+]] = tensorrt.element_wise +// CHECK-SAME: tensor, tensor) -> tensor +// CHECK: %[[v1:.+]] = tensorrt.element_wise +// CHECK-SAME: tensor, tensor) -> tensor +// CHECK: %[[v2:.+]] = tensorrt.identity %[[v0]] : tensor to tensor +// CHECK: %[[v3:.+]] = tensorrt.identity %[[v1]] : tensor to tensor +// CHECK: tensorrt.element_wise (%[[v2]], %[[v3]] : tensor, tensor) -> tensor +// CHECK: tensorrt.unary {unaryOperation = #tensorrt.unary_operation} diff --git a/mlir-tensorrt/test/Conversion/StablehloToTensorRT/stablehlo-to-tensorrt.mlir b/mlir-tensorrt/test/Conversion/StablehloToTensorRT/stablehlo-to-tensorrt.mlir index 1b2daee0e..c1b8d63ea 100644 --- a/mlir-tensorrt/test/Conversion/StablehloToTensorRT/stablehlo-to-tensorrt.mlir +++ b/mlir-tensorrt/test/Conversion/StablehloToTensorRT/stablehlo-to-tensorrt.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-tensorrt-opt -split-input-file %s --convert-stablehlo-to-tensorrt -allow-unregistered-dialect | FileCheck %s +// RUN: mlir-tensorrt-opt -split-input-file %s --convert-stablehlo-to-tensorrt="trt-major-version=8" -allow-unregistered-dialect | FileCheck %s +// RUN: mlir-tensorrt-opt -split-input-file %s --convert-stablehlo-to-tensorrt="trt-major-version=10" -allow-unregistered-dialect | FileCheck %s func.func @hlo_add_f32_static(%lhs: tensor<128x128xf32>, %rhs: tensor<128x128xf32>) -> tensor<128x128xf32> { %0 = "stablehlo.add"(%lhs, %rhs) : (tensor<128x128xf32>, tensor<128x128xf32>) -> tensor<128x128xf32> @@ -954,30 +955,6 @@ func.func @hlo_argmin(%arg0: tensor<1x10x20xf32>) -> (tensor<1x10xf32>, tensor<1 // ----- -func.func @hlo_iota() -> tensor<128xi32> { - %0 = "stablehlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<128xi32> - return %0 : tensor<128xi32> -} - -// CHECK-LABEL: @hlo_iota -// CHECK: tensorrt.linspace -// CHECK-SAME: [ 0.00{{.+}}] [ static] [ 1.000{{.+}}] : tensor<128xi32> - -// ----- - -func.func @hlo_rem(%arg0: tensor<128xi32>, %arg1: tensor<128xi32>) -> tensor<128xi32> { - %0 = "stablehlo.remainder"(%arg0, %arg1) {} : (tensor<128xi32>, tensor<128xi32>) -> tensor<128xi32> - return %0 : tensor<128xi32> -} - -// CHECK-LABEL: @hlo_rem -// CHECK-SAME: (%[[lhs:.+]]: tensor<128xi32>, %[[rhs:.+]]: tensor<128xi32>) -// CHECK: %[[div:.+]] = tensorrt.element_wise (%[[lhs]], %[[rhs]] : -// CHECK: %[[prod:.+]] = tensorrt.element_wise (%[[div]], %[[rhs]] : -// CHECK: tensorrt.element_wise (%[[lhs]], %[[prod]] : - -// ----- - func.func @hlo_rem_fp(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>) -> tensor<128xf32> { %0 = "stablehlo.remainder"(%arg0, %arg1) {} : (tensor<128xf32>, tensor<128xf32>) -> tensor<128xf32> return %0 : tensor<128xf32> @@ -1132,28 +1109,6 @@ func.func @hlo_pad_static_low_high(%arg0: tensor<10x48x48x32xf32>) -> tensor<10x // ----- -func.func @hlo_pad_dynamic_non_sliced_dim(%arg0: tensor) -> tensor { - %0 = "stablehlo.constant"() {value = dense<0.0> : tensor} : () -> tensor - %1 = "stablehlo.pad"(%arg0, %0) { - edge_padding_high = array, - edge_padding_low = array, - interior_padding = array - } : (tensor, tensor) -> tensor - func.return %1 : tensor -} - -// CHECK-LABEL: @hlo_pad_dynamic_non_sliced_dim -// CHECK-SAME: (%[[arg0:.+]]: tensor -// CHECK: %[[fill:.+]] = tensorrt.constant dense<0.000000e+00> : tensor -// CHECK: %[[shape:.+]] = tensorrt.shape %[[arg0]] : tensor -> tensor<4xi32> -// CHECK: %[[pad_high:.+]] = tensorrt.constant dense<[0, 0, 0, 16]> : tensor<4xi32> -// CHECK: %[[pad_low:.+]] = tensorrt.constant dense<0> : tensor<4xi32> -// CHECK: %[[sum0:.+]] = tensorrt.element_wise (%[[pad_low]], %[[pad_high]] : tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> -// CHECK: %[[sum1:.+]] = tensorrt.element_wise (%[[sum0]], %[[shape]] : tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> -// CHECK: tensorrt.slice %arg0[0, 0, 0, 0][%[[sum1]]: tensor<4xi32>][1, 1, 1, 1] fill(%[[fill]] : tensor) {mode = #tensorrt.slice_mode} : tensor to tensor - -// ----- - func.func @hlo_pad_interior_unsupported(%arg0: tensor<1x2x3xf16>, %arg1: tensor) -> tensor<2x4x7xf16> { %0 = "stablehlo.pad"(%arg0, %arg1) { edge_padding_high = array, @@ -1416,25 +1371,6 @@ func.func @hlo_clamp_i32(%lb : tensor<4xi32>, %x : tensor<4xi32>, %ub : tensor<4 // ----- -func.func @hlo_dynamic_pad(%arg0: tensor, %arg1: tensor, %arg2: tensor<2xi32>, %arg3: tensor<2xi32>) -> tensor { - %cst_0 = stablehlo.constant dense<[0, 0]> : tensor<2xi32> - %0 = "stablehlo.dynamic_pad"(%arg0, %arg1, %arg2, %arg3, %cst_0) : (tensor, tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor - return %0: tensor -} - -// CHECK-LABEL: @hlo_dynamic_pad -// CHECK-SAME: (%[[arg0:.+]]: tensor, %[[arg1:.+]]: tensor, %[[arg2:.+]]: tensor<2xi32>, %[[arg3:.+]]: tensor<2xi32>) -> tensor { -// CHECK: %[[pad_interior:.+]] = tensorrt.constant dense<0> : tensor<2xi32> -// CHECK: %[[padding_low_f32:.+]] = tensorrt.identity %arg2 : tensor<2xi32> to tensor<2xf32> -// CHECK: %[[sliceOffset_f32:.+]] = tensorrt.unary {unaryOperation = #tensorrt.unary_operation} %[[padding_low_f32]] : tensor<2xf32> -// CHECK: %[[sliceOffset_i32:.+]] = tensorrt.identity %[[sliceOffset_f32]] : tensor<2xf32> to tensor<2xi32> -// CHECK: %[[shape:.+]] = tensorrt.shape %[[arg0]] : tensor -> tensor<2xi32> -// CHECK: %[[sum0:.+]] = tensorrt.element_wise (%[[arg2]], %[[arg3]] : tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK: %[[sum1:.+]] = tensorrt.element_wise (%[[sum0]], %[[shape]] : tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK: tensorrt.slice %arg0[%[[sliceOffset_i32]]: tensor<2xi32>][%[[sum1]]: tensor<2xi32>][1, 1] fill(%arg1 : tensor) {mode = #tensorrt.slice_mode} : tensor to tensor - -// ----- - func.func @hlo_dynamic_pad_interior_unsupported(%arg0: tensor, %arg1: tensor, %arg2: tensor<2xi32>, %arg3: tensor<2xi32>, %arg4: tensor<2xi32>) -> tensor { %0 = "stablehlo.dynamic_pad"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor, tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor return %0: tensor @@ -1481,45 +1417,6 @@ func.func @hlo_round_nearest_even(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // ----- -func.func @hlo_dynamic_iota_0(%arg0 : tensor<1xi32>) -> tensor { - %0 = "stablehlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<1xi32>) -> tensor - return %0 : tensor -} - -// CHECK-LABEL: @hlo_dynamic_iota_0 -// CHECK: tensorrt.linspace -// CHECK-SAME: [ 0.00{{.+}}] [%arg0 : tensor<1xi32>] [ 1.000{{.+}}] : tensor - -// ----- - -func.func @dynamic_nd_iota_1(%arg0 : tensor<2xi32>) -> tensor { - %0 = "stablehlo.dynamic_iota"(%arg0) {iota_dimension = 1 : i64} : (tensor<2xi32>) -> tensor - return %0 : tensor -} - -// CHECK-LABEL: func.func @dynamic_nd_iota_1 -// CHECK-SAME: (%[[arg0:.+]]: tensor<2xi32>) -> tensor { -// CHECK: %[[cst_i32:.+]] = tensorrt.constant dense<0> : tensor -// CHECK: %[[cst_i32_0:.+]] = tensorrt.constant dense<[0, 1]> : tensor<2xi32> -// CHECK: %[[v0:.+]] = tensorrt.linspace[%[[cst_i32]] : tensor] [%[[arg0]] : tensor<2xi32>] [%[[cst_i32_0]] : tensor<2xi32>] : tensor -// CHECK: return %[[v0]] : tensor - -// ----- - -func.func @dynamic_nd_iota_2(%arg0 : tensor<2xi32>) -> tensor { - %0 = "stablehlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<2xi32>) -> tensor - return %0 : tensor -} - -// CHECK-LABEL: func.func @dynamic_nd_iota_2 -// CHECK-SAME: (%[[arg0:.+]]: tensor<2xi32>) -> tensor { -// CHECK: %[[cst_i32:.+]] = tensorrt.constant dense<0> : tensor -// CHECK: %[[cst_i32_0:.+]] = tensorrt.constant dense<[1, 0]> : tensor<2xi32> -// CHECK: %[[v0:.+]] = tensorrt.linspace[%[[cst_i32]] : tensor] [%[[arg0]] : tensor<2xi32>] [%[[cst_i32_0]] : tensor<2xi32>] : tensor -// CHECK: return %[[v0]] : tensor - -// ----- - func.func @stablehlo_broadcast(%arg0: tensor<8xf32>) -> tensor<4x8xf32> { %0 = "stablehlo.broadcast"(%arg0) { broadcast_sizes = array @@ -1534,27 +1431,6 @@ func.func @stablehlo_broadcast(%arg0: tensor<8xf32>) -> tensor<4x8xf32> { // ----- -func.func @stablehlo_real_dynamic_slice( - %input: tensor, - %start_indices: tensor<2xindex>, - %limit_indices: tensor<2xindex>, - %strides: tensor<2xindex>) -> tensor { - %0 = "stablehlo.real_dynamic_slice"(%input, %start_indices, %limit_indices, %strides) : (tensor, tensor<2xindex>, tensor<2xindex>, tensor<2xindex>) -> tensor - func.return %0 : tensor -} - -// CHECK-LABEL: @stablehlo_real_dynamic_slice( -// CHECK-SAME: %[[arg0:.+]]: tensor, %[[arg1:.+]]: tensor<2xi32>, %[[arg2:.+]]: tensor<2xi32>, %[[arg3]]: tensor<2xi32>) -> tensor { -// CHECK: %[[num:.+]] = tensorrt.element_wise (%[[arg2]], %[[arg1]] -// CHECK: %[[cst_i32:.+]] = tensorrt.constant dense<1> : tensor<2xi32> -// CHECK: %[[bias:.+]] = tensorrt.element_wise (%[[arg3]], %[[cst_i32]] : tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK: %[[num1:.+]] = tensorrt.element_wise (%[[num]], %[[bias]] : tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK: %[[ceilDiv:.+]] = tensorrt.element_wise (%[[num1]], %[[arg3]] : tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK: %[[result:.+]] = tensorrt.slice %[[arg0]][%[[arg1]]: tensor<2xi32>][%[[ceilDiv]]: tensor<2xi32>][%[[arg3]]: tensor<2xi32>] -// CHECK: return %[[result]] : tensor - -// ----- - func.func @hlo_log1p(%arg0: tensor<10x20x30xf32>) -> tensor<10x20x30xf32> { %0 = "stablehlo.log_plus_one"(%arg0) : (tensor<10x20x30xf32>) -> tensor<10x20x30xf32> return %0 : tensor<10x20x30xf32> @@ -1817,76 +1693,6 @@ func.func @disregard_non_tensor_funcs(%arg0: i32) -> i32 { // ----- -func.func @dynamic_update_slice_conversion_1(%arg0: tensor<1x6x12x64xf32>) -> tensor<1x20x12x64xf32> { - %0 = stablehlo.constant dense<0> : tensor - %1 = stablehlo.constant dense<0.000000e+00> : tensor<1x20x12x64xf32> - %2 = stablehlo.dynamic_update_slice %1, %arg0, %0, %0, %0, %0 : (tensor<1x20x12x64xf32>, tensor<1x6x12x64xf32>, tensor, tensor, tensor, tensor) -> tensor<1x20x12x64xf32> - return %2 : tensor<1x20x12x64xf32> -} - -// CHECK-LABEL: @dynamic_update_slice_conversion_1 -// CHECK-SAME: (%[[arg0:.+]]: tensor<1x6x12x64xf32>) -// CHECK: %[[cst_i32:.+]] = tensorrt.constant dense<0> -// CHECK: %[[cst_f32:.+]] = tensorrt.constant dense<0.000000e+00> -// CHECK: %[[v0:.+]] = tensorrt.expand_rank %[[cst_i32]] : tensor to tensor<1xi32> -// CHECK: %[[cst_i32_0:.+]] = tensorrt.constant dense<1> -// CHECK: %[[cst_i32_1:.+]] = tensorrt.constant dense<[12, 64]> -// CHECK: %[[v1:.+]] = tensorrt.concatenation {axis = 0 : i32} ins(%[[cst_i32_0]], %[[v0]], %[[cst_i32_1]] : -// CHECK: %[[v2:.+]] = tensorrt.slice %[[cst_f32]][0, 0, 0, 0][%[[v1]]: tensor<4xi32>][1, 1, 1, 1] -// CHECK: %[[cst_i32_2:.+]] = tensorrt.constant dense<6> -// CHECK: %[[v3:.+]] = tensorrt.element_wise (%[[v0]], %[[cst_i32_2]] : -// CHECK: %[[cst_i32_3:.+]] = tensorrt.constant dense<0> : tensor<1xi32> -// CHECK: %[[cst_i32_4:.+]] = tensorrt.constant dense<0> : tensor<2xi32> -// CHECK: %[[v4:.+]] = tensorrt.concatenation {axis = 0 : i32} ins(%[[cst_i32_3]], %[[v3]], %[[cst_i32_4]] -// CHECK: %[[cst_i32_5:.+]] = tensorrt.constant dense<20> -// CHECK: %[[v5:.+]] = tensorrt.element_wise (%[[cst_i32_5]], %[[v3]] : -// CHECK: %[[cst_i32_6:.+]] = tensorrt.constant dense<1> -// CHECK: %[[cst_i32_7:.+]] = tensorrt.constant dense<[12, 64]> -// CHECK: %[[v6:.+]] = tensorrt.concatenation {axis = 0 : i32} ins(%[[cst_i32_6]], %[[v5]], %[[cst_i32_7]] : -// CHECK: %[[v7:.+]] = tensorrt.slice %[[cst_f32]][%[[v4]]: tensor<4xi32>][%[[v6]]: tensor<4xi32>][1, 1, 1, 1] -// CHECK: %[[v8:.+]] = tensorrt.concatenation {axis = 1 : i32} ins(%[[v2]], %[[arg0]], %[[v7]] : -// CHECK: return %[[v8]] - -// ----- - -func.func @dynamic_update_slice_conversion_2(%arg0: tensor<1x20x12x64xf32>, %arg1: tensor<1x1x12x64xf32>, %arg2: tensor) -> tensor<1x20x12x64xf32> { - %0 = stablehlo.constant dense<0> : tensor - %1 = stablehlo.dynamic_update_slice %arg0, %arg1, %0, %arg2, %0, %0 : (tensor<1x20x12x64xf32>, tensor<1x1x12x64xf32>, tensor, tensor, tensor, tensor) -> tensor<1x20x12x64xf32> - return %1 : tensor<1x20x12x64xf32> -} - -// CHECK-LABEL: @dynamic_update_slice_conversion_2 -// CHECK-SAME: (%[[arg0:.+]]: tensor<1x20x12x64xf32>, %[[arg1:.+]]: tensor<1x1x12x64xf32>, %[[arg2:.+]]: tensor) -// CHECK: %[[cst_i32:.+]] = tensorrt.constant dense<0> : tensor -// CHECK: %[[v0:.+]] = tensorrt.expand_rank %[[arg2]] : tensor to tensor<1xi32> -// CHECK: %[[cst_i32_0:.+]] = tensorrt.constant dense<1> : tensor<1xi32> -// CHECK: %[[cst_i32_1:.+]] = tensorrt.constant dense<[12, 64]> : tensor<2xi32> -// CHECK: %[[v1:.+]] = tensorrt.concatenation -// CHECK-SAME: axis = 0 : i32 -// CHECK-SAME: ins(%[[cst_i32_0]], %[[v0]], %[[cst_i32_1]] : tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -// CHECK: %[[v2:.+]] = tensorrt.slice %[[arg0]][0, 0, 0, 0][%[[v1]]: tensor<4xi32>][1, 1, 1, 1] : tensor<1x20x12x64xf32> to tensor<1x?x12x64xf32> -// CHECK: %[[cst_i32_2:.+]] = tensorrt.constant dense<1> : tensor<1xi32> -// CHECK: %[[v3:.+]] = tensorrt.element_wise (%[[v0]], %[[cst_i32_2]] : tensor<1xi32>, tensor<1xi32>) -// CHECK: %[[cst_i32_3:.+]] = tensorrt.constant dense<0> : tensor<1xi32> -// CHECK: %[[cst_i32_4:.+]] = tensorrt.constant dense<0> : tensor<2xi32> -// CHECK: %[[v4:.+]] = tensorrt.concatenation -// CHECK-SAME: axis = 0 : i32 -// CHECK-SAME: ins(%[[cst_i32_3]], %[[v3]], %[[cst_i32_4]] : tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -// CHECK: %[[cst_i32_5:.+]] = tensorrt.constant dense<20> : tensor<1xi32> -// CHECK: %[[v5:.+]] = tensorrt.element_wise (%[[cst_i32_5]], %[[v3]] : tensor<1xi32>, tensor<1xi32>) -// CHECK: %[[cst_i32_6:.+]] = tensorrt.constant dense<1> : tensor<1xi32> -// CHECK: %[[cst_i32_7:.+]] = tensorrt.constant dense<[12, 64]> : tensor<2xi32> -// CHECK: %[[v6:.+]] = tensorrt.concatenation -// CHECK-SAME: axis = 0 : i32 -// CHECK-SAME: ins(%[[cst_i32_6]], %[[v5]], %[[cst_i32_7]] : tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -// CHECK: %[[v7:.+]] = tensorrt.slice %[[arg0]][%[[v4]]: tensor<4xi32>][%[[v6]]: tensor<4xi32>][1, 1, 1, 1] : tensor<1x20x12x64xf32> to tensor<1x?x12x64xf32> -// CHECK: %[[v8:.+]] = tensorrt.concatenation -// CHECK-SAME: axis = 1 : i32 -// CHECK-SAME: ins(%[[v2]], %[[arg1]], %[[v7]] : tensor<1x?x12x64xf32>, tensor<1x1x12x64xf32>, tensor<1x?x12x64xf32>) -// CHECK: return %[[v8]] - -// ----- - func.func @dynamic_update_slice_conversion_unsupported1(%arg0: tensor<1x20x12x64xf32>, %arg1: tensor<1x1x1x64xf32>, %arg2: tensor) -> tensor<1x20x12x64xf32> { %0 = stablehlo.constant dense<0> : tensor %1 = stablehlo.dynamic_update_slice %arg0, %arg1, %0, %arg2, %arg2, %0 : (tensor<1x20x12x64xf32>, tensor<1x1x1x64xf32>, tensor, tensor, tensor, tensor) -> tensor<1x20x12x64xf32> @@ -1909,88 +1715,6 @@ func.func @dynamic_update_slice_conversion_unsupported2(%arg0: tensor<1x20x12x64 // ----- -func.func @scatter_slice_update(%arg0: tensor<1x134xi32>, %arg1: tensor<1x2xi32>, %arg2: tensor<1x1x5xi32>) -> tensor<1x134xi32> { - %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ - ^bb0(%arg3: tensor, %arg4: tensor): - stablehlo.return %arg4 : tensor - }) {indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false} : (tensor<1x134xi32>, tensor<1x2xi32>, tensor<1x1x5xi32>) -> tensor<1x134xi32> - return %0 : tensor<1x134xi32> -} - -// CHECK-LABEL: @scatter_slice_update -// CHECK: %[[v0:.+]] = tensorrt.slice %arg1[0, 1][1, 1][1, 1] : tensor<1x2xi32> to tensor<1x1xi32> -// CHECK: %[[v1:.+]] = tensorrt.collapse_rank %[[v0]] : tensor<1x1xi32> to tensor -// CHECK: %[[v2:.+]] = tensorrt.collapse_rank %arg2 : tensor<1x1x5xi32> to tensor<1x5xi32> -// CHECK: %[[v3:.+]] = tensorrt.constant dense<1> : tensor<2xi32> -// CHECK: %[[v4:.+]] = tensorrt.linspace[%[[v1]] : tensor] [ static] [%[[v3]] : tensor<2xi32>] : tensor<1x5xi32> -// CHECK: %[[v5:.+]] = tensorrt.scatter_elements {axis = 1 : i64} data(%arg0 : tensor<1x134xi32>) indices(%[[v4]] : tensor<1x5xi32>) updates(%[[v2]] : tensor<1x5xi32>) -// CHECK: return %[[v5]] : tensor<1x134xi32> - -// ----- - -func.func @scatter_slice_update_f16_axis1(%arg0: tensor<1x134xf16>, %arg1: tensor<1x2xi32>, %arg2: tensor<1x1x5xf16>) -> tensor<1x134xf16> { - %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ - ^bb0(%arg3: tensor, %arg4: tensor): - stablehlo.return %arg4 : tensor - }) {indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false} : (tensor<1x134xf16>, tensor<1x2xi32>, tensor<1x1x5xf16>) -> tensor<1x134xf16> - return %0 : tensor<1x134xf16> -} - -// CHECK-LABEL: @scatter_slice_update_f16 -// CHECK: %[[v0:.+]] = tensorrt.slice %arg1[0, 1][1, 1][1, 1] : tensor<1x2xi32> to tensor<1x1xi32> -// CHECK: %[[v1:.+]] = tensorrt.collapse_rank %[[v0]] : tensor<1x1xi32> to tensor -// CHECK: %[[v2:.+]] = tensorrt.collapse_rank %arg2 : tensor<1x1x5xf16> to tensor<1x5xf16> -// CHECK: %[[v3:.+]] = tensorrt.constant dense<1> : tensor<2xi32> -// CHECK: %[[v4:.+]] = tensorrt.linspace[%[[v1]] : tensor] [ static] [%[[v3]] : tensor<2xi32>] : tensor<1x5xi32> -// CHECK: %[[v5:.+]] = tensorrt.scatter_elements {axis = 1 : i64} data(%arg0 : tensor<1x134xf16>) indices(%[[v4]] : tensor<1x5xi32>) updates(%[[v2]] : tensor<1x5xf16>) -// CHECK: return %[[v5]] : tensor<1x134xf16> - -// ----- - -func.func @scatter_slice_update_i1_axis1(%arg0: tensor<1x134xi1>, %arg1: tensor<1x2xi32>, %arg2: tensor<1x1x5xi1>) -> tensor<1x134xi1> { - %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ - ^bb0(%arg3: tensor, %arg4: tensor): - stablehlo.return %arg4 : tensor - }) {indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false} : (tensor<1x134xi1>, tensor<1x2xi32>, tensor<1x1x5xi1>) -> tensor<1x134xi1> - return %0 : tensor<1x134xi1> -} - -// CHECK-LABEL: @scatter_slice_update_i1 -// CHECK: %[[v0:.+]] = tensorrt.slice %arg1[0, 1][1, 1][1, 1] : tensor<1x2xi32> to tensor<1x1xi32> -// CHECK: %[[v1:.+]] = tensorrt.collapse_rank %[[v0]] : tensor<1x1xi32> to tensor -// CHECK: %[[v2:.+]] = tensorrt.collapse_rank %arg2 : tensor<1x1x5xi1> to tensor<1x5xi1> -// CHECK: %[[v3:.+]] = tensorrt.constant dense<1> : tensor<2xi32> -// CHECK: %[[v4:.+]] = tensorrt.linspace[%[[v1]] : tensor] [ static] [%[[v3]] : tensor<2xi32>] : tensor<1x5xi32> -// CHECK: %[[v5:.+]] = tensorrt.identity %arg0 : tensor<1x134xi1> to tensor<1x134xi32> -// CHECK: %[[v6:.+]] = tensorrt.identity %[[v2]] : tensor<1x5xi1> to tensor<1x5xi32> -// CHECK: %[[v7:.+]] = tensorrt.scatter_elements {axis = 1 : i64} data(%[[v5]] : tensor<1x134xi32>) indices(%[[v4]] : tensor<1x5xi32>) updates(%[[v6]] : tensor<1x5xi32>) -// CHECK: %[[v8:.+]] = tensorrt.identity %[[v7]] : tensor<1x134xi32> to tensor<1x134xi1> -// CHECK: return %[[v8]] : tensor<1x134xi1> - -// ----- - -func.func @scatter_slice_update_i1_axis0(%arg0: tensor<1024x1xi1>, %arg1: tensor<1x1xi32>, %arg2: tensor<1x134x1xi1>) -> tensor<1024x1xi1> { - %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ - ^bb0(%arg3: tensor, %arg4: tensor): - stablehlo.return %arg4 : tensor - }) : (tensor<1024x1xi1>, tensor<1x1xi32>, tensor<1x134x1xi1>) -> tensor<1024x1xi1> - return %0 : tensor<1024x1xi1> -} - -// CHECK-LABEL: @scatter_slice_update_i1_axis0 -// CHECK: %[[v0:.+]] = tensorrt.slice %arg1[0, 0][1, 1][1, 1] : tensor<1x1xi32> to tensor<1x1xi32> -// CHECK: %[[v1:.+]] = tensorrt.collapse_rank %[[v0]] : tensor<1x1xi32> to tensor -// CHECK: %[[v2:.+]] = tensorrt.collapse_rank %arg2 : tensor<1x134x1xi1> to tensor<134x1xi1> -// CHECK: %[[v3:.+]] = tensorrt.constant dense<1> : tensor<2xi32> -// CHECK: %[[v4:.+]] = tensorrt.linspace[%[[v1]] : tensor] [ static] [%[[v3]] : tensor<2xi32>] : tensor<134x1xi32> -// CHECK: %[[v5:.+]] = tensorrt.identity %arg0 : tensor<1024x1xi1> to tensor<1024x1xi32> -// CHECK: %[[v6:.+]] = tensorrt.identity %[[v2]] : tensor<134x1xi1> to tensor<134x1xi32> -// CHECK: %[[v7:.+]] = tensorrt.scatter_elements {axis = 0 : i64} data(%[[v5]] : tensor<1024x1xi32>) indices(%[[v4]] : tensor<134x1xi32>) updates(%[[v6]] : tensor<134x1xi32>) -// CHECK: %[[v8:.+]] = tensorrt.identity %[[v7]] : tensor<1024x1xi32> to tensor<1024x1xi1> -// CHECK: return %[[v8]] : tensor<1024x1xi1> - -// ----- - func.func @quantize_pt_to_i8_static(%arg0: tensor<2x3x300x300xf32>) -> tensor<2x3x300x300xi8> { %0 = stablehlo.composite "tensorrt.pt_q" %arg0 {composite_attributes = {axis = -1 : i32, scale = dense<8.000000e-01> : tensor}, decomposition = @pt_q} : (tensor<2x3x300x300xf32>) -> tensor<2x3x300x300xi8> return %0 : tensor<2x3x300x300xi8> @@ -2073,75 +1797,6 @@ func.func private @pc_q(%arg0: tensor<258x256xf32>) -> tensor<258x256xi8> attrib // ----- -func.func @large_weight() -> tensor<258x256xf32> { - %c = stablehlo.constant dense_resource<__elided__> : tensor<258x256xi4> - %0 = stablehlo.composite "tensorrt.block_dq" %c {composite_attributes = {axis = -1 : i32, scale = dense_resource<__elided__> : tensor<2x256xf32>}, decomposition = @block_dq} : (tensor<258x256xi4>) -> tensor<258x256xf32> - return %0 : tensor<258x256xf32> -} -func.func private @block_dq(%arg0: tensor<258x256xi4>) -> tensor<258x256xf32> attributes {plan.decomposition} { - %cst = stablehlo.constant dense_resource<__elided__> : tensor<2x256xf32> - %0 = stablehlo.broadcast_in_dim %cst, dims = [1, 2] : (tensor<2x256xf32>) -> tensor<129x2x256xf32> - %1 = stablehlo.reshape %0 : (tensor<129x2x256xf32>) -> tensor<258x256xf32> - %2 = stablehlo.convert %arg0 : (tensor<258x256xi4>) -> tensor<258x256xf32> - %3 = stablehlo.multiply %2, %1 : tensor<258x256xf32> - return %3 : tensor<258x256xf32> -} - -// CHECK-LABEL: large_weight -// CHECK-NEXT: %[[v0:.+]] = tensorrt.constant dense_resource<__elided__> : tensor<258x256xi4> -// CHECK-NEXT: %[[v1:.+]] = tensorrt.constant dense_resource<__elided__> : tensor<2x256xf32> -// CHECK-NEXT: %[[v2:.+]] = tensorrt.dequantize in(%[[v0]] : tensor<258x256xi4>) scale(%[[v1]] : tensor<2x256xf32>) -> tensor<258x256xf32> -// CHECK-NEXT: return %[[v2]] : tensor<258x256xf32> - -// ----- - -func.func @quantize_pt_bf16_to_fp8_static() -> tensor<2xf8E4M3FN> { - %cst = stablehlo.constant dense<[1.000000e+00, 2.000000e+00]> : tensor<2xbf16> - %0 = stablehlo.composite "tensorrt.pt_q" %cst {composite_attributes = {axis = -1 : i32, scale = dense<5.000000e-01> : tensor}, decomposition = @pt_q} : (tensor<2xbf16>) -> tensor<2xf8E4M3FN> - return %0 : tensor<2xf8E4M3FN> -} -func.func private @pt_q(%arg0: tensor<2xbf16>) -> tensor<2xf8E4M3FN> attributes {plan.decomposition} { - %cst = stablehlo.constant dense<-4.480000e+02> : tensor - %cst_0 = stablehlo.constant dense<4.480000e+02> : tensor - %cst_1 = stablehlo.constant dense<5.000000e-01> : tensor - %0 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor) -> tensor<2xbf16> - %1 = stablehlo.divide %arg0, %0 : tensor<2xbf16> - %2 = stablehlo.round_nearest_even %1 : tensor<2xbf16> - %3 = stablehlo.convert %cst_0 : (tensor) -> tensor - %4 = stablehlo.convert %cst : (tensor) -> tensor - %5 = stablehlo.clamp %4, %2, %3 : (tensor, tensor<2xbf16>, tensor) -> tensor<2xbf16> - %6 = stablehlo.convert %5 : (tensor<2xbf16>) -> tensor<2xf8E4M3FN> - return %6 : tensor<2xf8E4M3FN> -} - -// CHECK-LABEL: quantize_pt_bf16_to_fp8_static -// CHECK-NEXT: %[[v0:.+]] = tensorrt.constant dense<[1.000000e+00, 2.000000e+00]> : tensor<2xbf16> -// CHECK-NEXT: %[[v1:.+]] = tensorrt.constant dense<5.000000e-01> : tensor -// CHECK-NEXT: %[[v2:.+]] = tensorrt.quantize in(%[[v0]] : tensor<2xbf16>) scale(%[[v1]] : tensor) -> tensor<2xf8E4M3FN> -// CHECK-NEXT: return %[[v2]] : tensor<2xf8E4M3FN> - -// ----- - -func.func @compare_boolean_inputs(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = stablehlo.constant dense<1> : tensor - %1 = stablehlo.compare LT, %arg0, %0 : (tensor, tensor) -> tensor - %2 = stablehlo.compare LT, %arg1, %0 : (tensor, tensor) -> tensor - %3 = stablehlo.compare NE, %1, %2 : (tensor, tensor) -> tensor - return %3 : tensor -} - -// CHECK-LABEL: @compare_boolean_inputs -// CHECK: %[[v0:.+]] = tensorrt.element_wise -// CHECK-SAME: tensor, tensor) -> tensor -// CHECK: %[[v1:.+]] = tensorrt.element_wise -// CHECK-SAME: tensor, tensor) -> tensor -// CHECK: %[[v2:.+]] = tensorrt.identity %[[v0]] : tensor to tensor -// CHECK: %[[v3:.+]] = tensorrt.identity %[[v1]] : tensor to tensor -// CHECK: tensorrt.element_wise (%[[v2]], %[[v3]] : tensor, tensor) -> tensor -// CHECK: tensorrt.unary {unaryOperation = #tensorrt.unary_operation} - -// ----- - func.func @jnp_cumsum_2d_i32(%arg0: tensor<1x134xi32>) -> tensor<1x134xi32> { %cst = arith.constant dense<0> : tensor %4 = "stablehlo.reduce_window"(%arg0, %cst) <{base_dilations = array, padding = dense<[[0, 0], [133, 0]]> : tensor<2x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array}> ({ @@ -2272,4 +1927,4 @@ func.func @jnp_cumsum_2d_f16(%arg0: tensor<1x134xf16>) -> tensor<1x134xf16> { // CHECK-SAME: post_padding = array // CHECK-SAME: pre_padding = array // CHECK-SAME: in(%[[v1]] : tensor<1x1x1x134xf16>) kernel(%[[v2]] : tensor<1x1x1x134xf16>) -> tensor<1x1x1x134xf16> -// CHECK: %[[v4:.+]] = tensorrt.reshape %[[v3]] : tensor<1x1x1x134xf16> to tensor<1x134xf16> +// CHECK: %[[v4:.+]] = tensorrt.reshape %[[v3]] : tensor<1x1x1x134xf16> to tensor<1x134xf16> \ No newline at end of file diff --git a/mlir-tensorrt/test/Dialect/Plan/segmentation-pipeline.mlir b/mlir-tensorrt/test/Dialect/Plan/segmentation-pipeline.mlir index 216c20b65..ee6c825ac 100644 --- a/mlir-tensorrt/test/Dialect/Plan/segmentation-pipeline.mlir +++ b/mlir-tensorrt/test/Dialect/Plan/segmentation-pipeline.mlir @@ -1,5 +1,5 @@ // RUN: mlir-tensorrt-opt -split-input-file \ -// RUN: -plan-segmentation-pipeline -cse -verify-diagnostics %s | FileCheck %s +// RUN: -plan-segmentation-pipeline=trt-major-version=10 -cse -verify-diagnostics %s | FileCheck %s builtin.module attributes { plan.cluster_kinds = [ diff --git a/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_dynamic_iota.py b/mlir-tensorrt/test/python/IntegrationTests/TRT10/test_stablehlo_dynamic_iota.py similarity index 100% rename from mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_dynamic_iota.py rename to mlir-tensorrt/test/python/IntegrationTests/TRT10/test_stablehlo_dynamic_iota.py From 4f8fd901657b9e1b734813eaa99ba8c0e1944ce3 Mon Sep 17 00:00:00 2001 From: Parth Chadha Date: Wed, 20 Nov 2024 13:57:27 -0800 Subject: [PATCH 24/29] Update tripy version to 0.0.4 (#397) --- tripy/docs/packages.html | 25 ++++++++++++++++++++++++- tripy/pyproject.toml | 2 +- tripy/tripy/__init__.py | 2 +- 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/tripy/docs/packages.html b/tripy/docs/packages.html index 7cbef2ff4..28063782d 100644 --- a/tripy/docs/packages.html +++ b/tripy/docs/packages.html @@ -9,6 +9,9 @@

Package Index

+ tripy-0.0.4-py3-none-any.whl
+ tripy-0.0.3-py3-none-any.whl
@@ -102,6 +105,26 @@

Package Index

href="https://github.com/NVIDIA/TensorRT-Incubator/releases/download/mlir-tensorrt-v0.1.36/mlir_tensorrt_runtime-0.1.36+cuda12.trt102-cp312-cp312-linux_x86_64.whl">mlir_tensorrt_runtime-0.1.36+cuda12.trt102-cp312-cp312-linux_x86_64.whl
mlir_tensorrt_runtime-0.1.36+cuda12.trt102-cp39-cp39-linux_x86_64.whl
- + + + mlir_tensorrt_compiler-0.1.37+cuda12.trt102-cp310-cp310-linux_x86_64.whl
+ mlir_tensorrt_compiler-0.1.37+cuda12.trt102-cp311-cp311-linux_x86_64.whl
+ mlir_tensorrt_compiler-0.1.37+cuda12.trt102-cp312-cp312-linux_x86_64.whl
+ mlir_tensorrt_compiler-0.1.37+cuda12.trt102-cp39-cp39-linux_x86_64.whl
+ mlir_tensorrt_runtime-0.1.37+cuda12.trt102-cp310-cp310-linux_x86_64.whl
+ mlir_tensorrt_runtime-0.1.37+cuda12.trt102-cp311-cp311-linux_x86_64.whl
+ mlir_tensorrt_runtime-0.1.37+cuda12.trt102-cp312-cp312-linux_x86_64.whl
+ mlir_tensorrt_runtime-0.1.37+cuda12.trt102-cp39-cp39-linux_x86_64.whl
+ + + diff --git a/tripy/pyproject.toml b/tripy/pyproject.toml index 43f6bf6ee..a22dc06dc 100644 --- a/tripy/pyproject.toml +++ b/tripy/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "tripy" -version = "0.0.3" +version = "0.0.4" authors = [{name = "NVIDIA", email="svc_tensorrt@nvidia.com"}] description = "Tripy: A Python Programming Model For TensorRT" readme = "README.md" diff --git a/tripy/tripy/__init__.py b/tripy/tripy/__init__.py index 275e88820..642f0ec16 100644 --- a/tripy/tripy/__init__.py +++ b/tripy/tripy/__init__.py @@ -15,7 +15,7 @@ # limitations under the License. # -__version__ = "0.0.3" +__version__ = "0.0.4" # Import TensorRT to make sure all dependent libraries are loaded first. import tensorrt From 259ebf34e140f4563da23f06f408b09304e3eb98 Mon Sep 17 00:00:00 2001 From: Parth Chadha Date: Wed, 20 Nov 2024 16:28:07 -0800 Subject: [PATCH 25/29] Add compile fixture to test integration ops with compile mode (#387) --- tripy/tests/integration/conftest.py | 61 ++++++++++++++++++ tripy/tests/integration/test_batchnorm.py | 4 +- tripy/tests/integration/test_cast.py | 27 ++++---- tripy/tests/integration/test_concatenate.py | 8 +-- tripy/tests/integration/test_conv.py | 16 ++--- .../tests/integration/test_conv_transpose.py | 24 +++---- tripy/tests/integration/test_cumsum.py | 5 +- tripy/tests/integration/test_dequantize.py | 16 +++-- tripy/tests/integration/test_expand.py | 16 ++--- tripy/tests/integration/test_flatten.py | 16 ++--- tripy/tests/integration/test_flip.py | 18 +++--- tripy/tests/integration/test_full.py | 16 ++--- tripy/tests/integration/test_gather.py | 6 +- tripy/tests/integration/test_groupnorm.py | 4 +- tripy/tests/integration/test_iota.py | 24 +++---- tripy/tests/integration/test_layernorm.py | 4 +- tripy/tests/integration/test_linear.py | 14 ++-- .../integration/test_matrix_multiplication.py | 16 +++-- tripy/tests/integration/test_outer.py | 8 +-- tripy/tests/integration/test_pad.py | 8 +-- tripy/tests/integration/test_pooling.py | 12 ++-- tripy/tests/integration/test_quantize.py | 64 ++++++++++++++----- tripy/tests/integration/test_reduce.py | 24 +++---- tripy/tests/integration/test_repeat.py | 8 +-- tripy/tests/integration/test_reshape.py | 12 ++-- tripy/tests/integration/test_resize.py | 16 +++-- tripy/tests/integration/test_sequential.py | 12 ++-- tripy/tests/integration/test_slice.py | 18 ++++-- tripy/tests/integration/test_split.py | 11 +++- tripy/tests/integration/test_stack.py | 8 +-- .../integration/test_unary_elementwise.py | 4 +- tripy/tests/integration/test_unsqueeze.py | 4 +- tripy/tests/integration/test_where_op.py | 8 +-- 33 files changed, 317 insertions(+), 195 deletions(-) create mode 100644 tripy/tests/integration/conftest.py diff --git a/tripy/tests/integration/conftest.py b/tripy/tests/integration/conftest.py new file mode 100644 index 000000000..1229219f2 --- /dev/null +++ b/tripy/tests/integration/conftest.py @@ -0,0 +1,61 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2024-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest + +import tripy as tp + + +@pytest.fixture(params=["compile", "eager"]) +def eager_or_compiled(request): + def wrapper(func, *args, **kwargs): + def get_input_info(x: tp.Tensor): + return tp.InputInfo(list(map(int, x.shape)), dtype=x.dtype) + + if request.param == "eager": + return func(*args, **kwargs) + + assert request.param == "compile" + + compile_args = [] + for arg in args: + # We don't want to feed DimensionSize as a dynamic input to the compiler (https://github.com/NVIDIA/TensorRT-Incubator/issues/65). + if isinstance(arg, tp.Tensor) and not isinstance(arg, tp.DimensionSize): + compile_args.append(get_input_info(arg)) + else: + compile_args.append(arg) + compile_args = tuple(compile_args) + + compile_kwargs = dict( + ( + k, + ((get_input_info(v) if isinstance(v, tp.Tensor) and not isinstance(v, tp.DimensionSize) else v)), + ) + for k, v in kwargs.items() + ) + + compiled_func = tp.compile(func, args=compile_args, kwargs=compile_kwargs) + + tensor_args = tuple(x for x in args if isinstance(x, tp.Tensor) and not isinstance(x, tp.DimensionSize)) + + tensor_kwargs = { + k: v for k, v in kwargs.items() if isinstance(v, tp.Tensor) and not isinstance(v, tp.DimensionSize) + } + + return compiled_func(*tensor_args, **tensor_kwargs) + + return wrapper diff --git a/tripy/tests/integration/test_batchnorm.py b/tripy/tests/integration/test_batchnorm.py index 37f6cbf82..89a4c6715 100644 --- a/tripy/tests/integration/test_batchnorm.py +++ b/tripy/tests/integration/test_batchnorm.py @@ -26,7 +26,7 @@ class TestBatchNorm: @pytest.mark.parametrize("torch_dtype, tp_dtype", DTYPES) @pytest.mark.parametrize("input_shape", [(2, 2, 2, 2)]) - def test_batchnorm_accuracy(self, torch_dtype, tp_dtype, input_shape): + def test_batchnorm_accuracy(self, torch_dtype, tp_dtype, input_shape, eager_or_compiled): eps = 1e-5 num_features = input_shape[1] # Number of channels in the input tensor batchnorm = torch.nn.BatchNorm2d(num_features=num_features, eps=eps, dtype=torch_dtype) @@ -45,7 +45,7 @@ def test_batchnorm_accuracy(self, torch_dtype, tp_dtype, input_shape): input = torch.randn(input_shape, dtype=torch_dtype).to("cuda") tp_input = tp.Tensor(input, dtype=tp_dtype) - output = tp_batchnorm(tp_input) + output = eager_or_compiled(tp_batchnorm, tp_input) batchnorm.to("cuda").eval() with torch.no_grad(): diff --git a/tripy/tests/integration/test_cast.py b/tripy/tests/integration/test_cast.py index 3e5902924..634373237 100644 --- a/tripy/tests/integration/test_cast.py +++ b/tripy/tests/integration/test_cast.py @@ -30,54 +30,53 @@ class TestCast: [ (np.int32, np.float32), (np.float32, np.int32), - (np.int64, np.float32), - (np.float32, np.int64), - (np.int64, np.int32), - (np.int64, np.int8), (np.int32, np.int8), (np.float32, np.int8), - (np.int8, np.int64), (np.int8, np.int32), (np.int8, np.float32), # important to test conversion into bool because default StableHLO semantics # are simply to truncate to i1, which is not desirable (np.float32, bool), (np.int32, bool), - (np.int64, bool), # requires a dequantization first # TODO(#219): Dequantize fails with dynamic shapes # (np.int8, bool), ], ) - def test_cast(self, input_dtype, target_dtype): + def test_cast(self, input_dtype, target_dtype, eager_or_compiled): tp_input_dtype = NUMPY_TO_TRIPY[input_dtype] tp_target_dtype = NUMPY_TO_TRIPY[target_dtype] # TODO(#222): Integer casts with negative numbers fail in many cases input_tensor = tp.Tensor([0, 1, 2], dtype=tp_input_dtype) np_input = cp.from_dlpack(input_tensor).get() - output = tp.cast(input_tensor, tp_target_dtype) + output = eager_or_compiled(tp.cast, input_tensor, tp_target_dtype) assert np.array_equal(cp.from_dlpack(output).get(), np_input.astype(target_dtype)) # these dtypes don't have analogues in numpy @pytest.mark.parametrize("source_dtype", [pytest.param(tp.float8, marks=skip_if_older_than_sm89), tp.int4]) - def test_cast_quantized_dtypes_into_bool(self, source_dtype): + def test_cast_quantized_dtypes_into_bool(self, source_dtype, eager_or_compiled): # TODO(#223): Using an odd size leads to a strange crash, so can't just use [-1.0, 0.0, 1.0] input_tensor = tp.Tensor([-1.0, 0.0, 0.0, 1.0], dtype=tp.float32) - q = tp.quantize(input_tensor, scale=1.0, dtype=source_dtype) - output = tp.cast(q, tp.bool) + + def func(input): + q = tp.quantize(input, scale=1.0, dtype=source_dtype) + output = tp.cast(q, tp.bool) + return output + + output = eager_or_compiled(func, input_tensor) assert cp.from_dlpack(output).get().tolist() == [True, False, False, True] - @pytest.mark.parametrize("target_dtype", [np.float32, np.int32, np.int64, np.int8]) - def test_cast_from_bool(self, target_dtype): + @pytest.mark.parametrize("target_dtype", [np.float32, np.int32, np.int8]) + def test_cast_from_bool(self, target_dtype, eager_or_compiled): tp_target_dtype = NUMPY_TO_TRIPY[target_dtype] # in principle, it is not important what *specific* values we convert to, # so long as false is mapped to 0 and true to nonzero input_tensor = tp.Tensor([False, True], dtype=tp.bool) np_input = cp.from_dlpack(input_tensor).get() - output = tp.cast(input_tensor, tp_target_dtype) + output = eager_or_compiled(tp.cast, input_tensor, tp_target_dtype) tp_compare_to_zero = cp.from_dlpack(output).get() == 0 np_compare_to_zero = np_input.astype(target_dtype) == 0 diff --git a/tripy/tests/integration/test_concatenate.py b/tripy/tests/integration/test_concatenate.py index 01ea823b5..9df2d2f70 100644 --- a/tripy/tests/integration/test_concatenate.py +++ b/tripy/tests/integration/test_concatenate.py @@ -33,9 +33,9 @@ class TestConcatenate: ([(2, 3, 4)], 0), ], ) - def test_concat(self, tensor_shapes, dim): + def test_concat(self, tensor_shapes, dim, eager_or_compiled): tensors = [tp.ones(shape) for shape in tensor_shapes] - out = tp.concatenate(tensors, dim=dim) + out = eager_or_compiled(tp.concatenate, tensors, dim=dim) assert np.array_equal( cp.from_dlpack(out).get(), np.concatenate([np.ones(shape) for shape in tensor_shapes], axis=dim) ) @@ -44,8 +44,8 @@ def test_concat(self, tensor_shapes, dim): "tensor_shapes, dim", [([(2, 3, 4), (2, 4, 4)], 0), ([(4, 5, 6), (4, 1, 6)], -1)], ) - def test_negative_concat(self, tensor_shapes, dim): + def test_negative_concat(self, tensor_shapes, dim, eager_or_compiled): tensors = [tp.ones(shape) for shape in tensor_shapes] with helper.raises(tp.TripyException, match=f"not compatible at non-concat index"): - out = tp.concatenate(tensors, dim=dim) + out = eager_or_compiled(tp.concatenate, tensors, dim=dim) print(out) diff --git a/tripy/tests/integration/test_conv.py b/tripy/tests/integration/test_conv.py index 3f67c6629..078c2890d 100644 --- a/tripy/tests/integration/test_conv.py +++ b/tripy/tests/integration/test_conv.py @@ -75,7 +75,7 @@ class ConvTestCase: @pytest.mark.parametrize("torch_dtype,tp_dtype", DTYPES) class TestConvolution: @pytest.mark.parametrize("test_case", test_cases_1d) - def test_convolution_1d(self, torch_dtype, tp_dtype, test_case): + def test_convolution_1d(self, torch_dtype, tp_dtype, test_case, eager_or_compiled): if not test_case.torch_pad: test_case.torch_pad = 0 if not test_case.stride: @@ -122,7 +122,7 @@ def test_convolution_1d(self, torch_dtype, tp_dtype, test_case): conv_layer.bias = tp.cast(tp.Tensor(conv_layer_torch.bias.data), tp_dtype) expected = conv_layer_torch(input_torch).to(torch_dtype) - output = conv_layer(input) + output = eager_or_compiled(conv_layer, input) # FP32 kernel seems to lose some precision, and FP16 needs to be run in FP32 on torch rtol_ = 4e-5 if tp_dtype == tp.float32 else 1e-3 @@ -131,7 +131,7 @@ def test_convolution_1d(self, torch_dtype, tp_dtype, test_case): assert list(output_torch.shape) == list(expected.shape) @pytest.mark.parametrize("test_case", test_cases_2d) - def test_convolution_2d(self, torch_dtype, tp_dtype, test_case): + def test_convolution_2d(self, torch_dtype, tp_dtype, test_case, eager_or_compiled): if not test_case.torch_pad: test_case.torch_pad = 0 if not test_case.stride: @@ -178,7 +178,7 @@ def test_convolution_2d(self, torch_dtype, tp_dtype, test_case): conv_layer.bias = tp.cast(tp.Tensor(conv_layer_torch.bias.data), tp_dtype) expected = conv_layer_torch(input_torch).to(torch_dtype) - output = conv_layer(input) + output = eager_or_compiled(conv_layer, input) rtol_ = 2e-7 if tp_dtype == tp.float32 else 1.5e-3 output_torch = torch.from_dlpack(output) @@ -186,7 +186,7 @@ def test_convolution_2d(self, torch_dtype, tp_dtype, test_case): assert list(output_torch.shape) == list(expected.shape) @pytest.mark.parametrize("test_case", test_cases_3d) - def test_convolution_3d(self, torch_dtype, tp_dtype, test_case): + def test_convolution_3d(self, torch_dtype, tp_dtype, test_case, eager_or_compiled): pytest.skip("TODO (#260): Fix accuracy bugs in 3D conv") if not test_case.torch_pad: test_case.torch_pad = 0 @@ -245,14 +245,14 @@ def test_convolution_3d(self, torch_dtype, tp_dtype, test_case): return expected = conv_layer_torch(input_torch).to(torch_dtype) - output = conv_layer(input) + output = eager_or_compiled(conv_layer, input) rtol_ = 2e-4 if tp_dtype == tp.float32 else 1.4e-3 # 3d conv has greater accumulation error output_torch = torch.from_dlpack(output) assert torch.allclose(output_torch, expected, rtol=rtol_) assert list(output_torch.shape) == list(expected.shape) - def test_uneven_padding(self, torch_dtype, tp_dtype): + def test_uneven_padding(self, torch_dtype, tp_dtype, eager_or_compiled): input_torch = torch.arange(200, dtype=torch.float32, device=torch.device("cuda")).reshape(*(2, 4, 5, 5)) input = tp.cast(tp.Tensor(input_torch), tp_dtype) @@ -282,7 +282,7 @@ def test_uneven_padding(self, torch_dtype, tp_dtype): input_torch = torch_pad(input_torch) expected = conv_layer_torch(input_torch).to(torch_dtype) - output = conv_layer(input) + output = eager_or_compiled(conv_layer, input) rtol_ = 2e-7 if tp_dtype == tp.float32 else 2e-3 output_torch = torch.from_dlpack(output) diff --git a/tripy/tests/integration/test_conv_transpose.py b/tripy/tests/integration/test_conv_transpose.py index 9cc95f890..2245d024b 100644 --- a/tripy/tests/integration/test_conv_transpose.py +++ b/tripy/tests/integration/test_conv_transpose.py @@ -81,7 +81,7 @@ class ConvTestCase: @pytest.mark.parametrize("torch_dtype,tp_dtype", DTYPES) class TestConvolution: @pytest.mark.parametrize("test_case", test_cases_transpose_1d) - def test_transposed_convolution_1d(self, torch_dtype, tp_dtype, test_case): + def test_transposed_convolution_1d(self, torch_dtype, tp_dtype, test_case, eager_or_compiled): if not test_case.torch_pad: test_case.torch_pad = 0 if not test_case.stride: @@ -129,14 +129,14 @@ def test_transposed_convolution_1d(self, torch_dtype, tp_dtype, test_case): conv_layer.bias = tp.cast(tp.Tensor(conv_layer_torch.bias.data), tp_dtype) expected = conv_layer_torch(input_torch).to(torch_dtype) - output = conv_layer(input) + output = eager_or_compiled(conv_layer, input) - rtol_ = 1e-3 + rtol_ = 3e-3 assert tp.allclose(output, tp.Tensor(expected), rtol=rtol_) assert output.shape == list(expected.shape) @pytest.mark.parametrize("test_case", test_cases_transpose_2d) - def test_transposed_convolution_2d(self, torch_dtype, tp_dtype, test_case): + def test_transposed_convolution_2d(self, torch_dtype, tp_dtype, test_case, eager_or_compiled): if not test_case.torch_pad: test_case.torch_pad = 0 if not test_case.stride: @@ -184,14 +184,14 @@ def test_transposed_convolution_2d(self, torch_dtype, tp_dtype, test_case): conv_layer.bias = tp.cast(tp.Tensor(conv_layer_torch.bias.data), tp_dtype) expected = conv_layer_torch(input_torch).to(torch_dtype) - output = conv_layer(input) + output = eager_or_compiled(conv_layer, input) rtol_ = 1e-2 assert tp.allclose(output, tp.Tensor(expected), rtol=rtol_) assert output.shape == list(expected.shape) @pytest.mark.parametrize("test_case", test_cases_transpose_3d) - def test_transposed_convolution_3d(self, torch_dtype, tp_dtype, test_case): + def test_transposed_convolution_3d(self, torch_dtype, tp_dtype, test_case, eager_or_compiled): if not test_case.torch_pad: test_case.torch_pad = 0 if not test_case.stride: @@ -239,12 +239,12 @@ def test_transposed_convolution_3d(self, torch_dtype, tp_dtype, test_case): conv_layer.bias = tp.cast(tp.Tensor(conv_layer_torch.bias.data), tp_dtype) expected = conv_layer_torch(input_torch).to(torch_dtype) - output = conv_layer(input) + output = eager_or_compiled(conv_layer, input) rtol_ = 1.3e-6 if tp_dtype == tp.float32 else 1.6e-3 assert tp.allclose(output, tp.Tensor(expected), rtol=rtol_) assert output.shape == list(expected.shape) - def test_transposed_equivalency(self, torch_dtype, tp_dtype): + def test_transposed_equivalency(self, torch_dtype, tp_dtype, eager_or_compiled): input_torch = torch.arange(9, dtype=torch.float32, device=torch.device("cuda")).reshape(*(1, 1, 3, 3)) input = tp.cast(tp.Tensor(input_torch), tp_dtype) @@ -277,8 +277,8 @@ def test_transposed_equivalency(self, torch_dtype, tp_dtype): expected = conv_layer_torch(input_torch).to(torch_dtype) expected_transpose = conv_transpose_layer_torch(input_torch).to(torch_dtype) - output = conv_layer(input) - output_transpose = conv_transpose_layer(input) + output = eager_or_compiled(conv_layer, input) + output_transpose = eager_or_compiled(conv_transpose_layer, input) rtol_ = 2e-7 if tp_dtype == tp.float32 else 9e-4 assert tp.allclose(output, tp.Tensor(expected), rtol=rtol_) @@ -291,7 +291,7 @@ def test_transposed_equivalency(self, torch_dtype, tp_dtype): assert list(expected.shape) == list(expected_transpose.shape) @pytest.mark.parametrize("test_case", test_cases_transpose_downscale) - def test_transposed_downscale(self, torch_dtype, tp_dtype, test_case): + def test_transposed_downscale(self, torch_dtype, tp_dtype, test_case, eager_or_compiled): input_torch = torch.arange(9, dtype=torch.float32, device=torch.device("cuda")).reshape(*(1, 1, 3, 3)) input = tp.cast(tp.Tensor(input_torch), tp_dtype) @@ -320,7 +320,7 @@ def test_transposed_downscale(self, torch_dtype, tp_dtype, test_case): conv_layer.weight = tp.cast(tp.Tensor(conv_layer_torch.weight.data), tp_dtype) expected = conv_layer_torch(input_torch).to(torch_dtype) - output = conv_layer(input) + output = eager_or_compiled(conv_layer, input) rtol_ = 1e-15 if tp_dtype == tp.float32 else 1e-10 assert tp.allclose(output, tp.Tensor(expected), rtol=rtol_) diff --git a/tripy/tests/integration/test_cumsum.py b/tripy/tests/integration/test_cumsum.py index c8f8bbb7e..2360f3eaa 100644 --- a/tripy/tests/integration/test_cumsum.py +++ b/tripy/tests/integration/test_cumsum.py @@ -30,11 +30,10 @@ class TestCumsum: ([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], 0, [[[1, 2], [3, 4]], [[6, 8], [10, 12]]]), ], ) - def test_cumsum(self, data, dim, expected): + def test_cumsum(self, data, dim, expected, eager_or_compiled): inp = tp.Tensor(data, dtype=tp.float32) - out = tp.cumsum(inp, dim=dim) - + out = eager_or_compiled(tp.cumsum, inp, dim=dim) expected = tp.Tensor(expected, dtype=tp.float32) assert tp.allclose(out, expected) assert out.shape == expected.shape diff --git a/tripy/tests/integration/test_dequantize.py b/tripy/tests/integration/test_dequantize.py index f44f3a23b..4924ab9a6 100644 --- a/tripy/tests/integration/test_dequantize.py +++ b/tripy/tests/integration/test_dequantize.py @@ -29,12 +29,16 @@ class TestDequantize: @pytest.mark.parametrize( "dtype", [tp.float32, tp.float16, pytest.param(tp.bfloat16, marks=skip_if_older_than_sm80)] ) - def test_dequantize_int8_per_tensor(self, dtype): + def test_dequantize_int8_per_tensor(self, dtype, eager_or_compiled): data = [4, 8] input_tp = tp.Tensor(data, dtype=tp.int8) scale = torch.tensor(0.5, dtype=TORCH_DTYPES[dtype]) scale_tp = tp.Tensor(scale, dtype=dtype) - dequantized = tp.dequantize(input_tp, scale_tp, dtype) + + def func(input): + return tp.dequantize(input, scale_tp, dtype) + + dequantized = eager_or_compiled(func, input_tp) expected = torch.tensor(data) * scale output = torch.from_dlpack(dequantized) assert torch.allclose(expected, output.to("cpu")) @@ -42,7 +46,7 @@ def test_dequantize_int8_per_tensor(self, dtype): @pytest.mark.parametrize( "dtype", [tp.float32, tp.float16, pytest.param(tp.bfloat16, marks=skip_if_older_than_sm80)] ) - def test_dequantize_int8_per_channel(self, dtype): + def test_dequantize_int8_per_channel(self, dtype, eager_or_compiled): # TODO: Fix in #153 if dtype == tp.float16: pytest.skip("TRT does not support fp16->int8 per-channel dequant.") @@ -50,7 +54,11 @@ def test_dequantize_int8_per_channel(self, dtype): input_tp = tp.Tensor(data, dtype=tp.int8) scale = torch.tensor([0.8, 0.9], dtype=TORCH_DTYPES[dtype]) scale_tp = tp.Tensor(scale, dtype=dtype) - dequantized = tp.dequantize(input_tp, scale_tp, dtype, dim=0) + + def func(input): + return tp.dequantize(input, scale_tp, dtype, dim=0) + + dequantized = eager_or_compiled(func, input_tp) expected = torch.tensor(data) * scale.reshape((2, 1)) output = torch.from_dlpack(dequantized) assert torch.allclose(expected, output.to("cpu")) diff --git a/tripy/tests/integration/test_expand.py b/tripy/tests/integration/test_expand.py index d2ab402de..09b1fcfca 100644 --- a/tripy/tests/integration/test_expand.py +++ b/tripy/tests/integration/test_expand.py @@ -22,24 +22,24 @@ class TestExpand: - def test_int_sizes(self): + def test_int_sizes(self, eager_or_compiled): input = tp.ones((2, 1)) - out = tp.expand(input, (-1, 2)) + out = eager_or_compiled(tp.expand, input, (-1, 2)) assert np.array_equal(cp.from_dlpack(out).get(), np.ones((2, 2), dtype=np.float32)) - def test_shape_sizes(self): + def test_shape_sizes(self, eager_or_compiled): input = tp.ones((2, 1)) a = tp.ones((2, 4)) - out = tp.expand(input, a.shape) + out = eager_or_compiled(tp.expand, input, a.shape) assert np.array_equal(cp.from_dlpack(out).get(), np.ones((2, 4), dtype=np.float32)) - def test_extra_dims(self): + def test_extra_dims(self, eager_or_compiled): input = tp.ones((2, 1)) - out = tp.expand(input, (1, -1, 2)) + out = eager_or_compiled(tp.expand, input, (1, -1, 2)) assert np.array_equal(cp.from_dlpack(out).get(), np.ones((1, 2, 2), dtype=np.float32)) - def test_mixed_sizes(self): + def test_mixed_sizes(self, eager_or_compiled): input = tp.ones((2, 1, 1)) a = tp.ones((4, 4)) - out = tp.expand(input, (-1, a.shape[0], a.shape[1])) + out = eager_or_compiled(tp.expand, input, (-1, a.shape[0], a.shape[1])) assert np.array_equal(cp.from_dlpack(out).get(), np.ones((2, 4, 4), dtype=np.float32)) diff --git a/tripy/tests/integration/test_flatten.py b/tripy/tests/integration/test_flatten.py index da16c181b..59bc32f57 100644 --- a/tripy/tests/integration/test_flatten.py +++ b/tripy/tests/integration/test_flatten.py @@ -29,29 +29,29 @@ class TestFlatten: ((2, 3, 4, 5), 1, 3, (2, 60)), # Flatten dimensions 1 through 3 ], ) - def test_flatten(self, shape, start_dim, end_dim, expected_shape): + def test_flatten(self, shape, start_dim, end_dim, expected_shape, eager_or_compiled): cp_a = cp.arange(np.prod(shape)).reshape(shape).astype(np.float32) a = tp.Tensor(cp_a) - b = tp.flatten(a, start_dim=start_dim, end_dim=end_dim) + b = eager_or_compiled(tp.flatten, a, start_dim=start_dim, end_dim=end_dim) assert b.shape == list(expected_shape) assert np.array_equal(cp.from_dlpack(b).get(), cp_a.reshape(expected_shape).get()) - def test_flatten_invalid_dims(self): + def test_flatten_invalid_dims(self, eager_or_compiled): shape = (2, 3, 4) with pytest.raises(tp.TripyException, match="Invalid dimensions"): a = tp.ones(shape) # Invalid because end_dim < start_dim - tp.flatten(a, start_dim=2, end_dim=1) + eager_or_compiled(tp.flatten, a, start_dim=2, end_dim=1) - def test_flatten_single_dim(self): + def test_flatten_single_dim(self, eager_or_compiled): shape = (2, 3, 4) a = tp.ones(shape) # Flattening a single dimension should not change the output - b = tp.flatten(a, start_dim=1, end_dim=1) + b = eager_or_compiled(tp.flatten, a, start_dim=1, end_dim=1) assert b.shape == [2, 3, 4] assert np.array_equal(cp.from_dlpack(b).get(), np.ones(shape, dtype=np.float32)) - def test_flatten_with_unknown_dims(self): + def test_flatten_with_unknown_dims(self, eager_or_compiled): a = tp.ones((2, 3, 4, 5)) - b = tp.flatten(a, start_dim=1, end_dim=-1) + b = eager_or_compiled(tp.flatten, a, start_dim=1, end_dim=-1) assert np.array_equal(cp.from_dlpack(b).get(), np.ones((2, 60), dtype=np.float32)) diff --git a/tripy/tests/integration/test_flip.py b/tripy/tests/integration/test_flip.py index 8118716d5..ef53f6c1a 100644 --- a/tripy/tests/integration/test_flip.py +++ b/tripy/tests/integration/test_flip.py @@ -26,34 +26,34 @@ class TestFlip: "dims", [0, 1, None, [0, 1], [1, 0], -1, -2, [0, -1], [-2, 1]], ) - def test_flip(self, dims): + def test_flip(self, dims, eager_or_compiled): cp_a = cp.arange(16).reshape((4, 4)).astype(cp.float32) a = tp.Tensor(cp_a, device=tp.device("gpu")) f = tp.flip(a, dims=dims) assert np.array_equal(cp.from_dlpack(f).get(), np.flip(cp_a.get(), axis=dims)) # also ensure that flipping a second time restores the original value - f2 = tp.flip(f, dims=dims) + f2 = eager_or_compiled(tp.flip, f, dims=dims) assert cp.array_equal(cp.from_dlpack(f2), cp_a) - def test_no_op(self): + def test_no_op(self, eager_or_compiled): cp_a = cp.arange(16).reshape((4, 4)).astype(cp.float32) a = tp.Tensor(cp_a, device=tp.device("gpu")) - f = tp.flip(a, dims=[]) + f = eager_or_compiled(tp.flip, a, dims=[]) assert tp.equal(a, f) - def test_zero_rank(self): + def test_zero_rank(self, eager_or_compiled): t = tp.Tensor(1) - f = tp.flip(t) + f = eager_or_compiled(tp.flip, t) assert tp.equal(t, f) @pytest.mark.parametrize( "dims1, dims2", [(0, -2), (1, -1), ([0, 1], None), ([0, 1], [1, 0]), ([0, 1], [-2, -1])], ) - def test_equivalences(self, dims1, dims2): + def test_equivalences(self, dims1, dims2, eager_or_compiled): cp_a = cp.arange(16).reshape((4, 4)).astype(cp.float32) a = tp.Tensor(cp_a, device=tp.device("gpu")) - f1 = tp.flip(a, dims=dims1) - f2 = tp.flip(a, dims=dims2) + f1 = eager_or_compiled(tp.flip, a, dims=dims1) + f2 = eager_or_compiled(tp.flip, a, dims=dims2) assert tp.equal(f1, f2) diff --git a/tripy/tests/integration/test_full.py b/tripy/tests/integration/test_full.py index 9a04c1664..d96885ffa 100644 --- a/tripy/tests/integration/test_full.py +++ b/tripy/tests/integration/test_full.py @@ -22,21 +22,21 @@ class TestFull: - def test_normal_shape(self): - out = tp.full((2, 2), 5.0, tp.float32) + def test_normal_shape(self, eager_or_compiled): + out = eager_or_compiled(tp.full, (2, 2), 5.0, tp.float32) assert np.array_equal(cp.from_dlpack(out).get(), np.full((2, 2), 5.0, np.float32)) - def test_shape_tensor(self): + def test_shape_tensor(self, eager_or_compiled): a = tp.ones((2, 3)) - out = tp.full(a.shape, 5.0, tp.float32) + out = eager_or_compiled(tp.full, a.shape, 5.0, tp.float32) assert np.array_equal(cp.from_dlpack(out).get(), np.full((2, 3), 5.0, np.float32)) - def test_mixed_shape(self): + def test_mixed_shape(self, eager_or_compiled): a = tp.ones((2, 3)) - out = tp.full((a.shape[0], 4), 5.0, tp.float32) + out = eager_or_compiled(tp.full, (a.shape[0], 4), 5.0, tp.float32) assert np.array_equal(cp.from_dlpack(out).get(), np.full((2, 4), 5.0, np.float32)) - def test_value_as_tensor(self): + def test_value_as_tensor(self, eager_or_compiled): a = tp.ones((2, 3)) - out = tp.full((a.shape[0], 4), tp.Tensor(8.0), tp.float32) + out = eager_or_compiled(tp.full, (a.shape[0], 4), tp.Tensor(8.0), tp.float32) assert np.array_equal(cp.from_dlpack(out).get(), np.full((2, 4), 8.0, np.float32)) diff --git a/tripy/tests/integration/test_gather.py b/tripy/tests/integration/test_gather.py index e2f088346..d0e05a118 100644 --- a/tripy/tests/integration/test_gather.py +++ b/tripy/tests/integration/test_gather.py @@ -34,11 +34,11 @@ class TestGatherOp: ((2, 3, 4), 1, (2)), ], ) - def test_gather(self, x_shape, axis, indices): + def test_gather(self, x_shape, axis, indices, eager_or_compiled): x = np.arange(np.prod(x_shape)).reshape(x_shape) indices_tp = tp.Tensor(indices) a = tp.Tensor(x) a = tp.cast(a, tp.int32) - out = tp.gather(a, axis, indices_tp) - out.eval() + out = eager_or_compiled(tp.gather, a, axis, indices_tp) + assert np.array_equal(cp.from_dlpack(out).get(), np.take(x, indices, axis)) diff --git a/tripy/tests/integration/test_groupnorm.py b/tripy/tests/integration/test_groupnorm.py index 5f1cd8bc3..d56c15928 100644 --- a/tripy/tests/integration/test_groupnorm.py +++ b/tripy/tests/integration/test_groupnorm.py @@ -29,7 +29,7 @@ class TestGroupNorm: @pytest.mark.parametrize("input_shape", [(1, 10, 2)]) @pytest.mark.parametrize("num_groups", [2, 5]) @pytest.mark.parametrize("num_channels", [10]) - def test_groupnorm_accuracy(self, torch_dtype, tp_dtype, input_shape, num_groups, num_channels): + def test_groupnorm_accuracy(self, torch_dtype, tp_dtype, input_shape, num_groups, num_channels, eager_or_compiled): eps = 1e-5 groupnorm = torch.nn.GroupNorm( num_groups=num_groups, @@ -50,7 +50,7 @@ def test_groupnorm_accuracy(self, torch_dtype, tp_dtype, input_shape, num_groups input = torch.arange(torch.prod(torch.Tensor(input_shape))).reshape(input_shape).to(torch_dtype) tp_input = tp.Tensor(input, dtype=tp_dtype) - output = tp.copy(tp_groupnorm(tp_input), tp.device("cpu")) + output = eager_or_compiled(tp.copy, tp_groupnorm(tp_input), tp.device("cpu")) with torch.no_grad(): expected = groupnorm(input) diff --git a/tripy/tests/integration/test_iota.py b/tripy/tests/integration/test_iota.py index 2df779da2..44cb38ab6 100644 --- a/tripy/tests/integration/test_iota.py +++ b/tripy/tests/integration/test_iota.py @@ -49,17 +49,13 @@ def _compute_ref_iota(self, dtype, shape, dim): "shape, dim", [ ((2, 3), 1), - ((2, 3), None), + ((2, 3), 0), ((2, 3), -1), ((2, 3, 4), 2), ], ) - def test_iota(self, dtype, shape, dim): - if dim: - output = tp.iota(shape, dim, dtype[1]) - else: - output = tp.iota(shape, dtype=dtype[1]) - + def test_iota(self, dtype, shape, dim, eager_or_compiled): + output = eager_or_compiled(tp.iota, shape, dim, dtype[1]) assert np.array_equal(cp.from_dlpack(output).get(), self._compute_ref_iota(dtype[0], shape, dim)) @pytest.mark.parametrize("dtype", DTYPE_PARAMS) @@ -72,11 +68,11 @@ def test_iota(self, dtype, shape, dim): ((2, 3, 4), 2), ], ) - def test_iota_like(self, dtype, shape, dim): + def test_iota_like(self, dtype, shape, dim, eager_or_compiled): if dim: - output = tp.iota_like(tp.ones(shape), dim, dtype[1]) + output = eager_or_compiled(tp.iota_like, tp.ones(shape), dim, dtype[1]) else: - output = tp.iota_like(tp.ones(shape), dtype=dtype[1]) + output = eager_or_compiled(tp.iota_like, tp.ones(shape), dtype=dtype[1]) assert np.array_equal(cp.from_dlpack(output).get(), self._compute_ref_iota(dtype[0], shape, dim)) @@ -98,12 +94,12 @@ def test_negative_no_casting(self, dtype): ): print(out) - def test_iota_from_shape_tensor(self): + def test_iota_from_shape_tensor(self, eager_or_compiled): a = tp.ones((2, 2)) - output = tp.iota(a.shape) + output = eager_or_compiled(tp.iota, a.shape) assert np.array_equal(cp.from_dlpack(output).get(), self._compute_ref_iota("float32", (2, 2), 0)) - def test_iota_from_mixed_seqence(self): + def test_iota_from_mixed_seqence(self, eager_or_compiled): a = tp.ones((2, 2)) - output = tp.iota((3, a.shape[0])) + output = eager_or_compiled(tp.iota, (3, a.shape[0])) assert np.array_equal(cp.from_dlpack(output).get(), self._compute_ref_iota("float32", (3, 2), 0)) diff --git a/tripy/tests/integration/test_layernorm.py b/tripy/tests/integration/test_layernorm.py index 088054c39..b1304ae63 100644 --- a/tripy/tests/integration/test_layernorm.py +++ b/tripy/tests/integration/test_layernorm.py @@ -31,7 +31,7 @@ class TestLayerNorm: @pytest.mark.parametrize("torch_dtype, tp_dtype", DTYPES) @pytest.mark.parametrize("input_shape", [(2, 2, 2)]) @pytest.mark.parametrize("normalized_shape", [(2, 2), (2,)]) - def test_layernorm_accuracy(self, torch_dtype, tp_dtype, input_shape, normalized_shape): + def test_layernorm_accuracy(self, torch_dtype, tp_dtype, input_shape, normalized_shape, eager_or_compiled): eps = 1e-5 layernorm = torch.nn.LayerNorm( normalized_shape=normalized_shape, @@ -51,7 +51,7 @@ def test_layernorm_accuracy(self, torch_dtype, tp_dtype, input_shape, normalized input = torch.arange(torch.prod(torch.Tensor(input_shape))).reshape(input_shape).to(torch_dtype) tp_input = tp.Tensor(input, dtype=tp_dtype) - output = tp.copy(tp_layernorm(tp_input), tp.device("cpu")) + output = eager_or_compiled(tp.copy, tp_layernorm(tp_input), tp.device("cpu")) with torch.no_grad(): expected = layernorm(input) diff --git a/tripy/tests/integration/test_linear.py b/tripy/tests/integration/test_linear.py index ff4899a74..137d4a00d 100644 --- a/tripy/tests/integration/test_linear.py +++ b/tripy/tests/integration/test_linear.py @@ -25,7 +25,7 @@ class TestLinear: - def test_linear_module(self): + def test_linear_module(self, eager_or_compiled): class Network(tp.Module): def __init__(self): super().__init__() @@ -41,7 +41,7 @@ def __call__(self, x): cp_a1 = cp.ones((3, 4), dtype=cp.float32) a1 = tp.Tensor(cp_a1, device=tp.device("gpu")) - out = net(a1) + out = eager_or_compiled(net, a1) np_out = cp_a1.get() @ (np_weight.transpose()) + np_bias @@ -84,7 +84,7 @@ def __call__(self, x): @pytest.mark.parametrize("use_input_scale", [False, True]) @pytest.mark.parametrize("quant_dtype", [tp.int8, pytest.param(tp.float8, marks=skip_if_older_than_sm89)]) @pytest.mark.parametrize("weight_quant_dim", [None, 0, 1]) - def test_quant_linear(self, use_input_scale, quant_dtype, weight_quant_dim): + def test_quant_linear(self, use_input_scale, quant_dtype, weight_quant_dim, eager_or_compiled): net = self._create_network(use_input_scale, quant_dtype, weight_quant_dim) np_weight = cp.from_dlpack(net.linear.weight).get() np_bias = cp.from_dlpack(net.linear.bias).get() @@ -96,9 +96,9 @@ def test_quant_linear(self, use_input_scale, quant_dtype, weight_quant_dim): tp.TripyException, match="Unsupported quantization parameters for Linear module.", ): - out = net(a1) + out = eager_or_compiled(net, a1) else: - out = net(a1) + out = eager_or_compiled(net, a1) np_out = cp_a1.get() @ (np_weight.transpose()) + np_bias @@ -114,7 +114,7 @@ def test_quant_linear(self, use_input_scale, quant_dtype, weight_quant_dim): ], ids=["block-wise", "per-tensor", "per-channel-0", "per-channel-1"], ) - def test_quant_linear_int4_weight_only(self, weight_quant_dim, scale): + def test_quant_linear_int4_weight_only(self, weight_quant_dim, scale, eager_or_compiled): scale = tp.Parameter(scale) linear = tp.Linear(4, 8, quant_dtype=tp.int4, weight_quant_dim=weight_quant_dim) @@ -128,7 +128,7 @@ def test_quant_linear_int4_weight_only(self, weight_quant_dim, scale): cp_input = cp.ones((4, 4), dtype=np.float32) input = tp.Tensor(cp_input, device=tp.device("gpu")) - out = linear(input) + out = eager_or_compiled(linear, input) np_out = cp_input.get() @ (np_weight.transpose()) + np_bias diff --git a/tripy/tests/integration/test_matrix_multiplication.py b/tripy/tests/integration/test_matrix_multiplication.py index 57731b674..b19e38937 100644 --- a/tripy/tests/integration/test_matrix_multiplication.py +++ b/tripy/tests/integration/test_matrix_multiplication.py @@ -23,23 +23,27 @@ import tripy.common.datatype +def gemm(a, b): + return a @ b + + class TestMatrixMultiplication: - def test_2d_tensors(self): + def test_2d_tensors(self, eager_or_compiled): a_np = np.arange(6).reshape((2, 3)).astype(np.float32) b_np = np.arange(6).reshape((3, 2)).astype(np.float32) a = tp.Tensor(a_np) b = tp.Tensor(b_np) - out = a @ b + out = eager_or_compiled(gemm, a, b) assert tp.allclose(out, tp.Tensor(a_np @ b_np)) - def test_1d_tensors(self): + def test_1d_tensors(self, eager_or_compiled): a_np = np.arange(64).astype(np.float32) # 1D Tensor b_np = np.arange(64).astype(np.float32) # 1D Tensor a = tripy.Tensor(cp.asanyarray(a_np)) b = tripy.Tensor(cp.asanyarray(b_np)) - out = a @ b + out = eager_or_compiled(gemm, a, b) assert tp.allclose(out, tp.Tensor(cp.array(a_np @ b_np)), atol=1e-2) @pytest.mark.parametrize( @@ -53,11 +57,11 @@ def test_1d_tensors(self): ((1, 2, 3), (0, 0, 3, 2)), # Broadcasting batch dims with 0 ], ) - def test_broadcast_gemm(self, shape_a, shape_b): + def test_broadcast_gemm(self, shape_a, shape_b, eager_or_compiled): a_np = np.arange(np.prod(shape_a)).reshape(shape_a).astype(np.float32) b_np = np.arange(np.prod(shape_b)).reshape(shape_b).astype(np.float32) a = tp.Tensor(a_np) b = tp.Tensor(b_np) - out = a @ b + out = eager_or_compiled(gemm, a, b) assert tp.allclose(out, tp.Tensor(a_np @ b_np)) diff --git a/tripy/tests/integration/test_outer.py b/tripy/tests/integration/test_outer.py index 8ba7be979..53f8b5237 100644 --- a/tripy/tests/integration/test_outer.py +++ b/tripy/tests/integration/test_outer.py @@ -19,10 +19,10 @@ class TestOuter: - def test_outer(self): + def test_outer(self, eager_or_compiled): v1 = tp.arange(5, dtype=tp.float32) v2 = tp.arange(4, dtype=tp.float32) - output = tp.outer(v1, v2) + output = eager_or_compiled(tp.outer, v1, v2) t1 = torch.arange(5, dtype=torch.float32) t2 = torch.arange(4, dtype=torch.float32) @@ -30,9 +30,9 @@ def test_outer(self): assert output.shape == list(torch_out.shape) assert tp.allclose(output, tp.Tensor(torch_out)) - def test_empty(self): + def test_empty(self, eager_or_compiled): v1 = tp.Tensor([]) v2 = tp.arange(3, dtype=tp.float32) - output = tp.outer(v1, v2) + output = eager_or_compiled(tp.outer, v1, v2) assert output.shape == [0, 3] diff --git a/tripy/tests/integration/test_pad.py b/tripy/tests/integration/test_pad.py index 8843055ee..578cf4a12 100644 --- a/tripy/tests/integration/test_pad.py +++ b/tripy/tests/integration/test_pad.py @@ -29,19 +29,19 @@ class TestPad: (((1, 2), (2, 3)), 1), ], ) - def test_pad_constant(self, pad, value): + def test_pad_constant(self, pad, value, eager_or_compiled): inp = np.arange(4, dtype=np.int32).reshape((2, 2)) - out = tp.pad(tp.Tensor(inp), pad, value=value) + out = eager_or_compiled(tp.pad, tp.Tensor(inp), pad, value=value) expected = np.pad(inp, pad, constant_values=value) assert np.array_equal(cp.from_dlpack(out).get(), expected) - def test_pad_tensor(self): + def test_pad_tensor(self, eager_or_compiled): inp = np.arange(6, dtype=np.float32).reshape((2, 3)) inp_tp = tp.Tensor(inp) - out = tp.pad(tp.Tensor(inp), ((0, inp_tp.shape[0]), (inp_tp.shape[1], 0))) + out = eager_or_compiled(tp.pad, tp.Tensor(inp), ((0, inp_tp.shape[0]), (inp_tp.shape[1], 0))) expected = np.pad(inp, ((0, 2), (3, 0))) assert np.array_equal(cp.from_dlpack(out).get(), expected) diff --git a/tripy/tests/integration/test_pooling.py b/tripy/tests/integration/test_pooling.py index 86dd45a34..1e28f956e 100644 --- a/tripy/tests/integration/test_pooling.py +++ b/tripy/tests/integration/test_pooling.py @@ -32,7 +32,7 @@ class TestPooling: ) @pytest.mark.parametrize("dtype", [tp.float32, tp.float16, tp.int8]) @pytest.mark.parametrize("pool_type", ["max", "avg"]) - def test_pool_2d(self, kernel_dims, stride, padding, dtype, pool_type): + def test_pool_2d(self, kernel_dims, stride, padding, dtype, pool_type, eager_or_compiled): inp_tp = tp.reshape(tp.arange(64, dtype=dtype), (1, 1, 8, 8)) torch_padding = (padding[0][0], padding[1][0]) @@ -40,7 +40,7 @@ def test_pool_2d(self, kernel_dims, stride, padding, dtype, pool_type): pytest.skip("Torch average pool is not implemented for int8") if pool_type == "max": - out = tp.maxpool(inp_tp, kernel_dims=kernel_dims, stride=stride, padding=padding) + out = eager_or_compiled(tp.maxpool, inp_tp, kernel_dims=kernel_dims, stride=stride, padding=padding) pool_torch = torch.nn.MaxPool2d(kernel_size=kernel_dims, stride=stride, padding=torch_padding) elif pool_type == "avg": if torch_padding != (0, 0): @@ -48,7 +48,7 @@ def test_pool_2d(self, kernel_dims, stride, padding, dtype, pool_type): "https://github.com/NVIDIA/TensorRT-Incubator/issues/241: Tripy average pool is incorrect when padding != 0." ) - out = tp.avgpool(inp_tp, kernel_dims=kernel_dims, stride=stride, padding=padding) + out = eager_or_compiled(tp.avgpool, inp_tp, kernel_dims=kernel_dims, stride=stride, padding=padding) pool_torch = torch.nn.AvgPool2d(kernel_size=kernel_dims, stride=stride, padding=torch_padding) out_torch = torch.from_dlpack(out).to("cpu") @@ -64,7 +64,7 @@ def test_pool_2d(self, kernel_dims, stride, padding, dtype, pool_type): ) @pytest.mark.parametrize("dtype", [tp.float32, tp.float16]) @pytest.mark.parametrize("pool_type", ["max", "avg"]) - def test_pool_3d(self, kernel_dims, stride, padding, dtype, pool_type): + def test_pool_3d(self, kernel_dims, stride, padding, dtype, pool_type, eager_or_compiled): inp_tp = tp.reshape(tp.arange(512, dtype=dtype), (1, 1, 8, 8, 8)) torch_padding = (padding[0][0], padding[1][0], padding[2][0]) @@ -74,10 +74,10 @@ def test_pool_3d(self, kernel_dims, stride, padding, dtype, pool_type): ) if pool_type == "max": - out = tp.maxpool(inp_tp, kernel_dims=kernel_dims, stride=stride, padding=padding) + out = eager_or_compiled(tp.maxpool, inp_tp, kernel_dims=kernel_dims, stride=stride, padding=padding) pool_torch = torch.nn.MaxPool3d(kernel_size=kernel_dims, stride=stride, padding=torch_padding) elif pool_type == "avg": - out = tp.avgpool(inp_tp, kernel_dims=kernel_dims, stride=stride, padding=padding) + out = eager_or_compiled(tp.avgpool, inp_tp, kernel_dims=kernel_dims, stride=stride, padding=padding) pool_torch = torch.nn.AvgPool3d(kernel_size=kernel_dims, stride=stride, padding=torch_padding) out_torch = torch.from_dlpack(out).to("cpu") diff --git a/tripy/tests/integration/test_quantize.py b/tripy/tests/integration/test_quantize.py index 826db83bf..ee458d108 100644 --- a/tripy/tests/integration/test_quantize.py +++ b/tripy/tests/integration/test_quantize.py @@ -30,24 +30,42 @@ class TestQuantize: @pytest.mark.parametrize( "dtype", [tp.float32, tp.float16, pytest.param(tp.bfloat16, marks=skip_if_older_than_sm80)] ) - def test_quantize_int8_per_tensor(self, dtype): + def test_quantize_int8_per_tensor(self, dtype, eager_or_compiled): input = torch.tensor([1.0, 2.0], dtype=TORCH_DTYPES[dtype]) scale = torch.tensor(0.5, dtype=TORCH_DTYPES[dtype]) input_tp = tp.Tensor(input, dtype=dtype) scale_tp = tp.Tensor(scale, dtype=dtype) - quantized = tp.quantize(input_tp, scale_tp, tp.int8) + + def func(input): + return tp.quantize(input, scale_tp, tp.int8) + + quantized = eager_or_compiled(func, input_tp) expected = (input / scale).to(dtype=torch.int8) assert torch.equal(expected, torch.from_dlpack(quantized).to("cpu")) @pytest.mark.parametrize( - "dtype", [tp.float32, tp.float16, pytest.param(tp.bfloat16, marks=skip_if_older_than_sm80)] + "dtype", + [ + tp.float32, + pytest.param( + tp.float16, + marks=pytest.mark.skip( + reason="Known float16 precision issues due to https://github.com/NVIDIA/TensorRT-Incubator/issues/392" + ), + ), + pytest.param(tp.bfloat16, marks=skip_if_older_than_sm80), + ], ) - def test_quantize_int8_per_channel(self, dtype): + def test_quantize_int8_per_channel(self, dtype, eager_or_compiled): input = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=TORCH_DTYPES[dtype]) scale = torch.tensor([0.2, 0.1], dtype=TORCH_DTYPES[dtype]) input_tp = tp.Tensor(input, dtype=dtype) scale_tp = tp.Tensor(scale, dtype=dtype) - quantized = tp.quantize(input_tp, scale_tp, tp.int8, dim=0) + + def func(input): + return tp.quantize(input, scale_tp, tp.int8, dim=0) + + quantized = eager_or_compiled(func, input_tp) expected = (input / scale.reshape(2, 1)).to(dtype=torch.int8) assert torch.equal(expected, torch.from_dlpack(quantized).to("cpu")) @@ -55,12 +73,16 @@ def test_quantize_int8_per_channel(self, dtype): "dtype", [tp.float32, tp.float16, pytest.param(tp.bfloat16, marks=skip_if_older_than_sm80)] ) @skip_if_older_than_sm89 - def test_quantize_fp8_per_tensor(self, dtype): + def test_quantize_fp8_per_tensor(self, dtype, eager_or_compiled): input = torch.tensor([1.0, 2.0], dtype=TORCH_DTYPES[dtype]) scale = torch.tensor(0.5, dtype=TORCH_DTYPES[dtype]) input_tp = tp.Tensor(input, dtype=dtype) scale_tp = tp.Tensor(scale, dtype=dtype) - quantized = tp.quantize(input_tp, scale_tp, tp.float8) + + def func(input): + return tp.quantize(input, scale_tp, tp.float8) + + quantized = eager_or_compiled(func, input_tp) assert quantized.dtype == tp.float8 expected = (input / scale).to(dtype=torch.float32) with raises( @@ -74,12 +96,16 @@ def test_quantize_fp8_per_tensor(self, dtype): "dtype", [tp.float32, tp.float16, pytest.param(tp.bfloat16, marks=skip_if_older_than_sm80)] ) @skip_if_older_than_sm89 - def test_quantize_fp8_per_channel(self, dtype): + def test_quantize_fp8_per_channel(self, dtype, eager_or_compiled): input = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=TORCH_DTYPES[dtype]) scale = torch.tensor([0.2, 0.1], dtype=TORCH_DTYPES[dtype]) input_tp = tp.Tensor(input, dtype=dtype) scale_tp = tp.Tensor(scale, dtype=dtype) - quantized = tp.quantize(input_tp, scale_tp, tp.float8, dim=0) + + def func(input): + return tp.quantize(input, scale_tp, tp.float8, dim=0) + + quantized = eager_or_compiled(func, input_tp) assert quantized.dtype == tp.float8 expected = (input / scale.reshape(2, 1)).to(dtype=torch.float32) with raises( @@ -93,7 +119,7 @@ def test_quantize_fp8_per_channel(self, dtype): "dtype", [tp.float32, tp.float16, pytest.param(tp.bfloat16, marks=skip_if_older_than_sm80)] ) @pytest.mark.parametrize("quant_mode", ["block-wise", "per-tensor", "per-channel-0", "per-channel-1"]) - def test_qdq_int4(self, dtype, quant_mode): + def test_qdq_int4(self, dtype, quant_mode, eager_or_compiled): if quant_mode == "block-wise": dim = None scale = torch.ones((2, 4), dtype=TORCH_DTYPES[dtype]) @@ -109,14 +135,22 @@ def test_qdq_int4(self, dtype, quant_mode): input = torch.ones((4, 4), dtype=TORCH_DTYPES[dtype]) input_tp = tp.Tensor(input, dtype=dtype) - scale_tp = tp.Tensor(scale) - quantized = tp.quantize(input_tp, scale_tp, tp.int4, dim) - dequantized = tp.dequantize(quantized, scale_tp, dtype, dim) + + def func(inp): + scale_tp = tp.Tensor(scale) + quantized = tp.quantize(input_tp, scale_tp, tp.int4, dim) + dequantized = tp.dequantize(quantized, scale_tp, dtype, dim) + return dequantized + + dequantized = eager_or_compiled(func, input_tp) assert torch.equal(input, torch.from_dlpack(dequantized).to("cpu")) - def test_non_constant_scale(self): + def test_non_constant_scale(self, eager_or_compiled): input = tp.ones((4, 4)) scale = tp.ones((4,)) - quantized = tp.quantize(input, scale, tp.int8, dim=0) + def func(input): + return tp.quantize(input, scale, tp.int8, dim=0) + + quantized = eager_or_compiled(func, input) assert bool(cp.all(cp.from_dlpack(quantized) == cp.ones((4, 4), dtype=cp.int8))) diff --git a/tripy/tests/integration/test_reduce.py b/tripy/tests/integration/test_reduce.py index 66db0a0f4..bb922675b 100644 --- a/tripy/tests/integration/test_reduce.py +++ b/tripy/tests/integration/test_reduce.py @@ -36,10 +36,10 @@ class TestReduceOp: ((2, 3, 4, 5), (-2, -1), True), ], ) - def test_all(self, x_shape, axis, keepdim): + def test_all(self, x_shape, axis, keepdim, eager_or_compiled): x = np.array([i % 2 == 0 for i in np.arange(np.prod(x_shape))]).reshape(x_shape) a = tp.Tensor(x) - out = tp.all(a, dim=axis, keepdim=keepdim) + out = eager_or_compiled(tp.all, a, dim=axis, keepdim=keepdim) # np.array is necessary to deal with case where x.all returns a numpy scalar (5th case) expected = np.array(x.all(axis=axis, keepdims=keepdim)) assert np.array_equal(np.from_dlpack(tp.copy(out, device=tp.device("cpu"))), expected) @@ -56,10 +56,10 @@ def test_all(self, x_shape, axis, keepdim): ((2, 3, 4, 5), (-2, -1), True), ], ) - def test_any(self, x_shape, axis, keepdim): + def test_any(self, x_shape, axis, keepdim, eager_or_compiled): x = np.array([i % 2 == 0 for i in np.arange(np.prod(x_shape))]).reshape(x_shape) a = tp.Tensor(x) - out = tp.any(a, dim=axis, keepdim=keepdim) + out = eager_or_compiled(tp.any, a, dim=axis, keepdim=keepdim) expected = np.array(x.any(axis=axis, keepdims=keepdim)) assert np.array_equal(np.from_dlpack(tp.copy(out, device=tp.device("cpu"))), expected) @@ -81,11 +81,11 @@ def test_any(self, x_shape, axis, keepdim): ], ) @pytest.mark.parametrize("dtype", [tp.float32, tp.float16]) - def test_mean(self, x_shape, axis, keepdim: bool, dtype): + def test_mean(self, x_shape, axis, keepdim: bool, dtype, eager_or_compiled): np_dtype = np.float32 if dtype == tp.float32 else np.float16 x = np.arange(np.prod(x_shape)).reshape(x_shape).astype(np_dtype) a = tp.Tensor(x, dtype=dtype) - out = tp.mean(a, dim=axis, keepdim=keepdim) + out = eager_or_compiled(tp.mean, a, dim=axis, keepdim=keepdim) expected = tp.Tensor(cp.array(x.mean(axis=axis, keepdims=keepdim))) assert out.shape == expected.shape assert tp.allclose(out, expected, rtol=1e-3, atol=1e-3) @@ -102,10 +102,10 @@ def test_mean(self, x_shape, axis, keepdim: bool, dtype): ((2, 3, 4, 5), (-2, -1), True), ], ) - def test_var(self, x_shape, axis, keepdim: bool): + def test_var(self, x_shape, axis, keepdim: bool, eager_or_compiled): x = np.arange(np.prod(x_shape)).reshape(x_shape).astype(np.float32) a = tp.Tensor(x) - out = tp.var(a, dim=axis, keepdim=keepdim) + out = eager_or_compiled(tp.var, a, dim=axis, keepdim=keepdim) torch_tensor = torch.Tensor(x) expected = tp.Tensor(torch_tensor.var(dim=axis, keepdim=keepdim)) assert out.shape == expected.shape @@ -122,10 +122,10 @@ def test_var(self, x_shape, axis, keepdim: bool): ((2, 3, 4), None, True), ], ) - def test_argmax(self, x_shape, axis, keepdim: bool): + def test_argmax(self, x_shape, axis, keepdim: bool, eager_or_compiled): x = np.arange(np.prod(x_shape)).reshape(x_shape).astype(np.float32) a = tp.Tensor(x) - out = tp.argmax(a, dim=axis, keepdim=keepdim) + out = eager_or_compiled(tp.argmax, a, dim=axis, keepdim=keepdim) assert np.array_equal(cp.from_dlpack(out).get(), np.array(x.argmax(axis=axis, keepdims=keepdim))) @pytest.mark.parametrize( @@ -139,8 +139,8 @@ def test_argmax(self, x_shape, axis, keepdim: bool): ((2, 3, 4), None, True), ], ) - def test_argmin(self, x_shape, axis, keepdim: bool): + def test_argmin(self, x_shape, axis, keepdim: bool, eager_or_compiled): x = np.arange(np.prod(x_shape)).reshape(x_shape).astype(np.float32) a = tp.Tensor(x) - out = tp.argmin(a, dim=axis, keepdim=keepdim) + out = eager_or_compiled(tp.argmin, a, dim=axis, keepdim=keepdim) assert np.array_equal(cp.from_dlpack(out).get(), np.array(x.argmin(axis=axis, keepdims=keepdim))) diff --git a/tripy/tests/integration/test_repeat.py b/tripy/tests/integration/test_repeat.py index 89b34ca43..86bc54556 100644 --- a/tripy/tests/integration/test_repeat.py +++ b/tripy/tests/integration/test_repeat.py @@ -30,18 +30,18 @@ class TestRepeat: (0, 1), ], ) - def test_repeat(self, repeats, dim): + def test_repeat(self, repeats, dim, eager_or_compiled): inp = np.arange(4, dtype=np.int32).reshape((2, 2)) - out = tp.repeat(tp.Tensor(inp), repeats, dim) + out = eager_or_compiled(tp.repeat, tp.Tensor(inp), repeats, dim) expected = np.repeat(inp, repeats, dim) assert np.array_equal(np.from_dlpack(tp.copy(out, device=tp.device("cpu"))), expected) - def test_repeat_shape_scalar(self): + def test_repeat_shape_scalar(self, eager_or_compiled): inp = np.arange(4, dtype=np.int32).reshape((2, 2)) s = tp.ones((1, 2)) - out = tp.repeat(tp.Tensor(inp), s.shape[1], 0) + out = eager_or_compiled(tp.repeat, tp.Tensor(inp), repeats=s.shape[1], dim=0) expected = np.repeat(inp, 2, 0) assert np.array_equal(np.from_dlpack(tp.copy(out, device=tp.device("cpu"))), expected) diff --git a/tripy/tests/integration/test_reshape.py b/tripy/tests/integration/test_reshape.py index c30c01501..e7343c6b6 100644 --- a/tripy/tests/integration/test_reshape.py +++ b/tripy/tests/integration/test_reshape.py @@ -31,21 +31,21 @@ class TestReshape: ((2, 4), (1, -1)), # check negative dim ], ) - def test_static_reshape(self, shape, new_shape): + def test_static_reshape(self, shape, new_shape, eager_or_compiled): cp_a = cp.arange(np.prod(shape)).reshape(shape).astype(np.float32) a = tp.Tensor(cp_a, device=tp.device("gpu")) - b = tp.reshape(a, new_shape) + b = eager_or_compiled(tp.reshape, a, new_shape) if -1 in new_shape: new_shape = tuple(np.prod(shape) // -np.prod(new_shape) if d == -1 else d for d in new_shape) assert np.array_equal(cp.from_dlpack(b).get(), cp_a.reshape(new_shape).get()) - def test_reshape_shape_tensor(self): + def test_reshape_shape_tensor(self, eager_or_compiled): a = tp.ones((2, 3, 4)) b = tp.ones((2, 3, 2, 2)) - out = tp.reshape(a, (a.shape[0], a.shape[1], b.shape[2], b.shape[3])) + out = eager_or_compiled(tp.reshape, a, (a.shape[0], a.shape[1], b.shape[2], b.shape[3])) assert np.array_equal(cp.from_dlpack(out).get(), np.ones((2, 3, 2, 2), dtype=np.float32)) - def test_reshape_shape_with_unknown(self): + def test_reshape_shape_with_unknown(self, eager_or_compiled): a = tp.ones((2, 3, 4)) - out = tp.reshape(a, (2, a.shape[1], a.shape[2] / 2, -1)) + out = eager_or_compiled(tp.reshape, a, (2, a.shape[1], a.shape[2] / 2, -1)) assert np.array_equal(cp.from_dlpack(out).get(), np.ones((2, 3, 2, 2), dtype=np.float32)) diff --git a/tripy/tests/integration/test_resize.py b/tripy/tests/integration/test_resize.py index f080ef03b..137fb82d8 100644 --- a/tripy/tests/integration/test_resize.py +++ b/tripy/tests/integration/test_resize.py @@ -24,10 +24,14 @@ class TestResize: @pytest.mark.parametrize("mode", ["nearest", "linear", "cubic"]) - def test_scales(self, mode): + def test_scales(self, mode, eager_or_compiled): inp_torch = torch.arange(16, dtype=torch.float32).reshape((1, 1, 4, 4)) inp_tp = tp.Tensor(inp_torch) - out_tp = tp.resize(inp_tp, mode, scales=(1, 1, 2, 2)) + + def resize(inp, mode, scales): + return tp.resize(inp, mode=mode, scales=scales, align_corners=False) + + out_tp = eager_or_compiled(resize, inp_tp, mode=mode, scales=(1, 1, 2, 2)) torch_mode = { "nearest": "nearest", "linear": "bilinear", @@ -39,10 +43,14 @@ def test_scales(self, mode): assert torch.allclose(out_torch, expected) @pytest.mark.parametrize("mode", ["nearest", "linear", "cubic"]) - def test_output_shape(self, mode): + def test_output_shape(self, mode, eager_or_compiled): inp_torch = torch.arange(16, dtype=torch.float32).reshape((1, 1, 4, 4)) inp_tp = tp.Tensor(inp_torch) - out_tp = tp.resize(inp_tp, mode, output_shape=[1, 1, 8, 8]) + + def resize(inp, mode, output_shape): + return tp.resize(inp, mode=mode, output_shape=output_shape, align_corners=False) + + out_tp = eager_or_compiled(resize, inp_tp, mode, output_shape=[1, 1, 8, 8]) torch_mode = { "nearest": "nearest", "linear": "bilinear", diff --git a/tripy/tests/integration/test_sequential.py b/tripy/tests/integration/test_sequential.py index b6ef3e260..1429869cc 100644 --- a/tripy/tests/integration/test_sequential.py +++ b/tripy/tests/integration/test_sequential.py @@ -21,7 +21,7 @@ class TestSequential: - def test_basic_forward_pass_accuracy(self): + def test_basic_forward_pass_accuracy(self, eager_or_compiled): torch_model = torch.nn.Sequential( torch.nn.Linear(1, 3, dtype=torch.float32, device="cuda"), torch.nn.Linear(3, 2, dtype=torch.float32, device="cuda"), @@ -36,7 +36,7 @@ def test_basic_forward_pass_accuracy(self): input_tensor = torch.tensor([[1.0]], dtype=torch.float32, device="cuda") tp_input = tp.Tensor(input_tensor, dtype=tp.float32) - tp_output = tp_model(tp_input) + tp_output = eager_or_compiled(tp_model, tp_input) torch_model.eval() with torch.no_grad(): @@ -45,7 +45,7 @@ def test_basic_forward_pass_accuracy(self): rtol_ = 2e-6 assert torch.allclose(torch.from_dlpack(tp_output), torch_output, rtol=rtol_) - def test_dict_forward_pass_accuracy(self): + def test_dict_forward_pass_accuracy(self, eager_or_compiled): torch_model = torch.nn.Sequential( torch.nn.Linear(1, 3, dtype=torch.float32, device="cuda"), torch.nn.Linear(3, 2, dtype=torch.float32, device="cuda"), @@ -63,7 +63,7 @@ def test_dict_forward_pass_accuracy(self): input_tensor = torch.tensor([[1.0]], dtype=torch.float32, device="cuda") tp_input = tp.Tensor(input_tensor, dtype=tp.float32) - tp_output = tp_model(tp_input) + tp_output = eager_or_compiled(tp_model, tp_input) torch_model.eval() with torch.no_grad(): @@ -74,7 +74,7 @@ def test_dict_forward_pass_accuracy(self): torch.from_dlpack(tp_output), torch_output, rtol=rtol_ ), "Forward pass outputs do not match." - def test_nested_forward_pass_accuracy(self): + def test_nested_forward_pass_accuracy(self, eager_or_compiled): torch_model = torch.nn.Sequential( torch.nn.Linear(1, 3, dtype=torch.float32, device="cuda"), torch.nn.Sequential( @@ -97,7 +97,7 @@ def test_nested_forward_pass_accuracy(self): input_tensor = torch.tensor([[1.0]], dtype=torch.float32, device="cuda") tp_input = tp.Tensor(input_tensor, dtype=tp.float32) - tp_output = tp_model(tp_input) + tp_output = eager_or_compiled(tp_model, tp_input) torch_model.eval() with torch.no_grad(): diff --git a/tripy/tests/integration/test_slice.py b/tripy/tests/integration/test_slice.py index 063b0245c..534ac34db 100644 --- a/tripy/tests/integration/test_slice.py +++ b/tripy/tests/integration/test_slice.py @@ -69,25 +69,31 @@ class TestSliceOp: ((5,), lambda t: t[-12:-5:-1]), ], ) - def test_static_slice_op(self, dims_a, slice_func): + def test_static_slice_op(self, dims_a, slice_func, eager_or_compiled): a_cp = cp.arange(np.prod(dims_a)).reshape(dims_a).astype(np.float32) a = tp.Tensor(a_cp, device=tp.device("gpu")) def func(a): return slice_func(a) - out = func(a) + out = eager_or_compiled(func, a) assert np.array_equal(cp.from_dlpack(out).get(), slice_func(a_cp).get()) - def test_slice_as_gather(self): + def test_slice_as_gather(self, eager_or_compiled): x_data = [0, 1, 2] y_data = [3, 4, 5] x = tp.Tensor(x_data) y = tp.Tensor(y_data) + + def slice(y, x): + return y[x] + + output = eager_or_compiled(slice, y, x) + x_cp = cp.array(x_data) y_cp = cp.array(y_data) - assert np.array_equal(cp.from_dlpack(y[x]).get(), y_cp[x_cp].get()) + assert np.array_equal(cp.from_dlpack(output).get(), y_cp[x_cp].get()) x_shape = (2, 2) y_shape = (4, 3, 2) @@ -95,7 +101,9 @@ def test_slice_as_gather(self): y_vol = math.prod(y_shape) x = tp.reshape(tp.arange(x_vol, dtype=tp.int32), x_shape) y = tp.reshape(tp.arange(y_vol), y_shape) + output = eager_or_compiled(slice, y, x) + x_cp = cp.arange(x_vol, dtype=cp.int32).reshape(x_shape) y_cp = cp.arange(y_vol).reshape(y_shape) - assert np.array_equal(cp.from_dlpack(y[x]).get(), y_cp[x_cp].get()) + assert np.array_equal(cp.from_dlpack(output).get(), y_cp[x_cp].get()) diff --git a/tripy/tests/integration/test_split.py b/tripy/tests/integration/test_split.py index 9279c98fb..f6e7ad369 100644 --- a/tripy/tests/integration/test_split.py +++ b/tripy/tests/integration/test_split.py @@ -43,16 +43,21 @@ class TestSplitOp: ((12, 12), (3, 1), lambda t: (t[:, :4], t[:, 4:8], t[:, 8:])), ((12, 12), ([3], 1), lambda t: (t[:, :3], t[:, 3:])), ((12, 12), (4, 0), lambda t: (t[:3, :], t[3:6, :], t[6:9, :], t[9:12, :])), - ((3, 0), (5, 1), lambda t: (t[:, :0], t[:, 0:0], t[:, 0:0], t[:, 0:0], t[:, 0:0])), + pytest.param( + (3, 0), + (5, 1), + lambda t: (t[:, :0], t[:, 0:0], t[:, 0:0], t[:, 0:0], t[:, 0:0]), + marks=pytest.mark.skip(reason="https://github.com/NVIDIA/TensorRT-Incubator/issues/398"), + ), ], ) - def test_split_static(self, dims_a, split_params, reference_slices): + def test_split_static(self, dims_a, split_params, reference_slices, eager_or_compiled): a_cp = cp.arange(np.prod(dims_a)).reshape(dims_a).astype(cp.float32) a = tp.Tensor(a_cp, device=tp.device("gpu")) def func(t): return tp.split(t, split_params[0], split_params[1]) - out = func(a) + out = eager_or_compiled(func, a) reference_out = reference_slices(a_cp) compare_split_results(out, reference_out) diff --git a/tripy/tests/integration/test_stack.py b/tripy/tests/integration/test_stack.py index be1f724b5..796bcc26b 100644 --- a/tripy/tests/integration/test_stack.py +++ b/tripy/tests/integration/test_stack.py @@ -33,9 +33,9 @@ class TestStack: ([(2, 3, 4)], 0), ], ) - def test_stack(self, tensor_shapes, dim): + def test_stack(self, tensor_shapes, dim, eager_or_compiled): tensors = [tp.ones(shape) for shape in tensor_shapes] - out = tp.stack(tensors, dim=dim) + out = eager_or_compiled(tp.stack, tensors, dim=dim) # Create numpy arrays for comparison np_tensors = [np.ones(shape) for shape in tensor_shapes] @@ -44,13 +44,13 @@ def test_stack(self, tensor_shapes, dim): assert out.shape == list(expected.shape) assert np.array_equal(cp.from_dlpack(out).get(), expected) - def test_stack_different_ranks(self): + def test_stack_different_ranks(self, eager_or_compiled): tensors = [tp.ones((2, 3)), tp.ones((2, 3, 4))] with raises( tp.TripyException, match="Expected all input tensors to have the same rank.", ): - tp.stack(tensors) + eager_or_compiled(tp.stack, tensors) def test_stack_different_shapes(self): a = tp.ones((2, 3)) diff --git a/tripy/tests/integration/test_unary_elementwise.py b/tripy/tests/integration/test_unary_elementwise.py index e01ca3fff..e89a37d6c 100644 --- a/tripy/tests/integration/test_unary_elementwise.py +++ b/tripy/tests/integration/test_unary_elementwise.py @@ -35,7 +35,7 @@ class TestUnaryElementwise: @pytest.mark.parametrize("tp_func, np_func", [(tp_func, np_func) for tp_func, np_func in _UNARY_OPS.items()]) - def test_op_funcs(self, tp_func, np_func): + def test_op_funcs(self, tp_func, np_func, eager_or_compiled): input = tp.arange(1, 4, dtype=tp.float32) - output = tp_func(input) + output = eager_or_compiled(tp_func, input) assert tp.allclose(output, tp.Tensor(np_func(cp.from_dlpack(input).get()))) diff --git a/tripy/tests/integration/test_unsqueeze.py b/tripy/tests/integration/test_unsqueeze.py index e25d459b1..4402449fc 100644 --- a/tripy/tests/integration/test_unsqueeze.py +++ b/tripy/tests/integration/test_unsqueeze.py @@ -24,13 +24,13 @@ class TestUnsqueezeOp: @pytest.mark.parametrize("axis", [-1, 0, 2]) - def test_unsqueeze_dynamic_op(self, axis): + def test_unsqueeze_dynamic_op(self, axis, eager_or_compiled): def func(a): return tp.unsqueeze(a, dim=axis) inp = np.ones((4, 2, 2, 3), dtype=np.float32) - out = func(tp.Tensor(inp)) + out = eager_or_compiled(func, tp.Tensor(inp)) ref_out = np.expand_dims(inp, axis=axis) assert tp.allclose(out, tp.Tensor(ref_out)) diff --git a/tripy/tests/integration/test_where_op.py b/tripy/tests/integration/test_where_op.py index 36d4839f5..5f37b5724 100644 --- a/tripy/tests/integration/test_where_op.py +++ b/tripy/tests/integration/test_where_op.py @@ -35,19 +35,19 @@ class TestWhereOp: ((0,), (1,), (1,)), # 0 dim in the condition ], ) - def test_where_broadcast_shapes(self, cond, x, y): + def test_where_broadcast_shapes(self, cond, x, y, eager_or_compiled): x = np.arange(np.prod(x)).reshape(x).astype(np.float32) y = np.arange(np.prod(y)).reshape(y).astype(np.float32) t_cond = np.arange(np.prod(cond)).reshape(cond).astype(np.float32) a = Tensor(x) b = Tensor(y) condition = Tensor(t_cond % 2 == 0) - out = tp.where(condition, a, b) + out = eager_or_compiled(tp.where, condition, a, b) assert np.array_equal(cp.from_dlpack(out).get(), np.array(np.where((t_cond % 2 == 0), x, y))) - def test_explicit_condition(self): + def test_explicit_condition(self, eager_or_compiled): select_indices = tp.Tensor([True, False, True, False], dtype=tp.bool) ones = tp.ones((4,), dtype=tp.int32) zeros = tp.zeros((4,), dtype=tp.int32) - w = tp.where(select_indices, ones, zeros) + w = eager_or_compiled(tp.where, select_indices, ones, zeros) assert cp.from_dlpack(w).get().tolist() == [1, 0, 1, 0] From b04d42023f4903e59037d3fe0c044be56b5716aa Mon Sep 17 00:00:00 2001 From: Parth Chadha Date: Thu, 21 Nov 2024 10:43:46 -0800 Subject: [PATCH 26/29] Update mlir-tensorrt dependency version in Tripy (#399) --- tripy/docs/packages.html | 3 +++ tripy/pyproject.toml | 6 +++--- tripy/tripy/__init__.py | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/tripy/docs/packages.html b/tripy/docs/packages.html index 28063782d..cc2716d6b 100644 --- a/tripy/docs/packages.html +++ b/tripy/docs/packages.html @@ -9,6 +9,9 @@

Package Index

+ tripy-0.0.5-py3-none-any.whl
+ tripy-0.0.4-py3-none-any.whl
diff --git a/tripy/pyproject.toml b/tripy/pyproject.toml index a22dc06dc..aed7e2c37 100644 --- a/tripy/pyproject.toml +++ b/tripy/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "tripy" -version = "0.0.4" +version = "0.0.5" authors = [{name = "NVIDIA", email="svc_tensorrt@nvidia.com"}] description = "Tripy: A Python Programming Model For TensorRT" readme = "README.md" @@ -8,8 +8,8 @@ requires-python = ">= 3.9" license = {text = "Apache 2.0"} dependencies = [ "tensorrt~=10.0", - "mlir-tensorrt-compiler==0.1.36+cuda12.trt102", - "mlir-tensorrt-runtime==0.1.36+cuda12.trt102", + "mlir-tensorrt-compiler==0.1.37+cuda12.trt102", + "mlir-tensorrt-runtime==0.1.37+cuda12.trt102", "colored==2.2.3", ] diff --git a/tripy/tripy/__init__.py b/tripy/tripy/__init__.py index 642f0ec16..48e88d7a8 100644 --- a/tripy/tripy/__init__.py +++ b/tripy/tripy/__init__.py @@ -15,7 +15,7 @@ # limitations under the License. # -__version__ = "0.0.4" +__version__ = "0.0.5" # Import TensorRT to make sure all dependent libraries are loaded first. import tensorrt From a2d3d11cb642d56ebd6ae06aa5691faa71c0cbc6 Mon Sep 17 00:00:00 2001 From: pranavm Date: Thu, 21 Nov 2024 13:01:05 -0800 Subject: [PATCH 27/29] Updates various guides - Updates `RELEASE.md` to mention things to be careful of when creating a new release. - Updates `debugging.md` to be more concise. --- tripy/RELEASE.md | 10 ++++- .../post0_developer_guides/architecture.md | 1 - .../docs/post0_developer_guides/debugging.md | 42 +++++++++++-------- 3 files changed, 33 insertions(+), 20 deletions(-) diff --git a/tripy/RELEASE.md b/tripy/RELEASE.md index 247464180..5be715e91 100644 --- a/tripy/RELEASE.md +++ b/tripy/RELEASE.md @@ -5,15 +5,21 @@ This document explains how to release a new version of Tripy. 1. Update version numbers in [`pyproject.toml`](./pyproject.toml) and [`__init__.py`](./tripy/__init__.py) (make sure they match!). + Often, updates to Tripy will also require updates to dependencies, + like MLIR-TRT, so make sure to update those version numbers as well. + 2. Add a new entry to [`packages.html`](./docs/packages.html). This ensures that we will be able to `pip install` Tripy. 3. If there were any other functional changes since the most recent L1, make sure to run L1 testing locally. -4. Create a PR with the above two changes. +4. Create a PR with the above changes. + +5. Once the PR created in (4) is merged, **WAIT FOR THE POST-MERGE PIPELINES TO COMPLETE**. + This is a very important step as otherwise the release pipeline could fail. -5. Once the PR created in (4) is merged, create a new tag with: + Once the post-merge pipelines have succeeded, create a new tag with: ```bash git tag tripy-vX.Y.Z ``` diff --git a/tripy/docs/post0_developer_guides/architecture.md b/tripy/docs/post0_developer_guides/architecture.md index d824703a4..914e29956 100644 --- a/tripy/docs/post0_developer_guides/architecture.md +++ b/tripy/docs/post0_developer_guides/architecture.md @@ -3,7 +3,6 @@ This document explains the overall architecture of Tripy. - ## Overview The main technical requirement of Tripy is twofold: diff --git a/tripy/docs/post0_developer_guides/debugging.md b/tripy/docs/post0_developer_guides/debugging.md index 64beb03dc..6b60c5c71 100644 --- a/tripy/docs/post0_developer_guides/debugging.md +++ b/tripy/docs/post0_developer_guides/debugging.md @@ -1,29 +1,37 @@ -# Debugging MLIR-TensorRT backend +# Debugging MLIR-TensorRT -1. Install new python bindings for compiler and runtime. Assuming `tripy/mlir-tensorrt` directory exists. No need to update `LD_LIBRARY_PATH`. +While developing Tripy features, you may need to debug MLIR-TRT code. +This guide outlines some methods of doing so. - - ```bash - python3 -m pip install --force-reinstall mlir-tensorrt/build/wheels/trt100/**/*.whl - ``` - -2. Set environment flags for debugging: +## Environment Variables + +We include some environment variables to enable extra debugging information from MLIR-TRT: + +- `export TRIPY_MLIR_DEBUG_ENABLED=1` will enable debug prints in MLIR-TRT and dump all intermediate IRs to a directory. +- `export TRIPY_MLIR_DEBUG_PATH=` sets the directory for IR dumps. The default path is `mlir-dumps`. +- `export TRIPY_TRT_DEBUG_ENABLED=1` will dump TensorRT engines and their layer information. +- `export TRIPY_TRT_DEBUG_PATH=` sets the directory for TensorRT dumps. Default path is `tensorrt-dumps`. -- `export TRIPY_MLIR_DEBUG_ENABLED=1` to enable MLIR-TRT debugging. It will enable debugging prints in MLIR-TRT as well as dump all intermediate IRs after each pass. -- `export TRIPY_MLIR_DEBUG_PATH=` to set debug path for MLIR-TRT dumps. Default path is `mlir-dumps` under the repo directory. This will create one or more folders named like `module_ins_t1_outs_t2_1`. -- `export TRIPY_TRT_DEBUG_ENABLED=1` to enable TensorRT debugging. It will dump TensorRT engines and their layer information (if there are any TensorRT built during compilation). -- `export TRIPY_TRT_DEBUG_PATH=` to set debug path for TensorRT dumps. Default path is `tensorrt-dumps` under the repo directory. +## Using A Debugger -3. Use LLDB for debugging MLIR-TensorRT backend. -In order to use `lldb` in tripy container, launch the container with extra security options: +For more involved bugs, it may be helpful to step into MLIR-TRT code. +To do so, you will need a debug build of MLIR-TRT; +see [CONTRIBUTING.md](source:/CONTRIBUTING.md) +for details on using custom builds of MLIR-TRT. + +Once you've installed the debug build in the container, you should be able to use `gdb` as normal. + +Alternatively, you can use [LLDB](https://lldb.llvm.org/) if you launch the container with extra security options: ```bash docker run --gpus all --cap-add=SYS_PTRACE \ - --security-opt seccomp=unconfined --security-opt apparmor=unconfined \ - -p 8080:8080 -v $(pwd):/tripy/ -it --rm tripy:latest + --security-opt seccomp=unconfined --security-opt apparmor=unconfined \ + -p 8080:8080 -v $(pwd):/tripy/ -it --rm tripy:latest ``` -See https://forums.swift.org/t/debugging-using-lldb/18046 for more details. + +See [this post](https://forums.swift.org/t/debugging-using-lldb/18046) for details on +why these security options are required. From 7fc38c87cb87c620898375e0b1282602a044b277 Mon Sep 17 00:00:00 2001 From: pranavm Date: Thu, 21 Nov 2024 15:15:48 -0800 Subject: [PATCH 28/29] Formats code in documentation with black This change updates our documentation generation to format all code blocks with `black`. This helps ensure that lines do not overflow and create scrollable elements in the rendered docs. For lines that `black` doesn't touch (e.g. comments), we include our own assertions to ensure they don't exceed the length limits. The logic to extract code from a code block (i.e. stripping out markup) has also been consolidated. Finally, this change removes the `manual` test cadence and related tests since they were redundant (the same things are tested during doc generation) and didn't add any value (the thought was they would make it easier to debug, but (1) they don't and (2) it's rarely difficult to figure out what's wrong with the code blocks in documentation) --- tripy/docs/conf.py | 3 +- tripy/docs/generate_rsts.py | 3 +- .../how-to-add-new-ops.md | 73 ++++---- tripy/pyproject.toml | 2 +- tripy/tests/README.md | 11 +- tripy/tests/helper.py | 157 ++++++++++-------- tripy/tests/test_examples.py | 10 +- tripy/tests/test_helper.py | 18 +- tripy/tests/test_internal_docs.py | 20 --- tripy/tests/test_ux.py | 14 -- tripy/tripy/backend/api/compile.py | 6 +- tripy/tripy/frontend/trace/ops/plugin.py | 7 +- 12 files changed, 164 insertions(+), 160 deletions(-) diff --git a/tripy/docs/conf.py b/tripy/docs/conf.py index 762e289ea..8b30df9a1 100644 --- a/tripy/docs/conf.py +++ b/tripy/docs/conf.py @@ -286,10 +286,9 @@ def allow_no_example(): code_block_lines, local_var_lines, output_lines, _ = helper.process_code_block_for_outputs_and_locals( block, - block.code(), format_contents=lambda title, contents, lang: f"\n\n.. code-block:: {lang}\n" + indent((f":caption: {title}" if title else "") + f"\n\n{contents}", prefix=" " * helper.TAB_SIZE), - err_msg=f"Failed while processing docstring for: {what}: {name} ({obj})", + err_msg=f"Failed while processing docstring for: {what}: {name} ({obj}): ", strip_assertions=True, ) diff --git a/tripy/docs/generate_rsts.py b/tripy/docs/generate_rsts.py index 6a31b8bbb..5f9283b98 100644 --- a/tripy/docs/generate_rsts.py +++ b/tripy/docs/generate_rsts.py @@ -208,9 +208,8 @@ def add_block(title, contents, lang): code_block_lines, local_var_lines, output_lines, code_locals = ( helper.process_code_block_for_outputs_and_locals( block.raw_str(), - str(block), format_contents=add_block, - err_msg=f"Error while executing code block from {guide_path}.", + err_msg=f"Error while executing code block {index} (line {block.line_number}) from {guide_path}. ", local_vars=code_locals, ) ) diff --git a/tripy/docs/post0_developer_guides/how-to-add-new-ops.md b/tripy/docs/post0_developer_guides/how-to-add-new-ops.md index 928c52e69..87abe6fd7 100644 --- a/tripy/docs/post0_developer_guides/how-to-add-new-ops.md +++ b/tripy/docs/post0_developer_guides/how-to-add-new-ops.md @@ -47,9 +47,10 @@ from tripy.flat_ir.ops.base import BaseFlatIROp class ThetaOp(BaseFlatIROp): dim: int - # `to_mlir()` is the trickiest bit. As the name implies, the method is meant to lower the - # `FlatIR` operator into MLIR. To figure out which MLIR operators to use, refer to - # the 'MLIR Python API Guide' (linked below). + # `to_mlir()` is the trickiest bit. As the name implies, the method is + # meant to lower the `FlatIR` operator into MLIR. To figure out which + # MLIR operators to use, refer to the 'MLIR Python API Guide' + # (linked below). def to_mlir(self, operands): out_type = self.outputs[0].to_mlir() theta_dim = ir.IntegerAttr.get(type=ir.IntegerType.get_signless(64), value=self.dim) @@ -116,29 +117,31 @@ from tripy.frontend.trace.ops.base import BaseTraceOp import tripy.frontend.trace.ops.utils as op_utils -# Just like with `FlatIR` operators, all `Trace` operators are implemented as `dataclass`es. -# As before, we want `repr=False` here. +# Just like with `FlatIR` operators, all `Trace` operators are implemented +# as `dataclass`es. As before, we want `repr=False` here. @dataclass(repr=False) class Theta(BaseTraceOp): - # Notice that we do *not* need to define a constructor and can rely on the default - # implementation provided by `dataclass`. + # Notice that we do *not* need to define a constructor and can rely on + # the default implementation provided by `dataclass`. dim: int dtype: datatype.dtype # `infer_rank()` populates the rank of the output `TraceTensor`s. - # Here we use one of the predefined policies to set the output rank to the same as the shape (i.e. the length) - # of the shape operand. + # Here we use one of the predefined policies to set the output rank + # to the same as the shape (i.e. the length) of the shape operand. infer_rank = op_utils.InferRankPolicies.same_as_shape_of_shape_input() # *Optional* `infer_dtypes()` populates the data types of the # output `TraceTensor`s. The default implementation copies the input - # data types if they are all the same, so you may not need to implement this. + # data types if they are all the same, so you may not need to implement + # this. def infer_dtypes(self): self.outputs[0].dtype = self.dtype # *Optional* `infer_devices()` populates the devices of the # output `TraceTensor`s. The default implementation copies the input - # devices if they are all the same, so you may not need to implement this either. + # devices if they are all the same, so you may not need to implement + # this either. def infer_devices(self): self.outputs[0].device = device("gpu") @@ -177,30 +180,35 @@ from tripy import export import tripy.frontend.utils as frontend_utils from tripy.types import ShapeLike -# We can use the `export.public_api()` decorator to automatically export this function into the -# top-level module. This means it will be accessible as `tripy.theta`. +# We can use the `export.public_api()` decorator to automatically export this +# function into the top-level module. This means it will be accessible as +# `tripy.theta`. # -# This decorator also controls how the API is exposed in the documentation - the `document_under` -# option determines where in the documentation hierarchy this API will show up. +# This decorator also controls how the API is exposed in the documentation - +# the `document_under` option determines where in the documentation hierarchy +# this API will show up. # -# If we needed to provide any special autodoc options, we could use the `autodoc_options` parameter. +# If we needed to provide any special autodoc options, we could use the +# `autodoc_options` parameter. @export.public_api(document_under="tensor_operations") -# The `convert_to_tensors` decorator automatically converts compatible arguments, -# like `TensorLike` or `ShapeLike`s, into tensors. +# The `convert_to_tensors` decorator automatically converts compatible +# arguments, like `TensorLike` or `ShapeLike`s, into tensors. @frontend_utils.convert_to_tensors() def theta(shape: ShapeLike, dim: int = 0, dtype: datatype.dtype = datatype.float32) -> "tripy.Tensor": - # For any public facing interfaces, we have documentation requirements which you can read - # about in the 'Docs README' (linked below). The docstring we've implemented here - # adheres to all of these requirements. Non-compliant docstrings will, in most cases, - # cause test failures; however, you should still manually ensure you're writing high-quality - # docstrings. + # For any public facing interfaces, we have documentation requirements which + # you can read about in the 'Docs README' (linked below). The docstring + # we've implemented here adheres to all of these requirements. Non-compliant + # docstrings will, in most cases, cause test failures; however, you should + # still manually ensure you're writing high-quality docstrings. # - # The examples in docstrings are run as part of our tests, so you should also add - # assertions to make sure things are functionally correct. In this case, we check - # that the `output` we create in the code example is what we expect. + # The examples in docstrings are run as part of our tests, so you should + # also add assertions to make sure things are functionally correct. In this + # case, we check that the `output` we create in the code example is what we + # expect. """ - Fills an output tensor with consecutive values starting from zero along the given dimension. + Fills an output tensor with consecutive values starting from zero + along the given dimension. Args: shape: The desired shape. @@ -217,12 +225,15 @@ def theta(shape: ShapeLike, dim: int = 0, dtype: datatype.dtype = datatype.float output = tp.theta([3]) - assert np.array_equal(cp.from_dlpack(output).get(), np.arange(0, 3, dtype=np.float32)) + assert np.array_equal( + cp.from_dlpack(output).get(), np.arange(0, 3, dtype=np.float32) + ) """ - # Next we build the trace operator. The `build()` function is also responsible for constructing - # the output frontend Tensors. All of the arguments that follow the inputs - # are forwarded directly to the constructor of the `Trace` operator. + # Next we build the trace operator. The `build()` function is also + # responsible for constructing the output frontend Tensors. All of the + # arguments that follow the inputs are forwarded directly to the + # constructor of the `Trace` operator. return Theta.build([shape], dim, dtype) ``` diff --git a/tripy/pyproject.toml b/tripy/pyproject.toml index aed7e2c37..e624cc74f 100644 --- a/tripy/pyproject.toml +++ b/tripy/pyproject.toml @@ -40,6 +40,7 @@ build = [ "mypy==1.11.0", ] doc_test_common = [ + "black==24.10.0", "torch==2.4.0+cu121", "numpy==1.25.0", # cupy requires NVRTC but does not specify it as a package dependency @@ -96,5 +97,4 @@ testpaths = [ addopts = "--strict-markers" markers = [ "l1: Indicates that the test should only be run in nightlies.", - "manual: Disables tests in automation", ] diff --git a/tripy/tests/README.md b/tripy/tests/README.md index c971a9c3d..91399d5f7 100644 --- a/tripy/tests/README.md +++ b/tripy/tests/README.md @@ -18,13 +18,12 @@ You can also provide marker arguments to only run specific test cadences L0 tests, use: ```bash -pytest tests/ -v -m "not l1 and not manual" -n 4 --dist worksteal --ignore tests/performance -pytest tests/performance -v -m "not l1 and not manual" +pytest tests/ -v -m "not l1" -n 4 --dist worksteal --ignore tests/performance +pytest tests/performance -v -m "not l1" ``` -Note that the L0/L1 tests can be parallelized, which is not necessarily -true of `manual` tests. In that case, performance tests are run separately -because they must run serially to ensure accurate measurements. +Note that the L0/L1 tests can be parallelized. In that case, performance tests +are run separately because they must run serially to ensure accurate measurements. ## Profiling @@ -36,7 +35,7 @@ tests together. For example, to profile L0 tests, run: ```bash -pytest tests/ -v -m "not l1 and not manual" --ignore tests/performance --profile +pytest tests/ -v -m "not l1" --ignore tests/performance --profile ``` You can visualize the results using `snakeviz`. diff --git a/tripy/tests/helper.py b/tripy/tests/helper.py index ebefc94bb..d072c8a00 100644 --- a/tripy/tests/helper.py +++ b/tripy/tests/helper.py @@ -27,6 +27,7 @@ from textwrap import dedent, indent from typing import Any, Callable, Dict, List, Optional, Sequence, Set +import black import cupy as cp import numpy as np import pytest @@ -117,20 +118,25 @@ def config(name: str, value: Any): } -class DocstringCodeBlock(str): - def code(self) -> str: - # Special directives can be used in the code blocks and they should be - # excluded from the actual code. - def is_directive(line): - if not line.strip().startswith(":"): - return False - tokens = line.strip().split(" ") - if not tokens: - return False - return tokens[0].endswith(":") +def get_code_bounds(lines): + # Returns the start and end index of lines of pure code in a block. The block may contain backticks + # or RST markup indicating a code block. + code_start = len(lines) + code_end = 0 + BLOCK_MARKUP = {"```", ".. code-block::", ":"} + for index, line in enumerate(lines): + line = line.strip() + if line and not any(line.startswith(markup) for markup in BLOCK_MARKUP): + code_start = min(index, code_start) + + if line != "```": + code_end = max(index, code_end) + code_end += 1 + return code_start, code_end + - text = self.replace(".. code-block:: python", "", 1) - return "\n".join([line for line in text.splitlines() if not is_directive(line)]) +class DocstringCodeBlock(str): + pass def consolidate_code_blocks(doc): @@ -241,45 +247,6 @@ def get_all_tripy_interfaces(): return all_objects -def get_all_docstrings_with_examples(): - def get_qualname(obj): - if isinstance(obj, property): - return obj.fget.__qualname__ - return obj.__qualname__ - - # Because of our complicated method registration logic, the free function and method - # might both be recognized as separate objects by `get_all_tripy_interfaces()`. - # In order to avoid redundant testing, we compare the docstrings directly instead. - seen_docstring_hashes = set() - docstrings = [] - ids = [] - tripy_interfaces = get_all_tripy_interfaces() - for obj in tripy_interfaces: - if not obj.__doc__: - print(f"Skipping {get_qualname(obj)} because no docstring was present") - continue - - doc_hash = hash(obj.__doc__) - if doc_hash in seen_docstring_hashes: - print(f"Skipping {get_qualname(obj)} because it duplicates the docstring of another interface") - continue - seen_docstring_hashes.add(doc_hash) - - blocks = [ - dedent(block.code()) - for block in consolidate_code_blocks(obj.__doc__) - if isinstance(block, DocstringCodeBlock) - ] - if blocks is None: - print(f"Skipping {get_qualname(obj)} because no example was present in the docstring") - continue - - docstrings.extend(blocks) - ids.extend([f"{get_qualname(obj)}:{idx}" for idx in range(len(blocks))]) - - return docstrings, ids - - ## ## Working with READMEs ## @@ -379,12 +346,11 @@ def exiting(self, marker): class ReadmeCodeBlock: - def __init__(self, markers: Set[Marker], lang: str): + def __init__(self, markers: Set[Marker], lang: str, line_number: int): self.content: str = None self.markers = markers self.lang = lang - self.start_line = "" - self.end_line = "" + self.line_number = line_number def add(self, line: str): if self.content is None: @@ -396,7 +362,10 @@ def has_marker(self, name: str): return AVAILABLE_MARKERS[name] in self.markers def __str__(self): - return self.content or "" + content = self.content or "" + lines = content.splitlines() + start, end = get_code_bounds(lines) + return "\n".join(lines[start:end]) def __bool__(self): return bool(self.content) @@ -404,35 +373,36 @@ def __bool__(self): # Returns the original raw contents of the block. # This will include the backticks that were stripped out by the consolidation function. def raw_str(self) -> str: - contents = str(self) - if self.lang == "text": - return contents - return f"{self.start_line}\n{contents}\n{self.end_line}" + return self.content or "" # Extract any ``` blocks from the README at the specified path def consolidate_code_blocks_from_readme(readme_path: str) -> List[ReadmeCodeBlock]: cmd_blocks = [] - current_block = ReadmeCodeBlock(markers=set(), lang="text") + current_block = ReadmeCodeBlock(markers=set(), lang="text", line_number=0) with MarkerTracker(readme_path) as tracker: previous_markers = copy.copy(tracker.active_markers) - for line in tracker: + for index, line in enumerate(tracker): # We use copy here so we don't accidentally alias. if tracker.entering(AVAILABLE_MARKERS["command"]): # Append previous text block before creating a new block for the command. cmd_blocks.append(copy.copy(current_block)) lang = line.strip().partition("```")[-1] - current_block = ReadmeCodeBlock(markers=copy.copy(tracker.active_markers), lang=lang) - current_block.start_line = line + current_block = ReadmeCodeBlock(markers=copy.copy(tracker.active_markers), lang=lang, line_number=index) + current_block.add(line) elif tracker.exiting(AVAILABLE_MARKERS["command"]): - current_block.end_line = line + current_block.add(line) cmd_blocks.append(copy.copy(current_block)) # Create new text block for contents between command blocks - current_block = ReadmeCodeBlock(markers=copy.copy(tracker.active_markers), lang="text") + current_block = ReadmeCodeBlock( + markers=copy.copy(tracker.active_markers), lang="text", line_number=index + ) elif tracker.active_markers != previous_markers: cmd_blocks.append(copy.copy(current_block)) # When markers change, create a new text block - current_block = ReadmeCodeBlock(markers=copy.copy(tracker.active_markers), lang="text") + current_block = ReadmeCodeBlock( + markers=copy.copy(tracker.active_markers), lang="text", line_number=index + ) else: current_block.add(line) @@ -453,7 +423,6 @@ def consolidate_code_blocks_from_readme(readme_path: str) -> List[ReadmeCodeBloc def process_code_block_for_outputs_and_locals( block: str, - code: str, format_contents: Callable[[str, str, str], str], err_msg: str = "", local_vars: Dict[str, Any] = None, @@ -483,7 +452,7 @@ def process_code_block_for_outputs_and_locals( # Set of variables *not* to print no_print_vars = set() - code_block_lines = [] + stripped_code_block_lines = [] # All code except what was requested to be omitted. output_lines = [] local_var_lines = [] @@ -510,12 +479,54 @@ def process_code_block_for_outputs_and_locals( if any(block_line.strip().startswith(tag) for tag in REMOVE_TAGS) or block_line.endswith(OMIT_COMMENT): continue - code_block_lines.append(block_line) + stripped_code_block_lines.append(block_line) + + # Format the code portion of the block with black. We use a shorter + # line length so it doesn't overflow in the rendered docs: + MAX_LINE_LENGTH = 80 + stripped_code_start, stripped_code_end = get_code_bounds(stripped_code_block_lines) + stripped_code_lines = stripped_code_block_lines[stripped_code_start:stripped_code_end] + + indentation = len(stripped_code_lines[0]) - len(stripped_code_lines[0].lstrip()) + try: + stripped_code_lines = indent( + black.format_file_contents( + dedent("\n".join(stripped_code_lines)), fast=False, mode=black.Mode(line_length=MAX_LINE_LENGTH) + ) + + "\n", + prefix=" " * indentation, + ).splitlines() + except black.NothingChanged: + pass + + # Check that comments don't exceed maximum line length. Note that `black` will not automatically split + # comments, so this needs to be done manually. Without this, each code block will become a scrollable + # element, making it very annoying to read. It is also annoying to fix this manually, but it is a one + # time cost and makes the reading experience so much better. + too_long_lines = [] + for line in stripped_code_lines: + # The indentation of the code block doesn't show up in the rendered documentation + # (indentation *within* the block obviously will.) + if len(line) - indentation > MAX_LINE_LENGTH: + too_long_lines.append(f">| {line}") + too_long_lines = "\n".join(too_long_lines) + assert ( + not too_long_lines + ), f"{err_msg}One or more lines exceed maximum line length ({MAX_LINE_LENGTH} characters). Note: lines were:\n{too_long_lines}\n" + + stripped_code_block_lines = ( + stripped_code_block_lines[:stripped_code_start] + + stripped_code_lines + + stripped_code_block_lines[stripped_code_end:] + ) if not should_eval: - return code_block_lines, local_var_lines, output_lines, local_vars + return stripped_code_block_lines, local_var_lines, output_lines, local_vars - code = dedent(code) + # When we run the code, we need to get the original code, not the strpiped one. + block_lines = block.splitlines() + code_start, code_end = get_code_bounds(block_lines) + code = dedent("\n".join(block_lines[code_start:code_end])) with capture_output() as outfile: try: @@ -604,4 +615,4 @@ def split_block_lines(title, contents, lang="python"): stdout = ANSI_ESCAPE.sub("", stdout) output_lines = split_block_lines("Output:", stdout, lang="") - return code_block_lines, local_var_lines, output_lines, code_locals + return stripped_code_block_lines, local_var_lines, output_lines, code_locals diff --git a/tripy/tests/test_examples.py b/tripy/tests/test_examples.py index e24b906fd..1faf6ed01 100644 --- a/tripy/tests/test_examples.py +++ b/tripy/tests/test_examples.py @@ -99,18 +99,20 @@ def test_examples(example, sandboxed_install_run): if block.has_marker("test: ignore") or not block.has_marker("command"): continue - block_text = str(block) + code = str(block) if block.has_marker("test: expected_stdout"): print("Checking command output against expected output: ", end="") out = statuses[-1].stdout.strip() - matched = re.match(dedent(block_text).strip(), out) + matched = re.match(dedent(code).strip(), out) print("matched!" if matched else "did not match!") print(f"==== STDOUT ====\n{out}") assert matched else: - status = example.run(block_text, sandboxed_install_run) + status = example.run(code, sandboxed_install_run) - details = f"Note: Command was: {block_text}.\n==== STDOUT ====\n{status.stdout}\n==== STDERR ====\n{status.stderr}" + details = ( + f"Note: Command was: {code}.\n==== STDOUT ====\n{status.stdout}\n==== STDERR ====\n{status.stderr}" + ) if block.has_marker("test: xfail"): assert not status.success, f"Command that was expected to fail did not fail. {details}" else: diff --git a/tripy/tests/test_helper.py b/tripy/tests/test_helper.py index 8e03c3f85..36d0ff211 100644 --- a/tripy/tests/test_helper.py +++ b/tripy/tests/test_helper.py @@ -1,3 +1,17 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from tests import helper @@ -13,7 +27,7 @@ def test_non_tripy_types_not_printed_as_locals(self): b = "42" """ - _, local_var_lines, _, _ = helper.process_code_block_for_outputs_and_locals(block, block, format_contents) + _, local_var_lines, _, _ = helper.process_code_block_for_outputs_and_locals(block, format_contents) assert not local_var_lines @@ -24,6 +38,6 @@ def test_no_print_locals(self): cpu = tp.device("cpu") """ - _, local_var_lines, _, _ = helper.process_code_block_for_outputs_and_locals(block, block, format_contents) + _, local_var_lines, _, _ = helper.process_code_block_for_outputs_and_locals(block, format_contents) assert not local_var_lines diff --git a/tripy/tests/test_internal_docs.py b/tripy/tests/test_internal_docs.py index 49a91703d..68680165c 100644 --- a/tripy/tests/test_internal_docs.py +++ b/tripy/tests/test_internal_docs.py @@ -42,7 +42,6 @@ # Guides may use inline pytest tests or regular Python code snippets. INLINE_PYTESTS = {} -CODE_BLOCKS = {} for readme, code_blocks in ALL_DOC_CODE_BLOCKS.items(): if not code_blocks: @@ -54,7 +53,6 @@ assert not any( block.has_marker("test: use_pytest") for block in code_blocks ), "Guides must not mix Pytest code blocks with non-Pytest code blocks" - CODE_BLOCKS[readme] = code_blocks @pytest.mark.parametrize( @@ -68,21 +66,3 @@ def test_inline_pytest(code_blocks): f.write(code) f.flush() assert pytest.main([f.name, "-vv", "-s"]) == 0 - - -@pytest.mark.manual # Code snippets in guides are executed during doc generation. -@pytest.mark.parametrize( - "code_blocks", - CODE_BLOCKS.values(), - ids=CODE_BLOCKS.keys(), -) -def test_python_code_snippets(code_blocks): - code_locals = {} - for block in code_blocks: - print(f"Checking code block:\n{str(block)}") - try: - new_locals = helper.exec_code(str(block), code_locals) - # Update code_locals with new variables - code_locals.update(new_locals) - except Exception as e: - raise AssertionError(f"Error while executing code block: {str(e)}") from e diff --git a/tripy/tests/test_ux.py b/tripy/tests/test_ux.py index 4cbe041df..39ec0a03e 100644 --- a/tripy/tests/test_ux.py +++ b/tripy/tests/test_ux.py @@ -99,21 +99,7 @@ def test_links_valid(self, readme): raise -DOCSTRING_TEST_CASES, DOCSTRING_IDS = helper.get_all_docstrings_with_examples() - - class TestDocstrings: - @pytest.mark.manual # This is already tested during doc generation. - @pytest.mark.parametrize("example_code", DOCSTRING_TEST_CASES, ids=DOCSTRING_IDS) - def test_examples_in_docstrings(self, example_code): - assert example_code, "Example code is empty! Is the formatting correct? Refer to `tests/README.md`." - for banned_module in ["numpy", "cupy", "tripy", "torch"]: - assert ( - f"import {banned_module}" not in example_code - ), f"Avoid importing {banned_module} in example docstrings" - assert f"from {banned_module}" not in example_code, f"Avoid importing {banned_module} in example docstrings" - - helper.exec_code(example_code) @pytest.mark.parametrize("api", PUBLIC_APIS, ids=lambda public_api: public_api.qualname) def test_all_public_apis_have_docstrings(self, api): diff --git a/tripy/tripy/backend/api/compile.py b/tripy/tripy/backend/api/compile.py index f1816e370..a7fe83d62 100644 --- a/tripy/tripy/backend/api/compile.py +++ b/tripy/tripy/backend/api/compile.py @@ -78,7 +78,8 @@ def add(a, b): # doc: no-print-locals compiled_add - # Support shapes in the range of (1, 2) to (3, 2), optimizing for a shape of (2, 2) + # Support shapes in the range of (1, 2) to (3, 2), optimizing for a + # shape of (2, 2) compiled_add = tp.compile( add, args=[ @@ -92,7 +93,8 @@ def add(a, b): small_out = compiled_add(small_a, small_b) - # Now we can reuse the compiled function for any shapes within the range: + # Now we can reuse the compiled function for any shapes within the + # range: big_a = tp.ones((3, 2), dtype=tp.float32) big_b = tp.ones((3, 2), dtype=tp.float32) diff --git a/tripy/tripy/frontend/trace/ops/plugin.py b/tripy/tripy/frontend/trace/ops/plugin.py index e88ed86e2..4651dc070 100644 --- a/tripy/tripy/frontend/trace/ops/plugin.py +++ b/tripy/tripy/frontend/trace/ops/plugin.py @@ -80,10 +80,11 @@ def plugin( out = tp.plugin( "CustomGeluPluginDynamic", [inp], - # GELU has a single output which always has the same rank and data type as the input. + # GELU has a single output which always has the same rank and data + # type as the input. output_info=[(inp.rank, inp.dtype)], - # The GELU plugin expects a `type_id` parameter indicating the precision to use. - # `0` indicates float32. + # The GELU plugin expects a `type_id` parameter indicating the precision + # to use. `0` indicates float32. type_id=0, ) From c8ee99f0490e4db896444031a29a81bbda42ceaf Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 21 Nov 2024 21:55:01 -0500 Subject: [PATCH 29/29] [Tripy] Eliminate need for `skip_num_stack_entries` argument in `convert_to_tensors` (#333) Addresses issue #310. The only use of `skip_num_stack_entries` was for `slice_helper` and addressing this issue in a systematic manner would likely require building in many hacks and assumptions, so the approach here is just to manually override the stack information in that one function. --- tripy/tripy/frontend/trace/ops/slice.py | 35 ++++++++++++++++++++++--- tripy/tripy/frontend/utils.py | 29 ++++---------------- tripy/tripy/utils/ast.py | 9 ++++--- 3 files changed, 42 insertions(+), 31 deletions(-) diff --git a/tripy/tripy/frontend/trace/ops/slice.py b/tripy/tripy/frontend/trace/ops/slice.py index 954be255a..5f82e49e0 100644 --- a/tripy/tripy/frontend/trace/ops/slice.py +++ b/tripy/tripy/frontend/trace/ops/slice.py @@ -250,8 +250,37 @@ def clamp_bound(bound: Union[int, Tensor]) -> Union[int, Tensor]: return out -# Because the helper is called inside another function, we need to skip one entry in the call stack to find -# the original call to user code. -@frontend_utils.convert_to_tensors(skip_num_stack_entries=1) +@frontend_utils.convert_to_tensors() def slice_helper(tensor, *slice_params: TensorLike): + from tripy.utils import get_arg_candidate_column_offsets + + # The default behavior of convert_to_tensors will not add the correct column info to the slice params + # because this call occurs *inside* the overridden call to __getitem__, so we adjust the column info manually. + + # Look for the stack frame index to __getitem__. We need to go one stack frame beyond to get to the *user* call of __getitem__. + # It will be the same for all the slice params + frame_index = -1 + assert slice_params + + for idx, source_info in enumerate(slice_params[0].stack_info): + if source_info._dispatch_target == "__getitem__": + frame_index = idx + 1 + break + + # convert_to_tensors should have taken care of this for us + assert frame_index >= 0, "No call to the __getitem__ dispatch found" + + arg_names = ["tensor"] + ["slice_params"] * len(slice_params) + for arg_index, arg in enumerate(slice_params): + source_info = arg.stack_info[frame_index] + + # Note: arg_index does not account for the positional arg, hence we add 1 for the index argument + candidates = get_arg_candidate_column_offsets( + source_info.code, 1 + arg_index, 1, "__getitem__", False, arg_names + ) + + # Now we can set the column range correctly + if len(candidates) == 1: + source_info.column_range = candidates[0] + return Slice.build(inputs=[tensor, *slice_params]) diff --git a/tripy/tripy/frontend/utils.py b/tripy/tripy/frontend/utils.py index 352a2fb40..6bc50f4ad 100644 --- a/tripy/tripy/frontend/utils.py +++ b/tripy/tripy/frontend/utils.py @@ -70,7 +70,8 @@ def empty_buffer(): # Try to include correct column offsets for non-tensor arguments. -def _add_column_info(arg, arg_index, is_kwarg, num_positional, func_name, skip_num_stack_entries, arg_names): +def _add_column_info(arg, arg_index, is_kwarg, num_positional, func_name, arg_names): + from tripy import function_registry from tripy.frontend.tensor import Tensor assert isinstance(arg, Tensor), f"This function should only be called for objects that are already Tensor instances" @@ -90,10 +91,10 @@ def _add_column_info(arg, arg_index, is_kwarg, num_positional, func_name, skip_n # Find the first caller of this function that is NOT the function registry. # Also save the last dispatch target we see. dispatch_target = None - for idx, source_info in enumerate(arg.stack_info[WRAPPER_STACK_DEPTH + skip_num_stack_entries :]): + for idx, source_info in enumerate(arg.stack_info[WRAPPER_STACK_DEPTH:]): dispatch_target = source_info._dispatch_target or dispatch_target if source_info.module not in utils.get_module_names_to_exclude_from_stack_info(): - frame_index = idx + WRAPPER_STACK_DEPTH + skip_num_stack_entries + frame_index = idx + WRAPPER_STACK_DEPTH break else: # Fallback path is just to look at the user code @@ -118,12 +119,6 @@ def _add_column_info(arg, arg_index, is_kwarg, num_positional, func_name, skip_n arg_index = 0 if arg_index == 1 else 1 dispatch_target = dispatch_target.replace("__r", "__") - # Special case for __getitem__: It is variadic. Argument 0 is the tensor, - # and all subsequent arguments are slice parameters (in start, stop, step order). - # Hence, we subtract one to get the index of the slice parameters - if dispatch_target == "__getitem__": - arg_index -= 1 - candidates = utils.get_arg_candidate_column_offsets( source_info.code, arg_index, num_positional, dispatch_target or func_name, is_kwarg, arg_names ) @@ -136,9 +131,7 @@ def _add_column_info(arg, arg_index, is_kwarg, num_positional, func_name, skip_n # NOTE: Conversion to tensors needs to be done via a decorator so that we can add stack information # for non-tensors. Without having full context of the function signature, it is otherwise difficult to do so. -def convert_to_tensors( - targets: Set[str] = None, skip_num_stack_entries: int = 0, preprocess_args: Optional[Callable] = None -): +def convert_to_tensors(targets: Set[str] = None, preprocess_args: Optional[Callable] = None): """ Decorator that converts specified arguments to Tensors or DimensionSizes. If the argument can be converted to a DimensionSize, it is. Otherwise, it is @@ -152,17 +145,6 @@ def convert_to_tensors( targets: Names of arguments to convert to tensors. If not supplied, any arguments annotated with `TensorLike` or `ShapeLike` are converted. - skip_num_stack_entries: If the decorator is used on a function that is *called by* - a function that the user invokes, it will be necessary to skip stack entries - in order to get the column info from the user code. The number of entries skipped - should be equal to the nesting depth from a function called by user code - (if the decorated function is called by the user the depth is 0; - if the decorated function is called from a user function, the depth is 1; etc.). - - NOTE: When using this, make sure any extra arguments to the decorated function are - passed as keyword arguments. Otherwise, the logic for determining column information - will break. - preprocess_args: A callback used to preprocess arguments before potential conversion. If provided, this is always called, regardless of whether the decorator actually needed to perform conversion. This will be called with all arguments that were passed to the decorated function and should @@ -242,7 +224,6 @@ def add_arg(arg): name in kwargs, len(args), func.__name__, - skip_num_stack_entries, [name for name, _ in all_args], ) diff --git a/tripy/tripy/utils/ast.py b/tripy/tripy/utils/ast.py index 6de1e2e88..90fc7a8de 100644 --- a/tripy/tripy/utils/ast.py +++ b/tripy/tripy/utils/ast.py @@ -139,12 +139,13 @@ def index_into_expr(node: ast.expr, index: int) -> ast.expr: return node # If we have multiple dimensions specified, then we have a tuple of slices. - # Indices are given in as a list of start, stop, step + # NOTE: We subtract num_positional from the index because the slice arguments would + # be passed as *variadic arguments* to slice_helper and so would come after the positional argument if isinstance(node.slice, ast.Tuple): - element = node.slice.elts[index // 3] - arg_node = index_into_expr(element, index % 3) + element = node.slice.elts[(index - num_positional) // 3] + arg_node = index_into_expr(element, (index - num_positional) % 3) else: - arg_node = index_into_expr(node.slice, index) + arg_node = index_into_expr(node.slice, (index - num_positional)) if arg_node is not None: candidates.append((indentation + arg_node.col_offset, indentation + arg_node.end_col_offset))