From f53203aec5b20ed8d8db37bbaa49545bd8fe556f Mon Sep 17 00:00:00 2001 From: Junwei Yang Date: Thu, 9 May 2024 22:39:12 +0000 Subject: [PATCH] fix sharding on generate cache in prefill results. --- MaxText/layers/attentions.py | 23 +++++++++++++++++++++++ MaxText/maxengine.py | 1 + 2 files changed, 24 insertions(+) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 227267bc7..61f116375 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -401,6 +401,7 @@ def _get_prefill_cache(self, batch, heads, kv_head_size, quantize_kvcache): "cache_kv", ) cache_logical_shape = (batch, self.max_prefill_predict_length, heads, kv_head_size) + cached_key = self.variable( "cache", "cached_prefill_key", @@ -457,6 +458,8 @@ def _get_ar_cache(self, batch, heads, kv_head_size, quantize_kvcache): "cache_kv", ) cache_logical_shape = (batch, cache_length, heads, kv_head_size) + + # TODO(b/339703100): investigate the issue why with_logical_partitioning doesn't enforce sharding cached_key = self.variable( "cache", "cached_ar_key", @@ -464,6 +467,16 @@ def _get_ar_cache(self, batch, heads, kv_head_size, quantize_kvcache): self.cached_kv_shape(cache_logical_shape), dtype, ) + cached_key.value = nn.with_logical_constraint( + cached_key.value, + ( + "cache_sequence", + "cache_heads", + "cache_batch", + "cache_kv", + ), + ) + cached_value = self.variable( "cache", "cached_ar_value", @@ -471,6 +484,16 @@ def _get_ar_cache(self, batch, heads, kv_head_size, quantize_kvcache): self.cached_kv_shape(cache_logical_shape), dtype, ) + cached_value.value = nn.with_logical_constraint( + cached_value.value, + ( + "cache_sequence", + "cache_heads", + "cache_batch", + "cache_kv", + ), + ) + cached_segment_id = self.variable( "cache", "cache_ar_segment_id", diff --git a/MaxText/maxengine.py b/MaxText/maxengine.py index fd067267d..ff217a5a8 100644 --- a/MaxText/maxengine.py +++ b/MaxText/maxengine.py @@ -171,6 +171,7 @@ def prefill( flat_logits, (0, true_length - 1, 0), (flat_logits.shape[0], 1, flat_logits.shape[2]) ) selected_logits = jax.lax.with_sharding_constraint(selected_logits, self.replicated_sharding) + return { "logits": selected_logits, "cache": new_vars["cache"],