diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index 25116ce865..2d0992ca8b 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -11,7 +11,6 @@ import torchvision.models as models from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity from transformers import BertModel -from transformers.utils.fx import symbolic_trace as transformers_trace from packaging.version import Version @@ -196,16 +195,18 @@ def test_resnet18_half(ir): @unittest.skipIf( - torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9, - "FP8 compilation in Torch-TRT is not supported on cards older than Hopper", + torch.cuda.get_device_capability() < (8, 9), + "FP8 quantization requires compute capability 8.9 or later", ) @unittest.skipIf( not importlib.util.find_spec("modelopt"), - reason="ModelOpt is necessary to run this test", + "ModelOpt is required to run this test", ) @pytest.mark.unit def test_base_fp8(ir): - import modelopt + import modelopt.torch.quantization as mtq + from modelopt.torch.quantization.utils import export_torch_mode + from torch.export._trace import _export class SimpleNetwork(torch.nn.Module): def __init__(self): @@ -219,9 +220,6 @@ def forward(self, x): x = self.linear2(x) return x - import modelopt.torch.quantization as mtq - from modelopt.torch.quantization.utils import export_torch_mode - def calibrate_loop(model): """Simple calibration function for testing.""" model(input_tensor) @@ -236,7 +234,7 @@ def calibrate_loop(model): with torch.no_grad(): with export_torch_mode(): - exp_program = torch.export.export(model, (input_tensor,)) + exp_program = _export(model, (input_tensor,)) trt_model = torchtrt.dynamo.compile( exp_program, inputs=[input_tensor], @@ -247,7 +245,7 @@ def calibrate_loop(model): reuse_cached_engines=False, ) outputs_trt = trt_model(input_tensor) - assert torch.allclose(output_pyt, outputs_trt, rtol=1e-3, atol=1e-2) + assert torch.allclose(output_pyt, outputs_trt, rtol=5e-3, atol=1e-2) @unittest.skipIf( @@ -258,7 +256,9 @@ def calibrate_loop(model): ) @pytest.mark.unit def test_base_int8(ir): - import modelopt + import modelopt.torch.quantization as mtq + from modelopt.torch.quantization.utils import export_torch_mode + from torch.export._trace import _export class SimpleNetwork(torch.nn.Module): def __init__(self): @@ -272,9 +272,6 @@ def forward(self, x): x = self.linear2(x) return x - import modelopt.torch.quantization as mtq - from modelopt.torch.quantization.utils import export_torch_mode - def calibrate_loop(model): """Simple calibration function for testing.""" model(input_tensor) @@ -289,8 +286,6 @@ def calibrate_loop(model): with torch.no_grad(): with export_torch_mode(): - from torch.export._trace import _export - exp_program = _export(model, (input_tensor,)) trt_model = torchtrt.dynamo.compile( exp_program, @@ -302,4 +297,4 @@ def calibrate_loop(model): reuse_cached_engines=False, ) outputs_trt = trt_model(input_tensor) - assert torch.allclose(output_pyt, outputs_trt, rtol=1e-3, atol=1e-2) + assert torch.allclose(output_pyt, outputs_trt, rtol=5e-3, atol=1e-2)