From b50d12f4217c50ab0290c5e3fc52082f97b8a299 Mon Sep 17 00:00:00 2001 From: George S <113141689+gs-olive@users.noreply.github.com> Date: Tue, 6 Feb 2024 09:32:14 -0800 Subject: [PATCH] small fix: Index validator enable int64 (#2642) (#2643) --- examples/dynamo/torch_compile_advanced_usage.py | 7 +++++-- examples/dynamo/torch_compile_transformers_example.py | 1 + .../dynamo/conversion/aten_ops_converters.py | 2 +- tests/py/dynamo/conversion/test_index_aten.py | 8 ++------ 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/dynamo/torch_compile_advanced_usage.py b/examples/dynamo/torch_compile_advanced_usage.py index 96146a43d8..8ebedab111 100644 --- a/examples/dynamo/torch_compile_advanced_usage.py +++ b/examples/dynamo/torch_compile_advanced_usage.py @@ -43,7 +43,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): # For the default settings, we can simply call torch.compile # with the backend "torch_tensorrt", and run the model on an # input to cause compilation, as so: -optimized_model = torch.compile(model, backend="torch_tensorrt") +optimized_model = torch.compile(model, backend="torch_tensorrt", dynamic=False) optimized_model(*sample_inputs) # %% @@ -81,7 +81,10 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): # Run the model on an input to cause compilation, as so: optimized_model_custom = torch.compile( - model_half, backend="torch_tensorrt", options=backend_kwargs + model_half, + backend="torch_tensorrt", + options=backend_kwargs, + dynamic=False, ) optimized_model_custom(*sample_inputs_half) diff --git a/examples/dynamo/torch_compile_transformers_example.py b/examples/dynamo/torch_compile_transformers_example.py index 5422f9cc1d..01d46e96f6 100644 --- a/examples/dynamo/torch_compile_transformers_example.py +++ b/examples/dynamo/torch_compile_transformers_example.py @@ -61,6 +61,7 @@ optimized_model = torch.compile( model, backend="torch_tensorrt", + dynamic=False, options=compilation_kwargs, ) optimized_model(*inputs) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 69cb6cecb1..e8ee7d2bae 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -397,7 +397,7 @@ def index_dtype_validator(node: Node) -> bool: for ind in index: if ind is not None: val = ind.meta.get("val") - if val is not None and val.dtype != torch.int32: + if val is not None and val.dtype not in (torch.int32, torch.int64): return False return True diff --git a/tests/py/dynamo/conversion/test_index_aten.py b/tests/py/dynamo/conversion/test_index_aten.py index 88db7f0817..56762cd4bc 100644 --- a/tests/py/dynamo/conversion/test_index_aten.py +++ b/tests/py/dynamo/conversion/test_index_aten.py @@ -1,9 +1,8 @@ -import operator - import torch import torch.nn as nn from torch.testing._internal.common_utils import run_tests -from torch_tensorrt import Input + +from .harness import DispatchTestCase from .harness import DispatchTestCase @@ -16,7 +15,6 @@ def __init__(self): super().__init__() def forward(self, x): - index0 = torch.randint(0, 1, (1, 1)) indices = [None, self.index0] out = torch.ops.aten.index.Tensor(x, indices) return out @@ -159,8 +157,6 @@ def __init__(self): super().__init__() def forward(self, x): - index0 = torch.tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7]) - index1 = index0.unsqueeze(0).T.long() indices = [None, None, self.index0, self.index1] out = torch.ops.aten.index.Tensor(x, indices) return out