Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 112 additions & 21 deletions tpu_inference/worker/tpu_worker_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import jaxtyping
import vllm.envs as envs
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed import get_pp_group
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
has_kv_transfer_group)
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
Expand All @@ -23,10 +24,13 @@
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput

from tpu_inference import utils
from tpu_inference.distributed import jax_parallel_state
from tpu_inference.distributed.utils import (get_host_ip, get_kv_transfer_port,
get_node_id)
from tpu_inference.layers.jax.sharding import ShardingConfigManager
from tpu_inference.logger import init_logger
from tpu_inference.models.jax.jax_intermediate_tensor import \
JaxIntermediateTensors
from tpu_inference.runner.kv_cache import get_rpa_page_size_bytes
from tpu_inference.runner.tpu_jax_runner import TPUModelRunner

Expand All @@ -41,13 +45,17 @@

class TPUWorker:

def __init__(self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
devices=None):
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
devices=None,
ip: str = "localhost",
prev_worker_ip: str = "localhost",
):
# If we use vLLM's model implementation in PyTorch, we should set it
# with torch version of the dtype.
impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower()
Expand All @@ -74,6 +82,9 @@ def __init__(self,
self.devices = devices if devices is not None else []
self.device_ranks = set(device.id for device in self.devices
if isinstance(device, jaxlib._jax.Device))
self.ip = ip
self.prev_worker_ip = prev_worker_ip
self.pp_world_size = self.parallel_config.pipeline_parallel_size

if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
Expand All @@ -86,14 +97,20 @@ def __init__(self,
# TPU Worker is initialized. The profiler server needs to start after
# MP runtime is initialized.
self.profile_dir = None
if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1 and self.pp_world_size == 1:
if not self.devices or 0 in self.device_ranks:
# For TPU, we can only have 1 active profiler session for 1 profiler
# server. So we only profile on rank0.
self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
self.profile_dir)

# For PP, we use MPMD so we want to profile every worker.
if self.pp_world_size > 1 and envs.VLLM_TORCH_PROFILER_DIR:
self.profile_dir = os.path.join(envs.VLLM_TORCH_PROFILER_DIR,
f"rank_{self.rank}")
os.makedirs(self.profile_dir, exist_ok=True)

use_jax_profiler_server = os.getenv("USE_JAX_PROFILER_SERVER", False)
# Only one instance of profiler is allowed
if use_jax_profiler_server and self.rank < 1:
Expand All @@ -105,19 +122,50 @@ def __init__(self,
)
jax.profiler.start_server(jax_profiler_server_port)

self.step_counter = 0

def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks

def init_device(self):
# set tpu visible devices for Jax runtime in single host PP.
multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
if multihost_backend != "ray" and self.parallel_config.pipeline_parallel_size > 1:
# Note: Below is the setting for v6e8 host (8 chips of v6e)
# There are 2 ways of subslicing a v6e:
# 1) 2 slices with 4 TPU chips each, we can do PP=2, TP=1/2/3/4
# 2) 1 chip for each subslice, with at most 8 subslices,
# we can do TP=1, PP=1/2/3/4/5/6/7/8
# Replace with your own topology.

tpu_ports = [
jax_parallel_state.BASE_JAX_PORT + i
for i in range(self.pp_world_size)
]
os.environ["TPU_PROCESS_ADDRESSES"] = ",".join(
[f"localhost:{port}" for port in tpu_ports])
os.environ["TPU_PROCESS_PORT"] = f"{tpu_ports[self.rank]}"
os.environ["CLOUD_TPU_TASK_ID"] = f"{self.rank}"

# first way of subslicing.
# os.environ["TPU_PROCESS_BOUNDS"] = "1,1,1"
# os.environ["TPU_CHIPS_PER_PROCESS_BOUNDS"] = f"1,4,1"
# os.environ["TPU_VISIBLE_CHIPS"] = "0,1,2,3" if self.rank == 0 else "4,5,6,7"

# second way of subslicing.
os.environ["TPU_PROCESS_BOUNDS"] = f"1,{self.pp_world_size},1"
os.environ["TPU_CHIPS_PER_PROCESS_BOUNDS"] = "1,1,1"
os.environ["TPU_VISIBLE_CHIPS"] = f"{self.rank}"

if not self.devices:
sharding_config: ShardingConfigManager = self.vllm_config.sharding_config
device_indexes = sharding_config.device_indexes
if device_indexes is not None and len(device_indexes) > 0:
# Enforcing the devices sequence to be consistent with the specified device indexes
self.devices = [jax.devices()[i] for i in device_indexes]
all_devices = jax.devices()
self.devices = [jax.local_devices()[i] for i in device_indexes]
all_devices = jax.local_devices()
device_dict = {device.id: device for device in all_devices}
self.devices = []
for device_index in device_indexes:
Expand All @@ -128,10 +176,13 @@ def init_device(self):
f"jax.devices() with IDs {list(device_dict.keys())}!"
)
self.devices.append(device)
assert len(self.devices) >= sharding_config.total_devices
self.devices = self.devices[:sharding_config.total_devices]
else:
self.devices = jax.devices()[:sharding_config.total_devices]

assert jax.local_device_count(
) >= sharding_config.total_devices
self.devices = jax.local_devices()[:sharding_config.
total_devices]
# Initialize the vLLM distribution layer as a single chip environment,
# we'll swap the model's parallel modules with TPU SPMD equivalents.
with set_current_vllm_config(self.vllm_config):
Expand All @@ -147,15 +198,30 @@ def init_device(self):
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
)

jax_parallel_state.init_pp_distributed_environment(
self.ip,
self.rank,
self.parallel_config.pipeline_parallel_size,
self.devices[0],
need_pp=self.parallel_config.pipeline_parallel_size > 1)

ensure_kv_transfer_initialized(self.vllm_config)
self.model_runner = TPUModelRunner(self.vllm_config, self.devices)
self.model_runner = TPUModelRunner(self.vllm_config, self.devices,
self.rank, self.rank == 0,
self.rank == self.pp_world_size - 1)
logger.info(f"Init worker | "
f"rank={self.rank} | "
f"node_id={get_node_id()} | "
f"is_driver_worker={self.is_driver_worker} | "
f"hbm={utils.hbm_usage_gb(self.devices)}GiB")
vllm_utils.report_usage_stats(self.vllm_config)

def initialize_pp_transfer_connect(self):
if self.rank == 0:
return
jax_parallel_state.connect(self.prev_worker_ip, self.rank - 1)

def determine_available_memory(self) -> int:
gpu_memory_utilization = self.cache_config.gpu_memory_utilization
hbm_usage = utils.hbm_usage_bytes(self.devices)
Expand Down Expand Up @@ -195,14 +261,39 @@ def execute_model(
# deliberate, temporary compromise for the same reasons outlined in
# the `get_kv_cache_spec` method.

output = self.model_runner.execute_model(scheduler_output)

# With a connector, the scheduler expects output from all workers
# TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866
if has_kv_transfer_group():
return output

return output if self.is_driver_worker else None
if self.parallel_config.pipeline_parallel_size == 1 or self.rank == 0:
intermediate_tensors = None
else:
# receive intermediate tensors
uuid = self.model_runner.get_uuid_for_jax_transfer(
scheduler_output, self.rank - 1, self.step_counter)
# TODO: this method might only works for vllm model, not sure about jax models.
tensor_spec = self.model_runner.get_intermediate_tensor_spec(
scheduler_output.total_num_scheduled_tokens)
intermediate_tensors_dict = get_pp_group().recv_tensor_dict(
uuid, tensor_spec)
intermediate_tensors = JaxIntermediateTensors(
intermediate_tensors_dict)

output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors)

if isinstance(output, JaxIntermediateTensors):
assert self.parallel_config.pipeline_parallel_size > 1
assert not get_pp_group().is_last_rank
# send intermediate tensors
uuid = self.model_runner.get_uuid_for_jax_transfer(
scheduler_output, self.rank, self.step_counter)
get_pp_group().send_tensor_dict(uuid, output.tensors)
self.step_counter += 1
return None
else:
self.step_counter += 1
# With a connector, the scheduler expects output from all workers
# TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866
if has_kv_transfer_group():
return output
return output if self.is_driver_worker else None

def sample_tokens(self,
grammar_output: GrammarOutput) -> ModelRunnerOutput:
Expand Down