Skip to content

Commit 9eac744

Browse files
authored
[https://nvbugs/5464088] [fix] dequantize fp8 activation input to lora forward; update perf test config (NVIDIA#7014)
Signed-off-by: Venky Ganesh <[email protected]>
1 parent a875e50 commit 9eac744

File tree

2 files changed

+25
-6
lines changed

2 files changed

+25
-6
lines changed

tensorrt_llm/_torch/modules/gated_mlp.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch.nn.functional as F
66
from torch import nn
77

8+
from tensorrt_llm.logger import logger
89
from tensorrt_llm.mapping import Mapping
910

1011
from ..distributed import AllReduceParams
@@ -95,12 +96,21 @@ def __init__(self,
9596
[LoraModuleType.MLP_GATE_UP],
9697
[2 * self.intermediate_size // mapping.tp_size])
9798

98-
def _apply_activation(self, x):
99+
def _apply_activation(self, x, *, has_lora: bool = False):
99100
if self.activation == F.silu:
100101
if self.down_proj.has_fp8_qdq:
101-
return swiglu(x,
102-
quant_scale=self.down_proj.input_scale,
103-
quant_type=torch.float8_e4m3fn)
102+
if has_lora:
103+
# NOTE: This is a WAR, since LoRA grouped_gemm does not support FP8 yet.
104+
# TODO: Remove this path when LoRA grouped_gemm supports FP8
105+
# see: cpp/tensorrt_llm/thop/loraOp.cpp::lora_grouped_gemm
106+
logger.warning(
107+
f"GatedMLP._apply_activation: LoRA path active; forcing non-FP8 activation dtype bf16/fp16, layer_idx={self.layer_idx}"
108+
)
109+
return swiglu(x)
110+
else:
111+
return swiglu(x,
112+
quant_scale=self.down_proj.input_scale,
113+
quant_type=torch.float8_e4m3fn)
104114
else:
105115
return swiglu(x)
106116
elif callable(self.activation):
@@ -152,7 +162,7 @@ def forward_lora(
152162
if h1_lora is not None:
153163
h1 = h1 + h1_lora
154164

155-
h2 = self._apply_activation(h1)
165+
h2 = self._apply_activation(h1, has_lora=True)
156166
output = self.down_proj(h2,
157167
all_reduce_params=final_all_reduce_params,
158168
lora_params=lora_params,

tests/integration/defs/perf/pytorch_model_config.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,19 @@ def get_model_yaml_config(model_label: str,
181181

182182
# lora-specific change for pytorch
183183
if 'pytorch' in model_label and 'loras' in model_label:
184+
# Derive the requested number of adapters from model_label (segment like "loras:X")
185+
lora_count = 1
186+
for part in model_label.split('-'):
187+
if part.startswith('loras:'):
188+
lora_count = max(1, int(part.split(':', 1)[1]))
189+
break
190+
184191
lora_config = {
185192
'lora_config': {
186193
'lora_dir': lora_dirs if lora_dirs is not None else [],
187-
'max_lora_rank': 64
194+
'max_lora_rank': 64,
195+
'max_loras': lora_count,
196+
'max_cpu_loras': lora_count,
188197
}
189198
}
190199
if 'phi_4_multimodal_instruct' in model_label:

0 commit comments

Comments
 (0)