Skip to content

Commit

Permalink
better kv allocate (#2814)
Browse files Browse the repository at this point in the history
* better allocate

* update max session len
  • Loading branch information
grimoire authored Dec 2, 2024
1 parent c158d18 commit 986ad17
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 98 deletions.
135 changes: 58 additions & 77 deletions lmdeploy/pytorch/engine/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(
self.cache_stream = torch.cuda.Stream()
assert self.cache_stream != torch.cuda.current_stream()
# Initialize the events for stream synchronization.
self.events = [torch.cuda.Event() for _ in range(self.num_layers)]
self.events = torch.cuda.Event()

logger.debug(
f'Initialize cache engine with {cache_config.num_gpu_blocks}'
Expand Down Expand Up @@ -156,80 +156,60 @@ def get_value_block_shape(self,
local=local,
)

def allocate_gpu_cache(self):
"""allocate caches on GPU."""
gpu_cache: List[KVCache] = []
def _allocate_cache(self, num_blocks: int, device: torch.device):
"""allocate cache implement."""
key_block_shape = self.get_key_block_shape(local=True)
value_block_shape = self.get_value_block_shape(local=True)

for _ in range(self.num_layers):
key_blocks = torch.empty(
size=(self.num_gpu_blocks, *key_block_shape),
dtype=self.kv_cache_dtype,
device='cuda',
num_layers = self.num_layers
kv_cache_dtype = self.kv_cache_dtype

key_cache = torch.empty(
size=(num_layers, num_blocks, *key_block_shape),
dtype=kv_cache_dtype,
device=device,
)
value_cache = torch.empty(
size=(num_layers, num_blocks, *value_block_shape),
dtype=kv_cache_dtype,
device=device,
)

output = (key_cache, value_cache)

if self.cache_config.quant_policy in (4, 8):
dtype = self.model_config.dtype
key_sz_cache = torch.empty(
size=(num_layers, num_blocks, *key_block_shape[:-1], 2),
dtype=dtype,
device=device,
)
value_blocks = torch.empty(
size=(self.num_gpu_blocks, *value_block_shape),
dtype=self.kv_cache_dtype,
device='cuda',
val_sz_cache = torch.empty(
size=(num_layers, num_blocks, *value_block_shape[:-1], 2),
dtype=dtype,
device=device,
)
if self.cache_config.quant_policy in (4, 8):
key_scales_zeros = torch.empty(
size=(self.num_gpu_blocks, *key_block_shape[:-1], 2),
dtype=self.model_config.dtype,
device='cuda',
)
value_scales_zeros = torch.empty(
size=(self.num_gpu_blocks, *value_block_shape[:-1], 2),
dtype=self.model_config.dtype,
device='cuda',
)
gpu_cache.append((key_blocks, value_blocks, key_scales_zeros,
value_scales_zeros))
else:
gpu_cache.append((key_blocks, value_blocks))

return gpu_cache
output = output + (key_sz_cache, val_sz_cache)

return output

def allocate_gpu_cache(self):
"""allocate caches on GPU."""
caches = self._allocate_cache(self.num_gpu_blocks, 'cuda')
self.full_gpu_cache = caches
self.local_gpu_cache = list(zip(*caches))
return self.local_gpu_cache

def allocate_cpu_cache(self):
"""allocate caches on Host."""
cpu_cache: List[KVCache] = []
key_block_shape = self.get_key_block_shape(local=True)
value_block_shape = self.get_value_block_shape(local=True)

# TODO: pin memory might need be banned on wsl
pin_memory = True
caches = self._allocate_cache(self.num_gpu_blocks, 'cpu')

for _ in range(self.num_layers):
key_blocks = torch.empty(
size=(self.num_cpu_blocks, *key_block_shape),
dtype=self.kv_cache_dtype,
pin_memory=pin_memory,
)
value_blocks = torch.empty(
size=(self.num_cpu_blocks, *value_block_shape),
dtype=self.kv_cache_dtype,
pin_memory=pin_memory,
)
if self.cache_config.quant_policy in (4, 8):
key_scales_zeros = torch.empty(
size=(self.num_cpu_blocks, *key_block_shape[:-1], 2),
dtype=self.model_config.dtype,
pin_memory=pin_memory,
)
value_scales_zeros = torch.empty(
size=(self.num_cpu_blocks, *value_block_shape[:-1], 2),
dtype=self.model_config.dtype,
pin_memory=pin_memory,
)
cpu_cache.append((key_blocks, value_blocks, key_scales_zeros,
value_scales_zeros))
else:
cpu_cache.append((key_blocks, value_blocks))
return cpu_cache
self.full_cpu_cache = caches
self.local_cpu_cache = list(zip(*caches))
return self.local_cpu_cache

@torch.inference_mode()
def _swap(self, src: List[KVCache], dst: List[KVCache],
def _swap(self, src: List[torch.Tensor], dst: List[torch.Tensor],
src_to_dst: Dict[int, int]):
"""Move caches from src memory to dst memory.
Expand All @@ -238,34 +218,35 @@ def _swap(self, src: List[KVCache], dst: List[KVCache],
dst (List[KVCache]): Destination cache.
src_to_dst (Dict[int, int]): Map between src and dst.
"""
BLOCKS_PER_COPY = 2
num_copy = len(src_to_dst)
src_idx, dst_idx = list(zip(*src_to_dst.items()))
src_idx = torch.tensor(src_idx, device=src[0].device)
dst_idx = torch.tensor(dst_idx, device=dst[0].device)
with torch.cuda.stream(self.cache_stream):
for i in range(self.num_layers):
src_key_cache, src_value_cache = src[i]
dst_key_cache, dst_value_cache = dst[i]

for src_id, dst_id in src_to_dst.items():
if isinstance(dst_key_cache[dst_id], torch.Tensor):
dst_key_cache[dst_id].copy_(src_key_cache[src_id])
dst_value_cache[dst_id].copy_(src_value_cache[src_id])

event = self.events[i]
event.record(stream=self.cache_stream)
for scache, dcache in zip(src, dst):
for idx in range(0, num_copy, BLOCKS_PER_COPY):
sidx = src_idx[idx:idx + BLOCKS_PER_COPY]
didx = dst_idx[idx:idx + BLOCKS_PER_COPY]
sdata = scache[:, sidx]
dcache.index_copy_(1, didx, sdata.to(dcache.device))
self.events.record(stream=self.cache_stream)

def swap_in(self, src_to_dst: Dict[int, int]) -> None:
"""Move cache from Host to Device.
Args:
src_to_dst (Dict[int, int]): Map between src and dst.
"""
self._swap(self.local_cpu_cache, self.local_gpu_cache, src_to_dst)
self._swap(self.full_cpu_cache, self.full_gpu_cache, src_to_dst)

def swap_out(self, src_to_dst: Dict[int, int]) -> None:
"""Move cache from Device to Host.
Args:
src_to_dst (Dict[int, int]): Map between src and dst.
"""
self._swap(self.local_gpu_cache, self.local_cpu_cache, src_to_dst)
self._swap(self.full_gpu_cache, self.full_cpu_cache, src_to_dst)

@classmethod
def get_cache_block_size(cls,
Expand Down
26 changes: 20 additions & 6 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def __init__(self,
self.cache_config = cache_config
self.backend_config = backend_config
self.stream = self.model_agent.stream
self.max_session_len = self._get_max_session_len()

self.req_manager = self._bind_request_manager()

Expand Down Expand Up @@ -261,6 +262,20 @@ def _response(self,
data=data,
err_msg=err_msg))

def _get_max_session_len(self):
"""get max session len."""
session_len = self.scheduler_config.max_session_len
max_tokens = (self.cache_config.num_gpu_blocks *
self.cache_config.block_size)
window_size = self.cache_config.window_size
if window_size > 0 and window_size <= max_tokens:
max_tokens = (1 << 63) - 1
if session_len is None:
session_len = max_tokens
else:
session_len = min(max_tokens, session_len)
return session_len

def _on_add_session(self, reqs: Request, **kwargs):
"""on add session callback."""
for req in reqs:
Expand Down Expand Up @@ -315,12 +330,11 @@ def __update_bad_words(msg):

def __update_max_new_tokens(msg):
"""update max new tokens."""
max_session_len = self.scheduler_config.max_session_len
if max_session_len is not None:
sampling_param = msg.sampling_param
sampling_param.max_new_tokens = min(
sampling_param.max_new_tokens,
max_session_len - msg.num_all_tokens())
max_session_len = self.max_session_len
sampling_param = msg.sampling_param
sampling_param.max_new_tokens = min(
sampling_param.max_new_tokens,
max_session_len - msg.num_all_tokens())

for req in reqs:
session_id = req.data['session_id']
Expand Down
13 changes: 1 addition & 12 deletions lmdeploy/pytorch/engine/engine_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,21 +89,10 @@ class EngineInstance:
"""

def __init__(self, engine: Engine):

def __get_max_input_len(engine):
"""get max input len."""
cache_config = engine.cache_config
max_input_len = (cache_config.block_size *
cache_config.num_gpu_blocks)
window_size = cache_config.window_size
if window_size > 0 and window_size <= max_input_len:
max_input_len = (1 << 63) - 1
return max_input_len

self.engine = engine
self.req_sender = engine.req_manager.build_sender()

self.max_input_len = __get_max_input_len(self.engine)
self.max_input_len = self.engine.max_session_len

def __del__(self):
"""Destructor."""
Expand Down
4 changes: 1 addition & 3 deletions lmdeploy/pytorch/engine/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,7 @@ def cache_swapping(cache_engine: CacheEngine, swap_in_map: dict,
issued_cache_op = True

if issued_cache_op:
cache_events = cache_engine.events
for event in cache_events:
event.wait()
cache_engine.events.wait()


@torch.inference_mode()
Expand Down

0 comments on commit 986ad17

Please sign in to comment.