102
102
from jetstream .core .metrics .prometheus import JetstreamMetricsCollector
103
103
import numpy as np
104
104
105
- log_level = os .getenv ("LOG_LEVEL" , "WARNING" ).upper ()
105
+ from jax .experimental import layout as jax_layout
106
+ DLL = jax_layout .DeviceLocalLayout
107
+ Layout = jax_layout .Layout
108
+
109
+ log_level = os .getenv ("LOG_LEVEL" , "DEBUG" ).upper ()
106
110
107
111
logger = logging .getLogger ("JetstreamLogger" )
108
112
logger .propagate = False
@@ -405,6 +409,29 @@ def __init__(
405
409
406
410
self ._jax_padding = jax_padding
407
411
412
+ ##### Auto layout compile for interleaved engine
413
+ self ._generate_executables = [None for _ in self ._generate_engines ]
414
+ self ._cached_insert = [None for _ in self ._generate_engines ]
415
+ self ._cached_prefill = [None for _ in self ._prefill_engines ]
416
+ self ._decode_states = [None for _ in self ._generate_engines ]
417
+ if self ._interleaved_mode :
418
+ for idx in range (len (self ._generate_engines )):
419
+ logger .debug ("Compiling interleaved engine {}" .format (idx ))
420
+ engine = self ._generate_engines [idx ]
421
+ params = self ._generate_params [idx ]
422
+ engine , params , gen_fn , prefill_fn , insert_fn , decode_state = self ._auto_layout_compile (engine , params )
423
+
424
+ self ._prefill_engines [idx ] = engine
425
+ self ._generate_engines [idx ] = engine
426
+ self ._prefill_params [idx ] = params
427
+ self ._generate_params [idx ] = params
428
+ self ._cached_prefill [idx ] = prefill_fn
429
+ self ._cached_insert [idx ] = insert_fn
430
+ self ._generate_executables [idx ] = gen_fn
431
+
432
+ self ._decode_states [idx ] = decode_state
433
+
434
+
408
435
# Create all threads
409
436
self ._prefill_threads = [
410
437
JetThread (
@@ -670,6 +697,56 @@ def _do_chunked_prefill(
670
697
671
698
return prefill_result , first_token
672
699
700
+ def _auto_layout_compile (self , engine , params ):
701
+ logger .debug ("Compiling generate function" )
702
+ generate_executable , params , decode_state_executable = engine .aot_compile (
703
+ params , pass_rng_shape = False
704
+ )
705
+ decode_state = decode_state_executable (None )
706
+
707
+ # prefill
708
+ interesting_buckets = [
709
+ 64 ,
710
+ 128 ,
711
+ 256 ,
712
+ 512 ,
713
+ 1024 ,
714
+ ]
715
+
716
+ cached_prefill = {}
717
+ cached_insert = {}
718
+ for length in interesting_buckets :
719
+ i32_scalar = jax .ShapeDtypeStruct ((), int )
720
+ logger .debug ("Compiling prefill: %d" , length )
721
+ input_data = jax .ShapeDtypeStruct ((length ,), jax .numpy .dtype ("int32" ))
722
+
723
+ cached_prefill [length ] = (
724
+ jax .jit (
725
+ engine .prefill_aot ,
726
+ in_shardings = (engine .param_layouts , None , None ),
727
+ out_shardings = (Layout (DLL .AUTO ), Layout (DLL .AUTO )),
728
+ ).lower (params , input_data , i32_scalar )
729
+ ).compile (compiler_options = None )
730
+
731
+ logger .debug ("Generate dummy prefix: %d" , length )
732
+ dummy_tokens = jax .numpy .ones (shape = (length ,), dtype = jax .numpy .dtype ("int32" ))
733
+ prefix_shapes = jax .eval_shape (engine .prefill_aot , params , dummy_tokens , 1 )
734
+
735
+ logger .debug ("Compiling insert: %d" , length )
736
+ prefill_output_layout , _ = cached_prefill [length ].output_layouts
737
+ logger .debug ("Prefill output layout: {}" .format (prefill_output_layout ))
738
+ logger .debug ("Prefix shapes: {}" .format (prefix_shapes ))
739
+ i32_scalar = jax .ShapeDtypeStruct ((), int )
740
+ cached_insert [length ] = (
741
+ jax .jit (
742
+ engine .insert ,
743
+ in_shardings = (prefill_output_layout , engine .decode_state_layouts , None ),
744
+ out_shardings = (engine .decode_state_layouts ),
745
+ donate_argnames = ("decode_state" ),
746
+ ).lower (prefix_shapes [0 ], engine .decode_state_shapes , i32_scalar )
747
+ ).compile (compiler_options = None )
748
+ return engine , params , generate_executable , cached_prefill , cached_insert , decode_state
749
+
673
750
def _prefill_thread (self , idx : int ):
674
751
"""Thread which runs in the background performing prefills."""
675
752
logger .info ("Spinning up prefill thread %d." , idx )
@@ -683,6 +760,13 @@ def _prefill_thread(self, idx: int):
683
760
thread_name = f"Prefill thread { idx } "
684
761
ThreadDebugLog (thread_name , f"Prefill params { idx } loaded." )
685
762
763
+ if not self ._interleaved_mode :
764
+ logger .debug ("Compiling for disagg mode" )
765
+ prefill_engine , prefill_params , gen_fn , prefill_fn , insert_fn , _ = self ._auto_layout_compile (
766
+ prefill_engine , prefill_params
767
+ )
768
+ self ._cached_prefill [idx ] = prefill_fn
769
+
686
770
while self .live :
687
771
my_transfer_backlog = self ._transfer_backlogs [idx ]
688
772
# The prefill thread can just sleep until it has work to do.
@@ -759,10 +843,11 @@ def _prefill_thread(self, idx: int):
759
843
)
760
844
else :
761
845
# Compute new kv cache for the prefill_content.
762
- prefill_result , first_token = prefill_engine .prefill (
763
- params = final_prefill_params ,
764
- padded_tokens = padded_tokens ,
765
- true_length = true_length ,
846
+ assert padded_tokens .shape [0 ] in self ._cached_prefill [idx ]
847
+ prefill_result , first_token = self ._cached_prefill [idx ][padded_tokens .shape [0 ]](
848
+ final_prefill_params ,
849
+ padded_tokens ,
850
+ true_length ,
766
851
)
767
852
768
853
request .complete = np .zeros (
@@ -967,10 +1052,14 @@ def _insert_if_possible(
967
1052
else :
968
1053
break
969
1054
970
- decode_state = generate_engine .insert (
1055
+ if 'decoder' in new_request .prefill_result ['cache' ]:
1056
+ length = new_request .prefill_result ['cache' ]['decoder' ]['layers_0' ]['self_attention' ]['KVCache_0' ]['cache_prefill_segment_id' ].value .shape [1 ]
1057
+ else :
1058
+ length = new_request .prefill_result ['cache' ]['self_attention' ]['KVCache_0' ]['cache_prefill_segment_id' ].value .shape [2 ]
1059
+ decode_state = self ._cached_insert [idx ][length ](
971
1060
new_request .prefill_result ,
972
1061
decode_state ,
973
- slot = slot ,
1062
+ slot ,
974
1063
# request_id=new_request.request_id,
975
1064
)
976
1065
ThreadDebugLog (
@@ -1115,9 +1204,19 @@ def _generate_thread(self, idx: int):
1115
1204
# Keep track of what step tokens were generated at.
1116
1205
generate_timestep = 0
1117
1206
# State to store things like running kv cache in.
1118
- decode_state = generate_engine .init_decode_state ()
1119
-
1120
1207
generate_params = self ._generate_params [idx ]
1208
+
1209
+ if not self ._interleaved_mode :
1210
+ logger .debug ("Compiling for disagg mode" )
1211
+ generate_engine , generate_params , gen_fn , prefill_fn , insert_fn , decode_state = self ._auto_layout_compile (
1212
+ generate_engine , generate_params
1213
+ )
1214
+ self ._generate_executables [idx ] = gen_fn
1215
+ self ._cached_insert [idx ] = insert_fn
1216
+ self ._decode_states [idx ] = decode_state
1217
+
1218
+ decode_state = self ._decode_states [idx ]
1219
+
1121
1220
thread_name = f"Generate thread { idx } "
1122
1221
ThreadDebugLog (thread_name , f"Generate params { idx } loaded." )
1123
1222
time_of_last_generate = time .time ()
@@ -1178,8 +1277,8 @@ def _generate_thread(self, idx: int):
1178
1277
), "At this point we must have some requests inserted into the slots."
1179
1278
1180
1279
# Now we actually take a generate step on requests in the slots.
1181
- decode_state , sampled_tokens = generate_engine . generate (
1182
- generate_params , decode_state
1280
+ decode_state , sampled_tokens = self . _generate_executables [ idx ] (
1281
+ generate_params , decode_state , None
1183
1282
)
1184
1283
sampled_tokens .copy_to_host_async ()
1185
1284
# Respond to detokenization backpressure.
0 commit comments