Skip to content

Commit

Permalink
Fix Gemma2 Attention Args (NVIDIA#11365)
Browse files Browse the repository at this point in the history
* fix gemma2 attention args

* pylint

* kwargs

* Apply isort and black reformatting

Signed-off-by: suiyoubi <[email protected]>

---------

Signed-off-by: suiyoubi <[email protected]>
Co-authored-by: suiyoubi <[email protected]>
  • Loading branch information
suiyoubi and suiyoubi authored Nov 21, 2024
1 parent 8ab46ff commit 9ea442c
Showing 1 changed file with 13 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ class Gemma2DotProductAttention(MegatronModule):
Region where selective activation recomputation is applied.
This region is memory intensive but less compute intensive which
makes activation checkpointing more efficient for LLMs (20B+).
See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details.
See Reducing Activation Recomputation in Large Transformer Models:
https://arxiv.org/abs/2205.05198 for more details.
We use the following notation:
h: hidden size
Expand Down Expand Up @@ -126,7 +127,12 @@ def forward(
attention_mask: Tensor,
attn_mask_type: AttnMaskType = None,
packed_seq_params: PackedSeqParams = None,
**kwargs,
):
"""Forward.
Modified from mcore.transformer.dot_product_attention to support Gemma2-specific
final_logit_softcapping.
"""
assert packed_seq_params is None, (
"Packed sequence is not supported by DotProductAttention." "Please use TEDotProductAttention instead."
)
Expand Down Expand Up @@ -243,6 +249,8 @@ def forward(


class TERowParallelLinearLayerNorm(TERowParallelLinear):
"""Modified From TERowParallelLinear with an additional Post-LN."""

def __init__(
self,
input_size: int,
Expand Down Expand Up @@ -270,12 +278,16 @@ def __init__(
self.post_layernorm = TENorm(config, output_size)

def forward(self, x):
"""Forward with additional Post LN on output"""
output, bias = super().forward(x)
return self.post_layernorm(output), bias


class Gemma2OutputLayer(ColumnParallelLinear):
"""Extends from ColumnParallelLinear with logit soft capping."""

def forward(self, *args, **kwargs):
"""Forward with logit soft capping."""
output, bias = super().forward(*args, **kwargs)
output = logit_softcapping(output, self.config.final_logit_softcapping)
return output, bias

0 comments on commit 9ea442c

Please sign in to comment.