Skip to content

Commit

Permalink
happy lint
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Goldfarb <[email protected]>
  • Loading branch information
mgoldfarb-nvidia committed Aug 22, 2024
1 parent 0e35a02 commit 493e0c2
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions transformer_engine/jax/cpp_extensions/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,7 @@ def ag(x):
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
return ag(k), ag(v)
case _:
raise ValueError(f"Unsupported layout {self.qkv_layout=}")
return k, v

def reduce_scatter_dkv(self, dk, dv):
"""Performs a reduce-scatter of dk and dv over context parallel ranks."""
Expand All @@ -960,7 +960,7 @@ def rs(x):
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
return rs(dk), rs(dv)
case _:
raise ValueError(f"Unsupported layout {self.qkv_layout=}")
return dk, dv

def kv_seqlens_for_rank(self, cp_rank, kv_max_seqlen, kv_seqlen_per_subrank):
"""Returns sequence lengths of KV to use for each sub rank of the given cp_rank.
Expand Down Expand Up @@ -1001,7 +1001,7 @@ def sliced(x):
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
return sliced(k), sliced(v)
case _:
raise ValueError(f"Unsupported layout {self.qkv_layout=}")
return k, v

def pad_kv(self, dk, dv, pad_seq_len):
"""Pads dk and dv tensors to a sequence length of pad_seq_len."""
Expand All @@ -1017,7 +1017,7 @@ def pad(x, npad):
npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0]]
return pad(dk, npad), pad(dv, npad)
case _:
raise ValueError(f"Unsupported layout {self.qkv_layout=}")
return dk, dv


class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
Expand Down

0 comments on commit 493e0c2

Please sign in to comment.