Skip to content
181 changes: 87 additions & 94 deletions tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Large diffs are not rendered by default.

50 changes: 42 additions & 8 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
set_torch_compiling, with_model_extra_attrs)
from .config import _construct_checkpoint_loader
from .config_utils import is_mla
from .cuda_graph_runner import CUDAGraphRunner
from .cuda_graph_runner import CUDAGraphRunner, CUDAGraphRunnerConfig
from .guided_decoder import CapturableGuidedDecoder
from .layerwise_nvtx_marker import LayerwiseNvtxMarker
from .llm_request import get_draft_token_length
Expand Down Expand Up @@ -371,9 +371,31 @@ def __init__(
# We look up this key in resource_manager during forward to find the
# kv cache manager. Can be changed to support multiple model engines
# with different KV cache managers.
self.kv_cache_manager_key = ResourceManagerType.KV_CACHE_MANAGER
self.kv_cache_manager_key = ResourceManagerType.DRAFT_KV_CACHE_MANAGER if is_draft_model else ResourceManagerType.KV_CACHE_MANAGER
self.lora_model_config: Optional[LoraModelConfig] = None
self.cuda_graph_runner = CUDAGraphRunner(self)

# Create config and runner
cuda_graph_runner_config = CUDAGraphRunnerConfig(
use_cuda_graph=self.cuda_graph_config is not None,
cuda_graph_padding_enabled=self._cuda_graph_padding_enabled,
cuda_graph_batch_sizes=self._cuda_graph_batch_sizes,
max_cuda_graph_batch_size=self._max_cuda_graph_batch_size,
max_beam_width=self.max_beam_width,
spec_config=self.spec_config,
cuda_graph_mem_pool=self._cuda_graph_mem_pool,
max_num_tokens=self.max_num_tokens,
use_mrope=self.use_mrope,
original_max_draft_len=self.original_max_draft_len,
original_max_total_draft_tokens=self.
original_max_total_draft_tokens,
is_draft_model=self.is_draft_model,
enable_attention_dp=self.enable_attention_dp,
batch_size=self.batch_size,
mapping=self.mapping,
dist=self.dist,
kv_cache_manager_key=self.kv_cache_manager_key,
)
self.cuda_graph_runner = CUDAGraphRunner(cuda_graph_runner_config)

# Setup the local cache indirection buffer only once and reuse it.
# This way it can also be used for CUDA graphs.
Expand Down Expand Up @@ -2320,10 +2342,19 @@ def forward(
return self._forward_step(inputs, gather_ids,
gather_context_logits)
with self.cuda_graph_runner.pad_batch(
scheduled_requests, resource_manager) as padded_requests:
scheduled_requests, resource_manager,
self.runtime_draft_len) as padded_requests:

maybe_graph, maybe_attn_metadata, maybe_spec_metadata, key = self.cuda_graph_runner.maybe_get_cuda_graph(
padded_requests, spec_resource_manager)
padded_requests,
iter_counter=self.iter_counter,
enable_spec_decode=self.enable_spec_decode,
attn_metadata=attn_metadata,
spec_metadata=spec_metadata,
draft_tokens_cuda=self.draft_tokens_cuda
if self.is_spec_decode else None,
spec_resource_manager=spec_resource_manager,
)
if maybe_graph:
attn_metadata = maybe_attn_metadata
spec_metadata = maybe_spec_metadata
Expand Down Expand Up @@ -2358,9 +2389,12 @@ def capture_forward_fn(inputs: Dict[str, Any]):
def capture_postprocess_fn(inputs: Dict[str, Any]):
self._postprocess_inputs(inputs)

self.cuda_graph_runner.capture(key, capture_forward_fn,
inputs,
capture_postprocess_fn)
self.cuda_graph_runner.capture(
key,
capture_forward_fn,
inputs,
enable_spec_decode=self.enable_spec_decode,
postprocess_fn=capture_postprocess_fn)

# here we don't need to use context since cuda graph capture didn't run kernel.
# maybe we need a cleaner way to do this.
Expand Down
1 change: 0 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,6 @@ def drafting_loop_wrapper(model):
# For DeepseekV3 MTP, we need to set the num_hidden_layers to 1 for the draft model
if spec_config.spec_dec_mode.is_mtp_eagle():
draft_model_engine.model.model_config.pretrained_config.num_hidden_layers = 1
draft_model_engine.kv_cache_manager_key = ResourceManagerType.DRAFT_KV_CACHE_MANAGER
draft_model_engine.load_weights_from_target_model(
model_engine.model)
else:
Expand Down
60 changes: 22 additions & 38 deletions tests/unittest/_torch/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import torch
import torch.nn.functional as F

from tensorrt_llm.llmapi.llm_args import TorchLlmArgs
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import (
CUDAGraphRunner, CUDAGraphRunnerConfig)
from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
from tensorrt_llm.mapping import Mapping


def ceil_div(x: int, y: int) -> int:
Expand Down Expand Up @@ -166,42 +169,23 @@ def block_scale_gemm(mat_a: torch.Tensor, mat_scale_a: torch.Tensor,
return results.view_as(x)


class MockPytorchBackendConfig:

def __init__(self, use_cuda_graph, cuda_graph_padding_enabled):
self.use_cuda_graph = use_cuda_graph
self.cuda_graph_padding_enabled = cuda_graph_padding_enabled


class MockEngine:
"""A replacement for SimpleNamespace that supports weak references."""

def __init__(self, **kwargs):
self.__dict__.update(kwargs)


def create_mock_engine(batch_size: int):

class MockSpecConfig:

class SpecDecMode:

def needs_kv_cache_recompute(self):
return False

spec_dec_mode = SpecDecMode()

return MockEngine(
llm_args=TorchLlmArgs(model="dummy"),
_cuda_graph_padding_enabled=True,
_cuda_graph_batch_sizes=[batch_size],
_max_cuda_graph_batch_size=batch_size,
def create_mock_cuda_graph_runner(batch_size: int, use_mrope: bool = False):
config = CUDAGraphRunnerConfig(
use_cuda_graph=True,
cuda_graph_padding_enabled=False,
cuda_graph_batch_sizes=[batch_size],
max_cuda_graph_batch_size=batch_size,
batch_size=batch_size,
max_beam_width=1,
max_num_tokens=8192,
is_spec_decode=False,
enable_spec_decode=False,
spec_config=MockSpecConfig(),
max_num_tokens=1,
use_mrope=use_mrope,
spec_config=None,
cuda_graph_mem_pool=None,
enable_attention_dp=False,
original_max_draft_len=0,
original_max_total_draft_tokens=0,
is_draft_model=False,
_cuda_graph_mem_pool=None,
use_mrope=False,
)
mapping=Mapping(),
dist=None,
kv_cache_manager_key=ResourceManagerType.KV_CACHE_MANAGER)
return CUDAGraphRunner(config)
9 changes: 3 additions & 6 deletions tests/unittest/_torch/modeling/test_modeling_exaone4.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class Exaone4Config(PretrainedConfig):
# TODO: Remove this once we have a proper config for Exaone4
SKIP_EXAONE4_HF_ACCURACY_TEST = True

from _torch.helpers import create_mock_engine
from _torch.helpers import create_mock_cuda_graph_runner
from transformers.cache_utils import HybridCache
from utils.util import getSMVersion

Expand All @@ -31,7 +31,6 @@ class Exaone4Config(PretrainedConfig):
from tensorrt_llm._torch.metadata import KVCacheParams
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_exaone4 import Exaone4ForCausalLM
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.mapping import Mapping
Expand Down Expand Up @@ -338,10 +337,8 @@ def test_exaone4_allclose_to_hf(self, scenario: Scenario) -> None:
]
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()

graph_runner = None
if scenario.use_cuda_graph:
mock_engine = create_mock_engine(1)
graph_runner = CUDAGraphRunner(mock_engine)
graph_runner = create_mock_cuda_graph_runner(
1) if scenario.use_cuda_graph else None

def run_forward(input_ids, position_ids, attn_metadata):
attn_metadata.prepare()
Expand Down
9 changes: 3 additions & 6 deletions tests/unittest/_torch/modeling/test_modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any

import torch
from _torch.helpers import create_mock_engine
from _torch.helpers import create_mock_cuda_graph_runner
from parameterized import parameterized
from transformers import LlamaConfig
from transformers import LlamaForCausalLM as HFLlamaForCausalLM
Expand All @@ -16,7 +16,6 @@
from tensorrt_llm._torch.metadata import KVCacheParams
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_llama import LlamaForCausalLM
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequestState
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests
Expand Down Expand Up @@ -331,10 +330,8 @@ def test_llama_allclose_to_hf(self, scenario: Scenario) -> None:
]
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()

graph_runner = None
if scenario.use_cuda_graph:
mock_engine = create_mock_engine(1)
graph_runner = CUDAGraphRunner(mock_engine)
graph_runner = create_mock_cuda_graph_runner(
1) if scenario.use_cuda_graph else None

def run_forward(input_ids, position_ids, attn_metadata):
attn_metadata.prepare()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
import transformers
from _torch.helpers import create_mock_engine
from _torch.helpers import create_mock_cuda_graph_runner
from parameterized import parameterized
from transformers import Llama4Config
from transformers import \
Expand All @@ -21,7 +21,6 @@
from tensorrt_llm._torch.models.modeling_llama import \
Llama4ForConditionalGeneration
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.mapping import Mapping
Expand Down Expand Up @@ -405,10 +404,8 @@ def test_llama_allclose_to_hf(self, scenario: AllCloseScenario) -> None:
input_ids.size(-1) + gen_input_ids.size(-1))
]
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
graph_runner = None
if scenario.use_cuda_graph:
mock_engine = create_mock_engine(1)
graph_runner = CUDAGraphRunner(mock_engine)
graph_runner = create_mock_cuda_graph_runner(
1) if scenario.use_cuda_graph else None

def run_forward(input_ids, position_ids, attn_metadata):
attn_metadata.prepare()
Expand Down
8 changes: 2 additions & 6 deletions tests/unittest/_torch/modeling/test_modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
import transformers
import transformers.models.mistral3
from _torch.helpers import create_mock_engine
from _torch.helpers import create_mock_cuda_graph_runner
from PIL import Image
from utils.util import getSMVersion

Expand All @@ -19,7 +19,6 @@
from tensorrt_llm._torch.attention_backend import utils as attention_utils
from tensorrt_llm._torch.models import modeling_mistral
from tensorrt_llm._torch.pyexecutor import resource_manager
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
from tensorrt_llm.bindings import executor as executor_lib
from tensorrt_llm.models import modeling_utils

Expand Down Expand Up @@ -404,10 +403,7 @@ def test_mistral_3_vlm_allclose_to_hf(mistral_small_3_1_24b_config, backend, use
]
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()

graph_runner = None
if use_cuda_graph:
mock_engine = create_mock_engine(1)
graph_runner = CUDAGraphRunner(mock_engine)
graph_runner = create_mock_cuda_graph_runner(1) if use_cuda_graph else None

def run_forward(input_ids, position_ids, attn_metadata):
attn_metadata.prepare()
Expand Down
9 changes: 3 additions & 6 deletions tests/unittest/_torch/modeling/test_modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass

import torch
from _torch.helpers import create_mock_engine
from _torch.helpers import create_mock_cuda_graph_runner
from parameterized import parameterized
from transformers import MixtralConfig
from transformers import MixtralForCausalLM as HFMixtralForCausalLM
Expand All @@ -16,7 +16,6 @@
from tensorrt_llm._torch.models.checkpoints.hf.mixtral_weight_mapper import \
MixtralHfWeightMapper
from tensorrt_llm._torch.models.modeling_mixtral import MixtralForCausalLM
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.mapping import Mapping
Expand Down Expand Up @@ -310,10 +309,8 @@ def test_mixtral_allclose_to_hf(self, scenario: Scenario):
]
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()

graph_runner = None
if scenario.use_cuda_graph:
mock_engine = create_mock_engine(1)
graph_runner = CUDAGraphRunner(mock_engine)
graph_runner = create_mock_cuda_graph_runner(
1) if scenario.use_cuda_graph else None

def run_forward(input_ids, position_ids, attn_metadata):
attn_metadata.prepare()
Expand Down
9 changes: 3 additions & 6 deletions tests/unittest/_torch/modeling/test_modeling_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest
import torch
from _torch.helpers import create_mock_engine
from _torch.helpers import create_mock_cuda_graph_runner
from parameterized import parameterized
from test_modeling_llama import Scenario, reduce_llama_config
from transformers import MllamaConfig
Expand All @@ -17,7 +17,6 @@
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_mllama import \
MllamaForConditionalGeneration
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.mapping import Mapping
Expand Down Expand Up @@ -420,10 +419,8 @@ def test_mllama_allclose_to_hf_text_only(self, scenario: Scenario) -> None:
]
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()

graph_runner = None
if scenario.use_cuda_graph:
mock_engine = create_mock_engine(1)
graph_runner = CUDAGraphRunner(mock_engine)
graph_runner = create_mock_cuda_graph_runner(
1) if scenario.use_cuda_graph else None

def run_forward(input_ids, position_ids, attn_metadata):
attn_metadata.prepare()
Expand Down
6 changes: 2 additions & 4 deletions tests/unittest/_torch/modeling/test_modeling_multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Dict, List, Optional, Tuple, Type

import torch
from _torch.helpers import create_mock_engine
from _torch.helpers import create_mock_cuda_graph_runner
from transformers import AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
from utils.llm_data import llm_models_root

Expand All @@ -17,7 +17,6 @@
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
from tensorrt_llm._torch.metadata import KVCacheParams
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm._utils import str_dtype_to_torch
from tensorrt_llm.bindings.executor import KvCacheConfig
Expand Down Expand Up @@ -425,8 +424,7 @@ def run_trtllm_forward(self, trtllm_inputs, use_cuda_graph: bool = False):
trtllm_inputs["attn_metadata"].prepare()
return self.trtllm_model.forward(**trtllm_inputs)
else:
mock_engine = create_mock_engine(1)
graph_runner = CUDAGraphRunner(mock_engine)
graph_runner = create_mock_cuda_graph_runner(1)
trtllm_inputs["attn_metadata"] = trtllm_inputs[
"attn_metadata"
].create_cuda_graph_metadata(1)
Expand Down
9 changes: 3 additions & 6 deletions tests/unittest/_torch/modeling/test_modeling_nemotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any

import torch
from _torch.helpers import create_mock_engine
from _torch.helpers import create_mock_cuda_graph_runner
from parameterized import parameterized
from transformers import NemotronConfig
from transformers import NemotronForCausalLM as HFNemotronForCausalLM
Expand All @@ -15,7 +15,6 @@
from tensorrt_llm._torch.metadata import KVCacheParams
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_nemotron import NemotronForCausalLM
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.mapping import Mapping
Expand Down Expand Up @@ -318,10 +317,8 @@ def test_nemotron_allclose_to_hf(self, scenario: Scenario) -> None:
]
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()

graph_runner = None
if scenario.use_cuda_graph:
mock_engine = create_mock_engine(1)
graph_runner = CUDAGraphRunner(mock_engine)
graph_runner = create_mock_cuda_graph_runner(
1) if scenario.use_cuda_graph else None

def run_forward(input_ids, position_ids, attn_metadata):
attn_metadata.prepare()
Expand Down
Loading
Loading