Skip to content

Commit

Permalink
fix various pylint complaints
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoldfarb-nvidia committed Aug 21, 2024
1 parent c4cba9b commit 050bce8
Showing 1 changed file with 36 additions and 23 deletions.
59 changes: 36 additions & 23 deletions transformer_engine/jax/cpp_extensions/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,50 +1056,59 @@ class _FusedAttnCPWithAllGatherHelper:
context_parallel_load_balanced: bool

def check_supported(self):
"""Checks if the context parallel implementation is supported by the given arguments."""
header = "Context parallel fused attention only supports"

allowed_layouts = [NVTE_QKV_Layout.NVTE_BSHD_BS2HD, NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD]
assert self.qkv_layout in allowed_layouts, (
f"{header} supports layouts: {','.join([str(x) for x in cp_allowed_layouts])} got:"
f" {qkv_layout}"
f"{header} supports layouts: {','.join([str(x) for x in allowed_layouts])} got:"
f" {self.qkv_layout}"
)

allowed_masks = [NVTE_Mask_Type.NVTE_NO_MASK, NVTE_Mask_Type.NVTE_CAUSAL_MASK]
assert self.attn_mask_type in allowed_masks, (
f"{header} masking types: {','.join([str(x) for x in allowed_masks])} got:"
f" {attn_mask_type}"
f" {self.attn_mask_type}"
)

assert (
self.max_segments_per_seq == 1
), f"{header} max_segments_per_seq == 1 got: {max_segments_per_seq}"
), f"{header} max_segments_per_seq == 1 got: {self.max_segments_per_seq}"
assert self.dropout_probability == 0.0, f"{header} does not support dropout"

def get_adjusted_mask(self):
"""Converts the mask for context parallelism."""
if self.attn_mask_type == NVTE_Mask_Type.NVTE_CAUSAL_MASK:
return NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK
return mask
return self.attn_mask_type

def all_gather_kv(self, k, v):
ag = lambda x: all_gather_along_cp_axis(x, self.mesh, tiled=True, axis=1)
"""Performs a all-gather of k and v over context parallel ranks."""

def ag(x):
return all_gather_along_cp_axis(x, self.mesh, tiled=True, axis=1)

match self.qkv_layout:
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
return ag(k), v
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
return ag(k), ag(v)
case _:
raise ValueError(f"Unsupported layout {qkv_layout=}")
raise ValueError(f"Unsupported layout {self.qkv_layout=}")

def reduce_scatter_dkv(self, dk, dv):
rs = lambda x: reduce_scatter_along_cp_axis(x, self.mesh, axis=1, tiled=True)
"""Performs a reduce-scatter of dk and dv over context parallel ranks."""

def rs(x):
return reduce_scatter_along_cp_axis(x, self.mesh, axis=1, tiled=True)

match self.qkv_layout:
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
return rs(dk), dv
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
return rs(dk), rs(dv)
case _:
raise ValueError(f"Unsupported layout {qkv_layout=}")
raise ValueError(f"Unsupported layout {self.qkv_layout=}")

def kv_seqlens_for_rank(self, cp_rank, kv_max_seqlen, kv_seqlen_per_subrank):
"""Returns sequence lengths of KV to use for each sub rank of the given cp_rank.
Expand Down Expand Up @@ -1129,17 +1138,25 @@ def kv_seqlens_for_rank(self, cp_rank, kv_max_seqlen, kv_seqlen_per_subrank):
return kv_seq_this_rank

def slice_kv(self, k, v, slice_seq_len):
sliced = lambda x: lax.dynamic_slice_in_dim(x, 0, slice_seq_len, axis=1)
"""Slices k and v tensors to a sequence length of slice_seq_len."""

def sliced(x):
return lax.dynamic_slice_in_dim(x, 0, slice_seq_len, axis=1)

match self.qkv_layout:
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
return sliced(k), v
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
return sliced(k), sliced(v)
case _:
raise ValueError(f"Unsupported layout {qkv_layout=}")
raise ValueError(f"Unsupported layout {self.qkv_layout=}")

def pad_kv(self, dk, dv, pad_seq_len):
pad = lambda x, npad: jnp.pad(x, npad, "constant", constant_values=0.0)
"""Pads dk and dv tensors to a sequence length of pad_seq_len."""

def pad(x, npad):
return jnp.pad(x, npad, "constant", constant_values=0.0)

match self.qkv_layout:
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0], [0, 0]]
Expand All @@ -1148,7 +1165,7 @@ def pad_kv(self, dk, dv, pad_seq_len):
npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0]]
return pad(dk, npad), pad(dv, npad)
case _:
raise ValueError(f"Unsupported layout {qkv_layout=}")
raise ValueError(f"Unsupported layout {self.qkv_layout=}")


class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
Expand Down Expand Up @@ -1176,7 +1193,6 @@ def partition(

# Call base implementation for non-context parallel mesh.
if not is_context_parallel:
print(f"not context parallel falling back to base impl")
return FusedAttnFwdPrimitive.partition(
attn_bias_type,
attn_mask_type,
Expand Down Expand Up @@ -1212,7 +1228,6 @@ def partition(

def impl(q, k, v, bias, q_seqlen, kv_seqlen, q_seq_offsets, k_seq_offsets, seed):

_not_used = jnp.zeros(0, dtype=q.dtype)
cp_size = get_cp_axis_size(mesh)
cp_rank = get_cp_axis_rank(mesh)

Expand Down Expand Up @@ -1248,8 +1263,8 @@ def _cross_attn(idx, q, k, v, bias, q_seqlen, kv_seqlen, seed):
bias,
q_seqlen,
kv_seqlen,
_not_used, # q_seq_offsets
_not_used, # k_seq_offsets
q_seq_offsets,
k_seq_offsets,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=helper.get_adjusted_mask(),
Expand Down Expand Up @@ -1371,8 +1386,6 @@ def impl(
q_seq_offsets,
k_seq_offsets,
):

_not_used = jnp.zeros(0, dtype=q.dtype)
cp_size = get_cp_axis_size(mesh)
cp_rank = get_cp_axis_rank(mesh)

Expand All @@ -1398,9 +1411,9 @@ def _cross_attn_bwd(
for sub_idx in range(2):
k_sliced, v_sliced = helper.slice_kv(k, v, kv_seqlens_for_rank[sub_idx])

q_seqlen = q_seqlen / (cp_size * 2)
q_seqlen = q_seqlen // (cp_size * 2)
num_kv_chunks = kv_max_seqlen // kv_seqlens_for_rank[sub_idx]
kv_seqlen = (kv_seqlen / (cp_size * 2)) * num_kv_chunks
kv_seqlen = (kv_seqlen // (cp_size * 2)) * num_kv_chunks

dq_local, dk_local, dv_local, dbias_local = FusedAttnBwdPrimitive.impl(
q_split[sub_idx],
Expand All @@ -1413,8 +1426,8 @@ def _cross_attn_bwd(
doutput_split[sub_idx],
q_seqlen,
kv_seqlen,
_not_used, # q_seq_offsets
_not_used, # k_seq_offsets
q_seq_offsets,
k_seq_offsets,
attn_bias_type=attn_bias_type,
attn_mask_type=helper.get_adjusted_mask(),
qkv_layout=qkv_layout,
Expand Down

0 comments on commit 050bce8

Please sign in to comment.