Skip to content

Commit

Permalink
Merge pull request #756 from google:decode_prefill_fix
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 650409674
  • Loading branch information
maxtext authors committed Jul 8, 2024
2 parents b83a7a4 + 85c105e commit 1b4cd15
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion MaxText/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,15 @@ def main(config):
)
assert true_length <= config.max_prefill_predict_length, "can't take too many tokens"
assert config.quantization != "fp8", "fp8 on NVIDIA GPUs is not supported in decode.py yet"
prefill_result, _ = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length)
prefill_result, first_token = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length)
slot = 0

decode_state = engine.init_decode_state()
decode_state = engine.insert(prefill_result, decode_state, slot=slot)

steps = range(config.max_prefill_predict_length, config.max_target_length)
sampled_tokens_list = []
sampled_tokens_list.append(first_token)
for _ in steps:
decode_state, sampled_tokens = engine.generate(params, decode_state)
sampled_tokens_list.append(sampled_tokens)
Expand Down

0 comments on commit 1b4cd15

Please sign in to comment.