Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[Bugfix][TPU] Fix CPU cache allocation (vllm-project#5869)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored and robertgshaw2-neuralmagic committed Jul 1, 2024
1 parent 1653293 commit e423b2c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
5 changes: 2 additions & 3 deletions vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,10 @@ def swap_blocks(
) -> None:
src_k_cache, src_v_cache = src_kv_cache
dst_k_cache, dst_v_cache = dst_kv_cache
src_indices, dst_indices = src_to_dst
device = dst_k_cache.device
torch.ops.xla.dynamo_set_buffer_donor_(dst_k_cache, True)
torch.ops.xla.dynamo_set_buffer_donor_(dst_v_cache, True)

device = dst_k_cache.device
src_indices, dst_indices = src_to_dst
dst_k_cache[:, dst_indices] = src_k_cache[:, src_indices].to(device)
dst_v_cache[:, dst_indices] = src_v_cache[:, src_indices].to(device)

Expand Down
8 changes: 6 additions & 2 deletions vllm/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,18 @@ def initialize_cache(
self.tpu_cache = []
tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
num_gpu_blocks, self.block_size, num_kv_heads, head_size)
cpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
num_cpu_blocks, self.block_size, num_kv_heads, head_size)
for _ in range(num_layers):
tpu_k_cache = torch.zeros(tpu_cache_shape,
dtype=dtype,
device=self.device)
tpu_v_cache = torch.zeros_like(tpu_k_cache)
self.tpu_cache.append((tpu_k_cache, tpu_v_cache))
cpu_k_cache = torch.zeros_like(tpu_k_cache, device="cpu")
cpu_v_cache = torch.zeros_like(tpu_v_cache, device="cpu")
cpu_k_cache = torch.zeros(cpu_cache_shape,
dtype=dtype,
device="cpu")
cpu_v_cache = torch.zeros_like(cpu_k_cache)
self.cpu_cache.append((cpu_k_cache, cpu_v_cache))
self._warmup_model()

Expand Down

0 comments on commit e423b2c

Please sign in to comment.