Skip to content

Commit 74f7e29

Browse files
authored
Fix Torchax backend on Pathways (#1052)
Signed-off-by: Richard Liu <[email protected]>
1 parent 5c04ffa commit 74f7e29

File tree

3 files changed

+13
-9
lines changed

3 files changed

+13
-9
lines changed

tpu_inference/kernels/ragged_paged_attention/v3/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def get_tpu_version() -> int:
4343
return -1
4444
if kind.endswith(' lite'):
4545
kind = kind[:-len(' lite')]
46-
if kind.endswith('p'):
46+
if kind.endswith('p') or kind.endswith('e'):
4747
kind = kind[:-1]
4848
if kind == 'TPU7x':
4949
return 7

tpu_inference/layers/vllm/sharding.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import jax
2+
import jax.numpy as jnp
23
import torch
34
import torchax
45
from jax.sharding import Mesh, NamedSharding, PartitionSpec
56
from torch.nn import Parameter
67
from torch.utils import _pytree as pytree
7-
from torchax.interop import jax_view, torch_view
8-
from torchax.ops.mappings import t2j
8+
from torchax.interop import torch_view
99
from vllm.lora.layers import (MergedColumnParallelLinearWithLoRA,
1010
MergedQKVParallelLinearWithLoRA,
1111
RowParallelLinearWithLoRA)
@@ -19,6 +19,12 @@
1919

2020
logger = init_logger(__name__)
2121

22+
TORCH_TO_JAX_DTYPE_MAP = {
23+
torch.float32: jnp.float32,
24+
torch.float16: jnp.float16,
25+
torch.bfloat16: jnp.bfloat16,
26+
}
27+
2228

2329
def shard_model_to_tpu(model: torch.nn.Module,
2430
mesh: Mesh) -> dict[str, torchax.torch.Tensor]:
@@ -75,11 +81,9 @@ def _tensor_is_in_cpu(tensor: torch.tensor) -> bool:
7581

7682
def _convert_to_torchax_and_shard(tensor: torch.Tensor,
7783
sharding: NamedSharding) -> torch.Tensor:
78-
if isinstance(tensor, torchax.tensor.Tensor):
79-
tensor = jax_view(tensor)
80-
else:
81-
tensor = t2j(tensor)
82-
return torch_view(_sharded_device_put(tensor, sharding))
84+
np_tensor = tensor.detach().cpu().to(torch.float32).numpy()
85+
dtype = TORCH_TO_JAX_DTYPE_MAP.get(tensor.dtype, jnp.float32)
86+
return torch_view(jax.device_put(np_tensor, sharding).astype(dtype))
8387

8488

8589
def _shard_tensor_to_tpu_replicated(tensor: torch.Tensor,

tpu_inference/models/vllm/vllm_model_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def load_weights(self):
110110

111111
# Load the vLLM model and wrap it into a new model whose forward
112112
# function can calculate the hidden_state and logits.
113-
with load_context, jax.default_device(jax.devices('cpu')[0]):
113+
with load_context:
114114
vllm_model = vllm_get_model(vllm_config=vllm_config_for_load)
115115
lora_manager = None
116116
if vllm_config_for_load.lora_config is not None:

0 commit comments

Comments
 (0)