Skip to content

Commit

Permalink
rename is_first_module_in_mha to fp8_output
Browse files Browse the repository at this point in the history
Signed-off-by: Xin Yao <[email protected]>
  • Loading branch information
yaox12 committed Aug 14, 2024
1 parent abbcd9b commit a1ba977
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 18 deletions.
10 changes: 5 additions & 5 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6592,7 +6592,7 @@ def forward(
layernorm_qkv_outputs = self.layernorm_qkv(
hidden_states,
is_first_microbatch=is_first_microbatch,
is_first_module_in_mha=rotary_pos_emb is None, # specific to FP8 MHA
fp8_output=rotary_pos_emb is None, # specific to FP8 MHA
)
if self.return_layernorm_output:
mixed_x_layer, layernorm_output = layernorm_qkv_outputs
Expand All @@ -6602,7 +6602,7 @@ def forward(
mixed_x_layer = self.qkv(
hidden_states,
is_first_microbatch=is_first_microbatch,
is_first_module_in_mha=rotary_pos_emb is None, # specific to FP8 MHA
fp8_output=rotary_pos_emb is None, # specific to FP8 MHA
)

num_queries_per_key_value = (
Expand Down Expand Up @@ -6658,7 +6658,7 @@ def forward(
mixed_kv_layer = self.key_value(
encoder_output,
is_first_microbatch=is_first_microbatch,
is_first_module_in_mha=rotary_pos_emb is None, # specific to FP8 MHA
fp8_output=rotary_pos_emb is None, # specific to FP8 MHA
)

if self.qkv_weight_interleaved:
Expand Down Expand Up @@ -6708,7 +6708,7 @@ def forward(
layernorm_query_outputs = self.layernorm_query(
hidden_states,
is_first_microbatch=is_first_microbatch,
is_first_module_in_mha=rotary_pos_emb is None, # specific to FP8 MHA
fp8_output=rotary_pos_emb is None, # specific to FP8 MHA
)
if self.return_layernorm_output:
query_layer, layernorm_output = layernorm_query_outputs
Expand All @@ -6718,7 +6718,7 @@ def forward(
query_layer = self.query_layer(
hidden_states,
is_first_microbatch=is_first_microbatch,
is_first_module_in_mha=rotary_pos_emb is None, # specific to FP8 MHA
fp8_output=rotary_pos_emb is None, # specific to FP8 MHA
)

# [sq, b, hp] --> [sq, b, np, hn]
Expand Down
12 changes: 6 additions & 6 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def forward(
ub_overlap_rs_dgrad: bool,
ub_overlap_ag: bool,
ub_name: str,
is_first_module_in_mha: bool,
fp8_output: bool,
fsdp_group: Union[dist_group_type, None],
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
Expand Down Expand Up @@ -200,7 +200,7 @@ def forward(

assert isinstance(weight_fp8, Float8Tensor)

if is_first_module_in_mha:
if fp8_output:
out_index, meta_tensor, output_te_dtype, output_dtype = (
tex.FP8FwdTensors.GEMM1_OUTPUT,
fp8_meta["scaling_fwd"],
Expand Down Expand Up @@ -745,7 +745,7 @@ def backward(
None, # ub_overlap_rs_dgrad
None, # ub_overlap_ag
None, # ub_name
None, # is_first_module_in_mha
None, # fp8_output
None, # fsdp_group
)

Expand Down Expand Up @@ -1098,7 +1098,7 @@ def forward(
self,
inp: torch.Tensor,
is_first_microbatch: Optional[bool] = None,
is_first_module_in_mha: Optional[bool] = False,
fp8_output: Optional[bool] = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
"""
Apply layer normalization to the input followed by a linear transformation.
Expand Down Expand Up @@ -1128,7 +1128,7 @@ def forward(

with self.prepare_forward(inp, is_first_microbatch) as inp:

is_first_module_in_mha = is_first_module_in_mha and self.fp8_meta["recipe"].fp8_mha
fp8_output = fp8_output and self.fp8_meta["recipe"].fp8_mha

# Get concatenated weight and bias tensors
unfused_weights = [getattr(self, name) for name in self.weight_names]
Expand Down Expand Up @@ -1228,7 +1228,7 @@ def forward(
self.ub_overlap_rs_dgrad,
self.ub_overlap_ag,
self.ub_name,
is_first_module_in_mha,
fp8_output,
self.fsdp_group,
)
out = fwd_fn(*args)
Expand Down
14 changes: 7 additions & 7 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def forward(
ub_overlap_rs: bool,
ub_overlap_ag: bool,
ub_name: str,
is_first_module_in_mha: bool,
fp8_output: bool,
fsdp_group: Union[dist_group_type, None],
) -> torch.Tensor:
is_input_fp8 = isinstance(inp, Float8Tensor)
Expand Down Expand Up @@ -153,7 +153,7 @@ def forward(

assert isinstance(weight_fp8, Float8Tensor)

if is_first_module_in_mha:
if fp8_output:
proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = (
tex.FP8FwdTensors.GEMM1_OUTPUT,
fp8_meta["scaling_fwd"],
Expand Down Expand Up @@ -222,7 +222,7 @@ def forward(
fp8_meta_tensor=meta_tensor,
D_dtype=proj_out_tetype,
)
if is_first_module_in_mha:
if fp8_output:
out = Float8Tensor(
data=out,
fp8_meta=fp8_meta,
Expand Down Expand Up @@ -621,7 +621,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
None, # ub_overlap_rs
None, # ub_overlap_ag
None, # ub_name
None, # is_first_module_in_mha
None, # fp8_output
None, # fsdp_group
)

Expand Down Expand Up @@ -899,7 +899,7 @@ def forward(
self,
inp: torch.Tensor,
is_first_microbatch: Optional[bool] = None,
is_first_module_in_mha: Optional[bool] = False,
fp8_output: Optional[bool] = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
"""
Apply the linear transformation to the input.
Expand Down Expand Up @@ -933,7 +933,7 @@ def forward(
allow_non_contiguous=isinstance(inp, Float8Tensor),
) as inp:

is_first_module_in_mha = is_first_module_in_mha and self.fp8_meta["recipe"].fp8_mha
fp8_output = fp8_output and self.fp8_meta["recipe"].fp8_mha

# Get concatenated weight and bias tensors
unfused_weights = [getattr(self, name) for name in self.weight_names]
Expand Down Expand Up @@ -1019,7 +1019,7 @@ def forward(
self.ub_overlap_rs,
self.ub_overlap_ag,
self.ub_name,
is_first_module_in_mha,
fp8_output,
self.fsdp_group,
)
out = linear_fn(*args)
Expand Down

0 comments on commit a1ba977

Please sign in to comment.