Skip to content

Commit

Permalink
add docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire committed Sep 20, 2023
1 parent 6eb6aa1 commit 1231302
Show file tree
Hide file tree
Showing 16 changed files with 738 additions and 215 deletions.
3 changes: 3 additions & 0 deletions lmdeploy/pytorch_poc/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@


class LogicalTokenBlock:
"""Logical block used to count tokens per block."""

def __init__(self, block_id: int, block_size: int):
self.block_id = block_id
Expand All @@ -27,6 +28,8 @@ def append_tokens(self, num_tokens: int = 1):

@dataclass
class PhysicalTokenBlock:
"""Physical block used to schedule key value cache."""

device: str
block_id: int
block_size: int
Expand Down
6 changes: 6 additions & 0 deletions lmdeploy/pytorch_poc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

@dataclass
class SchedulerConfig:
"""Config of scheduler."""

max_batches: int
max_session_len: int
max_request_output_len: int
Expand All @@ -12,13 +14,17 @@ class SchedulerConfig:

@dataclass
class CacheConfig:
"""Config of key value cache."""

block_size: int
num_cpu_blocks: int
num_gpu_blocks: int


@dataclass
class ModelConfig:
"""Config of model."""

hidden_size: int
num_layers: int
num_heads: int
Expand Down
52 changes: 47 additions & 5 deletions lmdeploy/pytorch_poc/dist_utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,29 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable
from typing import Callable, Union

import torch
from torch import nn
from torch import Tensor, nn
from torch.distributed._tensor import (DeviceMesh, DTensor, Replicate, Shard,
distribute_tensor)


def try_to_local(tensor):
def try_to_local(tensor: Union[Tensor, DTensor]):
"""Try to convert DTensor to Tensor.
Args:
tensor (Tensor|DTensor): Tensor to convert.
"""
if isinstance(tensor, DTensor):
tensor = tensor.to_local()
return tensor


def module_to_local(module: nn.Module):
"""convert all DTensor parameters to Tensor parameters in module.
Args:
module (Module): Module to convert.
"""
for name, mod in module.named_children():
module_to_local(mod)

Expand Down Expand Up @@ -77,8 +87,24 @@ def colwise_parallelize_linear_fn(module: nn.Module,
module.register_parameter(name, dist_param)


def _partition_module(mod_name: str, prefix: str, module: nn.Module,
device_mesh: DeviceMesh, func: Callable):
def _partition_module(
mod_name: str,
prefix: str,
module: nn.Module,
device_mesh: DeviceMesh,
func: Callable,
):
"""partition module.
Parameters in module won't be force Replicated.
Args:
mod_name (str): module name.
prefix (str): Parameter prefix.
module (Module): Module to be partitioned.
device_mesh (DeviceMesh): The device mesh.
func (Callable): partition callback
"""
for name, mod in module.named_children():
child_name = f'{prefix}{name}'
_partition_module(child_name,
Expand All @@ -94,6 +120,16 @@ def partition_module(module: nn.Module,
device_mesh: DeviceMesh,
func: Callable,
to_local: bool = False):
"""partition module.
Parameters in module won't be force Replicated.
Args:
module (Module): Module to be partitioned.
device_mesh (DeviceMesh): The device mesh.
func (Callable): partition callback.
to_local (bool): Convert all DTensor parameters to Tensor parameters.
"""
_partition_module('',
'',
module=module,
Expand All @@ -105,6 +141,12 @@ def partition_module(module: nn.Module,


def replicate_module(model: nn.Module, device_mesh: DeviceMesh):
"""Replicate all parameters in module.
Args:
model (Module): Module to perform replicate.
device_mesh (DeviceMesh): The distribution device mesh.
"""
for name, param in model.named_parameters(recurse=False):
param = distribute_tensor(param,
device_mesh=device_mesh,
Expand Down
49 changes: 49 additions & 0 deletions lmdeploy/pytorch_poc/engine/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@


class CacheEngine:
"""Host and Device memory maintainer.
Args:
cache_config (CacheConfig): config of the cache information.
model_config (ModelConfig): config of the model.
rank (int): distribution rank, 0 on non-distributed environment.
world_size (int): distribution world size, 1 on non-distributed
environment.
device_mesh (DeviceMesh): distribution device mesh.
"""

def __init__(
self,
Expand Down Expand Up @@ -50,9 +60,11 @@ def __init__(

@property
def gpu_cache(self):
"""gpu cache."""
return self.local_gpu_cache

def get_key_block_shape(self, local: bool = False) -> Tuple[int, int, int]:
"""get shape of key block."""
num_heads = self.num_heads
if local:
assert self.num_heads % self.world_size == 0
Expand All @@ -65,6 +77,7 @@ def get_key_block_shape(self, local: bool = False) -> Tuple[int, int, int]:

def get_value_block_shape(self,
local: bool = False) -> Tuple[int, int, int]:
"""get shape of value block."""
num_heads = self.num_heads
if local:
assert self.num_heads % self.world_size == 0
Expand All @@ -76,6 +89,7 @@ def get_value_block_shape(self,
)

def allocate_gpu_cache(self):
"""allocate caches on GPU."""
gpu_cache: List[KVCache] = []
key_block_shape = self.get_key_block_shape(local=True)
value_block_shape = self.get_value_block_shape(local=True)
Expand All @@ -95,6 +109,7 @@ def allocate_gpu_cache(self):
return gpu_cache

def allocate_cpu_cache(self):
"""allocate caches on Host."""
cpu_cache: List[KVCache] = []
key_block_shape = self.get_key_block_shape()
value_block_shape = self.get_value_block_shape()
Expand All @@ -118,6 +133,13 @@ def allocate_cpu_cache(self):

def _swap(self, src: List[KVCache], dst: List[KVCache],
src_to_dst: Dict[int, int]):
"""Move caches from src memory to dst memory.
Args:
src (List[KVCache]): Source cache.
dst (List[KVCache]): Destination cache.
src_to_dst (Dict[int, int]): Map between src and dst.
"""
with torch.cuda.stream(self.cache_stream):
for i in range(self.num_layers):
src_key_cache, src_value_cache = src[i]
Expand All @@ -131,14 +153,33 @@ def _swap(self, src: List[KVCache], dst: List[KVCache],
event.record(stream=self.cache_stream)

def swap_in(self, src_to_dst: Dict[int, int]) -> None:
"""Move cache from Host to Device.
Args:
src_to_dst (Dict[int, int]): Map between src and dst.
"""
self._swap(self.local_cpu_cache, self.local_gpu_cache, src_to_dst)

def swap_out(self, src_to_dst: Dict[int, int]) -> None:
"""Move cache from Device to Host.
Args:
src_to_dst (Dict[int, int]): Map between src and dst.
"""
self._swap(self.local_gpu_cache, self.local_cpu_cache, src_to_dst)

@staticmethod
def get_cache_block_size(block_size: int,
model_config: ModelConfig) -> int:
"""Get the required cache size of the model.
Args:
block_size (int): The token numbers of the block.
model_config (ModelConfig): The config of the model.
Return:
int: Required memory size in bytes.
"""
head_size = model_config.get_head_size()
num_layers = model_config.num_layers
num_heads = model_config.num_heads
Expand All @@ -151,4 +192,12 @@ def get_cache_block_size(block_size: int,


def _get_dtype_size(dtype: torch.dtype) -> int:
"""get size of the given dtype.
Args:
dtype (torch.dtype): Data type.
Return:
int: size in bytes.
"""
return torch.tensor([], dtype=dtype).element_size()
Loading

0 comments on commit 1231302

Please sign in to comment.