Skip to content

Commit c3b62d2

Browse files
committed
TensorRT-LLM import fix and aot_joint_export specify as explicit setting in dynamo.compile
TRT-LLM installation utilities and adding test cases adding the option in _compiler.py changes in the TRT-LLM loading tool- removing install_wget, install_unzip, install_mpi Further changes in error logging of the TRT-LLM installation tool moving the load_tensorrt_llm to dynamo/utils.py correcting misprint for TRT LLM load Using python lib for download to make it platform agnostic dll file path update for windows correcting the non critical lint error Including version in versions.txt
1 parent 79083b6 commit c3b62d2

File tree

6 files changed

+144
-65
lines changed

6 files changed

+144
-65
lines changed

dev_dep_versions.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
__cuda_version__: "12.8"
22
__tensorrt_version__: "10.9.0"
3+
__tensorrt_llm_version__: "0.17.0.post1"

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def cross_compile_for_windows(
9898
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
9999
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
100100
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
101+
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
101102
**kwargs: Any,
102103
) -> torch.fx.GraphModule:
103104
"""Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows
@@ -173,6 +174,7 @@ def cross_compile_for_windows(
173174
enable_weight_streaming (bool): Enable weight streaming.
174175
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
175176
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
177+
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
176178
**kwargs: Any,
177179
Returns:
178180
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -332,6 +334,7 @@ def cross_compile_for_windows(
332334
"enable_weight_streaming": enable_weight_streaming,
333335
"tiling_optimization_level": tiling_optimization_level,
334336
"l2_limit_for_tiling": l2_limit_for_tiling,
337+
"use_distributed_mode_trace": use_distributed_mode_trace,
335338
}
336339

337340
# disable the following settings is not supported for cross compilation for windows feature
@@ -421,6 +424,7 @@ def compile(
421424
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
422425
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
423426
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
427+
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
424428
**kwargs: Any,
425429
) -> torch.fx.GraphModule:
426430
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -498,6 +502,7 @@ def compile(
498502
enable_weight_streaming (bool): Enable weight streaming.
499503
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
500504
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
505+
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
501506
**kwargs: Any,
502507
Returns:
503508
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -674,6 +679,7 @@ def compile(
674679
"enable_weight_streaming": enable_weight_streaming,
675680
"tiling_optimization_level": tiling_optimization_level,
676681
"l2_limit_for_tiling": l2_limit_for_tiling,
682+
"use_distributed_mode_trace": use_distributed_mode_trace,
677683
}
678684

679685
settings = CompilationSettings(**compilation_options)
@@ -964,6 +970,7 @@ def convert_exported_program_to_serialized_trt_engine(
964970
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
965971
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
966972
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
973+
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
967974
**kwargs: Any,
968975
) -> bytes:
969976
"""Convert an ExportedProgram to a serialized TensorRT engine
@@ -1029,6 +1036,7 @@ def convert_exported_program_to_serialized_trt_engine(
10291036
enable_weight_streaming (bool): Enable weight streaming.
10301037
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
10311038
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
1039+
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
10321040
Returns:
10331041
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
10341042
"""
@@ -1147,6 +1155,7 @@ def convert_exported_program_to_serialized_trt_engine(
11471155
"enable_weight_streaming": enable_weight_streaming,
11481156
"tiling_optimization_level": tiling_optimization_level,
11491157
"l2_limit_for_tiling": l2_limit_for_tiling,
1158+
"use_distributed_mode_trace": use_distributed_mode_trace,
11501159
}
11511160

11521161
settings = CompilationSettings(**compilation_options)

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,66 +1002,3 @@ def args_bounds_check(
10021002
args: Tuple[Argument, ...], i: int, replacement: Optional[Any] = None
10031003
) -> Any:
10041004
return args[i] if len(args) > i and args[i] is not None else replacement
1005-
1006-
1007-
def load_tensorrt_llm() -> bool:
1008-
"""
1009-
Attempts to load the TensorRT-LLM plugin and initialize it.
1010-
1011-
Returns:
1012-
bool: True if the plugin was successfully loaded and initialized, False otherwise.
1013-
"""
1014-
try:
1015-
import tensorrt_llm as trt_llm # noqa: F401
1016-
1017-
_LOGGER.info("TensorRT-LLM successfully imported")
1018-
return True
1019-
except (ImportError, AssertionError) as e_import_error:
1020-
# Check for environment variable for the plugin library path
1021-
plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH")
1022-
if not plugin_lib_path:
1023-
_LOGGER.warning(
1024-
"TensorRT-LLM is not installed. Please install TensorRT-LLM or set TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops",
1025-
)
1026-
return False
1027-
1028-
_LOGGER.info(f"TensorRT-LLM Plugin lib path found: {plugin_lib_path}")
1029-
try:
1030-
# Load the shared library
1031-
handle = ctypes.CDLL(plugin_lib_path)
1032-
_LOGGER.info(f"Successfully loaded plugin library: {plugin_lib_path}")
1033-
except OSError as e_os_error:
1034-
_LOGGER.error(
1035-
f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}"
1036-
f"Ensure the path is correct and the library is compatible",
1037-
exc_info=e_os_error,
1038-
)
1039-
return False
1040-
1041-
try:
1042-
# Configure plugin initialization arguments
1043-
handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
1044-
handle.initTrtLlmPlugins.restype = ctypes.c_bool
1045-
except AttributeError as e_plugin_unavailable:
1046-
_LOGGER.warning(
1047-
"Unable to initialize the TensorRT-LLM plugin library",
1048-
exc_info=e_plugin_unavailable,
1049-
)
1050-
return False
1051-
1052-
try:
1053-
# Initialize the plugin
1054-
TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm"
1055-
if handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8")):
1056-
_LOGGER.info("TensorRT-LLM plugin successfully initialized")
1057-
return True
1058-
else:
1059-
_LOGGER.warning("TensorRT-LLM plugin library failed in initialization")
1060-
return False
1061-
except Exception as e_initialization_error:
1062-
_LOGGER.warning(
1063-
"Exception occurred during TensorRT-LLM plugin library initialization",
1064-
exc_info=e_initialization_error,
1065-
)
1066-
return False
1067-
return False

py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
1212
dynamo_tensorrt_converter,
1313
)
14-
from torch_tensorrt.dynamo.conversion.converter_utils import load_tensorrt_llm
1514
from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import (
1615
tensorrt_fused_nccl_all_gather_op,
1716
tensorrt_fused_nccl_reduce_scatter_op,
1817
)
18+
from torch_tensorrt.dynamo.utils import load_tensorrt_llm
1919

2020
_LOGGER: logging.Logger = logging.getLogger(__name__)
2121

py/torch_tensorrt/dynamo/utils.py

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from __future__ import annotations
22

3+
import ctypes
34
import gc
45
import logging
6+
import os
7+
import urllib.request
58
import warnings
69
from dataclasses import fields, replace
710
from enum import Enum
@@ -14,9 +17,10 @@
1417
from torch._subclasses.fake_tensor import FakeTensor
1518
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
1619
from torch_tensorrt._Device import Device
17-
from torch_tensorrt._enums import dtype
20+
from torch_tensorrt._enums import Platform, dtype
1821
from torch_tensorrt._features import ENABLED_FEATURES
1922
from torch_tensorrt._Input import Input
23+
from torch_tensorrt._version import __tensorrt_llm_version__
2024
from torch_tensorrt.dynamo import _defaults
2125
from torch_tensorrt.dynamo._defaults import default_device
2226
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
@@ -812,3 +816,127 @@ def is_tegra_platform() -> bool:
812816
if torch.cuda.get_device_capability() in [(8, 7), (7, 2)]:
813817
return True
814818
return False
819+
820+
821+
def download_plugin_lib_path(py_version: str, platform: str) -> str:
822+
plugin_lib_path = None
823+
824+
# Downloading TRT-LLM lib
825+
base_url = "https://pypi.nvidia.com/tensorrt-llm/"
826+
file_name = f"tensorrt_llm-{__tensorrt_llm_version__}-{py_version}-{py_version}-{platform}.whl"
827+
download_url = base_url + file_name
828+
if not (os.path.exists(file_name)):
829+
try:
830+
logger.debug(f"Downloading {download_url} ...")
831+
urllib.request.urlretrieve(download_url, file_name)
832+
logger.debug("Download succeeded and TRT-LLM wheel is now present")
833+
except urllib.error.HTTPError as e:
834+
logger.error(
835+
f"HTTP error {e.code} when trying to download {download_url}: {e.reason}"
836+
)
837+
except urllib.error.URLError as e:
838+
logger.error(
839+
f"URL error when trying to download {download_url}: {e.reason}"
840+
)
841+
except OSError as e:
842+
logger.error(f"Local file write error: {e}")
843+
844+
# Proceeding with the unzip of the wheel file
845+
# This will exist if the filename was already downloaded
846+
if "linux" in platform:
847+
lib_filename = "libnvinfer_plugin_tensorrt_llm.so"
848+
else:
849+
lib_filename = "libnvinfer_plugin_tensorrt_llm.dll"
850+
plugin_lib_path = os.path.join("./tensorrt_llm/libs", lib_filename)
851+
if os.path.exists(plugin_lib_path):
852+
return plugin_lib_path
853+
try:
854+
import zipfile
855+
except ImportError as e:
856+
raise ImportError(
857+
"zipfile module is required but not found. Please install zipfile"
858+
)
859+
with zipfile.ZipFile(file_name, "r") as zip_ref:
860+
zip_ref.extractall(".") # Extract to a folder named 'tensorrt_llm'
861+
plugin_lib_path = "./tensorrt_llm/libs/" + lib_filename
862+
return plugin_lib_path
863+
864+
865+
def load_tensorrt_llm() -> bool:
866+
"""
867+
Attempts to load the TensorRT-LLM plugin and initialize it.
868+
Either the env variable TRTLLM_PLUGINS_PATH can specify the path
869+
Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it
870+
871+
Returns:
872+
bool: True if the plugin was successfully loaded and initialized, False otherwise.
873+
"""
874+
plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH")
875+
if not plugin_lib_path:
876+
# this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user
877+
use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in (
878+
"1",
879+
"true",
880+
"yes",
881+
"on",
882+
)
883+
if not use_trtllm_plugin:
884+
logger.warning(
885+
"Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library. Please set either of the two to use TRT-LLM libraries in torchTRT"
886+
)
887+
return False
888+
else:
889+
# this is used as the default py version
890+
py_version = "cp310"
891+
platform = Platform.current_platform()
892+
893+
platform = str(platform).lower()
894+
plugin_lib_path = download_plugin_lib_path(py_version, platform)
895+
896+
try:
897+
# Load the shared TRT-LLM file
898+
handle = ctypes.CDLL(plugin_lib_path)
899+
logger.info(f"Successfully loaded plugin library: {plugin_lib_path}")
900+
except OSError as e_os_error:
901+
if "libmpi" in str(e_os_error):
902+
logger.warning(
903+
f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}. "
904+
f"The dependency libmpi.so is missing. "
905+
f"Please install the packages libmpich-dev and libopenmpi-dev.",
906+
exc_info=e_os_error,
907+
)
908+
else:
909+
logger.warning(
910+
f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}"
911+
f"Ensure the path is correct and the library is compatible",
912+
exc_info=e_os_error,
913+
)
914+
return False
915+
916+
try:
917+
# Configure plugin initialization arguments
918+
handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
919+
handle.initTrtLlmPlugins.restype = ctypes.c_bool
920+
except AttributeError as e_plugin_unavailable:
921+
logger.warning(
922+
"Unable to initialize the TensorRT-LLM plugin library",
923+
exc_info=e_plugin_unavailable,
924+
)
925+
return False
926+
927+
try:
928+
# Initialize the plugin
929+
TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm"
930+
if handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8")):
931+
logger.info("TensorRT-LLM plugin successfully initialized")
932+
return True
933+
else:
934+
logger.warning("TensorRT-LLM plugin library failed in initialization")
935+
return False
936+
except Exception as e_initialization_error:
937+
logger.warning(
938+
"Exception occurred during TensorRT-LLM plugin library initialization",
939+
exc_info=e_initialization_error,
940+
)
941+
return False
942+
return False

setup.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
__version__: str = "0.0.0"
2929
__cuda_version__: str = "0.0"
3030
__tensorrt_version__: str = "0.0"
31+
__tensorrt_llm_version__: str = "0.0"
3132

3233
LEGACY_BASE_VERSION_SUFFIX_PATTERN = re.compile("a0$")
3334

@@ -63,6 +64,7 @@ def get_base_version() -> str:
6364
def load_dep_info():
6465
global __cuda_version__
6566
global __tensorrt_version__
67+
global __tensorrt_llm_version__
6668
with open("dev_dep_versions.yml", "r") as stream:
6769
versions = yaml.safe_load(stream)
6870
if (gpu_arch_version := os.environ.get("CU_VERSION")) is not None:
@@ -72,6 +74,7 @@ def load_dep_info():
7274
else:
7375
__cuda_version__ = versions["__cuda_version__"]
7476
__tensorrt_version__ = versions["__tensorrt_version__"]
77+
__tensorrt_llm_version__ = versions["__tensorrt_llm_version__"]
7578

7679

7780
load_dep_info()
@@ -249,6 +252,7 @@ def gen_version_file():
249252
f.write('__version__ = "' + __version__ + '"\n')
250253
f.write('__cuda_version__ = "' + __cuda_version__ + '"\n')
251254
f.write('__tensorrt_version__ = "' + __tensorrt_version__ + '"\n')
255+
f.write('__tensorrt_llm_version__ = "' + __tensorrt_llm_version__ + '"\n')
252256

253257

254258
def copy_libtorchtrt(multilinux=False, rt_only=False):

0 commit comments

Comments
 (0)