Skip to content

Commit f77df5a

Browse files
committed
Squash the commits
1 parent 1d038a1 commit f77df5a

File tree

9 files changed

+248
-32
lines changed

9 files changed

+248
-32
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
.. _resource_management:
2+
3+
Resource Management
4+
===================
5+
6+
Overview
7+
--------
8+
9+
Efficient control of CPU and GPU memory is essential for successful model compilation,
10+
especially when working with large models such as LLMs or diffusion models.
11+
Uncontrolled memory growth can cause compilation failures or process termination.
12+
This guide describes the symptoms of excessive memory usage and provides methods
13+
to reduce both CPU and GPU memory consumption.
14+
15+
Memory Usage Control
16+
--------------------
17+
18+
CPU Memory
19+
^^^^^^^^^^
20+
21+
By default, Torch-TensorRT may consume up to **** the model size in CPU memory.
22+
This can exceed system limits when compiling large models.
23+
24+
**Common symptoms of high CPU memory usage:**
25+
26+
- Program freeze
27+
- Process terminated by the operating system
28+
29+
**Ways to lower CPU memory usage:**
30+
31+
1. **Enable memory trimming**
32+
33+
Set the following environment variable:
34+
35+
.. code-block:: bash
36+
37+
export TRIM_CPU_MEMORY=1
38+
39+
This reduces approximately **** of redundant model copies, limiting
40+
total CPU memory usage to up to **** the model size.
41+
42+
2. **Disable CPU offloading**
43+
44+
In compilation settings, set:
45+
46+
.. code-block:: python
47+
48+
offload_module_to_cpu = False
49+
50+
This removes another **** model copy, reducing peak CPU memory
51+
usage to about **** the model size.
52+
53+
GPU Memory
54+
^^^^^^^^^^
55+
56+
By default, Torch-TensorRT may consume up to **** the model size in GPU memory.
57+
58+
**Common symptoms of high GPU memory usage:**
59+
60+
- CUDA out-of-memory errors
61+
- TensorRT compilation errors
62+
63+
**Ways to lower GPU memory usage:**
64+
65+
1. **Enable offloading to CPU**
66+
67+
In compilation settings, set:
68+
69+
.. code-block:: python
70+
71+
offload_module_to_cpu = True
72+
73+
This shifts one model copy from GPU to CPU memory.
74+
As a result, peak GPU memory usage decreases to about ****
75+
the model size, while CPU memory usage increases by roughly ****.
76+
77+

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
)
4343
from torch_tensorrt.dynamo.utils import (
4444
deallocate_module,
45+
get_cpu_memory_usage,
4546
get_flat_args_with_check,
4647
get_output_metadata,
4748
parse_graph_io,
@@ -681,7 +682,7 @@ def compile(
681682
"offload_module_to_cpu": offload_module_to_cpu,
682683
"use_distributed_mode_trace": use_distributed_mode_trace,
683684
}
684-
685+
logger.debug(f"CPU memory usage before lowering: {get_cpu_memory_usage()} MB")
685686
settings = CompilationSettings(**compilation_options)
686687
logger.info("Compilation Settings: %s\n", settings)
687688
exported_program = pre_export_lowering(exported_program, settings)
@@ -695,14 +696,17 @@ def compile(
695696

696697
# Apply lowering on the graph module
697698
gm = post_lowering(gm, settings)
699+
logger.debug(f"CPU memory usage after post_lowering: {get_cpu_memory_usage()} MB")
698700
logger.debug("Lowered Input graph: " + str(gm.graph))
699701

700702
# Move the weights in the state_dict to CPU
701703
if offload_module_to_cpu:
704+
deallocate_module(gm, delete_module=False)
702705
deallocate_module(exported_program.module(), delete_module=False)
703706
logger.info(
704707
"The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
705708
)
709+
logger.debug(f"CPU memory usage after CPU offload: {get_cpu_memory_usage()} MB")
706710
else:
707711
remaining_memory, total_memory = torch.cuda.mem_get_info()
708712
if remaining_memory < total_memory // 2:
@@ -868,6 +872,11 @@ def preserve_module_specs(
868872
# Iterate over all components that can be accelerated
869873
# Generate the corresponding TRT Module for those
870874

875+
# Here we delete the frozen parameters from the graph module. Note this does not affect the submodules. We are going to delete the frozen parameters from the submodules in the convert_module function.
876+
# This is done to release CPU memory.
877+
for attr in dir(gm):
878+
if attr.startswith("_frozen_param"):
879+
delattr(gm, attr)
871880
for name, _ in partitioned_module.named_children():
872881
submodule = getattr(partitioned_module, name)
873882
# filter on the GraphModule
@@ -1243,7 +1252,7 @@ def convert_exported_program_to_serialized_trt_engine(
12431252

12441253
# Prepare torch_trt inputs
12451254
trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs)
1246-
trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs)
1255+
trt_kwarg_inputs: Optional[dict[str, Any]] = prepare_inputs(kwarg_inputs)
12471256
device = to_torch_tensorrt_device(device)
12481257
enabled_precisions = {dtype._from(p) for p in enabled_precisions}
12491258

@@ -1330,7 +1339,7 @@ def convert_exported_program_to_serialized_trt_engine(
13301339
)
13311340

13321341
flattened_input_list = get_flat_args_with_check(
1333-
exported_program, list(trt_arg_inputs), trt_kwarg_inputs
1342+
exported_program, list(trt_arg_inputs), trt_kwarg_inputs # type: ignore
13341343
)[0]
13351344

13361345
try:

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,12 @@
5050
from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig
5151
from torch_tensorrt.dynamo.debug._supports_debugger import cls_supports_debugger
5252
from torch_tensorrt.dynamo.observer import Observer
53-
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, deallocate_module, to_torch_device
53+
from torch_tensorrt.dynamo.utils import (
54+
DYNAMIC_DIM,
55+
deallocate_module,
56+
get_cpu_memory_usage,
57+
to_torch_device,
58+
)
5459
from torch_tensorrt.logging import TRT_LOGGER
5560

5661
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -65,7 +70,7 @@ class UnsupportedOperatorException(RuntimeError):
6570

6671

6772
class TRTInterpreterResult(NamedTuple):
68-
serialized_engine: bytes
73+
engine: trt.ICudaEngine
6974
input_names: Sequence[str]
7075
output_names: Sequence[str]
7176
weight_name_map: Optional[dict[Any, Any]]
@@ -512,8 +517,7 @@ def _save_weight_mapping(self) -> None:
512517
_LOGGER.info("Building weight name mapping...")
513518
# Stage 1: Name mapping
514519
torch_device = to_torch_device(self.compilation_settings.device)
515-
self.module.to(torch_device)
516-
sd = self.module.state_dict()
520+
sd = {k: v.to(torch_device) for k, v in self.module.state_dict().items()}
517521
weight_name_map: dict[str, Any] = {}
518522
weight_refit_map = self.ctx.weight_refit_map
519523
constant_mapping = {k: v for k, v in weight_refit_map.items() if v.size == 1}
@@ -592,13 +596,11 @@ def _save_weight_mapping(self) -> None:
592596
torch.cuda.empty_cache()
593597

594598
@needs_refit # type: ignore[misc]
595-
def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> None:
599+
def _insert_engine_to_cache(self, hash_val: str, engine: trt.ICudaEngine) -> None:
600+
serialized_engine = engine.serialize()
596601
# TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine
597602
# if not self.compilation_settings.strip_engine_weights:
598603
# # set EXCLUDE_WEIGHTS flag to strip weights
599-
# runtime = trt.Runtime(TRT_LOGGER)
600-
# engine = runtime.deserialize_cuda_engine(serialized_engine)
601-
602604
# serialization_config = engine.create_serialization_config()
603605
# serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
604606
# serialized_engine = engine.serialize_with_config(
@@ -733,6 +735,9 @@ def run(
733735
return interpreter_result # type: ignore[no-any-return]
734736

735737
self._construct_trt_network_def()
738+
_LOGGER.debug(
739+
f"CPU memory usage after network construction: {get_cpu_memory_usage()} MB"
740+
)
736741

737742
if not self.compilation_settings.immutable_weights:
738743
self._save_weight_mapping()
@@ -750,16 +755,19 @@ def run(
750755
self._create_timing_cache(
751756
builder_config, self.compilation_settings.timing_cache_path
752757
)
753-
serialized_engine = self.builder.build_serialized_network(
758+
759+
cuda_engine = self.builder.build_engine_with_config(
754760
self.ctx.net, builder_config
755761
)
756-
assert serialized_engine
762+
assert cuda_engine
763+
764+
_LOGGER.debug(
765+
f"CPU memory usage after engine building: {get_cpu_memory_usage()} MB"
766+
)
757767

758768
_LOGGER.info(
759769
f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}"
760770
)
761-
_LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory")
762-
763771
self.ctx.clear_cpu_weights_reference_holder()
764772

765773
self._save_timing_cache(
@@ -772,14 +780,10 @@ def run(
772780
and self.compilation_settings.cache_built_engines
773781
and self.engine_cache is not None
774782
):
775-
self._insert_engine_to_cache(hash_val, serialized_engine)
776-
777-
with io.BytesIO() as engine_bytes:
778-
engine_bytes.write(serialized_engine)
779-
engine_str = engine_bytes.getvalue()
783+
self._insert_engine_to_cache(hash_val, cuda_engine)
780784

781785
return TRTInterpreterResult(
782-
engine_str,
786+
cuda_engine,
783787
self._input_names,
784788
self._output_names,
785789
self.weight_name_map,

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,34 @@
11
from __future__ import annotations
22

3+
import io
34
import logging
4-
from typing import Any, List, Optional, Sequence
5+
from typing import Any, List, NamedTuple, Optional, Sequence
56

67
import torch
78
from torch_tensorrt._enums import dtype
89
from torch_tensorrt._features import ENABLED_FEATURES
910
from torch_tensorrt._Input import Input
1011
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
1112
from torch_tensorrt.dynamo._settings import CompilationSettings
12-
from torch_tensorrt.dynamo.conversion._TRTInterpreter import (
13-
TRTInterpreter,
14-
TRTInterpreterResult,
15-
)
13+
from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter
1614
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
17-
from torch_tensorrt.dynamo.utils import get_output_dtypes
15+
from torch_tensorrt.dynamo.utils import (
16+
get_cpu_memory_usage,
17+
get_output_dtypes,
18+
release_memory,
19+
)
1820

1921
logger = logging.getLogger(__name__)
2022

2123

24+
class SerializedInterpreterResult(NamedTuple):
25+
serialized_engine: bytes
26+
input_names: Sequence[str]
27+
output_names: Sequence[str]
28+
weight_name_map: Optional[dict[Any, Any]]
29+
requires_output_allocator: bool
30+
31+
2232
def infer_module_output_dtypes(
2333
module: torch.fx.GraphModule,
2434
truncate_double: bool = False,
@@ -29,7 +39,7 @@ def infer_module_output_dtypes(
2939
"""
3040
outputs = [node for node in module.graph.nodes if node.op == "output"]
3141
outputs = outputs[0].args
32-
return get_output_dtypes(outputs, truncate_double) # type: ignore[no-any-return]
42+
return get_output_dtypes(outputs, truncate_double)
3343

3444

3545
def interpret_module_to_result(
@@ -39,7 +49,7 @@ def interpret_module_to_result(
3949
arg_inputs: Optional[Sequence[Input]] = None,
4050
kwarg_inputs: Optional[dict[str, Any]] = None,
4151
engine_cache: Optional[BaseEngineCache] = None,
42-
) -> TRTInterpreterResult:
52+
) -> SerializedInterpreterResult:
4353
"""Interpret an FX module to a TRTInterpreterResult
4454
Args:
4555
module: FX GraphModule to interpret
@@ -65,7 +75,32 @@ def interpret_module_to_result(
6575
)
6676

6777
interpreter_result = interpreter.run()
68-
return interpreter_result
78+
# Delete the frozen parameters from the module to release CPU memory
79+
del interpreter
80+
for attr in dir(module):
81+
if attr.startswith("_frozen_param"):
82+
delattr(module, attr)
83+
release_memory()
84+
logger.debug(
85+
f"CPU memory usage after clearing frozen parameters and building memory in conversion: {get_cpu_memory_usage()} MB"
86+
)
87+
88+
serialized_engine = interpreter_result.engine.serialize()
89+
with io.BytesIO() as engine_bytes:
90+
engine_bytes.write(serialized_engine)
91+
serialized_engine = engine_bytes.getvalue()
92+
logger.debug(
93+
f"CPU memory usage after serializing engine: {get_cpu_memory_usage()} MB"
94+
)
95+
serialized_interpreter_result = SerializedInterpreterResult(
96+
serialized_engine=serialized_engine,
97+
input_names=interpreter_result.input_names,
98+
output_names=interpreter_result.output_names,
99+
weight_name_map=interpreter_result.weight_name_map,
100+
requires_output_allocator=interpreter_result.requires_output_allocator,
101+
)
102+
103+
return serialized_interpreter_result
69104

70105

71106
def convert_module(

py/torch_tensorrt/dynamo/debug/_Debugger.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ def get_logging_config(self, log_level: Optional[int] = None) -> dict[str, Any]:
220220
"class": "logging.FileHandler",
221221
"filename": f"{self.cfg.logging_dir}/torch_tensorrt_logging.log",
222222
"formatter": "standard",
223+
"mode": "w", # This will clear the previous content
223224
}
224225
config["loggers"][""]["handlers"].append("file")
225226
return config

py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ def constant_fold(
3737
# For TRT INetwork construction the constants are moved to CPU in get_attr call.
3838
for node, constant in cf.node_replacements.items():
3939
replace_node_with_constant(
40-
gm, node, torch.nn.Parameter(constant, requires_grad=False)
40+
gm,
41+
node,
42+
torch.nn.Parameter(constant.cpu().contiguous(), requires_grad=False),
4143
)
4244

4345
erased_params = []

0 commit comments

Comments
 (0)