Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upstream 3 features for torch-trt: ms, VC, and optimisation level #1926

Closed
wants to merge 11 commits into from
16 changes: 16 additions & 0 deletions py/torch_tensorrt/fx/fx2trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ def run(
timing_cache=None,
profiling_verbosity=None,
tactic_sources=None,
max_aux_streams=None,
version_compatible=False,
optimization_level=None,
) -> TRTInterpreterResult:
"""
Build TensorRT engine with some configs.
Expand Down Expand Up @@ -231,6 +234,18 @@ def run(
if profiling_verbosity
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
)

if trt.__version__ >= "8.6":
if max_aux_streams is not None:
_LOGGER.info(f"Setting max aux streams to {max_aux_streams}")
builder_config.max_aux_streams = max_aux_streams
if version_compatible:
_LOGGER.info(f"Using version compatible")
builder_config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE)
if optimization_level is not None:
_LOGGER.info(f"Using optimization level {optimization_level}")
builder_config.builder_optimization_level = optimization_level

if lower_precision == LowerPrecision.FP16:
builder_config.set_flag(trt.BuilderFlag.FP16)

Expand Down Expand Up @@ -265,6 +280,7 @@ def run(
_LOGGER.info(
f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}"
)
_LOGGER.info(f"TRT Engine uses: {engine.device_memory_size} bytes of Memory")

return TRTInterpreterResult(
engine, self._input_names, self._output_names, serialized_cache
Expand Down
3 changes: 3 additions & 0 deletions py/torch_tensorrt/fx/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
if self.lower_setting.verbose_profile
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY,
tactic_sources=self.lower_setting.tactic_sources,
max_aux_streams=self.lower_setting.max_aux_streams,
version_compatible=self.lower_setting.version_compatible,
optimization_level=self.lower_setting.optimization_level,
)

# Update timing cache file if needed
Expand Down
6 changes: 6 additions & 0 deletions py/torch_tensorrt/fx/lower_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ class LowerSetting(LowerSettingBasic):
correctness_atol: absolute tolerance for correctness check
correctness_rtol: relative tolerance for correctness check
use_experimental_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++).
max_aux_streams: max number of aux stream to use
version_compatible: enable version compatible feature
optimization_level: builder optimization level
"""

input_specs: List[InputTensorSpec] = dc.field(default_factory=list)
Expand Down Expand Up @@ -101,3 +104,6 @@ class LowerSetting(LowerSettingBasic):
correctness_atol: float = 0.1
correctness_rtol: float = 0.1
use_experimental_rt: bool = False
max_aux_streams: Optional[int] = None
version_compatible: bool = False
optimization_level: Optional[int] = None
91 changes: 49 additions & 42 deletions py/torch_tensorrt/fx/passes/pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,10 @@ def parent_pass(module: fx.GraphModule, input: Input) -> fx.GraphModule:
# (TODO(shirongwu): Add exception notification for fblearner flow when available, notify oncall
# on pass that failed accuracy check.
def validate_inference(
rtol=None, atol=None, run_alternative_batch_size: int = -1
rtol=None,
atol=None,
suppress_accuracy_check_failure=True,
run_alternative_batch_size: int = -1,
) -> "Decorator":
"""
Returns a decorator on a PassFunc to sanity check the model outputs
Expand All @@ -160,6 +163,7 @@ def validate_inference(
Args:
rtol: reletive tolerance
atol: absoluate tolerance
suppress_accuracy_check_failure: accuracy check failure
run_alternative_batch_size (int):
In addition to running inference at original batch size in the
input, also run at an alternative batch size. If set to -1, do not
Expand All @@ -181,48 +185,51 @@ def pass_with_validation(
*args,
**kwargs,
) -> fx.GraphModule:
res0 = module(*input)
processed_module = pass_(module, input, *args, **kwargs)
res1 = processed_module(*input)

tensor_res_0 = _collect_tensors(res0)
tensor_res_1 = _collect_tensors(res1)
relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE

for kk, (x, y) in enumerate(zip(tensor_res_0, tensor_res_1)):
kwargs2 = {"equal_nan": True}
if rtol:
kwargs2["rtol"] = rtol
if atol:
kwargs2["atol"] = atol
kwargs2[
"msg"
] = (
lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}"
)
# If tensors are on different devices, make sure to compare
# their copies that are on the same device.
if x.get_device() != y.get_device():
x = x.cpu()
y = y.cpu()
try:
torch.testing.assert_close(x, y, **kwargs2)
except Exception as e:
if relax_accuracy_check_failure:
_LOGGER.error(f"{e}")
kwargs2["rtol"] *= FINAL_CHECK_RTOL_MULTIPLIER
kwargs2["atol"] *= FINAL_CHECK_ATOL_MULTIPLIER
new_atol = kwargs2["atol"]
new_rtol = kwargs2["rtol"]
_LOGGER.info(
f"Do a sanity check to see whether things are completely wrong with {new_atol=}, {new_rtol=}"
)
if suppress_accuracy_check_failure:
return pass_(module, input, *args, **kwargs)
else:
res0 = module(*input)
processed_module = pass_(module, input, *args, **kwargs)
res1 = processed_module(*input)

tensor_res_0 = _collect_tensors(res0)
tensor_res_1 = _collect_tensors(res1)
relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE

for kk, (x, y) in enumerate(zip(tensor_res_0, tensor_res_1)):
kwargs2 = {"equal_nan": True}
if rtol:
kwargs2["rtol"] = rtol
if atol:
kwargs2["atol"] = atol
kwargs2[
"msg"
] = (
lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}"
)
# If tensors are on different devices, make sure to compare
# their copies that are on the same device.
if x.get_device() != y.get_device():
x = x.cpu()
y = y.cpu()
try:
torch.testing.assert_close(x, y, **kwargs2)
return processed_module
else:
raise e

return processed_module
except Exception as e:
if relax_accuracy_check_failure:
_LOGGER.error(f"{e}")
kwargs2["rtol"] *= FINAL_CHECK_RTOL_MULTIPLIER
kwargs2["atol"] *= FINAL_CHECK_ATOL_MULTIPLIER
new_atol = kwargs2["atol"]
new_rtol = kwargs2["rtol"]
_LOGGER.info(
f"Do a sanity check to see whether things are completely wrong with {new_atol=}, {new_rtol=}"
)
torch.testing.assert_close(x, y, **kwargs2)
return processed_module
else:
raise e

return processed_module

return pass_with_validation

Expand Down