102102from jetstream .core .metrics .prometheus import JetstreamMetricsCollector
103103import numpy as np
104104
105- log_level = os .getenv ("LOG_LEVEL" , "WARNING" ).upper ()
105+ from jax .experimental import layout as jax_layout
106+ DLL = jax_layout .DeviceLocalLayout
107+ Layout = jax_layout .Layout
108+
109+ log_level = os .getenv ("LOG_LEVEL" , "DEBUG" ).upper ()
106110
107111logger = logging .getLogger ("JetstreamLogger" )
108112logger .propagate = False
@@ -405,6 +409,29 @@ def __init__(
405409
406410 self ._jax_padding = jax_padding
407411
412+ ##### Auto layout compile for interleaved engine
413+ self ._generate_executables = [None for _ in self ._generate_engines ]
414+ self ._cached_insert = [None for _ in self ._generate_engines ]
415+ self ._cached_prefill = [None for _ in self ._prefill_engines ]
416+ self ._decode_states = [None for _ in self ._generate_engines ]
417+ if self ._interleaved_mode :
418+ for idx in range (len (self ._generate_engines )):
419+ logger .debug ("Compiling interleaved engine {}" .format (idx ))
420+ engine = self ._generate_engines [idx ]
421+ params = self ._generate_params [idx ]
422+ engine , params , gen_fn , prefill_fn , insert_fn , decode_state = self ._auto_layout_compile (engine , params )
423+
424+ self ._prefill_engines [idx ] = engine
425+ self ._generate_engines [idx ] = engine
426+ self ._prefill_params [idx ] = params
427+ self ._generate_params [idx ] = params
428+ self ._cached_prefill [idx ] = prefill_fn
429+ self ._cached_insert [idx ] = insert_fn
430+ self ._generate_executables [idx ] = gen_fn
431+
432+ self ._decode_states [idx ] = decode_state
433+
434+
408435 # Create all threads
409436 self ._prefill_threads = [
410437 JetThread (
@@ -670,6 +697,56 @@ def _do_chunked_prefill(
670697
671698 return prefill_result , first_token
672699
700+ def _auto_layout_compile (self , engine , params ):
701+ logger .debug ("Compiling generate function" )
702+ generate_executable , params , decode_state_executable = engine .aot_compile (
703+ params , pass_rng_shape = False
704+ )
705+ decode_state = decode_state_executable (None )
706+
707+ # prefill
708+ interesting_buckets = [
709+ 64 ,
710+ 128 ,
711+ 256 ,
712+ 512 ,
713+ 1024 ,
714+ ]
715+
716+ cached_prefill = {}
717+ cached_insert = {}
718+ for length in interesting_buckets :
719+ i32_scalar = jax .ShapeDtypeStruct ((), int )
720+ logger .debug ("Compiling prefill: %d" , length )
721+ input_data = jax .ShapeDtypeStruct ((length ,), jax .numpy .dtype ("int32" ))
722+
723+ cached_prefill [length ] = (
724+ jax .jit (
725+ engine .prefill_aot ,
726+ in_shardings = (engine .param_layouts , None , None ),
727+ out_shardings = (Layout (DLL .AUTO ), Layout (DLL .AUTO )),
728+ ).lower (params , input_data , i32_scalar )
729+ ).compile (compiler_options = None )
730+
731+ logger .debug ("Generate dummy prefix: %d" , length )
732+ dummy_tokens = jax .numpy .ones (shape = (length ,), dtype = jax .numpy .dtype ("int32" ))
733+ prefix_shapes = jax .eval_shape (engine .prefill_aot , params , dummy_tokens , 1 )
734+
735+ logger .debug ("Compiling insert: %d" , length )
736+ prefill_output_layout , _ = cached_prefill [length ].output_layouts
737+ logger .debug ("Prefill output layout: {}" .format (prefill_output_layout ))
738+ logger .debug ("Prefix shapes: {}" .format (prefix_shapes ))
739+ i32_scalar = jax .ShapeDtypeStruct ((), int )
740+ cached_insert [length ] = (
741+ jax .jit (
742+ engine .insert ,
743+ in_shardings = (prefill_output_layout , engine .decode_state_layouts , None ),
744+ out_shardings = (engine .decode_state_layouts ),
745+ donate_argnames = ("decode_state" ),
746+ ).lower (prefix_shapes [0 ], engine .decode_state_shapes , i32_scalar )
747+ ).compile (compiler_options = None )
748+ return engine , params , generate_executable , cached_prefill , cached_insert , decode_state
749+
673750 def _prefill_thread (self , idx : int ):
674751 """Thread which runs in the background performing prefills."""
675752 logger .info ("Spinning up prefill thread %d." , idx )
@@ -683,6 +760,13 @@ def _prefill_thread(self, idx: int):
683760 thread_name = f"Prefill thread { idx } "
684761 ThreadDebugLog (thread_name , f"Prefill params { idx } loaded." )
685762
763+ if not self ._interleaved_mode :
764+ logger .debug ("Compiling for disagg mode" )
765+ prefill_engine , prefill_params , gen_fn , prefill_fn , insert_fn , _ = self ._auto_layout_compile (
766+ prefill_engine , prefill_params
767+ )
768+ self ._cached_prefill [idx ] = prefill_fn
769+
686770 while self .live :
687771 my_transfer_backlog = self ._transfer_backlogs [idx ]
688772 # The prefill thread can just sleep until it has work to do.
@@ -759,10 +843,11 @@ def _prefill_thread(self, idx: int):
759843 )
760844 else :
761845 # Compute new kv cache for the prefill_content.
762- prefill_result , first_token = prefill_engine .prefill (
763- params = final_prefill_params ,
764- padded_tokens = padded_tokens ,
765- true_length = true_length ,
846+ assert padded_tokens .shape [0 ] in self ._cached_prefill [idx ]
847+ prefill_result , first_token = self ._cached_prefill [idx ][padded_tokens .shape [0 ]](
848+ final_prefill_params ,
849+ padded_tokens ,
850+ true_length ,
766851 )
767852
768853 request .complete = np .zeros (
@@ -967,10 +1052,14 @@ def _insert_if_possible(
9671052 else :
9681053 break
9691054
970- decode_state = generate_engine .insert (
1055+ if 'decoder' in new_request .prefill_result ['cache' ]:
1056+ length = new_request .prefill_result ['cache' ]['decoder' ]['layers_0' ]['self_attention' ]['KVCache_0' ]['cache_prefill_segment_id' ].value .shape [1 ]
1057+ else :
1058+ length = new_request .prefill_result ['cache' ]['self_attention' ]['KVCache_0' ]['cache_prefill_segment_id' ].value .shape [2 ]
1059+ decode_state = self ._cached_insert [idx ][length ](
9711060 new_request .prefill_result ,
9721061 decode_state ,
973- slot = slot ,
1062+ slot ,
9741063 # request_id=new_request.request_id,
9751064 )
9761065 ThreadDebugLog (
@@ -1115,9 +1204,19 @@ def _generate_thread(self, idx: int):
11151204 # Keep track of what step tokens were generated at.
11161205 generate_timestep = 0
11171206 # State to store things like running kv cache in.
1118- decode_state = generate_engine .init_decode_state ()
1119-
11201207 generate_params = self ._generate_params [idx ]
1208+
1209+ if not self ._interleaved_mode :
1210+ logger .debug ("Compiling for disagg mode" )
1211+ generate_engine , generate_params , gen_fn , prefill_fn , insert_fn , decode_state = self ._auto_layout_compile (
1212+ generate_engine , generate_params
1213+ )
1214+ self ._generate_executables [idx ] = gen_fn
1215+ self ._cached_insert [idx ] = insert_fn
1216+ self ._decode_states [idx ] = decode_state
1217+
1218+ decode_state = self ._decode_states [idx ]
1219+
11211220 thread_name = f"Generate thread { idx } "
11221221 ThreadDebugLog (thread_name , f"Generate params { idx } loaded." )
11231222 time_of_last_generate = time .time ()
@@ -1178,8 +1277,8 @@ def _generate_thread(self, idx: int):
11781277 ), "At this point we must have some requests inserted into the slots."
11791278
11801279 # Now we actually take a generate step on requests in the slots.
1181- decode_state , sampled_tokens = generate_engine . generate (
1182- generate_params , decode_state
1280+ decode_state , sampled_tokens = self . _generate_executables [ idx ] (
1281+ generate_params , decode_state , None
11831282 )
11841283 sampled_tokens .copy_to_host_async ()
11851284 # Respond to detokenization backpressure.
0 commit comments