Skip to content

Commit

Permalink
[torch_xla2] Fix some duplicate and incorrect op registrations
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar committed Jul 16, 2024
1 parent b2c7f65 commit 353a0ff
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 84 deletions.
89 changes: 13 additions & 76 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
torch.ops.aten.add_: torch.ops.aten.add,
torch.ops.aten.sub_: torch.ops.aten.sub,
torch.ops.aten.mul_: torch.ops.aten.mul,
torch.ops.aten.div_: torch.ops.aten.div,
torch.ops.aten.pow_: torch.ops.aten.pow,
torch.ops.aten.lt_: torch.ops.aten.lt,
torch.ops.aten.le_: torch.ops.aten.le,
Expand Down Expand Up @@ -99,11 +98,6 @@ def _aten_clone(x, memory_format=None):
return x


@op(torch.ops.aten.full)
def _aten_full(size, value, **kwargs):
return jnp.full(size, value)


@op(torch.ops.aten.index_copy)
def _aten_index_copy(x, dim, indexes, source):
# return jax.lax.scatter(x, index, dim)
Expand Down Expand Up @@ -434,7 +428,7 @@ def _eye(n: int, m: Optional[int] = None, *, dtype=None, **kwargs):
return jnp.eye(n, m, dtype=dtype)


@op(torch.full)
@op(torch.ops.aten.full)
@op_base.convert_dtype()
def _full(size: Sequence[int], fill_value, *, dtype=None, **kwargs):
# TODO: handle torch.Size
Expand All @@ -446,15 +440,15 @@ def _full(size: Sequence[int], fill_value, *, dtype=None, **kwargs):
@op_base.convert_dtype()
def _aten_empty_permuted(sizes, physical_layout, dtype=None, **kwargs):
# Ignore the physical layout,
# since JAX and torch tensor doesn't share the same memory.
# since JAX and torch tensor doesn't share the same memory.
return jnp.empty(sizes, dtype=dtype)


@op(torch.ops.aten.empty_strided)
@op(torch.ops.aten.empty_strided.default)
@op_base.convert_dtype()
def _aten_empty_strided(sizes, stride, dtype=None, **kwargs):
# Ignore stride, since JAX and torch tensor doesn't share the same memory.
# Ignore stride, since JAX and torch tensor doesn't share the same memory.
return jnp.empty(sizes, dtype=dtype)


Expand Down Expand Up @@ -1604,8 +1598,8 @@ def _aten_constant_pad_nd(input, padding, value=0):
# means last dim get padded 1 in front and 1 in back;
# and second last dim get padded 2 in front and 2 in back.
# Jax padding tuple of 3-tuple: the same padding is
# [(0, 0, 0), ..., (2,2,0), (1,1,0)], where the last dimension
# is the amount of padding added between any two elements in each dimension
# [(0, 0, 0), ..., (2,2,0), (1,1,0)], where the last dimension
# is the amount of padding added between any two elements in each dimension
m = len(padding)
rev_padding = [(padding[i - 1], padding[i], 0) for i in range(m - 1, 0, -2)]
pad_dim = tuple(([(0, 0, 0)] * (len(input.shape) - m // 2)) + rev_padding)
Expand All @@ -1614,12 +1608,16 @@ def _aten_constant_pad_nd(input, padding, value=0):


# aten.convolution_backward
@op(torch.ops.aten.copy)
@op(torch.ops.aten.lift_fresh_copy)
def _aten_copy(x):
def _aten_lift_fresh_copy(x):
return jnp.copy(x)


@op(torch.ops.aten.copy, is_jax_function=False)
def _aten_copy(self, src):
return self.copy_(src)


@op(torch.ops.aten._cdist_forward)
def _aten_cdist_forward(x1, x2, p, compute_mode=""):
# x1 is B x P x M
Expand Down Expand Up @@ -2004,9 +2002,9 @@ def _aten_where(condition, x, y):

# aten.to.dtype
# Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None
@op(torch.ops.aten.to.dtype)
@op(torch.ops.aten.to.dtype, torch.ops.aten.to.dtype_layout)
def _aten_to_dtype(
a, dtype, non_blocking=False, copy=False, memory_format=None
a, *, dtype=None, layout=None, device=None, pin_memory=None, non_blocking=False, copy=False, memory_format=None
):
if dtype:
jaxdtype = mappings.t2j_dtype(dtype)
Expand Down Expand Up @@ -2164,67 +2162,6 @@ def _aten_scalar_tensor(val, **kwargs):
return mappings.t2j(p)


@op(torch.ops.aten.to.device)
def _aten_to_device(x, device, dtype):
return x


@op(torch.ops.aten.max_pool2d_with_indices_backward)
def max_pool2d_with_indices_backward_custom(
grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices
):
"""
Approximates the gradient calculation of PyTorch's max_pool2d_with_indices_backward.
Args:
grad_output: The gradient tensor from the preceding layer.
self: The input tensor on which the original max pooling was performed.
kernel_size: The size of the pooling window.
stride: The stride of the pooling window.
padding: The padding applied during max pooling.
dilation: The dilation factor for the pooling operation.
ceil_mode: Whether to use ceil or floor when calculating output shapes.
indices: The indices of the maximum values, as produced by max_pool2d_with_indices.
Returns:
The calculated gradient with respect to the input (grad_input).
"""

kH, kW = kernel_size
dH, dW = stride
padH, padW = padding
dilH, dilW = dilation

# Calculate output shape (may need adjustment based on ceil_mode)
out_shape = jnp.array(self.shape)
grad_input = jnp.zeros_like(self)

# Iterate over the flattened input and output tensors
for i, idx in enumerate(indices.flatten()):
# Calculate input coordinates corresponding to the maximum value
out_y, out_x = i // grad_output.shape[3], i % grad_output.shape[3]
in_y = out_y * dH - padH + out_y * (dilH - 1)
in_x = out_x * dW - padW + out_x * (dilW - 1)

# Scatter the gradient to the appropriate input locations (handling potential overlaps)
for y in range(in_y, in_y + kH):
for x in range(in_x, in_x + kW):
if 0 <= y < grad_input.shape[2] and 0 <= x < grad_input.shape[3]:
grad_input = grad_input.at[y, x].add(grad_output.flatten()[i])

return grad_input


@op(torch.ops.aten._local_scalar_dense)
def _aten_local_scalar_dense(x):
return x.item()


@op(torch.ops.aten.tensor_split.sections)
def _aten_tensor_split(ary, indices_or_sections, axis=0):
return jnp.array_split(ary, indices_or_sections, axis)


@op(torch.ops.aten.outer)
def _aten_outer(a, b):
return jnp.outer(a, b)
Expand Down
19 changes: 11 additions & 8 deletions experimental/torch_xla2/torch_xla2/ops/ops_registry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dataclasses
import logging
from torch_xla2.types import JaxCallable, TorchCallable

from typing import Union, Dict
Expand All @@ -18,30 +19,32 @@ class Operator:


def register_torch_dispatch_op(
aten_op, impl_callable,
is_jax_function=True,
aten_op, impl_callable,
is_jax_function=True,
is_user_defined=False,
needs_env=False,
):
op = Operator(
aten_op, impl_callable,
aten_op, impl_callable,
is_jax_function=is_jax_function,
is_user_defined=is_user_defined,
needs_env=needs_env)
if aten_op in all_aten_ops:
logging.warning(f'Duplicate op registration for {aten_op}')
all_aten_ops[aten_op] = op
return impl_callable
return impl_callable


def register_torch_function_op(
torch_func, impl_callable,
is_jax_function=True,
torch_func, impl_callable,
is_jax_function=True,
is_user_defined=False,
needs_env=False,
):
op = Operator(
torch_func, impl_callable,
torch_func, impl_callable,
is_jax_function=is_jax_function,
is_user_defined=is_user_defined,
needs_env=needs_env)
all_torch_functions[torch_func] = op
return impl_callable
return impl_callable

0 comments on commit 353a0ff

Please sign in to comment.