diff --git a/benchmark/profile_torch_generation.py b/benchmark/profile_torch_generation.py index de134e73d6..8eaf1e7235 100644 --- a/benchmark/profile_torch_generation.py +++ b/benchmark/profile_torch_generation.py @@ -102,7 +102,7 @@ def _infer(model, session_id): _start = time.perf_counter() procs = [] for i in range(concurrency): - proc = Thread(target=_infer, args=(model, i + 1)) + proc = Thread(target=_infer, args=(model, i + 1), daemon=True) procs.append(proc) proc.start() @@ -139,7 +139,8 @@ def profile_throughput(model_path: str, concurrency: int, input_seqlen: int, for i in range(concurrency): proc = Thread(target=infer, args=(tm_model, i + 1, input_ids, output_seqlen, top_k, - top_p, temperature, test_round, que)) + top_p, temperature, test_round, que), + daemon=True) procs.append(proc) proc.start() @@ -256,7 +257,7 @@ def mem_monitor(cls): def start(cls): cls._running = True from multiprocessing import Process - cls.proc = Process(target=cls.mem_monitor) + cls.proc = Process(target=cls.mem_monitor, daemon=True) cls.proc.start() @classmethod diff --git a/benchmark/profile_torch_throughput.py b/benchmark/profile_torch_throughput.py index e5b8142a14..f3f305a3ed 100644 --- a/benchmark/profile_torch_throughput.py +++ b/benchmark/profile_torch_throughput.py @@ -146,7 +146,8 @@ def process_request(self, # start threads for i in range(concurrency): t = Thread(target=self._inference, - args=(req_queue, res_queue, i, stream_output)) + args=(req_queue, res_queue, i, stream_output), + daemon=True) t.start() threads.append(t) diff --git a/lmdeploy/pytorch/adapter/adapter.py b/lmdeploy/pytorch/adapter/adapter.py new file mode 100644 index 0000000000..0eb73aef9e --- /dev/null +++ b/lmdeploy/pytorch/adapter/adapter.py @@ -0,0 +1,346 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import re +from dataclasses import dataclass +from typing import Any, Dict, List + +import torch +from torch import Tensor + +from ..block import LogicalTokenBlocks + + +def _cache_weight(cache: Tensor, weight: Tensor, block_table: Tensor): + """cache weight.""" + assert cache.dim() == 2 + assert weight.dim() == 2 + assert block_table.dim() == 1 + + rank, feat_size = weight.size() + assert cache.size(-1) >= feat_size, ('cache.size(-1) >= feat_size failed.') + assert rank <= block_table.size(0), ('rank <= block_table.size(0) failed.') + block_table = block_table[:rank] + cache[block_table, :feat_size] = weight.to(device=cache.device, + dtype=cache.dtype) + + +def _get_named_loralinears(model: torch.nn.Module): + """get all named loralinear.""" + from peft.tuners.lora import Linear as LoRALinear + named_loralinear: Dict[str, torch.nn.Module] = dict() + for name, module in model.named_modules(): + if isinstance(module, LoRALinear): + named_loralinear[name] = module + return named_loralinear + + +def _get_layer_index(key: str, config: Any): + """get layer index of the lora linear.""" + from peft.utils.other import COMMON_LAYERS_PATTERN + layer_indexing_pattern = getattr(config, 'layers_pattern', None) + layers_pattern = layer_indexing_pattern or COMMON_LAYERS_PATTERN + if isinstance(layers_pattern, str): + layers_pattern = [layers_pattern] + for pattern in layers_pattern: + layer_index = re.match(f'.*.{pattern}\\.(\\d+)\\.*', key) + + if layer_index is not None: + return int(layer_index[1]) + + +def get_indexed_lora_linears(model: torch.nn.Module): + """get indexed lora linear.""" + named_linears = _get_named_loralinears(model) + + config = None + peft_config = getattr(model, 'peft_config', dict) + if len(peft_config) > 0: + config = next(iter(peft_config.values())) + + indexed_linears = dict() + for name, layer in named_linears.items(): + index = _get_layer_index(name, config) + target = name.split('.')[-1] + indexed_linears.setdefault(index, dict()) + indexed_linears[index][target] = layer + return indexed_linears + + +def update_lora_linears(lora_linears: Dict, + weight_maps: List['AdapterWeightMap'], + device: str = 'cuda'): + """update lora linears.""" + + def __get_targets(): + """get targets.""" + all_targets = set() + for weight_map in weight_maps: + targets = weight_map.target_modules.keys() + all_targets.update(targets) + return all_targets + + def __get_rank_and_start(target_names): + """get rank and start.""" + rank_map = dict() + start_map = dict() + for target in target_names: + ranks = [0] + [ + weight_map.target_modules[target].rank + for weight_map in weight_maps + ] + block_starts = [0] + [ + weight_map.target_modules[target].block_start + for weight_map in weight_maps + ] + rank_map[target] = torch.tensor(ranks) + start_map[target] = torch.tensor(block_starts) + return rank_map, start_map + + def __update_linear(linear, idx, rank_map, start_map, adapter_names): + """update linear.""" + linear.layer_idx = idx + linear.ranks = rank_map[target].to(device) + linear.block_starts = start_map[target].to(device) + for name in adapter_names: + if name in linear.lora_A: + linear.lora_A.pop(name) + linear.lora_B.pop(name) + + adapter_names = [weight_map.adapter_name for weight_map in weight_maps] + + all_targets = __get_targets() + + for weight_map in weight_maps: + weight_map.expand_targets(all_targets) + + rank_map, start_map = __get_rank_and_start(all_targets) + + for idx, lora_linear in lora_linears.items(): + for target, linear in lora_linear.items(): + __update_linear(linear, + idx, + rank_map=rank_map, + start_map=start_map, + adapter_names=adapter_names) + + +@dataclass +class TargetMeta: + rank: int + block_start: int + + +@dataclass +class AdapterWeightMap: + adapter_name: str + block_table: Tensor + target_modules: Dict[str, TargetMeta] + + @classmethod + def new(cls, adapter_name: str, rank: int, target_names: List[str], + block_table: Tensor): + """create new weightmap.""" + block_start = 0 + target_modules: Dict[str, TargetMeta] = dict() + for name in target_names: + target_modules[name] = TargetMeta(rank, block_start) + block_start += rank + + return AdapterWeightMap(adapter_name, + block_table=block_table, + target_modules=target_modules) + + def expand_targets(self, + target_names: List[str], + ignore_exists: bool = True): + for name in target_names: + if name in self.target_modules: + if ignore_exists: + continue + else: + raise RuntimeError(f'target {name} exists.') + self.target_modules[name] = TargetMeta(0, 0) + + @classmethod + def cache_lora_a(cls, cache: Tensor, weight: Tensor, block_table: Tensor): + """cache lora a weight.""" + return _cache_weight(cache, weight, block_table) + + @classmethod + def cache_lora_b(cls, cache: Tensor, weight: Tensor, block_table: Tensor): + """cache lora b weight.""" + return _cache_weight(cache, weight.t(), block_table) + + def cache_lora_linear(self, lora_linear: torch.nn.Module, cache_a: Tensor, + cache_b: Tensor): + """cache lora linear.""" + name = self.adapter_name + target_modules = self.target_modules + block_table = self.block_table + block_start = 0 + for target, target_meta in target_modules.items(): + linear = lora_linear[target] + if not (name in linear.lora_A and name in linear.lora_B): + continue + linear_a = linear.lora_A[name] + linear_b = linear.lora_B[name] + weight_a = linear_a.weight + weight_b = linear_b.weight + assert weight_a is not None + assert weight_b is not None + rank = target_meta.rank + block_offset = block_table[block_start:block_start + rank] + block_start += rank + self.cache_lora_a(cache_a, weight_a, block_offset) + self.cache_lora_b(cache_b, weight_b, block_offset) + + def cache_adapter(self, lora_linears: Dict, caches: List[List[Tensor]]): + """cache all linear.""" + assert len(lora_linears) == len(caches), ( + 'len(lora_linears) == len(caches)') + + for idx, lora_linear in lora_linears.items(): + assert idx < len(caches), 'idx < len(caches)' + cache_a, cache_b = caches[idx] + self.cache_lora_linear(lora_linear, cache_a, cache_b) + + +@dataclass +class SchedulerAdapter: + """lora adapter.""" + + idx: int + adapter_path: str + adapter_name: str + config: Any + target_modules: List[str] + logical_blocks: LogicalTokenBlocks + adapter_manager: 'AdapterManager' + _active: bool = False + + @classmethod + def from_pretrained(cls, adapter_path: str, adapter_name: str, idx: int, + manager: 'AdapterManager'): + """from_pretrained.""" + from peft import PeftConfig + config = PeftConfig.from_pretrained(adapter_path) + + return cls.from_config(config, + adapter_name=adapter_name, + idx=idx, + manager=manager) + + @classmethod + def from_config(cls, config: Any, adapter_name: str, idx: int, + manager: 'AdapterManager'): + """from config.""" + new_adapter = SchedulerAdapter( + idx, + adapter_path=config.base_model_name_or_path, + adapter_name=adapter_name, + config=config, + target_modules=list(config.target_modules), + logical_blocks=LogicalTokenBlocks(1), + adapter_manager=manager) + new_adapter._active = False + return new_adapter + + @property + def name(self): + """get adapter name.""" + return self.adapter_name + + @property + def rank(self): + """get rank.""" + return self.config.r + + def is_actived(self): + """check if adapter is active.""" + return self._active + + def active(self, flag: bool = True): + """active adapter.""" + self.adapter_manager._on_active(self, flag) + self._active = flag + + def num_blocks(self): + """get num blocks.""" + # ranks * (lora_a + lora_b) * num_targets + return self.rank * len(self.target_modules) + + def num_required_blocks(self): + """get num required blocks.""" + if self.is_actived(): + return 0 + else: + return self.num_blocks() + + def build_weight_map(self, block_table: Tensor): + return AdapterWeightMap.new(self.name, + rank=self.rank, + target_names=self.target_modules, + block_table=block_table) + + +class AdapterManager: + """Adapter manager.""" + + def __init__(self) -> None: + self._adapters: Dict[str, SchedulerAdapter] = dict() + self._adapter_count = 0 + self._active_count = 0 + + self._add_non_adapter() + + def _add_non_adapter(self): + """add non adapter.""" + from peft import LoraConfig + adapter_name = None + config = LoraConfig(r=0, target_modules=[]) + adapter = self.add_adapter_from_config(config, + adapter_name=adapter_name) + adapter.active() + + def _on_active(self, adapter: SchedulerAdapter, flag: bool): + """on active.""" + if adapter._active != flag: + if flag: + self._active_count += 1 + else: + self._active_count -= 1 + + def _add_adapter(self, adapter: SchedulerAdapter): + """add adapter.""" + assert adapter.adapter_name not in self._adapters + self._adapters[adapter.adapter_name] = adapter + self._adapter_count += 1 + return adapter + + def add_adapter_from_config(self, config: Any, adapter_name: str): + """add adapter from config.""" + adapter = SchedulerAdapter.from_config(config, + adapter_name=adapter_name, + idx=self._adapter_count, + manager=self) + return self._add_adapter(adapter) + + def add_adapter_from_pretrained(self, adapter_path: str, + adapter_name: str): + """add adapter by path and name.""" + adapter = SchedulerAdapter.from_pretrained(adapter_path, + adapter_name=adapter_name, + idx=self._adapter_count, + manager=self) + return self._add_adapter(adapter) + + def get_adapter(self, name: str, default=None): + """get adapter.""" + return self._adapters.get(name, default) + + def num_adapters(self): + """get num adapters.""" + return len(self._adapters) + + +ADAPTER_MANAGER = AdapterManager() diff --git a/lmdeploy/pytorch/block.py b/lmdeploy/pytorch/block.py index f41e251575..434fb82150 100644 --- a/lmdeploy/pytorch/block.py +++ b/lmdeploy/pytorch/block.py @@ -5,33 +5,6 @@ import numpy as np -class LogicalTokenBlock: - """Logical block used to count tokens per block.""" - - def __init__(self, block_id: int, block_size: int): - self.block_id = block_id - self.block_size = block_size - - self.num_tokens = 0 - - def get_num_empty_slots(self): - """get num empty slots.""" - return self.block_size - self.num_tokens - - def is_empty(self): - """is empty.""" - return self.num_tokens == 0 - - def is_full(self): - """is full.""" - return self.num_tokens == self.block_size - - def append_tokens(self, num_tokens: int = 1): - """append tokens.""" - assert num_tokens <= self.get_num_empty_slots() - self.num_tokens += num_tokens - - def _div_up(x, n): """perform div up.""" return (x + n - 1) // n diff --git a/lmdeploy/pytorch/chat.py b/lmdeploy/pytorch/chat.py index c18413b5ab..491a432b40 100644 --- a/lmdeploy/pytorch/chat.py +++ b/lmdeploy/pytorch/chat.py @@ -52,7 +52,7 @@ def run_chat(model_path, engine_config: EngineConfig, gen_config: EngineGenerationConfig = None, session_id: int = 1, - trust_remote_code=True): + trust_remote_code: bool = True): """An example to perform model inference through the command line interface. @@ -64,12 +64,16 @@ def run_chat(model_path, trust_remote_code (bool): trust remote code. """ from lmdeploy.pytorch.engine import Engine - tm_model = Engine(model_path, - engine_config=engine_config, - trust_remote_code=trust_remote_code) + tm_model = Engine.from_pretrained(model_path, + engine_config=engine_config, + trust_remote_code=trust_remote_code) tokenizer = tm_model.tokenizer generator = tm_model.create_instance() + adapter_name = None + if engine_config.adapters is not None: + adapter_name = next(iter(engine_config.adapters.keys())) + nth_round = 1 step = 0 seed = random.getrandbits(64) @@ -107,7 +111,8 @@ def run_chat(model_path, gen_config.stop_words = stop_words for outputs in generator.stream_infer(session_id=session_id, input_ids=input_ids, - gen_config=gen_config): + gen_config=gen_config, + adapter_name=adapter_name): status, res, tokens = outputs # decode res response = tokenizer.decode(res, offset=response_size) @@ -136,7 +141,8 @@ def main(model_path, repetition_penalty: float = 1.0, tp: int = 1, stream_output: bool = True, - trust_remote_code=True): + adapter: str = None, + trust_remote_code: bool = True): """An example to perform model inference through the command line interface. @@ -150,9 +156,15 @@ def main(model_path, repetition_penalty (float): parameter to penalize repetition tp (int): GPU number used in tensor parallelism stream_output (bool): indicator for streaming output or not + adapter (str): path to lora adapter. trust_remote_code (bool): Trust remote code. """ - engine_config = EngineConfig(model_name=model_name, tp=tp) + adapters = None + if adapter is not None: + adapters = dict(default=adapter) + engine_config = EngineConfig(model_name=model_name, + tp=tp, + adapters=adapters) gen_config = EngineGenerationConfig(max_new_tokens=512, top_k=top_k, top_p=top_p, diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 5130aa3137..4f2f90210a 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -1,5 +1,19 @@ # Copyright (c) OpenMMLab. All rights reserved. from dataclasses import dataclass, field +from typing import Any, Dict + +import torch + + +def _get_torch_dtype(config: Any, default: str = 'float16'): + """Get the torch dtype from the model config. + + Args: + config: Config of the hf model. + default (str): default device type. + """ + torch_dtype = getattr(config, 'torch_dtype', default) + return eval(f'torch.{torch_dtype}') @dataclass @@ -30,6 +44,7 @@ class EngineConfig: block_size: int = 64 num_cpu_blocks: int = 0 num_gpu_blocks: int = 0 + adapters: Dict[str, str] = None @dataclass @@ -41,6 +56,7 @@ class SchedulerConfig: max_request_output_len: int = 512 eviction_type: str = 'recompute' prefill_interval: int = 16 + max_active_adapters: int = 64 @dataclass @@ -61,10 +77,80 @@ class ModelConfig: num_heads: int bos_token_id: int eos_token_id: int - dtype: str + dtype: torch.dtype = torch.float16 multi_query_attention: bool = False json_config: dict = field(default_factory=dict) + hf_config: Any = None def get_head_size(self): """get head size.""" return self.hidden_size // self.num_heads + + @classmethod + def from_pretrained(cls, + pretrained_model_name_or_path: str, + trust_remote_code: bool = True): + """build ModelConfig from model path or name.""" + from transformers import AutoConfig + hf_config = AutoConfig.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=trust_remote_code) + return cls.from_hf_config(hf_config, pretrained_model_name_or_path) + + @classmethod + def from_hf_config(cls, hf_config: Any, model_path: str = None): + """from huggingface config.""" + if model_path is None: + model_path = '' + + def __build_falcon(): + """build falcon.""" + if hf_config.new_decoder_architecture: + # 40b-instruct, GQA + kv_dim = hf_config.hidden_size // hf_config.num_attention_heads + kv_dim *= hf_config.num_kv_heads + kv_head = hf_config.num_kv_heads + if hf_config.multi_query: + # 7b-instruct, MQA + kv_dim = hf_config.hidden_size // hf_config.num_attention_heads + kv_head = 1 + else: + # rw-1b, MHA + kv_dim = hf_config.hidden_size + kv_head = hf_config.num_attention_heads + return ModelConfig( + kv_dim, + hf_config.num_hidden_layers, + kv_head, + bos_token_id=hf_config.bos_token_id, + eos_token_id=hf_config.eos_token_id, + multi_query_attention=hf_config.multi_query, + ) + + def __build_chatglm(): + """build chatglm.""" + return ModelConfig(hf_config.hidden_size // + hf_config.num_attention_heads * + hf_config.multi_query_group_num, + hf_config.num_layers, + hf_config.multi_query_group_num, + bos_token_id=hf_config.bos_token_id, + eos_token_id=hf_config.eos_token_id) + + def __build_default(): + return ModelConfig(hf_config.hidden_size, + hf_config.num_hidden_layers, + hf_config.num_attention_heads, + bos_token_id=hf_config.bos_token_id, + eos_token_id=hf_config.eos_token_id) + + if 'falcon' in model_path: + model_config = __build_falcon() + elif 'chatglm' in model_path: + model_config = __build_chatglm() + else: + model_config = __build_default() + + model_config.dtype = _get_torch_dtype(hf_config) + model_config.hf_config = hf_config + model_config.json_config = hf_config.to_dict() + return model_config diff --git a/lmdeploy/pytorch/dist_utils.py b/lmdeploy/pytorch/dist_utils.py index 0b0f63f83c..62d80155d6 100644 --- a/lmdeploy/pytorch/dist_utils.py +++ b/lmdeploy/pytorch/dist_utils.py @@ -6,6 +6,13 @@ from torch.distributed._tensor import (DeviceMesh, DTensor, Replicate, Shard, distribute_tensor) +try: + from peft.tuners.lora import Linear as LoRALinear +except ImportError: + + class LoRALinear: + pass + def try_to_local(tensor: Union[Tensor, DTensor]): """Try to convert DTensor to Tensor. @@ -34,9 +41,9 @@ def module_to_local(module: nn.Module): module.register_buffer(name, try_to_local(buf)) -def rowwise_parallelize_linear_fn(module: nn.Module, - device_mesh: DeviceMesh, - to_local: bool = False) -> None: +def rowwise_parallelize_linear(module: nn.Module, + device_mesh: DeviceMesh, + to_local: bool = False) -> None: """ This function parallelizes the input :class:`nn.Linear` module in :class:`RowwiseParallel` style. @@ -84,10 +91,59 @@ def rowwise_parallelize_linear_fn(module: nn.Module, module.register_buffer(name, dist_tensor) -def colwise_parallelize_linear_fn(module: nn.Module, +def rowwise_parallelize_loralinear(module: LoRALinear, + device_mesh: DeviceMesh, + to_local: bool = False) -> None: + """rowwize parallelize lora linear. + + Read S-LoRA for more detail. + """ + rowwise_parallelize_linear(module.base_layer, + device_mesh=device_mesh, + to_local=to_local) + for mod in module.lora_A.values(): + rowwise_parallelize_linear(mod, + device_mesh=device_mesh, + to_local=to_local) + for mod in module.lora_B.values(): + colwise_parallelize_linear(mod, + device_mesh=device_mesh, + to_local=to_local) + module._tp_mode = 'rowwise' + + +def rowwise_parallelize_linear_fn(module: nn.Module, device_mesh: DeviceMesh, to_local: bool = False) -> None: """ + This function parallelizes the input :Linear module in + :class:`RowwiseParallel` style. + + Args: + module (:class:`nn.Module`): + The :class:`nn.Linear` module to be parallelized. + device_mesh (:class:`DeviceMesh`): + Object which describes the mesh topology of devices. + + Returns: + None + """ + if isinstance(module, torch.nn.Linear): + return rowwise_parallelize_linear(module, + device_mesh=device_mesh, + to_local=to_local) + elif isinstance(module, LoRALinear): + return rowwise_parallelize_loralinear(module, + device_mesh=device_mesh, + to_local=to_local) + else: + raise TypeError(f'Unsupported module: {type(module)}') + + +def colwise_parallelize_linear(module: nn.Module, + device_mesh: DeviceMesh, + to_local: bool = False) -> None: + """ This function parallelizes the input :class:`nn.Linear` module in :class:`ColwiseParallel` style. @@ -107,7 +163,6 @@ def colwise_parallelize_linear_fn(module: nn.Module, dist_tensor = try_to_local(dist_tensor) dist_param = torch.nn.Parameter(dist_tensor) module.register_parameter(name, dist_param) - # Weight, bias and scale are registered as buffer in QLinear for name, buffer in module.named_buffers(): dist_tensor = distribute_tensor(buffer, device_mesh, [Shard(0)]) @@ -116,6 +171,52 @@ def colwise_parallelize_linear_fn(module: nn.Module, module.register_buffer(name, dist_tensor) +def colwise_parallelize_loralinear(module: nn.Module, + device_mesh: DeviceMesh, + to_local: bool = False) -> None: + """colwise parallelize lora linear.""" + colwise_parallelize_linear(module.base_layer, + device_mesh=device_mesh, + to_local=to_local) + for mod in module.lora_A.values(): + colwise_parallelize_linear(mod, + device_mesh=device_mesh, + to_local=to_local) + for mod in module.lora_B.values(): + colwise_parallelize_linear(mod, + device_mesh=device_mesh, + to_local=to_local) + module._tp_mode = 'colwise' + + +def colwise_parallelize_linear_fn(module: nn.Module, + device_mesh: DeviceMesh, + to_local: bool = False) -> None: + """ + This function parallelizes the input :Linear module in + :class:`ColwiseParallel` style. + + Args: + module (:class:`nn.Module`): + The :class:`nn.Linear` module to be parallelized. + device_mesh (:class:`DeviceMesh`): + Object which describes the mesh topology of devices. + + Returns: + None + """ + if isinstance(module, torch.nn.Linear): + return colwise_parallelize_linear(module, + device_mesh=device_mesh, + to_local=to_local) + elif isinstance(module, LoRALinear): + return colwise_parallelize_loralinear(module, + device_mesh=device_mesh, + to_local=to_local) + else: + raise TypeError(f'Unsupported module: {type(module)}') + + def _partition_module( mod_name: str, prefix: str, diff --git a/lmdeploy/pytorch/engine/cache_engine.py b/lmdeploy/pytorch/engine/cache_engine.py index ef76fe68bb..57de049991 100644 --- a/lmdeploy/pytorch/engine/cache_engine.py +++ b/lmdeploy/pytorch/engine/cache_engine.py @@ -34,6 +34,8 @@ def __init__( world_size: int = 1, device_mesh: DeviceMesh = None, ) -> None: + if rank == 0: + logger.info(f'build CacheEngine with config:{cache_config}') self.rank = rank self.world_size = world_size if device_mesh is None and self.world_size > 1: @@ -44,8 +46,6 @@ def __init__( self.model_config = model_config self.block_size = cache_config.block_size - self.num_gpu_blocks = cache_config.num_gpu_blocks - self.num_cpu_blocks = cache_config.num_cpu_blocks self.head_size = model_config.get_head_size() self.num_layers = model_config.num_layers @@ -71,11 +71,26 @@ def __init__( f'Initialize cache engine with {cache_config.num_gpu_blocks}' f' gpu blocks and {cache_config.num_cpu_blocks} cpu blocks.') + @property + def cpu_cache(self): + """gpu cache.""" + return self.local_cpu_cache + @property def gpu_cache(self): """gpu cache.""" return self.local_gpu_cache + @property + def num_gpu_blocks(self): + """num gpu blocks.""" + return self.cache_config.num_gpu_blocks + + @property + def num_cpu_blocks(self): + """num gpu blocks.""" + return self.cache_config.num_cpu_blocks + def get_key_block_shape(self, local: bool = False) -> Tuple[int, int, int]: """get shape of key block.""" num_heads = self.num_heads @@ -127,8 +142,8 @@ def allocate_gpu_cache(self): def allocate_cpu_cache(self): """allocate caches on Host.""" cpu_cache: List[KVCache] = [] - key_block_shape = self.get_key_block_shape() - value_block_shape = self.get_value_block_shape() + key_block_shape = self.get_key_block_shape(local=True) + value_block_shape = self.get_value_block_shape(local=True) # TODO: pin memory might need be banned on wsl pin_memory = True diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 82e0811664..f80a15ef61 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -6,23 +6,26 @@ from typing import Any, Dict, List import torch -from transformers import AutoConfig from lmdeploy.messages import EngineGenerationConfig from lmdeploy.tokenizer import Tokenizer from lmdeploy.utils import get_logger -from ..config import CacheConfig, EngineConfig, ModelConfig, SchedulerConfig +from ..adapter.adapter import ADAPTER_MANAGER, SchedulerAdapter +from ..config import CacheConfig, EngineConfig, SchedulerConfig from ..messages import (MessageStatus, SamplingParam, SchedulerSequence, SchedulerSession) from ..paging import Scheduler from .logits_process import FusedLogitsProcessor -from .model_agent import BaseModelAgent, ModelInputs, TPModelAgent +from .model_agent import AutoModelAgent, ModelInputs from .request import (Request, RequestManager, RequestType, Response, ResponseType) logger = get_logger('lmdeploy') +SeqList = List[SchedulerSequence] +AdapterList = List[SchedulerAdapter] + @dataclass class InferOutput: @@ -50,82 +53,15 @@ def _check_resp_success(resp: Response, warning_msg: str = None): return _check_resp(resp, ResponseType.SUCCESS, warning_msg) -def _get_torch_dtype(config: Any, default: str = 'float16'): - """Get the torch dtype from the model config. - - Args: - config: Config of the hf model. - default (str): default device type. - """ - torch_dtype = getattr(config, 'torch_dtype', default) - return eval(f'torch.{torch_dtype}') - - -def _build_model_config(model_path: str, hf_config: Any): - """build model config.""" - torch_dtype = _get_torch_dtype(hf_config) - if 'falcon' in model_path: - if hf_config.new_decoder_architecture: - # 40b-instruct, GQA - kv_dim = hf_config.hidden_size // hf_config.num_attention_heads - kv_dim *= hf_config.num_kv_heads - kv_head = hf_config.num_kv_heads - if hf_config.multi_query: - # 7b-instruct, MQA - kv_dim = hf_config.hidden_size // hf_config.num_attention_heads - kv_head = 1 - else: - # rw-1b, MHA - kv_dim = hf_config.hidden_size - kv_head = hf_config.num_attention_heads - model_config = ModelConfig(kv_dim, - hf_config.num_hidden_layers, - kv_head, - bos_token_id=hf_config.bos_token_id, - eos_token_id=hf_config.eos_token_id, - dtype=torch_dtype, - multi_query_attention=hf_config.multi_query, - json_config=hf_config.to_dict()) - elif 'chatglm' in model_path: - model_config = ModelConfig(hf_config.hidden_size // - hf_config.num_attention_heads * - hf_config.multi_query_group_num, - hf_config.num_layers, - hf_config.multi_query_group_num, - bos_token_id=hf_config.bos_token_id, - eos_token_id=hf_config.eos_token_id, - dtype=torch_dtype, - json_config=hf_config.to_dict()) - else: - model_config = ModelConfig(hf_config.hidden_size, - hf_config.num_hidden_layers, - hf_config.num_attention_heads, - bos_token_id=hf_config.bos_token_id, - eos_token_id=hf_config.eos_token_id, - dtype=torch_dtype, - json_config=hf_config.to_dict()) - - return model_config - - -def _build_model_agent(model_path: str, - model_config: ModelConfig, - cache_config: CacheConfig, - trust_remote_code: bool, - tp: int = 1): - """create model agent.""" - if tp == 1: - model_agent = BaseModelAgent(model_path, - model_config=model_config, - cache_config=cache_config, - trust_remote_code=trust_remote_code) - else: - model_agent = TPModelAgent(model_path, - model_config=model_config, - cache_config=cache_config, - world_size=tp, - trust_remote_code=trust_remote_code) - return model_agent +def _paging_adapters(adapters: dict, model_agent: AutoModelAgent, + scheduler: Scheduler): + adapters = adapters or dict() + weight_maps = [] + for name, path in adapters.items(): + weight_map = scheduler.add_adapter(path, name) + weight_map.block_table = torch.tensor(weight_map.block_table) + weight_maps.append(weight_map) + model_agent.paging_adapters(weight_maps) def _tensorlize_block_offsets(block_offsets): @@ -136,6 +72,26 @@ def _tensorlize_block_offsets(block_offsets): return block_offsets +def _get_adapter_ids(seqs: SeqList, adapters: AdapterList): + """get adapter ids.""" + adapter_names_map = dict( + (ada.name, idx) for idx, ada in enumerate(adapters)) + adapter_ids = [adapter_names_map[seq.adapter_name] for seq in seqs] + return adapter_ids + + +def _update_blocksize(cache_config: CacheConfig, adapters: List[str], tp: int): + """update blocksize for adapters.""" + if adapters is None: + return cache_config + + if cache_config.block_size != tp: + logger.warning('Lora adapter require block size ' + f'= tp({tp}).') + cache_config.block_size = tp + return cache_config + + class Engine: """The inference engine of lmdeploy pytorch. @@ -147,13 +103,12 @@ class Engine: def __init__(self, model_path: str, engine_config: EngineConfig, - trust_remote_code=True) -> None: + trust_remote_code: bool = True) -> None: self.engine_config = engine_config model_name = engine_config.model_name tp = engine_config.tp self.tp = tp - self.gpu_count = tp self.model_name = model_name scheduler_config = SchedulerConfig( @@ -162,46 +117,78 @@ def __init__(self, eviction_type='recompute') # block_size = 1 to enable unified paging + adapters = engine_config.adapters cache_config = CacheConfig(block_size=engine_config.block_size, num_cpu_blocks=engine_config.num_cpu_blocks, num_gpu_blocks=engine_config.num_gpu_blocks) + cache_config = _update_blocksize(cache_config, + adapters=adapters, + tp=tp) - hf_config = AutoConfig.from_pretrained( - model_path, trust_remote_code=trust_remote_code) - - torch_dtype = _get_torch_dtype(hf_config) - self.torch_dtype = torch_dtype - - model_config = _build_model_config(model_path, hf_config) - - self.model_agent = _build_model_agent( + self.model_agent = AutoModelAgent.from_pretrained( model_path, - model_config=model_config, cache_config=cache_config, trust_remote_code=trust_remote_code, + adapters=adapters, tp=tp) cache_config = self.model_agent.cache_config self.scheduler = Scheduler(scheduler_config, cache_config) + if adapters: + _paging_adapters(adapters, + model_agent=self.model_agent, + scheduler=self.scheduler) + self.scheduler_config = scheduler_config self.cache_config = cache_config - self.model_config = model_config - self.session_len = scheduler_config.max_session_len self.stream = torch.cuda.Stream() - self._bind_request_manager() + self.req_manager = self._bind_request_manager() self.owned_sessions = [] # create main thread - self._start_loop() + self.loop_threads = self._start_loop() + self.req_sender = self.req_manager.build_sender(self.loop_threads) self._create_buffers() self.tokenizer = Tokenizer(model_path) + @classmethod + def from_pretrained(cls, + pretrained_model_name_or_path: str, + engine_config: EngineConfig, + trust_remote_code: bool = True, + **kwargs): + """lmdeploy python inference engine. + + Args: + pretrained_model_name_or_path (str): + It could be one of the following options: + - i) A local directory path of a turbomind model which is + converted by `lmdeploy convert` command or download from + ii) and iii) + - ii) The model_id of a lmdeploy-quantized model hosted + inside a model repo on huggingface.co, such as + "InternLM/internlm-chat-20b-4bit", + "lmdeploy/llama2-chat-70b-4bit", etc. + - iii) The model_id of a model hosted inside a model repo + on huggingface.co, such as "InternLM/internlm-chat-7b", + "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" + and so on. + scheduler_config (SchedulerConfig): The config of the scheduler. + cache_config (CacheConfig): The config of the cache info. + tp (int): Number of tensor parallel. + model_name (str): needed when pretrained_model_name_or_path is c) + adapters (dict): named lora adapters. + """ + logger.debug(f'Get unexpected kwargs: {kwargs}') + return cls(model_path=pretrained_model_name_or_path, + engine_config=engine_config, + trust_remote_code=trust_remote_code) + def _create_buffers(self): - scheduler_config = self.scheduler_config - max_batches = scheduler_config.max_batches + max_batches = self.scheduler_config.max_batches # buffers to create inputs self._q_start_loc_buf = torch.arange(max_batches) @@ -215,14 +202,13 @@ def _bind_request_manager(self): req_manager.bind_func(RequestType.STOP_SESSION, self._on_stop_session) req_manager.bind_func(RequestType.END_SESSION, self._on_end_session) req_manager.bind_func(RequestType.ADD_MESSAGE, self._on_add_message) - self.req_manager = req_manager - self.req_sender = req_manager.build_sender() + return req_manager def _start_loop(self): """start loop.""" loop_threads = Thread(target=self.loop, daemon=True) loop_threads.start() - self.loop_threads = loop_threads + return loop_threads def _on_add_session(self, reqs: Request, **kwargs): """on add session callback.""" @@ -279,10 +265,13 @@ def _on_add_message(self, reqs: Request, **kwargs): sess = self.scheduler.sessions[session_id] # TODO: support 1 session n sequence if len(sess.sequences) == 0: + assert len( + req.data['token_ids']) > 0, ('Empty input is not allowed.') sess.add_sequence( req.data['token_ids'], max_output_len=req.data['max_request_output_len'], - sampling_param=req.data['sampling_param']) + sampling_param=req.data['sampling_param'], + adapter_name=req.data['adapter_name']) msg = next(iter(sess.sequences.values())) self.scheduler.add_sequence(msg) else: @@ -296,6 +285,19 @@ def _on_add_message(self, reqs: Request, **kwargs): msg.req_id = req.req_id self.scheduler.update() + @property + def model_config(self): + """model config.""" + return self.model_agent.model_config + + @property + def gpu_count(self): + return self.tp + + @property + def session_len(self): + return self.scheduler_config.max_session_len + def create_instance(self, cuda_stream_id=0): """Create a turbomind instance. @@ -337,11 +339,12 @@ def end_session(self, session_id: int): self.owned_sessions.remove(session_id) @torch.inference_mode() - def create_model_inputs(self, messages: List[SchedulerSequence]): + def create_model_inputs(self, messages: SeqList, adapters: AdapterList): """create model inputs from messages. Args: - messages (List[SchedulerSequence]): The input messages. + messages (SeqList): The input messages. + adapters (AdapterList): Adapters. """ history_lengths = [msg.history_len for msg in messages] @@ -377,6 +380,20 @@ def create_model_inputs(self, messages: List[SchedulerSequence]): block_offsets = self.scheduler.get_block_tables(messages) block_offsets = _tensorlize_block_offsets(block_offsets) + local_adapter_ids = None + global_adapter_ids = None + adapter_offsets = None + max_rank = 0 + if ADAPTER_MANAGER.num_adapters() > 1: + local_adapter_ids = _get_adapter_ids(messages, adapters) + local_adapter_ids = seq_length.new_tensor(local_adapter_ids) + adapter_offsets = self.scheduler.get_block_tables(adapters) + adapter_offsets = _tensorlize_block_offsets(adapter_offsets) + global_adapter_ids = [ada.idx for ada in adapters] + global_adapter_ids = seq_length.new_tensor(global_adapter_ids) + ranks = [ada.rank for ada in adapters] + max_rank = max(ranks) + # add batch dim [bs=1, seq_len] if input_ids.ndim == 1: input_ids = input_ids.unsqueeze(0) @@ -389,6 +406,10 @@ def create_model_inputs(self, messages: List[SchedulerSequence]): q_start_loc=q_start_loc, history_lengths=history_lengths, is_decoding=is_decoding, + local_adapter_ids=local_adapter_ids, + global_adapter_ids=global_adapter_ids, + adapter_offsets=adapter_offsets, + max_rank=max_rank, meta=meta) def _stoping_criteria(self, msg: SchedulerSequence, next_token_id: int): @@ -432,8 +453,8 @@ def _check_session_len(msg, max_session_len): return True return False - def sampling_logits(self, logits: torch.Tensor, - running: List[SchedulerSequence], inputs: ModelInputs): + def sampling_logits(self, logits: torch.Tensor, running: SeqList, + inputs: ModelInputs): """sampling logits.""" def _group_params(running): @@ -481,8 +502,8 @@ def _sampling(grouped_params, split_logits, inputs): return next_token_ids, split_logits - def update_running(self, running: List[SchedulerSequence], - next_token_ids: torch.Tensor, meta: Any): + def update_running(self, running: SeqList, next_token_ids: torch.Tensor, + meta: Any): """update scheduler.""" for token, msg in zip(next_token_ids, running): msg.meta = meta @@ -504,13 +525,14 @@ def step(self, is_prefill: bool, return_logits: bool = False): # schedule schedule_output = self.scheduler.schedule(is_prefill=is_prefill) - running: List[SchedulerSequence] = schedule_output.running + running: SeqList = schedule_output.running swap_in_map = schedule_output.swap_in_map swap_out_map = schedule_output.swap_out_map + adapters = schedule_output.adapters if len(running) == 0: return dict() - inputs = self.create_model_inputs(running) + inputs = self.create_model_inputs(running, adapters) # inference output = self.model_agent.forward(inputs, @@ -547,8 +569,8 @@ def step(self, is_prefill: bool, return_logits: bool = False): def batched_infer(self, session_ids: List[int], token_ids: List[List[int]] = None, - request_output_len: int = 512, - sampling_param: SamplingParam = SamplingParam(), + gen_config: EngineGenerationConfig = None, + adapter_names: List[str] = None, keep_cache: bool = False): """Send inference request. @@ -558,6 +580,7 @@ def batched_infer(self, request_output_len (int): The max output length of this request. step (int): No use for now. sampling_param (SamplingParam): The sampling param of the output. + adapter_names (List[str]): The name of the adapters. keep_cache (bool): Keep kv cache after infer. Returns: @@ -567,6 +590,10 @@ def batched_infer(self, """ batch_size = len(token_ids) assert len(session_ids) == batch_size + if adapter_names is not None: + assert len(adapter_names) == batch_size + else: + adapter_names = [None for _ in range(batch_size)] def _add_sessions(session_ids, owned_sessions): for session_id in session_ids: @@ -575,13 +602,15 @@ def _add_sessions(session_ids, owned_sessions): def _add_messages(session_ids, token_ids): add_msgs = [] - for session_id, token_id in zip(session_ids, token_ids): - msg = dict( - token_ids=token_id, - session_id=session_id, - max_request_output_len=request_output_len, - sampling_param=sampling_param, - ) + request_output_len = gen_config.max_new_tokens + sampling_param = SamplingParam.from_gen_config(gen_config) + for session_id, token_id, adapter_name in zip( + session_ids, token_ids, adapter_names): + msg = dict(token_ids=token_id, + session_id=session_id, + max_request_output_len=request_output_len, + sampling_param=sampling_param, + adapter_name=adapter_name) add_msgs.append(msg) req_types = [RequestType.ADD_MESSAGE] * batch_size req_ids = self.req_sender.batched_send_async(req_types, @@ -647,7 +676,7 @@ def decode(self, prompt_token_ids: List[List[int]]): sessions.append(sess) self.add_session(sess) - msgs: List[SchedulerSequence] = [] + msgs: SeqList = [] for token_ids, sess in zip(prompt_token_ids, sessions): msg = sess.add_sequence(token_ids=token_ids) msgs.append(msg) @@ -727,13 +756,15 @@ class EngineInstance: def __init__(self, engine: Engine): self.engine = engine - self.req_sender = engine.req_manager.build_sender() + self.req_sender = engine.req_manager.build_sender(engine.loop_threads) self.owned_sessions: List[int] = list() def __del__(self): """Destructor.""" - for session_id in self.owned_sessions: - self.end(session_id) + if self.req_sender.is_thread_alive(): + for session_id in self.owned_sessions: + self.end(session_id) + self.engine.req_manager.senders.pop(self.req_sender.sender_id) def _try_add_session(self, session_id: int): """Add new session. @@ -752,6 +783,7 @@ def stream_infer(self, session_id: int, input_ids: List[int], gen_config: EngineGenerationConfig = None, + adapter_name: str = None, **kwargs): """Send stream inference request. @@ -771,21 +803,14 @@ def stream_infer(self, # TODO: support input embedding, step gen_config = gen_config or EngineGenerationConfig() request_output_len = gen_config.max_new_tokens - sampling_param = SamplingParam( - top_p=gen_config.top_p, - top_k=gen_config.top_k, - temperature=gen_config.temperature, - repetition_penalty=gen_config.repetition_penalty, - ignore_eos=gen_config.ignore_eos, - random_seed=gen_config.random_seed, - stop_words=gen_config.stop_words, - bad_words=gen_config.bad_words) + sampling_param = SamplingParam.from_gen_config(gen_config=gen_config) self._try_add_session(session_id) msg = dict( token_ids=input_ids, session_id=session_id, max_request_output_len=request_output_len, sampling_param=sampling_param, + adapter_name=adapter_name, ) req_id = self.req_sender.send_async(RequestType.ADD_MESSAGE, msg) diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index cb4dda8c3d..cf99e8a1c7 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -1,20 +1,19 @@ # Copyright (c) OpenMMLab. All rights reserved. - -import json import os -from dataclasses import asdict, dataclass +from dataclasses import asdict, dataclass, field from typing import Any, Callable, Dict, List, Union import torch import torch.distributed as dist from torch import multiprocessing as mp from torch.distributed._tensor import DeviceMesh, Replicate, distribute_tensor -from transformers import AutoConfig, AutoModelForCausalLM -from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME, cached_file +from transformers import AutoModelForCausalLM from lmdeploy.pytorch.accel import LoadNoInit from lmdeploy.utils import get_logger +from ..adapter.adapter import (AdapterWeightMap, get_indexed_lora_linears, + update_lora_linears) from ..config import CacheConfig, ModelConfig from ..models import patch from ..utils import get_gpu_memory @@ -27,7 +26,9 @@ def _update_cache_config(model_config: ModelConfig, cache_config: CacheConfig, - gpu_id: int = 0): + gpu_id: int = 0, + gpu_mem_percent: float = 0.7, + host_mem_size: int = 4 * (1 << 30)): """Update the gpu mem and cpu mem according to model info. Args: @@ -35,11 +36,10 @@ def _update_cache_config(model_config: ModelConfig, cache_config (CacheConfig): The config of the cache info. gpu_id (int): The GPU id to use. """ - GPU_MEM_PERCENT = 0.7 - SWAP_SPACE = 8 * (1 << 30) + torch.cuda.empty_cache() gpu_mem_physical_free, _ = get_gpu_memory(gpu_id) - gpu_mem = gpu_mem_physical_free * GPU_MEM_PERCENT - cpu_mem = SWAP_SPACE + gpu_mem = gpu_mem_physical_free * gpu_mem_percent + cpu_mem = host_mem_size cache_block_size = CacheEngine.get_cache_block_size( cache_config.block_size, model_config) if cache_config.num_cpu_blocks == 0: @@ -47,18 +47,7 @@ def _update_cache_config(model_config: ModelConfig, if cache_config.num_gpu_blocks == 0: cache_config.num_gpu_blocks = int(gpu_mem / cache_block_size) - logger.info('block num: {}'.format(cache_config.num_gpu_blocks)) - - -def _get_torch_dtype(config: Any, default: str = 'float16'): - """Get the torch dtype from the model config. - - Args: - config: Config of the hf model. - default (str): default device type. - """ - torch_dtype = getattr(config, 'torch_dtype', default) - return eval(f'torch.{torch_dtype}') + logger.debug('block num: {}'.format(cache_config.num_gpu_blocks)) @dataclass @@ -72,6 +61,10 @@ class ModelInputs: q_start_loc: torch.LongTensor history_lengths: List[int] is_decoding: bool + local_adapter_ids: torch.LongTensor + global_adapter_ids: torch.LongTensor + adapter_offsets: torch.LongTensor + max_rank: int meta: Any def to_device(self, device: str): @@ -86,42 +79,79 @@ def to_device(self, device: str): return ModelInputs(**out_dict) +@dataclass class StepContext: """context of Model. - patched model might need extra information to perform inference. - This dataclass provide these infos and tools. - - Args: - inputs (ModelInputs): packaged model inputs. - world_size (int): The distribution world size. - device (str): The device of the tensors. + patched model might need extra information to perform inference. This + dataclass provide these infos and tools. """ + inputs: ModelInputs + block_offsets: torch.LongTensor + position_ids: torch.LongTensor + position_ids_1d: torch.LongTensor + q_start_loc: torch.LongTensor + history_lengths: torch.LongTensor + seq_length: torch.LongTensor + max_seq_length: int + kv_seq_length: torch.LongTensor + kv_caches: List + is_decoding: bool + world_size: int = 1 + json_config: Dict = None + local_adapter_ids: torch.LongTensor = None + global_adapter_ids: torch.LongTensor = None + adapter_offsets: torch.LongTensor = None + max_rank: int = 0 - def __init__( - self, + _outputs: Dict = field(default_factory=dict) + + @classmethod + def new( + cls, inputs: ModelInputs, world_size: int = 1, device: str = 'cuda', json_config: dict = None, + kv_caches: List = None, ): - self.inputs = inputs - self.block_offsets = inputs.block_offsets - self.position_ids = inputs.position_ids - self.q_start_loc = inputs.q_start_loc - self.history_lengths = inputs.history_lengths - self.seq_length = inputs.seq_length - self.q_seq_length = self.seq_length - self.world_size = world_size - self.json_config = json_config + """build step context. - # seq_len + history_length - self.kv_seq_length = self.position_ids[..., -1] + 1 + Args: + inputs (ModelInputs): packaged model inputs. + world_size (int): The distribution world size. + device (str): The device of the tensors. + """ - self.position_ids_1d = self.get_position_ids_1d( - self.position_ids, self.seq_length, device) + position_ids = inputs.position_ids + max_seq_length = position_ids.size(-1) - self._outputs = dict() + # seq_len + history_length + kv_seq_length = position_ids[..., -1] + 1 + + # position ids 1d + seq_length = inputs.seq_length + position_ids_1d = cls.get_position_ids_1d(position_ids, seq_length, + device) + + ret = StepContext(inputs=inputs, + block_offsets=inputs.block_offsets, + position_ids=inputs.position_ids, + position_ids_1d=position_ids_1d, + q_start_loc=inputs.q_start_loc, + history_lengths=inputs.history_lengths, + seq_length=inputs.seq_length, + max_seq_length=max_seq_length, + kv_seq_length=kv_seq_length, + kv_caches=kv_caches, + is_decoding=inputs.is_decoding, + world_size=world_size, + json_config=json_config, + local_adapter_ids=inputs.local_adapter_ids, + global_adapter_ids=inputs.global_adapter_ids, + adapter_offsets=inputs.adapter_offsets, + max_rank=inputs.max_rank) + return ret @classmethod def tensorlize_block_offsets(cls, block_offsets, device): @@ -200,12 +230,13 @@ def model_forward( with torch.inference_mode(), torch.cuda.stream(stream): # forward inputs = inputs.to_device('cuda') - context = StepContext( + context = StepContext.new( inputs=inputs, world_size=world_size, json_config=json_config, + kv_caches=cache_engine.gpu_cache, ) - output = patched_model( + output = patched_model.patched_forward( input_ids=inputs.input_ids, position_ids=inputs.position_ids, attention_mask=inputs.attention_mask, @@ -220,7 +251,91 @@ def model_forward( return dict(logits=output['logits'], custom_outputs=context._outputs) -class BaseModelAgent: +def _load_adapters(hf_model: torch.nn.Module, + adapters: Dict[str, str], + device_map: str = 'cpu'): + """load adapters.""" + if not adapters: + return + for name, path in adapters.items(): + logger.info(f'load adapter <{name}> from "{path}".') + hf_model.load_adapter(path, name, device_map=device_map) + + +def _add_adapters(hf_model: torch.nn.Module, adapters: Dict[str, str]): + """add adapters.""" + if not adapters: + return + from peft import PeftConfig, inject_adapter_in_model + for name, path in adapters.items(): + config = PeftConfig.from_pretrained(path) + inject_adapter_in_model(config, model=hf_model, adapter_name=name) + + +def _unparam_lora_weight(model: torch.nn.Module): + """unparam lora weight. + + We don't want to move weight of lora to gpu. + """ + from peft.tuners.lora import Linear as LoRALinear + + def _tensorize_weight(linear): + """tensorize weight.""" + w = linear.weight + del linear.weight + linear.weight = w.data + + for _, mod in model.named_modules(): + if isinstance(mod, LoRALinear): + lora_A = mod.lora_A + lora_B = mod.lora_B + for linear in lora_A.values(): + _tensorize_weight(linear) + for linear in lora_B.values(): + _tensorize_weight(linear) + + +SwapMap = Dict[int, int] + + +class AutoModelAgent: + """Base model agent.""" + + def __init__(self, model_config: ModelConfig, cache_config: CacheConfig): + self.model_config = model_config + self.cache_config = cache_config + + def paging_adapters(self, weight_maps: List[AdapterWeightMap]): + """paging adapter.""" + raise NotImplementedError('Not implemented.') + + def forward(self, inputs: ModelInputs, swap_in_map: SwapMap, + swap_out_map: SwapMap): + """model forward. + + Args: + inputs (Dict): The input data comes from _make_inputs. + swap_in_map (SwapMap): Cache maps to swap in. + swap_out_map (SwapMap): Cache maps to swap out. + """ + raise NotImplementedError('Not implemented.') + + @classmethod + def from_pretrained(cls, + pretrained_model_name_or_path: str, + cache_config: CacheConfig, + trust_remote_code: bool, + adapters: Dict[str, str] = None, + tp: int = 1): + """from pretrained.""" + return build_model_agent(pretrained_model_name_or_path, + cache_config=cache_config, + trust_remote_code=trust_remote_code, + adapters=adapters, + tp=tp) + + +class BaseModelAgent(AutoModelAgent): """Base model agent. load model on local gpu @@ -236,14 +351,15 @@ def __init__(self, model_path: str, model_config: ModelConfig, cache_config: CacheConfig, + adapters: Dict[str, str] = None, trust_remote_code: bool = True): - self.model_config = model_config - self.cache_config = cache_config + super().__init__(model_config=model_config, cache_config=cache_config) torch_dtype = model_config.dtype self.patched_model = self._build_model( model_path, torch_dtype=torch_dtype, + adapters=adapters, trust_remote_code=trust_remote_code) _update_cache_config(model_config, cache_config) @@ -254,6 +370,7 @@ def __init__(self, def _build_model(self, model_path: str, torch_dtype: torch.dtype, + adapters: Dict[str, str] = None, trust_remote_code: bool = True): """build patched model.""" with LoadNoInit(): @@ -263,17 +380,39 @@ def _build_model(self, trust_remote_code=trust_remote_code) hf_model.eval() hf_model.config.use_cache = True - patched_model = patch(hf_model, _PATCH_ARG_NAMES).cuda() + + if adapters: + _load_adapters(hf_model, adapters) + + patched_model = patch(hf_model, _PATCH_ARG_NAMES) + + if adapters: + _unparam_lora_weight(patched_model) + + patched_model = patched_model.cuda() return patched_model - def forward(self, inputs: ModelInputs, swap_in_map: Dict[int, int], - swap_out_map: Dict[int, int]): + def paging_adapters(self, weight_maps: List[AdapterWeightMap]): + """paging adapter.""" + logger.info('paging adapters.') + lora_linears = get_indexed_lora_linears(self.patched_model) + cpu_caches = self.cache_engine.cpu_cache + num_blocks = self.cache_engine.num_cpu_blocks + cpu_caches = [(kcache.view(num_blocks, + -1), vcache.view(num_blocks, -1)) + for kcache, vcache in cpu_caches] + for weight_map in weight_maps: + weight_map.cache_adapter(lora_linears, cpu_caches) + update_lora_linears(lora_linears, weight_maps, device='cuda') + + def forward(self, inputs: ModelInputs, swap_in_map: SwapMap, + swap_out_map: SwapMap): """model forward. Args: inputs (Dict): The input data comes from _make_inputs. - swap_in_map (Dict[int, int]): Cache maps to swap in. - swap_out_map (Dict[int, int]): Cache maps to swap out. + swap_in_map (SwapMap): Cache maps to swap in. + swap_out_map (SwapMap): Cache maps to swap out. """ cache_swapping(self.cache_engine, @@ -290,36 +429,49 @@ def forward(self, inputs: ModelInputs, swap_in_map: Dict[int, int], return output -def _get_checkpoints(model_path: str): - """get checkpoints.""" - try: - torch_model_json_path = cached_file(model_path, WEIGHTS_INDEX_NAME) - with open(torch_model_json_path, mode='r') as f: - torch_model_json = json.load(f) - - weight_map = torch_model_json['weight_map'] - - checkpoints = list(set(weight_map.values())) - checkpoints = [cached_file(model_path, ckpt) for ckpt in checkpoints] - except Exception: - logger.warning(f'load failed, try load from {WEIGHTS_NAME}.') - checkpoints = [cached_file(model_path, WEIGHTS_NAME)] - - return checkpoints - - @dataclass class TPResponse: ret_code: int error: Union[Exception, List[Exception]] = None data: Any = None + def gather_error(self): + """gather error.""" + rank = dist.get_rank() + world_size = dist.get_world_size() + + # gather errors + error_count = torch.tensor(self.ret_code).cuda(rank) + dist.all_reduce(error_count) + if error_count.item() > 0: + all_errors = [None] * world_size + dist.all_gather_object(all_errors, self.error) + self.ret_code = 1 + self.error = all_errors + + def raise_error(self, default_error: Exception): + """raise error.""" + if self.error is None: + raise default_error + elif isinstance(self.error, Exception): + raise self.error + else: + assert isinstance(self.error, List), ('expect error type list, ' + f'got {type(self.error)}') + rank = dist.get_rank() + err = self.error[rank] + if err is None: + raise default_error + else: + raise err + def _tp_build_model( rank: int, model_path: str, model_config: ModelConfig, cache_config: CacheConfig, + adapters: Dict[str, str], out_que: mp.Queue, world_size: int, trust_remote_code=True, @@ -356,24 +508,36 @@ def _broadcast_config(cache_config): return config_list[0] try: - config = AutoConfig.from_pretrained( - model_path, trust_remote_code=trust_remote_code) - torch_dtype = _get_torch_dtype(config) + config = model_config.hf_config + # config = AutoConfig.from_pretrained( + # model_path, trust_remote_code=trust_remote_code) + torch_dtype = model_config.dtype with init_empty_weights(): model = AutoModelForCausalLM.from_config( config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code) - model.eval() - model.config.use_cache = True + _add_adapters(model, adapters) + model.eval() + model.config.use_cache = True + + if rank == 0: + with LoadNoInit(): + device_map = 'auto' + param_model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=torch_dtype, + device_map=device_map, + trust_remote_code=trust_remote_code) + _load_adapters(param_model, adapters, device_map=device_map) + model.load_state_dict(param_model.state_dict(), assign=True) + del param_model - checkpoints = _get_checkpoints(model_path) patched_model = patch( model, extra_args=_PATCH_ARG_NAMES, rank=rank, world_size=world_size, - checkpoints=checkpoints, ) _update_cache_config(model_config, cache_config) @@ -387,18 +551,12 @@ def _broadcast_config(cache_config): error_type = e # response - error_code = torch.tensor(error_code).cuda(rank) - dist.all_reduce(error_code) - error_code = error_code.item() - if error_code > 0: - all_errors = [None] * world_size - dist.all_gather_object(all_errors, error_type) - if rank == 0: - out_que.put(TPResponse(1, all_errors, cache_config)) - return - else: - if rank == 0: - out_que.put(TPResponse(0, None, cache_config)) + resp = TPResponse(error_code, error_type, cache_config) + resp.gather_error() + if rank == 0: + out_que.put(resp) + if resp.ret_code != 0: + resp.raise_error(RuntimeError('failed to init model.')) return patched_model, cache_engine @@ -446,11 +604,62 @@ def _tp_get_input(rank: int, in_que: mp.Queue, world_size: int): return inputs, swap_in_map, swap_out_map +def _tp_paging_adapters( + rank: int, + patched_model: torch.nn.Module, + cache_engine: CacheEngine, + in_que: mp.Queue, + out_que: mp.Queue, +): + """tp paging adapters.""" + + def __get_weight_map(): + """get weight map.""" + if rank == 0: + weight_maps = in_que.get() + dist_obj = [weight_maps] + else: + dist_obj = [None] + dist.broadcast_object_list(dist_obj) + return dist_obj[0] + + def __paging(weight_maps): + """paging.""" + lora_linears = get_indexed_lora_linears(patched_model) + cpu_caches = cache_engine.cpu_cache + num_blocks = cache_engine.num_cpu_blocks + cpu_caches = [(kcache.view(num_blocks, + -1), vcache.view(num_blocks, -1)) + for kcache, vcache in cpu_caches] + for weight_map in weight_maps: + weight_map.cache_adapter(lora_linears, cpu_caches) + update_lora_linears(lora_linears, weight_maps, device='cuda') + + weight_maps = __get_weight_map() + + resp = TPResponse(0) + try: + if rank == 0: + logger.info('tp paging adapters.') + if len(weight_maps) > 0: + __paging(weight_maps) + except Exception as e: + resp.ret_code = 1 + resp.error = e + + resp.gather_error() + if rank == 0: + out_que.put(resp) + if resp.ret_code != 0: + resp.raise_error(RuntimeError('tp paging adapters failed.')) + + def _tp_model_loop( rank: int, model_path: str, model_config: ModelConfig, cache_config: CacheConfig, + adapters: Dict[str, str], in_que: mp.Queue, out_que: mp.Queue, world_size: int, @@ -469,10 +678,22 @@ def _tp_model_loop( world_size (int): The distribution world size. """ stream = torch.cuda.Stream() - patched_model, cache_engine = _tp_build_model(rank, model_path, - model_config, cache_config, - out_que, world_size, - trust_remote_code) + patched_model, cache_engine = _tp_build_model( + rank, + model_path, + model_config, + cache_config, + adapters, + out_que=out_que, + world_size=world_size, + trust_remote_code=trust_remote_code) + + if adapters: + _tp_paging_adapters(rank, + patched_model, + cache_engine=cache_engine, + in_que=in_que, + out_que=out_que) while True: inputs, swap_in_map, swap_out_map = _tp_get_input( @@ -499,7 +720,8 @@ def _start_tp_process(rank: int, world_size: int, func: Callable, args: List = None, - kwargs: Dict = None): + kwargs: Dict = None, + port: int = 29500): """Start the tensor parallel process. Args: @@ -511,7 +733,7 @@ def _start_tp_process(rank: int, """ try: os.environ['MASTER_ADDR'] = '127.0.0.1' - os.environ['MASTER_PORT'] = '29500' + os.environ['MASTER_PORT'] = str(port) dist.init_process_group('nccl', rank=rank, world_size=world_size) with torch.cuda.device(rank), torch.no_grad(): @@ -520,13 +742,33 @@ def _start_tp_process(rank: int, func(rank, *args, **kwargs) except Exception as e: from traceback import print_exc - - logger.error(f'rank[{rank}]: {e}') + logger.error(f'Rank[{rank}] failed.') print_exc() raise e -class TPModelAgent: +def _queue_get_response(que: mp.Queue, + mp_context: mp.ProcessContext, + interval: float = 1.0): + """get response.""" + from multiprocessing.queues import Empty + + def __check_context_alive(): + """check context alive.""" + procs = mp_context.processes + for idx, p in enumerate(procs): + if not p.is_alive(): + raise RuntimeError(f'Rank[{idx}] failed.') + + while True: + try: + return que.get(timeout=interval) + except Empty: + pass + __check_context_alive() + + +class TPModelAgent(AutoModelAgent): """Tensor Parallelism model agent. load model on multiple GPUs @@ -543,25 +785,26 @@ def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, world_size: int, + adapters: Dict[str, str] = None, trust_remote_code: bool = True) -> None: mp.set_start_method('spawn') + super().__init__(model_config=model_config, cache_config=cache_config) self.world_size = world_size - self.model_config = model_config - self.cache_config = cache_config self.tp_model_in_que = mp.Queue(10) self.tp_model_out_que = mp.Queue(10) self.patch_model_tp(model_path, model_config=model_config, cache_config=cache_config, + adapters=adapters, in_que=self.tp_model_in_que, out_que=self.tp_model_out_que, world_size=world_size, trust_remote_code=trust_remote_code) def patch_model_tp(self, model_path: str, model_config: ModelConfig, - cache_config: CacheConfig, in_que: mp.Queue, - out_que: mp.Queue, world_size: int, + cache_config: CacheConfig, adapters: Dict[str, str], + in_que: mp.Queue, out_que: mp.Queue, world_size: int, trust_remote_code: bool): """Start tensor parallel sub process. @@ -576,6 +819,17 @@ def patch_model_tp(self, model_path: str, model_config: ModelConfig, out_que (mp.Queue): Output queue. Used to send the model output. world_size (int): The distribution world size. """ + + def __find_available_port() -> bool: + """find available port.""" + import socket + port = 29500 + while True: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + if s.connect_ex(('localhost', port)) != 0: + return port + port += 1 + self.mp_context = mp.spawn( _start_tp_process, args=( @@ -584,23 +838,35 @@ def patch_model_tp(self, model_path: str, model_config: ModelConfig, (model_path, ), dict(model_config=model_config, cache_config=cache_config, + adapters=adapters, in_que=in_que, out_que=out_que, world_size=world_size, trust_remote_code=trust_remote_code), + __find_available_port(), ), nprocs=world_size, join=False, daemon=True, ) - resp: TPResponse = out_que.get() + resp: TPResponse = _queue_get_response(out_que, self.mp_context) if resp.ret_code != 0: logger.error(f'Init tp model failed with error: {resp.error}') raise next(err for err in resp.error if err is not None) self.cache_config = resp.data - def forward(self, inputs: Dict, swap_in_map: Dict[int, int], - swap_out_map: Dict[int, int]): + def paging_adapters(self, weight_maps: List[AdapterWeightMap]): + """load adapter.""" + if not weight_maps: + return + self.tp_model_in_que.put(weight_maps) + resp: TPResponse = self.tp_model_out_que.get() + if resp.ret_code != 0: + logger.error(f'paging adapters failed with error: {resp.error}') + raise next(err for err in resp.error if err is not None) + + def forward(self, inputs: ModelInputs, swap_in_map: SwapMap, + swap_out_map: SwapMap): """model forward. Args: @@ -611,8 +877,33 @@ def forward(self, inputs: Dict, swap_in_map: Dict[int, int], with torch.no_grad(): self.tp_model_in_que.put((inputs, swap_in_map, swap_out_map)) - resp: TPResponse = self.tp_model_out_que.get() + resp: TPResponse = _queue_get_response(self.tp_model_out_que, + self.mp_context) if resp.ret_code != 0: raise RuntimeError('tp forward failed.') return resp.data + + +def build_model_agent(model_path: str, + cache_config: CacheConfig, + trust_remote_code: bool, + adapters: Dict[str, str] = None, + tp: int = 1): + """create model agent.""" + model_config = ModelConfig.from_pretrained( + model_path, trust_remote_code=trust_remote_code) + if tp == 1: + model_agent = BaseModelAgent(model_path, + model_config=model_config, + cache_config=cache_config, + adapters=adapters, + trust_remote_code=trust_remote_code) + else: + model_agent = TPModelAgent(model_path, + model_config=model_config, + cache_config=cache_config, + world_size=tp, + adapters=adapters, + trust_remote_code=trust_remote_code) + return model_agent diff --git a/lmdeploy/pytorch/engine/request.py b/lmdeploy/pytorch/engine/request.py index ad1d154fa4..7bf9871029 100644 --- a/lmdeploy/pytorch/engine/request.py +++ b/lmdeploy/pytorch/engine/request.py @@ -1,9 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. import enum -from dataclasses import dataclass -from queue import Queue -from threading import Lock -from typing import Any, Callable, Dict, List +from dataclasses import dataclass, field +from queue import Empty, Queue +from threading import Lock, Thread, ThreadError +from typing import Any, Callable, ClassVar, Dict, List from lmdeploy.utils import get_logger @@ -53,6 +53,7 @@ class Response: err_msg: str = '' +@dataclass class RequestSender: """Request sender. @@ -60,12 +61,33 @@ class RequestSender: sender_id (int): The id of the sender """ - def __init__(self, sender_id: int, req_que: Queue): - self._next_req_id = 0 - self.sender_id = sender_id - self.req_que = req_que - self.resp_que = Queue() - self.resp_dict = dict() + sender_id: int + req_que: Queue + resp_que: Queue = field(default_factory=Queue) + resp_dict: Dict[int, List[Response]] = field(default_factory=dict) + THREAD_ALIVE_INTERVAL: ClassVar[float] = 1.0 + _next_req_id: int = 0 + _thread: Thread = None + + @classmethod + def new(cls, sender_id: int, req_que: Queue, thread: Thread): + """new sender.""" + return cls(sender_id=sender_id, req_que=req_que, _thread=thread) + + def _resp_que_get(self, block: bool = True, timeout: float = None): + """warp of resp_que.get.""" + if not block: + return self.resp_que(block=block, timeout=timeout) + timeout_counter = timeout or float(1 << 30) + while timeout_counter > self.THREAD_ALIVE_INTERVAL: + try: + return self.resp_que.get(timeout=self.THREAD_ALIVE_INTERVAL) + except Empty: + timeout_counter -= self.THREAD_ALIVE_INTERVAL + if self._thread and not self._thread.is_alive(): + raise ThreadError('Engine main loop stopped.') + + return self.resp_que.get(timeout=timeout_counter) def _push_resp(self, req_id: int, resp: Response): """push response.""" @@ -86,13 +108,19 @@ def _prefetch_resps(self): """prefetch from resp que.""" num_resps = self.resp_que.qsize() for _ in range(num_resps): - resp: Response = self.resp_que.get() + resp: Response = self._resp_que_get() req_id = resp.req_id self._push_resp(req_id, resp) + def is_thread_alive(self): + """is thread alive.""" + return self._thread and self._thread.is_alive() + def batched_send_async(self, req_types: List[RequestType], data: List[Any]) -> List[int]: """Batched send request asynchronize.""" + if self._thread and not self._thread.is_alive(): + raise ThreadError('Engine main loop stopped.') assert len(req_types) == len(data) batch_size = len(req_types) @@ -125,7 +153,7 @@ def recv_any(self, que_timeout: float = None) -> Response: return ret # check resp que - return self.resp_que.get(timeout=que_timeout) + return self._resp_que_get(timeout=que_timeout) def recv_all(self, req_id: int): """revceive all response with req_id.""" @@ -142,7 +170,7 @@ def recv(self, req_id: int, que_timeout: float = None) -> Response: # check resp que while True: - resp: Response = self.resp_que.get(timeout=que_timeout) + resp: Response = self._resp_que_get(timeout=que_timeout) if resp.req_id != req_id: self._push_resp(req_id, resp) else: @@ -173,12 +201,12 @@ def __init__(self): self.requests = Queue() self.mutex = Lock() - def build_sender(self): + def build_sender(self, thread: Thread = None): """create a new sender.""" with self.mutex: sender_id = self._next_sender_id self._next_sender_id += 1 - new_sender = RequestSender(sender_id, self.requests) + new_sender = RequestSender.new(sender_id, self.requests, thread) self.senders[sender_id] = new_sender return new_sender diff --git a/lmdeploy/pytorch/kernels/fill_kv_cache.py b/lmdeploy/pytorch/kernels/fill_kv_cache.py index d0f72558af..ac999f9d96 100644 --- a/lmdeploy/pytorch/kernels/fill_kv_cache.py +++ b/lmdeploy/pytorch/kernels/fill_kv_cache.py @@ -179,11 +179,20 @@ def _create_fill_cache_info(is_decoding: bool, block_size: int, 4. block_offset1d: which block we want to perform the filling. """ if not is_decoding: - return _prefilling_cache_info(block_size, seq_length, block_offsets, - history_lengths, device) + cache_info = _prefilling_cache_info(block_size, seq_length, + block_offsets, history_lengths, + device) else: - return _decoding_cache_info(block_size, start_loc, seq_length, - block_offsets, history_lengths, device) + cache_info = _decoding_cache_info(block_size, start_loc, seq_length, + block_offsets, history_lengths, + device) + + state_len = cache_info['state_len'] + block_offsets1d = cache_info['block_offsets1d'] + assert state_len.size(0) == block_offsets1d.size(0), ( + f'len(state_len)=={state_len.size(0)} not equal to ' + f'len(block_offsets1d)=={block_offsets1d.size(0)}') + return cache_info @torch.inference_mode() diff --git a/lmdeploy/pytorch/kernels/mbgmm.py b/lmdeploy/pytorch/kernels/mbgmm.py index 206deeb50d..24be1e234e 100644 --- a/lmdeploy/pytorch/kernels/mbgmm.py +++ b/lmdeploy/pytorch/kernels/mbgmm.py @@ -3,6 +3,7 @@ import triton import triton.language as tl from torch import Tensor +from triton.runtime.jit import get_cuda_stream def _next_pow_of_2(x): @@ -17,8 +18,9 @@ def _x_a_mm_kernel( XA, B_start_loc, B_seq_lens, - B_rank_id, + B_adapter_id, Rank_page_table, + Rank_page_start, Ranks, stride_xs, stride_xh, @@ -27,6 +29,7 @@ def _x_a_mm_kernel( stride_xas, stride_xar, stride_ptb, + rank_step, BLOCK_M: tl.constexpr, BLOCK_R: tl.constexpr, BLOCK_H: tl.constexpr, @@ -43,17 +46,16 @@ def _x_a_mm_kernel( return start_loc = tl.load(B_start_loc + cur_batch) - rank_id = tl.load(B_rank_id + cur_batch) - rank = tl.load(Ranks + rank_id) + adapter_id = tl.load(B_adapter_id + cur_batch) + rank = tl.load(Ranks + adapter_id) // rank_step + page_start = tl.load(Rank_page_start + adapter_id) - page_table_off = rank_id * stride_ptb + r_off + page_table_off = adapter_id * stride_ptb + r_off + page_start rank_mask = r_off < rank page_table = tl.load(Rank_page_table + page_table_off, mask=rank_mask) m_off = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - r_off = tl.arange(0, BLOCK_R) dm_off = tl.arange(0, BLOCK_DMODEL) - rank_mask = r_off < rank x_off = (start_loc + m_off) * stride_xs xs_mask = m_off < seq_len @@ -98,8 +100,9 @@ def _acc_b_mm_kernel( Out, B_start_loc, B_seq_lens, - B_rank_id, + B_adapter_id, Rank_page_table, + Rank_page_start, Ranks, stride_xas, stride_xar, @@ -123,16 +126,16 @@ def _acc_b_mm_kernel( return start_loc = tl.load(B_start_loc + cur_batch) - rank_id = tl.load(B_rank_id + cur_batch) - rank = tl.load(Ranks + rank_id) + adapter_id = tl.load(B_adapter_id + cur_batch) + rank = tl.load(Ranks + adapter_id) + page_start = tl.load(Rank_page_start + adapter_id) - page_table_off = rank_id * stride_ptb + r_off + page_table_off = adapter_id * stride_ptb + r_off + page_start rank_mask = r_off < rank page_table = tl.load(Rank_page_table + page_table_off, mask=rank_mask) m_off = start_m * BLOCK_M + tl.arange(0, BLOCK_M) dm_off = tl.arange(0, BLOCK_DMODEL) - rank_mask = r_off < rank lb_page_off = page_table * stride_lbs xs_mask = m_off < seq_len @@ -169,16 +172,33 @@ def _acc_b_mm_kernel( @torch.inference_mode() -def mbgmm_a(x: Tensor, lora_a: Tensor, b_start_loc: Tensor, b_seq_lens: Tensor, - b_rank_ids: Tensor, rank_page_table: Tensor, ranks: Tensor, - max_seq_len: int, max_rank: int): +def mbgmm_a(x: Tensor, + lora_a: Tensor, + b_start_loc: Tensor, + b_seq_lens: Tensor, + b_adapter_ids: Tensor, + rank_page_table: Tensor, + ranks: Tensor, + rank_page_start: Tensor, + max_seq_len: int, + max_rank: int, + rank_step: int = 1): """mbgmm_a.""" + + def _kernel_meta(): + device = x.device + device_idx = device.index + device_type = device.type + stream = get_cuda_stream(device_idx) + return dict(device=device, device_type=device_type, stream=stream) + assert x.dim() == 2 assert lora_a.dim() == 2 assert rank_page_table.dim() == 2 head_size = x.size(-1) batch_size = len(b_seq_lens) + max_rank = max_rank // rank_step BLOCK_M = 32 BLOCK_R = _next_pow_of_2(max_rank) @@ -189,79 +209,97 @@ def mbgmm_a(x: Tensor, lora_a: Tensor, b_start_loc: Tensor, b_seq_lens: Tensor, num_warps = 4 grid = [batch_size, triton.cdiv(max_seq_len, BLOCK_M)] - xa = x.new_empty((x.size(0), BLOCK_R)) - _x_a_mm_kernel[grid]( - x, - lora_a, - xa, - b_start_loc, - b_seq_lens, - b_rank_ids, - Rank_page_table=rank_page_table, - Ranks=ranks, - stride_xs=x.stride(0), - stride_xh=x.stride(1), - stride_las=lora_a.stride(0), - stride_lah=lora_a.stride(1), - stride_xas=xa.stride(0), - stride_xar=xa.stride(1), - stride_ptb=rank_page_table.stride(0), - BLOCK_M=BLOCK_M, - BLOCK_R=BLOCK_R, - BLOCK_H=BLOCK_H, - BLOCK_DMODEL=BLOCK_DMODEL, - num_warps=num_warps, - num_stages=1, - ) + xa = x.new_empty((x.size(0), max_rank)) + kernel_meta = _kernel_meta() + _x_a_mm_kernel[grid](x, + lora_a, + xa, + b_start_loc, + b_seq_lens, + b_adapter_ids, + Rank_page_table=rank_page_table, + Rank_page_start=rank_page_start, + Ranks=ranks, + stride_xs=x.stride(0), + stride_xh=x.stride(1), + stride_las=lora_a.stride(0), + stride_lah=lora_a.stride(1), + stride_xas=xa.stride(0), + stride_xar=xa.stride(1), + stride_ptb=rank_page_table.stride(0), + rank_step=rank_step, + BLOCK_M=BLOCK_M, + BLOCK_R=BLOCK_R, + BLOCK_H=BLOCK_H, + BLOCK_DMODEL=BLOCK_DMODEL, + num_warps=num_warps, + num_stages=1, + **kernel_meta) return xa @torch.inference_mode() -def mbgmm_b(xa: Tensor, lora_b: Tensor, b_start_loc: Tensor, - b_seq_lens: Tensor, b_rank_ids: Tensor, rank_page_table: Tensor, - ranks: Tensor, max_seq_len: int, max_rank: int): +def mbgmm_b(xa: Tensor, + lora_b: Tensor, + b_start_loc: Tensor, + b_seq_lens: Tensor, + b_adapter_ids: Tensor, + rank_page_table: Tensor, + ranks: Tensor, + rank_page_start: Tensor, + max_seq_len: int, + max_rank: int, + out_size: int = None): """mbgmm_b.""" + def _kernel_meta(): + device = xa.device + device_idx = device.index + device_type = device.type + stream = get_cuda_stream(device_idx) + return dict(device=device, device_type=device_type, stream=stream) + assert xa.dim() == 2 assert lora_b.dim() == 2 assert rank_page_table.dim() == 2 - head_o_size = lora_b.size(-1) + if out_size is None: + out_size = lora_b.size(-1) batch_size = len(b_seq_lens) BLOCK_M = 32 BLOCK_R = _next_pow_of_2(max_rank) if BLOCK_R < 16: BLOCK_R = 16 - BLOCK_HO = head_o_size + BLOCK_HO = out_size BLOCK_DMODEL = 64 num_warps = 4 grid = [batch_size, triton.cdiv(max_seq_len, BLOCK_M)] output = xa.new_empty((xa.size(0), BLOCK_HO)) - - _acc_b_mm_kernel[grid]( - xa, - lora_b, - output, - b_start_loc, - b_seq_lens, - b_rank_ids, - Rank_page_table=rank_page_table, - Ranks=ranks, - stride_xas=xa.stride(0), - stride_xar=xa.stride(1), - stride_os=output.stride(0), - stride_oh=output.stride(1), - stride_lbs=lora_b.stride(0), - stride_lbh=lora_b.stride(1), - stride_ptb=rank_page_table.stride(0), - BLOCK_M=BLOCK_M, - BLOCK_R=BLOCK_R, - BLOCK_HO=BLOCK_HO, - BLOCK_DMODEL=BLOCK_DMODEL, - num_warps=num_warps, - num_stages=1, - ) + kernel_meta = _kernel_meta() + _acc_b_mm_kernel[grid](xa, + lora_b, + output, + b_start_loc, + b_seq_lens, + b_adapter_ids, + Rank_page_table=rank_page_table, + Rank_page_start=rank_page_start, + Ranks=ranks, + stride_xas=xa.stride(0), + stride_xar=xa.stride(1), + stride_os=output.stride(0), + stride_oh=output.stride(1), + stride_lbs=lora_b.stride(0), + stride_lbh=lora_b.stride(1), + stride_ptb=rank_page_table.stride(0), + BLOCK_M=BLOCK_M, + BLOCK_R=BLOCK_R, + BLOCK_HO=BLOCK_HO, + BLOCK_DMODEL=BLOCK_DMODEL, + num_warps=num_warps, + num_stages=1, + **kernel_meta) return output diff --git a/lmdeploy/pytorch/kernels/mbgmv.py b/lmdeploy/pytorch/kernels/mbgmv.py index b45d2290ba..e8c92d63d9 100644 --- a/lmdeploy/pytorch/kernels/mbgmv.py +++ b/lmdeploy/pytorch/kernels/mbgmv.py @@ -3,6 +3,7 @@ import triton import triton.language as tl from torch import Tensor +from triton.runtime.jit import get_cuda_stream def _next_pow_of_2(x): @@ -15,8 +16,9 @@ def _x_a_mv_kernel( X, LoRA_A, XA, - B_rank_id, + B_adapter_id, Rank_page_table, + Rank_page_start, Ranks, stride_xs, stride_xh, @@ -25,6 +27,7 @@ def _x_a_mv_kernel( stride_xas, stride_xar, stride_ptb, + rank_step, BLOCK_R: tl.constexpr, BLOCK_H: tl.constexpr, BLOCK_DMODEL: tl.constexpr, @@ -33,15 +36,15 @@ def _x_a_mv_kernel( cur_batch = tl.program_id(0) r_off = tl.arange(0, BLOCK_R) - rank_id = tl.load(B_rank_id + cur_batch) - rank = tl.load(Ranks + rank_id) + adapter_id = tl.load(B_adapter_id + cur_batch) + rank = tl.load(Ranks + adapter_id) // rank_step + page_start = tl.load(Rank_page_start + adapter_id) - page_table_off = rank_id * stride_ptb + r_off + page_table_off = adapter_id * stride_ptb + r_off + page_start rank_mask = r_off < rank page_table = tl.load(Rank_page_table + page_table_off, mask=rank_mask) dm_off = tl.arange(0, BLOCK_DMODEL) - rank_mask = r_off < rank x_off = cur_batch * stride_xs la_page_off = page_table * stride_las @@ -55,7 +58,7 @@ def _x_a_mv_kernel( # load x xh_off = cur_dm_off * stride_xh x_mask = h_mask - x = tl.load(X + x_off + xh_off, mask=x_mask, other=0.0).to(tl.float32) + x = tl.load(X + x_off + xh_off, mask=x_mask, other=0.0) # load lora a lah_off = cur_dm_off * stride_lah @@ -77,8 +80,9 @@ def _acc_b_mv_kernel( XA, LoRA_B, Out, - B_rank_id, + B_adapter_id, Rank_page_table, + Rank_page_start, Ranks, stride_xas, stride_xar, @@ -95,15 +99,15 @@ def _acc_b_mv_kernel( cur_batch = tl.program_id(0) r_off = tl.arange(0, BLOCK_R) - rank_id = tl.load(B_rank_id + cur_batch) - rank = tl.load(Ranks + rank_id) + adapter_id = tl.load(B_adapter_id + cur_batch) + rank = tl.load(Ranks + adapter_id) + page_start = tl.load(Rank_page_start + adapter_id) - page_table_off = rank_id * stride_ptb + r_off + page_table_off = adapter_id * stride_ptb + r_off + page_start rank_mask = r_off < rank page_table = tl.load(Rank_page_table + page_table_off, mask=rank_mask) dm_off = tl.arange(0, BLOCK_DMODEL) - rank_mask = r_off < rank lb_page_off = page_table * stride_lbs o_off = cur_batch * stride_os @@ -133,91 +137,116 @@ def _acc_b_mv_kernel( @torch.inference_mode() -def mbgmv_a(x: Tensor, lora_a: Tensor, b_rank_ids: Tensor, - rank_page_table: Tensor, ranks: Tensor, max_rank: int): +def mbgmv_a(x: Tensor, + lora_a: Tensor, + b_adapter_ids: Tensor, + rank_page_table: Tensor, + ranks: Tensor, + rank_page_start: Tensor, + max_rank: int, + rank_step: int = 1): """mbgmv_a.""" + def _kernel_meta(): + device = x.device + device_idx = device.index + device_type = device.type + stream = get_cuda_stream(device_idx) + return dict(device=device, device_type=device_type, stream=stream) + assert x.dim() == 2 assert lora_a.dim() == 2 assert rank_page_table.dim() == 2 head_size = x.size(-1) batch_size = x.size(0) + max_rank = max_rank // rank_step BLOCK_R = _next_pow_of_2(max_rank) - if BLOCK_R < 16: - BLOCK_R = 16 BLOCK_H = head_size - BLOCK_DMODEL = 64 + BLOCK_DMODEL = 512 num_warps = 4 grid = [batch_size] xa = x.new_empty((x.size(0), BLOCK_R)) - - _x_a_mv_kernel[grid]( - x, - lora_a, - xa, - b_rank_ids, - Rank_page_table=rank_page_table, - Ranks=ranks, - stride_xs=x.stride(0), - stride_xh=x.stride(1), - stride_las=lora_a.stride(0), - stride_lah=lora_a.stride(1), - stride_xas=xa.stride(0), - stride_xar=xa.stride(1), - stride_ptb=rank_page_table.stride(0), - BLOCK_R=BLOCK_R, - BLOCK_H=BLOCK_H, - BLOCK_DMODEL=BLOCK_DMODEL, - num_warps=num_warps, - num_stages=1, - ) + kernel_meta = _kernel_meta() + _x_a_mv_kernel[grid](x, + lora_a, + xa, + b_adapter_ids, + Rank_page_table=rank_page_table, + Rank_page_start=rank_page_start, + Ranks=ranks, + stride_xs=x.stride(0), + stride_xh=x.stride(1), + stride_las=lora_a.stride(0), + stride_lah=lora_a.stride(1), + stride_xas=xa.stride(0), + stride_xar=xa.stride(1), + stride_ptb=rank_page_table.stride(0), + rank_step=rank_step, + BLOCK_R=BLOCK_R, + BLOCK_H=BLOCK_H, + BLOCK_DMODEL=BLOCK_DMODEL, + num_warps=num_warps, + num_stages=1, + **kernel_meta) return xa @torch.inference_mode() -def mbgmv_b(xa: Tensor, lora_b: Tensor, b_rank_ids: Tensor, - rank_page_table: Tensor, ranks: Tensor, max_rank: int): +def mbgmv_b(xa: Tensor, + lora_b: Tensor, + b_adapter_ids: Tensor, + rank_page_table: Tensor, + ranks: Tensor, + rank_page_start: Tensor, + max_rank: int, + out_size: int = None): """mbgmv_b.""" + def _kernel_meta(): + device = xa.device + device_idx = device.index + device_type = device.type + stream = get_cuda_stream(device_idx) + return dict(device=device, device_type=device_type, stream=stream) + assert xa.dim() == 2 assert lora_b.dim() == 2 assert rank_page_table.dim() == 2 - head_o_size = lora_b.size(-1) + if out_size is None: + out_size = lora_b.size(-1) batch_size = xa.size(0) BLOCK_R = _next_pow_of_2(max_rank) - if BLOCK_R < 16: - BLOCK_R = 16 - BLOCK_HO = head_o_size - BLOCK_DMODEL = 64 + BLOCK_HO = out_size + BLOCK_DMODEL = 512 num_warps = 4 grid = [batch_size] output = xa.new_empty((xa.size(0), BLOCK_HO)) - - _acc_b_mv_kernel[grid]( - xa, - lora_b, - output, - b_rank_ids, - Rank_page_table=rank_page_table, - Ranks=ranks, - stride_xas=xa.stride(0), - stride_xar=xa.stride(1), - stride_lbs=lora_b.stride(0), - stride_lbh=lora_b.stride(1), - stride_os=output.stride(0), - stride_oh=output.stride(1), - stride_ptb=rank_page_table.stride(0), - BLOCK_R=BLOCK_R, - BLOCK_HO=BLOCK_HO, - BLOCK_DMODEL=BLOCK_DMODEL, - num_warps=num_warps, - num_stages=1, - ) + kernel_meta = _kernel_meta() + _acc_b_mv_kernel[grid](xa, + lora_b, + output, + b_adapter_ids, + Rank_page_table=rank_page_table, + Rank_page_start=rank_page_start, + Ranks=ranks, + stride_xas=xa.stride(0), + stride_xar=xa.stride(1), + stride_lbs=lora_b.stride(0), + stride_lbh=lora_b.stride(1), + stride_os=output.stride(0), + stride_oh=output.stride(1), + stride_ptb=rank_page_table.stride(0), + BLOCK_R=BLOCK_R, + BLOCK_HO=BLOCK_HO, + BLOCK_DMODEL=BLOCK_DMODEL, + num_warps=num_warps, + num_stages=1, + **kernel_meta) return output diff --git a/lmdeploy/pytorch/kernels/pagedattention.py b/lmdeploy/pytorch/kernels/pagedattention.py index 9212eb9568..e88cd1f874 100644 --- a/lmdeploy/pytorch/kernels/pagedattention.py +++ b/lmdeploy/pytorch/kernels/pagedattention.py @@ -10,12 +10,16 @@ @triton.jit -def _load_block_offsets(offset_ptr, block_id, is_unified_paging: tl.constexpr, +def _load_block_offsets(offset_ptr, block_id, num_sub_blocks: tl.constexpr, BLOCK: tl.constexpr): - offs_n = tl.arange(0, BLOCK) - if is_unified_paging: - return tl.load(offset_ptr + block_id * BLOCK + offs_n) + if num_sub_blocks > 1: + offs_sub = tl.arange(0, num_sub_blocks) + offs_n = tl.arange(0, BLOCK // num_sub_blocks) + ret = tl.load(offset_ptr + block_id * num_sub_blocks + offs_sub)[ + None, :] * BLOCK // num_sub_blocks + offs_n[:, None] + return tl.ravel(ret) else: + offs_n = tl.arange(0, BLOCK) return tl.load(offset_ptr + block_id) * BLOCK + offs_n @@ -44,7 +48,7 @@ def _fwd_split_kernel( stride_boffb, kv_group_num, block_per_cta, - is_unified_paging: tl.constexpr, + num_sub_blocks: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, ): @@ -86,7 +90,7 @@ def _fwd_split_kernel( # load block offset start_block_id = loop_start // BLOCK_N b_offset = _load_block_offsets(block_offset_ptrs, start_block_id, - is_unified_paging, BLOCK_N) + num_sub_blocks, BLOCK_N) for start_n in range(loop_start, loop_end, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) @@ -110,7 +114,7 @@ def _fwd_split_kernel( if start_n + BLOCK_N < loop_end: start_block_id += 1 b_offset = _load_block_offsets(block_offset_ptrs, start_block_id, - is_unified_paging, BLOCK_N) + num_sub_blocks, BLOCK_N) qk = tl.sum(q[None, :] * k, 1) qk *= sm_scale @@ -219,7 +223,7 @@ def _fwd_kernel( stride_od, stride_boffb, kv_group_num, - is_unified_paging: tl.constexpr, + num_sub_blocks: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, @@ -261,7 +265,7 @@ def _fwd_kernel( block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - b_offset = _load_block_offsets(block_offset_ptrs, 0, is_unified_paging, + b_offset = _load_block_offsets(block_offset_ptrs, 0, num_sub_blocks, BLOCK_N) for start_n in range(0, block_mask * cur_batch_kv_len, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) @@ -281,7 +285,7 @@ def _fwd_kernel( if start_n + BLOCK_N < cur_batch_kv_len: start_block_id = start_n // BLOCK_N + 1 b_offset = _load_block_offsets(block_offset_ptrs, start_block_id, - is_unified_paging, BLOCK_N) + num_sub_blocks, BLOCK_N) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) @@ -362,8 +366,8 @@ def _kernel_meta(): num_warps = 4 if Lk <= 64 else 8 - is_unified_paging = k.size(1) == 1 - BLOCK = 64 if is_unified_paging else k.size(1) + BLOCK = 64 if k.size(1) < 16 else k.size(1) + num_sub_blocks = BLOCK // k.size(1) kernel_meta = _kernel_meta() is_decoding = q.shape[-3] == b_seq_len.size(0) @@ -392,7 +396,7 @@ def _kernel_meta(): o.stride(-1), block_offsets.stride(0), kv_group_num=kv_group_num, - is_unified_paging=is_unified_paging, + num_sub_blocks=num_sub_blocks, BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, BLOCK_N=BLOCK, @@ -427,7 +431,7 @@ def _kernel_meta(): stride_boffb=block_offsets.stride(0), kv_group_num=kv_group_num, block_per_cta=block_per_cta, - is_unified_paging=is_unified_paging, + num_sub_blocks=num_sub_blocks, BLOCK_DMODEL=Lk, BLOCK_N=BLOCK, num_warps=4, diff --git a/lmdeploy/pytorch/kernels/rearange_all_gather.py b/lmdeploy/pytorch/kernels/rearange_all_gather.py new file mode 100644 index 0000000000..868f0558a0 --- /dev/null +++ b/lmdeploy/pytorch/kernels/rearange_all_gather.py @@ -0,0 +1,134 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import triton +import triton.language as tl +from triton.runtime.jit import get_cuda_stream + + +@triton.jit +def _rearange_all_gather_kernel(X, StartLoc, SeqLen, AdapterIds, Ranks, Out, + stride_x, stride_o, world_size, + BLOCK: tl.constexpr, BLOCK_P: tl.constexpr): + """rearange all gather kernel.""" + batch_id = tl.program_id(0) + block_id = tl.program_id(1) + + start_loc = tl.load(StartLoc + batch_id) + block_id * BLOCK + seq_len = tl.load(SeqLen + batch_id) + + if block_id * BLOCK >= seq_len: + return + + block_off = start_loc + tl.arange(0, BLOCK) + block_mask = block_id * BLOCK + tl.arange(0, BLOCK) < seq_len + + adapter_id = tl.load(AdapterIds + batch_id) + rank = tl.load(Ranks + adapter_id) + prank = rank // world_size + p_off = tl.arange(0, BLOCK_P) + + for p_id in range(world_size): + ip_off = p_id * BLOCK_P + p_off + i_mask = block_mask[:, None] and (p_off < prank)[None, :] + i_off = block_off[:, None] * stride_x + ip_off[None, :] + x = tl.load(X + i_off, mask=i_mask) + + op_off = p_id * prank + p_off + o_mask = i_mask + o_off = block_off[:, None] * stride_o + op_off[None, :] + tl.store(Out + o_off, x, mask=o_mask) + + +@triton.jit +def _rearange_all_gather_decoding_kernel(X, AdapterIds, Ranks, Out, stride_x, + stride_o, world_size, seq_len, + BLOCK: tl.constexpr, + BLOCK_P: tl.constexpr): + """rearange all gather kernel.""" + block_id = tl.program_id(0) + block_off = block_id * BLOCK + tl.arange(0, BLOCK) + block_mask = block_off < seq_len + + adapter_ids = tl.load(AdapterIds + block_off, mask=block_mask) + ranks = tl.load(Ranks + adapter_ids) + pranks = ranks // world_size + p_off = tl.arange(0, BLOCK_P) + + for p_id in range(world_size): + ip_off = p_id * BLOCK_P + p_off + i_mask = block_mask[:, None] and (p_off[None, :] < pranks[:, None]) + i_off = block_off[:, None] * stride_x + ip_off[None, :] + x = tl.load(X + i_off, mask=i_mask) + + op_off = p_id * pranks[:, None] + p_off[None, :] + o_mask = i_mask + o_off = block_off[:, None] * stride_o + op_off + tl.store(Out + o_off, x, mask=o_mask) + + +def rearange_all_gather(x: torch.Tensor, + b_start_loc: torch.Tensor, + b_seq_lens: torch.Tensor, + adapter_ids: torch.LongTensor, + ranks: torch.Tensor, + world_size: int, + max_seq_len: int, + output: torch.Tensor = None): + """rearange all gather.""" + + def _kernel_meta(): + device = x.device + device_idx = device.index + device_type = device.type + stream = get_cuda_stream(device_idx) + return dict(device=device, device_type=device_type, stream=stream) + + max_rank = x.size(1) + batch_size = len(b_seq_lens) + partition_size = max_rank // world_size + + if output is None: + output = torch.empty_like(x) + + num_warps = 4 + kernel_meta = _kernel_meta() + + is_decoding = batch_size == x.size(0) + if not is_decoding: + BLOCK = 128 + BLOCK_P = partition_size + grid = (batch_size, triton.cdiv(max_seq_len, BLOCK)) + _rearange_all_gather_kernel[grid](x, + b_start_loc, + b_seq_lens, + adapter_ids, + ranks, + output, + stride_x=x.stride(0), + stride_o=output.stride(0), + world_size=world_size, + BLOCK=BLOCK, + BLOCK_P=BLOCK_P, + num_warps=num_warps, + num_stages=1, + **kernel_meta) + else: + BLOCK = 64 + BLOCK_P = partition_size + seq_len = x.size(0) + grid = (triton.cdiv(seq_len, BLOCK), ) + _rearange_all_gather_decoding_kernel[grid](x, + adapter_ids, + ranks, + output, + stride_x=x.stride(0), + stride_o=output.stride(0), + world_size=world_size, + seq_len=seq_len, + BLOCK=BLOCK, + BLOCK_P=BLOCK_P, + num_warps=num_warps, + num_stages=1, + **kernel_meta) + + return output diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index 14649124d3..791c7d8721 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -8,6 +8,8 @@ import torch from torch import Tensor +from lmdeploy.messages import EngineGenerationConfig + from .block import LogicalTokenBlocks @@ -34,6 +36,19 @@ def __init__( self.stop_words = stop_words self.bad_words = bad_words + @classmethod + def from_gen_config(self, gen_config: EngineGenerationConfig): + """from gen config.""" + + return SamplingParam(top_p=gen_config.top_p, + top_k=gen_config.top_k, + temperature=gen_config.temperature, + repetition_penalty=gen_config.repetition_penalty, + ignore_eos=gen_config.ignore_eos, + random_seed=gen_config.random_seed, + stop_words=gen_config.stop_words, + bad_words=gen_config.bad_words) + class MessageStatus(enum.Enum): """Status of a sequence.""" @@ -69,7 +84,7 @@ def add_sequence(self, token_ids: Tensor, max_output_len: int = 512, sampling_param: SamplingParam = None, - adapter_id: int = -1) -> 'SchedulerSequence': + adapter_name: str = None) -> 'SchedulerSequence': """Add a new message.""" if not isinstance(token_ids, Tensor): token_ids = torch.tensor(token_ids) @@ -85,7 +100,7 @@ def add_sequence(self, status=MessageStatus.WAITING, remain_output_len=max_output_len, sampling_param=sampling_param, - adapter_id=adapter_id, + adapter_name=adapter_name, arrive_time=time.time()) self.sequences[seq.seq_id] = seq return seq @@ -115,7 +130,7 @@ def fork_sequence( sampling_param=sampling_param, status=seq.status, logical_blocks=seq.logical_blocks.clone(), - adapter_id=seq.adapter_id, + adapter_name=seq.adapter_name, arrive_time=time.time(), meta=deepcopy(seq.meta)) @@ -137,7 +152,7 @@ class SchedulerSequence: logical_blocks: LogicalTokenBlocks = None sender_id: int = -1 req_id: int = -1 - adapter_id: int = -1 + adapter_name: str = None arrive_time: float = 0.0 meta: Any = None diff --git a/lmdeploy/pytorch/models/chatglm2.py b/lmdeploy/pytorch/models/chatglm2.py index 98b18fe4ff..1ebb259339 100644 --- a/lmdeploy/pytorch/models/chatglm2.py +++ b/lmdeploy/pytorch/models/chatglm2.py @@ -184,7 +184,7 @@ def _contiguous_batching_forward( if kv_cache is not None: cache_k, cache_v = kv_cache q_start_loc = context.q_start_loc - q_seq_length = context.q_seq_length + q_seq_length = context.seq_length q_start_loc: torch.Tensor history_lengths = q_seq_length.new_tensor(history_lengths) diff --git a/lmdeploy/pytorch/models/falcon.py b/lmdeploy/pytorch/models/falcon.py index a76a653fc1..2b2a4c9c3f 100644 --- a/lmdeploy/pytorch/models/falcon.py +++ b/lmdeploy/pytorch/models/falcon.py @@ -235,7 +235,7 @@ def _contiguous_batching_forward( history_lengths = context.history_lengths q_start_loc = context.q_start_loc - q_seq_length = context.q_seq_length + q_seq_length = context.seq_length history_lengths = q_seq_length.new_tensor(history_lengths) kv_seq_length = q_seq_length + history_lengths max_seq_len = q_seq_length.max().item() diff --git a/lmdeploy/pytorch/models/functional.py b/lmdeploy/pytorch/models/functional.py index 2ca7589fe4..23d0e3c1a5 100644 --- a/lmdeploy/pytorch/models/functional.py +++ b/lmdeploy/pytorch/models/functional.py @@ -155,22 +155,16 @@ def attention_forward_with_paged_attention( kv_seq_length = getattr(context, 'kv_seq_length', None) if kv_seq_length is None: kv_seq_length = position_ids[..., -1] + 1 - if context is not None: - context.kv_seq_length = kv_seq_length q_seq_length = getattr(context, 'seq_length', None) if q_seq_length is None: q_seq_length = kv_seq_length - kv_seq_length.new_tensor( history_lengths) - if context is not None: - context.q_seq_length = q_seq_length q_start_loc = getattr(context, 'q_start_loc', None) if q_start_loc is None: q_start_loc = q_seq_length.cumsum(0) q_start_loc = torch.cat([q_start_loc.new_zeros(1), q_start_loc[:-1]]) - if context is not None: - context.q_start_loc = q_start_loc fill_kv_cache(key_states, value_states, diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index 6bd91c393c..bc167fdae4 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -84,3 +84,9 @@ 'modeling_internlm.InternLMMLP': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP', }) + +# peft +MODULE_MAP.update({ + 'peft.tuners.lora.layer.Linear': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.peft.LoRALinear' +}) diff --git a/lmdeploy/pytorch/models/patch.py b/lmdeploy/pytorch/models/patch.py index 791160e440..037abd1132 100644 --- a/lmdeploy/pytorch/models/patch.py +++ b/lmdeploy/pytorch/models/patch.py @@ -6,7 +6,6 @@ from typing import Any, Dict, Sequence import torch -import torch.distributed as dist from addict import Addict from torch.distributed._tensor import DeviceMesh @@ -154,117 +153,35 @@ def _update_model(model: torch.nn.Module): model._update_model_fn() -def _params_to_meta(model: torch.nn.Module): - """move parameters to meta device.""" - # recursive over children - for _, child in model.named_children(): - _params_to_meta(child) - - for k, v in model.named_parameters(recurse=False): - model.register_parameter( - k, torch.nn.Parameter(v.to('meta'), requires_grad=False)) - - -def _load_state_dict( - model: torch.nn.Module, - state_dict: Dict[str, torch.Tensor] = None, - rank: int = 0, - world_size: int = 1, - device_mesh: DeviceMesh = None, - state_prefix: str = '', -): - """Load state dict by rank. - - Load full state dict into device memory is not possible in LLM. - This method load shard and partition weights for different - distribution rank - - Args: - model (Module): Model to load weight. - state_dict (Dict[str, Tensor]): State dict object. - rank (int): Distribution rank. - world_size (int): Distribution world size. - device_mesh (DeviceMesh): Distribution device mesh. - state_prefix (str): The prefix of state dict. - - Returns: - Module: Updated model - """ +def _dist_model(model: torch.nn.Module, + rank: int = 0, + device_mesh: DeviceMesh = None): + """distribute model parameters.""" - def _recursive_children(): - """recursive children.""" - for name, child in model.named_children(): - loaded_child = _load_state_dict( - child, - state_dict, - rank, - world_size, - device_mesh=device_mesh, - state_prefix=f'{state_prefix}{name}.', - ) - if loaded_child != child: - model.register_module(name, loaded_child) - - def _init_parameters(): - """init parameters.""" - model_state_dict = model.state_dict() + def _init_params(): + """init params.""" device = torch.device(f'cuda:{rank}') - for k, v in model_state_dict.items(): - if '.' in k or not v.is_meta: - # only process weight that directly owned by module - # already initialized - continue + for name, param in model.named_parameters(recurse=False): + if rank == 0: + if device != param.device: + new_param = param.to(device) + model.register_parameter(name, + torch.nn.Parameter(new_param)) + else: + new_param = torch.empty_like(param, device=device) + model.register_parameter(name, torch.nn.Parameter(new_param)) - full_k = state_prefix + k + for name, param in model.named_buffers(recurse=False): if rank == 0: - objs = [full_k in state_dict] + if device != param.device: + new_param = param.to(device) + model.register_buffer(name, new_param) else: - objs = [None] - dist.broadcast_object_list(objs, 0) - in_state_dict = objs[0] - - if not in_state_dict: - continue - - param_names = [ - name for name, _ in model.named_parameters(recurse=False) - ] - if k in param_names: - if rank == 0: - new_param = torch.nn.Parameter( - state_dict[full_k].to(v.dtype), - requires_grad=False).to(device) - else: - new_param = torch.nn.Parameter(torch.empty_like( - v, device=device), - requires_grad=False) - model.register_parameter(k, new_param) - - # Weight, bias and scale are registered as buffer in QLinear - buffer_names = [ - name for name, _ in model.named_buffers(recurse=False) - ] - if k in buffer_names: - if rank == 0: - new_buffer = state_dict[full_k].to(v.dtype).to(device) - else: - new_buffer = torch.empty_like(v, device=device) - model.register_buffer(k, new_buffer) - - def _check_need_dist(model): - """check need dist.""" - need_dist = not getattr(model, '__tp_distributed__', False) - finish_param_init = all(not v.is_meta - for v in model.state_dict().values()) - return need_dist and finish_param_init - - _recursive_children() - _init_parameters() - - # distribute module - if world_size > 1 and _check_need_dist(model): - model.__tp_distributed__ = True + new_param = torch.empty_like(param, device=device) + model.register_buffer(name, new_param) + def _dist_params(): + """dist params.""" if hasattr(model, '_distribute_partition_fn'): partition_module( model, @@ -275,6 +192,8 @@ def _check_need_dist(model): else: replicate_module(model, device_mesh=device_mesh) + def _register_hooks(): + """register hooks.""" if hasattr(model, '_distribute_input_fn'): input_fn = model._distribute_input_fn model.register_forward_pre_hook( @@ -288,15 +207,43 @@ def _check_need_dist(model): model.register_forward_hook( lambda mod, inputs, outputs: output_fn(outputs, device_mesh)) + for name, child in model.named_children(): + new_child = _dist_model(child, rank, device_mesh) + if new_child != child: + model.register_module(name, child) + + _init_params() + _dist_params() + _register_hooks() + return model +class PatchedForward: + """patched forward.""" + + def __init__(self, model, context, extra_args): + self._model = model + self._patch_context: Dict = context + self._extra_args: list = extra_args + + def __call__(self, *args, **kwargs): + for arg_name in self._extra_args: + extra_arg = kwargs.pop(arg_name, None) + self._patch_context[arg_name] = extra_arg + + output = self._model(*args, **kwargs) + + self._patch_context.clear() + + return output + + def patch( model: torch.nn.Module, extra_args: Sequence[str] = None, rank: int = 0, world_size: int = 1, - checkpoints: Sequence[str] = None, ): """Patch the model with rewrite modules. @@ -308,33 +255,11 @@ def patch( extra_args (Sequence[str]): Extra arguments of model forward. rank (int): Distribution rank. world_size (int): Distribution world size. - checkpoints (Sequence[str]): checkpoints of the model. Returns: Module: The patched model. """ - def _load_checkpoints(model, checkpoints, rank, world_size): - """load checkpoints.""" - _params_to_meta(model) - device_mesh = DeviceMesh('cuda', list(range(world_size))) - for ckpt in checkpoints: - if rank == 0: - logger = get_logger('lmdeploy') - logger.info(f'loading checkpoint from: {ckpt}') - state_dict = torch.load(ckpt, map_location=f'cuda:{rank}') - else: - state_dict = None - - with torch.cuda.device(rank): - _load_state_dict( - model, - state_dict, - rank=rank, - world_size=world_size, - device_mesh=device_mesh, - ) - if extra_args is None: extra_args = [] @@ -342,37 +267,17 @@ def _load_checkpoints(model, checkpoints, rank, world_size): model = _patch(model, _patch_context) - # load checkpoint - if checkpoints is not None: - _load_checkpoints(model, checkpoints, rank, world_size) + if world_size > 1: + if rank == 0: + logger.info('distribute model parameters.') + device_mesh = DeviceMesh('cuda', list(range(world_size))) + model = _dist_model(model, rank, device_mesh=device_mesh) _update_model(model) - extra_args_str = ' '.join(f'{arg}=None,' for arg in extra_args) - context_update_str = ' '.join(f'{arg}={arg},' for arg in extra_args) - - wrap_forward_src = f""" -from functools import wraps -# old_forward = model.forward -old_forward = type(model).forward -@wraps(old_forward) -def wrap_forward(self, *args, {extra_args_str} **kwargs): - global _patch_context - _patch_context.update({context_update_str}) - - output = old_forward(self, *args, **kwargs) - - _patch_context.clear() - - return output -# model.forward = wrap_forward - -attrs = dict(type(model).__dict__) -attrs.update(dict(forward=wrap_forward)) -class_name = model.__class__.__name__ -new_type = type(class_name, (type(model), ), attrs) -model.__class__ = new_type -""" - exec(wrap_forward_src, dict(_patch_context=_patch_context, model=model)) + patched_forward = PatchedForward(model, + _patch_context, + extra_args=extra_args) + model.patched_forward = patched_forward return model diff --git a/lmdeploy/pytorch/models/peft.py b/lmdeploy/pytorch/models/peft.py new file mode 100644 index 0000000000..fc092df504 --- /dev/null +++ b/lmdeploy/pytorch/models/peft.py @@ -0,0 +1,273 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from dataclasses import dataclass + +import torch +import torch.distributed as dist + +from ..kernels.mbgmm import mbgmm_a, mbgmm_b +from ..kernels.mbgmv import mbgmv_a, mbgmv_b +from ..kernels.rearange_all_gather import rearange_all_gather + + +@dataclass +class PackedLoRAInput: + x: torch.Tensor + a_cache: torch.Tensor + b_cache: torch.Tensor + b_start_loc: torch.Tensor + b_seq_lens: torch.Tensor + b_adapter_ids: torch.Tensor + rank_page_table: torch.Tensor + rank_page_start: torch.Tensor + ranks: torch.Tensor + max_seq_len: int + max_rank: int + is_decoding: bool + + +class LoRALinear(torch.nn.Module): + + def _make_packed_lora_input(self, x): + context = self.context.context + + # adapter cache + global_adapter_ids = context.global_adapter_ids + layer_idx = self.layer_idx + ranks = self.ranks[global_adapter_ids] + block_starts = self.block_starts[global_adapter_ids] + k_cache, v_cache = context.kv_caches[layer_idx] + cache_len = k_cache.size(0) + a_cache = k_cache.view(cache_len, -1) + b_cache = v_cache.view(cache_len, -1) + + return PackedLoRAInput(x=x.flatten(0, -2).contiguous(), + a_cache=a_cache, + b_cache=b_cache, + b_start_loc=context.q_start_loc, + b_seq_lens=context.seq_length, + b_adapter_ids=context.local_adapter_ids, + rank_page_table=context.adapter_offsets, + rank_page_start=block_starts, + ranks=ranks, + max_seq_len=context.max_seq_length, + max_rank=context.max_rank, + is_decoding=context.is_decoding) + + def _lora_forward_local(self, x): + """lora forward no tp.""" + + lora_input = self._make_packed_lora_input(x) + + out_size = self.base_layer.weight.size(0) + if not lora_input.is_decoding: + xa = mbgmm_a(lora_input.x, + lora_input.a_cache, + b_start_loc=lora_input.b_start_loc, + b_seq_lens=lora_input.b_seq_lens, + b_adapter_ids=lora_input.b_adapter_ids, + rank_page_table=lora_input.rank_page_table, + rank_page_start=lora_input.rank_page_start, + ranks=lora_input.ranks, + max_seq_len=lora_input.max_seq_len, + max_rank=lora_input.max_rank) + lora_out = mbgmm_b(xa, + lora_input.b_cache, + b_start_loc=lora_input.b_start_loc, + b_seq_lens=lora_input.b_seq_lens, + b_adapter_ids=lora_input.b_adapter_ids, + rank_page_table=lora_input.rank_page_table, + rank_page_start=lora_input.rank_page_start, + ranks=lora_input.ranks, + max_seq_len=lora_input.max_seq_len, + max_rank=lora_input.max_rank, + out_size=out_size) + else: + xa = mbgmv_a(lora_input.x, + lora_input.a_cache, + b_adapter_ids=lora_input.b_adapter_ids, + rank_page_table=lora_input.rank_page_table, + rank_page_start=lora_input.rank_page_start, + ranks=lora_input.ranks, + max_rank=lora_input.max_rank) + lora_out = mbgmv_b(xa, + lora_input.b_cache, + b_adapter_ids=lora_input.b_adapter_ids, + rank_page_table=lora_input.rank_page_table, + rank_page_start=lora_input.rank_page_start, + ranks=lora_input.ranks, + max_rank=lora_input.max_rank, + out_size=out_size) + + base_out = self.base_layer(x) + lora_out = lora_out.reshape(base_out.shape) + output = base_out + lora_out + + return output + + def _lora_forward_tp_rowwise(self, x): + """lora forward tp rowwise.""" + + lora_input = self._make_packed_lora_input(x) + rank = dist.get_world_size() + world_size = dist.get_world_size() + out_size = self.base_layer.weight.size(0) // world_size + if not lora_input.is_decoding: + xa = mbgmm_a(lora_input.x, + lora_input.a_cache, + b_start_loc=lora_input.b_start_loc, + b_seq_lens=lora_input.b_seq_lens, + b_adapter_ids=lora_input.b_adapter_ids, + rank_page_table=lora_input.rank_page_table, + rank_page_start=lora_input.rank_page_start, + ranks=lora_input.ranks, + max_seq_len=lora_input.max_seq_len, + max_rank=lora_input.max_rank) + lora_out = mbgmm_b(xa, + lora_input.b_cache, + b_start_loc=lora_input.b_start_loc, + b_seq_lens=lora_input.b_seq_lens, + b_adapter_ids=lora_input.b_adapter_ids, + rank_page_table=lora_input.rank_page_table, + rank_page_start=lora_input.rank_page_start, + ranks=lora_input.ranks, + max_seq_len=lora_input.max_seq_len, + max_rank=lora_input.max_rank, + out_size=out_size) + else: + xa = mbgmv_a(lora_input.x, + lora_input.a_cache, + b_adapter_ids=lora_input.b_adapter_ids, + rank_page_table=lora_input.rank_page_table, + rank_page_start=lora_input.rank_page_start, + ranks=lora_input.ranks, + max_rank=lora_input.max_rank) + lora_out = mbgmv_b(xa, + lora_input.b_cache, + b_adapter_ids=lora_input.b_adapter_ids, + rank_page_table=lora_input.rank_page_table, + rank_page_start=lora_input.rank_page_start, + ranks=lora_input.ranks, + max_rank=lora_input.max_rank, + out_size=out_size) + + base_out = self.base_layer(x) + out_shape = base_out.shape + base_out = base_out.flatten(0, -2) + + slice_start = rank * out_size + slice_end = slice_start + out_size + base_out[:, slice_start:slice_end] += lora_out + base_out = base_out.reshape(out_shape) + + return base_out + + def _lora_forward_tp_colwise(self, x): + """lora forward tp colwise.""" + + def __gather_xa(xa): + """gather xa.""" + gathered_xa = xa.new_empty(world_size, xa.size(0), xa.size(1)) + dist.all_gather_into_tensor(gathered_xa, xa) + # TODO: gather would failed when adapters have different ranks. + gathered_xa = gathered_xa.permute(1, 0, 2).flatten(-2, -1) + return gathered_xa + + lora_input = self._make_packed_lora_input(x) + world_size = dist.get_world_size() + out_size = self.base_layer.weight.size(0) + if not lora_input.is_decoding: + xa = mbgmm_a(lora_input.x, + lora_input.a_cache, + b_start_loc=lora_input.b_start_loc, + b_seq_lens=lora_input.b_seq_lens, + b_adapter_ids=lora_input.b_adapter_ids, + rank_page_table=lora_input.rank_page_table, + rank_page_start=lora_input.rank_page_start, + ranks=lora_input.ranks, + max_seq_len=lora_input.max_seq_len, + max_rank=lora_input.max_rank, + rank_step=world_size) + gathered_xa = __gather_xa(xa) + if len(lora_input.ranks) > 1: + gathered_xa = rearange_all_gather( + gathered_xa, + b_start_loc=lora_input.b_start_loc, + b_seq_lens=lora_input.b_seq_lens, + adapter_ids=lora_input.b_adapter_ids, + ranks=lora_input.ranks, + world_size=world_size, + max_seq_len=lora_input.max_seq_len, + output=gathered_xa) + lora_out = mbgmm_b(gathered_xa, + lora_input.b_cache, + b_start_loc=lora_input.b_start_loc, + b_seq_lens=lora_input.b_seq_lens, + b_adapter_ids=lora_input.b_adapter_ids, + rank_page_table=lora_input.rank_page_table, + rank_page_start=lora_input.rank_page_start, + ranks=lora_input.ranks, + max_seq_len=lora_input.max_seq_len, + max_rank=lora_input.max_rank, + out_size=out_size) + else: + xa = mbgmv_a(lora_input.x, + lora_input.a_cache, + b_adapter_ids=lora_input.b_adapter_ids, + rank_page_table=lora_input.rank_page_table, + rank_page_start=lora_input.rank_page_start, + ranks=lora_input.ranks, + max_rank=lora_input.max_rank, + rank_step=world_size) + gathered_xa = __gather_xa(xa) + if len(lora_input.ranks) > 1: + gathered_xa = rearange_all_gather( + gathered_xa, + b_start_loc=lora_input.b_start_loc, + b_seq_lens=lora_input.b_seq_lens, + adapter_ids=lora_input.b_adapter_ids, + ranks=lora_input.ranks, + world_size=world_size, + max_seq_len=lora_input.max_seq_len, + output=gathered_xa) + lora_out = mbgmv_b(gathered_xa, + lora_input.b_cache, + b_adapter_ids=lora_input.b_adapter_ids, + rank_page_table=lora_input.rank_page_table, + rank_page_start=lora_input.rank_page_start, + ranks=lora_input.ranks, + max_rank=lora_input.max_rank, + out_size=out_size) + + base_out = self.base_layer(x) + lora_out = lora_out.reshape(base_out.shape) + output = base_out + lora_out + + return output + + def _lora_forward_tp(self, x): + """lora forward tp.""" + tp_mode = getattr(self, '_tp_mode', None) + if tp_mode == 'rowwise': + return self._lora_forward_tp_rowwise(x) + elif tp_mode == 'colwise': + return self._lora_forward_tp_colwise(x) + else: + assert tp_mode is None, 'tp_mode == None failed.' + return self._lora_forward_tp(x) + + def _lora_forward(self, x): + """lora forward.""" + if dist.is_initialized(): + return self._lora_forward_tp(x) + else: + return self._lora_forward_local(x) + + def forward(self, x): + """forward.""" + context = self.context.context + max_rank = context.max_rank + + if max_rank == 0: + return self.origin_mod.forward(x) + else: + return self._lora_forward(x) diff --git a/lmdeploy/pytorch/paging/block_manager.py b/lmdeploy/pytorch/paging/block_manager.py index 15fb39d643..555b7e00ac 100644 --- a/lmdeploy/pytorch/paging/block_manager.py +++ b/lmdeploy/pytorch/paging/block_manager.py @@ -1,9 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. # modify from: https://github.com/vllm-project/vllm -from typing import Dict +from typing import Dict, Union import numpy as np +from ..adapter.adapter import ADAPTER_MANAGER, SchedulerAdapter from ..messages import SchedulerSequence @@ -17,6 +18,9 @@ def __init__(self, num_blocks: int) -> None: def get_physical_blocks(self, logical_address: np.ndarray): """get physical address.""" + if isinstance(logical_address, + np.ndarray) and len(logical_address) == 0: + return np.empty((0, ), dtype=np.int64) return self.phy_map[logical_address] def num_blocks(self): @@ -233,7 +237,7 @@ def __init__(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: self.block_tables: Dict[int, BlockTable] = {} - def get_block_table(self, msg: SchedulerSequence): + def get_block_table(self, msg: Union[SchedulerSequence, SchedulerAdapter]): """Get the block table of given msg. Args: @@ -247,9 +251,12 @@ def can_allocate(self, msg: SchedulerSequence): """Return if physical block can be allocated for given message.""" num_required_blocks = msg.num_required_blocks() num_free_phy = self.get_num_free_gpu_blocks() + if msg.adapter_name is not None: + adapter = ADAPTER_MANAGER.get_adapter(msg.adapter_name) + num_required_blocks += adapter.num_required_blocks() return num_required_blocks <= num_free_phy - def allocate(self, msg: SchedulerSequence): + def allocate_msg(self, msg: SchedulerSequence): """Allocate physical blocks for given message according to logical blocks.""" logical_blocks = msg.logical_blocks @@ -261,6 +268,22 @@ def allocate(self, msg: SchedulerSequence): logical_blocks.append(blocks) logical_blocks.add_tokens(num_required_tokens) + def allocate_adapter(self, adapter: SchedulerAdapter): + """Allocate cpu blocks for given adapter.""" + num_required_blocks = adapter.num_required_blocks() + if num_required_blocks > 0: + blocks = self.allocator.allocate(num_required_blocks, 'cpu') + adapter.logical_blocks.append(blocks) + + def allocate(self, data: Union[SchedulerSequence, SchedulerAdapter]): + """allocate stuff.""" + if isinstance(data, SchedulerSequence): + return self.allocate_msg(data) + elif isinstance(data, SchedulerAdapter): + return self.allocate_adapter(data) + else: + raise TypeError(f'Unsupported allocate type: {type(data)}') + def free(self, msg: SchedulerSequence): """Free all physical blocks allocated for the session.""" self.allocator.free(msg.logical_blocks.get_real_blocks()) @@ -317,7 +340,7 @@ def _copy_lask_block(logical_blocks, copy_map): return copy_map - def try_swap_out(self, msg: SchedulerSequence): + def try_swap_out(self, msg: Union[SchedulerSequence, SchedulerAdapter]): """Try swap msg out.""" swap_map = dict() logical_blocks = msg.logical_blocks @@ -357,6 +380,8 @@ def _do_swap(): gpu_allocator.free(old_blocks) self.allocator.update_phy_map(logical_blocks.get_real_blocks(), new_blocks) + if isinstance(msg, SchedulerAdapter): + msg.active(False) return True, swap_map if not _can_swap(): @@ -364,7 +389,7 @@ def _do_swap(): else: return _do_swap() - def try_swap_in(self, msg: SchedulerSequence): + def try_swap_in(self, msg: Union[SchedulerSequence, SchedulerAdapter]): """Try swap msg in.""" swap_map = dict() logical_blocks = msg.logical_blocks @@ -404,6 +429,8 @@ def _do_swap(): cpu_allocator.free(old_blocks) self.allocator.update_phy_map(logical_blocks.get_real_blocks(), new_blocks) + if isinstance(msg, SchedulerAdapter): + msg.active(True) return True, swap_map if not _can_swap(): diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index 8ee3e337d6..2ac81c0b84 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -2,10 +2,11 @@ # modify from: https://github.com/vllm-project/vllm from collections import OrderedDict from dataclasses import dataclass -from typing import Dict, List +from typing import Dict, List, Set, Union from lmdeploy.utils import get_logger +from ..adapter.adapter import ADAPTER_MANAGER, SchedulerAdapter from ..config import CacheConfig, SchedulerConfig from ..messages import MessageStatus, SchedulerSequence, SchedulerSession from .block_manager import BlockManager @@ -13,6 +14,7 @@ logger = get_logger('lmdeploy') SeqList = List[SchedulerSequence] +AdapterList = List[SchedulerAdapter] def _find_seq_with_session_id(group: SeqList, session_id: int): @@ -27,6 +29,7 @@ class SchedulerOutput: swap_in_map: Dict[int, int] swap_out_map: Dict[int, int] copy_map: Dict[int, int] + adapters: AdapterList class Scheduler: @@ -46,6 +49,7 @@ def __init__(self, scheduler_config: SchedulerConfig, self.running: SeqList = [] self.hanging: SeqList = [] self.sessions: Dict[int, SchedulerSession] = OrderedDict() + self.actived_adapters: Set[str] = set() self.block_manager = BlockManager( cache_config.num_gpu_blocks, @@ -103,6 +107,20 @@ def add_sequence(self, seq: SchedulerSequence): self.scheduler_config.max_request_output_len self.waiting.append(seq) + def add_adapter(self, adapter_path: str, adapter_name: str): + """Add adapter. + + Args: + adapter_path (str): The path of adapter. + adapter_name (str): The name of the adapter. + """ + adapter = ADAPTER_MANAGER.add_adapter_from_pretrained( + adapter_path, adapter_name=adapter_name) + self.block_manager.allocate_adapter(adapter) + block_table = self.block_manager.get_block_table( + adapter) - self.block_manager.num_gpu_blocks + return adapter.build_weight_map(block_table) + def _schedule_prefill(self): """Schedule for prefilling.""" @@ -113,6 +131,9 @@ def _schedule_prefill(self): swap_in_map: Dict[int, int] = dict() copy_map: Dict[int, int] = dict() running: SeqList = [] + required_adapters = set(seq.adapter_name for seq in self.running) + max_adapters = self.scheduler_config.max_active_adapters - len( + required_adapters) def _to_running(seq: SchedulerSequence): """to running.""" @@ -133,6 +154,29 @@ def _reorder_waiting(): self.waiting = sorted(self.waiting, key=lambda seq: seq.arrive_time) + def _active_adapter(adapter_name): + """active adapter of a seq.""" + if adapter_name is None: + required_adapters.add(adapter_name) + return + if adapter_name not in required_adapters: + adapter = ADAPTER_MANAGER.get_adapter(adapter_name) + if not adapter.is_actived(): + success, tmp_map = self.block_manager.try_swap_in(adapter) + assert success + swap_in_map.update(tmp_map) + required_adapters.add(adapter_name) + + def _deactive_adapter(adapter_name): + """deactive_adapter.""" + if adapter_name is None: + return + adapter = ADAPTER_MANAGER.get_adapter(adapter_name) + if adapter.is_actived(): + success, tmp_map = self.block_manager.try_swap_out(adapter) + assert success + swap_out_map.update(tmp_map) + if len(running) >= max_batches or len(self.waiting) == 0: return running, swap_in_map, swap_out_map, copy_map @@ -140,6 +184,11 @@ def _reorder_waiting(): while len(self.waiting) > 0 and len(running) < max_batches: seq = self.waiting[0] + # limit number of adapters + if len(required_adapters) >= max_adapters: + if seq.adapter_name not in required_adapters: + break + if not block_manager.can_allocate(seq): if not _evict_until_can_append(seq): break @@ -149,9 +198,16 @@ def _reorder_waiting(): break # allocate session memory block_manager.allocate(seq) + _active_adapter(seq.adapter_name) self.waiting.pop(0) _to_running(seq) + deactive_adapters = self.actived_adapters.difference(required_adapters) + for adapter_name in deactive_adapters: + _deactive_adapter(adapter_name) + + self.actived_adapters = required_adapters + self.running += running return running, swap_in_map, swap_out_map, copy_map @@ -215,6 +271,13 @@ def _evict_until_can_append(seq: SchedulerSequence): self.running = running return running, swap_in_map, swap_out_map, copy_map + @classmethod + def _get_adapter_list(cls, adapter_names: List[str]): + adapters = [ + ADAPTER_MANAGER.get_adapter(name) for name in adapter_names + ] + return adapters + def schedule(self, is_prefill: bool): """Schedule inputs for next steps.""" if is_prefill: @@ -223,12 +286,13 @@ def schedule(self, is_prefill: bool): output = self._schedule_decoding() running, swap_in_map, swap_out_map, copy_map = output - return SchedulerOutput( - running=running, - swap_in_map=swap_in_map, - swap_out_map=swap_out_map, - copy_map=copy_map, - ) + adapters = self._get_adapter_list(self.actived_adapters) + + return SchedulerOutput(running=running, + swap_in_map=swap_in_map, + swap_out_map=swap_out_map, + copy_map=copy_map, + adapters=adapters) def _set_session_status(self, session_id: int, status: MessageStatus): """Setup the status of session. @@ -324,6 +388,6 @@ def _update_queue(group: SeqList, expect_status: MessageStatus): for session_id in session_id_to_remove: self.sessions.pop(session_id) - def get_block_tables(self, seqs: SeqList): + def get_block_tables(self, seqs: Union[SeqList, AdapterList]): """get block table of the sequences.""" return [self.block_manager.get_block_table(seq) for seq in seqs] diff --git a/requirements/runtime.txt b/requirements/runtime.txt index e5d322f52f..26aae42080 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -2,6 +2,7 @@ fastapi fire mmengine-lite numpy +peft pydantic>2.0.0 pynvml safetensors diff --git a/tests/pytorch/kernel/test_mbgmm.py b/tests/pytorch/kernel/test_mbgmm.py index 54f820f86a..8151a9bc72 100644 --- a/tests/pytorch/kernel/test_mbgmm.py +++ b/tests/pytorch/kernel/test_mbgmm.py @@ -27,6 +27,10 @@ def seq_lens(self): def ranks(self): yield torch.tensor([2, 4]).cuda() + @pytest.fixture + def page_start(self, ranks): + yield torch.zeros_like(ranks) + @pytest.fixture def start_loc(self, seq_lens): yield seq_lens.cumsum(0) - seq_lens @@ -37,7 +41,7 @@ def input(self, seq_lens, head_size, dtype): yield torch.rand(total_len, head_size, dtype=dtype).cuda() @pytest.fixture - def rank_ids(self, seq_lens, ranks): + def adapter_ids(self, seq_lens, ranks): num_ranks = len(ranks) num_seqs = len(seq_lens) ret = torch.randint(0, num_ranks, (num_seqs, )).cuda() @@ -84,9 +88,9 @@ def paged_lora_b(self, lora_b, ranks, page_table, head_size, out_head_size, yield cache @pytest.fixture - def gt(self, input, start_loc, seq_lens, rank_ids, lora_a, lora_b): + def gt(self, input, start_loc, seq_lens, adapter_ids, lora_a, lora_b): out = [] - for loc, s_len, r_id in zip(start_loc, seq_lens, rank_ids): + for loc, s_len, r_id in zip(start_loc, seq_lens, adapter_ids): inp = input[loc:loc + s_len] l_a = lora_a[r_id] l_b = lora_b[r_id] @@ -95,7 +99,8 @@ def gt(self, input, start_loc, seq_lens, rank_ids, lora_a, lora_b): yield torch.cat(out) def test_mbgmm(self, input, paged_lora_a, paged_lora_b, out_head_size, - start_loc, seq_lens, rank_ids, page_table, ranks, gt): + start_loc, seq_lens, adapter_ids, page_table, ranks, + page_start, gt): max_seq_len = max(seq_lens).item() max_rank = page_table.size(-1) @@ -103,8 +108,9 @@ def test_mbgmm(self, input, paged_lora_a, paged_lora_b, out_head_size, paged_lora_a, b_start_loc=start_loc, b_seq_lens=seq_lens, - b_rank_ids=rank_ids, + b_adapter_ids=adapter_ids, rank_page_table=page_table, + rank_page_start=page_start, ranks=ranks, max_seq_len=max_seq_len, max_rank=max_rank) @@ -113,8 +119,9 @@ def test_mbgmm(self, input, paged_lora_a, paged_lora_b, out_head_size, paged_lora_b[..., :out_head_size], b_start_loc=start_loc, b_seq_lens=seq_lens, - b_rank_ids=rank_ids, + b_adapter_ids=adapter_ids, rank_page_table=page_table, + rank_page_start=page_start, ranks=ranks, max_seq_len=max_seq_len, max_rank=max_rank) diff --git a/tests/pytorch/kernel/test_mbgmv.py b/tests/pytorch/kernel/test_mbgmv.py index 08866ba294..95e042a3fc 100644 --- a/tests/pytorch/kernel/test_mbgmv.py +++ b/tests/pytorch/kernel/test_mbgmv.py @@ -27,6 +27,10 @@ def batch_size(self): def ranks(self): yield torch.tensor([2, 4]).cuda() + @pytest.fixture + def page_start(self, ranks): + yield torch.zeros_like(ranks) + @pytest.fixture def input(self, batch_size, head_size, dtype): x = torch.rand(batch_size, head_size, dtype=dtype).cuda() @@ -34,7 +38,7 @@ def input(self, batch_size, head_size, dtype): yield x @pytest.fixture - def rank_ids(self, batch_size, ranks): + def adapter_ids(self, batch_size, ranks): num_ranks = len(ranks) ret = torch.randint(0, num_ranks, (batch_size, )).cuda() yield ret @@ -82,9 +86,9 @@ def paged_lora_b(self, lora_b, ranks, page_table, head_size, out_head_size, yield cache @pytest.fixture - def gt(self, input, rank_ids, lora_a, lora_b): + def gt(self, input, adapter_ids, lora_a, lora_b): out = [] - for inp, r_id in zip(input, rank_ids): + for inp, r_id in zip(input, adapter_ids): inp = inp.unsqueeze(0) l_a = lora_a[r_id] l_b = lora_b[r_id] @@ -93,20 +97,22 @@ def gt(self, input, rank_ids, lora_a, lora_b): yield torch.cat(out) def test_mbgmv(self, input, paged_lora_a, paged_lora_b, out_head_size, - rank_ids, page_table, ranks, gt): + adapter_ids, page_table, ranks, page_start, gt): max_rank = page_table.size(-1) xa = mbgmv_a(input, paged_lora_a, - b_rank_ids=rank_ids, + b_adapter_ids=adapter_ids, rank_page_table=page_table, + rank_page_start=page_start, ranks=ranks, max_rank=max_rank) output = mbgmv_b(xa, paged_lora_b[..., :out_head_size], - b_rank_ids=rank_ids, + b_adapter_ids=adapter_ids, rank_page_table=page_table, + rank_page_start=page_start, ranks=ranks, max_rank=max_rank) torch.testing.assert_close(gt, output, atol=1e-3, rtol=1e-5) diff --git a/tests/pytorch/kernel/test_paged_attention.py b/tests/pytorch/kernel/test_paged_attention.py index d943581a7b..e669338d9b 100644 --- a/tests/pytorch/kernel/test_paged_attention.py +++ b/tests/pytorch/kernel/test_paged_attention.py @@ -192,7 +192,7 @@ def conti_gt(self, gt, seq_lens): [([30, 50, 70, 90], [50, 40, 30, 20]), ([1, 1, 1, 1], [50, 40, 30, 20])], indirect=True) - @pytest.mark.parametrize('block_size', [1, 16], indirect=True) + @pytest.mark.parametrize('block_size', [2, 16], indirect=True) def test_paged_attention(self, conti_q, blocked_kv, block_offsets, start_loc, seq_lens, history_lens, block_size, conti_gt): @@ -212,7 +212,7 @@ def test_paged_attention(self, conti_q, blocked_kv, block_offsets, b_seq_len=seq_lens, b_kv_seq_len=kv_seq_lens, max_input_len=max_seq_len) - torch.testing.assert_close(out, conti_gt, atol=5e-4, rtol=1e-5) + torch.testing.assert_close(out, conti_gt, atol=1e-3, rtol=1e-5) @pytest.mark.parametrize(['num_heads_q', 'num_heads_k'], [(4, 2)], indirect=True) diff --git a/tests/pytorch/kernel/test_rearange_all_gather.py b/tests/pytorch/kernel/test_rearange_all_gather.py new file mode 100644 index 0000000000..643c425fca --- /dev/null +++ b/tests/pytorch/kernel/test_rearange_all_gather.py @@ -0,0 +1,83 @@ +import pytest +import torch + +from lmdeploy.pytorch.kernels.rearange_all_gather import rearange_all_gather + + +class TestRearangeAllGather: + + @pytest.fixture + def seq_lens(self, request): + yield torch.tensor(request.param, device='cuda') + + @pytest.fixture + def start_loc(self, seq_lens): + yield seq_lens.cumsum(0) - seq_lens + + @pytest.fixture + def ranks(self): + yield torch.tensor([4, 8]).cuda() + + @pytest.fixture + def adapter_ids(self, seq_lens, ranks): + num_ranks = len(ranks) + num_seqs = len(seq_lens) + ret = torch.randint(0, num_ranks, (num_seqs, )).cuda() + yield ret + + @pytest.fixture + def world_size(self): + yield 2 + + @pytest.fixture + def input(self, seq_lens, ranks): + max_rank = max(ranks) + total_len = seq_lens.sum() + yield torch.rand(total_len, max_rank).cuda() + + @pytest.fixture + def rank_per_input(self, seq_lens, ranks, adapter_ids): + token_adapter_ids = [ + torch.full((slen, ), ada_id) + for slen, ada_id in zip(seq_lens, adapter_ids) + ] + token_adapter_ids = torch.cat(token_adapter_ids).cuda() + yield ranks[token_adapter_ids] + + @pytest.fixture + def valid_mask(self, rank_per_input, seq_lens, ranks): + max_rank = max(ranks) + total_len = seq_lens.sum() + mask = torch.zeros(total_len, max_rank).to(bool) + for r, m in zip(rank_per_input, mask): + m[:r] = True + yield mask.cuda() + + @pytest.fixture + def gt(self, input, rank_per_input, ranks, world_size): + max_rank = max(ranks) + pranks = rank_per_input // world_size + pmax_rank = max_rank // world_size + output = torch.empty_like(input) + for pr, inp, out in zip(pranks, input, output): + pindex = torch.arange(pr).cuda() + index = [pindex + ws * pmax_rank for ws in range(world_size)] + index = torch.cat(index) + out[:index.size(0)] = inp[index] + yield output + + @pytest.mark.parametrize('seq_lens', [[30, 50, 70, 90], [1, 1, 1, 1]], + indirect=True) + def test_gather(self, input, start_loc, seq_lens, adapter_ids, ranks, + world_size, gt, valid_mask): + max_seq_len = max(seq_lens) + output = rearange_all_gather(input, + start_loc, + seq_lens, + adapter_ids, + ranks, + world_size, + max_seq_len=max_seq_len) + output = output.where(valid_mask, output.new_tensor(0)) + gt = gt.where(valid_mask, gt.new_tensor(0)) + torch.testing.assert_close(output, gt)