Skip to content

Commit

Permalink
chunk_validator
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed Sep 4, 2024
1 parent 4dbeafd commit cd5815b
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 1 deletion.
4 changes: 3 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,9 @@ def aten_ops_softmax(

@dynamo_tensorrt_converter(
torch.ops.aten.split.Tensor,
capability_validator=has_static_shapes_in_args([1]),
capability_validator=(
has_static_shapes_in_args([0]) and has_static_shapes_in_args([1])
),
supports_dynamic_shapes=True,
)
@dynamo_tensorrt_converter(
Expand Down
103 changes: 103 additions & 0 deletions tests/py/dynamo/conversion/test_chunk_aten.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import unittest

import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand All @@ -27,6 +30,7 @@ def forward(self, input):
self.run_test(
TestChunk(),
input,
use_dynamo_tracer=True,
)

@parameterized.expand(
Expand All @@ -51,6 +55,7 @@ def forward(self, input):
self.run_test(
TestChunk(),
input,
use_dynamo_tracer=True,
)

@parameterized.expand(
Expand All @@ -75,6 +80,104 @@ def forward(self, input):
self.run_test(
TestChunk(),
input,
use_dynamo_tracer=True,
)


#######################Dynamic cases################
####The tests are skipped for now. Will be addressed once https://github.com/pytorch/pytorch/issues/134663 is addressed
@unittest.skip("Pending aten.split converter. Currently tested by E2E")
class TestChunkDynamicConverter(DispatchTestCase):
@parameterized.expand(
[
((1,), (1,), (3,), 3, 0),
((3,), (3,), (4,), 3, 0),
((4,), (4,), (6,), 3, 0),
((6,), (6,), (9,), 3, 0),
((3,), (3,), (4,), 1, -1),
((3,), (3,), (4,), 3, -1),
((3,), (3,), (4,), 4, -1),
]
)
def test_chunk_1D(self, min_shape, opt_shape, max_shape, chunks, dim):
class TestChunk(torch.nn.Module):
def forward(self, input):
out = torch.ops.aten.chunk.default(input, chunks, dim)
return out

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
),
]
self.run_test_with_dynamic_shape(
TestChunk(),
input_specs,
use_dynamo_tracer=True,
)

@parameterized.expand(
[
((3, 4), (3, 4), (4, 4), 1, 0),
((3, 4), (3, 4), (4, 4), 3, 0),
((3, 4), (3, 4), (4, 4), 4, 0),
((3, 4), (3, 4), (4, 4), 2, -2),
((3, 4), (3, 4), (4, 4), 6, -2),
((3, 4), (3, 4), (4, 4), 3, 1),
((3, 4), (3, 4), (4, 4), 4, 1),
((3, 4), (3, 4), (4, 4), 5, -1),
]
)
def test_chunk_2D(self, min_shape, opt_shape, max_shape, chunks, dim):
class TestChunk(torch.nn.Module):
def forward(self, input):
out = torch.ops.aten.chunk.default(input, chunks, dim)
return out

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
),
]
self.run_test_with_dynamic_shape(
TestChunk(),
input_specs,
use_dynamo_tracer=True,
)

@parameterized.expand(
[
((3, 4, 2), (3, 4, 2), (4, 4, 2), 1, 0),
((3, 4, 2), (3, 4, 2), (4, 4, 2), 3, -3),
((3, 4, 2), (3, 4, 2), (4, 4, 2), 3, 1),
((3, 4, 2), (3, 4, 2), (4, 4, 2), 4, 1),
((3, 4, 2), (3, 4, 2), (4, 4, 2), 6, -2),
((3, 4, 2), (3, 4, 2), (4, 4, 2), 1, 2),
((3, 4, 2), (3, 4, 2), (4, 4, 2), 3, -1),
((3, 4, 2), (3, 4, 2), (4, 4, 2), 4, -1),
]
)
def test_chunk_3D(self, min_shape, opt_shape, max_shape, chunks, dim):
class TestChunk(torch.nn.Module):
def forward(self, input):
out = torch.ops.aten.chunk.default(input, chunks, dim)
return out

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
),
]
self.run_test_with_dynamic_shape(
TestChunk(),
input_specs,
use_dynamo_tracer=True,
)


Expand Down

0 comments on commit cd5815b

Please sign in to comment.