Skip to content

Commit 4cac52f

Browse files
committed
worker changes for pp
Signed-off-by: Chenyaaang <[email protected]>
1 parent b43ea79 commit 4cac52f

File tree

1 file changed

+112
-21
lines changed

1 file changed

+112
-21
lines changed

tpu_inference/worker/tpu_worker_jax.py

Lines changed: 112 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import jaxtyping
1111
import vllm.envs as envs
1212
from vllm.config import VllmConfig, set_current_vllm_config
13+
from vllm.distributed import get_pp_group
1314
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
1415
has_kv_transfer_group)
1516
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
@@ -23,10 +24,13 @@
2324
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
2425

2526
from tpu_inference import utils
27+
from tpu_inference.distributed import jax_parallel_state
2628
from tpu_inference.distributed.utils import (get_host_ip, get_kv_transfer_port,
2729
get_node_id)
2830
from tpu_inference.layers.jax.sharding import ShardingConfigManager
2931
from tpu_inference.logger import init_logger
32+
from tpu_inference.models.jax.jax_intermediate_tensor import \
33+
JaxIntermediateTensors
3034
from tpu_inference.runner.kv_cache import get_rpa_page_size_bytes
3135
from tpu_inference.runner.tpu_jax_runner import TPUModelRunner
3236

@@ -41,13 +45,17 @@
4145

4246
class TPUWorker:
4347

44-
def __init__(self,
45-
vllm_config: VllmConfig,
46-
local_rank: int,
47-
rank: int,
48-
distributed_init_method: str,
49-
is_driver_worker: bool = False,
50-
devices=None):
48+
def __init__(
49+
self,
50+
vllm_config: VllmConfig,
51+
local_rank: int,
52+
rank: int,
53+
distributed_init_method: str,
54+
is_driver_worker: bool = False,
55+
devices=None,
56+
ip: str = "localhost",
57+
prev_worker_ip: str = "localhost",
58+
):
5159
# If we use vLLM's model implementation in PyTorch, we should set it
5260
# with torch version of the dtype.
5361
impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower()
@@ -74,6 +82,9 @@ def __init__(self,
7482
self.devices = devices if devices is not None else []
7583
self.device_ranks = set(device.id for device in self.devices
7684
if isinstance(device, jaxlib._jax.Device))
85+
self.ip = ip
86+
self.prev_worker_ip = prev_worker_ip
87+
self.pp_world_size = self.parallel_config.pipeline_parallel_size
7788

7889
if self.model_config.trust_remote_code:
7990
# note: lazy import to avoid importing torch before initializing
@@ -86,14 +97,20 @@ def __init__(self,
8697
# TPU Worker is initialized. The profiler server needs to start after
8798
# MP runtime is initialized.
8899
self.profile_dir = None
89-
if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
100+
if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1 and self.pp_world_size == 1:
90101
if not self.devices or 0 in self.device_ranks:
91102
# For TPU, we can only have 1 active profiler session for 1 profiler
92103
# server. So we only profile on rank0.
93104
self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR
94105
logger.info("Profiling enabled. Traces will be saved to: %s",
95106
self.profile_dir)
96107

108+
# For PP, we use MPMD so we want to profile every worker.
109+
if self.pp_world_size > 1 and envs.VLLM_TORCH_PROFILER_DIR:
110+
self.profile_dir = os.path.join(envs.VLLM_TORCH_PROFILER_DIR,
111+
f"rank_{self.rank}")
112+
os.makedirs(self.profile_dir, exist_ok=True)
113+
97114
use_jax_profiler_server = os.getenv("USE_JAX_PROFILER_SERVER", False)
98115
# Only one instance of profiler is allowed
99116
if use_jax_profiler_server and self.rank < 1:
@@ -105,19 +122,50 @@ def __init__(self,
105122
)
106123
jax.profiler.start_server(jax_profiler_server_port)
107124

125+
self.step_counter = 0
126+
108127
def initialize_cache(self, num_gpu_blocks: int,
109128
num_cpu_blocks: int) -> None:
110129
self.cache_config.num_gpu_blocks = num_gpu_blocks
111130
self.cache_config.num_cpu_blocks = num_cpu_blocks
112131

113132
def init_device(self):
133+
# set tpu visible devices for Jax runtime in single host PP.
134+
multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
135+
if multihost_backend != "ray" and self.parallel_config.pipeline_parallel_size > 1:
136+
# Note: Below is the setting for v6e8 host (8 chips of v6e)
137+
# There are 2 ways of subslicing a v6e:
138+
# 1) 2 slices with 4 TPU chips each, we can do PP=2, TP=1/2/3/4
139+
# 2) 1 chip for each subslice, with at most 8 subslices,
140+
# we can do TP=1, PP=1/2/3/4/5/6/7/8
141+
# Replace with your own topology.
142+
143+
tpu_ports = [
144+
jax_parallel_state.BASE_JAX_PORT + i
145+
for i in range(self.pp_world_size)
146+
]
147+
os.environ["TPU_PROCESS_ADDRESSES"] = ",".join(
148+
[f"localhost:{port}" for port in tpu_ports])
149+
os.environ["TPU_PROCESS_PORT"] = f"{tpu_ports[self.rank]}"
150+
os.environ["CLOUD_TPU_TASK_ID"] = f"{self.rank}"
151+
152+
# first way of subslicing.
153+
# os.environ["TPU_PROCESS_BOUNDS"] = "1,1,1"
154+
# os.environ["TPU_CHIPS_PER_PROCESS_BOUNDS"] = f"1,4,1"
155+
# os.environ["TPU_VISIBLE_CHIPS"] = "0,1,2,3" if self.rank == 0 else "4,5,6,7"
156+
157+
# second way of subslicing.
158+
os.environ["TPU_PROCESS_BOUNDS"] = f"1,{self.pp_world_size},1"
159+
os.environ["TPU_CHIPS_PER_PROCESS_BOUNDS"] = "1,1,1"
160+
os.environ["TPU_VISIBLE_CHIPS"] = f"{self.rank}"
161+
114162
if not self.devices:
115163
sharding_config: ShardingConfigManager = self.vllm_config.sharding_config
116164
device_indexes = sharding_config.device_indexes
117165
if device_indexes is not None and len(device_indexes) > 0:
118166
# Enforcing the devices sequence to be consistent with the specified device indexes
119-
self.devices = [jax.devices()[i] for i in device_indexes]
120-
all_devices = jax.devices()
167+
self.devices = [jax.local_devices()[i] for i in device_indexes]
168+
all_devices = jax.local_devices()
121169
device_dict = {device.id: device for device in all_devices}
122170
self.devices = []
123171
for device_index in device_indexes:
@@ -128,10 +176,13 @@ def init_device(self):
128176
f"jax.devices() with IDs {list(device_dict.keys())}!"
129177
)
130178
self.devices.append(device)
179+
assert len(self.devices) >= sharding_config.total_devices
131180
self.devices = self.devices[:sharding_config.total_devices]
132181
else:
133-
self.devices = jax.devices()[:sharding_config.total_devices]
134-
182+
assert jax.local_device_count(
183+
) >= sharding_config.total_devices
184+
self.devices = jax.local_devices()[:sharding_config.
185+
total_devices]
135186
# Initialize the vLLM distribution layer as a single chip environment,
136187
# we'll swap the model's parallel modules with TPU SPMD equivalents.
137188
with set_current_vllm_config(self.vllm_config):
@@ -147,15 +198,30 @@ def init_device(self):
147198
tensor_model_parallel_size=1,
148199
pipeline_model_parallel_size=1,
149200
)
201+
202+
jax_parallel_state.init_pp_distributed_environment(
203+
self.ip,
204+
self.rank,
205+
self.parallel_config.pipeline_parallel_size,
206+
self.devices[0],
207+
need_pp=self.parallel_config.pipeline_parallel_size > 1)
208+
150209
ensure_kv_transfer_initialized(self.vllm_config)
151-
self.model_runner = TPUModelRunner(self.vllm_config, self.devices)
210+
self.model_runner = TPUModelRunner(self.vllm_config, self.devices,
211+
self.rank, self.rank == 0,
212+
self.rank == self.pp_world_size - 1)
152213
logger.info(f"Init worker | "
153214
f"rank={self.rank} | "
154215
f"node_id={get_node_id()} | "
155216
f"is_driver_worker={self.is_driver_worker} | "
156217
f"hbm={utils.hbm_usage_gb(self.devices)}GiB")
157218
vllm_utils.report_usage_stats(self.vllm_config)
158219

220+
def initialize_pp_transfer_connect(self):
221+
if self.rank == 0:
222+
return
223+
jax_parallel_state.connect(self.prev_worker_ip, self.rank - 1)
224+
159225
def determine_available_memory(self) -> int:
160226
gpu_memory_utilization = self.cache_config.gpu_memory_utilization
161227
hbm_usage = utils.hbm_usage_bytes(self.devices)
@@ -195,14 +261,39 @@ def execute_model(
195261
# deliberate, temporary compromise for the same reasons outlined in
196262
# the `get_kv_cache_spec` method.
197263

198-
output = self.model_runner.execute_model(scheduler_output)
199-
200-
# With a connector, the scheduler expects output from all workers
201-
# TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866
202-
if has_kv_transfer_group():
203-
return output
204-
205-
return output if self.is_driver_worker else None
264+
if self.parallel_config.pipeline_parallel_size == 1 or self.rank == 0:
265+
intermediate_tensors = None
266+
else:
267+
# receive intermediate tensors
268+
uuid = self.model_runner.get_uuid_for_jax_transfer(
269+
scheduler_output, self.rank - 1, self.step_counter)
270+
# TODO: this method might only works for vllm model, not sure about jax models.
271+
tensor_spec = self.model_runner.get_intermediate_tensor_spec(
272+
scheduler_output.total_num_scheduled_tokens)
273+
intermediate_tensors_dict = get_pp_group().recv_tensor_dict(
274+
uuid, tensor_spec)
275+
intermediate_tensors = JaxIntermediateTensors(
276+
intermediate_tensors_dict)
277+
278+
output = self.model_runner.execute_model(scheduler_output,
279+
intermediate_tensors)
280+
281+
if isinstance(output, JaxIntermediateTensors):
282+
assert self.parallel_config.pipeline_parallel_size > 1
283+
assert not get_pp_group().is_last_rank
284+
# send intermediate tensors
285+
uuid = self.model_runner.get_uuid_for_jax_transfer(
286+
scheduler_output, self.rank, self.step_counter)
287+
get_pp_group().send_tensor_dict(uuid, output.tensors)
288+
self.step_counter += 1
289+
return None
290+
else:
291+
self.step_counter += 1
292+
# With a connector, the scheduler expects output from all workers
293+
# TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866
294+
if has_kv_transfer_group():
295+
return output
296+
return output if self.is_driver_worker else None
206297

207298
def sample_tokens(self,
208299
grammar_output: GrammarOutput) -> ModelRunnerOutput:

0 commit comments

Comments
 (0)