Skip to content

Add features needed for vllm #9092

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 16 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions torchax/test/test_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,36 @@ def setUp(self):
torch.manual_seed(0)
torchax.enable_globally()

def test_index_copy_(self):
x = torch.zeros((10, 10), device="jax")
x_view = x[0, :]
indices = torch.arange(5, device="jax")
new_value = torch.ones((5,), device="jax")
x_view.index_copy_(0, indices, new_value)
self.assertEqual(type(x), Tensor)
self.assertEqual(type(x_view), View)
self.assertEqual(x.shape, (10, 10))
self.assertEqual(x.sum(), 5)

def test_flatten(self):
x = torch.zeros((10, 10), device="jax")
x1 = x.flatten(0, 1)
y = torch.ones(100, device="jax")
x1.copy_(y)
self.assertEqual(type(x), Tensor)
self.assertEqual(type(x1), View)
self.assertEqual(x.shape, (10, 10))
self.assertEqual(x.sum(), 100)

def test_narrow(self):
x = torch.zeros((10, 10), device="jax")
x = x.narrow(0, 0, 5).narrow(0, 0, 5)
y = torch.ones((5, 10), device="jax")
x.copy_(y)
self.assertEqual(type(x), View)
self.assertEqual(x.shape, (5, 10))
self.assertEqual(x.sum(), 50)

def test_copy_(self):
x = torch.zeros((10, 10), device="jax")
y = torch.ones((5, 5), device="jax")
Expand Down
3 changes: 3 additions & 0 deletions torchax/torchax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,6 @@ def compile(fn, options: Optional[CompileOptions] = None):
raise RuntimeError('dynamo mode is not supported yet')
elif options.mode == 'export':
raise RuntimeError('export mode is not supported yet')

# Intercept torch._sync as no-op
torch._sync = lambda *args, **kwargs: None
1 change: 1 addition & 0 deletions torchax/torchax/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,4 +766,5 @@ def get_summand(ix: torch.Tensor, iy: torch.Tensor, iz: torch.Tensor,
MUTABLE_DECOMPOSITION = [
torch.ops.aten.bernoulli_.Tensor,
torch.ops.aten.bernoulli_.float,
torch.ops.aten.index_copy_.default,
]
3 changes: 2 additions & 1 deletion torchax/torchax/interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torchax import tensor
from torchax import util
import torchax
from torchax.view import View

from torchax.types import JaxValue, TorchValue, JaxCallable, TorchCallable

Expand Down Expand Up @@ -179,7 +180,7 @@ def _jax_view(t: TorchValue) -> JaxValue:
# t is an object from torch land
# view it as-if it's a jax land object
if isinstance(t, torch.Tensor):
assert isinstance(t, tensor.Tensor), type(t)
assert isinstance(t, tensor.Tensor) or isinstance(t, View), type(t)
return t.jax()
if isinstance(t, type(torch.int32)):
return tensor.t2j_dtype(t)
Expand Down
23 changes: 13 additions & 10 deletions torchax/torchax/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torchax.ops import op_base, mappings
from torchax import interop
from torchax.ops import jax_reimplement
from torchax.view import View
from torchax.view import View, NarrowInfo, ReshapeInfo
from torchax.tensor import Tensor
# Keys are OpOverload, value is a callable that takes
# Tensor
Expand Down Expand Up @@ -57,6 +57,7 @@
torch.ops.aten.scatter_add_: torch.ops.aten.scatter_add,
torch.ops.aten.scatter_reduce_.two: torch.ops.aten.scatter_reduce,
torch.ops.aten.scatter_: torch.ops.aten.scatter,
torch.ops.aten.index_put_: torch.ops.aten.index_put,
}

# Note: tuple comparisons work intuitively, e.g. `_jax_version >= (0, 4, 32)`.
Expand Down Expand Up @@ -102,13 +103,14 @@ def inner(func):


@op(
torch.ops.aten.view_copy,
torch.ops.aten.view,
torch.ops.aten._unsafe_view,
torch.ops.aten.reshape,
torch.ops.aten.view_copy,
torch.ops.aten.view,
torch.ops.aten._unsafe_view,
torch.ops.aten.reshape,
is_jax_function=False,
)
def _aten_unsafe_view(x, shape):
return jnp.reshape(x, shape)
return View(x, ReshapeInfo(shape=shape), env=x._env)


@op(torch.ops.aten.add.Tensor)
Expand All @@ -131,6 +133,8 @@ def _aten_copy(x, y, memory_format=None):
if isinstance(x, View):
x.update(y)
return x
if isinstance(y, View):
y = y.torch()

if x.ndim == 1 and y.ndim == 0:
# case of torch.empty((1,)).copy_(tensor(N))
Expand Down Expand Up @@ -402,8 +406,8 @@ def _aten_triu(m, k):
return jnp.triu(m, k)


@op(torch.ops.aten.slice)
@op(torch.ops.aten.slice_copy)
@op(torch.ops.aten.slice, is_jax_function=False, is_view_op=True)
@op(torch.ops.aten.slice_copy, is_jax_function=False, is_view_op=True)
def _aten_slice(self, dim=0, start=None, end=None, step=1):
if dim < 0:
dim += self.ndim
Expand All @@ -416,7 +420,7 @@ def _aten_slice(self, dim=0, start=None, end=None, step=1):
dims.append(sl)
else:
dims.append(slice(None, None, None))
return self[tuple(dims)]
return View(self, NarrowInfo(slices=tuple(dims)), env = self._env)


@op(torch.ops.aten.detach)
Expand Down Expand Up @@ -779,7 +783,6 @@ def _aten_empty_strided(sizes, stride, dtype=None, **kwargs):
return jnp.empty(sizes, dtype=dtype)


@op(torch.ops.aten.index_put_)
@op(torch.ops.aten.index_put)
def _aten_index_put(self, indexes, values, accumulate=False):
indexes = [slice(None, None, None) if i is None else i for i in indexes]
Expand Down
53 changes: 51 additions & 2 deletions torchax/torchax/ops/jtorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,15 +269,15 @@ def getitem(self, indexes):
elif isinstance(indexes, list):
indexes = tuple(indexes)

def is_narrow_slicing():
def is_view_slicing():
tensor_free = not pytree.tree_any(
lambda x: isinstance(x, torch.Tensor) or isinstance(x, jax.Array),
indexes)
list_free = not isinstance(indexes, tuple) or all(
[False if isinstance(x, list) else True for x in indexes])
return tensor_free and list_free

if is_narrow_slicing():
if is_view_slicing():
return View(self, view_info=NarrowInfo(indexes), env=self._env)

indexes = self._env.t2j_iso(indexes)
Expand Down Expand Up @@ -512,3 +512,52 @@ def functional_linear(self, weights, bias=None):
if bias is not None:
res += bias
return res


try:
# TODO: Currently the following ops are wrapped in the try
# catch block because torch.ops.xla is not in the torch ops
# registry. Either we import torch_xla in the upper level,
# or modify the the register_function to support this.
@register_function(torch.ops.xla.dynamo_set_buffer_donor_)
def _dynamo_set_buffer_donor(self, donor):
pass

@register_function(torch.ops.xla.ragged_paged_attention)
def _ragged_paged_attention(
q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim]
kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
kv_lens: jax.Array, # i32[max_num_seqs]
page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq]
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
num_seqs: jax.Array, # i32[1]
use_kernel: bool = True,
sm_scale: float = 1.0,
sliding_window: int | None = None,
soft_cap: float | None = None,
mask_value: float | None = None,
num_kv_pages_per_block: int | None = None,
num_queries_per_block: int | None = None,
vmem_limit_bytes: int | None = None,
):

from torch_xla.experimental.pallas_kernels.ragged_paged_attention_v2 import ragged_paged_attention as ragged_paged_attention_kernel
return ragged_paged_attention_kernel(
q = q,
kv_pages = kv_pages,
kv_lens = kv_lens,
page_indices = page_indices,
cu_q_lens = cu_q_lens,
num_seqs = num_seqs,
sm_scale = sm_scale,
sliding_window = sliding_window,
soft_cap = soft_cap,
mask_value = mask_value,
num_kv_pages_per_block = num_kv_pages_per_block,
num_queries_per_block = num_queries_per_block,
vmem_limit_bytes = vmem_limit_bytes,
)
except Exception as e:
pass


10 changes: 1 addition & 9 deletions torchax/torchax/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,6 @@ def shape(self):
def ndim(self):
return len(self._elem.shape)

def flatten(self, start_dim=0, end_dim=-1):
if end_dim == -1:
end_dim = self.ndim
new_shape = (
self._elem.shape[:start_dim] + (-1,) + self._elem.shape[end_dim + 1:])
new_elem = jnp.reshape(self._elem, new_shape)
return Tensor(new_elem, self._env)
# return torch.reshape(self, new_shape)

def __setitem__(self, key, val):
key, val = self._env.t2j_iso((key, val))
Expand Down Expand Up @@ -381,7 +373,7 @@ def load_ops(self):
)

def _to_copy(self, the_tensor, new_dtype, new_device):
if isinstance(the_tensor, Tensor):
if isinstance(the_tensor, Tensor) or isinstance(the_tensor, View):
arr = the_tensor.jax()
if new_dtype is not None and new_dtype != arr.dtype:
arr = arr.astype(mappings.t2j_dtype(new_dtype))
Expand Down
54 changes: 54 additions & 0 deletions torchax/torchax/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from enum import Enum
from typing import Union, List, Tuple, Optional, Any, cast
from abc import ABC, abstractmethod
import torch.utils._pytree as pytree

# Reference to original PyTorch native functions
# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
Expand Down Expand Up @@ -114,6 +115,35 @@ def update_tensor(self, new_value: jax.Array,
def calculate_output_shape(self, source: jax.Array) -> List[int]:
return source[self.slices].shape

class ReshapeInfo(ViewInfo):
"""
Represents a reshape operation on a tensor.
Handles operations like tensor.reshape(1, 2, 3) and tensor.reshape(-1, 1)
"""

def __init__(self, shape: Tuple[int, ...]) -> None:
"""
Args:
shape: The shape to reshape the tensor to.
E.g. jax_array.reshape(shape) will return the transformed tensor.
"""
super().__init__(ViewInfoType.RESHAPE)
self.shape = shape

def __eq__(self, other: object) -> bool:
if not isinstance(other, ReshapeInfo):
return False
return self.shape == other.shape

def transform_tensor(self, jax_array: jax.Array) -> jax.Array:
return jax_array.reshape(self.shape)

def update_tensor(self, new_value: jax.Array, jax_array: jax.Array) -> jax.Array:
return new_value.reshape(jax_array.shape)

def calculate_output_shape(self, source: jax.Array) -> List[int]:
return source.reshape(self.shape).shape


class SelectInfo(ViewInfo):
"""
Expand Down Expand Up @@ -318,6 +348,8 @@ def update(
# applying inverse transformations in reverse order
for view_info, parent_array in zip(
reversed(view_infos), reversed(intermediate_values)):
assert isinstance(new_values, jax.Array)
assert isinstance(parent_array, jax.Array)
# Apply the inverse transformation to propagate changes back
new_values = view_info.update_tensor(new_values, parent_array)

Expand Down Expand Up @@ -353,6 +385,10 @@ def create_sub_view(self, view_info: ViewInfo) -> "View":
def __str__(self) -> str:
return f"View({self.torch()})"

@property
def _elem(self) -> jax.Array:
return self.jax()

def jax(self) -> jax.Array:
"""
Returns a copy of the source tensor after transformations.
Expand All @@ -363,6 +399,8 @@ def jax(self) -> jax.Array:
return result

def __setitem__(self, indexes, val):
# Handle tensor indexing
indexes = pytree.tree_map(lambda x: x.jax() if isinstance(x, torch.Tensor) else x, indexes)
view_infos = self.get_transformation_chain() + [NarrowInfo(indexes)]
self.update(view_infos=view_infos, new_values=val)

Expand All @@ -381,4 +419,20 @@ def jax_device(self):
def ndim(self):
return len(self.shape)

@property
def data(self):
return self

__repr__ = __str__


# FAILED test/test_ops.py::TestOpInfoCPU::test_reference_eager_masked_std_cpu_float32 - NotImplementedError: Cannot copy out of meta tensor; no data!
# FAILED test/test_ops.py::TestOpInfoCPU::test_reference_eager_masked_var_cpu_float32 - NotImplementedError: Cannot copy out of meta tensor; no data!
# FAILED test/test_ops.py::TestOpInfoCPU::test_reference_eager_masked_std_cpu_int64 - NotImplementedError: Cannot copy out of meta tensor; no data!
# FAILED test/test_ops.py::TestOpInfoCPU::test_reference_eager_masked_var_cpu_int64 - NotImplementedError: Cannot copy out of meta tensor; no data!
# FAILED test/test_ops.py::TestOpInfoCPU::test_reference_eager_nn_functional_interpolate_bilinear_cpu_float32 - NotImplementedError: Cannot copy out of meta tensor; no data!
# FAILED test/test_ops.py::TestOpInfoCPU::test_reference_eager_nn_functional_interpolate_linear_cpu_float32 - NotImplementedError: Cannot copy out of meta tensor; no data!
# FAILED test/test_ops.py::TestOpInfoCPU::test_reference_eager_nn_functional_interpolate_trilinear_cpu_float32 - NotImplementedError: Cannot copy out of meta tensor; no data!
# FAILED test/test_ops.py::TestOpInfoCPU::test_reference_eager_nn_functional_upsample_bilinear_cpu_float32 - NotImplementedError: Cannot copy out of meta tensor; no data!
# FAILED test/test_ops.py::TestOpInfoCPU::test_reference_eager_take_cpu_float32 - AttributeError: 'View' object has no attribute '_elem'
# FAILED test/test_ops.py::TestOpInfoCPU::test_reference_eager_take_cpu_int64 - AttributeError: 'View' object has no attribute '_elem'
Loading