Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions tensorrt_llm/_torch/models/modeling_nemotron_nas.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
from ..modules.gated_mlp import GatedMLP
from ..modules.linear import Linear, TensorParallelMode
from ..modules.rms_norm import RMSNorm
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
register_auto_model)
from ..speculative import SpecMetadata
from .modeling_speculative import SpecDecOneEngineForCausalLM
from .modeling_utils import DecoderModel, register_auto_model


def _ffn_mult_to_intermediate_size(ffn_mult: float, n_embd: int) -> int:
Expand Down Expand Up @@ -117,6 +118,7 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig],
block_config: Dict[str, Any], layer_idx: int):
super().__init__()
config = model_config.pretrained_config
self.layer_idx = layer_idx
self.block_config = block_config
if not self.block_config.attention.no_op:
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
Expand Down Expand Up @@ -150,6 +152,7 @@ def forward(
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor] = None,
spec_metadata: Optional[SpecMetadata] = None,
**kwargs,
) -> torch.Tensor:
if not self.block_config.attention.no_op:
Expand Down Expand Up @@ -178,6 +181,11 @@ def forward(
hidden_states, residual)
hidden_states = self.mlp(hidden_states, **kwargs)

# Capture hidden states for speculative decoding
if spec_metadata is not None:
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
hidden_states, residual)

return hidden_states, residual


Expand Down Expand Up @@ -238,6 +246,7 @@ def forward(
input_ids: Optional[torch.IntTensor] = None,
position_ids: Optional[torch.IntTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
spec_metadata: Optional[SpecMetadata] = None,
lora_params: Optional[dict] = None,
**kwargs,
) -> torch.Tensor:
Expand All @@ -259,6 +268,7 @@ def forward(
hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual,
spec_metadata=spec_metadata,
lora_params=lora_params,
)

Expand All @@ -267,11 +277,8 @@ def forward(


@register_auto_model("DeciLMForCausalLM")
class NemotronNASForCausalLM(DecoderModelForCausalLM[NemotronNASModel,
PretrainedConfig]):
class NemotronNASForCausalLM(SpecDecOneEngineForCausalLM[NemotronNASModel,
PretrainedConfig]):

def __init__(self, model_config: ModelConfig[PretrainedConfig]):
super().__init__(NemotronNASModel(model_config),
config=model_config,
hidden_size=model_config.pretrained_config.hidden_size,
vocab_size=model_config.pretrained_config.vocab_size)
super().__init__(NemotronNASModel(model_config), model_config)