Skip to content

Commit

Permalink
Cherrypick #8530 (#8604)
Browse files Browse the repository at this point in the history
  • Loading branch information
tengyifei authored Jan 22, 2025
1 parent ff75e1f commit a954d92
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
26 changes: 26 additions & 0 deletions test/scan/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,32 @@ def count_number_of_sines(partition_fn):
count_number_of_sines(min_cut_rematerialization_partition), 10)
self.assertEqual(count_number_of_sines(default_partition), 0)

def test_scan_different_dtypes(self):
"""Test that the combine function can output different dtypes."""

def fn(carry, x):
bf16_value, f32_value = x
y = (torch.sin(bf16_value), torch.sin(f32_value))
return torch.sin(carry), y

init = torch.tensor([0.0, 0.0],
requires_grad=True,
device=self.device,
dtype=torch.float16)
bf16_xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
requires_grad=True,
device=self.device,
dtype=torch.bfloat16)
f32_xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
requires_grad=True,
device=self.device,
dtype=torch.float32)
final_carry, ys = self.run_test(fn, init, (bf16_xs, f32_xs))
bf16_ys, f32_ys = ys
self.assertEqual(final_carry.dtype, torch.float16)
self.assertEqual(bf16_ys.dtype, torch.bfloat16)
self.assertEqual(f32_ys.dtype, torch.float32)


class PyTreeTest(TestBase):

Expand Down
5 changes: 2 additions & 3 deletions torch_xla/experimental/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,16 +501,15 @@ def make_fake_tensor(v: torch.Tensor) -> torch.Tensor:
fn_carry_out, fn_y_out = split(fn_outputs, carry_len)
assert carry_len + y_len == len(fn_outputs)
fn_carry_shapes = [v.shape for v in fn_carry_out]
fn_y_shapes = [v.shape for v in fn_y_out]
for fn_carry_shape, init_leaf in zip(fn_carry_shapes, init):
assert fn_carry_shape == init_leaf.shape, f"`fn` must keep the `carry` shape unchanged. \
Got {fn_carry_shape} but expected {init_leaf.shape}"

builder = Builder('scan')
num_iters = next(iter(tree_iter(xs))).size(0)
ys = [
torch.zeros((num_iters, *fn_y_shape), device=device)
for fn_y_shape in fn_y_shapes
torch.zeros((num_iters, *v.shape), device=device, dtype=v.dtype)
for v in fn_y_out
]
# Start the `curr_iter` loop variable at zero.
zero = torch.tensor(0, device=device)
Expand Down

0 comments on commit a954d92

Please sign in to comment.