Skip to content

Commit 114fd74

Browse files
committed
utils functions for PP
Signed-off-by: Chenyaaang <[email protected]>
1 parent 2392503 commit 114fd74

File tree

2 files changed

+146
-0
lines changed

2 files changed

+146
-0
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from typing import Any, Optional
2+
3+
import jax
4+
from jax.experimental import transfer
5+
6+
BASE_JAX_PORT = 5000
7+
8+
9+
class GroupCoordinator:
10+
"""
11+
Jax ProcessGroup wrapper for a group of Pipeline Parallel processes.
12+
This is a simplfied version which aligns the APIs with pytorch's
13+
GroupdCoordinator in vllm/distributed/parallel_state.py.
14+
GroupCoordinator takes charge of the communication operations among
15+
the processes in the group. Currently the communication is
16+
send/recv intermediate tensor (tensor_dict) between consecutive PP
17+
processes.
18+
"""
19+
rank_in_group: int
20+
world_size: int
21+
transfer_server: Optional[Any]
22+
connection: Optional[Any]
23+
24+
def __init__(self, rank_in_group: int, world_size: int):
25+
self.rank_in_group = rank_in_group
26+
self.world_size = world_size
27+
self.transfer_server = None
28+
self.connection = None
29+
30+
def send_tensor_dict(self, uuid: int, tensor_dict: dict[str, jax.Array]):
31+
self.transfer_server.await_pull(uuid, tensor_dict)
32+
33+
def recv_tensor_dict(self, uuid: int,
34+
tensor_spec: dict[str, jax.ShapeDtypeStruct]):
35+
return self.connection.pull(uuid, tensor_spec)
36+
37+
@property
38+
def is_first_rank(self):
39+
return self.rank_in_group == 0
40+
41+
@property
42+
def is_last_rank(self):
43+
return self.rank_in_group == self.world_size - 1
44+
45+
46+
def init_pp_distributed_environment(ip: str, rank: int, world_size: int,
47+
device: Any, need_pp: bool):
48+
global _PP
49+
_PP = GroupCoordinator(rank, world_size)
50+
if need_pp:
51+
port_number = BASE_JAX_PORT + rank
52+
server_address = f"{ip}:{port_number}"
53+
transfer_server = transfer.start_transfer_server(
54+
device.client, server_address, [f"{ip}:0", f"{ip}:0"])
55+
_PP.transfer_server = transfer_server
56+
57+
58+
def connect(prev_ip: str, prev_rank: int):
59+
prev_port_number = BASE_JAX_PORT + prev_rank
60+
connection = _PP.transfer_server.connect(f'{prev_ip}:{prev_port_number}')
61+
_PP.connection = connection
62+
63+
64+
def get_pp_group() -> GroupCoordinator:
65+
assert _PP is not None, (
66+
"pipeline model parallel group is not initialized")
67+
return _PP
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from dataclasses import dataclass
2+
from typing import TYPE_CHECKING, Any, Dict, Union
3+
4+
import jax
5+
from jax.tree_util import register_pytree_node_class
6+
from torchax.interop import jax_view, torch_view
7+
from vllm.sequence import IntermediateTensors
8+
9+
if TYPE_CHECKING:
10+
from vllm.v1.worker.kv_connector_model_runner_mixin import \
11+
KVConnectorOutput
12+
else:
13+
KVConnectorOutput = Any
14+
15+
16+
@register_pytree_node_class
17+
@dataclass
18+
class JaxIntermediateTensors:
19+
"""For all pipeline stages except the last, we need to return the
20+
intermediate tensor which is the hidden states (and residuals) to be
21+
sent to the next stage. This data structure contains the
22+
intermediate tensor for a request.
23+
24+
There is a PyTorch IntermediateTensors (in vllm/sequence.py) class in vllm
25+
for the same purpose.
26+
27+
Each stage also needs to handle its own kv_connector_output.
28+
29+
This class also contains the from_torch and to_torch functions, the goal is
30+
to convert between pytorch's intermediate tensor
31+
and Jax's intermediate tensor in torchax path.
32+
"""
33+
34+
tensors: Dict[str, Any]
35+
kv_connector_output: KVConnectorOutput = None
36+
37+
def tree_flatten(self):
38+
children = (self.tensors, )
39+
aux_data = self.kv_connector_output
40+
return (children, aux_data)
41+
42+
@classmethod
43+
def tree_unflatten(cls, aux_data, children):
44+
return cls(children[0], aux_data)
45+
46+
@classmethod
47+
def from_torch(cls, torch_obj: IntermediateTensors):
48+
kv_connector_output = getattr(torch_obj, 'kv_connector_output', None)
49+
jax_tensors = {k: jax_view(v) for k, v in torch_obj.tensors.items()}
50+
return cls(jax_tensors, kv_connector_output)
51+
52+
def to_torch(self) -> IntermediateTensors:
53+
torch_tensors = {k: torch_view(v) for k, v in self.tensors.items()}
54+
return IntermediateTensors(torch_tensors)
55+
56+
def __getitem__(self, key: Union[str, slice]):
57+
if isinstance(key, str):
58+
return self.tensors[key]
59+
elif isinstance(key, slice):
60+
return self.__class__({k: v[key] for k, v in self.tensors.items()})
61+
62+
def __setitem__(self, key: str, value: Any):
63+
self.tensors[key] = value
64+
65+
def keys(self):
66+
return self.tensors.keys()
67+
68+
def items(self):
69+
return self.tensors.items()
70+
71+
def __len__(self):
72+
return len(self.tensors)
73+
74+
def block_until_ready(self):
75+
for tensor in self.tensors.values():
76+
assert isinstance(
77+
tensor, jax.Array
78+
), "block_until_ready needs to be applied on jax arrays"
79+
tensor.block_until_ready()

0 commit comments

Comments
 (0)