Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip-prototype]chunked prefill on 1k prompt len #1237

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ cast_logits_to_fp32: True # whether to cast the logits to fp32. The higher preci
# mixture of experts (moe)
num_experts: 1
num_experts_per_tok: 1
megablox: True
megablox: False
capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default
load_balance_loss_weight: 0.01 # weight for the load balance loss

Expand Down Expand Up @@ -370,7 +370,11 @@ learning_rate_schedule_steps: -1 # By default the length of the schedule is set

max_target_length: 2048 # Maximum sequence length
max_prefill_predict_length: 64 # Maximum length for the prefill when doing autoregression
prompt: "I love to" # Prompt for language model sampling.
# prompt: "I love to"
prompt: "The old lighthouse keeper, Silas, squinted at the churning grey sea. For fifty years, he'd watched the waves crash against the jagged cliffs of this remote island, his only company the rhythmic pulse of the lamp and the cries of gulls. But tonight, something was different. A strange, ethereal glow emanated from the depths, pulsing in sync with a low hum that vibrated through the stone tower. It wasn't the familiar bioluminescence of sea creatures; this was something... other.\n\nSilas, a man weathered by solitude and sea storms, felt a prickle of unease, a feeling he hadn't experienced since the day his wife, Elara, disappeared into the sea decades ago, lured by a siren's song, or so the villagers whispered. He dismissed it then as grief and folklore, but now, staring into the glowing abyss, he wasn't so sure.\n\nSuddenly, the lighthouse beam flickered and died, plunging the island into darkness. The humming intensified, and the glow from the sea grew brighter, illuminating a swirling vortex of iridescent colors. From the center of the vortex, a shape began to emerge - a vessel, unlike any he had ever seen. It was not of wood or steel, but of a shimmering, crystalline material that seemed to absorb and reflect the surrounding light, constantly shifting in hue and form.\n\nAs the vessel drew closer, Silas could make out figures moving within. They were humanoid, but with an ethereal grace that defied human movement. Their skin shimmered with the same iridescent light as the vessel, and their eyes glowed with an otherworldly luminescence.\n\nFear warred with curiosity in Silas's heart. He had lived a solitary life, but he had never turned away a soul in need. Yet, these beings were clearly not of this world. Could they be the source of the strange energy, the cause of the lighthouse's failure? Were they friend or foe?\n\nOne of the figures raised a hand, and a voice, melodic and ancient, echoed in Silas's mind, not through his ears, but directly into his consciousness. \"Greetings, Silas, keeper of the light. We come in peace, seeking knowledge and understanding. We are the Luminians, and we have traveled far to reach this place.\"\n\nSilas, despite his fear, felt compelled to respond. He had always been drawn to the mysteries of the sea, and now, the greatest mystery of all had come to his doorstep. He gripped the railing of the lighthouse balcony, his knuckles white, and prepared to face whatever the dawn might bring.\n\n**Consider the following in your narrative:**\n\n* **Silas's Backstory:** Expand on Silas's past, his relationship with Elara, and the impact of her disappearance. How does his past influence his reaction to the Luminians? What drove him to become a lighthouse keeper?\n* **The Luminians:** Describe the Luminians in more detail. What is their culture like? What is their technology based on? What is their purpose in coming to Earth, and specifically to Silas's island? What do they mean by \"seeking knowledge and understanding\"?\n* **The Island's Secret:** The island is more than just a rock in the sea. It holds a secret, a connection to the Luminians, that Silas is about to discover. What is this secret, and how is it tied to the strange energy emanating from the sea?\n* **Conflict and Choice:** Silas must choose whether to trust the Luminians or to resist them. What factors influence his decision? What are the potential consequences of his choice?\n* **The Siren's Song:** Is there a connection between Elara's disappearance and the Luminians? Was the \"siren's song\" a real phenomenon, and if so, what is its nature?\n\n**Possible Story Arcs:**\n\n* **Revelation:** Silas learns the truth about Elara's disappearance and the history of the island.\n* **Alliance:** Silas chooses to help the Luminians, leading to a partnership that could change the fate of both their peoples.\n* **Betrayal:** The Luminians' motives are not what they seem, and Silas must fight to protect the island and perhaps even the world.\n* **Transformation:** Silas undergoes a personal transformation, shedding his old life and embracing a new destiny connected to the Luminians and the island's secret.\n\nThis prompt provides a foundation for a compelling story filled with mystery, suspense, and the potential for profound change. Explore the themes of isolation vs. connection, fear of the unknown vs. the allure of discovery, and the power of choice in the face of the extraordinary. Let the story unfold, revealing the secrets of the island, the truth about the Luminians, and the destiny of Silas, the old lighthouse keeper who found himself at the center of an interdimensional encounter. What will he do now that the light has gone out, only to be replaced by something far stranger and more profound? The answers lie within your narrative. Let the journey begin. Remember the sea holds secrets deeper than any man can fathom, and some are best left undisturbed. But destiny, it seems, has other plans for Silas. This is his story now. This is the next chapter. What will it be? This is where the past meets the future. This is where the ordinary meets the extraordinary. This is where one man's life changes forever. This is the moment everything changes. This is the turning point. This is just the beginning. The old lighthouse keeper, Silas, squinted at the churning grey sea. For fifty years, he'd watched the waves crash against the jagged cliffs of this remote island, his only company the rhythmic pulse of the lamp and the cries of gulls. But tonight, something was different. A strange, ethereal glow emanated from the depths, pulsing in sync with a low hum that vibrated through the stone tower. It wasn't the familiar bioluminescence of sea creatures; this was something... other.\n\nSilas, a man weathered by solitude and sea storms, felt a prickle of unease, a feeling he hadn't experienced since the day his wife, Elara, disappeared into the sea decades ago, lured by a siren's song, or so the villagers whispered. He dismissed it then as grief and folklore, but now, staring into the glowing abyss, he wasn't so sure.\n\nSuddenly, the lighthouse beam flickered and died, plunging the island into darkness. The humming intensified, and the glow from the sea grew brighter, illuminating a swirling vortex of iridescent colors. From the center of the vortex, a shape began to emerge - a vessel, unlike any he had ever seen. It was not of wood or steel, but of a shimmering, crystalline material that seemed to absorb and reflect the surrounding light, constantly shifting in hue and form.\n\nAs the vessel drew closer, Silas could make out figures moving within. They were humanoid, but with an ethereal grace that defied human movement. Their skin shimmered with the same iridescent light as the vessel, and their eyes glowed with an otherworldly luminescence.\n\nFear warred with curiosity in Silas's heart. He had lived a solitary life, but he had never turned away a soul in need. Yet, these beings were clearly not of this world. Could they be the source of the strange energy, the cause of the lighthouse's failure? Were they friend or foe?\n\nOne of the figures raised a hand, and a voice, melodic and ancient, echoed in Silas's mind, not through his ears, but directly into his consciousness. \"Greetings, Silas, keeper of the light. We come in peace, seeking knowledge and understanding. We are the Luminians, and we have traveled far to reach this place.\"\n\nSilas, despite his fear, felt compelled to respond. He had always been drawn to the mysteries of the sea, and now, the greatest mystery of all had come to his doorstep. He gripped the railing of the lighthouse balcony, his knuckles white, and prepared to face whatever the dawn might bring.\n\n**Consider the following in your narrative:**\n\n* **Silas's Backstory:** Expand on Silas's past, his relationship with Elara, and the impact of her disappearance. How does his past influence his reaction to the Luminians? What drove him to become a lighthouse keeper?\n* **The Luminians:** Describe the Luminians in more detail. What is their culture like? What is their technology based on? What is their purpose in coming to Earth, and specifically to Silas's island? What do they mean by \"seeking knowledge and understanding\"?\n* **The Island's Secret:** The island is more than just a rock in the sea. It holds a secret, a connection to the Luminians, that Silas is about to discover. What is this secret, and how is it tied to the strange energy emanating from the sea?\n* **Conflict and Choice:** Silas must choose whether to trust the Luminians or to resist them. What factors influence his decision? What are the potential consequences of his choice?\n* **The Siren's Song:** Is there a connection between Elara's disappearance and the Luminians? Was the \"siren's song\" a real phenomenon, and if so, what is its nature?\n\n**Possible Story Arcs:**\n\n* **Revelation:** Silas learns the truth about Elara's disappearance and the history of the island.\n* **Alliance:** Silas chooses to help the Luminians, leading to a partnership that could change the fate of both their peoples.\n* **Betrayal:** The Luminians' motives are not what they seem, and Silas must fight to protect the island and perhaps even the world.\n* **Transformation:** Silas undergoes a personal transformation, shedding his old life and embracing a new destiny connected to the Luminians and the island's secret.\n\nThis prompt provides a foundation for a compelling story filled with mystery, suspense, and the potential for profound change. Explore the themes of isolation vs. connection, fear of the unknown vs. the allure of discovery, and the power of choice in the face of the extraordinary. Let the story unfold, revealing the secrets of the island, the truth about the Luminians, and the destiny of Silas, the old lighthouse keeper who found himself at the center of an interdimensional encounter. What will he do now that the light has gone out, only to be replaced by something far stranger and more profound? The answers lie within your narrative. Let the journey begin. Remember the sea holds secrets deeper than any man can fathom, and some are best left undisturbed. But destiny, it seems, has other plans for Silas. This is his story now. This is the next chapter. What will it be? This is where the past meets the future. This is where the ordinary meets the extraordinary. This is where one man's life changes forever. This is the moment everything changes. This is the turning point. This is just the beginning." # Prompt for language model sampling.

# prompt: ""

load_from_prefill_dir: False # If true, decode.py doesn't "prefill" but just reads from directory
prefill_cache_dir: "" # If set and load_from_prefill_dir, decode.py reads from directory. If set, decode.py writes to directory
autoregressive_decode_assert: ""
Expand Down Expand Up @@ -533,3 +537,6 @@ sa_use_fused_bwd_kernel: False
sa_q_layout: "HEAD_DIM_MINOR"
sa_k_layout: "HEAD_DIM_MINOR"
sa_v_layout: "HEAD_DIM_MINOR"

chunk_size: 512
use_chunked_prefill: True
69 changes: 65 additions & 4 deletions MaxText/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""CLI utility for running inference on a single stream"""

from collections import defaultdict
import jax

import max_utils
Expand All @@ -24,6 +25,38 @@

from typing import Sequence
from absl import app
from flax import struct
import common_types
Array = common_types.Array
import jax.numpy as jnp

@struct.dataclass
class ChunkMetadata:
tokens_entire_sequence: Array
true_length: int
true_length_chunk: int
chunk_padded: Array
processed: bool
chunk_seq_start_index: int


def create_chunked_metadata(tokens, true_length, chunk_size):
start = 0
chunk_metadata_list = []

while start < len(tokens):
end = min(start + chunk_size, true_length)
cur_chunk_tokens = tokens[start:end]

chunk_metadata_list.append(ChunkMetadata(tokens_entire_sequence=tokens,
true_length=true_length,
true_length_chunk=chunk_size,
chunk_padded=cur_chunk_tokens,
processed=False,
chunk_seq_start_index=start))

start = start + chunk_size
return chunk_metadata_list


def main(argv: Sequence[str]) -> None:
Expand All @@ -44,29 +77,57 @@ def main(argv: Sequence[str]) -> None:
metadata = engine.get_tokenizer()
tokenizer_model = engine.build_tokenizer(metadata)
tokens, true_length = tokenizer_model.encode(text, is_bos=True, prefill_lengths=[config.max_prefill_predict_length])
chunk_size = config.chunk_size
tokens = tokens[:config.max_prefill_predict_length]
print("tokens are ", tokens, len(tokens))
true_length = config.max_prefill_predict_length
chunked_metadata_list = create_chunked_metadata(tokens, true_length, chunk_size)
assert true_length <= config.max_prefill_predict_length, "can't take too many tokens"
assert config.quantization != "fp8", "fp8 on NVIDIA GPUs is not supported in decode.py yet"

# Split RNG before calling prefill
rng, rng_prefill = jax.random.split(rng)
prefill_result, first_token = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length, rng=rng_prefill)
slot = 0
rng, rng_init_decode = jax.random.split(rng)

prefill_result = None
for i,chunk_metadata in enumerate(chunked_metadata_list):
if i == 0:
prefill_result, first_token = engine.prefill(existing_prefix=prefill_result,
params=params,
padded_tokens=chunk_metadata.chunk_padded,
true_length=chunk_size,
rng=rng,
position_mask_cur=None)
else:
prefill_result, first_token = engine.prefill(existing_prefix=prefill_result,
params=params | {"cache": prefill_result["cache"]},
padded_tokens=chunk_metadata.chunk_padded,
true_length=chunk_size,
rng=rng,
position_mask_cur=None)

import pdb
pdb.set_trace()

rng, rng_init_decode = jax.random.split(rng)
decode_state = engine.init_decode_state(rng_init_decode)
decode_state = engine.insert(prefill_result, decode_state, slot=slot)

steps = range(config.max_prefill_predict_length, config.max_target_length)
print("total steps ", steps)
sampled_tokens_list = []
sampled_tokens_list.append(first_token)
for _ in steps:
for s in steps:
rng, rng_generate = jax.random.split(rng)

decode_state, sampled_tokens = engine.generate(params, decode_state, rng=rng_generate)
sampled_tokens_list.append(sampled_tokens)

results = [sampled_tokens.get_result_at_slot(slot).tokens.item() for sampled_tokens in sampled_tokens_list]
print("len results ", len(results))
output = tokenizer_model.decode(results)
print(f"Input `{text}` -> `{output}`")
print(f"Output ************************** -> `{output}`")

if config.autoregressive_decode_assert != "":
assert (
Expand Down
Loading