1010import jaxtyping
1111import vllm .envs as envs
1212from vllm .config import VllmConfig , set_current_vllm_config
13+ from vllm .distributed import get_pp_group
1314from vllm .distributed .kv_transfer import (ensure_kv_transfer_initialized ,
1415 has_kv_transfer_group )
1516from vllm .distributed .parallel_state import (ensure_model_parallel_initialized ,
2324from vllm .v1 .outputs import DraftTokenIds , ModelRunnerOutput
2425
2526from tpu_inference import utils
27+ from tpu_inference .distributed import jax_parallel_state
2628from tpu_inference .distributed .utils import (get_host_ip , get_kv_transfer_port ,
2729 get_node_id )
2830from tpu_inference .layers .jax .sharding import ShardingConfigManager
2931from tpu_inference .logger import init_logger
32+ from tpu_inference .models .jax .jax_intermediate_tensor import \
33+ JaxIntermediateTensors
3034from tpu_inference .runner .kv_cache import get_rpa_page_size_bytes
3135from tpu_inference .runner .tpu_jax_runner import TPUModelRunner
3236
4145
4246class TPUWorker :
4347
44- def __init__ (self ,
45- vllm_config : VllmConfig ,
46- local_rank : int ,
47- rank : int ,
48- distributed_init_method : str ,
49- is_driver_worker : bool = False ,
50- devices = None ):
48+ def __init__ (
49+ self ,
50+ vllm_config : VllmConfig ,
51+ local_rank : int ,
52+ rank : int ,
53+ distributed_init_method : str ,
54+ is_driver_worker : bool = False ,
55+ devices = None ,
56+ ip : str = "localhost" ,
57+ prev_worker_ip : str = "localhost" ,
58+ ):
5159 # If we use vLLM's model implementation in PyTorch, we should set it
5260 # with torch version of the dtype.
5361 impl = os .getenv ("MODEL_IMPL_TYPE" , "flax_nnx" ).lower ()
@@ -74,6 +82,9 @@ def __init__(self,
7482 self .devices = devices if devices is not None else []
7583 self .device_ranks = set (device .id for device in self .devices
7684 if isinstance (device , jaxlib ._jax .Device ))
85+ self .ip = ip
86+ self .prev_worker_ip = prev_worker_ip
87+ self .pp_world_size = self .parallel_config .pipeline_parallel_size
7788
7889 if self .model_config .trust_remote_code :
7990 # note: lazy import to avoid importing torch before initializing
@@ -86,14 +97,20 @@ def __init__(self,
8697 # TPU Worker is initialized. The profiler server needs to start after
8798 # MP runtime is initialized.
8899 self .profile_dir = None
89- if envs .VLLM_TORCH_PROFILER_DIR and self .rank < 1 :
100+ if envs .VLLM_TORCH_PROFILER_DIR and self .rank < 1 and self . pp_world_size == 1 :
90101 if not self .devices or 0 in self .device_ranks :
91102 # For TPU, we can only have 1 active profiler session for 1 profiler
92103 # server. So we only profile on rank0.
93104 self .profile_dir = envs .VLLM_TORCH_PROFILER_DIR
94105 logger .info ("Profiling enabled. Traces will be saved to: %s" ,
95106 self .profile_dir )
96107
108+ # For PP, we use MPMD so we want to profile every worker.
109+ if self .pp_world_size > 1 and envs .VLLM_TORCH_PROFILER_DIR :
110+ self .profile_dir = os .path .join (envs .VLLM_TORCH_PROFILER_DIR ,
111+ f"rank_{ self .rank } " )
112+ os .makedirs (self .profile_dir , exist_ok = True )
113+
97114 use_jax_profiler_server = os .getenv ("USE_JAX_PROFILER_SERVER" , False )
98115 # Only one instance of profiler is allowed
99116 if use_jax_profiler_server and self .rank < 1 :
@@ -105,19 +122,50 @@ def __init__(self,
105122 )
106123 jax .profiler .start_server (jax_profiler_server_port )
107124
125+ self .step_counter = 0
126+
108127 def initialize_cache (self , num_gpu_blocks : int ,
109128 num_cpu_blocks : int ) -> None :
110129 self .cache_config .num_gpu_blocks = num_gpu_blocks
111130 self .cache_config .num_cpu_blocks = num_cpu_blocks
112131
113132 def init_device (self ):
133+ # set tpu visible devices for Jax runtime in single host PP.
134+ multihost_backend = os .environ .get ("TPU_MULTIHOST_BACKEND" , "" ).lower ()
135+ if multihost_backend != "ray" and self .parallel_config .pipeline_parallel_size > 1 :
136+ # Note: Below is the setting for v6e8 host (8 chips of v6e)
137+ # There are 2 ways of subslicing a v6e:
138+ # 1) 2 slices with 4 TPU chips each, we can do PP=2, TP=1/2/3/4
139+ # 2) 1 chip for each subslice, with at most 8 subslices,
140+ # we can do TP=1, PP=1/2/3/4/5/6/7/8
141+ # Replace with your own topology.
142+
143+ tpu_ports = [
144+ jax_parallel_state .BASE_JAX_PORT + i
145+ for i in range (self .pp_world_size )
146+ ]
147+ os .environ ["TPU_PROCESS_ADDRESSES" ] = "," .join (
148+ [f"localhost:{ port } " for port in tpu_ports ])
149+ os .environ ["TPU_PROCESS_PORT" ] = f"{ tpu_ports [self .rank ]} "
150+ os .environ ["CLOUD_TPU_TASK_ID" ] = f"{ self .rank } "
151+
152+ # first way of subslicing.
153+ # os.environ["TPU_PROCESS_BOUNDS"] = "1,1,1"
154+ # os.environ["TPU_CHIPS_PER_PROCESS_BOUNDS"] = f"1,4,1"
155+ # os.environ["TPU_VISIBLE_CHIPS"] = "0,1,2,3" if self.rank == 0 else "4,5,6,7"
156+
157+ # second way of subslicing.
158+ os .environ ["TPU_PROCESS_BOUNDS" ] = f"1,{ self .pp_world_size } ,1"
159+ os .environ ["TPU_CHIPS_PER_PROCESS_BOUNDS" ] = "1,1,1"
160+ os .environ ["TPU_VISIBLE_CHIPS" ] = f"{ self .rank } "
161+
114162 if not self .devices :
115163 sharding_config : ShardingConfigManager = self .vllm_config .sharding_config
116164 device_indexes = sharding_config .device_indexes
117165 if device_indexes is not None and len (device_indexes ) > 0 :
118166 # Enforcing the devices sequence to be consistent with the specified device indexes
119- self .devices = [jax .devices ()[i ] for i in device_indexes ]
120- all_devices = jax .devices ()
167+ self .devices = [jax .local_devices ()[i ] for i in device_indexes ]
168+ all_devices = jax .local_devices ()
121169 device_dict = {device .id : device for device in all_devices }
122170 self .devices = []
123171 for device_index in device_indexes :
@@ -128,10 +176,13 @@ def init_device(self):
128176 f"jax.devices() with IDs { list (device_dict .keys ())} !"
129177 )
130178 self .devices .append (device )
179+ assert len (self .devices ) >= sharding_config .total_devices
131180 self .devices = self .devices [:sharding_config .total_devices ]
132181 else :
133- self .devices = jax .devices ()[:sharding_config .total_devices ]
134-
182+ assert jax .local_device_count (
183+ ) >= sharding_config .total_devices
184+ self .devices = jax .local_devices ()[:sharding_config .
185+ total_devices ]
135186 # Initialize the vLLM distribution layer as a single chip environment,
136187 # we'll swap the model's parallel modules with TPU SPMD equivalents.
137188 with set_current_vllm_config (self .vllm_config ):
@@ -147,15 +198,30 @@ def init_device(self):
147198 tensor_model_parallel_size = 1 ,
148199 pipeline_model_parallel_size = 1 ,
149200 )
201+
202+ jax_parallel_state .init_pp_distributed_environment (
203+ self .ip ,
204+ self .rank ,
205+ self .parallel_config .pipeline_parallel_size ,
206+ self .devices [0 ],
207+ need_pp = self .parallel_config .pipeline_parallel_size > 1 )
208+
150209 ensure_kv_transfer_initialized (self .vllm_config )
151- self .model_runner = TPUModelRunner (self .vllm_config , self .devices )
210+ self .model_runner = TPUModelRunner (self .vllm_config , self .devices ,
211+ self .rank , self .rank == 0 ,
212+ self .rank == self .pp_world_size - 1 )
152213 logger .info (f"Init worker | "
153214 f"rank={ self .rank } | "
154215 f"node_id={ get_node_id ()} | "
155216 f"is_driver_worker={ self .is_driver_worker } | "
156217 f"hbm={ utils .hbm_usage_gb (self .devices )} GiB" )
157218 vllm_utils .report_usage_stats (self .vllm_config )
158219
220+ def initialize_pp_transfer_connect (self ):
221+ if self .rank == 0 :
222+ return
223+ jax_parallel_state .connect (self .prev_worker_ip , self .rank - 1 )
224+
159225 def determine_available_memory (self ) -> int :
160226 gpu_memory_utilization = self .cache_config .gpu_memory_utilization
161227 hbm_usage = utils .hbm_usage_bytes (self .devices )
@@ -195,14 +261,39 @@ def execute_model(
195261 # deliberate, temporary compromise for the same reasons outlined in
196262 # the `get_kv_cache_spec` method.
197263
198- output = self .model_runner .execute_model (scheduler_output )
199-
200- # With a connector, the scheduler expects output from all workers
201- # TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866
202- if has_kv_transfer_group ():
203- return output
204-
205- return output if self .is_driver_worker else None
264+ if self .parallel_config .pipeline_parallel_size == 1 or self .rank == 0 :
265+ intermediate_tensors = None
266+ else :
267+ # receive intermediate tensors
268+ uuid = self .model_runner .get_uuid_for_jax_transfer (
269+ scheduler_output , self .rank - 1 , self .step_counter )
270+ # TODO: this method might only works for vllm model, not sure about jax models.
271+ tensor_spec = self .model_runner .get_intermediate_tensor_spec (
272+ scheduler_output .total_num_scheduled_tokens )
273+ intermediate_tensors_dict = get_pp_group ().recv_tensor_dict (
274+ uuid , tensor_spec )
275+ intermediate_tensors = JaxIntermediateTensors (
276+ intermediate_tensors_dict )
277+
278+ output = self .model_runner .execute_model (scheduler_output ,
279+ intermediate_tensors )
280+
281+ if isinstance (output , JaxIntermediateTensors ):
282+ assert self .parallel_config .pipeline_parallel_size > 1
283+ assert not get_pp_group ().is_last_rank
284+ # send intermediate tensors
285+ uuid = self .model_runner .get_uuid_for_jax_transfer (
286+ scheduler_output , self .rank , self .step_counter )
287+ get_pp_group ().send_tensor_dict (uuid , output .tensors )
288+ self .step_counter += 1
289+ return None
290+ else :
291+ self .step_counter += 1
292+ # With a connector, the scheduler expects output from all workers
293+ # TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866
294+ if has_kv_transfer_group ():
295+ return output
296+ return output if self .is_driver_worker else None
206297
207298 def sample_tokens (self ,
208299 grammar_output : GrammarOutput ) -> ModelRunnerOutput :
0 commit comments