Skip to content

Commit ad7c017

Browse files
committed
hacky way to test aot in jetstream
1 parent 9cb7785 commit ad7c017

File tree

1 file changed

+107
-11
lines changed

1 file changed

+107
-11
lines changed

jetstream/core/orchestrator.py

Lines changed: 107 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,11 @@
102102
from jetstream.core.metrics.prometheus import JetstreamMetricsCollector
103103
import numpy as np
104104

105-
log_level = os.getenv("LOG_LEVEL", "WARNING").upper()
105+
from jax.experimental import layout as jax_layout
106+
DLL = jax_layout.DeviceLocalLayout
107+
Layout = jax_layout.Layout
108+
109+
log_level = os.getenv("LOG_LEVEL", "DEBUG").upper()
106110

107111
logger = logging.getLogger("JetstreamLogger")
108112
logger.propagate = False
@@ -405,6 +409,29 @@ def __init__(
405409

406410
self._jax_padding = jax_padding
407411

412+
##### Auto layout compile for interleaved engine
413+
self._generate_executables = [None for _ in self._generate_engines]
414+
self._cached_insert = [None for _ in self._generate_engines]
415+
self._cached_prefill = [None for _ in self._prefill_engines]
416+
self._decode_states = [None for _ in self._generate_engines]
417+
if self._interleaved_mode:
418+
for idx in range(len(self._generate_engines)):
419+
logger.debug("Compiling interleaved engine {}".format(idx))
420+
engine = self._generate_engines[idx]
421+
params = self._generate_params[idx]
422+
engine, params, gen_fn, prefill_fn, insert_fn, decode_state = self._auto_layout_compile(engine, params)
423+
424+
self._prefill_engines[idx] = engine
425+
self._generate_engines[idx] = engine
426+
self._prefill_params[idx] = params
427+
self._generate_params[idx] = params
428+
self._cached_prefill[idx] = prefill_fn
429+
self._cached_insert[idx] = insert_fn
430+
self._generate_executables[idx] = gen_fn
431+
432+
self._decode_states[idx] = decode_state
433+
434+
408435
# Create all threads
409436
self._prefill_threads = [
410437
JetThread(
@@ -670,6 +697,56 @@ def _do_chunked_prefill(
670697

671698
return prefill_result, first_token
672699

700+
def _auto_layout_compile(self, engine, params):
701+
logger.debug("Compiling generate function")
702+
generate_executable, params, decode_state_executable = engine.aot_compile(
703+
params, pass_rng_shape=False
704+
)
705+
decode_state = decode_state_executable(None)
706+
707+
# prefill
708+
interesting_buckets = [
709+
64,
710+
128,
711+
256,
712+
512,
713+
1024,
714+
]
715+
716+
cached_prefill = {}
717+
cached_insert = {}
718+
for length in interesting_buckets:
719+
i32_scalar = jax.ShapeDtypeStruct((), int)
720+
logger.debug("Compiling prefill: %d", length)
721+
input_data = jax.ShapeDtypeStruct((length,), jax.numpy.dtype("int32"))
722+
723+
cached_prefill[length] = (
724+
jax.jit(
725+
engine.prefill_aot,
726+
in_shardings=(engine.param_layouts, None, None),
727+
out_shardings=(Layout(DLL.AUTO), Layout(DLL.AUTO)),
728+
).lower(params, input_data, i32_scalar)
729+
).compile(compiler_options=None)
730+
731+
logger.debug("Generate dummy prefix: %d", length)
732+
dummy_tokens = jax.numpy.ones(shape=(length,), dtype=jax.numpy.dtype("int32"))
733+
prefix_shapes = jax.eval_shape(engine.prefill_aot, params, dummy_tokens, 1)
734+
735+
logger.debug("Compiling insert: %d", length)
736+
prefill_output_layout, _ = cached_prefill[length].output_layouts
737+
logger.debug("Prefill output layout: {}".format(prefill_output_layout))
738+
logger.debug("Prefix shapes: {}".format(prefix_shapes))
739+
i32_scalar = jax.ShapeDtypeStruct((), int)
740+
cached_insert[length] = (
741+
jax.jit(
742+
engine.insert,
743+
in_shardings=(prefill_output_layout, engine.decode_state_layouts, None),
744+
out_shardings=(engine.decode_state_layouts),
745+
donate_argnames=("decode_state"),
746+
).lower(prefix_shapes[0], engine.decode_state_shapes, i32_scalar)
747+
).compile(compiler_options=None)
748+
return engine, params, generate_executable, cached_prefill, cached_insert, decode_state
749+
673750
def _prefill_thread(self, idx: int):
674751
"""Thread which runs in the background performing prefills."""
675752
logger.info("Spinning up prefill thread %d.", idx)
@@ -683,6 +760,13 @@ def _prefill_thread(self, idx: int):
683760
thread_name = f"Prefill thread {idx}"
684761
ThreadDebugLog(thread_name, f"Prefill params {idx} loaded.")
685762

763+
if not self._interleaved_mode:
764+
logger.debug("Compiling for disagg mode")
765+
prefill_engine, prefill_params, gen_fn, prefill_fn, insert_fn, _ = self._auto_layout_compile(
766+
prefill_engine, prefill_params
767+
)
768+
self._cached_prefill[idx] = prefill_fn
769+
686770
while self.live:
687771
my_transfer_backlog = self._transfer_backlogs[idx]
688772
# The prefill thread can just sleep until it has work to do.
@@ -759,10 +843,11 @@ def _prefill_thread(self, idx: int):
759843
)
760844
else:
761845
# Compute new kv cache for the prefill_content.
762-
prefill_result, first_token = prefill_engine.prefill(
763-
params=final_prefill_params,
764-
padded_tokens=padded_tokens,
765-
true_length=true_length,
846+
assert padded_tokens.shape[0] in self._cached_prefill[idx]
847+
prefill_result, first_token = self._cached_prefill[idx][padded_tokens.shape[0]](
848+
final_prefill_params,
849+
padded_tokens,
850+
true_length,
766851
)
767852

768853
request.complete = np.zeros(
@@ -967,10 +1052,11 @@ def _insert_if_possible(
9671052
else:
9681053
break
9691054

970-
decode_state = generate_engine.insert(
1055+
length = new_request.prefill_result['cache']['decoder']['layers_0']['self_attention']['KVCache_0']['cache_prefill_segment_id'].value.shape[1]
1056+
decode_state = self._cached_insert[idx][length](
9711057
new_request.prefill_result,
9721058
decode_state,
973-
slot=slot,
1059+
slot,
9741060
# request_id=new_request.request_id,
9751061
)
9761062
ThreadDebugLog(
@@ -1115,9 +1201,19 @@ def _generate_thread(self, idx: int):
11151201
# Keep track of what step tokens were generated at.
11161202
generate_timestep = 0
11171203
# State to store things like running kv cache in.
1118-
decode_state = generate_engine.init_decode_state()
1119-
11201204
generate_params = self._generate_params[idx]
1205+
1206+
if not self._interleaved_mode:
1207+
logger.debug("Compiling for disagg mode")
1208+
generate_engine, generate_params, gen_fn, prefill_fn, insert_fn, decode_state = self._auto_layout_compile(
1209+
generate_engine, generate_params
1210+
)
1211+
self._generate_executables[idx] = gen_fn
1212+
self._cached_insert[idx] = insert_fn
1213+
self._decode_states[idx] = decode_state
1214+
1215+
decode_state = self._decode_states[idx]
1216+
11211217
thread_name = f"Generate thread {idx}"
11221218
ThreadDebugLog(thread_name, f"Generate params {idx} loaded.")
11231219
time_of_last_generate = time.time()
@@ -1178,8 +1274,8 @@ def _generate_thread(self, idx: int):
11781274
), "At this point we must have some requests inserted into the slots."
11791275

11801276
# Now we actually take a generate step on requests in the slots.
1181-
decode_state, sampled_tokens = generate_engine.generate(
1182-
generate_params, decode_state
1277+
decode_state, sampled_tokens = self._generate_executables[idx](
1278+
generate_params, decode_state, None
11831279
)
11841280
sampled_tokens.copy_to_host_async()
11851281
# Respond to detokenization backpressure.

0 commit comments

Comments
 (0)