Skip to content

Add cache support for scan_layers #9297

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 48 additions & 5 deletions test/scan/test_scan_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
os.path.dirname(__file__))) + "/examples"
sys.path.append(example_folder)
from decoder_only_model import DecoderOnlyConfig, DecoderOnlyModel # type:ignore
from absl.testing import parameterized

import unittest
from copy import deepcopy
Expand All @@ -14,19 +15,23 @@
import torch.nn as nn

import torch_xla
import torch_xla.experimental.scan_layers as scan_layers_module
from torch_xla.experimental.scan_layers import scan_layers
from functorch.compile import default_partition, min_cut_rematerialization_partition

parent_folder = os.path.dirname(os.path.dirname(__file__))
sys.path.append(parent_folder)
from test_utils import XlaTestCase # type:ignore


class ScanLayersTest(XlaTestCase):
class ScanLayersTest(XlaTestCase, parameterized.TestCase):

def setUp(self):
super().setUp()

self.device = torch_xla.device()
# Clear the cache before each test
scan_layers_module._ONE_LAYER_CACHE.clear()

def assert_different_tensor(self, a: torch.Tensor, b: torch.Tensor):
assert a is not b, f"Expected {a} and {b} to be different tensors"
Expand All @@ -44,7 +49,8 @@ def test_empty_layers(self):
output = scan_layers(layers, input_data.clone())
super().compareResults(output, input_data, abs_err=0.0001, rel_err=0.001)

def test_linear_layers(self):
@parameterized.parameters(False, True)
def test_linear_layers(self, is_layer_pure: bool):
# Fix the random seed to avoid flakes.
with torch.random.fork_rng():
with torch_xla.xm.fork_rng():
Expand All @@ -59,7 +65,8 @@ def test_linear_layers(self):
layers_for_loop = deepcopy(layers)
torch_xla.sync()

output = scan_layers(layers_for_scan, input_data.clone())
output = scan_layers(
layers_for_scan, input_data.clone(), is_layer_pure=is_layer_pure)
self.assert_while_found_in_hlo(output)
output.sum().backward()
torch_xla.sync()
Expand Down Expand Up @@ -96,7 +103,8 @@ def test_linear_layers(self):
layer_loop.weight.grad)
self.assert_different_tensor(layer_scan.bias.grad, layer_loop.bias.grad)

def test_tuple_layers(self):
@parameterized.parameters(False, True)
def test_tuple_layers(self, is_layer_pure: bool):
"""Test applying layers that consume and return tuples. Construct a module
that transforms each element in the tuple.
"""
Expand Down Expand Up @@ -125,7 +133,8 @@ def forward(self, x, y, z):
torch.randn(64).to(self.device) * 300)
a = torch.randn(64).to(self.device)
input_data = tuple(t + a for t in input_data)
output = scan_layers(layers_for_scan, input_data)
output = scan_layers(
layers_for_scan, input_data, is_layer_pure=is_layer_pure)
self.assert_while_found_in_hlo(output[0])
self.assert_while_found_in_hlo(output[1])
output[0].sum().backward()
Expand Down Expand Up @@ -273,6 +282,40 @@ def test_mismatched_shapes(self):
with self.assertRaisesRegex(ValueError, "Shape mismatch"):
scan_layers([layer1, layer2], torch.zeros((128,), device='xla'))

@parameterized.parameters(default_partition,
min_cut_rematerialization_partition)
def test_scan_layers_cache(self, partition_fn):
# Test that the cache is used correctly.
layers = [nn.Linear(64, 64).to(self.device) for _ in range(10)]
input_data = torch.randn(64).to(self.device)
torch_xla.sync(wait=True)

scan_layers(
layers,
input_data.clone(),
partition_fn=partition_fn,
is_layer_pure=True)

# Check that the cache is correctly populated.
cache_key = (id(partition_fn), id(layers[0]))
self.assertIn(cache_key, scan_layers_module._ONE_LAYER_CACHE)

# Check that the cache is created based on the layer and the cache is properly hit.
scan_layers(layers, input_data.clone())
scan_layers(layers, input_data.clone())
self.assertEqual(len(scan_layers_module._ONE_LAYER_CACHE), 1)

def test_scan_layers_cache_non_pure(self):
# Test that the cache is not used for non-pure layers.
layers = [nn.Linear(64, 64).to(self.device) for _ in range(10)]
input_data = torch.randn(64).to(self.device)
torch_xla.sync(wait=True)

scan_layers(layers, input_data.clone(), is_layer_pure=False)

# Check that the cache is not populated.
self.assertEqual(len(scan_layers_module._ONE_LAYER_CACHE), 0)


if __name__ == '__main__':
test = unittest.main()
Expand Down
59 changes: 46 additions & 13 deletions torch_xla/experimental/scan_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,43 @@

from torch_xla.experimental.scan import scan

# Because the given function (first layer) need to be wrapped to fit the `scan` API (see _create_one_layer_fn),
# the wrapped function are different even if the given layers are the same.
# We cache the wrapped function so that the same layer has the same wrapped function
# so that the `scan` cache works correctly.
_ONE_LAYER_CACHE = {}


def _create_or_get_cached_one_layer_fn(first_layer: nn.Module,
partition_fn,
is_layer_pure: bool = False):
cache_key = (id(partition_fn), id(first_layer))
if is_layer_pure and cache_key in _ONE_LAYER_CACHE:
return _ONE_LAYER_CACHE[cache_key]

# Use the first layer as the example/template layer.
from copy import deepcopy
example_layer = deepcopy(first_layer)

# Define the function to apply at each step
def one_layer_fn(carry, params_buffers):
# Apply the current layer's weights and biases to the example layer,
# then run the resulting layer.
output = torch.func.functional_call( # type: ignore
example_layer, params_buffers, carry, strict=True)
return output, None

if is_layer_pure:
# Cache the function for pure layers to avoid recomputing it.
_ONE_LAYER_CACHE[cache_key] = one_layer_fn

return one_layer_fn


def scan_layers(layers: Iterable[torch.nn.Module],
input_data,
partition_fn=default_partition):
partition_fn=default_partition,
is_layer_pure=False):
"""Runs each layer in `layers` sequentially, starting with `input_data`.

`input_data` is provided as input to the first layer in `layers`. The output of one
Expand Down Expand Up @@ -41,6 +74,11 @@ def scan_layers(layers: Iterable[torch.nn.Module],
`functorch.compile.min_cut_rematerialization_partition` to use min-cut based
activation checkpointing. You may also write your own partitioner to insert any custom
logic such as host offloading of activations.

is_layer_pure: (Optional[bool]) If True, the function assumes that the layers are pure
functions, meaning that they do not have any side effects and do not depend on any
external state. This allows tracing caching.


Returns:
The output of the last layer from `layers`.
Expand Down Expand Up @@ -76,21 +114,16 @@ def scan_layers(layers: Iterable[torch.nn.Module],
stacked_buffers = tree_map(lambda *tensors: torch.stack(tensors, dim=0),
*buffers_list)

# Use the first layer as the example/template layer.
from copy import deepcopy
example_layer = deepcopy(first_layer)

# Define the function to apply at each step
def one_layer(carry, params_buffers):
# Apply the current layer's weights and biases to the example layer,
# then run the resulting layer.
output = torch.func.functional_call( # type: ignore
example_layer, params_buffers, carry, strict=True)
return output, None
one_layer = _create_or_get_cached_one_layer_fn(first_layer, partition_fn,
is_layer_pure)

stacked_params_buffers = (stacked_params, stacked_buffers)
final_carry, _ = scan(
one_layer, input_data, stacked_params_buffers, partition_fn=partition_fn)
one_layer,
input_data,
stacked_params_buffers,
partition_fn=partition_fn,
is_fn_pure=is_layer_pure)

return final_carry

Expand Down
Loading