Skip to content

Commit

Permalink
feat: support amax dynamo converter (#2241)
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 authored Aug 25, 2023
1 parent b774440 commit a65c95c
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 0 deletions.
31 changes: 31 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,37 @@ def aten_ops_expand(
)


def amax_param_validator(amax_node: Node) -> bool:
if len(amax_node.args) < 2:
_LOGGER.debug(
f"At least two args input and dim should be provided, but only got {len(amax_node.args)} args."
)
return False

return True


@dynamo_tensorrt_converter(
torch.ops.aten.amax.default, capability_validator=amax_param_validator
)
def aten_ops_amax(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.reduce.amax(
network,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
args_bounds_check(args, 2, replacement=False),
)


@dynamo_tensorrt_converter(torch.ops.aten.exp.default) # type: ignore[misc]
def aten_ops_exp(
network: TRTNetwork,
Expand Down
7 changes: 7 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import logging
import re
from typing import List, Optional
Expand All @@ -7,6 +8,7 @@
from torch.fx.node import Target
from torch_tensorrt.fx.converters.converter_utils import (
Frameworks,
get_axes_for_reduce_op,
unified_dtype_converter,
)
from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor
Expand Down Expand Up @@ -157,3 +159,8 @@ def broadcastable(
if not (a_shape[i] == b_shape[i] or a_shape[i] == 1 or b_shape[i] == 1):
return False
return True


get_axes_for_reduce_op = functools.partial(
get_axes_for_reduce_op, has_implicit_batch_dimension=False
)
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
matmul,
normalization,
permutation,
reduce,
select,
shape,
slice,
Expand Down
35 changes: 35 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/reduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Optional, Tuple, Union

import tensorrt as trt
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion.converter_utils import (
cast_trt_tensor,
get_axes_for_reduce_op,
)
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor


def amax(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input_val: TRTTensor,
dim: Union[int, Tuple[int]],
keepdim: bool = False,
) -> TRTTensor:
if (isinstance(input_val, TRTTensor)) and (
input_val.dtype == trt.int8 or input_val.dtype == trt.int32
):
input_val = cast_trt_tensor(network, input_val, trt.float32, name)

layer = network.add_reduce(
input_val,
trt.ReduceOperation.MAX,
axes=get_axes_for_reduce_op(dim),
keep_dims=keepdim,
)
set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)
93 changes: 93 additions & 0 deletions tests/py/dynamo/converters/test_amax_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import torch
import torch.nn as nn
from harness import DispatchTestCase
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests


class TestAmaxConverter(DispatchTestCase):
@parameterized.expand(
[
((3, 2, 4), 1, True),
((2, 3, 4, 5), 3, True),
((2, 3, 4, 5), 2, False),
((6, 7, 5, 4, 5), 4, False),
]
)
def test_amax_dim_int_default(self, input_shape, dim, keep_dims):
class Amax(nn.Module):
def forward(self, x):
return torch.amax(x, dim=dim, keepdim=keep_dims)

inputs = [torch.randn(*input_shape)]
self.run_test(
Amax(),
inputs,
expected_ops={torch.ops.aten.amax.default},
)

@parameterized.expand(
[
((3, 2, 4), [1], True),
((2, 1, 4, 5), [0, 3], True),
((2, 3, 4, 5), [0, 1, 2, 3], False),
((6, 7, 5, 4, 5), [1, 3, 4], False),
]
)
def test_amax_dim_tuple_default(self, input_shape, dim, keep_dims):
class Amax(nn.Module):
def forward(self, x):
return torch.amax(x, dim=dim, keepdim=keep_dims)

inputs = [torch.randn(*input_shape)]
self.run_test(
Amax(),
inputs,
expected_ops={torch.ops.aten.amax.default},
)

@parameterized.expand(
[
((3, 2, 4), 1, True, torch.int, 0, 5),
((2, 3, 4, 5), 3, True, torch.int, -10, 10),
((2, 3, 4, 5), 2, False, torch.int32, -5, 0),
((6, 7, 5, 4, 5), 4, False, torch.int32, -5, 5),
]
)
def test_amax_dim_int_int(self, input_shape, dim, keep_dims, dtype, low, high):
class Amax(nn.Module):
def forward(self, x):
return torch.amax(x, dim=dim, keepdim=keep_dims)

inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
self.run_test(
Amax(),
inputs,
expected_ops={torch.ops.aten.amax.default},
check_dtype=False,
)

@parameterized.expand(
[
((3, 2, 4), [1], True, torch.int, 0, 5),
((2, 1, 4, 5), [0, 3], True, torch.int, -10, 10),
((2, 3, 4, 5), [0, 1, 2, 3], False, torch.int32, -5, 0),
((6, 7, 5, 4, 5), [1, 3, 4], False, torch.int32, -5, 5),
]
)
def test_amax_dim_tuple_int(self, input_shape, dim, keep_dims, dtype, low, high):
class Amax(nn.Module):
def forward(self, x):
return torch.amax(x, dim=dim, keepdim=keep_dims)

inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
self.run_test(
Amax(),
inputs,
expected_ops={torch.ops.aten.amax.default},
check_dtype=False,
)


if __name__ == "__main__":
run_tests()

0 comments on commit a65c95c

Please sign in to comment.