Skip to content

Commit

Permalink
fix sharding on generate cache in prefill results.
Browse files Browse the repository at this point in the history
  • Loading branch information
jwyang-google committed May 9, 2024
1 parent d590328 commit f53203a
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
23 changes: 23 additions & 0 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def _get_prefill_cache(self, batch, heads, kv_head_size, quantize_kvcache):
"cache_kv",
)
cache_logical_shape = (batch, self.max_prefill_predict_length, heads, kv_head_size)

cached_key = self.variable(
"cache",
"cached_prefill_key",
Expand Down Expand Up @@ -457,20 +458,42 @@ def _get_ar_cache(self, batch, heads, kv_head_size, quantize_kvcache):
"cache_kv",
)
cache_logical_shape = (batch, cache_length, heads, kv_head_size)

# TODO(b/339703100): investigate the issue why with_logical_partitioning doesn't enforce sharding
cached_key = self.variable(
"cache",
"cached_ar_key",
nn.with_logical_partitioning(jnp.zeros, kv_cache_layout),
self.cached_kv_shape(cache_logical_shape),
dtype,
)
cached_key.value = nn.with_logical_constraint(
cached_key.value,
(
"cache_sequence",
"cache_heads",
"cache_batch",
"cache_kv",
),
)

cached_value = self.variable(
"cache",
"cached_ar_value",
nn.with_logical_partitioning(jnp.zeros, kv_cache_layout),
self.cached_kv_shape(cache_logical_shape),
dtype,
)
cached_value.value = nn.with_logical_constraint(
cached_value.value,
(
"cache_sequence",
"cache_heads",
"cache_batch",
"cache_kv",
),
)

cached_segment_id = self.variable(
"cache",
"cache_ar_segment_id",
Expand Down
1 change: 1 addition & 0 deletions MaxText/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def prefill(
flat_logits, (0, true_length - 1, 0), (flat_logits.shape[0], 1, flat_logits.shape[2])
)
selected_logits = jax.lax.with_sharding_constraint(selected_logits, self.replicated_sharding)

return {
"logits": selected_logits,
"cache": new_vars["cache"],
Expand Down

0 comments on commit f53203a

Please sign in to comment.