diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index a8c954518..dfa03bc28 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -124,7 +124,7 @@ cast_logits_to_fp32: True # whether to cast the logits to fp32. The higher preci # mixture of experts (moe) num_experts: 1 num_experts_per_tok: 1 -megablox: True +megablox: False capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default load_balance_loss_weight: 0.01 # weight for the load balance loss @@ -370,7 +370,11 @@ learning_rate_schedule_steps: -1 # By default the length of the schedule is set max_target_length: 2048 # Maximum sequence length max_prefill_predict_length: 64 # Maximum length for the prefill when doing autoregression -prompt: "I love to" # Prompt for language model sampling. +# prompt: "I love to" +prompt: "The old lighthouse keeper, Silas, squinted at the churning grey sea. For fifty years, he'd watched the waves crash against the jagged cliffs of this remote island, his only company the rhythmic pulse of the lamp and the cries of gulls. But tonight, something was different. A strange, ethereal glow emanated from the depths, pulsing in sync with a low hum that vibrated through the stone tower. It wasn't the familiar bioluminescence of sea creatures; this was something... other.\n\nSilas, a man weathered by solitude and sea storms, felt a prickle of unease, a feeling he hadn't experienced since the day his wife, Elara, disappeared into the sea decades ago, lured by a siren's song, or so the villagers whispered. He dismissed it then as grief and folklore, but now, staring into the glowing abyss, he wasn't so sure.\n\nSuddenly, the lighthouse beam flickered and died, plunging the island into darkness. The humming intensified, and the glow from the sea grew brighter, illuminating a swirling vortex of iridescent colors. From the center of the vortex, a shape began to emerge - a vessel, unlike any he had ever seen. It was not of wood or steel, but of a shimmering, crystalline material that seemed to absorb and reflect the surrounding light, constantly shifting in hue and form.\n\nAs the vessel drew closer, Silas could make out figures moving within. They were humanoid, but with an ethereal grace that defied human movement. Their skin shimmered with the same iridescent light as the vessel, and their eyes glowed with an otherworldly luminescence.\n\nFear warred with curiosity in Silas's heart. He had lived a solitary life, but he had never turned away a soul in need. Yet, these beings were clearly not of this world. Could they be the source of the strange energy, the cause of the lighthouse's failure? Were they friend or foe?\n\nOne of the figures raised a hand, and a voice, melodic and ancient, echoed in Silas's mind, not through his ears, but directly into his consciousness. \"Greetings, Silas, keeper of the light. We come in peace, seeking knowledge and understanding. We are the Luminians, and we have traveled far to reach this place.\"\n\nSilas, despite his fear, felt compelled to respond. He had always been drawn to the mysteries of the sea, and now, the greatest mystery of all had come to his doorstep. He gripped the railing of the lighthouse balcony, his knuckles white, and prepared to face whatever the dawn might bring.\n\n**Consider the following in your narrative:**\n\n* **Silas's Backstory:** Expand on Silas's past, his relationship with Elara, and the impact of her disappearance. How does his past influence his reaction to the Luminians? What drove him to become a lighthouse keeper?\n* **The Luminians:** Describe the Luminians in more detail. What is their culture like? What is their technology based on? What is their purpose in coming to Earth, and specifically to Silas's island? What do they mean by \"seeking knowledge and understanding\"?\n* **The Island's Secret:** The island is more than just a rock in the sea. It holds a secret, a connection to the Luminians, that Silas is about to discover. What is this secret, and how is it tied to the strange energy emanating from the sea?\n* **Conflict and Choice:** Silas must choose whether to trust the Luminians or to resist them. What factors influence his decision? What are the potential consequences of his choice?\n* **The Siren's Song:** Is there a connection between Elara's disappearance and the Luminians? Was the \"siren's song\" a real phenomenon, and if so, what is its nature?\n\n**Possible Story Arcs:**\n\n* **Revelation:** Silas learns the truth about Elara's disappearance and the history of the island.\n* **Alliance:** Silas chooses to help the Luminians, leading to a partnership that could change the fate of both their peoples.\n* **Betrayal:** The Luminians' motives are not what they seem, and Silas must fight to protect the island and perhaps even the world.\n* **Transformation:** Silas undergoes a personal transformation, shedding his old life and embracing a new destiny connected to the Luminians and the island's secret.\n\nThis prompt provides a foundation for a compelling story filled with mystery, suspense, and the potential for profound change. Explore the themes of isolation vs. connection, fear of the unknown vs. the allure of discovery, and the power of choice in the face of the extraordinary. Let the story unfold, revealing the secrets of the island, the truth about the Luminians, and the destiny of Silas, the old lighthouse keeper who found himself at the center of an interdimensional encounter. What will he do now that the light has gone out, only to be replaced by something far stranger and more profound? The answers lie within your narrative. Let the journey begin. Remember the sea holds secrets deeper than any man can fathom, and some are best left undisturbed. But destiny, it seems, has other plans for Silas. This is his story now. This is the next chapter. What will it be? This is where the past meets the future. This is where the ordinary meets the extraordinary. This is where one man's life changes forever. This is the moment everything changes. This is the turning point. This is just the beginning. The old lighthouse keeper, Silas, squinted at the churning grey sea. For fifty years, he'd watched the waves crash against the jagged cliffs of this remote island, his only company the rhythmic pulse of the lamp and the cries of gulls. But tonight, something was different. A strange, ethereal glow emanated from the depths, pulsing in sync with a low hum that vibrated through the stone tower. It wasn't the familiar bioluminescence of sea creatures; this was something... other.\n\nSilas, a man weathered by solitude and sea storms, felt a prickle of unease, a feeling he hadn't experienced since the day his wife, Elara, disappeared into the sea decades ago, lured by a siren's song, or so the villagers whispered. He dismissed it then as grief and folklore, but now, staring into the glowing abyss, he wasn't so sure.\n\nSuddenly, the lighthouse beam flickered and died, plunging the island into darkness. The humming intensified, and the glow from the sea grew brighter, illuminating a swirling vortex of iridescent colors. From the center of the vortex, a shape began to emerge - a vessel, unlike any he had ever seen. It was not of wood or steel, but of a shimmering, crystalline material that seemed to absorb and reflect the surrounding light, constantly shifting in hue and form.\n\nAs the vessel drew closer, Silas could make out figures moving within. They were humanoid, but with an ethereal grace that defied human movement. Their skin shimmered with the same iridescent light as the vessel, and their eyes glowed with an otherworldly luminescence.\n\nFear warred with curiosity in Silas's heart. He had lived a solitary life, but he had never turned away a soul in need. Yet, these beings were clearly not of this world. Could they be the source of the strange energy, the cause of the lighthouse's failure? Were they friend or foe?\n\nOne of the figures raised a hand, and a voice, melodic and ancient, echoed in Silas's mind, not through his ears, but directly into his consciousness. \"Greetings, Silas, keeper of the light. We come in peace, seeking knowledge and understanding. We are the Luminians, and we have traveled far to reach this place.\"\n\nSilas, despite his fear, felt compelled to respond. He had always been drawn to the mysteries of the sea, and now, the greatest mystery of all had come to his doorstep. He gripped the railing of the lighthouse balcony, his knuckles white, and prepared to face whatever the dawn might bring.\n\n**Consider the following in your narrative:**\n\n* **Silas's Backstory:** Expand on Silas's past, his relationship with Elara, and the impact of her disappearance. How does his past influence his reaction to the Luminians? What drove him to become a lighthouse keeper?\n* **The Luminians:** Describe the Luminians in more detail. What is their culture like? What is their technology based on? What is their purpose in coming to Earth, and specifically to Silas's island? What do they mean by \"seeking knowledge and understanding\"?\n* **The Island's Secret:** The island is more than just a rock in the sea. It holds a secret, a connection to the Luminians, that Silas is about to discover. What is this secret, and how is it tied to the strange energy emanating from the sea?\n* **Conflict and Choice:** Silas must choose whether to trust the Luminians or to resist them. What factors influence his decision? What are the potential consequences of his choice?\n* **The Siren's Song:** Is there a connection between Elara's disappearance and the Luminians? Was the \"siren's song\" a real phenomenon, and if so, what is its nature?\n\n**Possible Story Arcs:**\n\n* **Revelation:** Silas learns the truth about Elara's disappearance and the history of the island.\n* **Alliance:** Silas chooses to help the Luminians, leading to a partnership that could change the fate of both their peoples.\n* **Betrayal:** The Luminians' motives are not what they seem, and Silas must fight to protect the island and perhaps even the world.\n* **Transformation:** Silas undergoes a personal transformation, shedding his old life and embracing a new destiny connected to the Luminians and the island's secret.\n\nThis prompt provides a foundation for a compelling story filled with mystery, suspense, and the potential for profound change. Explore the themes of isolation vs. connection, fear of the unknown vs. the allure of discovery, and the power of choice in the face of the extraordinary. Let the story unfold, revealing the secrets of the island, the truth about the Luminians, and the destiny of Silas, the old lighthouse keeper who found himself at the center of an interdimensional encounter. What will he do now that the light has gone out, only to be replaced by something far stranger and more profound? The answers lie within your narrative. Let the journey begin. Remember the sea holds secrets deeper than any man can fathom, and some are best left undisturbed. But destiny, it seems, has other plans for Silas. This is his story now. This is the next chapter. What will it be? This is where the past meets the future. This is where the ordinary meets the extraordinary. This is where one man's life changes forever. This is the moment everything changes. This is the turning point. This is just the beginning." # Prompt for language model sampling. + +# prompt: "" + load_from_prefill_dir: False # If true, decode.py doesn't "prefill" but just reads from directory prefill_cache_dir: "" # If set and load_from_prefill_dir, decode.py reads from directory. If set, decode.py writes to directory autoregressive_decode_assert: "" @@ -533,3 +537,6 @@ sa_use_fused_bwd_kernel: False sa_q_layout: "HEAD_DIM_MINOR" sa_k_layout: "HEAD_DIM_MINOR" sa_v_layout: "HEAD_DIM_MINOR" + +chunk_size: 512 +use_chunked_prefill: True diff --git a/MaxText/decode.py b/MaxText/decode.py index ef2f2fc79..176ec918f 100644 --- a/MaxText/decode.py +++ b/MaxText/decode.py @@ -14,6 +14,7 @@ """CLI utility for running inference on a single stream""" +from collections import defaultdict import jax import max_utils @@ -24,6 +25,38 @@ from typing import Sequence from absl import app +from flax import struct +import common_types +Array = common_types.Array +import jax.numpy as jnp + +@struct.dataclass +class ChunkMetadata: + tokens_entire_sequence: Array + true_length: int + true_length_chunk: int + chunk_padded: Array + processed: bool + chunk_seq_start_index: int + + +def create_chunked_metadata(tokens, true_length, chunk_size): + start = 0 + chunk_metadata_list = [] + + while start < len(tokens): + end = min(start + chunk_size, true_length) + cur_chunk_tokens = tokens[start:end] + + chunk_metadata_list.append(ChunkMetadata(tokens_entire_sequence=tokens, + true_length=true_length, + true_length_chunk=chunk_size, + chunk_padded=cur_chunk_tokens, + processed=False, + chunk_seq_start_index=start)) + + start = start + chunk_size + return chunk_metadata_list def main(argv: Sequence[str]) -> None: @@ -44,29 +77,57 @@ def main(argv: Sequence[str]) -> None: metadata = engine.get_tokenizer() tokenizer_model = engine.build_tokenizer(metadata) tokens, true_length = tokenizer_model.encode(text, is_bos=True, prefill_lengths=[config.max_prefill_predict_length]) + chunk_size = config.chunk_size + tokens = tokens[:config.max_prefill_predict_length] + print("tokens are ", tokens, len(tokens)) + true_length = config.max_prefill_predict_length + chunked_metadata_list = create_chunked_metadata(tokens, true_length, chunk_size) assert true_length <= config.max_prefill_predict_length, "can't take too many tokens" assert config.quantization != "fp8", "fp8 on NVIDIA GPUs is not supported in decode.py yet" # Split RNG before calling prefill rng, rng_prefill = jax.random.split(rng) - prefill_result, first_token = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length, rng=rng_prefill) slot = 0 + rng, rng_init_decode = jax.random.split(rng) + prefill_result = None + for i,chunk_metadata in enumerate(chunked_metadata_list): + if i == 0: + prefill_result, first_token = engine.prefill(existing_prefix=prefill_result, + params=params, + padded_tokens=chunk_metadata.chunk_padded, + true_length=chunk_size, + rng=rng, + position_mask_cur=None) + else: + prefill_result, first_token = engine.prefill(existing_prefix=prefill_result, + params=params | {"cache": prefill_result["cache"]}, + padded_tokens=chunk_metadata.chunk_padded, + true_length=chunk_size, + rng=rng, + position_mask_cur=None) + + import pdb + pdb.set_trace() + rng, rng_init_decode = jax.random.split(rng) decode_state = engine.init_decode_state(rng_init_decode) decode_state = engine.insert(prefill_result, decode_state, slot=slot) - + steps = range(config.max_prefill_predict_length, config.max_target_length) + print("total steps ", steps) sampled_tokens_list = [] sampled_tokens_list.append(first_token) - for _ in steps: + for s in steps: rng, rng_generate = jax.random.split(rng) + decode_state, sampled_tokens = engine.generate(params, decode_state, rng=rng_generate) sampled_tokens_list.append(sampled_tokens) results = [sampled_tokens.get_result_at_slot(slot).tokens.item() for sampled_tokens in sampled_tokens_list] + print("len results ", len(results)) output = tokenizer_model.decode(results) - print(f"Input `{text}` -> `{output}`") + print(f"Output ************************** -> `{output}`") if config.autoregressive_decode_assert != "": assert ( diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index f5990e7d9..607e88a05 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -173,7 +173,7 @@ def check_attention_inputs(self, query: Array, key: Array | KVTensor, value: Arr # Following Pallas MHA Flash Attention Reference. # https://github.com/jax-ml/jax/blob/main/jax/experimental/pallas/ops/tpu/flash_attention.py # This mask models (1) separate sequences (decoder_segment_ids) and (2) causality - def generate_attention_mask(self, query, key, decoder_segment_ids: Array | None, model_mode: str) -> Array | None: + def generate_attention_mask(self, query, key, decoder_segment_ids: Array | None, model_mode: str, existing_prefix=None) -> Array | None: mask = None if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE: mask = decoder_segment_ids[:, None, None, None, :] == common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR @@ -190,11 +190,38 @@ def generate_attention_mask(self, query, key, decoder_segment_ids: Array | None, row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0) col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1) causal_mask = (col_ids <= row_ids)[None, None, None, :, :] + # import pdb + # pdb.set_trace() output_mask = None - if (mask is not None) and (causal_mask is not None): + if self.config.use_chunked_prefill and model_mode == common_types.MODEL_MODE_PREFILL: + _, q_seq_len, _, _ = query.shape + _, kv_seq_len, _, _ = key.shape + mask_shape = (q_seq_len, q_seq_len) + row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0) + col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1) + causal_mask = (col_ids <= row_ids).astype(jnp.int32) + + next_pos = 0 + if existing_prefix != None: + next_pos = existing_prefix['next_pos'][0][0] + output_mask = jnp.ones((q_seq_len, kv_seq_len), jnp.int32) + output_mask = jax.lax.dynamic_update_slice(output_mask, causal_mask, (0,next_pos)) + else: + output_mask = jnp.zeros((q_seq_len, kv_seq_len), jnp.int32) + output_mask = jax.lax.dynamic_update_slice(output_mask, causal_mask, (0,next_pos)) + + output_mask = output_mask[None, None, None, :, :] + return jnp.where(output_mask, 0.0, DEFAULT_MASK_VALUE) if output_mask is not None else None + + + + elif (mask is not None) and (causal_mask is not None): + # import pdb + # pdb.set_trace() output_mask = jnp.logical_and(mask, causal_mask) + elif mask is not None: output_mask = mask elif causal_mask is not None: @@ -219,6 +246,7 @@ def apply_attention( lengths: Array | None, model_mode: str, use_ragged_attention: bool = False, + existing_prefix=None, ): self.check_attention_inputs(query, key, value) length = query.shape[-3] @@ -232,7 +260,7 @@ def apply_attention( or (self.attention_kernel == "autoselected" and model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE) or (self.attention_kernel == "autoselected" and length < 128) ): - return self.apply_attention_dot(query, key, value, decoder_segment_ids, model_mode) + return self.apply_attention_dot(query, key, value, decoder_segment_ids, model_mode, existing_prefix) elif self.attention_kernel == "flash" or self.attention_kernel == "autoselected": if isinstance(key, KVTensor): key = key.dequant() @@ -473,6 +501,7 @@ def apply_attention_dot( value: Array | KVTensor, decoder_segment_ids: Array | None, model_mode: str = common_types.MODEL_MODE_TRAIN, + existing_prefix=None, ): """Apply Attention.""" validate_compute_axis_order(self.compute_axis_order) @@ -485,6 +514,7 @@ def apply_attention_dot( q_seq_len = query.shape[1] attn_weights = self.qk_product(query, key, q_seq_len, model_mode) + # 1,8,4,512,1024 if self.attn_logits_soft_cap: attn_weights = jnp.tanh(attn_weights / self.attn_logits_soft_cap) @@ -493,7 +523,11 @@ def apply_attention_dot( # Casting softmaxt computation for float32 for model stability. if model_mode == common_types.MODEL_MODE_TRAIN and self.float32_logits: attn_weights = attn_weights.astype(jnp.float32) - attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode) + + # import pdb + # pdb.set_trace() + + attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode,existing_prefix=existing_prefix) if attn_mask is not None: attn_weights = apply_mask_to_logits(attn_weights, attn_mask) return self.compute_local_attention(attn_weights, value, q_seq_len, model_mode) @@ -749,6 +783,7 @@ def kv_cache_prefill( key: Array, value: Array, decoder_segment_ids: Array, + existing_prefix=None, ): """In prefill mode, we zero out the existing cache, run the computation and prepare the cache as necessary. @@ -768,6 +803,8 @@ def kv_cache_prefill( cached_prefill_key_vars, cached_prefill_value_vars, cached_prefill_segment_id_var = self._get_prefill_cache_vars( batch, heads, kv_head_size, common_types.MODEL_MODE_PREFILL ) + # import pdb + # pdb.set_trace() # TODO: Find a way to not enable the ar cache for prefill mode. _ = self._get_ar_cache_vars(batch, heads, kv_head_size, common_types.MODEL_MODE_PREFILL) # initialize it now @@ -783,14 +820,31 @@ def kv_cache_prefill( cached_prefill_key_vars[1].value = key_scale_shaped_for_cache cached_prefill_value_vars[1].value = value_scale_shaped_for_cache - cached_prefill_key_vars[0].value = key_shaped_for_cache - cached_prefill_value_vars[0].value = value_shaped_for_cache - - if decoder_segment_ids is not None: + # import pdb + # pdb.set_trace() + # s, n, b, d + next_pos = 0 + if existing_prefix != None: + next_pos = existing_prefix['next_pos'][0][0] + # cached_prefill = ( + # self.get_cached_values(cached_prefill_key_vars, key.dtype, self.prefill_cache_axis_order), + # self.get_cached_values(cached_prefill_value_vars, value.dtype, self.prefill_cache_axis_order), + cached_key = self.get_cached_values(cached_prefill_key_vars, key.dtype, self.prefill_cache_axis_order) + cached_value = self.get_cached_values(cached_prefill_value_vars, value.dtype, self.prefill_cache_axis_order) + cached_key_value = jnp.transpose(cached_key, (1,2,0,3)) + cached_value_value = jnp.transpose(cached_value, (1,2,0,3)) + + cached_prefill_key_vars[0].value = jax.lax.dynamic_update_slice(cached_key_value, key_shaped_for_cache, (next_pos, 0, 0, 0)) + + cached_prefill_value_vars[0].value = jax.lax.dynamic_update_slice(cached_value_value, value_shaped_for_cache, (next_pos, 0, 0, 0)) cached_prefill_segment_id_var.value = decoder_segment_ids - - return key, value, decoder_segment_ids - + return jnp.transpose(cached_prefill_key_vars[0].value, (2,0,1,3)), jnp.transpose(cached_prefill_value_vars[0].value, (2,0,1,3)), cached_prefill_segment_id_var.value + else: + cached_prefill_key_vars[0].value = jax.lax.dynamic_update_slice(cached_prefill_key_vars[0].value, key_shaped_for_cache, (next_pos, 0, 0, 0)) + cached_prefill_value_vars[0].value = jax.lax.dynamic_update_slice(cached_prefill_value_vars[0].value, value_shaped_for_cache, (next_pos, 0, 0, 0)) + cached_prefill_segment_id_var.value = decoder_segment_ids + return jnp.transpose(cached_prefill_key_vars[0].value, (2,0,1,3)), jnp.transpose(cached_prefill_value_vars[0].value, (2,0,1,3)), cached_prefill_segment_id_var.value + def update_ar_key_value( self, one_token_key: Array, @@ -819,10 +873,16 @@ def update_ar_key_value( # In order to update the key, value caches with the current key and # value, we reshape the one_token_key and one_token_value + # import pdb + # pdb.set_trace() + # one_token_key >> (4, 1, 8, 128) one_token_key_shaped_for_cache = jnp.transpose(one_token_key, self.ar_cache_axis_order) + # (1, 8, 4, 128) + one_token_value_shaped_for_cache = jnp.transpose(one_token_value, self.ar_cache_axis_order) ar_cache_axis_names = self.transpose_tuple(self.cache_logical_axis_names, self.ar_cache_axis_order) + if self.kv_quant: one_token_key_shaped_for_cache, one_token_key_scale_shaped_for_cache = self.kv_quant.quantize( one_token_key_shaped_for_cache, ar_cache_axis_names @@ -956,6 +1016,9 @@ def kv_cache_autoregressive( self.get_cached_values(cached_prefill_value_vars, value.dtype, self.prefill_cache_axis_order), cached_prefill_segment_id_var.value, ) + + # import pdb + # pdb.set_trace() cached_ar = ( self.get_cached_values(cached_ar_key_vars, key.dtype, self.ar_cache_axis_order), @@ -966,7 +1029,7 @@ def kv_cache_autoregressive( return cached_prefill, cached_ar def kv_cache( - self, key: Array, value: Array, decoder_segment_ids: Array, model_mode: str, use_ragged_attention: bool = False + self, key: Array, value: Array, decoder_segment_ids: Array, model_mode: str, use_ragged_attention: bool = False, existing_prefix=None, ) -> tuple: """KV cache takes the current state and updates the state accordingly. @@ -991,7 +1054,7 @@ def kv_cache( if model_mode == common_types.MODEL_MODE_TRAIN: return (key, value, decoder_segment_ids), None elif model_mode == common_types.MODEL_MODE_PREFILL: - return self.kv_cache_prefill(key, value, decoder_segment_ids), None + return self.kv_cache_prefill(key, value, decoder_segment_ids, existing_prefix), None elif model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE: return self.kv_cache_autoregressive(key, value, use_ragged_attention) else: @@ -1021,11 +1084,10 @@ def normalize_attention(self, local_outs, local_maxes, local_sums): return attn_out @nn.compact - def __call__(self, query, key, value, decoder_segment_ids, model_mode): - prefill_kv_cache, ar_kv_cache = self.kv_cache( - key, value, decoder_segment_ids, model_mode, use_ragged_attention=self.use_ragged_attention - ) + def __call__(self, query, key, value, decoder_segment_ids, model_mode, existing_prefix=None): + prefill_kv_cache, ar_kv_cache = self.kv_cache(key, value, decoder_segment_ids, model_mode, use_ragged_attention=self.use_ragged_attention, existing_prefix=existing_prefix) + # jax.debug.print("infos******** {model_mode} {key_shape} {existing_prefix_none}", model_mode=model_mode, key_shape=prefill_kv_cache[0].shape, existing_prefix_none=existing_prefix==None) prefill_unnormalized_output, prefill_exponentials_max, prefill_exponentials_sum = self.apply_attention( query=query, key=prefill_kv_cache[0], @@ -1034,12 +1096,17 @@ def __call__(self, query, key, value, decoder_segment_ids, model_mode): lengths=None, model_mode=model_mode, use_ragged_attention=self.use_ragged_attention, + existing_prefix=existing_prefix, ) # Return the "prefill" cache if it actually the combined prefill+ar kv cache + # jax.debug.print("ar_kv_cache is {ar_kv_cache}", ar_kv_cache=ar_kv_cache) if ar_kv_cache is None: if prefill_exponentials_sum is not None: - return prefill_unnormalized_output / prefill_exponentials_sum + o = prefill_unnormalized_output / prefill_exponentials_sum + # jax.debug.print("output in prefill shape {o_shape}", o_shape=o.shape) + # jax.debug.print("output in prefill values: {o_shape}", o_shape=o[:,:,:,-1]) + return o return prefill_unnormalized_output ar_unnormalized_output, ar_exponentials_max, ar_exponentials_sum = self.apply_attention( @@ -1050,6 +1117,7 @@ def __call__(self, query, key, value, decoder_segment_ids, model_mode): lengths=ar_kv_cache[3], model_mode=model_mode, use_ragged_attention=self.use_ragged_attention, + existing_prefix=existing_prefix, ) if ar_unnormalized_output is not None: @@ -1244,6 +1312,7 @@ def __call__( *, model_mode: str = common_types.MODEL_MODE_TRAIN, deterministic: bool = False, + existing_prefix=None, ): """Applies Attention on the input data. @@ -1267,10 +1336,13 @@ def __call__( Returns: output of shape `[batch, length, q_features]`. """ + # jax.debug.print("query in attention call {inputs_q}", inputs_q=inputs_q) + # jax.debug.print("key in attention call {inputs_kv}", inputs_kv=inputs_kv) + inputs_q = nn.with_logical_constraint(inputs_q, self.input_axis_names) inputs_kv = nn.with_logical_constraint(inputs_kv, self.input_axis_names) - # apply projection. + # apply projection. if self.config.fused_qkv: query, key, value = self.qkv_projection(inputs_q, proj_name="qkv_proj") else: @@ -1320,7 +1392,7 @@ def __call__( ragged_block_size=self.ragged_block_size, ) - out = attention_op(query, key, value, decoder_segment_ids, model_mode) + out = attention_op(query, key, value, decoder_segment_ids, model_mode, existing_prefix=existing_prefix) out = nn.with_logical_constraint(out, self.out_axis_names) diff --git a/MaxText/layers/mistral.py b/MaxText/layers/mistral.py index 5fbb9e1e1..1313bc45c 100644 --- a/MaxText/layers/mistral.py +++ b/MaxText/layers/mistral.py @@ -66,6 +66,7 @@ def __call__( decoder_positions, deterministic, model_mode, + existing_prefix=None, ): cfg = self.config mesh = self.mesh @@ -100,7 +101,6 @@ def __call__( quant=self.quant, kv_quant=quantizations.configure_kv_quant(cfg), ) - attention_lnx = attention_layer( lnx, lnx, @@ -108,6 +108,7 @@ def __call__( decoder_segment_ids=decoder_segment_ids, deterministic=deterministic, model_mode=model_mode, + existing_prefix = existing_prefix, ) attention_lnx = nn.with_logical_constraint( diff --git a/MaxText/layers/models.py b/MaxText/layers/models.py index 4c2046c1f..cb43aadf5 100644 --- a/MaxText/layers/models.py +++ b/MaxText/layers/models.py @@ -274,11 +274,11 @@ def __call__( decoder_segment_ids=None, deterministic=False, model_mode=common_types.MODEL_MODE_TRAIN, + existing_prefix=None, ): cfg = self.config mesh = self.mesh assert decoder_input_tokens.ndim == 2 # [batch, len] - # [batch, length] -> [batch, length, emb_dim] y = self.shared_embedding(decoder_input_tokens.astype("int32")) y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(y, deterministic=deterministic) @@ -399,6 +399,7 @@ def __call__( decoder_positions, deterministic, model_mode, + existing_prefix=existing_prefix, ) y = self.get_norm_layer()( @@ -472,6 +473,7 @@ def __call__( decoder_segment_ids=None, enable_dropout=True, model_mode=common_types.MODEL_MODE_TRAIN, + existing_prefix=None, ): """Applies Transformer decoder-branch on encoded-input and target.""" @@ -487,5 +489,6 @@ def __call__( decoder_segment_ids=decoder_segment_ids, deterministic=not enable_dropout, model_mode=model_mode, + existing_prefix=existing_prefix, ) return logits diff --git a/MaxText/layers/test.py b/MaxText/layers/test.py new file mode 100644 index 000000000..2b4089f87 --- /dev/null +++ b/MaxText/layers/test.py @@ -0,0 +1,135 @@ +def kv_cache_prefill_chunked( + self, + key: Array, + value: Array, + use_ragged_attention: bool = False, + ): + """In autoregressive mode, we update the cache for this entry and + then return the full cache. + + Args: + key: in shape [b, 1, n, d]. + value: in shape [b, 1, n, d]. + decoder_segment_ids: [b, 1] -- marking segment ids for tokens + + Returns: + tuple of (key, value, segment_id) for both prefill and ar cache, + Raises: + ValueError: when key/value shape is not [batch, 1, num_heads, heads_dim]. + """ + batch, _, heads, kv_head_size = key.shape + assert key.dtype == value.dtype, "Key and Value Dtypes should match." + + cached_prefill_key_vars, cached_prefill_value_vars, cached_prefill_segment_id_var = self._get_prefill_cache_vars( + batch, heads, kv_head_size, common_types.MODEL_MODE_PREFILL + ) + + self.update_prefill_key_value( + key, + value, + cached_prefill_key_vars, + cached_prefill_value_vars, + cached_prefill_segment_id_var.value, + ) + active_indicator = jnp.zeros((batch, 1), dtype=jnp.int32) + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR + cached_ar_segment_id_var.value = jax.lax.dynamic_update_index_in_dim( + cached_ar_segment_id_var.value, active_indicator, jnp.squeeze(cache_ar_index_var.value), 1 + ) + cache_ar_index_var.value = jnp.mod( + cache_ar_index_var.value + 1, self.max_target_length - self.max_prefill_predict_length + ) + cache_ar_lengths_var.value = cache_ar_lengths_var.value.at[:].add(1) + + # The below retrieves the existing prefill cache variables, not creating new ones + cached_prefill_key_vars, cached_prefill_value_vars, cached_prefill_segment_id_var = self._get_prefill_cache_vars( + batch, heads, kv_head_size, common_types.MODEL_MODE_AUTOREGRESSIVE + ) + + cached_prefill = ( + self.get_cached_values(cached_prefill_key_vars, key.dtype, self.prefill_cache_axis_order), + self.get_cached_values(cached_prefill_value_vars, value.dtype, self.prefill_cache_axis_order), + cached_prefill_segment_id_var.value, + ) + + import pdb + pdb.set_trace() + + cached_ar = ( + self.get_cached_values(cached_ar_key_vars, key.dtype, self.ar_cache_axis_order), + self.get_cached_values(cached_ar_value_vars, value.dtype, self.ar_cache_axis_order), + cached_ar_segment_id_var.value, + cache_ar_lengths_var.value, + ) + return cached_prefill, cached_ar + + + + + + +def kv_cache_autoregressive( + self, + key: Array, + value: Array, + use_ragged_attention: bool = False, + ): + """In autoregressive mode, we update the cache for this entry and + then return the full cache. + + Args: + key: in shape [b, 1, n, d]. + value: in shape [b, 1, n, d]. + decoder_segment_ids: [b, 1] -- marking segment ids for tokens + + Returns: + tuple of (key, value, segment_id) for both prefill and ar cache, + Raises: + ValueError: when key/value shape is not [batch, 1, num_heads, heads_dim]. + """ + batch, sequence, heads, kv_head_size = key.shape + if sequence != 1: + raise ValueError(f"Sequence length should be 1 during autoregression, got {sequence=}") + + cached_ar_key_vars, cached_ar_value_vars, cached_ar_segment_id_var, cache_ar_index_var, cache_ar_lengths_var = ( + self._get_ar_cache_vars(batch, heads, kv_head_size, common_types.MODEL_MODE_AUTOREGRESSIVE) + ) + + self.update_ar_key_value( + key, + value, + cached_ar_key_vars, + cached_ar_value_vars, + cache_ar_index_var.value, + cache_ar_lengths_var.value, + use_ragged_attention, + ) + active_indicator = jnp.zeros((batch, 1), dtype=jnp.int32) + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR + cached_ar_segment_id_var.value = jax.lax.dynamic_update_index_in_dim( + cached_ar_segment_id_var.value, active_indicator, jnp.squeeze(cache_ar_index_var.value), 1 + ) + cache_ar_index_var.value = jnp.mod( + cache_ar_index_var.value + 1, self.max_target_length - self.max_prefill_predict_length + ) + cache_ar_lengths_var.value = cache_ar_lengths_var.value.at[:].add(1) + + # The below retrieves the existing prefill cache variables, not creating new ones + cached_prefill_key_vars, cached_prefill_value_vars, cached_prefill_segment_id_var = self._get_prefill_cache_vars( + batch, heads, kv_head_size, common_types.MODEL_MODE_AUTOREGRESSIVE + ) + + cached_prefill = ( + self.get_cached_values(cached_prefill_key_vars, key.dtype, self.prefill_cache_axis_order), + self.get_cached_values(cached_prefill_value_vars, value.dtype, self.prefill_cache_axis_order), + cached_prefill_segment_id_var.value, + ) + + import pdb + pdb.set_trace() + + cached_ar = ( + self.get_cached_values(cached_ar_key_vars, key.dtype, self.ar_cache_axis_order), + self.get_cached_values(cached_ar_value_vars, value.dtype, self.ar_cache_axis_order), + cached_ar_segment_id_var.value, + cache_ar_lengths_var.value, + ) + return cached_prefill, cached_ar \ No newline at end of file diff --git a/MaxText/maxengine.py b/MaxText/maxengine.py index fca2fdb9e..03b3e0fba 100644 --- a/MaxText/maxengine.py +++ b/MaxText/maxengine.py @@ -226,6 +226,7 @@ def prefill( true_length: int, sampler: Optional[Callable[[Any], Any]] = None, # pylint: disable=unused-argument rng: Optional[jax.random.PRNGKey] = None, + position_mask_cur: Optional[jax.Array] =None, ) -> Tuple[Prefix, engine_api.ResultTokens]: """Computes a kv-cache for a new generate request. @@ -239,21 +240,29 @@ def prefill( Returns: kv_cache: For the resulting text. """ - if existing_prefix: - raise ValueError("We don't know what to do with existing_prefix") + # if existing_prefix: + # raise ValueError("We don't know what to do with existing_prefix") if rng is None: rng = jax.random.PRNGKey(0) + mul = 0 + to_add = jnp.zeros((1,512,32000)) + to_add_pos = 0 + if existing_prefix is not None: + mul = 512 + to_add = existing_prefix['flat_logits'] + to_add_pos = 512 + input_tokens = jnp.expand_dims(padded_tokens, 0) # [BATCH, SEQUENCE] - positions = jnp.expand_dims(jnp.arange(0, input_tokens.shape[1]), 0) + positions = jnp.expand_dims(jnp.arange(mul, mul+input_tokens.shape[1]), 0) - zero_to_n = jnp.arange(0, padded_tokens.shape[0]) - ones_to_keep = zero_to_n < true_length + zero_to_n = jnp.arange(0, 1024) + ones_to_keep = zero_to_n < 1024 one_d_output = ones_to_keep * common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR sequence_indicator = jnp.expand_dims(one_d_output, 0) - rng, new_rng = jax.random.split(rng) + # rng, new_rng = jax.random.split(rng) with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): flat_logits, new_vars = self.model.apply( params, @@ -262,12 +271,16 @@ def prefill( decoder_segment_ids=sequence_indicator, enable_dropout=False, model_mode=common_types.MODEL_MODE_PREFILL, - rngs={"params": new_rng}, + rngs={"params": rng}, mutable=["cache"], + existing_prefix=existing_prefix, ) + jax.debug.print("flat_logits {flat_logits} {shape} ", flat_logits=flat_logits, shape=flat_logits.shape) - next_pos = jnp.full((1, 1), true_length, dtype=jnp.int32) + # if existing_prefix is None: + next_pos = jnp.full((1, 1), true_length + mul, dtype=jnp.int32) generated_tokens = jnp.zeros((1, 1), dtype=jnp.int32) + selected_logits = jax.lax.dynamic_slice( flat_logits, (0, true_length - 1, 0), @@ -308,6 +321,7 @@ def prefill( "next_pos": next_pos, "generated_tokens": generated_tokens, "tokens": first_generated_token, + "flat_logits": flat_logits, }, result @functools.partial(jax.jit, static_argnums=(0,), donate_argnums=(2,))