|
1 | 1 | import jax |
| 2 | +import jax.numpy as jnp |
2 | 3 | import torch |
3 | 4 | import torchax |
4 | 5 | from jax.sharding import Mesh, NamedSharding, PartitionSpec |
5 | 6 | from torch.nn import Parameter |
6 | 7 | from torch.utils import _pytree as pytree |
7 | | -from torchax.interop import jax_view, torch_view |
8 | | -from torchax.ops.mappings import t2j |
| 8 | +from torchax.interop import torch_view |
9 | 9 | from vllm.lora.layers import (MergedColumnParallelLinearWithLoRA, |
10 | 10 | MergedQKVParallelLinearWithLoRA, |
11 | 11 | RowParallelLinearWithLoRA) |
|
19 | 19 |
|
20 | 20 | logger = init_logger(__name__) |
21 | 21 |
|
| 22 | +TORCH_TO_JAX_DTYPE_MAP = { |
| 23 | + torch.float32: jnp.float32, |
| 24 | + torch.float16: jnp.float16, |
| 25 | + torch.bfloat16: jnp.bfloat16, |
| 26 | +} |
| 27 | + |
22 | 28 |
|
23 | 29 | def shard_model_to_tpu(model: torch.nn.Module, |
24 | 30 | mesh: Mesh) -> dict[str, torchax.torch.Tensor]: |
@@ -75,11 +81,9 @@ def _tensor_is_in_cpu(tensor: torch.tensor) -> bool: |
75 | 81 |
|
76 | 82 | def _convert_to_torchax_and_shard(tensor: torch.Tensor, |
77 | 83 | sharding: NamedSharding) -> torch.Tensor: |
78 | | - if isinstance(tensor, torchax.tensor.Tensor): |
79 | | - tensor = jax_view(tensor) |
80 | | - else: |
81 | | - tensor = t2j(tensor) |
82 | | - return torch_view(_sharded_device_put(tensor, sharding)) |
| 84 | + np_tensor = tensor.detach().cpu().to(torch.float32).numpy() |
| 85 | + dtype = TORCH_TO_JAX_DTYPE_MAP.get(tensor.dtype, jnp.float32) |
| 86 | + return torch_view(jax.device_put(np_tensor, sharding).astype(dtype)) |
83 | 87 |
|
84 | 88 |
|
85 | 89 | def _shard_tensor_to_tpu_replicated(tensor: torch.Tensor, |
|
0 commit comments