From ac568122928e2b5e195d39e35ca37e2c9472facb Mon Sep 17 00:00:00 2001 From: Yifei Teng Date: Wed, 22 Jan 2025 13:49:16 -0800 Subject: [PATCH] Cherrypick #8562 (#8611) --- test/scan/test_scan.py | 57 +++++++++++++++-- test/scan/test_scan_spmd.py | 2 +- torch_xla/experimental/scan.py | 114 +++++++++++++++++++++++++++------ 3 files changed, 148 insertions(+), 25 deletions(-) diff --git a/test/scan/test_scan.py b/test/scan/test_scan.py index eb292ee5185..d0bb6b08e82 100644 --- a/test/scan/test_scan.py +++ b/test/scan/test_scan.py @@ -444,6 +444,46 @@ def fn(carry, x): self.assertEqual(bf16_ys.dtype, torch.bfloat16) self.assertEqual(f32_ys.dtype, torch.float32) + def test_scan_activation_aliases_input(self): + """Test that if an intermediate activation of fn aliases an input, + we directly save the input tensor into the context object, instead of + indexing into the leading dimension during the while loop and copying + the those slices into a new output tensor. This is a memory saving optimization. + """ + + def fn(carry, x): + return carry, torch.sin(x) + + carry = torch.randn(4, 4, requires_grad=True, device=self.device) + xs = torch.randn(20, 4, 4, requires_grad=True, device=self.device) + torch_xla.sync() + + storage = [] + + def pack(x): + storage.append(x) + return len(storage) - 1 + + def unpack(x): + return storage[x] + + # Intercept the tensors stored in the context object. + with torch.autograd.graph.saved_tensors_hooks(pack, unpack): + final_carry, ys = scan(fn, carry, xs) + ys.sum().backward() + torch_xla.sync() + + # Find the input that is stored in the context object. + stored_xs = None + for s in storage: + if s.shape == xs.shape: + assert stored_xs is None + stored_xs = s + + # Test that it's literally the same object as the input tensor, + # as opposed to just numerically identical but otherwise an extra copy. + assert id(stored_xs) == id(xs) + class PyTreeTest(TestBase): @@ -469,12 +509,16 @@ def fn(carry, x): xs = torch.tensor([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]], requires_grad=True, device=self.device) - forward, backward = value_and_grad_partitioned(fn, init, xs) + forward, alias_input, backward = value_and_grad_partitioned(fn, init, xs) - # Forward should return `(new_carry, (y, (carry, x)))`, - # because `(carry, x)` are the two intermediate activations (primals), - # and they will be packed alongside the original output `y`. + # Once we add back activations that are aliases to inputs, the result should + # be `(new_carry, (y, (carry, x)))`, because `(carry, x)` are the two + # intermediate activations (primals), and they will be packed alongside + # the original output `y`. out = forward(init, xs[0]) + new_carry, (y, partial_activations) = out + activations = alias_input(partial_activations, xs[0]) + out = (new_carry, (y, activations)) torch_xla.sync() carry = init x = xs[0] @@ -521,11 +565,12 @@ def fn(carry, x): } # Get the forward and backward functions using value_and_grad_partitioned - forward, backward = value_and_grad_partitioned( + forward, alias_input, backward = value_and_grad_partitioned( fn, init, tree_map(lambda v: v.unsqueeze(0), x)) # Run the forward function - carry_out, (y_out, activations) = forward(init, x) + carry_out, (y_out, partial_activations) = forward(init, x) + activations = alias_input(partial_activations, x) torch_xla.sync() # Compute expected outputs and gradients using PyTorch autograd diff --git a/test/scan/test_scan_spmd.py b/test/scan/test_scan_spmd.py index 19ec991cbb3..6afe4fb196c 100644 --- a/test/scan/test_scan_spmd.py +++ b/test/scan/test_scan_spmd.py @@ -25,7 +25,7 @@ def test_scan_cumsum(self): """This test uses `scan` to implement `torch.cumsum`.""" def fn(carry, x): - new_carry = carry + x + new_carry = torch.sin(carry + x) y = new_carry return new_carry, y diff --git a/torch_xla/experimental/scan.py b/torch_xla/experimental/scan.py index 8f956faebb0..63d75f3da3a 100644 --- a/torch_xla/experimental/scan.py +++ b/torch_xla/experimental/scan.py @@ -46,6 +46,7 @@ import torch_xla import torch_xla.core.xla_builder as xb from torch_xla.experimental.pytreeify import pytreeify +import torch_xla.debug.profiler as xp Carry = TypeVar('Carry') X = TypeVar('X') @@ -154,9 +155,10 @@ def scan(fn, init, xs): if xs_length is None: raise ValueError(f"`xs` {xs} is an empty PyTree.") - forward, backward = value_and_grad_partitioned( + forward, alias_input, backward = value_and_grad_partitioned( fn, init, xs, partition_fn=partition_fn) - carry, ys = Scan.apply(forward, backward, init, xs) # type: ignore + carry, ys = Scan.apply(forward, alias_input, backward, init, + xs) # type: ignore return carry, ys @@ -164,25 +166,33 @@ def value_and_grad_partitioned( fn: Callable[[Carry, X], tuple[Carry, Y]], init: Carry, xs: X, - partition_fn=default_partition) -> tuple[Callable, Callable]: + partition_fn=default_partition) -> tuple[Callable, Callable, Callable]: """ Given a user `fn` to be scanned over the leading dimension of the input `xs` PyTree and an initial carry object `init`, symbolically traces `fn` and - returns two functions, `forward` and `backward`, which wrap the forward and - backward graphs of `fn` and plumbs through intermediate activations. - Specifically, given + returns three functions, `forward`, `alias_input`, and `backward`. + `forward` and `backward` wrap the forward and backward graphs of `fn` and + plumbs through intermediate activations, while `alias_input` is a memory + saving optimization. Specifically, given `fn(carry, x) -> (new_carry, y)` - + this function will build and return - `forward(carry, x) -> (new_carry, (y, activations))` + `forward(carry, x) -> (new_carry, (y, partial_activations))` + + `alias_input(stack(partial_activations), xs) -> stack(activations)` `backward(grad_new_carry, (grad_y, activations)) -> (grad_carry, grad_x)` where `grad_y` is the gradient w.r.t `y`, and `grad_new_carry` is the gradient w.r.t. `new_carry`. - + + The `partial_activations` returned by `forward` are intermediate activations + that do not alias any input tensors. You may pass a stack of `partial_activations` + and the original input `xs` PyTree to `alias_input` to reconstitute the full + list of `activations`. + `activations` will always be a flat list of tensors. This is similar to the `value_and_grad` transform found in JAX, but additionally @@ -201,7 +211,7 @@ def value_and_grad_partitioned( forward and backward graphs. Returns: - A tuple of `(forward, backward)`, detailed in the docstring of this function. + A tuple of `(forward, alias_input, backward)`, detailed in the docstring of this function. """ # Make some fake tensors to trace the user function and obtain the @@ -253,24 +263,92 @@ def fn_no_output_aliasing(*args): fwd_graph = get_fwd() bwd_graph = get_bwd() - def forward(carry, x): + # Figure out which activations are alises to the inputs. We don't need to + # pass them through the scan logic unchanged. That would use more memory. + input_activation_aliases = _find_input_activation_aliases( + fake_carry_pytree, fake_x_pytree, num_out, fwd_graph) + aliased_activations = set(input_activation_aliases.values()) + + def forward_core(carry, x): flat_carry, _ = tree_flatten(carry) flat_x, _ = tree_flatten(x) - out = fwd_graph(*flat_carry, *flat_x) + with xp.Trace('aot_forward'): + out = fwd_graph(*flat_carry, *flat_x) actual_out, activations = split(out, num_out) carry, y = unflatten_fwd_out(actual_out) y = (y, activations) return carry, y + def forward(carry, x): + carry, (y, activations) = forward_core(carry, x) + + # Remove activations that alias to inputs. Those will be added back + # in `alias_input`. + partial_activations = tuple( + v for i, v in enumerate(activations) if i not in aliased_activations) + + y = (y, partial_activations) + return carry, y + + def alias_input(partial_activations, xs): + """ + Add back activations that are aliases to input tensors. + + In principle, we could have `forward` return all the intermediate activations, + including those that are aliases to an input tensor. However, those inputs will + then be duplicated as part of the output of a `scan` call, because we want to + save all activations during the forward pass of a `scan`. The XLA compiler can't + optimize away this duplication likely because they're behind a DynamicSlice + + DynamicUpdateSlice, so we end up doubling the memory usage from those inputs. + + To reduce memory usage, we can have `forward` return the activations that + don't alias to inputs, called `partial_activations`. The autograd implementation + of `scan` will call `alias_input` to add back activations that are aliases + of input tensors outside of a scan, turning the partial activations back to + full activations. + """ + activations = list(partial_activations) + aliased_inputs = [ + v for i, v in enumerate(tree_iter(xs)) if i in input_activation_aliases + ] + for (i, activation_idx) in enumerate(input_activation_aliases.values()): + activations.insert(activation_idx, aliased_inputs[i]) + return tuple(activations) + def backward(carry, x): grad_new_carry, _ = tree_flatten(carry) (grad_y, activations) = x grad_y, _ = tree_flatten_none(grad_y) - out = bwd_graph(*activations, *grad_new_carry, *grad_y) + with xp.Trace('aot_backward'): + out = bwd_graph(*activations, *grad_new_carry, *grad_y) grad_carry, grad_x = unflatten_bwd_out(out) return grad_carry, grad_x - return forward, backward + return forward, alias_input, backward + + +def _find_input_activation_aliases(fake_carry_pytree, fake_x_pytree, num_out, + fwd_graph): + """ + Find which activations are aliases to input tensors. + + Returns: + + A mapping from index into the flatttened + input pytree to the index into the list of intermediate activations. + + """ + flat_carry, _ = tree_flatten(fake_carry_pytree) + flat_x, _ = tree_flatten(fake_x_pytree) + _actual_out, activations = split(fwd_graph(*flat_carry, *flat_x), num_out) + input_id_to_index = { + v: i for i, v in enumerate(id(v) for v in tree_iter(flat_x)) + } + input_activation_aliases = {} + for idx, i in enumerate(id(v) for v in activations): + if i in input_id_to_index: + input_activation_aliases[input_id_to_index[i]] = idx + return input_activation_aliases def _make_get_graph_compiler(): @@ -297,12 +375,12 @@ def get_graph(): class Scan(torch.autograd.Function): @staticmethod - def forward(ctx, forward, backward, init, xs): - # Forward pass, save activations for backward + def forward(ctx, forward, alias_input, backward, init, xs): ctx._backward = backward with torch.no_grad(): carry, ys = _scan_impl_pytree(forward, init, xs) - ys, activations = ys + ys, partial_activations = ys + activations = alias_input(partial_activations, xs) ctx.save_for_backward(*activations) return carry, ys @@ -314,7 +392,7 @@ def backward(ctx, grad_carry, grad_ys): # type: ignore # Reverse loop to propagate gradients from last iteration to first. grad_init, grad_xs = _scan_impl_pytree( backward, grad_carry, (grad_ys, activations), reverse=True) - return None, None, grad_init, grad_xs + return None, None, None, grad_init, grad_xs def _scan_impl_pytree(fn, init, xs, reverse: bool = False):