Skip to content

Commit

Permalink
Cherrypick #8562 (#8611)
Browse files Browse the repository at this point in the history
  • Loading branch information
tengyifei authored Jan 22, 2025
1 parent a954d92 commit ac56812
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 25 deletions.
57 changes: 51 additions & 6 deletions test/scan/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,46 @@ def fn(carry, x):
self.assertEqual(bf16_ys.dtype, torch.bfloat16)
self.assertEqual(f32_ys.dtype, torch.float32)

def test_scan_activation_aliases_input(self):
"""Test that if an intermediate activation of fn aliases an input,
we directly save the input tensor into the context object, instead of
indexing into the leading dimension during the while loop and copying
the those slices into a new output tensor. This is a memory saving optimization.
"""

def fn(carry, x):
return carry, torch.sin(x)

carry = torch.randn(4, 4, requires_grad=True, device=self.device)
xs = torch.randn(20, 4, 4, requires_grad=True, device=self.device)
torch_xla.sync()

storage = []

def pack(x):
storage.append(x)
return len(storage) - 1

def unpack(x):
return storage[x]

# Intercept the tensors stored in the context object.
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
final_carry, ys = scan(fn, carry, xs)
ys.sum().backward()
torch_xla.sync()

# Find the input that is stored in the context object.
stored_xs = None
for s in storage:
if s.shape == xs.shape:
assert stored_xs is None
stored_xs = s

# Test that it's literally the same object as the input tensor,
# as opposed to just numerically identical but otherwise an extra copy.
assert id(stored_xs) == id(xs)


class PyTreeTest(TestBase):

Expand All @@ -469,12 +509,16 @@ def fn(carry, x):
xs = torch.tensor([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]],
requires_grad=True,
device=self.device)
forward, backward = value_and_grad_partitioned(fn, init, xs)
forward, alias_input, backward = value_and_grad_partitioned(fn, init, xs)

# Forward should return `(new_carry, (y, (carry, x)))`,
# because `(carry, x)` are the two intermediate activations (primals),
# and they will be packed alongside the original output `y`.
# Once we add back activations that are aliases to inputs, the result should
# be `(new_carry, (y, (carry, x)))`, because `(carry, x)` are the two
# intermediate activations (primals), and they will be packed alongside
# the original output `y`.
out = forward(init, xs[0])
new_carry, (y, partial_activations) = out
activations = alias_input(partial_activations, xs[0])
out = (new_carry, (y, activations))
torch_xla.sync()
carry = init
x = xs[0]
Expand Down Expand Up @@ -521,11 +565,12 @@ def fn(carry, x):
}

# Get the forward and backward functions using value_and_grad_partitioned
forward, backward = value_and_grad_partitioned(
forward, alias_input, backward = value_and_grad_partitioned(
fn, init, tree_map(lambda v: v.unsqueeze(0), x))

# Run the forward function
carry_out, (y_out, activations) = forward(init, x)
carry_out, (y_out, partial_activations) = forward(init, x)
activations = alias_input(partial_activations, x)
torch_xla.sync()

# Compute expected outputs and gradients using PyTorch autograd
Expand Down
2 changes: 1 addition & 1 deletion test/scan/test_scan_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_scan_cumsum(self):
"""This test uses `scan` to implement `torch.cumsum`."""

def fn(carry, x):
new_carry = carry + x
new_carry = torch.sin(carry + x)
y = new_carry
return new_carry, y

Expand Down
114 changes: 96 additions & 18 deletions torch_xla/experimental/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import torch_xla
import torch_xla.core.xla_builder as xb
from torch_xla.experimental.pytreeify import pytreeify
import torch_xla.debug.profiler as xp

Carry = TypeVar('Carry')
X = TypeVar('X')
Expand Down Expand Up @@ -154,35 +155,44 @@ def scan(fn, init, xs):
if xs_length is None:
raise ValueError(f"`xs` {xs} is an empty PyTree.")

forward, backward = value_and_grad_partitioned(
forward, alias_input, backward = value_and_grad_partitioned(
fn, init, xs, partition_fn=partition_fn)
carry, ys = Scan.apply(forward, backward, init, xs) # type: ignore
carry, ys = Scan.apply(forward, alias_input, backward, init,
xs) # type: ignore
return carry, ys


def value_and_grad_partitioned(
fn: Callable[[Carry, X], tuple[Carry, Y]],
init: Carry,
xs: X,
partition_fn=default_partition) -> tuple[Callable, Callable]:
partition_fn=default_partition) -> tuple[Callable, Callable, Callable]:
"""
Given a user `fn` to be scanned over the leading dimension of the input `xs`
PyTree and an initial carry object `init`, symbolically traces `fn` and
returns two functions, `forward` and `backward`, which wrap the forward and
backward graphs of `fn` and plumbs through intermediate activations.
Specifically, given
returns three functions, `forward`, `alias_input`, and `backward`.
`forward` and `backward` wrap the forward and backward graphs of `fn` and
plumbs through intermediate activations, while `alias_input` is a memory
saving optimization. Specifically, given
`fn(carry, x) -> (new_carry, y)`
this function will build and return
`forward(carry, x) -> (new_carry, (y, activations))`
`forward(carry, x) -> (new_carry, (y, partial_activations))`
`alias_input(stack(partial_activations), xs) -> stack(activations)`
`backward(grad_new_carry, (grad_y, activations)) -> (grad_carry, grad_x)`
where `grad_y` is the gradient w.r.t `y`, and `grad_new_carry` is the gradient
w.r.t. `new_carry`.
The `partial_activations` returned by `forward` are intermediate activations
that do not alias any input tensors. You may pass a stack of `partial_activations`
and the original input `xs` PyTree to `alias_input` to reconstitute the full
list of `activations`.
`activations` will always be a flat list of tensors.
This is similar to the `value_and_grad` transform found in JAX, but additionally
Expand All @@ -201,7 +211,7 @@ def value_and_grad_partitioned(
forward and backward graphs.
Returns:
A tuple of `(forward, backward)`, detailed in the docstring of this function.
A tuple of `(forward, alias_input, backward)`, detailed in the docstring of this function.
"""

# Make some fake tensors to trace the user function and obtain the
Expand Down Expand Up @@ -253,24 +263,92 @@ def fn_no_output_aliasing(*args):
fwd_graph = get_fwd()
bwd_graph = get_bwd()

def forward(carry, x):
# Figure out which activations are alises to the inputs. We don't need to
# pass them through the scan logic unchanged. That would use more memory.
input_activation_aliases = _find_input_activation_aliases(
fake_carry_pytree, fake_x_pytree, num_out, fwd_graph)
aliased_activations = set(input_activation_aliases.values())

def forward_core(carry, x):
flat_carry, _ = tree_flatten(carry)
flat_x, _ = tree_flatten(x)
out = fwd_graph(*flat_carry, *flat_x)
with xp.Trace('aot_forward'):
out = fwd_graph(*flat_carry, *flat_x)
actual_out, activations = split(out, num_out)
carry, y = unflatten_fwd_out(actual_out)
y = (y, activations)
return carry, y

def forward(carry, x):
carry, (y, activations) = forward_core(carry, x)

# Remove activations that alias to inputs. Those will be added back
# in `alias_input`.
partial_activations = tuple(
v for i, v in enumerate(activations) if i not in aliased_activations)

y = (y, partial_activations)
return carry, y

def alias_input(partial_activations, xs):
"""
Add back activations that are aliases to input tensors.
In principle, we could have `forward` return all the intermediate activations,
including those that are aliases to an input tensor. However, those inputs will
then be duplicated as part of the output of a `scan` call, because we want to
save all activations during the forward pass of a `scan`. The XLA compiler can't
optimize away this duplication likely because they're behind a DynamicSlice +
DynamicUpdateSlice, so we end up doubling the memory usage from those inputs.
To reduce memory usage, we can have `forward` return the activations that
don't alias to inputs, called `partial_activations`. The autograd implementation
of `scan` will call `alias_input` to add back activations that are aliases
of input tensors outside of a scan, turning the partial activations back to
full activations.
"""
activations = list(partial_activations)
aliased_inputs = [
v for i, v in enumerate(tree_iter(xs)) if i in input_activation_aliases
]
for (i, activation_idx) in enumerate(input_activation_aliases.values()):
activations.insert(activation_idx, aliased_inputs[i])
return tuple(activations)

def backward(carry, x):
grad_new_carry, _ = tree_flatten(carry)
(grad_y, activations) = x
grad_y, _ = tree_flatten_none(grad_y)
out = bwd_graph(*activations, *grad_new_carry, *grad_y)
with xp.Trace('aot_backward'):
out = bwd_graph(*activations, *grad_new_carry, *grad_y)
grad_carry, grad_x = unflatten_bwd_out(out)
return grad_carry, grad_x

return forward, backward
return forward, alias_input, backward


def _find_input_activation_aliases(fake_carry_pytree, fake_x_pytree, num_out,
fwd_graph):
"""
Find which activations are aliases to input tensors.
Returns:
A mapping from index into the flatttened
input pytree to the index into the list of intermediate activations.
"""
flat_carry, _ = tree_flatten(fake_carry_pytree)
flat_x, _ = tree_flatten(fake_x_pytree)
_actual_out, activations = split(fwd_graph(*flat_carry, *flat_x), num_out)
input_id_to_index = {
v: i for i, v in enumerate(id(v) for v in tree_iter(flat_x))
}
input_activation_aliases = {}
for idx, i in enumerate(id(v) for v in activations):
if i in input_id_to_index:
input_activation_aliases[input_id_to_index[i]] = idx
return input_activation_aliases


def _make_get_graph_compiler():
Expand All @@ -297,12 +375,12 @@ def get_graph():
class Scan(torch.autograd.Function):

@staticmethod
def forward(ctx, forward, backward, init, xs):
# Forward pass, save activations for backward
def forward(ctx, forward, alias_input, backward, init, xs):
ctx._backward = backward
with torch.no_grad():
carry, ys = _scan_impl_pytree(forward, init, xs)
ys, activations = ys
ys, partial_activations = ys
activations = alias_input(partial_activations, xs)
ctx.save_for_backward(*activations)
return carry, ys

Expand All @@ -314,7 +392,7 @@ def backward(ctx, grad_carry, grad_ys): # type: ignore
# Reverse loop to propagate gradients from last iteration to first.
grad_init, grad_xs = _scan_impl_pytree(
backward, grad_carry, (grad_ys, activations), reverse=True)
return None, None, grad_init, grad_xs
return None, None, None, grad_init, grad_xs


def _scan_impl_pytree(fn, init, xs, reverse: bool = False):
Expand Down

0 comments on commit ac56812

Please sign in to comment.