Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upstream 3 features for torch-trt: ms, VC, and optimisation level #1926

Closed
wants to merge 11 commits into from
46 changes: 46 additions & 0 deletions py/torch_tensorrt/fx/fx2trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ def run(
timing_cache=None,
profiling_verbosity=None,
tactic_sources=None,
max_aux_streams=None,
version_compatible=False,
optimization_level=None,
) -> TRTInterpreterResult:
"""
Build TensorRT engine with some configs.
Expand Down Expand Up @@ -225,6 +228,18 @@ def run(
if profiling_verbosity
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
)

if trt.__version__ >= "8.6":
if max_aux_streams is not None:
_LOGGER.info(f"Setting max aux streams to {max_aux_streams}")
builder_config.max_aux_streams = max_aux_streams
if version_compatible:
_LOGGER.info(f"Using version compatible")
builder_config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE)
if optimization_level is not None:
_LOGGER.info(f"Using optimization level {optimization_level}")
builder_config.builder_optimization_level = optimization_level

if lower_precision == LowerPrecision.FP16:
builder_config.set_flag(trt.BuilderFlag.FP16)

Expand All @@ -251,6 +266,34 @@ def run(
engine = self.builder.build_engine(self.network, builder_config)
assert engine

import os
def get_file_name(org):
file_name = org
i = 0
while os.path.exists(os.path.abspath(file_name)):
i += 1
file_name = org + str(i)
return file_name

engine_file = os.environ.get('TORCH_FX_DUMP_ENGINE')
wu6u3tw marked this conversation as resolved.
Show resolved Hide resolved
if engine_file:
dump_file = get_file_name(engine_file)
print(f'Dumping engine to {dump_file}')
s = engine.serialize()
with open(dump_file, 'wb') as f:
f.write(s)
engine_info_file = os.environ.get('TORCH_FX_DUMP_ENGINE_INFO')
wu6u3tw marked this conversation as resolved.
Show resolved Hide resolved
if engine_info_file:
inspector = engine.create_engine_inspector()
engine_info = inspector.get_engine_information(trt.LayerInformationFormat.JSON)
if engine_info is None or len(engine_info) == 0:
raise Exception('Engine info is empty')
else:
dump_file = get_file_name(engine_info_file)
print(f'Dumping engine info to {dump_file}')
with open(dump_file, 'w') as f:
f.write(engine_info)

serialized_cache = (
bytearray(cache.serialize())
if builder_config.get_timing_cache()
Expand All @@ -259,6 +302,9 @@ def run(
_LOGGER.info(
f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}"
)
_LOGGER.info(
f"TRT Engine uses: {engine.device_memory_size} Memory"
wu6u3tw marked this conversation as resolved.
Show resolved Hide resolved
)

return TRTInterpreterResult(
engine, self._input_names, self._output_names, serialized_cache
Expand Down
3 changes: 3 additions & 0 deletions py/torch_tensorrt/fx/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
if self.lower_setting.verbose_profile
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY,
tactic_sources=self.lower_setting.tactic_sources,
max_aux_streams=self.lower_setting.max_aux_streams,
version_compatible=self.lower_setting.version_compatible,
optimization_level=self.lower_setting.optimization_level,
)

# Update timing cache file if needed
Expand Down
6 changes: 6 additions & 0 deletions py/torch_tensorrt/fx/lower_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ class LowerSetting(LowerSettingBasic):
correctness_atol: absolute tolerance for correctness check
correctness_rtol: relative tolerance for correctness check
use_experimental_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++).
max_aux_streams: max number of aux stream to use
version_compatible: enable version compatible feature
optimization_level: builder optimization level
"""

input_specs: List[InputTensorSpec] = dc.field(default_factory=list)
Expand Down Expand Up @@ -101,3 +104,6 @@ class LowerSetting(LowerSettingBasic):
correctness_atol: float = 0.1
correctness_rtol: float = 0.1
use_experimental_rt: bool = False
max_aux_streams: Optional[int] = None
version_compatible: bool = False
optimization_level: Optional[int] = None
87 changes: 45 additions & 42 deletions py/torch_tensorrt/fx/passes/pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def parent_pass(module: fx.GraphModule, input: Input) -> fx.GraphModule:

# (TODO(shirongwu): Add exception notification for fblearner flow when available, notify oncall
# on pass that failed accuracy check.
def validate_inference(rtol=None, atol=None):
def validate_inference(rtol=None, atol=None, suppress_accuracy_check_failure=True):
def _validate_inference(pass_: PassFunc) -> PassFunc:
"""
Wraps a pass function to validate that its inference results before and
Expand All @@ -114,48 +114,51 @@ def pass_with_validation(
*args,
**kwargs,
) -> fx.GraphModule:
res0 = module(*input)
processed_module = pass_(module, input, *args, **kwargs)
res1 = processed_module(*input)

tensor_res_0 = _collect_tensors(res0)
tensor_res_1 = _collect_tensors(res1)
relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE

for kk, (x, y) in enumerate(zip(tensor_res_0, tensor_res_1)):
kwargs2 = {"equal_nan": True}
if rtol:
kwargs2["rtol"] = rtol
if atol:
kwargs2["atol"] = atol
kwargs2[
"msg"
] = (
lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}"
)
# If tensors are on different devices, make sure to compare
# their copies that are on the same device.
if x.get_device() != y.get_device():
x = x.cpu()
y = y.cpu()
try:
torch.testing.assert_close(x, y, **kwargs2)
except Exception as e:
if relax_accuracy_check_failure:
_LOGGER.error(f"{e}")
kwargs2["rtol"] *= FINAL_CHECK_RTOL_MULTIPLIER
kwargs2["atol"] *= FINAL_CHECK_ATOL_MULTIPLIER
new_atol = kwargs2["atol"]
new_rtol = kwargs2["rtol"]
_LOGGER.info(
f"Do a sanity check to see whether things are completely wrong with {new_atol=}, {new_rtol=}"
)
if suppress_accuracy_check_failure:
return pass_(module, input, *args, **kwargs)
else:
res0 = module(*input)
processed_module = pass_(module, input, *args, **kwargs)
res1 = processed_module(*input)

tensor_res_0 = _collect_tensors(res0)
tensor_res_1 = _collect_tensors(res1)
relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE

for kk, (x, y) in enumerate(zip(tensor_res_0, tensor_res_1)):
kwargs2 = {"equal_nan": True}
if rtol:
kwargs2["rtol"] = rtol
if atol:
kwargs2["atol"] = atol
kwargs2[
"msg"
] = (
lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}"
)
# If tensors are on different devices, make sure to compare
# their copies that are on the same device.
if x.get_device() != y.get_device():
x = x.cpu()
y = y.cpu()
try:
torch.testing.assert_close(x, y, **kwargs2)
return processed_module
else:
raise e

return processed_module
except Exception as e:
if relax_accuracy_check_failure:
_LOGGER.error(f"{e}")
kwargs2["rtol"] *= FINAL_CHECK_RTOL_MULTIPLIER
kwargs2["atol"] *= FINAL_CHECK_ATOL_MULTIPLIER
new_atol = kwargs2["atol"]
new_rtol = kwargs2["rtol"]
_LOGGER.info(
f"Do a sanity check to see whether things are completely wrong with {new_atol=}, {new_rtol=}"
)
torch.testing.assert_close(x, y, **kwargs2)
return processed_module
else:
raise e

return processed_module

return pass_with_validation

Expand Down