Skip to content

Commit 07e7fd3

Browse files
committed
save changes
Signed-off-by: Juncheng Gu <[email protected]>
1 parent 205e474 commit 07e7fd3

File tree

3 files changed

+161
-23
lines changed

3 files changed

+161
-23
lines changed
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import functools
4+
5+
import jax
6+
import jax.numpy as jnp
7+
import numpy as np
8+
from absl.testing import parameterized
9+
from jax._src import compilation_cache as cc
10+
from jax._src import test_util as jtu
11+
from jax.sharding import NamedSharding, PartitionSpec
12+
13+
from tpu_inference.distributed.cache_util import get_kv_cache_swap_fn
14+
15+
16+
class TestGetKVCacheSwapFn(jtu.JaxTestCase):
17+
"""Test the get_kv_cache_swap_fn functionality."""
18+
19+
def setUp(self):
20+
super().setUp()
21+
self.num_layers = 2
22+
self.num_tokens = 128
23+
self.num_heads = 8
24+
self.head_size = 128
25+
self.mesh = self.create_mesh((1, 8), ("data", "model"))
26+
if self.mesh is None:
27+
self.skipTest("Cannot create mesh. Must be run on a TPU node.")
28+
return
29+
30+
# Define cache properties
31+
self.cache_shape = (
32+
self.num_tokens,
33+
self.num_heads,
34+
2,
35+
self.head_size,
36+
)
37+
self.cache_dtype = jnp.bfloat16
38+
39+
# Define shardings, mirroring the setup in TPUConnectorWorker
40+
partition_spec = PartitionSpec(None, "model")
41+
self.device_sharding = NamedSharding(self.mesh,
42+
partition_spec,
43+
memory_kind="device")
44+
self.host_sharding = NamedSharding(self.mesh,
45+
partition_spec,
46+
memory_kind="pinned_host")
47+
48+
def tearDown(self):
49+
super().tearDown()
50+
# Reset the cache after each test.
51+
# This can also be achieved by running with JAX_TEST_WITH_PERSISTENT_COMPILATION_CACHE=True
52+
cc.reset_cache()
53+
54+
def create_mesh(self, axis_shapes, axis_names):
55+
"""Creates a JAX device mesh with the default device order."""
56+
try:
57+
num_required_devices = np.prod(axis_shapes)
58+
devices = np.array(jax.devices())
59+
if len(devices) < num_required_devices:
60+
self.skipTest(
61+
f"Not enough devices to create mesh of shape {axis_shapes}."
62+
)
63+
device_array = devices[:num_required_devices].reshape(axis_shapes)
64+
return jax.sharding.Mesh(device_array, axis_names)
65+
except RuntimeError:
66+
return None
67+
68+
@parameterized.named_parameters(
69+
dict(testcase_name="_swap_op_jax_jitted",
70+
swap_op_type="jax",
71+
jitted=True),
72+
dict(testcase_name="_swap_op_pallas_jitted",
73+
swap_op_type="pallas",
74+
jitted=True),
75+
dict(testcase_name="_swap_op_jax_unjitted",
76+
swap_op_type="jax",
77+
jitted=False),
78+
dict(testcase_name="_swap_op_pallas_unjitted",
79+
swap_op_type="pallas",
80+
jitted=False),
81+
)
82+
def test_kv_cache_swap_roundtrip(self, swap_op_type: str, jitted: bool):
83+
"""
84+
Tests the round-trip transfer of KV cache data: Device -> Host -> Device.
85+
86+
This test verifies that the `swap_in_fn` and `swap_out_fn` generated by
87+
`get_kv_cache_swap_fn` correctly transfer data between TPU HBM and
88+
host memory without corruption. It also exercises the code path that
89+
enables buffer donation for the device-to-host transfer.
90+
"""
91+
# 1. Get the swap functions to be tested.
92+
swap_in_fn, swap_out_fn = get_kv_cache_swap_fn(
93+
swap_op_type=swap_op_type,
94+
host_sharding=self.host_sharding,
95+
device_sharding=self.device_sharding,
96+
jitted=jitted,
97+
)
98+
99+
# 2. Create original source data on the TPU device.
100+
@functools.partial(jax.jit, out_shardings=self.device_sharding)
101+
def create_on_device(key):
102+
return jax.random.uniform(key,
103+
shape=self.cache_shape,
104+
dtype=self.cache_dtype)
105+
106+
original_data_tpu = [
107+
create_on_device(jax.random.key(i)) for i in range(self.num_layers)
108+
]
109+
jax.block_until_ready(original_data_tpu)
110+
111+
# 3. Perform Device-to-Host (D2H) transfer (swap out).
112+
# This call exercises the `donate_argnames` functionality when jitted.
113+
data_cpu = swap_out_fn(original_data_tpu)
114+
jax.block_until_ready(data_cpu)
115+
116+
# 4. Verify the data on the host.
117+
for i in range(self.num_layers):
118+
self.assertIs(data_cpu[i].sharding.memory_kind, "pinned_host")
119+
self.assertEqual(data_cpu[i].sharding, self.host_sharding)
120+
self.assertArraysEqual(np.array(data_cpu[i]),
121+
np.array(original_data_tpu[i]))
122+
123+
# 5. Perform Host-to-Device (H2D) transfer (swap in).
124+
roundtrip_data_tpu = swap_in_fn(data_cpu)
125+
jax.block_until_ready(roundtrip_data_tpu)
126+
127+
# 6. Verify the round-tripped data on the device.
128+
for i in range(self.num_layers):
129+
self.assertIs(roundtrip_data_tpu[i].sharding.memory_kind, "device")
130+
self.assertEqual(roundtrip_data_tpu[i].sharding,
131+
self.device_sharding)
132+
self.assertArraysEqual(np.array(roundtrip_data_tpu[i]),
133+
np.array(original_data_tpu[i]))

tpu_inference/distributed/cache_util.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -191,40 +191,41 @@ def get_kv_cache_swap_fn(
191191
jitted: bool = True,
192192
) -> Tuple[KVCacheSwapFn, KVCacheSwapFn]:
193193
"""get the right swap_in and swap_out functions
194-
195194
Args:
196195
swap_op_type : (str) pallas or jax
197196
host_sharding:
198197
device_sharding:
199-
200198
Returns:
201199
A tuple containing the jitted swap-in and swap-out functions.
202200
"""
203201
_swap_fn: SwapFn = pallas_swap_kv_caches if swap_op_type == "pallas" else jax_swap_kv_caches
204-
if jitted:
205-
_swap_in_fn = jax.jit(
206-
_swap_fn,
207-
static_argnames=["src_sharding", "dst_sharding", "direction"],
208-
out_shardings=device_sharding)
209-
_swap_out_fn = jax.jit(
210-
_swap_fn,
211-
static_argnames=["src_sharding", "dst_sharding", "direction"],
212-
out_shardings=host_sharding)
213-
else:
214-
_swap_in_fn = _swap_fn
215-
_swap_out_fn = _swap_fn
216202

217203
# swap_in (h2d)
218-
swap_in_fn = functools.partial(_swap_in_fn,
219-
src_sharding=host_sharding,
220-
dst_sharding=device_sharding,
221-
direction="h2d")
204+
_swap_in_partial = functools.partial(_swap_fn,
205+
src_sharding=host_sharding,
206+
dst_sharding=device_sharding,
207+
direction="h2d")
222208
# swap_out (d2h)
223-
swap_out_fn = functools.partial(_swap_out_fn,
224-
src_sharding=device_sharding,
225-
dst_sharding=host_sharding,
226-
direction="d2h")
227-
return swap_in_fn, swap_out_fn
209+
_swap_out_partial = functools.partial(_swap_fn,
210+
src_sharding=device_sharding,
211+
dst_sharding=host_sharding,
212+
direction="d2h")
213+
214+
if jitted:
215+
216+
def swap_in_fn(src_kv_caches: List[jax.Array]) -> List[jax.Array]:
217+
return _swap_in_partial(src_kv_caches=src_kv_caches)
218+
219+
def swap_out_fn(src_kv_caches: List[jax.Array]) -> List[jax.Array]:
220+
return _swap_out_partial(src_kv_caches=src_kv_caches)
221+
222+
swap_in_fn = jax.jit(swap_in_fn, out_shardings=device_sharding)
223+
swap_out_fn = jax.jit(swap_out_fn,
224+
donate_argnames=["src_kv_caches"],
225+
out_shardings=host_sharding)
226+
return swap_in_fn, swap_out_fn
227+
else:
228+
return _swap_in_partial, _swap_out_partial
228229

229230

230231
@functools.partial(

tpu_inference/distributed/tpu_connector_local.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,11 +1038,15 @@ def _save_blocks_to_cpu(self, req_id: ReqId, full_block_ids: list[int],
10381038
self.runner.kv_caches, blocks_to_save)
10391039

10401040
jax.block_until_ready(flat_kv_caches_tpu)
1041+
flat_kv_caches_tpu_copy = flat_kv_caches_tpu
10411042
logger.info(
10421043
f"extracted_blocks_tpu: {flat_kv_caches_tpu[0].shape}, {flat_kv_caches_tpu[0].sharding}"
10431044
)
10441045

10451046
flat_kv_caches_cpu = self.swap_out_fn(flat_kv_caches_tpu)
1047+
logger.info(
1048+
f"---debug----: flat_kv_caches_tpu_copy: {flat_kv_caches_tpu_copy[0].shape}, {flat_kv_caches_tpu_copy[0].sharding}"
1049+
)
10461050
# Block until the transfer is complete
10471051
if flat_kv_caches_cpu:
10481052
jax.block_until_ready(flat_kv_caches_cpu)

0 commit comments

Comments
 (0)