1+ from dataclasses import dataclass
2+ from typing import TYPE_CHECKING , Any , Dict , Union
3+
4+ import jax
5+ from jax .tree_util import register_pytree_node_class
6+ from torchax .interop import jax_view , torch_view
7+ from vllm .sequence import IntermediateTensors
8+
9+ if TYPE_CHECKING :
10+ from vllm .v1 .worker .kv_connector_model_runner_mixin import \
11+ KVConnectorOutput
12+ else :
13+ KVConnectorOutput = Any
14+
15+
16+ @register_pytree_node_class
17+ @dataclass
18+ class JaxIntermediateTensors :
19+ """For all pipeline stages except the last, we need to return the
20+ intermediate tensor which is the hidden states (and residuals) to be
21+ sent to the next stage. This data structure contains the
22+ intermediate tensor for a request.
23+
24+ There is a PyTorch IntermediateTensors (in vllm/sequence.py) class in vllm
25+ for the same purpose.
26+
27+ Each stage also needs to handle its own kv_connector_output.
28+
29+ This class also contains the from_torch and to_torch functions, the goal is
30+ to convert between pytorch's intermediate tensor
31+ and Jax's intermediate tensor in torchax path.
32+ """
33+
34+ tensors : Dict [str , Any ]
35+ kv_connector_output : KVConnectorOutput = None
36+
37+ def tree_flatten (self ):
38+ children = (self .tensors , )
39+ aux_data = self .kv_connector_output
40+ return (children , aux_data )
41+
42+ @classmethod
43+ def tree_unflatten (cls , aux_data , children ):
44+ return cls (children [0 ], aux_data )
45+
46+ @classmethod
47+ def from_torch (cls , torch_obj : IntermediateTensors ):
48+ kv_connector_output = getattr (torch_obj , 'kv_connector_output' , None )
49+ jax_tensors = {k : jax_view (v ) for k , v in torch_obj .tensors .items ()}
50+ return cls (jax_tensors , kv_connector_output )
51+
52+ def to_torch (self ) -> IntermediateTensors :
53+ torch_tensors = {k : torch_view (v ) for k , v in self .tensors .items ()}
54+ return IntermediateTensors (torch_tensors )
55+
56+ def __getitem__ (self , key : Union [str , slice ]):
57+ if isinstance (key , str ):
58+ return self .tensors [key ]
59+ elif isinstance (key , slice ):
60+ return self .__class__ ({k : v [key ] for k , v in self .tensors .items ()})
61+
62+ def __setitem__ (self , key : str , value : Any ):
63+ self .tensors [key ] = value
64+
65+ def keys (self ):
66+ return self .tensors .keys ()
67+
68+ def items (self ):
69+ return self .tensors .items ()
70+
71+ def __len__ (self ):
72+ return len (self .tensors )
73+
74+ def block_until_ready (self ):
75+ for tensor in self .tensors .values ():
76+ assert isinstance (
77+ tensor , jax .Array
78+ ), "block_until_ready needs to be applied on jax arrays"
79+ tensor .block_until_ready ()
0 commit comments