diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 61b1f38242..a625c05791 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -692,7 +692,9 @@ def aten_ops_softmax( @dynamo_tensorrt_converter( torch.ops.aten.split.Tensor, - capability_validator=has_static_shapes_in_args([1]), + capability_validator=( + has_static_shapes_in_args([0]) and has_static_shapes_in_args([1]) + ), supports_dynamic_shapes=True, ) @dynamo_tensorrt_converter( diff --git a/tests/py/dynamo/conversion/test_chunk_aten.py b/tests/py/dynamo/conversion/test_chunk_aten.py index 1812165b43..64a0385b8d 100644 --- a/tests/py/dynamo/conversion/test_chunk_aten.py +++ b/tests/py/dynamo/conversion/test_chunk_aten.py @@ -1,6 +1,9 @@ +import unittest + import torch from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -27,6 +30,7 @@ def forward(self, input): self.run_test( TestChunk(), input, + use_dynamo_tracer=True, ) @parameterized.expand( @@ -51,6 +55,7 @@ def forward(self, input): self.run_test( TestChunk(), input, + use_dynamo_tracer=True, ) @parameterized.expand( @@ -75,6 +80,104 @@ def forward(self, input): self.run_test( TestChunk(), input, + use_dynamo_tracer=True, + ) + + +#######################Dynamic cases################ +####The tests are skipped for now. Will be addressed once https://github.com/pytorch/pytorch/issues/134663 is addressed +@unittest.skip("Pending aten.split converter. Currently tested by E2E") +class TestChunkDynamicConverter(DispatchTestCase): + @parameterized.expand( + [ + ((1,), (1,), (3,), 3, 0), + ((3,), (3,), (4,), 3, 0), + ((4,), (4,), (6,), 3, 0), + ((6,), (6,), (9,), 3, 0), + ((3,), (3,), (4,), 1, -1), + ((3,), (3,), (4,), 3, -1), + ((3,), (3,), (4,), 4, -1), + ] + ) + def test_chunk_1D(self, min_shape, opt_shape, max_shape, chunks, dim): + class TestChunk(torch.nn.Module): + def forward(self, input): + out = torch.ops.aten.chunk.default(input, chunks, dim) + return out + + input_specs = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + ), + ] + self.run_test_with_dynamic_shape( + TestChunk(), + input_specs, + use_dynamo_tracer=True, + ) + + @parameterized.expand( + [ + ((3, 4), (3, 4), (4, 4), 1, 0), + ((3, 4), (3, 4), (4, 4), 3, 0), + ((3, 4), (3, 4), (4, 4), 4, 0), + ((3, 4), (3, 4), (4, 4), 2, -2), + ((3, 4), (3, 4), (4, 4), 6, -2), + ((3, 4), (3, 4), (4, 4), 3, 1), + ((3, 4), (3, 4), (4, 4), 4, 1), + ((3, 4), (3, 4), (4, 4), 5, -1), + ] + ) + def test_chunk_2D(self, min_shape, opt_shape, max_shape, chunks, dim): + class TestChunk(torch.nn.Module): + def forward(self, input): + out = torch.ops.aten.chunk.default(input, chunks, dim) + return out + + input_specs = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + ), + ] + self.run_test_with_dynamic_shape( + TestChunk(), + input_specs, + use_dynamo_tracer=True, + ) + + @parameterized.expand( + [ + ((3, 4, 2), (3, 4, 2), (4, 4, 2), 1, 0), + ((3, 4, 2), (3, 4, 2), (4, 4, 2), 3, -3), + ((3, 4, 2), (3, 4, 2), (4, 4, 2), 3, 1), + ((3, 4, 2), (3, 4, 2), (4, 4, 2), 4, 1), + ((3, 4, 2), (3, 4, 2), (4, 4, 2), 6, -2), + ((3, 4, 2), (3, 4, 2), (4, 4, 2), 1, 2), + ((3, 4, 2), (3, 4, 2), (4, 4, 2), 3, -1), + ((3, 4, 2), (3, 4, 2), (4, 4, 2), 4, -1), + ] + ) + def test_chunk_3D(self, min_shape, opt_shape, max_shape, chunks, dim): + class TestChunk(torch.nn.Module): + def forward(self, input): + out = torch.ops.aten.chunk.default(input, chunks, dim) + return out + + input_specs = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + ), + ] + self.run_test_with_dynamic_shape( + TestChunk(), + input_specs, + use_dynamo_tracer=True, )