Skip to content
Closed
16 changes: 16 additions & 0 deletions py/torch_tensorrt/fx/fx2trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ def run(
timing_cache=None,
profiling_verbosity=None,
tactic_sources=None,
max_aux_streams=None,
version_compatible=False,
optimization_level=None,
) -> TRTInterpreterResult:
"""
Build TensorRT engine with some configs.
Expand Down Expand Up @@ -231,6 +234,18 @@ def run(
if profiling_verbosity
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
)

if trt.__version__ >= "8.6":
if max_aux_streams is not None:
_LOGGER.info(f"Setting max aux streams to {max_aux_streams}")
builder_config.max_aux_streams = max_aux_streams
if version_compatible:
_LOGGER.info(f"Using version compatible")
builder_config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE)
if optimization_level is not None:
_LOGGER.info(f"Using optimization level {optimization_level}")
builder_config.builder_optimization_level = optimization_level

if lower_precision == LowerPrecision.FP16:
builder_config.set_flag(trt.BuilderFlag.FP16)

Expand Down Expand Up @@ -265,6 +280,7 @@ def run(
_LOGGER.info(
f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}"
)
_LOGGER.info(f"TRT Engine uses: {engine.device_memory_size} bytes of Memory")

return TRTInterpreterResult(
engine, self._input_names, self._output_names, serialized_cache
Expand Down
3 changes: 3 additions & 0 deletions py/torch_tensorrt/fx/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
if self.lower_setting.verbose_profile
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY,
tactic_sources=self.lower_setting.tactic_sources,
max_aux_streams=self.lower_setting.max_aux_streams,
version_compatible=self.lower_setting.version_compatible,
optimization_level=self.lower_setting.optimization_level,
)

# Update timing cache file if needed
Expand Down
6 changes: 6 additions & 0 deletions py/torch_tensorrt/fx/lower_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ class LowerSetting(LowerSettingBasic):
correctness_atol: absolute tolerance for correctness check
correctness_rtol: relative tolerance for correctness check
use_experimental_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++).
max_aux_streams: max number of aux stream to use
version_compatible: enable version compatible feature
optimization_level: builder optimization level
"""

input_specs: List[InputTensorSpec] = dc.field(default_factory=list)
Expand Down Expand Up @@ -101,3 +104,6 @@ class LowerSetting(LowerSettingBasic):
correctness_atol: float = 0.1
correctness_rtol: float = 0.1
use_experimental_rt: bool = False
max_aux_streams: Optional[int] = None
version_compatible: bool = False
optimization_level: Optional[int] = None
91 changes: 49 additions & 42 deletions py/torch_tensorrt/fx/passes/pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,10 @@ def parent_pass(module: fx.GraphModule, input: Input) -> fx.GraphModule:
# (TODO(shirongwu): Add exception notification for fblearner flow when available, notify oncall
# on pass that failed accuracy check.
def validate_inference(
rtol=None, atol=None, run_alternative_batch_size: int = -1
rtol=None,
atol=None,
suppress_accuracy_check_failure=True,
run_alternative_batch_size: int = -1,
) -> "Decorator":
"""
Returns a decorator on a PassFunc to sanity check the model outputs
Expand All @@ -160,6 +163,7 @@ def validate_inference(
Args:
rtol: reletive tolerance
atol: absoluate tolerance
suppress_accuracy_check_failure: accuracy check failure
run_alternative_batch_size (int):
In addition to running inference at original batch size in the
input, also run at an alternative batch size. If set to -1, do not
Expand All @@ -181,48 +185,51 @@ def pass_with_validation(
*args,
**kwargs,
) -> fx.GraphModule:
res0 = module(*input)
processed_module = pass_(module, input, *args, **kwargs)
res1 = processed_module(*input)

tensor_res_0 = _collect_tensors(res0)
tensor_res_1 = _collect_tensors(res1)
relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE

for kk, (x, y) in enumerate(zip(tensor_res_0, tensor_res_1)):
kwargs2 = {"equal_nan": True}
if rtol:
kwargs2["rtol"] = rtol
if atol:
kwargs2["atol"] = atol
kwargs2[
"msg"
] = (
lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}"
)
# If tensors are on different devices, make sure to compare
# their copies that are on the same device.
if x.get_device() != y.get_device():
x = x.cpu()
y = y.cpu()
try:
torch.testing.assert_close(x, y, **kwargs2)
except Exception as e:
if relax_accuracy_check_failure:
_LOGGER.error(f"{e}")
kwargs2["rtol"] *= FINAL_CHECK_RTOL_MULTIPLIER
kwargs2["atol"] *= FINAL_CHECK_ATOL_MULTIPLIER
new_atol = kwargs2["atol"]
new_rtol = kwargs2["rtol"]
_LOGGER.info(
f"Do a sanity check to see whether things are completely wrong with {new_atol=}, {new_rtol=}"
)
if suppress_accuracy_check_failure:
return pass_(module, input, *args, **kwargs)
else:
res0 = module(*input)
processed_module = pass_(module, input, *args, **kwargs)
res1 = processed_module(*input)

tensor_res_0 = _collect_tensors(res0)
tensor_res_1 = _collect_tensors(res1)
relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE

for kk, (x, y) in enumerate(zip(tensor_res_0, tensor_res_1)):
kwargs2 = {"equal_nan": True}
if rtol:
kwargs2["rtol"] = rtol
if atol:
kwargs2["atol"] = atol
kwargs2[
"msg"
] = (
lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}"
)
# If tensors are on different devices, make sure to compare
# their copies that are on the same device.
if x.get_device() != y.get_device():
x = x.cpu()
y = y.cpu()
try:
torch.testing.assert_close(x, y, **kwargs2)
return processed_module
else:
raise e

return processed_module
except Exception as e:
if relax_accuracy_check_failure:
_LOGGER.error(f"{e}")
kwargs2["rtol"] *= FINAL_CHECK_RTOL_MULTIPLIER
kwargs2["atol"] *= FINAL_CHECK_ATOL_MULTIPLIER
new_atol = kwargs2["atol"]
new_rtol = kwargs2["rtol"]
_LOGGER.info(
f"Do a sanity check to see whether things are completely wrong with {new_atol=}, {new_rtol=}"
)
torch.testing.assert_close(x, y, **kwargs2)
return processed_module
else:
raise e

return processed_module

return pass_with_validation

Expand Down