Skip to content

Commit

Permalink
fp8 mha with rope
Browse files Browse the repository at this point in the history
Signed-off-by: Xin Yao <[email protected]>
  • Loading branch information
yaox12 committed Aug 13, 2024
1 parent def4d1c commit b10d27f
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 28 deletions.
57 changes: 30 additions & 27 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3560,19 +3560,19 @@ def forward(
fp8_meta,
deterministic,
):
is_input_fp8 = False
if fp8:
if fp8_meta["recipe"].fp8_mha:
assert isinstance(qkv, Float8Tensor), "qkv must be Float8Tensors for FP8 MHA."
is_input_fp8 = isinstance(qkv, Float8Tensor)
if is_input_fp8:
fp8_meta["scaling_fwd"].scale_inv[META_QKV] = qkv._scale_inv
fused_attention_backend = FusedAttnBackend["FP8"]
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
# 1: qkv packed, 2: kv packed, 3: qkv separate
qkv_group = len(qkv_layout.split("_"))
assert qkv_group == 1, (
"qkv layout should conform to 3hd or h3d, e.g. sb3hd, but found"
f" {qkv_layout}."
f"qkv layout should conform to 3hd or h3d, e.g. sb3hd, but found {qkv_layout}."
)
if fp8_meta["recipe"].fp8_mha:
if is_input_fp8:
qkv_fp8 = qkv._data
else:
qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
Expand Down Expand Up @@ -3621,7 +3621,7 @@ def forward(
qkv_dtype,
).view(out_fp8.shape)
out_save = out_ret
if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
if is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
qkv = cast_from_fp8(
qkv_c._data,
Expand Down Expand Up @@ -3672,6 +3672,7 @@ def forward(
out_save = out_ret

ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
ctx.is_input_fp8 = is_input_fp8
qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None)
ctx.save_for_backward(
*qkvo_tensors, cu_seqlens, cu_seqlens_padded, *fp8_tensors, *aux_ctx_tensors
Expand Down Expand Up @@ -3793,7 +3794,7 @@ def backward(ctx, d_out):
ctx.window_size,
ctx.deterministic,
)
if ctx.fp8_meta["recipe"].fp8_mha:
if ctx.is_input_fp8:
dqkv = Float8Tensor(
data=dqkv_fp8,
fp8_meta=ctx.fp8_meta,
Expand Down Expand Up @@ -3931,22 +3932,22 @@ def forward(
fp8_meta,
deterministic,
):
is_input_fp8 = False
if fp8:
if fp8_meta["recipe"].fp8_mha:
assert isinstance(q, Float8Tensor) and isinstance(
kv, Float8Tensor
), "q/kv must be Float8Tensors for FP8 MHA."
assert type(q) == type(kv), "q and kv must have the same type."
is_input_fp8 = isinstance(q, Float8Tensor)
if is_input_fp8:
fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
fused_attention_backend = FusedAttnBackend["FP8"]
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if fp8_meta["recipe"].fp8_mha:
if is_input_fp8:
q_fp8, kv_fp8 = q._data, kv._data
else:
# 1: qkv packed, 2: kv packed, 3: qkv separate
qkv_group = len(qkv_layout.split("_"))
assert qkv_group == 2, (
"qkv layout should conform to hd_2hd or hd_h2d, e.g. sbhd_sb2hd, "
f" but found {qkv_layout}."
"qkv layout should conform to hd_2hd or hd_h2d, e.g. sbhd_sb2hd, "
f"but found {qkv_layout}."
)
q_fp8 = cast_to_fp8(q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward).view(
q.shape
Expand Down Expand Up @@ -4001,7 +4002,7 @@ def forward(
qkv_dtype,
).view(out_fp8.shape)
out_save = out_ret
if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
if is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
q = cast_from_fp8(
q._data, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward, TE_DType[q.dtype]
).view(q.shape)
Expand Down Expand Up @@ -4060,6 +4061,7 @@ def forward(
fp8_tensors = (None, None, None, None, None)

ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
ctx.is_input_fp8 = is_input_fp8
qkvo_tensors = (q, kv, out_save) if not ctx.fp8 else (None, None, None)
ctx.save_for_backward(
*qkvo_tensors,
Expand Down Expand Up @@ -4196,7 +4198,7 @@ def backward(ctx, d_out):
ctx.window_size,
ctx.deterministic,
)
if ctx.fp8_meta["recipe"].fp8_mha:
if ctx.is_input_fp8:
dq = Float8Tensor(
data=dq_fp8,
fp8_meta=ctx.fp8_meta,
Expand Down Expand Up @@ -4362,15 +4364,13 @@ def forward(
fp8_meta,
deterministic,
):
is_input_fp8 = False
if fp8:
fused_attention_backend = FusedAttnBackend["FP8"]
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if fp8_meta["recipe"].fp8_mha:
assert (
isinstance(q, Float8Tensor)
and isinstance(k, Float8Tensor)
and isinstance(v, Float8Tensor)
), "q/k/v must be Float8Tensors for FP8 MHA."
assert type(q) == type(k) == type(v), "q, k, and v must have the same type."
is_input_fp8 = isinstance(q, Float8Tensor)
if is_input_fp8:
fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
q_fp8, k_fp8, v_fp8 = q._data, k._data, v._data
else:
Expand Down Expand Up @@ -4455,7 +4455,7 @@ def forward(
).view(out_fp8.shape)
out_save = out_ret

if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
if is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
# 1: qkv packed, 2: kv packed, 3: qkv separate
qkv_group = len(qkv_layout.split("_"))
if qkv_group == 1:
Expand Down Expand Up @@ -4572,6 +4572,7 @@ def forward(
tensor.activation_offloading = True

ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
ctx.is_input_fp8 = is_input_fp8
qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None)
ctx.save_for_backward(
*qkvo_tensors,
Expand Down Expand Up @@ -4714,7 +4715,7 @@ def backward(ctx, d_out):
ctx.deterministic,
)

if ctx.fp8_meta["recipe"].fp8_mha:
if ctx.is_input_fp8:
dq = Float8Tensor(
data=dq_fp8,
fp8_meta=ctx.fp8_meta,
Expand Down Expand Up @@ -6579,6 +6580,7 @@ def forward(
layernorm_qkv_outputs = self.layernorm_qkv(
hidden_states,
is_first_microbatch=is_first_microbatch,
is_first_module_in_mha=rotary_pos_emb is None, # specific to FP8 MHA
)
if self.return_layernorm_output:
mixed_x_layer, layernorm_output = layernorm_qkv_outputs
Expand All @@ -6588,7 +6590,7 @@ def forward(
mixed_x_layer = self.qkv(
hidden_states,
is_first_microbatch=is_first_microbatch,
is_first_module_in_mha=True, # specific to FP8 MHA
is_first_module_in_mha=rotary_pos_emb is None, # specific to FP8 MHA
)

num_queries_per_key_value = (
Expand Down Expand Up @@ -6644,7 +6646,7 @@ def forward(
mixed_kv_layer = self.key_value(
encoder_output,
is_first_microbatch=is_first_microbatch,
is_first_module_in_mha=True, # specific to FP8 MHA
is_first_module_in_mha=rotary_pos_emb is None, # specific to FP8 MHA
)

if self.qkv_weight_interleaved:
Expand Down Expand Up @@ -6694,6 +6696,7 @@ def forward(
layernorm_query_outputs = self.layernorm_query(
hidden_states,
is_first_microbatch=is_first_microbatch,
is_first_module_in_mha=rotary_pos_emb is None, # specific to FP8 MHA
)
if self.return_layernorm_output:
query_layer, layernorm_output = layernorm_query_outputs
Expand All @@ -6703,7 +6706,7 @@ def forward(
query_layer = self.query_layer(
hidden_states,
is_first_microbatch=is_first_microbatch,
is_first_module_in_mha=True, # specific to FP8 MHA
is_first_module_in_mha=rotary_pos_emb is None, # specific to FP8 MHA
)

# [sq, b, hp] --> [sq, b, np, hn]
Expand Down
8 changes: 7 additions & 1 deletion transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def forward(
ub_overlap_rs_dgrad: bool,
ub_overlap_ag: bool,
ub_name: str,
is_first_module_in_mha: bool,
fsdp_group: Union[dist_group_type, None],
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
Expand Down Expand Up @@ -199,7 +200,7 @@ def forward(

assert isinstance(weight_fp8, Float8Tensor)

if fp8_meta["recipe"].fp8_mha:
if is_first_module_in_mha:
out_index, meta_tensor, output_te_dtype, output_dtype = (
tex.FP8FwdTensors.GEMM1_OUTPUT,
fp8_meta["scaling_fwd"],
Expand Down Expand Up @@ -744,6 +745,7 @@ def backward(
None, # ub_overlap_rs_dgrad
None, # ub_overlap_ag
None, # ub_name
None, # is_first_module_in_mha
None, # fsdp_group
)

Expand Down Expand Up @@ -1096,6 +1098,7 @@ def forward(
self,
inp: torch.Tensor,
is_first_microbatch: Optional[bool] = None,
is_first_module_in_mha: Optional[bool] = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
"""
Apply layer normalization to the input followed by a linear transformation.
Expand Down Expand Up @@ -1125,6 +1128,8 @@ def forward(

with self.prepare_forward(inp, is_first_microbatch) as inp:

is_first_module_in_mha = is_first_module_in_mha and self.fp8_meta["recipe"].fp8_mha

# Get concatenated weight and bias tensors
unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, Float8Tensor) for w in unfused_weights):
Expand Down Expand Up @@ -1223,6 +1228,7 @@ def forward(
self.ub_overlap_rs_dgrad,
self.ub_overlap_ag,
self.ub_name,
is_first_module_in_mha,
self.fsdp_group,
)
out = fwd_fn(*args)
Expand Down

0 comments on commit b10d27f

Please sign in to comment.