Skip to content

Commit 12ecb86

Browse files
authored
[None][chore] share input_ids buffers among different cuda graphs (#7236)
Signed-off-by: junq <[email protected]>
1 parent 12c66f7 commit 12ecb86

File tree

3 files changed

+44
-23
lines changed

3 files changed

+44
-23
lines changed

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,42 @@ def __init__(self, engine: "PyTorchModelEngine"):
4040
self.max_beam_width = engine.max_beam_width
4141
self.spec_config = engine.spec_config
4242

43+
self.max_possible_draft_len = (self.spec_config.max_draft_len
44+
if self.enable_spec_decode else 0)
45+
4346
self.graphs: Dict[Tuple[int, int], torch.cuda.CUDAGraph] = {}
44-
self.static_inputs: Dict[Tuple[int, int], Dict[str, torch.Tensor]] = {}
4547
self.graph_outputs: Dict[Tuple[int, int],
4648
Callable[[], Optional[torch.Tensor]]] = {}
4749
self.graph_metadata: Dict[Tuple[int, int], Dict[str, Any]] = {}
4850
self.memory_pool = engine._cuda_graph_mem_pool
4951
self.padding_dummy_request: Optional["Request"] = None
5052

53+
self.shared_static_tensors: Dict[str, torch.Tensor] = {}
54+
if self.enabled:
55+
self._create_shared_static_tensors()
56+
57+
def _create_shared_static_tensors(self):
58+
"""Allocates static tensors sized for the largest possible batch."""
59+
engine = self._get_engine()
60+
61+
token_per_request = self.max_possible_draft_len + 1
62+
max_total_tokens = (self.max_supported_batch_size *
63+
self.max_beam_width * token_per_request)
64+
max_total_tokens = min(max_total_tokens, engine.max_num_tokens)
65+
66+
self.shared_static_tensors = {
67+
"input_ids":
68+
torch.ones((max_total_tokens, ), device="cuda", dtype=torch.int32),
69+
"position_ids":
70+
torch.zeros((1, max_total_tokens), device="cuda",
71+
dtype=torch.int32),
72+
}
73+
if engine.use_mrope:
74+
self.shared_static_tensors["mrope_position_deltas"] = torch.zeros(
75+
(self.max_supported_batch_size, 1),
76+
device="cuda",
77+
dtype=torch.int32)
78+
5179
@property
5280
def enable_spec_decode(self):
5381
return self._get_engine().is_spec_decode
@@ -139,38 +167,32 @@ def needs_capture(self, batch_size: int):
139167
def capture(self, batch_size: int, forward_fn: Callable,
140168
initial_inputs: Dict[str, Any]):
141169
"""Captures the forward pass for a given batch size."""
142-
engine = self._get_engine()
143170
key = (batch_size, self.draft_len)
144-
spec_metadata = initial_inputs.get("spec_metadata", None)
145171
# [CUDA graph spec decode padding]
146172
# We pad input IDs/position IDs to the maximum draft length (token per request).
147173
# We're forced to do this because we cannot reallocate inputs over many graph runs.
148-
token_per_request = spec_metadata.max_draft_len + 1 if spec_metadata is not None else 1
174+
token_per_request = self.max_possible_draft_len + 1
175+
num_tokens_for_capture = (batch_size * self.max_beam_width *
176+
token_per_request)
149177

150-
static_tensors = {
178+
sliced_static_tensors = {
151179
"input_ids":
152-
torch.ones((batch_size * self.max_beam_width * token_per_request, ),
153-
device="cuda",
154-
dtype=torch.int32),
180+
self.shared_static_tensors["input_ids"][:num_tokens_for_capture],
155181
"position_ids":
156-
torch.zeros((
157-
1,
158-
batch_size * self.max_beam_width * token_per_request,
159-
),
160-
device="cuda",
161-
dtype=torch.int32),
182+
self.shared_static_tensors["position_ids"]
183+
[:, :num_tokens_for_capture],
162184
}
163-
if engine.use_mrope:
164-
static_tensors["mrope_position_deltas"] = torch.zeros(
165-
(batch_size, 1), device="cuda", dtype=torch.int32)
166-
self.static_inputs[key] = static_tensors
185+
if "mrope_position_deltas" in self.shared_static_tensors:
186+
sliced_static_tensors["mrope_position_deltas"] = \
187+
self.shared_static_tensors["mrope_position_deltas"][:batch_size]
167188

189+
# Use the sliced tensors for capture
168190
capture_inputs = initial_inputs.copy()
169-
capture_inputs.update(static_tensors)
191+
capture_inputs.update(sliced_static_tensors)
170192

171193
self.graph_metadata[key] = {
172194
"attn_metadata": initial_inputs["attn_metadata"],
173-
"spec_metadata": spec_metadata,
195+
"spec_metadata": initial_inputs.get("spec_metadata", None),
174196
}
175197

176198
# We have to do warm up runs to initialize PyTorch's
@@ -198,7 +220,7 @@ def replay(self, batch_size: int,
198220
assert current_inputs.get(
199221
"spec_metadata") is stored_meta["spec_metadata"]
200222

201-
static_tensors = self.static_inputs[key]
223+
static_tensors = self.shared_static_tensors
202224

203225
input_ids = current_inputs["input_ids"]
204226
seqlen = input_ids.shape[0]
@@ -301,7 +323,6 @@ def clear(self):
301323
for graph in self.graphs.values():
302324
graph.reset()
303325
self.graphs.clear()
304-
self.static_inputs.clear()
305326
self.graph_outputs.clear()
306327
self.graph_metadata.clear()
307328
self.padding_dummy_request = None

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,6 @@ def __init__(
426426
# the model engine.
427427
self.attn_metadata = None
428428
self.iter_states = {}
429-
self._cuda_graphs = {}
430429
self._cuda_graph_mem_pool = self._torch_compile_backend._graph_pool_handle if self._torch_compile_enabled else None
431430

432431
self._cuda_graph_padding_enabled = pytorch_backend_config.cuda_graph_padding_enabled

tests/unittest/_torch/helpers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ def create_mock_engine(batch_size: int):
186186
_cuda_graph_batch_sizes=[batch_size],
187187
_max_cuda_graph_batch_size=batch_size,
188188
max_beam_width=1,
189+
max_num_tokens=8192,
189190
is_spec_decode=False,
190191
spec_config=None,
191192
_cuda_graph_mem_pool=None,

0 commit comments

Comments
 (0)