Skip to content

Commit

Permalink
Merge pull request #885 from google:shralex_pylint_nit
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 674330000
  • Loading branch information
maxtext authors committed Sep 13, 2024
2 parents 7cc7f2c + 5677957 commit a2553fa
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
from layers import quantizations


# pylint: disable=line-too-long, g-doc-args, g-doc-return-or-yield, bad-continuation, g-inconsistent-quotes
# pytype: disable=attribute-error


class AttentionType(enum.Enum):
GLOBAL = "global"
LOCAL_SLIDING = "local_sliding"
Expand Down Expand Up @@ -79,9 +83,6 @@ class AttentionType(enum.Enum):

dynamic_vector_slice_in_dim = jax.vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))

# pylint: disable=line-too-long, g-doc-args, g-doc-return-or-yield, bad-continuation, g-inconsistent-quotes
# pytype: disable=attribute-error


def validate_compute_axis_order(s: AxisIdxes) -> None:
valid_compute_axis_order = ((0,1,2,3), (0,2,1,3))
Expand Down Expand Up @@ -235,7 +236,7 @@ def apply_attention(self, query: Array, key: Array | KVTensor, value: Array | KV
else:
raise ValueError(f"Unexpected attention kernel {self.attention_kernel=}.")


def ragged_attention(self, query: Array, key: Array | KVTensor, value: Array | KVTensor, lengths: Array, block_size: int) -> tuple[Array, Array, Array]:
"""Ragged Attention."""
if isinstance(query, KVTensor) or isinstance(query, KVTensor):
Expand Down Expand Up @@ -753,9 +754,9 @@ def update_ar_key_value(
ar_cache_batch_axis = ar_cache_axis_names.index(CACHE_BATCH)

if use_ragged_attention:
cache_locations = [slice(None)] * 4
cache_locations = [slice(None)] * 4
new_token_locations = [slice(None)] * 4
new_token_locations[ar_cache_sequence_axis] = 0
new_token_locations[ar_cache_sequence_axis] = 0

def key_body(i, val):
cache_locations[ar_cache_batch_axis] = i
Expand All @@ -772,7 +773,7 @@ def value_body(i, val):
cached_key_var.value = jax.lax.fori_loop(0, one_token_key_shaped_for_cache.shape[0], key_body, cached_key_var.value, unroll=8)
cached_value_var.value = jax.lax.fori_loop(0, one_token_value_shaped_for_cache.shape[0], value_body, cached_value_var.value, unroll=8)

else:
else:
one_hot_indices = one_hot_indices.astype(int)
cached_key_var.value = jax.lax.dynamic_update_index_in_dim(
cached_key_var.value, one_token_key_shaped_for_cache, ar_cache_update_idx, ar_cache_update_axis)
Expand All @@ -781,7 +782,7 @@ def value_body(i, val):

cached_key_var.value = nn.with_logical_constraint(cached_key_var.value, ar_cache_axis_names)
cached_value_var.value = nn.with_logical_constraint(cached_value_var.value, ar_cache_axis_names)


if self.kv_quant:
ar_cache_scale_axis_names = self.transpose_tuple(self.cache_scale_logical_axis_names, self.ar_cache_axis_order)
Expand Down

0 comments on commit a2553fa

Please sign in to comment.