diff --git a/test/scan/test_scan_layers.py b/test/scan/test_scan_layers.py index ba193ea1eb3..459d28c3b13 100644 --- a/test/scan/test_scan_layers.py +++ b/test/scan/test_scan_layers.py @@ -5,6 +5,7 @@ os.path.dirname(__file__))) + "/examples" sys.path.append(example_folder) from decoder_only_model import DecoderOnlyConfig, DecoderOnlyModel # type:ignore +from absl.testing import parameterized import unittest from copy import deepcopy @@ -14,19 +15,23 @@ import torch.nn as nn import torch_xla +import torch_xla.experimental.scan_layers as scan_layers_module from torch_xla.experimental.scan_layers import scan_layers +from functorch.compile import default_partition, min_cut_rematerialization_partition parent_folder = os.path.dirname(os.path.dirname(__file__)) sys.path.append(parent_folder) from test_utils import XlaTestCase # type:ignore -class ScanLayersTest(XlaTestCase): +class ScanLayersTest(XlaTestCase, parameterized.TestCase): def setUp(self): super().setUp() self.device = torch_xla.device() + # Clear the cache before each test + scan_layers_module._ONE_LAYER_CACHE.clear() def assert_different_tensor(self, a: torch.Tensor, b: torch.Tensor): assert a is not b, f"Expected {a} and {b} to be different tensors" @@ -44,7 +49,8 @@ def test_empty_layers(self): output = scan_layers(layers, input_data.clone()) super().compareResults(output, input_data, abs_err=0.0001, rel_err=0.001) - def test_linear_layers(self): + @parameterized.parameters(False, True) + def test_linear_layers(self, is_layer_pure: bool): # Fix the random seed to avoid flakes. with torch.random.fork_rng(): with torch_xla.xm.fork_rng(): @@ -59,7 +65,8 @@ def test_linear_layers(self): layers_for_loop = deepcopy(layers) torch_xla.sync() - output = scan_layers(layers_for_scan, input_data.clone()) + output = scan_layers( + layers_for_scan, input_data.clone(), is_layer_pure=is_layer_pure) self.assert_while_found_in_hlo(output) output.sum().backward() torch_xla.sync() @@ -96,7 +103,8 @@ def test_linear_layers(self): layer_loop.weight.grad) self.assert_different_tensor(layer_scan.bias.grad, layer_loop.bias.grad) - def test_tuple_layers(self): + @parameterized.parameters(False, True) + def test_tuple_layers(self, is_layer_pure: bool): """Test applying layers that consume and return tuples. Construct a module that transforms each element in the tuple. """ @@ -125,7 +133,8 @@ def forward(self, x, y, z): torch.randn(64).to(self.device) * 300) a = torch.randn(64).to(self.device) input_data = tuple(t + a for t in input_data) - output = scan_layers(layers_for_scan, input_data) + output = scan_layers( + layers_for_scan, input_data, is_layer_pure=is_layer_pure) self.assert_while_found_in_hlo(output[0]) self.assert_while_found_in_hlo(output[1]) output[0].sum().backward() @@ -273,6 +282,40 @@ def test_mismatched_shapes(self): with self.assertRaisesRegex(ValueError, "Shape mismatch"): scan_layers([layer1, layer2], torch.zeros((128,), device='xla')) + @parameterized.parameters(default_partition, + min_cut_rematerialization_partition) + def test_scan_layers_cache(self, partition_fn): + # Test that the cache is used correctly. + layers = [nn.Linear(64, 64).to(self.device) for _ in range(10)] + input_data = torch.randn(64).to(self.device) + torch_xla.sync(wait=True) + + scan_layers( + layers, + input_data.clone(), + partition_fn=partition_fn, + is_layer_pure=True) + + # Check that the cache is correctly populated. + cache_key = (id(partition_fn), id(layers[0])) + self.assertIn(cache_key, scan_layers_module._ONE_LAYER_CACHE) + + # Check that the cache is created based on the layer and the cache is properly hit. + scan_layers(layers, input_data.clone()) + scan_layers(layers, input_data.clone()) + self.assertEqual(len(scan_layers_module._ONE_LAYER_CACHE), 1) + + def test_scan_layers_cache_non_pure(self): + # Test that the cache is not used for non-pure layers. + layers = [nn.Linear(64, 64).to(self.device) for _ in range(10)] + input_data = torch.randn(64).to(self.device) + torch_xla.sync(wait=True) + + scan_layers(layers, input_data.clone(), is_layer_pure=False) + + # Check that the cache is not populated. + self.assertEqual(len(scan_layers_module._ONE_LAYER_CACHE), 0) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/experimental/scan_layers.py b/torch_xla/experimental/scan_layers.py index cfc5de5d1ed..871b36e23de 100644 --- a/torch_xla/experimental/scan_layers.py +++ b/torch_xla/experimental/scan_layers.py @@ -7,10 +7,43 @@ from torch_xla.experimental.scan import scan +# Because the given function (first layer) need to be wrapped to fit the `scan` API (see _create_one_layer_fn), +# the wrapped function are different even if the given layers are the same. +# We cache the wrapped function so that the same layer has the same wrapped function +# so that the `scan` cache works correctly. +_ONE_LAYER_CACHE = {} + + +def _create_or_get_cached_one_layer_fn(first_layer: nn.Module, + partition_fn, + is_layer_pure: bool = False): + cache_key = (id(partition_fn), id(first_layer)) + if is_layer_pure and cache_key in _ONE_LAYER_CACHE: + return _ONE_LAYER_CACHE[cache_key] + + # Use the first layer as the example/template layer. + from copy import deepcopy + example_layer = deepcopy(first_layer) + + # Define the function to apply at each step + def one_layer_fn(carry, params_buffers): + # Apply the current layer's weights and biases to the example layer, + # then run the resulting layer. + output = torch.func.functional_call( # type: ignore + example_layer, params_buffers, carry, strict=True) + return output, None + + if is_layer_pure: + # Cache the function for pure layers to avoid recomputing it. + _ONE_LAYER_CACHE[cache_key] = one_layer_fn + + return one_layer_fn + def scan_layers(layers: Iterable[torch.nn.Module], input_data, - partition_fn=default_partition): + partition_fn=default_partition, + is_layer_pure=False): """Runs each layer in `layers` sequentially, starting with `input_data`. `input_data` is provided as input to the first layer in `layers`. The output of one @@ -41,6 +74,11 @@ def scan_layers(layers: Iterable[torch.nn.Module], `functorch.compile.min_cut_rematerialization_partition` to use min-cut based activation checkpointing. You may also write your own partitioner to insert any custom logic such as host offloading of activations. + + is_layer_pure: (Optional[bool]) If True, the function assumes that the layers are pure + functions, meaning that they do not have any side effects and do not depend on any + external state. This allows tracing caching. + Returns: The output of the last layer from `layers`. @@ -76,21 +114,16 @@ def scan_layers(layers: Iterable[torch.nn.Module], stacked_buffers = tree_map(lambda *tensors: torch.stack(tensors, dim=0), *buffers_list) - # Use the first layer as the example/template layer. - from copy import deepcopy - example_layer = deepcopy(first_layer) - - # Define the function to apply at each step - def one_layer(carry, params_buffers): - # Apply the current layer's weights and biases to the example layer, - # then run the resulting layer. - output = torch.func.functional_call( # type: ignore - example_layer, params_buffers, carry, strict=True) - return output, None + one_layer = _create_or_get_cached_one_layer_fn(first_layer, partition_fn, + is_layer_pure) stacked_params_buffers = (stacked_params, stacked_buffers) final_carry, _ = scan( - one_layer, input_data, stacked_params_buffers, partition_fn=partition_fn) + one_layer, + input_data, + stacked_params_buffers, + partition_fn=partition_fn, + is_fn_pure=is_layer_pure) return final_carry