Skip to content

Commit d14f687

Browse files
committed
init
Signed-off-by: junq <[email protected]>
1 parent 76a47c7 commit d14f687

14 files changed

+175
-202
lines changed

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 75 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import bisect
22
import contextlib
3-
import weakref
3+
from dataclasses import dataclass
44
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple
55

66
import torch
@@ -16,12 +16,35 @@
1616
from .scheduler import ScheduledRequests
1717

1818
if TYPE_CHECKING:
19-
from .model_engine import PyTorchModelEngine
19+
from ..distributed import MPIDist
20+
from ..mapping import Mapping
21+
from ..speculative import DecodingBaseConfig
2022

2123
# A large prime number used for dummy request IDs to avoid collisions
2224
CUDA_GRAPH_DUMMY_REQUEST_ID = (1 << 64) - 1
2325

2426

27+
@dataclass
28+
class CUDAGraphRunnerConfig:
29+
"""Configuration for the CUDAGraphRunner, passed from the ModelEngine."""
30+
use_cuda_graph: bool
31+
cuda_graph_padding_enabled: bool
32+
cuda_graph_batch_sizes: list[int]
33+
max_cuda_graph_batch_size: int
34+
max_beam_width: int
35+
max_num_tokens: int
36+
spec_config: Optional["DecodingBaseConfig"]
37+
cuda_graph_mem_pool: Any
38+
use_mrope: bool
39+
original_max_draft_len: int
40+
is_draft_model: bool
41+
enable_attention_dp: bool
42+
batch_size: int
43+
mapping: Optional["Mapping"]
44+
dist: Optional["MPIDist"]
45+
kv_cache_manager_key: Any
46+
47+
2548
class CUDAGraphRunner:
2649
"""
2750
Manages the lifecycle and execution of CUDA graphs for the model engine.
@@ -32,23 +55,22 @@ class CUDAGraphRunner:
3255
"""
3356
WARMUP_STEPS = 2
3457

35-
def __init__(self, engine: "PyTorchModelEngine"):
36-
self.engine_ref = weakref.ref(engine)
58+
def __init__(self, config: CUDAGraphRunnerConfig):
59+
self.config = config
3760

38-
# High-level configuration
39-
config = engine.pytorch_backend_config
61+
# High-level configuration from the config object
4062
self.enabled = config.use_cuda_graph
4163
self.padding_enabled = config.cuda_graph_padding_enabled
42-
self.supported_batch_sizes = engine._cuda_graph_batch_sizes
43-
self.max_supported_batch_size = engine._max_cuda_graph_batch_size
44-
self.max_beam_width = engine.max_beam_width
45-
self.spec_config = engine.spec_config
64+
self.supported_batch_sizes = config.cuda_graph_batch_sizes
65+
self.max_supported_batch_size = config.max_cuda_graph_batch_size
66+
self.max_beam_width = config.max_beam_width
67+
self.spec_config = config.spec_config
4668

4769
self.graphs: Dict[Tuple[int, int, int], torch.cuda.CUDAGraph] = {}
4870
self.graph_outputs: Dict[Tuple[int, int, int],
4971
Callable[[], Optional[torch.Tensor]]] = {}
5072
self.graph_metadata: Dict[Tuple[int, int, int], Dict[str, Any]] = {}
51-
self.memory_pool = engine._cuda_graph_mem_pool
73+
self.memory_pool = config.cuda_graph_mem_pool
5274
self.padding_dummy_request: Optional["Request"] = None
5375

5476
self.shared_static_tensors: Dict[str, torch.Tensor] = {}
@@ -58,12 +80,11 @@ def __init__(self, engine: "PyTorchModelEngine"):
5880

5981
def _create_shared_static_tensors(self):
6082
"""Allocates static tensors sized for the largest possible batch."""
61-
engine = self._get_engine()
62-
63-
token_per_request = self.max_possible_draft_len + 1
83+
max_draft_len = self.config.original_max_draft_len if self.config.is_spec_decode else 0
84+
token_per_request = max_draft_len + 1
6485
max_total_tokens = (self.max_supported_batch_size *
6586
self.max_beam_width * token_per_request)
66-
max_total_tokens = min(max_total_tokens, engine.max_num_tokens)
87+
max_total_tokens = min(max_total_tokens, self.config.max_num_tokens)
6788

6889
self.shared_static_tensors = {
6990
"input_ids":
@@ -72,7 +93,7 @@ def _create_shared_static_tensors(self):
7293
torch.zeros((1, max_total_tokens), device="cuda",
7394
dtype=torch.int32),
7495
}
75-
if engine.use_mrope:
96+
if self.config.use_mrope:
7697
self.shared_static_tensors["position_ids"] = torch.zeros(
7798
(3, 1, max_total_tokens), device="cuda", dtype=torch.int32)
7899
self.shared_static_tensors["multimodal_params"] = [
@@ -86,55 +107,31 @@ def _create_shared_static_tensors(self):
86107
}) for _ in range(max_total_tokens)
87108
]
88109

89-
@property
90-
def enable_spec_decode(self):
91-
return self._get_engine().enable_spec_decode
92-
93-
@property
94-
def max_possible_draft_len(self):
95-
engine = self._get_engine()
96-
return (engine.original_max_draft_len if self.enable_spec_decode else 0)
97-
98110
def get_graph_key(
99111
self,
100112
batch_size,
113+
enable_spec_decode: bool,
101114
spec_resource_manager: Optional[BaseResourceManager] = None):
102-
engine = self._get_engine()
103-
if engine.is_draft_model and spec_resource_manager is not None and isinstance(
115+
if self.config.is_draft_model and spec_resource_manager is not None and isinstance(
104116
spec_resource_manager, Eagle3ResourceManager):
105-
draft_len = engine.original_max_draft_len if spec_resource_manager.is_first_draft else 0
117+
draft_len = self.config.original_max_draft_len if spec_resource_manager.is_first_draft else 0
106118
key = (batch_size, draft_len, spec_resource_manager.is_first_draft)
107119
else:
108-
draft_len = self.spec_config.max_draft_len if self.enable_spec_decode else 0
120+
draft_len = self.spec_config.max_draft_len if enable_spec_decode else 0
109121
key = (batch_size, draft_len, False)
110122
return key
111123

112-
@property
113-
def spec_metadata(self):
114-
return self._get_engine().spec_metadata
115-
116-
@property
117-
def draft_tokens_cuda(self):
118-
return self._get_engine().draft_tokens_cuda
119-
120-
@property
121-
def attn_metadata(self):
122-
return self._get_engine().attn_metadata
123-
124124
def __del__(self):
125125
self.clear()
126126

127-
def _get_engine(self) -> "PyTorchModelEngine":
128-
"""Safely dereferences the weak reference to the engine."""
129-
engine = self.engine_ref()
130-
if engine is None:
131-
raise RuntimeError(
132-
"The parent PyTorchModelEngine has been garbage collected.")
133-
return engine
134-
135127
def maybe_get_cuda_graph(
136128
self,
137129
batch: ScheduledRequests,
130+
iter_counter: int,
131+
enable_spec_decode: bool,
132+
attn_metadata: Any,
133+
spec_metadata: Optional[Any],
134+
draft_tokens_cuda: torch.Tensor,
138135
spec_resource_manager: Optional[BaseResourceManager] = None):
139136
"""
140137
Determines if the current batch can be run with a CUDA graph.
@@ -145,17 +142,14 @@ def maybe_get_cuda_graph(
145142
- The spec_metadata for the graph, if applicable.
146143
- The key for the graph.
147144
"""
148-
engine = self._get_engine()
149-
150145
# disable when doing statistic
151-
if hasattr(engine, 'iter_counter') and ExpertStatistic.set_iter(
152-
engine.iter_counter):
146+
if ExpertStatistic.set_iter(iter_counter):
153147
return False, None, None, None
154148

155149
can_run_cuda_graph = batch.can_run_cuda_graph
156150
batch_size = batch.batch_size
157-
if self.enabled and engine.enable_attention_dp and engine.mapping.tp_size > 1:
158-
all_can_graph_batch = engine.dist.tp_allgather(
151+
if self.enabled and self.config.enable_attention_dp and self.config.mapping.tp_size > 1:
152+
all_can_graph_batch = self.config.dist.tp_allgather(
159153
[can_run_cuda_graph, batch_size])
160154
is_all_gen_only = all(all_can_graph[0]
161155
for all_can_graph in all_can_graph_batch)
@@ -168,7 +162,8 @@ def maybe_get_cuda_graph(
168162

169163
if not self.enabled or not can_run_cuda_graph:
170164
return False, None, None, None
171-
key = self.get_graph_key(batch_size, spec_resource_manager)
165+
key = self.get_graph_key(batch_size, enable_spec_decode,
166+
spec_resource_manager)
172167

173168
if key in self.graphs:
174169
return True, self.graph_metadata[key][
@@ -178,29 +173,28 @@ def maybe_get_cuda_graph(
178173
return False, None, None, None
179174

180175
num_sequences_in_batch = batch_size * self.max_beam_width
181-
attn_metadata = self.attn_metadata.create_cuda_graph_metadata(
176+
graph_attn_metadata = attn_metadata.create_cuda_graph_metadata(
182177
num_sequences_in_batch, False, key[1], self.cuda_graph_meta_buffers)
183-
assert attn_metadata.is_cuda_graph
178+
assert graph_attn_metadata.is_cuda_graph
184179

185-
if self.enable_spec_decode:
186-
spec_metadata = self.spec_metadata.create_cuda_graph_metadata(
180+
if enable_spec_decode:
181+
graph_spec_metadata = spec_metadata.create_cuda_graph_metadata(
187182
num_sequences_in_batch)
188-
spec_metadata.draft_tokens = self.draft_tokens_cuda
183+
graph_spec_metadata.draft_tokens = draft_tokens_cuda
189184
else:
190-
spec_metadata = None
191-
return True, attn_metadata, spec_metadata, key
185+
graph_spec_metadata = None
186+
return True, graph_attn_metadata, graph_spec_metadata, key
192187

193188
def needs_capture(self, key: Tuple[int, int, int]):
194-
195189
return key not in self.graph_outputs
196190

197191
def capture(self,
198192
key: Tuple[int, int, int],
199193
forward_fn: Callable,
200194
initial_inputs: Dict[str, Any],
195+
enable_spec_decode: bool,
201196
postprocess_fn: Optional[Callable] = None):
202197
"""Captures the forward pass for a given batch size."""
203-
engine = self._get_engine()
204198
batch_size = key[0]
205199
# [CUDA graph spec decode padding]
206200
# We pad input IDs/position IDs to the maximum draft length (token per request).
@@ -217,7 +211,7 @@ def capture(self,
217211
self.shared_static_tensors["position_ids"]
218212
[:, :num_tokens_for_capture],
219213
}
220-
if engine.use_mrope:
214+
if self.config.use_mrope:
221215
sliced_static_tensors["position_ids"] = self.shared_static_tensors[
222216
"position_ids"][:, :, :num_tokens_for_capture],
223217
sliced_static_tensors[
@@ -235,12 +229,10 @@ def capture(self,
235229
def _setup_spec_decoding_and_forward(key: Tuple[int, int, int],
236230
forward_fn: Callable,
237231
capture_inputs: Dict[str, Any]):
238-
engine = self._get_engine()
239-
# for the first inference of draft model, we need to set the use_spec_decoding to True when capture the graph for multiple runs.
240232
is_first_draft = key[2]
241-
needs_kv_cache_recompute = True if engine.enable_spec_decode and engine.spec_config.spec_dec_mode.needs_kv_cache_recompute(
233+
needs_kv_cache_recompute = True if enable_spec_decode and self.config.spec_config.spec_dec_mode.needs_kv_cache_recompute(
242234
) else False
243-
if is_first_draft and engine.is_draft_model and needs_kv_cache_recompute:
235+
if is_first_draft and self.config.is_draft_model and needs_kv_cache_recompute:
244236
capture_inputs['attn_metadata'].use_spec_decoding = True
245237
return forward_fn(capture_inputs)
246238

@@ -268,7 +260,6 @@ def _setup_spec_decoding_and_forward(key: Tuple[int, int, int],
268260
def replay(self, key: Tuple[int, int, int],
269261
current_inputs: Dict[str, Any]) -> Optional[torch.Tensor]:
270262
"""Replays a previously captured graph."""
271-
engine = self._get_engine()
272263
stored_meta = self.graph_metadata[key]
273264
assert current_inputs["attn_metadata"] is stored_meta["attn_metadata"]
274265
if stored_meta["spec_metadata"] is not None:
@@ -282,7 +273,7 @@ def replay(self, key: Tuple[int, int, int],
282273
static_tensors["input_ids"][:seqlen].copy_(input_ids)
283274

284275
position_ids = current_inputs["position_ids"]
285-
if engine.use_mrope and current_inputs.get(
276+
if self.config.use_mrope and current_inputs.get(
286277
'multimodal_params') is not None:
287278
static_tensors["position_ids"][:, :, :seqlen].copy_(position_ids)
288279
for i, multimodal_param in enumerate(
@@ -302,16 +293,16 @@ def replay(self, key: Tuple[int, int, int],
302293
return output_ref
303294

304295
def _get_padded_batch(self, batch: ScheduledRequests,
305-
resource_manager: ResourceManager) -> int:
306-
engine = self._get_engine()
296+
resource_manager: ResourceManager,
297+
runtime_draft_len: int) -> int:
307298
kv_cache_manager = resource_manager.get_resource_manager(
308-
engine.kv_cache_manager_key)
299+
self.config.kv_cache_manager_key)
309300
can_run_cuda_graph = batch.can_run_cuda_graph
310301
batch_size = batch.batch_size
311302
new_batch_size = batch_size
312303

313-
if self.enabled and engine.enable_attention_dp and engine.mapping.tp_size > 1:
314-
graph_batch_size = engine.dist.tp_allgather(
304+
if self.enabled and self.config.enable_attention_dp and self.config.mapping.tp_size > 1:
305+
graph_batch_size = self.config.dist.tp_allgather(
315306
[can_run_cuda_graph, batch_size])
316307
all_can_graph = all(graph_batch[0]
317308
for graph_batch in graph_batch_size)
@@ -329,7 +320,7 @@ def _get_padded_batch(self, batch: ScheduledRequests,
329320
return 0
330321

331322
padding_size = padded_batch_size - batch_size
332-
if padding_size + batch.batch_size > engine.batch_size:
323+
if padding_size + batch.batch_size > self.config.batch_size:
333324
return 0
334325

335326
# No padding if it would create too many concurrent requests.
@@ -344,9 +335,9 @@ def _get_padded_batch(self, batch: ScheduledRequests,
344335
self.padding_dummy_request = kv_cache_manager.add_dummy_requests(
345336
[CUDA_GRAPH_DUMMY_REQUEST_ID],
346337
is_gen=True,
347-
max_num_draft_tokens=engine.runtime_draft_len,
348-
use_mrope=engine.use_mrope,
349-
max_beam_width=engine.max_beam_width)[0]
338+
max_num_draft_tokens=runtime_draft_len,
339+
use_mrope=self.config.use_mrope,
340+
max_beam_width=self.config.max_beam_width)[0]
350341
self.padding_dummy_request.is_cuda_graph_dummy = True
351342
spec_res_mgr = resource_manager.get_resource_manager(
352343
ResourceManagerType.SPEC_RESOURCE_MANAGER)
@@ -368,11 +359,11 @@ def _round_up_batch_size(self, batch_size: int) -> int:
368359

369360
@contextlib.contextmanager
370361
def pad_batch(self, scheduled_requests: ScheduledRequests,
371-
resource_manager: ResourceManager):
362+
resource_manager: ResourceManager, runtime_draft_len: int):
372363
"""Context manager to pad a batch to a graph-compatible size."""
373-
374364
padding_size = self._get_padded_batch(scheduled_requests,
375-
resource_manager)
365+
resource_manager,
366+
runtime_draft_len)
376367
try:
377368
yield scheduled_requests
378369
finally:

0 commit comments

Comments
 (0)