From 4cac52f8d5d6f55603d0783bdb31cbc85e42e701 Mon Sep 17 00:00:00 2001 From: Chenyaaang Date: Fri, 7 Nov 2025 17:59:12 +0000 Subject: [PATCH] worker changes for pp Signed-off-by: Chenyaaang --- tpu_inference/worker/tpu_worker_jax.py | 133 +++++++++++++++++++++---- 1 file changed, 112 insertions(+), 21 deletions(-) diff --git a/tpu_inference/worker/tpu_worker_jax.py b/tpu_inference/worker/tpu_worker_jax.py index ca7cb680a..8cd74cc9d 100644 --- a/tpu_inference/worker/tpu_worker_jax.py +++ b/tpu_inference/worker/tpu_worker_jax.py @@ -10,6 +10,7 @@ import jaxtyping import vllm.envs as envs from vllm.config import VllmConfig, set_current_vllm_config +from vllm.distributed import get_pp_group from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, has_kv_transfer_group) from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, @@ -23,10 +24,13 @@ from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from tpu_inference import utils +from tpu_inference.distributed import jax_parallel_state from tpu_inference.distributed.utils import (get_host_ip, get_kv_transfer_port, get_node_id) from tpu_inference.layers.jax.sharding import ShardingConfigManager from tpu_inference.logger import init_logger +from tpu_inference.models.jax.jax_intermediate_tensor import \ + JaxIntermediateTensors from tpu_inference.runner.kv_cache import get_rpa_page_size_bytes from tpu_inference.runner.tpu_jax_runner import TPUModelRunner @@ -41,13 +45,17 @@ class TPUWorker: - def __init__(self, - vllm_config: VllmConfig, - local_rank: int, - rank: int, - distributed_init_method: str, - is_driver_worker: bool = False, - devices=None): + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool = False, + devices=None, + ip: str = "localhost", + prev_worker_ip: str = "localhost", + ): # If we use vLLM's model implementation in PyTorch, we should set it # with torch version of the dtype. impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower() @@ -74,6 +82,9 @@ def __init__(self, self.devices = devices if devices is not None else [] self.device_ranks = set(device.id for device in self.devices if isinstance(device, jaxlib._jax.Device)) + self.ip = ip + self.prev_worker_ip = prev_worker_ip + self.pp_world_size = self.parallel_config.pipeline_parallel_size if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing @@ -86,7 +97,7 @@ def __init__(self, # TPU Worker is initialized. The profiler server needs to start after # MP runtime is initialized. self.profile_dir = None - if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1: + if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1 and self.pp_world_size == 1: if not self.devices or 0 in self.device_ranks: # For TPU, we can only have 1 active profiler session for 1 profiler # server. So we only profile on rank0. @@ -94,6 +105,12 @@ def __init__(self, logger.info("Profiling enabled. Traces will be saved to: %s", self.profile_dir) + # For PP, we use MPMD so we want to profile every worker. + if self.pp_world_size > 1 and envs.VLLM_TORCH_PROFILER_DIR: + self.profile_dir = os.path.join(envs.VLLM_TORCH_PROFILER_DIR, + f"rank_{self.rank}") + os.makedirs(self.profile_dir, exist_ok=True) + use_jax_profiler_server = os.getenv("USE_JAX_PROFILER_SERVER", False) # Only one instance of profiler is allowed if use_jax_profiler_server and self.rank < 1: @@ -105,19 +122,50 @@ def __init__(self, ) jax.profiler.start_server(jax_profiler_server_port) + self.step_counter = 0 + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks def init_device(self): + # set tpu visible devices for Jax runtime in single host PP. + multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower() + if multihost_backend != "ray" and self.parallel_config.pipeline_parallel_size > 1: + # Note: Below is the setting for v6e8 host (8 chips of v6e) + # There are 2 ways of subslicing a v6e: + # 1) 2 slices with 4 TPU chips each, we can do PP=2, TP=1/2/3/4 + # 2) 1 chip for each subslice, with at most 8 subslices, + # we can do TP=1, PP=1/2/3/4/5/6/7/8 + # Replace with your own topology. + + tpu_ports = [ + jax_parallel_state.BASE_JAX_PORT + i + for i in range(self.pp_world_size) + ] + os.environ["TPU_PROCESS_ADDRESSES"] = ",".join( + [f"localhost:{port}" for port in tpu_ports]) + os.environ["TPU_PROCESS_PORT"] = f"{tpu_ports[self.rank]}" + os.environ["CLOUD_TPU_TASK_ID"] = f"{self.rank}" + + # first way of subslicing. + # os.environ["TPU_PROCESS_BOUNDS"] = "1,1,1" + # os.environ["TPU_CHIPS_PER_PROCESS_BOUNDS"] = f"1,4,1" + # os.environ["TPU_VISIBLE_CHIPS"] = "0,1,2,3" if self.rank == 0 else "4,5,6,7" + + # second way of subslicing. + os.environ["TPU_PROCESS_BOUNDS"] = f"1,{self.pp_world_size},1" + os.environ["TPU_CHIPS_PER_PROCESS_BOUNDS"] = "1,1,1" + os.environ["TPU_VISIBLE_CHIPS"] = f"{self.rank}" + if not self.devices: sharding_config: ShardingConfigManager = self.vllm_config.sharding_config device_indexes = sharding_config.device_indexes if device_indexes is not None and len(device_indexes) > 0: # Enforcing the devices sequence to be consistent with the specified device indexes - self.devices = [jax.devices()[i] for i in device_indexes] - all_devices = jax.devices() + self.devices = [jax.local_devices()[i] for i in device_indexes] + all_devices = jax.local_devices() device_dict = {device.id: device for device in all_devices} self.devices = [] for device_index in device_indexes: @@ -128,10 +176,13 @@ def init_device(self): f"jax.devices() with IDs {list(device_dict.keys())}!" ) self.devices.append(device) + assert len(self.devices) >= sharding_config.total_devices self.devices = self.devices[:sharding_config.total_devices] else: - self.devices = jax.devices()[:sharding_config.total_devices] - + assert jax.local_device_count( + ) >= sharding_config.total_devices + self.devices = jax.local_devices()[:sharding_config. + total_devices] # Initialize the vLLM distribution layer as a single chip environment, # we'll swap the model's parallel modules with TPU SPMD equivalents. with set_current_vllm_config(self.vllm_config): @@ -147,8 +198,18 @@ def init_device(self): tensor_model_parallel_size=1, pipeline_model_parallel_size=1, ) + + jax_parallel_state.init_pp_distributed_environment( + self.ip, + self.rank, + self.parallel_config.pipeline_parallel_size, + self.devices[0], + need_pp=self.parallel_config.pipeline_parallel_size > 1) + ensure_kv_transfer_initialized(self.vllm_config) - self.model_runner = TPUModelRunner(self.vllm_config, self.devices) + self.model_runner = TPUModelRunner(self.vllm_config, self.devices, + self.rank, self.rank == 0, + self.rank == self.pp_world_size - 1) logger.info(f"Init worker | " f"rank={self.rank} | " f"node_id={get_node_id()} | " @@ -156,6 +217,11 @@ def init_device(self): f"hbm={utils.hbm_usage_gb(self.devices)}GiB") vllm_utils.report_usage_stats(self.vllm_config) + def initialize_pp_transfer_connect(self): + if self.rank == 0: + return + jax_parallel_state.connect(self.prev_worker_ip, self.rank - 1) + def determine_available_memory(self) -> int: gpu_memory_utilization = self.cache_config.gpu_memory_utilization hbm_usage = utils.hbm_usage_bytes(self.devices) @@ -195,14 +261,39 @@ def execute_model( # deliberate, temporary compromise for the same reasons outlined in # the `get_kv_cache_spec` method. - output = self.model_runner.execute_model(scheduler_output) - - # With a connector, the scheduler expects output from all workers - # TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866 - if has_kv_transfer_group(): - return output - - return output if self.is_driver_worker else None + if self.parallel_config.pipeline_parallel_size == 1 or self.rank == 0: + intermediate_tensors = None + else: + # receive intermediate tensors + uuid = self.model_runner.get_uuid_for_jax_transfer( + scheduler_output, self.rank - 1, self.step_counter) + # TODO: this method might only works for vllm model, not sure about jax models. + tensor_spec = self.model_runner.get_intermediate_tensor_spec( + scheduler_output.total_num_scheduled_tokens) + intermediate_tensors_dict = get_pp_group().recv_tensor_dict( + uuid, tensor_spec) + intermediate_tensors = JaxIntermediateTensors( + intermediate_tensors_dict) + + output = self.model_runner.execute_model(scheduler_output, + intermediate_tensors) + + if isinstance(output, JaxIntermediateTensors): + assert self.parallel_config.pipeline_parallel_size > 1 + assert not get_pp_group().is_last_rank + # send intermediate tensors + uuid = self.model_runner.get_uuid_for_jax_transfer( + scheduler_output, self.rank, self.step_counter) + get_pp_group().send_tensor_dict(uuid, output.tensors) + self.step_counter += 1 + return None + else: + self.step_counter += 1 + # With a connector, the scheduler expects output from all workers + # TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866 + if has_kv_transfer_group(): + return output + return output if self.is_driver_worker else None def sample_tokens(self, grammar_output: GrammarOutput) -> ModelRunnerOutput: