Skip to content

Commit

Permalink
Merge pull request #859 from google:ragged_attn_update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 669498336
  • Loading branch information
maxtext authors committed Aug 31, 2024
2 parents bd50865 + 3200a8f commit 30c0d0d
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,14 +749,25 @@ def update_ar_key_value(


ar_cache_update_idx = jnp.squeeze(one_hot_indices)
ar_cache_update_axis = ar_cache_axis_names.index(CACHE_SEQUENCE)
ar_cache_sequence_axis = ar_cache_update_axis = ar_cache_axis_names.index(CACHE_SEQUENCE)
ar_cache_batch_axis = ar_cache_axis_names.index(CACHE_BATCH)

if use_ragged_attention:
cache_locations = [slice(None)] * 4
new_token_locations = [slice(None)] * 4
new_token_locations[ar_cache_sequence_axis] = 0

def key_body(i, val):
return val.at[i, :, lengths[i], :].set(one_token_key_shaped_for_cache[i, :, 0, :])
cache_locations[ar_cache_batch_axis] = i
cache_locations[ar_cache_sequence_axis] = lengths[i]
new_token_locations[ar_cache_batch_axis] = i
return val.at[tuple(cache_locations)].set(one_token_key_shaped_for_cache[tuple(new_token_locations)])

def value_body(i, val):
return val.at[i, :, lengths[i], :].set(one_token_value_shaped_for_cache[i, :, 0, :])
cache_locations[ar_cache_batch_axis] = i
cache_locations[ar_cache_sequence_axis] = lengths[i]
new_token_locations[ar_cache_batch_axis] = i
return val.at[tuple(cache_locations)].set(one_token_value_shaped_for_cache[tuple(new_token_locations)])

cached_key_var.value = jax.lax.fori_loop(0, one_token_key_shaped_for_cache.shape[0], key_body, cached_key_var.value, unroll=8)
cached_value_var.value = jax.lax.fori_loop(0, one_token_value_shaped_for_cache.shape[0], value_body, cached_value_var.value, unroll=8)
Expand Down

0 comments on commit 30c0d0d

Please sign in to comment.