Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run test_base_fp8 for compute capability 8.9 or later #3164

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 12 additions & 17 deletions tests/py/dynamo/models/test_models_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import torchvision.models as models
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity
from transformers import BertModel
from transformers.utils.fx import symbolic_trace as transformers_trace

from packaging.version import Version

Expand Down Expand Up @@ -196,16 +195,18 @@ def test_resnet18_half(ir):


@unittest.skipIf(
torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
"FP8 compilation in Torch-TRT is not supported on cards older than Hopper",
torch.cuda.get_device_capability() < (8, 9),
"FP8 quantization requires compute capability 8.9 or later",
)
@unittest.skipIf(
not importlib.util.find_spec("modelopt"),
reason="ModelOpt is necessary to run this test",
"ModelOpt is required to run this test",
)
@pytest.mark.unit
def test_base_fp8(ir):
import modelopt
import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.utils import export_torch_mode
from torch.export._trace import _export

class SimpleNetwork(torch.nn.Module):
def __init__(self):
Expand All @@ -219,9 +220,6 @@ def forward(self, x):
x = self.linear2(x)
return x

import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.utils import export_torch_mode

def calibrate_loop(model):
"""Simple calibration function for testing."""
model(input_tensor)
Expand All @@ -236,7 +234,7 @@ def calibrate_loop(model):

with torch.no_grad():
with export_torch_mode():
exp_program = torch.export.export(model, (input_tensor,))
exp_program = _export(model, (input_tensor,))
trt_model = torchtrt.dynamo.compile(
exp_program,
inputs=[input_tensor],
Expand All @@ -247,7 +245,7 @@ def calibrate_loop(model):
reuse_cached_engines=False,
)
outputs_trt = trt_model(input_tensor)
assert torch.allclose(output_pyt, outputs_trt, rtol=1e-3, atol=1e-2)
assert torch.allclose(output_pyt, outputs_trt, rtol=5e-3, atol=1e-2)


@unittest.skipIf(
Expand All @@ -258,7 +256,9 @@ def calibrate_loop(model):
)
@pytest.mark.unit
def test_base_int8(ir):
import modelopt
import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.utils import export_torch_mode
from torch.export._trace import _export

class SimpleNetwork(torch.nn.Module):
def __init__(self):
Expand All @@ -272,9 +272,6 @@ def forward(self, x):
x = self.linear2(x)
return x

import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.utils import export_torch_mode

def calibrate_loop(model):
"""Simple calibration function for testing."""
model(input_tensor)
Expand All @@ -289,8 +286,6 @@ def calibrate_loop(model):

with torch.no_grad():
with export_torch_mode():
from torch.export._trace import _export

exp_program = _export(model, (input_tensor,))
trt_model = torchtrt.dynamo.compile(
exp_program,
Expand All @@ -302,4 +297,4 @@ def calibrate_loop(model):
reuse_cached_engines=False,
)
outputs_trt = trt_model(input_tensor)
assert torch.allclose(output_pyt, outputs_trt, rtol=1e-3, atol=1e-2)
assert torch.allclose(output_pyt, outputs_trt, rtol=5e-3, atol=1e-2)
Loading