|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | + |
| 3 | +import functools |
| 4 | + |
| 5 | +import jax |
| 6 | +import jax.numpy as jnp |
| 7 | +import numpy as np |
| 8 | +from absl.testing import parameterized |
| 9 | +from jax._src import compilation_cache as cc |
| 10 | +from jax._src import test_util as jtu |
| 11 | +from jax.sharding import NamedSharding, PartitionSpec |
| 12 | + |
| 13 | +from tpu_inference.distributed.cache_util import get_kv_cache_swap_fn |
| 14 | + |
| 15 | + |
| 16 | +class TestGetKVCacheSwapFn(jtu.JaxTestCase): |
| 17 | + """Test the get_kv_cache_swap_fn functionality.""" |
| 18 | + |
| 19 | + def setUp(self): |
| 20 | + super().setUp() |
| 21 | + self.num_layers = 2 |
| 22 | + self.num_tokens = 128 |
| 23 | + self.num_heads = 8 |
| 24 | + self.head_size = 128 |
| 25 | + self.mesh = self.create_mesh((1, 8), ("data", "model")) |
| 26 | + if self.mesh is None: |
| 27 | + self.skipTest("Cannot create mesh. Must be run on a TPU node.") |
| 28 | + return |
| 29 | + |
| 30 | + # Define cache properties |
| 31 | + self.cache_shape = ( |
| 32 | + self.num_tokens, |
| 33 | + self.num_heads, |
| 34 | + 2, |
| 35 | + self.head_size, |
| 36 | + ) |
| 37 | + self.cache_dtype = jnp.bfloat16 |
| 38 | + |
| 39 | + # Define shardings, mirroring the setup in TPUConnectorWorker |
| 40 | + partition_spec = PartitionSpec(None, "model") |
| 41 | + self.device_sharding = NamedSharding(self.mesh, |
| 42 | + partition_spec, |
| 43 | + memory_kind="device") |
| 44 | + self.host_sharding = NamedSharding(self.mesh, |
| 45 | + partition_spec, |
| 46 | + memory_kind="pinned_host") |
| 47 | + |
| 48 | + def tearDown(self): |
| 49 | + super().tearDown() |
| 50 | + # Reset the cache after each test. |
| 51 | + # This can also be achieved by running with JAX_TEST_WITH_PERSISTENT_COMPILATION_CACHE=True |
| 52 | + cc.reset_cache() |
| 53 | + |
| 54 | + def create_mesh(self, axis_shapes, axis_names): |
| 55 | + """Creates a JAX device mesh with the default device order.""" |
| 56 | + try: |
| 57 | + num_required_devices = np.prod(axis_shapes) |
| 58 | + devices = np.array(jax.devices()) |
| 59 | + if len(devices) < num_required_devices: |
| 60 | + self.skipTest( |
| 61 | + f"Not enough devices to create mesh of shape {axis_shapes}." |
| 62 | + ) |
| 63 | + device_array = devices[:num_required_devices].reshape(axis_shapes) |
| 64 | + return jax.sharding.Mesh(device_array, axis_names) |
| 65 | + except RuntimeError: |
| 66 | + return None |
| 67 | + |
| 68 | + @parameterized.named_parameters( |
| 69 | + dict(testcase_name="_swap_op_jax_jitted", |
| 70 | + swap_op_type="jax", |
| 71 | + jitted=True), |
| 72 | + dict(testcase_name="_swap_op_pallas_jitted", |
| 73 | + swap_op_type="pallas", |
| 74 | + jitted=True), |
| 75 | + dict(testcase_name="_swap_op_jax_unjitted", |
| 76 | + swap_op_type="jax", |
| 77 | + jitted=False), |
| 78 | + dict(testcase_name="_swap_op_pallas_unjitted", |
| 79 | + swap_op_type="pallas", |
| 80 | + jitted=False), |
| 81 | + ) |
| 82 | + def test_kv_cache_swap_roundtrip(self, swap_op_type: str, jitted: bool): |
| 83 | + """ |
| 84 | + Tests the round-trip transfer of KV cache data: Device -> Host -> Device. |
| 85 | +
|
| 86 | + This test verifies that the `swap_in_fn` and `swap_out_fn` generated by |
| 87 | + `get_kv_cache_swap_fn` correctly transfer data between TPU HBM and |
| 88 | + host memory without corruption. It also exercises the code path that |
| 89 | + enables buffer donation for the device-to-host transfer. |
| 90 | + """ |
| 91 | + # 1. Get the swap functions to be tested. |
| 92 | + swap_in_fn, swap_out_fn = get_kv_cache_swap_fn( |
| 93 | + swap_op_type=swap_op_type, |
| 94 | + host_sharding=self.host_sharding, |
| 95 | + device_sharding=self.device_sharding, |
| 96 | + jitted=jitted, |
| 97 | + ) |
| 98 | + |
| 99 | + # 2. Create original source data on the TPU device. |
| 100 | + @functools.partial(jax.jit, out_shardings=self.device_sharding) |
| 101 | + def create_on_device(key): |
| 102 | + return jax.random.uniform(key, |
| 103 | + shape=self.cache_shape, |
| 104 | + dtype=self.cache_dtype) |
| 105 | + |
| 106 | + original_data_tpu = [ |
| 107 | + create_on_device(jax.random.key(i)) for i in range(self.num_layers) |
| 108 | + ] |
| 109 | + jax.block_until_ready(original_data_tpu) |
| 110 | + |
| 111 | + # 3. Perform Device-to-Host (D2H) transfer (swap out). |
| 112 | + # This call exercises the `donate_argnames` functionality when jitted. |
| 113 | + data_cpu = swap_out_fn(original_data_tpu) |
| 114 | + jax.block_until_ready(data_cpu) |
| 115 | + |
| 116 | + # 4. Verify the data on the host. |
| 117 | + for i in range(self.num_layers): |
| 118 | + self.assertIs(data_cpu[i].sharding.memory_kind, "pinned_host") |
| 119 | + self.assertEqual(data_cpu[i].sharding, self.host_sharding) |
| 120 | + self.assertArraysEqual(np.array(data_cpu[i]), |
| 121 | + np.array(original_data_tpu[i])) |
| 122 | + |
| 123 | + # 5. Perform Host-to-Device (H2D) transfer (swap in). |
| 124 | + roundtrip_data_tpu = swap_in_fn(data_cpu) |
| 125 | + jax.block_until_ready(roundtrip_data_tpu) |
| 126 | + |
| 127 | + # 6. Verify the round-tripped data on the device. |
| 128 | + for i in range(self.num_layers): |
| 129 | + self.assertIs(roundtrip_data_tpu[i].sharding.memory_kind, "device") |
| 130 | + self.assertEqual(roundtrip_data_tpu[i].sharding, |
| 131 | + self.device_sharding) |
| 132 | + self.assertArraysEqual(np.array(roundtrip_data_tpu[i]), |
| 133 | + np.array(original_data_tpu[i])) |
0 commit comments