Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimizations for offline mlperf inference #1017

Merged
merged 1 commit into from
Nov 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 92 additions & 34 deletions MaxText/inference_mlperf/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
import jax
from jax import numpy as jnp
import numpy as np
import queue
import os
import functools
import threading
import traceback
import signal

from jetstream.engine import engine_api

Expand All @@ -35,9 +41,21 @@ class InputData:
true_length: int


class JetThread(threading.Thread):

def run(self):
try:
super().run()
except Exception as e: # pylint: disable=broad-exception-caught
print(f"Thread {self.name} encountered an error: {e}")
traceback.print_exc()
os.kill(os.getpid(), signal.SIGKILL)


class OfflineInference:

def __init__(self, engine: engine_api.Engine, params, base_engine: engine_api.Engine):
self.live = False
self.engine = engine
self.decode_state = None
if params is None:
Expand All @@ -55,6 +73,7 @@ def __init__(self, engine: engine_api.Engine, params, base_engine: engine_api.En

self._cached_pref = {}
self._cached_generate = None
self.detokenize_backlog = queue.Queue(10)

def init_decode_state(self):
if self.decode_state is None:
Expand All @@ -75,8 +94,17 @@ def warmup(self, max_length, warmup_samples):
for length in interesting_buckets:
if length > max_length:
break

log.info(f"Compiling prefill: {length}")
input_data = jax.ShapeDtypeStruct((length,), jnp.dtype("int32"))
self._cached_pref[length] = (
jax.jit(self._prefill_insert, donate_argnums=(4,))
.lower(self.params, tokens=input_data, slot=0, true_length=length - 1, decode_state=self.decode_state)
.compile()
)
self.batch_inference(warmup_samples, desc="warmup")
self._cached_generate = (
jax.jit(self.engine.generate, donate_argnums=(1,)).lower(self.params, self.decode_state).compile()
)

def _prefill_insert(self, params, tokens, slot, true_length, decode_state):
"""return decodestate."""
Expand All @@ -99,7 +127,7 @@ def batch_inference_with_callback(
def prefill(slot, tokens, true_length):
nonlocal self
if self.dummy:
log.debug("dummy prefill")
log.info("dummy prefill")
return 123

prefill_fn = self._prefill_insert
Expand All @@ -109,7 +137,7 @@ def prefill(slot, tokens, true_length):
first_token, self.decode_state = prefill_fn(
self.params, tokens=tokens, slot=slot, true_length=true_length, decode_state=self.decode_state
)
return first_token.data[0][0].item()
return first_token

empty_slots = list(range(self.batch_size))
slot_to_id = {}
Expand All @@ -119,12 +147,10 @@ def prefill(slot, tokens, true_length):
dummy_length = 1

def decode():
log.debug("decode")
nonlocal self
nonlocal slot_to_id
nonlocal dummy_length
if self.dummy:
log.debug("Dummy generate")
log.info("Dummy generate")
res = engine_api.ResultTokens(
data=np.array([[123, 1, dummy_length]] * self.batch_size),
tokens_idx=(0, 0),
Expand All @@ -138,51 +164,80 @@ def decode():
gen_fn = self.engine.generate
if self._cached_generate is not None:
gen_fn = self._cached_generate
self.decode_state, result_tokens = gen_fn(self.params, self.decode_state)
result_tokens_l = []
for i in range(5):
self.decode_state, result_tokens = gen_fn(self.params, self.decode_state)
result_tokens_l.append(result_tokens)
for i in range(5):
result_tokens = result_tokens_l[i].convert_to_numpy()
self.detokenize_backlog.put((result_tokens, False, 0, 0), block=True)
# log.info(f"Decode put result {i} to queue")

result_tokens = result_tokens.convert_to_numpy()

newly_empty = []
for slot, id_ in slot_to_id.items():
token, is_valid, length = result_tokens.data[slot]
log.debug(f"slot is {slot}, length is {length}")
should_finish = False
if is_valid:
should_finish = emit_token(id_, token.item())
if should_finish or length >= self.max_decode_length:
newly_empty.append(slot)

# Add slots of those that are empty to empty
for slot in newly_empty:
del slot_to_id[slot]
empty_slots.append(slot)
def detokenize():
nonlocal self
nonlocal slot_to_id
nonlocal empty_slots
while self.live:
# log.info("Detokenize start")
newly_empty = []
result_tokens, is_first_token, row_id, _slot = self.detokenize_backlog.get(block=True)
# log.info("Detokenize get from queue")
if is_first_token:
first_token = result_tokens.data[0][0].item()
should_terminate = emit_first_token(row_id, first_token)
if not should_terminate:
slot_to_id[_slot] = row_id
else:
empty_slots.append(_slot)
continue
for slot, id_ in slot_to_id.items():
token, is_valid, length = result_tokens.data[slot]
log.debug(f"slot is {slot}, length is {length}")
should_finish = False
if is_valid:
should_finish = emit_token(id_, token.item())
if should_finish or length >= self.max_decode_length:
newly_empty.append(slot)
log.info(f"Detokenize free up {slot}, length {length}")
# Add slots of those that are empty to empty
for slot in newly_empty:
del slot_to_id[slot]
empty_slots.append(slot)
if newly_empty and self.detokenize_backlog.qsize() == 0 and len(slot_to_id.items()) == 0:
break

detokenize_thread = JetThread(
target=functools.partial(
detokenize,
),
name="detokenize",
)
self.live = True
detokenize_thread.start()
for row in data:
log.debug(f"empty_slots {len(empty_slots)}")
while not empty_slots:
# If slots are all full, decode until there are free slots
# to insert
num_decodes += 1
log.debug(f"decode-{desc}-{num_decodes}")
log.info(f"decode-{desc}-{num_decodes}")
decode()
# do one insert
num_tokens = len(row.tokens)
num_prefills[num_tokens] = 0 if num_tokens not in num_prefills else num_prefills[num_tokens] + 1
log.debug(
f"prefill-{desc}-{num_prefills} num_tokens {num_tokens} true_length {row.true_length} num_empty_slots {len(empty_slots)} num_decodes {num_decodes}"
log.info(
f"prefill-{desc}-{num_prefills} num_prefills {sum(num_prefills.values())} num_tokens {num_tokens} true_length {row.true_length} num_empty_slots {len(empty_slots)} num_decodes {num_decodes}"
)
slot = empty_slots.pop()
first_token = prefill(slot, row.tokens, row.true_length)
should_terminate = emit_first_token(row.id, first_token)
if not should_terminate:
slot_to_id[slot] = row.id
else:
empty_slots.append(slot) # dont use the slot
self.detokenize_backlog.put((first_token, True, row.id, slot), block=True)

while slot_to_id:
log.debug(f"decode-{desc}-{num_decodes} num_filled_slots {len(slot_to_id)}")
log.info(f"decode-{desc}-{num_decodes} num_filled_slots {len(slot_to_id)}")
num_decodes += 1
decode()

self.live = False
detokenize_thread.join()
log.info(f"summary-{desc}-prefills-{num_prefills}-decodes-{num_decodes} completed.")

def batch_inference(self, data: List[InputData], desc=""):
Expand All @@ -191,7 +246,10 @@ def batch_inference(self, data: List[InputData], desc=""):

def callback(id_, token):
nonlocal res
res[id_].append(token)
if token == self.tokenizer.eos_id:
log.info(f"res[{id_}] eos")
if not res[id_] or res[id_][-1] != self.tokenizer.eos_id:
res[id_].append(token)
return token == self.tokenizer.eos_id

self.batch_inference_with_callback(data, emit_first_token=callback, emit_token=callback, desc=desc)
Expand Down
Loading