Skip to content

Commit

Permalink
Fix to.dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar committed Jul 16, 2024
1 parent 353a0ff commit 6ad843c
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2002,15 +2002,26 @@ def _aten_where(condition, x, y):

# aten.to.dtype
# Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None
@op(torch.ops.aten.to.dtype, torch.ops.aten.to.dtype_layout)
@op(torch.ops.aten.to.dtype)
def _aten_to_dtype(
a, *, dtype=None, layout=None, device=None, pin_memory=None, non_blocking=False, copy=False, memory_format=None
a, dtype, non_blocking=False, copy=False, memory_format=None
):
if dtype:
jaxdtype = mappings.t2j_dtype(dtype)
return a.astype(jaxdtype)


@op(torch.ops.aten.to.dtype_layout)
def _aten_to_dtype_layout(
a, *, dtype=None, layout=None, device=None, pin_memory=None, non_blocking=False, copy=False, memory_format=None
):
return _aten_to_dtype(
a,
dtype,
non_blocking=non_blocking,
copy=copy,
memory_format=memory_format)

# aten.to.device


Expand Down

0 comments on commit 6ad843c

Please sign in to comment.