From 050bce824d6a854d8af844631df3e37db6f4d05d Mon Sep 17 00:00:00 2001 From: Michael Goldfarb Date: Wed, 21 Aug 2024 21:03:21 +0000 Subject: [PATCH] fix various pylint complaints --- .../jax/cpp_extensions/attention.py | 59 +++++++++++-------- 1 file changed, 36 insertions(+), 23 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index e07cad4e94..81de2cf126 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1056,50 +1056,59 @@ class _FusedAttnCPWithAllGatherHelper: context_parallel_load_balanced: bool def check_supported(self): + """Checks if the context parallel implementation is supported by the given arguments.""" header = "Context parallel fused attention only supports" allowed_layouts = [NVTE_QKV_Layout.NVTE_BSHD_BS2HD, NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD] assert self.qkv_layout in allowed_layouts, ( - f"{header} supports layouts: {','.join([str(x) for x in cp_allowed_layouts])} got:" - f" {qkv_layout}" + f"{header} supports layouts: {','.join([str(x) for x in allowed_layouts])} got:" + f" {self.qkv_layout}" ) allowed_masks = [NVTE_Mask_Type.NVTE_NO_MASK, NVTE_Mask_Type.NVTE_CAUSAL_MASK] assert self.attn_mask_type in allowed_masks, ( f"{header} masking types: {','.join([str(x) for x in allowed_masks])} got:" - f" {attn_mask_type}" + f" {self.attn_mask_type}" ) assert ( self.max_segments_per_seq == 1 - ), f"{header} max_segments_per_seq == 1 got: {max_segments_per_seq}" + ), f"{header} max_segments_per_seq == 1 got: {self.max_segments_per_seq}" assert self.dropout_probability == 0.0, f"{header} does not support dropout" def get_adjusted_mask(self): """Converts the mask for context parallelism.""" if self.attn_mask_type == NVTE_Mask_Type.NVTE_CAUSAL_MASK: return NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK - return mask + return self.attn_mask_type def all_gather_kv(self, k, v): - ag = lambda x: all_gather_along_cp_axis(x, self.mesh, tiled=True, axis=1) + """Performs a all-gather of k and v over context parallel ranks.""" + + def ag(x): + return all_gather_along_cp_axis(x, self.mesh, tiled=True, axis=1) + match self.qkv_layout: case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: return ag(k), v case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: return ag(k), ag(v) case _: - raise ValueError(f"Unsupported layout {qkv_layout=}") + raise ValueError(f"Unsupported layout {self.qkv_layout=}") def reduce_scatter_dkv(self, dk, dv): - rs = lambda x: reduce_scatter_along_cp_axis(x, self.mesh, axis=1, tiled=True) + """Performs a reduce-scatter of dk and dv over context parallel ranks.""" + + def rs(x): + return reduce_scatter_along_cp_axis(x, self.mesh, axis=1, tiled=True) + match self.qkv_layout: case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: return rs(dk), dv case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: return rs(dk), rs(dv) case _: - raise ValueError(f"Unsupported layout {qkv_layout=}") + raise ValueError(f"Unsupported layout {self.qkv_layout=}") def kv_seqlens_for_rank(self, cp_rank, kv_max_seqlen, kv_seqlen_per_subrank): """Returns sequence lengths of KV to use for each sub rank of the given cp_rank. @@ -1129,17 +1138,25 @@ def kv_seqlens_for_rank(self, cp_rank, kv_max_seqlen, kv_seqlen_per_subrank): return kv_seq_this_rank def slice_kv(self, k, v, slice_seq_len): - sliced = lambda x: lax.dynamic_slice_in_dim(x, 0, slice_seq_len, axis=1) + """Slices k and v tensors to a sequence length of slice_seq_len.""" + + def sliced(x): + return lax.dynamic_slice_in_dim(x, 0, slice_seq_len, axis=1) + match self.qkv_layout: case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: return sliced(k), v case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: return sliced(k), sliced(v) case _: - raise ValueError(f"Unsupported layout {qkv_layout=}") + raise ValueError(f"Unsupported layout {self.qkv_layout=}") def pad_kv(self, dk, dv, pad_seq_len): - pad = lambda x, npad: jnp.pad(x, npad, "constant", constant_values=0.0) + """Pads dk and dv tensors to a sequence length of pad_seq_len.""" + + def pad(x, npad): + return jnp.pad(x, npad, "constant", constant_values=0.0) + match self.qkv_layout: case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0], [0, 0]] @@ -1148,7 +1165,7 @@ def pad_kv(self, dk, dv, pad_seq_len): npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0]] return pad(dk, npad), pad(dv, npad) case _: - raise ValueError(f"Unsupported layout {qkv_layout=}") + raise ValueError(f"Unsupported layout {self.qkv_layout=}") class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive): @@ -1176,7 +1193,6 @@ def partition( # Call base implementation for non-context parallel mesh. if not is_context_parallel: - print(f"not context parallel falling back to base impl") return FusedAttnFwdPrimitive.partition( attn_bias_type, attn_mask_type, @@ -1212,7 +1228,6 @@ def partition( def impl(q, k, v, bias, q_seqlen, kv_seqlen, q_seq_offsets, k_seq_offsets, seed): - _not_used = jnp.zeros(0, dtype=q.dtype) cp_size = get_cp_axis_size(mesh) cp_rank = get_cp_axis_rank(mesh) @@ -1248,8 +1263,8 @@ def _cross_attn(idx, q, k, v, bias, q_seqlen, kv_seqlen, seed): bias, q_seqlen, kv_seqlen, - _not_used, # q_seq_offsets - _not_used, # k_seq_offsets + q_seq_offsets, + k_seq_offsets, seed, attn_bias_type=attn_bias_type, attn_mask_type=helper.get_adjusted_mask(), @@ -1371,8 +1386,6 @@ def impl( q_seq_offsets, k_seq_offsets, ): - - _not_used = jnp.zeros(0, dtype=q.dtype) cp_size = get_cp_axis_size(mesh) cp_rank = get_cp_axis_rank(mesh) @@ -1398,9 +1411,9 @@ def _cross_attn_bwd( for sub_idx in range(2): k_sliced, v_sliced = helper.slice_kv(k, v, kv_seqlens_for_rank[sub_idx]) - q_seqlen = q_seqlen / (cp_size * 2) + q_seqlen = q_seqlen // (cp_size * 2) num_kv_chunks = kv_max_seqlen // kv_seqlens_for_rank[sub_idx] - kv_seqlen = (kv_seqlen / (cp_size * 2)) * num_kv_chunks + kv_seqlen = (kv_seqlen // (cp_size * 2)) * num_kv_chunks dq_local, dk_local, dv_local, dbias_local = FusedAttnBwdPrimitive.impl( q_split[sub_idx], @@ -1413,8 +1426,8 @@ def _cross_attn_bwd( doutput_split[sub_idx], q_seqlen, kv_seqlen, - _not_used, # q_seq_offsets - _not_used, # k_seq_offsets + q_seq_offsets, + k_seq_offsets, attn_bias_type=attn_bias_type, attn_mask_type=helper.get_adjusted_mask(), qkv_layout=qkv_layout,