From d895e516e506a11ec6ad34c6b07134f69a7ca333 Mon Sep 17 00:00:00 2001 From: grimoire Date: Fri, 13 Sep 2024 18:29:50 +0800 Subject: [PATCH] refactor lora tp1 --- lmdeploy/pytorch/adapter/adapter.py | 401 +++--------------- lmdeploy/pytorch/backends/base.py | 2 +- lmdeploy/pytorch/backends/cuda/lora.py | 80 ++++ lmdeploy/pytorch/backends/cuda/op_backend.py | 6 +- lmdeploy/pytorch/backends/cuda/slora.py | 224 ---------- .../pytorch/backends/{slora.py => lora.py} | 26 +- lmdeploy/pytorch/engine/engine.py | 64 +-- lmdeploy/pytorch/engine/model_agent.py | 151 ++----- lmdeploy/pytorch/kernels/__init__.py | 8 - lmdeploy/pytorch/kernels/cuda/__init__.py | 8 - lmdeploy/pytorch/kernels/cuda/fused_lora.py | 180 ++++++++ lmdeploy/pytorch/kernels/cuda/mbgmm.py | 276 ------------ lmdeploy/pytorch/kernels/cuda/mbgmv.py | 223 ---------- .../kernels/cuda/rearange_all_gather.py | 130 ------ lmdeploy/pytorch/kernels/mbgmm.py | 5 - lmdeploy/pytorch/kernels/mbgmv.py | 6 - .../pytorch/kernels/rearange_all_gather.py | 4 - lmdeploy/pytorch/model_inputs.py | 44 -- lmdeploy/pytorch/models/patch.py | 78 ++-- lmdeploy/pytorch/nn/linear.py | 68 ++- .../pytorch/paging/block_manager/__init__.py | 11 +- .../block_manager/base_block_manager.py | 30 +- .../block_manager/default_block_manager.py | 35 +- .../block_manager/window_block_manager.py | 29 +- lmdeploy/pytorch/paging/scheduler.py | 79 +--- tests/pytorch/kernel/test_fused_lora.py | 108 +++++ tests/pytorch/kernel/test_mbgmm.py | 134 ------ tests/pytorch/kernel/test_mbgmv.py | 122 ------ .../kernel/test_rearange_all_gather.py | 83 ---- 29 files changed, 605 insertions(+), 2010 deletions(-) create mode 100644 lmdeploy/pytorch/backends/cuda/lora.py delete mode 100644 lmdeploy/pytorch/backends/cuda/slora.py rename lmdeploy/pytorch/backends/{slora.py => lora.py} (58%) create mode 100644 lmdeploy/pytorch/kernels/cuda/fused_lora.py delete mode 100644 lmdeploy/pytorch/kernels/cuda/mbgmm.py delete mode 100644 lmdeploy/pytorch/kernels/cuda/mbgmv.py delete mode 100644 lmdeploy/pytorch/kernels/cuda/rearange_all_gather.py delete mode 100644 lmdeploy/pytorch/kernels/mbgmm.py delete mode 100644 lmdeploy/pytorch/kernels/mbgmv.py delete mode 100644 lmdeploy/pytorch/kernels/rearange_all_gather.py create mode 100644 tests/pytorch/kernel/test_fused_lora.py delete mode 100644 tests/pytorch/kernel/test_mbgmm.py delete mode 100644 tests/pytorch/kernel/test_mbgmv.py delete mode 100644 tests/pytorch/kernel/test_rearange_all_gather.py diff --git a/lmdeploy/pytorch/adapter/adapter.py b/lmdeploy/pytorch/adapter/adapter.py index 3e39cea6bb..2b9387eed2 100644 --- a/lmdeploy/pytorch/adapter/adapter.py +++ b/lmdeploy/pytorch/adapter/adapter.py @@ -1,20 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import re -from dataclasses import dataclass, field -from typing import Any, Dict, Iterable, List +from typing import Dict, Iterable, List, Tuple -import numpy as np import torch import torch.distributed as dist -from torch import Tensor - -from ..block import LogicalTokenBlocks - - -def _div_up(a, b): - """div up.""" - return (a + b - 1) // b +from torch import nn def get_ranks_and_scalings(target_name: str, @@ -57,47 +48,6 @@ def find_all_target(model: torch.nn.Module, target_name: str): return found_mods, pack_idx -def get_max_ranks_per_block(block_numel: int, rank_stride: int): - assert block_numel >= rank_stride, ( - 'LoRA Adapter requires larger block_size.') - return block_numel // rank_stride - - -def get_ranks_per_block(block_numel: int, rank_stride: int, rank: int): - """ranks per blocks.""" - max_ranks_per_block = get_max_ranks_per_block(block_numel, rank_stride) - return min(rank, max_ranks_per_block) - - -def get_num_required_blocks(block_numel: int, rank_stride: int, rank: int): - """get num required blocks.""" - ranks_per_block = get_ranks_per_block(block_numel, rank_stride, rank) - if rank == 0: - return 0 - return _div_up(rank, ranks_per_block) - - -def get_inblock_offset(block_numel: int, rank_stride: int, rank: int): - """in block offset.""" - ranks_per_block = get_ranks_per_block(block_numel, rank_stride, rank) - num_required_blocks = get_num_required_blocks(block_numel, rank_stride, - rank) - ret = np.arange(ranks_per_block) * rank_stride - ret = ret.repeat(num_required_blocks)[:rank] - return ret - - -def get_block_idx_per_rank(block_numel: int, rank_stride: int, rank: int): - """out block idx.""" - ranks_per_block = get_ranks_per_block(block_numel, rank_stride, rank) - num_required_blocks = get_num_required_blocks(block_numel, rank_stride, - rank) - ret = np.arange(num_required_blocks) - ret = ret[:, None].repeat(ranks_per_block, 1) - ret = ret.flatten()[:rank] - return ret - - def get_layer_index(key: str, layers_pattern: str = None): """get layer index of the lora linear.""" if isinstance(layers_pattern, str): @@ -113,19 +63,6 @@ def get_layer_index(key: str, layers_pattern: str = None): return int(layer_index[1]) -@dataclass -class LoRATargetInfo: - """lora linear info.""" - in_features: int - out_features: int - colwise: bool - rank_stride: int = field(default=0, init=False) - - def __post_init__(self): - """post init.""" - self.rank_stride = max(self.in_features, self.out_features) - - def _get_rank_and_world(): """get rank and world size.""" rank = 0 @@ -137,298 +74,72 @@ def _get_rank_and_world(): return rank, world_size -@dataclass -class AdapterWeightMap: - adapter_name: str - path: str - rank: List[int] - rank_offset: np.ndarray - max_rank: int - target_modules: List[str] - colwise: List[bool] - - @staticmethod - def _get_weight(weight: torch.Tensor, is_lora_a: bool, is_col: bool, - rank: int, world_size: int): - """get sliced weight.""" - if world_size == 1: - return weight - - if not is_col and is_lora_a: - # rowwise - weight = weight.chunk(world_size, dim=1)[rank] - else: - # colwise - weight = weight.chunk(world_size, dim=0)[rank] - return weight - - @staticmethod - def _fill_a_cache(weight: torch.Tensor, cache: torch.Tensor, - rank_off: torch.Tensor): - """fill a cache.""" - num_ranks, feat_size = weight.shape - - for rank in range(num_ranks): - off = rank_off[rank] - cache[off:off + feat_size].copy_(weight[rank]) - - @staticmethod - def _fill_b_cache(weight: torch.Tensor, cache: torch.Tensor, - rank_off: torch.Tensor): - """fill a cache.""" - feat_size, num_ranks = weight.shape - - for rank in range(num_ranks): - off = rank_off[rank] - cache[off:off + feat_size].copy_(weight[:, rank]) - - def cache_adapter(self, caches: List[List[Tensor]]): - """cache all linear.""" - if self.path is None: - return - checkpoint_path = f'{self.path}/adapter_model.bin' - state_dict = torch.load(checkpoint_path, map_location='cpu') - - dist_rank, world_size = _get_rank_and_world() - - target_modules = self.target_modules - target_map = dict( - (name, idx) for idx, name in enumerate(target_modules)) - num_targets = len(target_modules) - rank_offset = self.rank_offset.view(num_targets, -1) - for key, weight in state_dict.items(): - layer_idx = get_layer_index(key, None) - a_cache, b_cache = caches[layer_idx] - a_cache = a_cache.view(-1) - b_cache = b_cache.view(-1) - - split_key = key.split('.') - assert split_key[-1] == 'weight' - target_name = split_key[-3] - if split_key[-2] == 'lora_A': - is_lora_a = True - elif split_key[-2] == 'lora_B': - is_lora_a = False - else: - raise RuntimeError(f'Unexpected key: {key}') - - target_id = target_map[target_name] - rank_off = rank_offset[target_id] - is_col = self.colwise[target_id] - weight = self._get_weight(weight, - is_lora_a, - is_col, - rank=dist_rank, - world_size=world_size) - if is_lora_a: - self._fill_a_cache(weight, a_cache, rank_off) - else: - self._fill_b_cache(weight, b_cache, rank_off) - - -@dataclass -class SchedulerAdapter: - """lora adapter.""" - - adapter_id: int - adapter_name: str - rank: List[int] - scaling: List[int] - target_modules: List[str] - target_infos: List[LoRATargetInfo] - logical_blocks: LogicalTokenBlocks - inblock_offset: np.ndarray - block_idx_per_rank: np.ndarray - adapter_path: str = None - block_stride: int = 0 - max_rank: int = 0 - num_required_blocks: int = 0 - rank_offset: np.ndarray = field(default=None, init=False) - _active: bool = field(default=False, init=False) - - @classmethod - def new(cls, adapter_id: int, adapter_name: str, adapter_path: str, - adapter_cfg: Any, target_infos: Dict[str, LoRATargetInfo], - block_numel: int, max_rank: int): - """new.""" - - target_modules = list(target_infos.keys()) - - rank = [] - scaling = [] - inblock_offset = [np.empty((0, ), dtype=np.int64)] - block_idx_per_rank = [np.empty((0, ), dtype=np.int64)] - num_required_blocks = 0 - for target_name in target_modules: - - # get rank and scaling - r = 0 - s = 1.0 - if target_name in adapter_cfg.target_modules: - r = adapter_cfg.r - if r != 0: - s = adapter_cfg.lora_alpha / r - rank.append(r) - scaling.append(s) - - info = target_infos[target_name] - rank_stride = info.rank_stride - ib_offset = get_inblock_offset(block_numel, rank_stride, r) - pad_ib_offset = np.zeros((max_rank, ), dtype=np.int64) - pad_ib_offset[:ib_offset.shape[0]] = ib_offset - inblock_offset.append(pad_ib_offset) - bidx_p_rank = get_block_idx_per_rank(block_numel, rank_stride, - r) + num_required_blocks - pad_bidx_p_rank = np.zeros((max_rank, ), dtype=np.int64) - pad_bidx_p_rank[:bidx_p_rank.shape[0]] = bidx_p_rank - block_idx_per_rank.append(pad_bidx_p_rank) - num_required_blocks += get_num_required_blocks( - block_numel, rank_stride, r) - inblock_offset = np.concatenate(inblock_offset) - block_idx_per_rank = np.concatenate(block_idx_per_rank) - - ret = cls( - adapter_id=adapter_id, - adapter_name=adapter_name, - rank=rank, - scaling=scaling, - target_modules=target_modules, - target_infos=target_infos, - logical_blocks=LogicalTokenBlocks(), - inblock_offset=inblock_offset, - block_idx_per_rank=block_idx_per_rank, - adapter_path=adapter_path, - block_stride=block_numel, - max_rank=max_rank, - num_required_blocks=num_required_blocks, - ) - - return ret - - def update_rank_offset(self, phy_blocks: np.ndarray): - """update rank offset.""" - if len(phy_blocks) > 0: - rank_offset = phy_blocks[ - self.block_idx_per_rank] * self.block_stride - rank_offset += self.inblock_offset +def _get_reverse_pack_map(model: nn.Module): + """get reverse pack map.""" + packed_modules_mapping = getattr(model, 'packed_modules_mapping', dict()) + reverse_map = dict() + for pack_name, names in packed_modules_mapping.items(): + for name in names: + reverse_map[name] = pack_name + return reverse_map + + +def _get_key_map(reverse_map: Dict[str, str]): + """get key map.""" + key_map = dict() + for name, pack_name in reverse_map.items(): + key = f'.{name}' + val = f'.{pack_name}.lora_adapters.{name}' + key_map[key] = val + + return key_map + + +def load_lora_weights(model: nn.Module, weights: Iterable[Tuple[str, + torch.Tensor]], + adapter_id: int): + """load lora weights.""" + from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight + prefix_len = len('base_model.model.') + w_len = len('.weight') + reverse_map = _get_reverse_pack_map(model) + key_map = _get_key_map(reverse_map) + + params_dict = dict(model.named_parameters()) + for name, loaded_weight in weights: + name = name[prefix_len:] + splited_name = name.split('.') + assert splited_name[-1] == 'weight' + assert splited_name[-2] in ['lora_A', 'lora_B'] + mod_name = splited_name[-3] + dot_mod_name = f'.{mod_name}' + if dot_mod_name in key_map: + replace_name = key_map[dot_mod_name] else: - rank_offset = np.zeros_like(self.inblock_offset) - self.rank_offset = rank_offset - return rank_offset + replace_name = f'.{mod_name}.lora_adapters.{mod_name}' + name = name[:-w_len] + param_name = name.replace(dot_mod_name, replace_name) - def is_actived(self): - """check if adapter is active.""" - return self._active - - def active(self, flag: bool = True): - """active adapter.""" - self._active = flag - - @property - def name(self): - return self.adapter_name - - def build_weight_map(self): - """build weight map.""" - assert self.rank_offset is not None - colwise = [ - self.target_infos[name].colwise for name in self.target_modules - ] - return AdapterWeightMap( - adapter_name=self.name, - path=self.adapter_path, - rank=self.rank, - rank_offset=self.rank_offset, - max_rank=self.max_rank, - target_modules=self.target_modules, - colwise=colwise, - ) - - -class NoneLoraConfig: - - def __init__(self): - self.r = 0 - self.lora_alpha = 8 - self.target_modules = [] + param = params_dict[param_name] + load_weight(param, loaded_weight, adapter_id=adapter_id) class AdapterManager: """adapter manager.""" - def __init__(self, adapters: Dict[str, str], - target_infos: Dict[str, LoRATargetInfo], block_numel: int): - self.target_infos = target_infos - self.block_numel = block_numel + def __init__(self, adapters: Dict[str, str]): if adapters is None: adapters = dict() - self.adapter_paths = dict( - (name, path) for name, path in adapters.items()) - self.adapter_paths[None] = None - - self.adapter_cfgs = self._get_adapter_cfgs(adapters) - adapter_names = list(adapters.keys()) - self.adapter_id_map = dict( - (name, idx + 1) for idx, name in enumerate(adapter_names)) - self.adapter_id_map[None] = 0 - - self._adapters: Dict[str, SchedulerAdapter] = dict() - self.max_rank = self._get_max_rank() - self._add_non_adapter() - - @staticmethod - def _get_adapter_cfgs(adapters: Dict[str, str]): - """get adapter cfgs.""" - if len(adapters) == 0: - return {None: NoneLoraConfig()} - from peft import PeftConfig - adapter_cfgs = dict((name, PeftConfig.from_pretrained(path)) - for name, path in adapters.items()) - adapter_cfgs[None] = NoneLoraConfig() - return adapter_cfgs + adapter_names = sorted(adapter_names) + adapter_names = [None] + adapter_names - def _get_max_rank(self): - """get max rank.""" - max_rank = 0 - for cfg in self.adapter_cfgs.values(): - max_rank = max(max_rank, cfg.r) - return max_rank + adapter_id_map = dict(zip(adapter_names, range(len(adapter_names)))) + self.adapter_id_map = adapter_id_map - def _add_non_adapter(self): - """add non adapter.""" - adapter = self.add_adapter(None) - rank_offset = adapter.inblock_offset.copy() - adapter.update_rank_offset(rank_offset) - - def _register_adapter(self, adapter: SchedulerAdapter): - """register adapter.""" - assert adapter.adapter_name not in self._adapters - self._adapters[adapter.adapter_name] = adapter - return adapter - - def get_adapter(self, name: str, default=None): - """get adapter.""" - return self._adapters.get(name, default) + def get_adapter_ids(self, names: List[str]): + return [self.adapter_id_map[name] for name in names] def num_adapters(self): - """get num adapters.""" - return len(self._adapters) - - def add_adapter(self, adapter_name: str): - """add adapter.""" - adapter_id = self.adapter_id_map[adapter_name] - adapter_cfg = self.adapter_cfgs[adapter_name] - adapter_path = self.adapter_paths[adapter_name] - adapter = SchedulerAdapter.new( - adapter_id, - adapter_name, - adapter_path, - adapter_cfg, - self.target_infos, - self.block_numel, - max_rank=self.max_rank, - ) - self._register_adapter(adapter) - return adapter + return len(self.adapter_id_map) diff --git a/lmdeploy/pytorch/backends/base.py b/lmdeploy/pytorch/backends/base.py index db651996d3..ef538f7a3d 100644 --- a/lmdeploy/pytorch/backends/base.py +++ b/lmdeploy/pytorch/backends/base.py @@ -20,7 +20,7 @@ class OpType(Enum): GeluAndMul = auto() RMSNorm = auto() LayerNorm = auto() - SLoRA = auto() + LoRA = auto() LinearW8A8 = auto() RMSNormW8A8 = auto() MultinomialSampling = auto() diff --git a/lmdeploy/pytorch/backends/cuda/lora.py b/lmdeploy/pytorch/backends/cuda/lora.py new file mode 100644 index 0000000000..ea4504fff1 --- /dev/null +++ b/lmdeploy/pytorch/backends/cuda/lora.py @@ -0,0 +1,80 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from dataclasses import dataclass + +import torch + +from lmdeploy.pytorch.kernels.cuda.fused_lora import fused_lora +from lmdeploy.pytorch.model_inputs import StepContextManager + +from ..lora import AdapterInfo, LoRABuilder, LoRAImpl + + +@dataclass +class PackedLoRAInput: + """packed lora input.""" + x: torch.Tensor + q_start_loc: torch.Tensor + q_seqlens: torch.Tensor + adapter_ids: torch.Tensor + max_seq_len: int + is_decoding: bool + + +class TritonLoRAImpl(LoRAImpl): + """triton lora implementation.""" + + @staticmethod + def _make_packed_lora_input(x, ctx_mgr): + """make PackedLoRAInput.""" + context = ctx_mgr.current_context() + + # adapter cache + max_q_seq_length = x.numel() // x.size(-1) + + return PackedLoRAInput(x=x.flatten(0, -2).contiguous(), + q_start_loc=context.q_start_loc, + q_seqlens=context.q_seqlens, + adapter_ids=context.local_adapter_ids, + max_seq_len=max_q_seq_length, + is_decoding=context.is_decoding) + + def forward(self, + x: torch.Tensor, + lora_A: torch.Tensor, + lora_B: torch.Tensor, + base_output: torch.Tensor, + adapter_info: AdapterInfo, + ctx_mgr: StepContextManager, + colwise: bool, + is_tp: bool = True): + """forward.""" + lora_input = self._make_packed_lora_input(x, ctx_mgr) + + lora_out = fused_lora(lora_input.x, + lora_A, + lora_B, + scaling=adapter_info.scalings, + rank_start=adapter_info.rank_offsets, + ranks=adapter_info.ranks, + seq_start=lora_input.q_start_loc, + seq_lens=lora_input.q_seqlens, + adapter_ids=lora_input.adapter_ids, + max_rank=adapter_info.max_rank, + max_seqlen=lora_input.max_seq_len, + ) + + base_slice = adapter_info.base_slice + sliced_base = base_output[..., base_slice] + lora_out = lora_out.reshape(sliced_base.shape) + sliced_base.add_(lora_out) + output = base_output + return output + + +class TritonLoRABuilder(LoRABuilder): + """triton lora layer builder.""" + + @staticmethod + def build(): + """build.""" + return TritonLoRAImpl() diff --git a/lmdeploy/pytorch/backends/cuda/op_backend.py b/lmdeploy/pytorch/backends/cuda/op_backend.py index c7da9c40e1..216aac5a77 100644 --- a/lmdeploy/pytorch/backends/cuda/op_backend.py +++ b/lmdeploy/pytorch/backends/cuda/op_backend.py @@ -32,9 +32,9 @@ def get_layer_impl_builder(cls, layer_type: OpType): elif layer_type == OpType.RMSNorm: from .norm import TritonRMSNormBuilder return TritonRMSNormBuilder - elif layer_type == OpType.SLoRA: - from .slora import TritonSLoRABuilder - return TritonSLoRABuilder + elif layer_type == OpType.LoRA: + from .lora import TritonLoRABuilder + return TritonLoRABuilder elif layer_type == OpType.LinearW8A8: from .qmodules import TritonLinearW8A8Builder return TritonLinearW8A8Builder diff --git a/lmdeploy/pytorch/backends/cuda/slora.py b/lmdeploy/pytorch/backends/cuda/slora.py deleted file mode 100644 index 84319446d4..0000000000 --- a/lmdeploy/pytorch/backends/cuda/slora.py +++ /dev/null @@ -1,224 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from dataclasses import dataclass - -import torch -import torch.distributed as dist - -from lmdeploy.pytorch.kernels.cuda.mbgmm import mbgmm_a, mbgmm_b -from lmdeploy.pytorch.kernels.cuda.mbgmv import mbgmv_a, mbgmv_b -from lmdeploy.pytorch.kernels.rearange_all_gather import rearange_all_gather -from lmdeploy.pytorch.model_inputs import StepContextManager - -from ..slora import AdapterInfo, SLoRABuilder, SLoRAImpl - - -@dataclass -class PackedLoRAInput: - """packed lora input.""" - x: torch.Tensor - q_start_loc: torch.Tensor - q_seqlens: torch.Tensor - adapter_ids: torch.Tensor - max_seq_len: int - is_decoding: bool - - -class TritonSLoRAImpl(SLoRAImpl): - """triton slora implementation.""" - - @staticmethod - def _make_packed_lora_input(x, ctx_mgr): - """make PackedLoRAInput.""" - context = ctx_mgr.current_context() - - # adapter cache - max_q_seq_length = x.numel() // x.size(-1) - - return PackedLoRAInput(x=x.flatten(0, -2).contiguous(), - q_start_loc=context.q_start_loc, - q_seqlens=context.q_seqlens, - adapter_ids=context.local_adapter_ids, - max_seq_len=max_q_seq_length, - is_decoding=context.is_decoding) - - def _forward_rowwise(self, - lora_input: PackedLoRAInput, - base_output: torch.Tensor, - adapter_info: AdapterInfo, - is_tp: bool = True): - """forward_rowwise.""" - sliced_base = base_output[..., adapter_info.base_slice] - out_size = sliced_base.size(-1) - if is_tp: - rank = dist.get_rank() - world_size = dist.get_world_size() - out_size //= world_size - if not lora_input.is_decoding: - xa = mbgmm_a(lora_input.x, - adapter_info.a_cache, - q_start_loc=lora_input.q_start_loc, - q_seqlens=lora_input.q_seqlens, - adapter_ids=lora_input.adapter_ids, - rank_offset=adapter_info.rank_offsets, - ranks=adapter_info.ranks, - max_seq_len=lora_input.max_seq_len, - max_rank=adapter_info.max_rank) - lora_out = mbgmm_b(xa, - adapter_info.b_cache, - q_start_loc=lora_input.q_start_loc, - q_seqlens=lora_input.q_seqlens, - adapter_ids=lora_input.adapter_ids, - scaling=adapter_info.scalings, - rank_offset=adapter_info.rank_offsets, - ranks=adapter_info.ranks, - max_seq_len=lora_input.max_seq_len, - max_rank=adapter_info.max_rank, - out_size=out_size) - else: - xa = mbgmv_a(lora_input.x, - adapter_info.a_cache, - adapter_ids=lora_input.adapter_ids, - rank_offset=adapter_info.rank_offsets, - ranks=adapter_info.ranks, - max_rank=adapter_info.max_rank) - lora_out = mbgmv_b(xa, - adapter_info.b_cache, - adapter_ids=lora_input.adapter_ids, - scaling=adapter_info.scalings, - rank_offset=adapter_info.rank_offsets, - ranks=adapter_info.ranks, - max_rank=adapter_info.max_rank, - out_size=out_size) - - if is_tp: - out_shape = base_output.shape - out = base_output.flatten(0, -2) - slice_off = adapter_info.base_slice.start - slice_off = 0 if slice_off is None else slice_off - slice_start = slice_off + rank * out_size - slice_end = slice_start + out_size - out[:, slice_start:slice_end] += lora_out - out = out.reshape(out_shape) - else: - lora_out = lora_out.reshape(sliced_base.shape) - sliced_base.add_(lora_out) - out = base_output - - return out - - def _forward_colwise( - self, - lora_input: PackedLoRAInput, - base_output: torch.Tensor, - adapter_info: AdapterInfo, - ): - """forward_colwise.""" - - def __gather_xa(xa): - """gather xa.""" - gathered_xa = xa.new_empty(world_size, xa.size(0), xa.size(1)) - dist.all_gather_into_tensor(gathered_xa, xa) - # TODO: gather would failed when adapters have different ranks. - gathered_xa = gathered_xa.permute(1, 0, 2).flatten(-2, -1) - return gathered_xa - - base_slice = adapter_info.base_slice - a_cache = adapter_info.a_cache - b_cache = adapter_info.b_cache - rank_offsets = adapter_info.rank_offsets - ranks = adapter_info.ranks - max_rank = adapter_info.max_rank - scalings = adapter_info.scalings - sliced_base = base_output[..., base_slice] - out_size = sliced_base.size(-1) - world_size = dist.get_world_size() - - if not lora_input.is_decoding: - xa = mbgmm_a(lora_input.x, - a_cache, - q_start_loc=lora_input.q_start_loc, - q_seqlens=lora_input.q_seqlens, - adapter_ids=lora_input.adapter_ids, - rank_offset=rank_offsets, - ranks=ranks, - max_seq_len=lora_input.max_seq_len, - max_rank=max_rank, - rank_step=world_size) - gathered_xa = __gather_xa(xa) - if len(ranks) > 1: - gathered_xa = rearange_all_gather( - gathered_xa, - b_start_loc=lora_input.q_start_loc, - b_seq_lens=lora_input.q_seqlens, - adapter_ids=lora_input.adapter_ids, - ranks=ranks, - world_size=world_size, - max_seq_len=lora_input.max_seq_len, - output=gathered_xa) - lora_out = mbgmm_b(gathered_xa, - b_cache, - q_start_loc=lora_input.q_start_loc, - q_seqlens=lora_input.q_seqlens, - adapter_ids=lora_input.adapter_ids, - scaling=scalings, - rank_offset=rank_offsets, - ranks=ranks, - max_seq_len=lora_input.max_seq_len, - max_rank=max_rank, - out_size=out_size) - else: - xa = mbgmv_a(lora_input.x, - a_cache, - adapter_ids=lora_input.adapter_ids, - rank_offset=rank_offsets, - ranks=ranks, - max_rank=max_rank, - rank_step=world_size) - gathered_xa = __gather_xa(xa) - if len(ranks) > 1: - gathered_xa = rearange_all_gather( - gathered_xa, - b_start_loc=lora_input.q_start_loc, - b_seq_lens=lora_input.q_seqlens, - adapter_ids=lora_input.adapter_ids, - ranks=ranks, - world_size=world_size, - max_seq_len=lora_input.max_seq_len, - output=gathered_xa) - lora_out = mbgmv_b(gathered_xa, - b_cache, - adapter_ids=lora_input.adapter_ids, - scaling=scalings, - rank_offset=rank_offsets, - ranks=ranks, - max_rank=max_rank, - out_size=out_size) - - lora_out = lora_out.reshape(sliced_base.shape) - sliced_base.add_(lora_out) - output = base_output - return output - - def forward(self, - x: torch.Tensor, - base_output: torch.Tensor, - adapter_info: AdapterInfo, - ctx_mgr: StepContextManager, - colwise: bool, - is_tp: bool = True): - """forward.""" - lora_input = self._make_packed_lora_input(x, ctx_mgr) - if colwise and is_tp: - return self._forward_colwise(lora_input, base_output, adapter_info) - else: - return self._forward_rowwise(lora_input, base_output, adapter_info, - is_tp) - - -class TritonSLoRABuilder(SLoRABuilder): - """triton slora layer builder.""" - - @staticmethod - def build(): - """build.""" - return TritonSLoRAImpl() diff --git a/lmdeploy/pytorch/backends/slora.py b/lmdeploy/pytorch/backends/lora.py similarity index 58% rename from lmdeploy/pytorch/backends/slora.py rename to lmdeploy/pytorch/backends/lora.py index 6fc606cbb8..f6b2af9f05 100644 --- a/lmdeploy/pytorch/backends/slora.py +++ b/lmdeploy/pytorch/backends/lora.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from abc import ABC, abstractmethod -from dataclasses import dataclass +from dataclasses import dataclass, field import torch @@ -14,20 +14,28 @@ class AdapterInfo: out_features: int ranks: torch.Tensor scalings: torch.Tensor - rank_offsets: torch.Tensor - a_cache: torch.Tensor - b_cache: torch.Tensor base_slice: slice - max_rank: int + rank_offsets: torch.Tensor = field(init=False) + max_rank: int = field(init=False) + def __post_init__(self): + """post init.""" + ranks = self.ranks + rank_offsets = ranks.cumsum(0) - ranks + max_rank = ranks.max().item() + self.rank_offsets = rank_offsets + self.max_rank = max_rank -class SLoRAImpl(ABC): - """slora implementation api.""" + +class LoRAImpl(ABC): + """lora implementation.""" @abstractmethod def forward(self, x: torch.Tensor, base_output: torch.Tensor, + lora_A: torch.Tensor, + lora_B: torch.Tensor, adapter_info: AdapterInfo, ctx_mgr: StepContextManager, colwise: bool, @@ -36,8 +44,8 @@ def forward(self, raise NotImplementedError -class SLoRABuilder(ABC): - """slora implementation builder.""" +class LoRABuilder(ABC): + """lora implementation builder.""" @staticmethod @abstractmethod diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 995fea6cb2..8ecadd5de1 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -13,22 +13,21 @@ from lmdeploy.utils import (get_logger, get_max_batch_size, get_model, logging_timer) -from ..adapter.adapter import AdapterManager, SchedulerAdapter +from ..adapter.adapter import AdapterManager from ..check_env import check_adapters, check_env, check_model from ..config import BackendConfig, CacheConfig, SchedulerConfig from ..devices import DeviceContext, get_device_manager from ..messages import (InputEmbeddingRangeType, InputEmbeddingType, MessageStatus, SchedulerSequence) -from ..model_inputs import AdapterInfo, ModelInputs, VisionModelInputs +from ..model_inputs import ModelInputs, VisionModelInputs from ..paging import Scheduler from .logits_process import FusedLogitsProcessor, SamplingInputs -from .model_agent import AutoModelAgent, build_model_agent +from .model_agent import build_model_agent from .request import Request, RequestManager, RequestType, Response logger = get_logger('lmdeploy') SeqList = List[SchedulerSequence] -AdapterList = List[SchedulerAdapter] _EMPTY_TOKEN = np.empty((0, ), dtype=np.int64) @@ -61,20 +60,6 @@ class InferOutput: logits: torch.Tensor = None -def _paging_adapters(adapters: dict, model_agent: AutoModelAgent, - scheduler: Scheduler): - adapters = adapters or dict() - weight_maps = [] - adapter_manager = scheduler.adapter_manager - non_adapter = adapter_manager.get_adapter(None) - weight_maps.append(non_adapter.build_weight_map()) - for name in adapters: - weight_map = scheduler.add_adapter(name) - weight_map.rank_offset = torch.tensor(weight_map.rank_offset) - weight_maps.append(weight_map) - model_agent.paging_adapters(weight_maps) - - def _tensorlize_block_offsets(block_offsets): """tensorlize block_offsets.""" from torch.nn.utils.rnn import pad_sequence @@ -83,13 +68,6 @@ def _tensorlize_block_offsets(block_offsets): return block_offsets -def _get_adapter_ids(seqs: SeqList, adapters: AdapterList): - """get adapter ids.""" - adapter_names_map = dict((ada.name, ada.adapter_id) for ada in adapters) - adapter_ids = [adapter_names_map[seq.adapter_name] for seq in seqs] - return adapter_ids - - def _check_finish(scheduler: Scheduler, current_iter: int): """dynamic prefill interval.""" if not scheduler.has_waiting(): @@ -127,8 +105,9 @@ def __init__(self, if engine_config.max_batch_size is None: engine_config.max_batch_size = get_max_batch_size( engine_config.device_type) - if engine_config.adapters is not None: - check_adapters(list(engine_config.adapters.values())) + adapters = engine_config.adapters + if adapters is not None: + check_adapters(list(adapters.values())) self.engine_config = engine_config self.tp = engine_config.tp @@ -142,7 +121,6 @@ def __init__(self, prefill_interval=engine_config.prefill_interval) # block_size = 1 to enable unified paging - adapters = engine_config.adapters cache_config = CacheConfig( max_batches=engine_config.max_batch_size, block_size=engine_config.block_size, @@ -175,13 +153,7 @@ def __init__(self, cache_config = self.model_agent.cache_config self.adapter_manager = self._build_adapter_manager(adapters) - self.scheduler = Scheduler(scheduler_config, cache_config, - self.adapter_manager) - - if adapters: - _paging_adapters(adapters, - model_agent=self.model_agent, - scheduler=self.scheduler) + self.scheduler = Scheduler(scheduler_config, cache_config) self.scheduler_config = scheduler_config self.cache_config = cache_config @@ -238,12 +210,7 @@ def _create_buffers(self): self._seq_length_buf = torch.ones(max_batches, dtype=torch.long) def _build_adapter_manager(self, adapters): - if adapters is not None and len(adapters) > 0: - linear_infos = self.model_agent.get_lora_target_info() - else: - linear_infos = dict() - block_numel = self.model_agent.get_block_numel() - return AdapterManager(adapters, linear_infos, block_numel) + return AdapterManager(adapters) def _bind_request_manager(self): """bind request manager.""" @@ -380,13 +347,11 @@ def gpu_count(self): return self.tp @logging_timer('CreateModelInputs', logger) - def create_model_inputs(self, messages: SeqList, adapters: AdapterList, - is_prefill: bool): + def create_model_inputs(self, messages: SeqList, is_prefill: bool): """create model inputs from messages. Args: messages (SeqList): The input messages. - adapters (AdapterList): Adapters. """ history_lengths = [msg.history_len for msg in messages] history_lengths = torch.tensor(history_lengths) @@ -412,11 +377,11 @@ def create_model_inputs(self, messages: SeqList, adapters: AdapterList, block_offsets = _tensorlize_block_offsets(block_offsets) local_adapter_ids = None - adapter_info = None if self.adapter_manager.num_adapters() > 1: - local_adapter_ids = _get_adapter_ids(messages, adapters) + adapter_names = [msg.adapter_name for msg in messages] + local_adapter_ids = self.adapter_manager.get_adapter_ids( + adapter_names) local_adapter_ids = seq_length.new_tensor(local_adapter_ids) - adapter_info = AdapterInfo.from_adapters(adapters) # add batch dim [bs=1, seq_len] if input_ids.ndim == 1: @@ -491,7 +456,6 @@ def __get_vlm_embeddings(): is_decoding=is_decoding, num_ignored_history=num_ignored_history, local_adapter_ids=local_adapter_ids, - adapter_info=adapter_info, vision_inputs=vision_embedding_inputs, ) @@ -831,14 +795,12 @@ def __need_logits(seqs: SeqList): schedule_output = self.scheduler.schedule( is_prefill=is_prefill, prealloc_size=prefill_interval) running: SeqList = schedule_output.running - adapters = schedule_output.adapters loop_count = 1 if is_prefill else (prefill_interval - 1) if len(running) == 0: raise NoRunningSeqs() # create inputs - inputs = self.create_model_inputs(running, adapters, - is_prefill) + inputs = self.create_model_inputs(running, is_prefill) sampling_inputs = SamplingInputs.from_sampling_params(running) all_ids = __gather_all_ids(running, sampling_inputs) guided_input_ids = __gather_guided_input_ids( diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 46d17b5c89..fa36317162 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -11,7 +11,6 @@ from lmdeploy.utils import get_logger -from ..adapter.adapter import AdapterWeightMap from ..backends import get_backend from ..config import BackendConfig, CacheConfig, ModelConfig from ..devices import DeviceContext, get_device_manager @@ -28,7 +27,7 @@ def _update_cache_config(model_config: ModelConfig, cache_config: CacheConfig, gpu_id: int = 0, - host_mem_size: int = 4 * (1 << 30), + host_mem_size: int = 1 * (1 << 30), world_size: int = 1): """Update the gpu mem and cpu mem according to model info. @@ -144,8 +143,6 @@ def model_forward( world_size=world_size, kv_caches=cache_engine.gpu_cache, ) - if inputs.adapter_info is not None: - inputs.adapter_info.update_offsets(model.rank_offsets) with ctx_mgr.context(context): input_dict = model.prepare_inputs_for_generation( past_key_values=cache_engine.gpu_cache, @@ -155,22 +152,6 @@ def model_forward( return dict(logits=output) -def _get_indexed_lora_linears(model): - """get indexed lora linears.""" - from ..adapter.adapter import get_indexed_lora_linears - if hasattr(model, 'get_model'): - model = model.get_model() - return get_indexed_lora_linears(model) - - -def _get_lora_target_info(model, adapters: Dict[str, str]): - """get lora linear info.""" - from ..adapter.adapter import get_lora_target_info - if hasattr(model, 'get_model'): - model = model.get_model() - return get_lora_target_info(model, adapters) - - SwapMap = Dict[int, int] @@ -181,18 +162,10 @@ def __init__(self, model_config: ModelConfig, cache_config: CacheConfig): self.model_config = model_config self.cache_config = cache_config - def get_lora_target_info(self): - """get lora linear info.""" - raise NotImplementedError('Not implemented') - def get_block_numel(self): """get block nelement.""" raise NotImplementedError('Not implemented') - def paging_adapters(self, weight_maps: List[AdapterWeightMap]): - """paging adapter.""" - raise NotImplementedError('Not implemented.') - async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap): """model forward. @@ -240,7 +213,9 @@ def __init__(self, self.backend_config = backend_config self._adapters = adapters - self.patched_model = self._build_model(model_path, device=device) + self.patched_model = self._build_model(model_path, + adapters, + device=device) _update_cache_config(model_config, cache_config) @@ -254,15 +229,12 @@ def __init__(self, self.cache_engine = CacheEngine(cache_config, model_config) - self._target_infos = None - if adapters is not None: - self._target_infos = add_adapters(self.patched_model, - self.cache_engine.gpu_cache, - adapters=adapters) - self.stream = torch.cuda.Stream() - def _build_model(self, model_path: str, device: torch.device = 'cuda'): + def _build_model(self, + model_path: str, + adapters: Dict[str, str] = None, + device: torch.device = 'cuda'): """build patched model.""" custom_module_map = self.model_config.custom_module_map if custom_module_map is not None: @@ -271,26 +243,19 @@ def _build_model(self, model_path: str, device: torch.device = 'cuda'): patched_model = build_patched_model(self.model_config, device=device) logger.info('loading weights.') load_model_weights(patched_model, model_path, device=device) + logger.info('loading adapters.') + if adapters is not None: + add_adapters(patched_model, + adapters, + dtype=self.model_config.dtype, + device=device) return patched_model - def get_lora_target_info(self): - """get lora linear info.""" - return self._target_infos - def get_block_numel(self): """get block nelement.""" k_cache = self.cache_engine.local_gpu_cache[0][0] return k_cache[0].numel() - def paging_adapters(self, weight_maps: List[AdapterWeightMap]): - """paging adapter.""" - logger.info('paging adapters.') - cpu_caches = self.cache_engine.cpu_cache - cpu_caches = [(kcache.flatten(1, -1), vcache.flatten(1, -1)) - for kcache, vcache in cpu_caches] - for weight_map in weight_maps: - weight_map.cache_adapter(cpu_caches) - def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap): cache_swapping(self.cache_engine, @@ -388,6 +353,14 @@ def _broadcast_config(cache_config): logger.info('loading weights.') load_model_weights(patched_model, model_path, device=device_map) + if adapters is not None: + if rank == 0: + logger.info('loading adapters.') + add_adapters(patched_model, + adapters, + dtype=model_config.dtype, + device=device_map) + _update_cache_config(model_config, cache_config, gpu_id=rank, @@ -406,16 +379,11 @@ def _broadcast_config(cache_config): model_config, rank=rank, world_size=world_size) - target_infos = None - if adapters is not None: - target_infos = add_adapters(patched_model, - cache_engine.gpu_cache, - adapters=adapters) except Exception as e: raise e - return patched_model, cache_engine, cache_config, target_infos + return patched_model, cache_engine, cache_config def _broadcast_inputs(rank: int, inputs: Any, stream: torch.cuda.Stream): @@ -429,38 +397,6 @@ def _broadcast_inputs(rank: int, inputs: Any, stream: torch.cuda.Stream): return inputs -@torch.inference_mode() -def _tp_paging_adapters( - rank: int, - cache_engine: CacheEngine, - weight_maps: AdapterWeightMap = None, -): - """tp paging adapters.""" - - def __get_weight_map(weight_maps): - """get weight map.""" - if rank == 0: - assert weight_maps is not None - dist_obj = [weight_maps] - else: - dist_obj = [None] - dist.broadcast_object_list(dist_obj) - return dist_obj[0] - - def __paging(weight_maps): - """paging.""" - cpu_caches = cache_engine.cpu_cache - cpu_caches = [(kcache.flatten(1, -1), vcache.flatten(1, -1)) - for kcache, vcache in cpu_caches] - for weight_map in weight_maps: - weight_map.cache_adapter(cpu_caches) - - weight_maps = __get_weight_map(weight_maps) - - if len(weight_maps) > 0: - __paging(weight_maps) - - def _tp_model_loop( rank: int, model_path: str, @@ -469,7 +405,6 @@ def _tp_model_loop( backend_config: BackendConfig, adapters: Dict[str, str], world_size: int, - trust_remote_code: bool = True, ): """Start model loops for tensor parallel model inference. @@ -484,16 +419,13 @@ def _tp_model_loop( world_size (int): The distribution world size. """ stream = torch.cuda.Stream() - patched_model, cache_engine, _, _ = _tp_build_model(rank, - model_path, - model_config, - cache_config, - backend_config, - adapters=adapters, - world_size=world_size) - - if adapters: - _tp_paging_adapters(rank, cache_engine=cache_engine, weight_maps=None) + patched_model, cache_engine, _ = _tp_build_model(rank, + model_path, + model_config, + cache_config, + backend_config, + adapters=adapters, + world_size=world_size) while True: inputs, swap_in_map, swap_out_map, exit_flag = _broadcast_inputs( @@ -627,7 +559,7 @@ def __signal_term_handler(sig, frame): world_size=world_size, trust_remote_code=trust_remote_code) - model, cache_engine, cache_config, target_infos = self._build_model( + model, cache_engine, cache_config = self._build_model( model_path=model_path, model_config=model_config, cache_config=cache_config, @@ -638,7 +570,6 @@ def __signal_term_handler(sig, frame): self.patched_model = model self.cache_config = cache_config self.cache_engine = cache_engine - self._target_infos = target_infos self.stream = torch.cuda.Stream() def _start_sub_process(self, model_path: str, model_config: ModelConfig, @@ -666,8 +597,7 @@ def _start_sub_process(self, model_path: str, model_config: ModelConfig, cache_config=cache_config, backend_config=backend_config, adapters=adapters, - world_size=world_size, - trust_remote_code=trust_remote_code), + world_size=world_size), ), nprocs=world_size - 1, join=False, @@ -704,7 +634,7 @@ def _build_model( """build model.""" _check_context_alive(self.mp_context) rank = 0 - model, cache_engine, cache_config, target_infos = _tp_build_model( + model, cache_engine, cache_config = _tp_build_model( rank, model_path=model_path, model_config=model_config, @@ -714,26 +644,13 @@ def _build_model( world_size=world_size, ) - return model, cache_engine, cache_config, target_infos - - def get_lora_target_info(self): - """get lora linear info.""" - return self._target_infos + return model, cache_engine, cache_config def get_block_numel(self): """get block nelement.""" k_cache = self.cache_engine.local_gpu_cache[0][0] return k_cache[0].numel() - def paging_adapters(self, weight_maps: List[AdapterWeightMap]): - """load adapter.""" - if not weight_maps: - return - _check_context_alive(self.mp_context) - rank = 0 - logger.info('paging adapters.') - _tp_paging_adapters(rank, self.cache_engine, weight_maps) - def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap): """forward impl.""" diff --git a/lmdeploy/pytorch/kernels/__init__.py b/lmdeploy/pytorch/kernels/__init__.py index 89d0076342..23cf5d33ea 100644 --- a/lmdeploy/pytorch/kernels/__init__.py +++ b/lmdeploy/pytorch/kernels/__init__.py @@ -4,11 +4,8 @@ from .fill_kv_cache import fill_kv_cache from .fused_moe import fused_moe from .fused_rotary_emb import fused_rotary_emb -from .mbgmm import mbgmm_a, mbgmm_b -from .mbgmv import mbgmv_a, mbgmv_b from .multinomial_sampling import multinomial_sampling from .pagedattention import paged_attention_fwd -from .rearange_all_gather import rearange_all_gather from .rms_norm import rms_norm from .w8a8_triton_kernels import (matmul_kernel_dynamic_quant, per_channel_quant, per_token_quant_int8, @@ -23,11 +20,6 @@ 'fill_kv_cache', 'multinomial_sampling', 'rms_norm', - 'mbgmv_a', - 'mbgmv_b', - 'mbgmm_a', - 'mbgmm_b', - 'rearange_all_gather', 'matmul_kernel_dynamic_quant', 'per_channel_quant', 'per_token_quant_int8', diff --git a/lmdeploy/pytorch/kernels/cuda/__init__.py b/lmdeploy/pytorch/kernels/cuda/__init__.py index 89d0076342..23cf5d33ea 100644 --- a/lmdeploy/pytorch/kernels/cuda/__init__.py +++ b/lmdeploy/pytorch/kernels/cuda/__init__.py @@ -4,11 +4,8 @@ from .fill_kv_cache import fill_kv_cache from .fused_moe import fused_moe from .fused_rotary_emb import fused_rotary_emb -from .mbgmm import mbgmm_a, mbgmm_b -from .mbgmv import mbgmv_a, mbgmv_b from .multinomial_sampling import multinomial_sampling from .pagedattention import paged_attention_fwd -from .rearange_all_gather import rearange_all_gather from .rms_norm import rms_norm from .w8a8_triton_kernels import (matmul_kernel_dynamic_quant, per_channel_quant, per_token_quant_int8, @@ -23,11 +20,6 @@ 'fill_kv_cache', 'multinomial_sampling', 'rms_norm', - 'mbgmv_a', - 'mbgmv_b', - 'mbgmm_a', - 'mbgmm_b', - 'rearange_all_gather', 'matmul_kernel_dynamic_quant', 'per_channel_quant', 'per_token_quant_int8', diff --git a/lmdeploy/pytorch/kernels/cuda/fused_lora.py b/lmdeploy/pytorch/kernels/cuda/fused_lora.py new file mode 100644 index 0000000000..e3ca030914 --- /dev/null +++ b/lmdeploy/pytorch/kernels/cuda/fused_lora.py @@ -0,0 +1,180 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import triton +import triton.language as tl + + +def get_autotune_config(): + """get autotune config.""" + return [ + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + ] + + +@triton.autotune( + configs=get_autotune_config(), + key=['N', 'K'], +) +@triton.jit +def _fused_lora_kernel( + a_ptr, + lora_a_ptr, + lora_b_ptr, + c_ptr, + scaling_ptr, + rank_start_ptr, + ranks_ptr, + seq_start_ptr, + seq_lens_ptr, + adapter_ids_ptr, + N: tl.constexpr, + K: tl.constexpr, + stride_am: tl.constexpr, + stride_ak: tl.constexpr, + stride_lar: tl.constexpr, + stride_lak: tl.constexpr, + stride_lbr: tl.constexpr, + stride_lbn: tl.constexpr, + stride_cm: tl.constexpr, + stride_cn: tl.constexpr, + BLOCK_SIZE_R: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + """fused lora kernel.""" + pid = tl.program_id(axis=0) + bid = tl.program_id(axis=1) + + M = tl.load(seq_lens_ptr + bid) + if M <= 0: + return + + seq_start = tl.load(seq_start_ptr + bid) + adapter_id = tl.load(adapter_ids_ptr + bid) + rank_start = tl.load(rank_start_ptr + adapter_id) + rank = tl.load(ranks_ptr + adapter_id) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + GROUP_SIZE_M: tl.constexpr = 1 + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + if pid_m * BLOCK_SIZE_M >= M: + return + + offs_m = (seq_start + pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) + + mask_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) < M + if rank == 0: + offs_cm = offs_m + offs_cn = offs_n + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[ + None, :] + c_mask = mask_cm[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, 0, mask=c_mask) + return + + offs_am = (seq_start + + (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M) + offs_r = rank_start + tl.arange(0, BLOCK_SIZE_R) % rank + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + + offs_k[None, :] * stride_ak) + la_ptrs = lora_a_ptr + (offs_k[:, None] * stride_lak + + offs_r[None, :] * stride_lar) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_R), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B + # If it is out of bounds, set it to 0. + a = tl.load(a_ptrs, + mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, + other=0.0) + la = tl.load(la_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0) + # We accumulate along the K dimension. + accumulator += tl.dot(a, la) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + la_ptrs += BLOCK_SIZE_K * stride_lak + ar = accumulator.to(lora_b_ptr.dtype.element_ty) + + offs_lbn = offs_n % N + lb_ptrs = lora_b_ptr + (offs_r[:, None] * stride_lbr + + offs_lbn * stride_lbn) + lb = tl.load(lb_ptrs, mask=tl.arange(0, BLOCK_SIZE_R)[:, None] < rank) + + c = tl.dot(ar, lb) + + scaling = tl.load(scaling_ptr + adapter_id) + c *= scaling + + c = c.to(c_ptr.dtype.element_ty) + offs_cm = offs_m + offs_cn = offs_n + c_ptrs = c_ptr + stride_cm * offs_cm[:, + None] + stride_cn * offs_cn[None, :] + c_mask = mask_cm[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def fused_lora(input: torch.Tensor, lora_a: torch.Tensor, lora_b: torch.Tensor, + scaling: torch.LongTensor, rank_start: torch.LongTensor, + ranks: torch.LongTensor, seq_start: torch.LongTensor, + seq_lens: torch.LongTensor, adapter_ids: torch.LongTensor, + max_rank: int, max_seqlen: int): + """fused lora.""" + + def grid(META): + ret = ((triton.cdiv(max_seqlen, META['BLOCK_SIZE_M']) * + triton.cdiv(N, META['BLOCK_SIZE_N'])), batch_size) + return ret + + assert input.dim() == 2 + batch_size = seq_lens.numel() + M, K = input.shape + BLOCK_SIZE_R = max(16, max_rank) + N = lora_b.size(1) + + output = torch.empty((M, N), dtype=input.dtype, device=input.device) + + _fused_lora_kernel[grid]( + input, + lora_a, + lora_b, + output, + scaling, + rank_start, + ranks, + seq_start, + seq_lens, + adapter_ids, + N, + K, + stride_am=input.stride(0), + stride_ak=input.stride(1), + stride_lar=lora_a.stride(0), + stride_lak=lora_a.stride(1), + stride_lbr=lora_b.stride(0), + stride_lbn=lora_b.stride(1), + stride_cm=output.stride(0), + stride_cn=output.stride(1), + BLOCK_SIZE_R=BLOCK_SIZE_R, + ) + + return output diff --git a/lmdeploy/pytorch/kernels/cuda/mbgmm.py b/lmdeploy/pytorch/kernels/cuda/mbgmm.py deleted file mode 100644 index d91cacd751..0000000000 --- a/lmdeploy/pytorch/kernels/cuda/mbgmm.py +++ /dev/null @@ -1,276 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import triton -import triton.language as tl -from torch import Tensor - -from .triton_utils import get_kernel_meta, wrap_jit_func - - -def _next_pow_of_2(x): - """get next power of 2.""" - return 1 << (x - 1).bit_length() - - -@wrap_jit_func -@triton.jit -def _x_a_mm_kernel( - X, - LoRA_A, - XA, - B_start_loc, - B_seq_lens, - B_adapter_id, - Rank_offset, - Ranks, - stride_xs, - stride_xh, - stride_xas, - stride_xar, - stride_ptb, - stride_r, - rank_step, - BLOCK_M: tl.constexpr, - BLOCK_R: tl.constexpr, - BLOCK_H: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, -): - """xa mm kernel.""" - cur_batch = tl.program_id(0) - start_m = tl.program_id(1) - - r_off = tl.arange(0, BLOCK_R) - - seq_len = tl.load(B_seq_lens + cur_batch) - if start_m * BLOCK_M >= seq_len: - return - - start_loc = tl.load(B_start_loc + cur_batch) - adapter_id = tl.load(B_adapter_id + cur_batch) - rank = tl.load(Ranks + adapter_id * stride_r) // rank_step - - rank_off = adapter_id * stride_ptb + r_off - rank_mask = r_off < rank - - m_off = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - dm_off = tl.arange(0, BLOCK_DMODEL) - - x_off = (start_loc + m_off) * stride_xs - xs_mask = m_off < seq_len - la_page_off = tl.load(Rank_offset + rank_off, mask=rank_mask) - acc = tl.zeros((BLOCK_M, BLOCK_R), dtype=tl.float32) - - # compute acc - for start_h in range(0, BLOCK_H, BLOCK_DMODEL): - cur_dm_off = start_h + dm_off - h_mask = cur_dm_off < BLOCK_H - - # load x - xh_off = cur_dm_off * stride_xh - x_mask = xs_mask[:, None] and h_mask[None, :] - x = tl.load(X + x_off[:, None] + xh_off[None, :], - mask=x_mask, - other=0.0) - - # load lora a - lah_off = cur_dm_off - la_mask = rank_mask[None, :] and h_mask[:, None] - la = tl.load(LoRA_A + la_page_off[None, :] + lah_off[:, None], - mask=la_mask, - other=0.0) - - # compute - acc += tl.dot(x, la) - - acc = acc.to(X.dtype.element_ty) - xa_off = (start_loc + m_off) * stride_xas - xas_mask = xs_mask - xa_mask = xas_mask[:, None] and rank_mask[None, :] - tl.store(XA + xa_off[:, None] + r_off[None, :] * stride_xar, - acc, - mask=xa_mask) - - -@wrap_jit_func -@triton.jit -def _acc_b_mm_kernel( - XA, - LoRA_B, - Out, - B_start_loc, - B_seq_lens, - B_adapter_id, - B_scaling, - Rank_offset, - Ranks, - stride_xas, - stride_xar, - stride_os, - stride_oh, - stride_ptb, - stride_r, - stride_s, - BLOCK_M: tl.constexpr, - BLOCK_R: tl.constexpr, - BLOCK_HO: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, -): - cur_batch = tl.program_id(0) - start_m = tl.program_id(1) - - r_off = tl.arange(0, BLOCK_R) - - seq_len = tl.load(B_seq_lens + cur_batch) - if start_m * BLOCK_M >= seq_len: - return - - start_loc = tl.load(B_start_loc + cur_batch) - adapter_id = tl.load(B_adapter_id + cur_batch) - scaling = tl.load(B_scaling + adapter_id * stride_s) - rank = tl.load(Ranks + adapter_id * stride_r) - - rank_off = adapter_id * stride_ptb + r_off - rank_mask = r_off < rank - - m_off = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - dm_off = tl.arange(0, BLOCK_DMODEL) - lb_page_off = tl.load(Rank_offset + rank_off, mask=rank_mask) - - xs_mask = m_off < seq_len - o_off = (start_loc + m_off) * stride_os - os_mask = xs_mask - - xa_off = (start_loc + m_off) * stride_xas - xa_mask = xs_mask[:, None] and rank_mask[None, :] - acc = tl.load(XA + xa_off[:, None] + r_off[None, :] * stride_xar, - mask=xa_mask, - other=0.0) - acc = acc.to(LoRA_B.dtype.element_ty) - - # compute output - for start_h in range(0, BLOCK_HO, BLOCK_DMODEL): - cur_dm_off = start_h + dm_off - h_mask = cur_dm_off < BLOCK_HO - - # load lora b - lbh_off = cur_dm_off - lb_mask = rank_mask[:, None] and h_mask[None, :] - lb = tl.load(LoRA_B + lb_page_off[:, None] + lbh_off[None, :], - mask=lb_mask, - other=0) - - # compute - out = tl.dot(acc, lb) - out = out.to(lb.dtype) - out = out * scaling - - # store o - oh_off = cur_dm_off * stride_oh - o_mask = os_mask[:, None] and h_mask[None, :] - tl.store(Out + o_off[:, None] + oh_off[None, :], out, mask=o_mask) - - -def mbgmm_a(x: Tensor, - lora_a: Tensor, - q_start_loc: Tensor, - q_seqlens: Tensor, - adapter_ids: Tensor, - rank_offset: Tensor, - ranks: Tensor, - max_seq_len: int, - max_rank: int, - rank_step: int = 1): - """mbgmm_a.""" - - head_size = x.size(-1) - batch_size = len(q_seqlens) - max_rank = max_rank // rank_step - - BLOCK_M = 32 - BLOCK_R = _next_pow_of_2(max_rank) - if BLOCK_R < 16: - BLOCK_R = 16 - BLOCK_H = head_size - BLOCK_DMODEL = 64 - - num_warps = 4 - grid = [batch_size, triton.cdiv(max_seq_len, BLOCK_M)] - xa = x.new_empty((x.size(0), max_rank)) - kernel_meta = get_kernel_meta(x) - _x_a_mm_kernel[grid](x, - lora_a, - xa, - q_start_loc, - q_seqlens, - adapter_ids, - Rank_offset=rank_offset, - Ranks=ranks, - stride_xs=x.stride(0), - stride_xh=x.stride(1), - stride_xas=xa.stride(0), - stride_xar=xa.stride(1), - stride_ptb=rank_offset.stride(0), - stride_r=ranks.stride(0), - rank_step=rank_step, - BLOCK_M=BLOCK_M, - BLOCK_R=BLOCK_R, - BLOCK_H=BLOCK_H, - BLOCK_DMODEL=BLOCK_DMODEL, - num_warps=num_warps, - num_stages=1, - **kernel_meta) - return xa - - -def mbgmm_b(xa: Tensor, - lora_b: Tensor, - q_start_loc: Tensor, - q_seqlens: Tensor, - adapter_ids: Tensor, - scaling: Tensor, - rank_offset: Tensor, - ranks: Tensor, - max_seq_len: int, - max_rank: int, - out_size: int = None): - """mbgmm_b.""" - - if out_size is None: - out_size = lora_b.size(-1) - batch_size = len(q_seqlens) - - BLOCK_M = 32 - BLOCK_R = _next_pow_of_2(max_rank) - if BLOCK_R < 16: - BLOCK_R = 16 - BLOCK_HO = out_size - BLOCK_DMODEL = 64 - - num_warps = 4 - grid = [batch_size, triton.cdiv(max_seq_len, BLOCK_M)] - output = xa.new_empty((xa.size(0), BLOCK_HO)) - kernel_meta = get_kernel_meta(xa) - _acc_b_mm_kernel[grid](xa, - lora_b, - output, - q_start_loc, - q_seqlens, - adapter_ids, - scaling, - Rank_offset=rank_offset, - Ranks=ranks, - stride_xas=xa.stride(0), - stride_xar=xa.stride(1), - stride_os=output.stride(0), - stride_oh=output.stride(1), - stride_ptb=rank_offset.stride(0), - stride_r=ranks.stride(0), - stride_s=scaling.stride(0), - BLOCK_M=BLOCK_M, - BLOCK_R=BLOCK_R, - BLOCK_HO=BLOCK_HO, - BLOCK_DMODEL=BLOCK_DMODEL, - num_warps=num_warps, - num_stages=1, - **kernel_meta) - - return output diff --git a/lmdeploy/pytorch/kernels/cuda/mbgmv.py b/lmdeploy/pytorch/kernels/cuda/mbgmv.py deleted file mode 100644 index c042e71d5d..0000000000 --- a/lmdeploy/pytorch/kernels/cuda/mbgmv.py +++ /dev/null @@ -1,223 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import triton -import triton.language as tl -from torch import Tensor - -from .triton_utils import get_kernel_meta, wrap_jit_func - - -def _next_pow_of_2(x): - """get next power of 2.""" - return 1 << (x - 1).bit_length() - - -@wrap_jit_func -@triton.jit -def _x_a_mv_kernel( - X, - LoRA_A, - XA, - B_adapter_id, - Rank_offset, - Ranks, - stride_xs, - stride_xh, - stride_xas, - stride_xar, - stride_ptb, - stride_r, - rank_step, - BLOCK_R: tl.constexpr, - BLOCK_H: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, -): - """xa mv kernel.""" - cur_batch = tl.program_id(0) - - r_off = tl.arange(0, BLOCK_R) - adapter_id = tl.load(B_adapter_id + cur_batch) - rank = tl.load(Ranks + adapter_id * stride_r) // rank_step - - rank_off = adapter_id * stride_ptb + r_off - rank_mask = r_off < rank - - dm_off = tl.arange(0, BLOCK_DMODEL) - - x_off = cur_batch * stride_xs - la_page_off = tl.load(Rank_offset + rank_off, mask=rank_mask) - acc = tl.zeros((BLOCK_R, ), dtype=tl.float32) - - # compute acc - for start_h in range(0, BLOCK_H, BLOCK_DMODEL): - cur_dm_off = start_h + dm_off - h_mask = cur_dm_off < BLOCK_H - - # load x - xh_off = cur_dm_off * stride_xh - x_mask = h_mask - x = tl.load(X + x_off + xh_off, mask=x_mask, other=0.0) - - # load lora a - lah_off = cur_dm_off - la_mask = rank_mask[:, None] and h_mask[None, :] - la = tl.load(LoRA_A + la_page_off[:, None] + lah_off[None, :], - mask=la_mask, - other=0.0) - - # compute - acc += tl.sum(x[None, :] * la, 1) - - acc = acc.to(X.dtype.element_ty) - xa_off = cur_batch * stride_xas - tl.store(XA + xa_off + r_off * stride_xar, acc, mask=rank_mask) - - -@wrap_jit_func -@triton.jit -def _acc_b_mv_kernel( - XA, - LoRA_B, - Out, - B_adapter_id, - B_scaling, - Rank_offset, - Ranks, - stride_xas, - stride_xar, - stride_os, - stride_oh, - stride_ptb, - stride_r, - stride_s, - BLOCK_R: tl.constexpr, - BLOCK_HO: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, -): - """acc b mv kernel.""" - cur_batch = tl.program_id(0) - - r_off = tl.arange(0, BLOCK_R) - adapter_id = tl.load(B_adapter_id + cur_batch) - scaling = tl.load(B_scaling + adapter_id * stride_s) - rank = tl.load(Ranks + adapter_id * stride_r) - - rank_off = adapter_id * stride_ptb + r_off - rank_mask = r_off < rank - - dm_off = tl.arange(0, BLOCK_DMODEL) - lb_page_off = tl.load(Rank_offset + rank_off, mask=rank_mask) - - o_off = cur_batch * stride_os - - xa_off = cur_batch * stride_xas - acc = tl.load(XA + xa_off + r_off * stride_xar, mask=rank_mask, other=0.0) - - # compute output - for start_h in range(0, BLOCK_HO, BLOCK_DMODEL): - cur_dm_off = start_h + dm_off - h_mask = cur_dm_off < BLOCK_HO - - # load lora b - lbh_off = cur_dm_off - lb_mask = rank_mask[:, None] and h_mask[None, :] - lb = tl.load(LoRA_B + lb_page_off[:, None] + lbh_off[None, :], - mask=lb_mask, - other=0) - - # compute - out = tl.sum(acc[:, None] * lb, 0) - out = out.to(lb.dtype) - out = out * scaling - - # store o - oh_off = cur_dm_off * stride_oh - tl.store(Out + o_off + oh_off, out, mask=h_mask) - - -def mbgmv_a(x: Tensor, - lora_a: Tensor, - adapter_ids: Tensor, - rank_offset: Tensor, - ranks: Tensor, - max_rank: int, - rank_step: int = 1): - """mbgmv_a.""" - - head_size = x.size(-1) - batch_size = x.size(0) - max_rank = max_rank // rank_step - - BLOCK_R = _next_pow_of_2(max_rank) - BLOCK_H = head_size - BLOCK_DMODEL = 512 - - num_warps = 4 - grid = [batch_size] - xa = x.new_empty((x.size(0), BLOCK_R)) - kernel_meta = get_kernel_meta(x) - _x_a_mv_kernel[grid](x, - lora_a, - xa, - adapter_ids, - Rank_offset=rank_offset, - Ranks=ranks, - stride_xs=x.stride(0), - stride_xh=x.stride(1), - stride_xas=xa.stride(0), - stride_xar=xa.stride(1), - stride_ptb=rank_offset.stride(0), - stride_r=ranks.stride(0), - rank_step=rank_step, - BLOCK_R=BLOCK_R, - BLOCK_H=BLOCK_H, - BLOCK_DMODEL=BLOCK_DMODEL, - num_warps=num_warps, - num_stages=1, - **kernel_meta) - return xa - - -def mbgmv_b(xa: Tensor, - lora_b: Tensor, - adapter_ids: Tensor, - scaling: Tensor, - rank_offset: Tensor, - ranks: Tensor, - max_rank: int, - out_size: int = None): - """mbgmv_b.""" - - if out_size is None: - out_size = lora_b.size(-1) - batch_size = xa.size(0) - - BLOCK_R = _next_pow_of_2(max_rank) - BLOCK_HO = out_size - BLOCK_DMODEL = 512 - - num_warps = 4 - grid = [batch_size] - output = xa.new_empty((xa.size(0), BLOCK_HO)) - kernel_meta = get_kernel_meta(xa) - _acc_b_mv_kernel[grid](xa, - lora_b, - output, - adapter_ids, - scaling, - Rank_offset=rank_offset, - Ranks=ranks, - stride_xas=xa.stride(0), - stride_xar=xa.stride(1), - stride_os=output.stride(0), - stride_oh=output.stride(1), - stride_ptb=rank_offset.stride(0), - stride_r=ranks.stride(0), - stride_s=scaling.stride(0), - BLOCK_R=BLOCK_R, - BLOCK_HO=BLOCK_HO, - BLOCK_DMODEL=BLOCK_DMODEL, - num_warps=num_warps, - num_stages=1, - **kernel_meta) - - return output diff --git a/lmdeploy/pytorch/kernels/cuda/rearange_all_gather.py b/lmdeploy/pytorch/kernels/cuda/rearange_all_gather.py deleted file mode 100644 index 7eddc4839c..0000000000 --- a/lmdeploy/pytorch/kernels/cuda/rearange_all_gather.py +++ /dev/null @@ -1,130 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch -import triton -import triton.language as tl - -from .triton_utils import get_kernel_meta, wrap_jit_func - - -@wrap_jit_func -@triton.jit -def _rearange_all_gather_kernel(X, StartLoc, SeqLen, AdapterIds, Ranks, Out, - stride_x, stride_o, world_size, - BLOCK: tl.constexpr, BLOCK_P: tl.constexpr): - """rearange all gather kernel.""" - batch_id = tl.program_id(0) - block_id = tl.program_id(1) - - start_loc = tl.load(StartLoc + batch_id) + block_id * BLOCK - seq_len = tl.load(SeqLen + batch_id) - - if block_id * BLOCK >= seq_len: - return - - block_off = start_loc + tl.arange(0, BLOCK) - block_mask = block_id * BLOCK + tl.arange(0, BLOCK) < seq_len - - adapter_id = tl.load(AdapterIds + batch_id) - rank = tl.load(Ranks + adapter_id) - prank = rank // world_size - p_off = tl.arange(0, BLOCK_P) - - for p_id in range(world_size): - ip_off = p_id * BLOCK_P + p_off - i_mask = block_mask[:, None] and (p_off < prank)[None, :] - i_off = block_off[:, None] * stride_x + ip_off[None, :] - x = tl.load(X + i_off, mask=i_mask) - - op_off = p_id * prank + p_off - o_mask = i_mask - o_off = block_off[:, None] * stride_o + op_off[None, :] - tl.store(Out + o_off, x, mask=o_mask) - - -@wrap_jit_func -@triton.jit -def _rearange_all_gather_decoding_kernel(X, AdapterIds, Ranks, Out, stride_x, - stride_o, world_size, seq_len, - BLOCK: tl.constexpr, - BLOCK_P: tl.constexpr): - """rearange all gather kernel.""" - block_id = tl.program_id(0) - block_off = block_id * BLOCK + tl.arange(0, BLOCK) - block_mask = block_off < seq_len - - adapter_ids = tl.load(AdapterIds + block_off, mask=block_mask) - ranks = tl.load(Ranks + adapter_ids) - pranks = ranks // world_size - p_off = tl.arange(0, BLOCK_P) - - for p_id in range(world_size): - ip_off = p_id * BLOCK_P + p_off - i_mask = block_mask[:, None] and (p_off[None, :] < pranks[:, None]) - i_off = block_off[:, None] * stride_x + ip_off[None, :] - x = tl.load(X + i_off, mask=i_mask) - - op_off = p_id * pranks[:, None] + p_off[None, :] - o_mask = i_mask - o_off = block_off[:, None] * stride_o + op_off - tl.store(Out + o_off, x, mask=o_mask) - - -def rearange_all_gather(x: torch.Tensor, - b_start_loc: torch.Tensor, - b_seq_lens: torch.Tensor, - adapter_ids: torch.LongTensor, - ranks: torch.Tensor, - world_size: int, - max_seq_len: int, - output: torch.Tensor = None): - """rearange all gather.""" - - max_rank = x.size(1) - batch_size = len(b_seq_lens) - partition_size = max_rank // world_size - - if output is None: - output = torch.empty_like(x) - - num_warps = 4 - kernel_meta = get_kernel_meta(x) - - is_decoding = batch_size == x.size(0) - if not is_decoding: - BLOCK = 128 - BLOCK_P = partition_size - grid = (batch_size, triton.cdiv(max_seq_len, BLOCK)) - _rearange_all_gather_kernel[grid](x, - b_start_loc, - b_seq_lens, - adapter_ids, - ranks, - output, - stride_x=x.stride(0), - stride_o=output.stride(0), - world_size=world_size, - BLOCK=BLOCK, - BLOCK_P=BLOCK_P, - num_warps=num_warps, - num_stages=1, - **kernel_meta) - else: - BLOCK = 64 - BLOCK_P = partition_size - seq_len = x.size(0) - grid = (triton.cdiv(seq_len, BLOCK), ) - _rearange_all_gather_decoding_kernel[grid](x, - adapter_ids, - ranks, - output, - stride_x=x.stride(0), - stride_o=output.stride(0), - world_size=world_size, - seq_len=seq_len, - BLOCK=BLOCK, - BLOCK_P=BLOCK_P, - num_warps=num_warps, - num_stages=1, - **kernel_meta) - - return output diff --git a/lmdeploy/pytorch/kernels/mbgmm.py b/lmdeploy/pytorch/kernels/mbgmm.py deleted file mode 100644 index ddc5bacf43..0000000000 --- a/lmdeploy/pytorch/kernels/mbgmm.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .dispatcher import FunctionDispatcher - -mbgmm_a = FunctionDispatcher('mbgmm_a').make_caller() -mbgmm_b = FunctionDispatcher('mbgmm_b').make_caller() diff --git a/lmdeploy/pytorch/kernels/mbgmv.py b/lmdeploy/pytorch/kernels/mbgmv.py deleted file mode 100644 index 5252119667..0000000000 --- a/lmdeploy/pytorch/kernels/mbgmv.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .dispatcher import FunctionDispatcher - -mbgmv_a = FunctionDispatcher('mbgmv_a').make_caller() - -mbgmv_b = FunctionDispatcher('mbgmv_b').make_caller() diff --git a/lmdeploy/pytorch/kernels/rearange_all_gather.py b/lmdeploy/pytorch/kernels/rearange_all_gather.py deleted file mode 100644 index aec3fcb04b..0000000000 --- a/lmdeploy/pytorch/kernels/rearange_all_gather.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .dispatcher import FunctionDispatcher - -rearange_all_gather = FunctionDispatcher('rearange_all_gather').make_caller() diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index 27ed28defb..10052825e6 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -7,45 +7,6 @@ from lmdeploy.pytorch.backends import get_backend -from .adapter.adapter import SchedulerAdapter - - -@dataclass -class AdapterInfo: - adapter_ids: torch.LongTensor - rank_offsets: torch.LongTensor - - @classmethod - def from_adapters(cls, adapters: List[SchedulerAdapter]): - """from adapters.""" - if len(adapters) == 0: - return None - adapter_ids = [ada.adapter_id for ada in adapters] - adapter_ids = torch.tensor(adapter_ids) - rank_offsets = [torch.from_numpy(ada.rank_offset) for ada in adapters] - rank_offsets = torch.stack(rank_offsets) - - return cls( - adapter_ids=adapter_ids, - rank_offsets=rank_offsets, - ) - - def update_offsets(self, rank_offsets: torch.LongTensor): - """update rank offsets.""" - rank_offsets[self.adapter_ids] = self.rank_offsets - - def to_device(self, device: str): - """to device.""" - out_dict = dict() - for f in fields(self): - k = f.name - v = getattr(self, k) - if isinstance(v, torch.Tensor): - v = v.to(device) - out_dict[k] = v - - return AdapterInfo(**out_dict) - @dataclass class VisionModelInputs: @@ -118,7 +79,6 @@ class ModelInputs: is_decoding: bool num_ignored_history: torch.LongTensor local_adapter_ids: torch.LongTensor = None - adapter_info: AdapterInfo = None vision_inputs: VisionModelInputs = None def update(self, input_ids: torch.LongTensor): @@ -162,7 +122,6 @@ def split(self, split_size: int, block_size: int): is_decoding=self.is_decoding, num_ignored_history=self.num_ignored_history, local_adapter_ids=self.local_adapter_ids, - adapter_info=self.adapter_info, vision_inputs=self.vision_inputs, ) ret.append(inp) @@ -180,8 +139,6 @@ def to_device(self, device: str): v = v.to(device) elif isinstance(v, VisionModelInputs): v = v.to_device(device) - elif isinstance(v, AdapterInfo): - v = v.to_device(device) out_dict[k] = v return ModelInputs(**out_dict) @@ -205,7 +162,6 @@ class StepContext: is_decoding: bool world_size: int = 1 local_adapter_ids: torch.LongTensor = None - adapter_params: Dict[str, AdapterInfo] = None input_embeddings: torch.Tensor = None input_embedding_indexing: torch.Tensor = None vision_inputs: VisionModelInputs = None diff --git a/lmdeploy/pytorch/models/patch.py b/lmdeploy/pytorch/models/patch.py index cae8f43cf1..d6c6d69c5c 100644 --- a/lmdeploy/pytorch/models/patch.py +++ b/lmdeploy/pytorch/models/patch.py @@ -4,7 +4,7 @@ import os.path as osp import re import sys -from typing import Any, Dict, List +from typing import Any, Dict import torch from transformers.configuration_utils import PretrainedConfig @@ -202,18 +202,17 @@ def build_patched_model(config: ModelConfig, device: torch.device = None): @torch.inference_mode() def add_adapters(model: torch.nn.Module, - kv_caches: List[List[torch.Tensor]], adapters: Dict[str, str], + dtype: torch.dtype = torch.float16, device: torch.device = None): """add adapters.""" from peft import PeftConfig from peft.tuners.lora import LoraConfig - from lmdeploy.pytorch.adapter.adapter import (LoRATargetInfo, - find_all_target, - get_layer_index, - get_ranks_and_scalings) - from lmdeploy.pytorch.nn.linear import SLoRA + from lmdeploy.pytorch.adapter.adapter import (find_all_target, + get_ranks_and_scalings, + load_lora_weights) + from lmdeploy.pytorch.nn.linear import LoRA num_adapters = len(adapters) if num_adapters == 0: return @@ -222,56 +221,43 @@ def add_adapters(model: torch.nn.Module, device = torch.device('cuda') # model could be graph runner - origin_model = model if hasattr(model, 'get_model'): model = model.get_model() ctx_mgr = model.ctx_mgr + adapter_names = list(adapters.keys()) + adapter_names = sorted(adapter_names) + adapter_cfgs = [ - PeftConfig.from_pretrained(path) for path in adapters.values() + PeftConfig.from_pretrained(adapters[name]) for name in adapter_names ] - # get layer pattern (should be same between different adapter) - config = next(iter(adapter_cfgs)) - layers_pattern = getattr(config, 'layers_pattern', None) # insert one for no adapter adapter_cfgs = [LoraConfig(r=0, target_modules=[])] + adapter_cfgs + adapter_names = [None] + adapter_names + adapter_id_map = dict(zip(adapter_names, range(len(adapter_names)))) # target layer name to add adapter target_names = set() - max_rank = 0 for cfg in adapter_cfgs: target_names = target_names.union(cfg.target_modules) - max_rank = max(max_rank, cfg.r) target_names = list(target_names) target_names = sorted(target_names) - num_targets = len(target_names) - - # get rank offsets - # add 1 for none adapter - rank_offsets = torch.zeros(num_adapters + 1, - num_targets * max_rank, - dtype=torch.int64, - device=device) target_infos = dict() - for target_idx, target_name in enumerate(target_names): + for _, target_name in enumerate(target_names): # get ranks and scalings ranks, scalings = get_ranks_and_scalings(target_name, adapter_cfgs, device=device) found_mods, pack_idx = find_all_target(model, target_name) - r_start = target_idx * max_rank - r_end = r_start + max_rank - r_offs = rank_offsets[:, r_start:r_end] + sum_rank = ranks.sum().item() in_features = 0 out_features = 0 colwise = True - for name, mod in found_mods: + for _, mod in found_mods: assert hasattr(mod, 'lora_adapters') - layer_idx = get_layer_index(name, layers_pattern) - k_cache, v_cache = kv_caches[layer_idx] in_features = mod.in_features colwise = mod.colwise if pack_idx is None: @@ -281,28 +267,36 @@ def add_adapters(model: torch.nn.Module, prev_feats = sum(mod.all_out_features[:pack_idx]) out_features = mod.all_out_features[pack_idx] base_slice = slice(prev_feats, prev_feats + out_features) - - slora = SLoRA( + lora_a = torch.empty((sum_rank, in_features), + dtype=dtype, + device=device) + lora_b = torch.empty((sum_rank, out_features), + dtype=dtype, + device=device) + + lora = LoRA( in_features, out_features, ranks=ranks, scalings=scalings, - rank_offsets=r_offs, - a_cache=k_cache, - b_cache=v_cache, + lora_a=lora_a, + lora_b=lora_b, base_slice=base_slice, - max_rank=max_rank, ctx_mgr=ctx_mgr, colwise=colwise, is_tp=mod.is_tp, ) - mod.lora_adapters.append(slora) + mod.lora_adapters[target_name] = lora + + # fill adapter data + for name, path in adapters.items(): + adapter_id = adapter_id_map[name] + checkpoint_path = f'{path}/adapter_model.bin' + state_dict = torch.load(checkpoint_path, map_location=device) - target_info = LoRATargetInfo(in_features=in_features, - out_features=out_features, - colwise=colwise) - target_infos[target_name] = target_info + if hasattr(model, 'load_lora_weights'): + model.load_lora_weights(state_dict.items(), adapter_id=adapter_id) + else: + load_lora_weights(model, state_dict.items(), adapter_id=adapter_id) - # add rank_offsets - setattr(origin_model, 'rank_offsets', rank_offsets) return target_infos diff --git a/lmdeploy/pytorch/nn/linear.py b/lmdeploy/pytorch/nn/linear.py index c1927df623..8f3b0bfc1a 100644 --- a/lmdeploy/pytorch/nn/linear.py +++ b/lmdeploy/pytorch/nn/linear.py @@ -10,7 +10,7 @@ from lmdeploy.utils import get_logger from ..backends import OpType, get_backend -from ..backends.slora import AdapterInfo +from ..backends.lora import AdapterInfo from .utils import div_up, get_distribute_size, get_world_rank logger = get_logger('lmdeploy') @@ -73,19 +73,17 @@ def split_qkv(self, x: torch.Tensor): return q, k, v -class SLoRA(nn.Module): - """SLoRA layer.""" +class LoRA(nn.Module): + """LoRA layer.""" def __init__(self, in_features: int, out_features: int, ranks: torch.Tensor, scalings: torch.Tensor, - rank_offsets: torch.Tensor, - a_cache: torch.Tensor, - b_cache: torch.Tensor, + lora_a: torch.Tensor, + lora_b: torch.Tensor, base_slice: slice, - max_rank: int, ctx_mgr: Any = None, colwise: bool = True, is_tp: bool = True): @@ -95,14 +93,17 @@ def __init__(self, out_features=out_features, ranks=ranks, scalings=scalings, - rank_offsets=rank_offsets, - a_cache=a_cache, - b_cache=b_cache, base_slice=base_slice, - max_rank=max_rank, ) - impl_builder = get_backend().get_layer_impl_builder(OpType.SLoRA) + impl_builder = get_backend().get_layer_impl_builder(OpType.LoRA) self.impl = impl_builder.build() + + lora_A = nn.Parameter(lora_a, requires_grad=False) + lora_B = nn.Parameter(lora_b, requires_grad=False) + self.register_parameter('lora_A', lora_A) + self.register_parameter('lora_B', lora_B) + lora_A.weight_loader = self.weight_loader_A + lora_B.weight_loader = self.weight_loader_B self.is_tp = is_tp self.ctx_mgr = ctx_mgr self.colwise = colwise @@ -110,12 +111,43 @@ def __init__(self, def forward(self, x, base_output=None): """forward of loraA@loraB.""" return self.impl.forward(x, + self.lora_A, + self.lora_B, base_output, self.adapter_info, ctx_mgr=self.ctx_mgr, colwise=self.colwise, is_tp=self.is_tp) + def weight_loader_A(self, param: nn.Parameter, loaded_weight: torch.Tensor, + adapter_id: int): + """weight loader.""" + rank = self.adapter_info.ranks[adapter_id].item() + r_start = self.adapter_info.rank_offsets[adapter_id].item() + r_end = r_start + rank + param_r = param.data[r_start:r_end] + + if self.is_tp and not self.colwise: + world_size, rank = get_world_rank() + loaded_weight = loaded_weight.chunk(world_size, dim=1)[rank] + + param_r.copy_(loaded_weight) + + def weight_loader_B(self, param: nn.Parameter, loaded_weight: torch.Tensor, + adapter_id: int): + """weight loader.""" + rank = self.adapter_info.ranks[adapter_id].item() + r_start = self.adapter_info.rank_offsets[adapter_id].item() + r_end = r_start + rank + param_r = param.data[r_start:r_end] + + loaded_weight = loaded_weight.t() + if self.is_tp and self.colwise: + world_size, rank = get_world_rank() + loaded_weight = loaded_weight.chunk(world_size, dim=1)[rank] + + param_r.copy_(loaded_weight) + class AwqLinear(nn.Module): """w4a16 linear.""" @@ -171,7 +203,7 @@ def __init__( self.w_bit = w_bit self.group_size = group_size self.elem_per_int = 32 // self.w_bit - self.lora_adapters = [] + self.lora_adapters = nn.ModuleDict() self.is_tp = is_tp self.colwise = colwise self.all_reduce = all_reduce @@ -305,7 +337,7 @@ def forward(self, x): out = self.impl.forward(x, self.qweight, self.scales, self.qzeros, self.bias, False) if self.lora_adapters is not None: - for lora_adapter in self.lora_adapters: + for lora_adapter in self.lora_adapters.values(): out = lora_adapter(x, out) if all_reduce: dist.all_reduce(out) @@ -560,7 +592,7 @@ def __init__( self.in_features = in_features self.out_features = out_features - self.lora_adapters = [] + self.lora_adapters = nn.ModuleDict() self.is_tp = is_tp self.colwise = colwise self.all_reduce = all_reduce @@ -651,7 +683,7 @@ def forward(self, x): all_reduce) out = self.impl.forward(x, self.weight, self.scale, self.bias, False) - for lora_adapter in self.lora_adapters: + for lora_adapter in self.lora_adapters.values(): out = lora_adapter(x, out) if all_reduce: dist.all_reduce(out) @@ -833,7 +865,7 @@ def __init__( self.in_features = in_features self.out_features = out_features - self.lora_adapters = [] + self.lora_adapters = nn.ModuleDict() self.is_tp = is_tp self.colwise = colwise self.all_reduce = all_reduce @@ -921,7 +953,7 @@ def forward(self, x): return self.impl.forward(x, self.weight, self.bias, all_reduce) out = self.impl.forward(x, self.weight, self.bias, False) - for lora_adapter in self.lora_adapters: + for lora_adapter in self.lora_adapters.values(): out = lora_adapter(x, out) if all_reduce: dist.all_reduce(out) diff --git a/lmdeploy/pytorch/paging/block_manager/__init__.py b/lmdeploy/pytorch/paging/block_manager/__init__.py index becc1f2836..82b4a9bace 100644 --- a/lmdeploy/pytorch/paging/block_manager/__init__.py +++ b/lmdeploy/pytorch/paging/block_manager/__init__.py @@ -1,14 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any - from ...config import CacheConfig from .base_block_manager import BaseBlockManager from .default_block_manager import DefaultBlockManager from .window_block_manager import WindowBlockManager -def build_block_manager(cache_config: CacheConfig, - adapter_manager: Any = None) -> BaseBlockManager: +def build_block_manager(cache_config: CacheConfig) -> BaseBlockManager: """build block manager. Args: @@ -20,10 +17,8 @@ def build_block_manager(cache_config: CacheConfig, window_size = cache_config.window_size if window_size < 0: - return DefaultBlockManager(num_gpu_blocks, num_cpu_blocks, - adapter_manager) + return DefaultBlockManager(num_gpu_blocks, num_cpu_blocks) else: return WindowBlockManager(num_gpu_blocks, num_cpu_blocks, - window_size=window_size, - adapter_manager=adapter_manager) + window_size=window_size) diff --git a/lmdeploy/pytorch/paging/block_manager/base_block_manager.py b/lmdeploy/pytorch/paging/block_manager/base_block_manager.py index 0630c7a312..ef6709624b 100644 --- a/lmdeploy/pytorch/paging/block_manager/base_block_manager.py +++ b/lmdeploy/pytorch/paging/block_manager/base_block_manager.py @@ -1,10 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. import time -from typing import Dict, Union +from typing import Dict import numpy as np -from ...adapter.adapter import AdapterManager, SchedulerAdapter from ...messages import SchedulerSequence @@ -236,10 +235,7 @@ class BaseBlockManager: num_cpu_blocks (int): number of cpu blocks. """ - def __init__(self, - num_gpu_blocks: int, - num_cpu_blocks: int, - adapter_manager: AdapterManager = None) -> None: + def __init__(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: self.num_gpu_blocks = num_gpu_blocks self.num_cpu_blocks = num_cpu_blocks @@ -247,13 +243,9 @@ def __init__(self, self.block_tables: Dict[int, BlockTable] = {} - if adapter_manager is None: - adapter_manager = AdapterManager(dict(), dict(), 0) - self.adapter_manager = adapter_manager - @classmethod def num_required_blocks(cls, - obj: Union[SchedulerSequence, SchedulerAdapter], + obj: SchedulerSequence, prealloc_size: int = 0): """get num required blocks.""" raise NotImplementedError('Not implemented.') @@ -272,23 +264,19 @@ def allocate_msg(self, msg: SchedulerSequence, prealloc_size: int = 0): blocks.""" raise NotImplementedError('Not implemented.') - def allocate_adapter(self, adapter: SchedulerAdapter): - """Allocate cpu blocks for given adapter.""" - raise NotImplementedError('Not implemented.') - def free(self, msg: SchedulerSequence): """Free all physical blocks allocated for the session.""" raise NotImplementedError('Not implemented.') - def try_swap_out(self, msg: Union[SchedulerSequence, SchedulerAdapter]): + def try_swap_out(self, msg: SchedulerSequence): """Try swap msg out.""" raise NotImplementedError('Not implemented.') - def try_swap_in(self, msg: Union[SchedulerSequence, SchedulerAdapter]): + def try_swap_in(self, msg: SchedulerSequence): """Try swap msg in.""" raise NotImplementedError('Not implemented.') - def get_block_table(self, msg: Union[SchedulerSequence, SchedulerAdapter]): + def get_block_table(self, msg: SchedulerSequence): """Get the block table of given msg. Args: @@ -298,14 +286,10 @@ def get_block_table(self, msg: Union[SchedulerSequence, SchedulerAdapter]): return self.allocator.get_physical_blocks( logical_blocks.get_real_blocks()) - def allocate(self, - data: Union[SchedulerSequence, SchedulerAdapter], - prealloc_size: int = 0): + def allocate(self, data: SchedulerSequence, prealloc_size: int = 0): """allocate stuff.""" if isinstance(data, SchedulerSequence): return self.allocate_msg(data, prealloc_size) - elif isinstance(data, SchedulerAdapter): - return self.allocate_adapter(data) else: raise TypeError(f'Unsupported allocate type: {type(data)}') diff --git a/lmdeploy/pytorch/paging/block_manager/default_block_manager.py b/lmdeploy/pytorch/paging/block_manager/default_block_manager.py index 49c6786762..9a5ff0136d 100644 --- a/lmdeploy/pytorch/paging/block_manager/default_block_manager.py +++ b/lmdeploy/pytorch/paging/block_manager/default_block_manager.py @@ -1,10 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. # modify from: https://github.com/vllm-project/vllm -from typing import Union - import numpy as np -from ...adapter.adapter import SchedulerAdapter from ...messages import SchedulerSequence from .base_block_manager import BaseBlockManager @@ -27,18 +24,12 @@ class DefaultBlockManager(BaseBlockManager): @classmethod def num_required_blocks(cls, - obj: Union[SchedulerSequence, SchedulerAdapter], + obj: SchedulerSequence, prealloc_size: int = 0): """get num required blocks.""" - if isinstance(obj, SchedulerSequence): - num_tokens = obj.num_all_tokens() + prealloc_size - num_all_blocks = _div_up(num_tokens, obj.block_size) - return max(0, num_all_blocks - len(obj.logical_blocks)) - else: - if obj.is_actived(): - return 0 - else: - return obj.num_required_blocks + num_tokens = obj.num_all_tokens() + prealloc_size + num_all_blocks = _div_up(num_tokens, obj.block_size) + return max(0, num_all_blocks - len(obj.logical_blocks)) @classmethod def last_block_size(cls, seq: SchedulerSequence) -> int: @@ -54,9 +45,6 @@ def can_allocate(self, msg: SchedulerSequence, prealloc_size: int = 0): """Return if physical block can be allocated for given message.""" num_required_blocks = self.num_required_blocks(msg, prealloc_size) num_free_phy = self.get_num_free_gpu_blocks() - if msg.adapter_name is not None: - adapter = self.adapter_manager.get_adapter(msg.adapter_name) - num_required_blocks += self.num_required_blocks(adapter) return num_required_blocks <= num_free_phy def allocate_msg(self, msg: SchedulerSequence, prealloc_size: int = 0): @@ -68,19 +56,12 @@ def allocate_msg(self, msg: SchedulerSequence, prealloc_size: int = 0): blocks = self.allocator.allocate(num_required_blocks, 'gpu') logical_blocks.append(blocks) - def allocate_adapter(self, adapter: SchedulerAdapter): - """Allocate cpu blocks for given adapter.""" - num_required_blocks = self.num_required_blocks(adapter) - if num_required_blocks > 0: - blocks = self.allocator.allocate(num_required_blocks, 'cpu') - adapter.logical_blocks.append(blocks) - def free(self, msg: SchedulerSequence): """Free all physical blocks allocated for the session.""" self.allocator.free(msg.logical_blocks.get_real_blocks()) msg.logical_blocks.reset() - def try_swap_out(self, msg: Union[SchedulerSequence, SchedulerAdapter]): + def try_swap_out(self, msg: SchedulerSequence): """Try swap msg out.""" swap_map = dict() logical_blocks = msg.logical_blocks @@ -120,8 +101,6 @@ def _do_swap(): gpu_allocator.free(old_blocks) self.allocator.update_phy_map(logical_blocks.get_real_blocks(), new_blocks) - if isinstance(msg, SchedulerAdapter): - msg.active(False) return True, swap_map if not _can_swap(): @@ -129,7 +108,7 @@ def _do_swap(): else: return _do_swap() - def try_swap_in(self, msg: Union[SchedulerSequence, SchedulerAdapter]): + def try_swap_in(self, msg: SchedulerSequence): """Try swap msg in.""" swap_map = dict() logical_blocks = msg.logical_blocks @@ -169,8 +148,6 @@ def _do_swap(): cpu_allocator.free(old_blocks) self.allocator.update_phy_map(logical_blocks.get_real_blocks(), new_blocks) - if isinstance(msg, SchedulerAdapter): - msg.active(True) return True, swap_map if not _can_swap(): diff --git a/lmdeploy/pytorch/paging/block_manager/window_block_manager.py b/lmdeploy/pytorch/paging/block_manager/window_block_manager.py index 96039f9296..7b5a7c5285 100644 --- a/lmdeploy/pytorch/paging/block_manager/window_block_manager.py +++ b/lmdeploy/pytorch/paging/block_manager/window_block_manager.py @@ -1,9 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Union - import numpy as np -from ...adapter.adapter import AdapterManager, SchedulerAdapter from ...block import LogicalTokenBlocks from ...messages import SchedulerSequence from .default_block_manager import DefaultBlockManager @@ -44,19 +41,16 @@ class WindowBlockManager(DefaultBlockManager): num_cpu_blocks (int): number of cpu blocks. """ - def __init__(self, - num_gpu_blocks: int, - num_cpu_blocks: int, - window_size: int, - adapter_manager: AdapterManager = None): - super().__init__(num_gpu_blocks, num_cpu_blocks, adapter_manager) + def __init__(self, num_gpu_blocks: int, num_cpu_blocks: int, + window_size: int): + super().__init__(num_gpu_blocks, num_cpu_blocks) assert window_size > 0, ('expect window size > 0, ' f'but get window_size = {window_size}') self.window_size = window_size @classmethod def num_required_blocks(cls, - obj: Union[SchedulerSequence, SchedulerAdapter], + obj: SchedulerSequence, prealloc_size: int = 0): """get num required blocks.""" @@ -71,17 +65,7 @@ def __num_req_seq(seq: SchedulerSequence): num_req_tokens = max(0, num_input_tokens - lb_remain_tokens) return _div_up(num_req_tokens, block_size) - def __num_req_adapter(adapter: SchedulerAdapter): - """get num required adapter blocks.""" - if adapter.is_actived(): - return 0 - else: - return obj.num_required_blocks - - if isinstance(obj, SchedulerSequence): - return __num_req_seq(obj) - else: - return __num_req_adapter(obj) + return __num_req_seq(obj) @classmethod def last_block_size(cls, seq: SchedulerSequence) -> int: @@ -96,9 +80,6 @@ def can_allocate(self, msg: SchedulerSequence, prealloc_size: int = 0): num_drop_blocks = _num_blocks_to_drop(msg, self.window_size) num_required_blocks = self.num_required_blocks(msg, prealloc_size) num_free_phy = self.get_num_free_gpu_blocks() - if msg.adapter_name is not None: - adapter = self.adapter_manager.get_adapter(msg.adapter_name) - num_required_blocks += self.num_required_blocks(adapter) return num_required_blocks <= num_free_phy + num_drop_blocks def allocate_msg(self, msg: SchedulerSequence, prealloc_size: int = 0): diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index 79b4833d8f..8879863092 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -2,11 +2,10 @@ # modify from: https://github.com/vllm-project/vllm from collections import OrderedDict from dataclasses import dataclass -from typing import Dict, List, Set, Union +from typing import Dict, List from lmdeploy.utils import get_logger, logging_timer -from ..adapter.adapter import AdapterManager, SchedulerAdapter from ..config import CacheConfig, SchedulerConfig from ..messages import (MessageStatus, SchedulerSequence, SchedulerSession, SequenceManager) @@ -16,7 +15,6 @@ logger = get_logger('lmdeploy') SeqList = List[SchedulerSequence] -AdapterList = List[SchedulerAdapter] @dataclass @@ -27,7 +25,6 @@ class SchedulerOutput: swap_in_map: Dict[int, int] swap_out_map: Dict[int, int] copy_map: Dict[int, int] - adapters: AdapterList class Scheduler: @@ -38,21 +35,14 @@ class Scheduler: cache_config (CacheConfig): The config of cache info. """ - def __init__(self, - scheduler_config: SchedulerConfig, - cache_config: CacheConfig, - adapter_manager: AdapterManager = None) -> None: + def __init__(self, scheduler_config: SchedulerConfig, + cache_config: CacheConfig) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config self.sessions: Dict[int, SchedulerSession] = OrderedDict() - self.actived_adapters: Set[str] = set() - if adapter_manager is None: - adapter_manager = AdapterManager(dict(), dict(), 0) - self.adapter_manager = adapter_manager - - self.block_manager = build_block_manager(cache_config, adapter_manager) + self.block_manager = build_block_manager(cache_config) self.block_trie = BlockTrie(self.cache_config, self.block_manager) self.eviction_helper = self.build_eviction_helper( @@ -124,19 +114,6 @@ def add_sequence(self, seq: SchedulerSequence): # push message to waiting queue self._set_message_status(seq, MessageStatus.WAITING) - def add_adapter(self, adapter_name: str): - """Add adapter. - - Args: - adapter_name (str): The name of the adapter. - """ - adapter = self.adapter_manager.add_adapter(adapter_name) - self.block_manager.allocate_adapter(adapter) - block_table = self.block_manager.get_block_table( - adapter) - self.block_manager.num_gpu_blocks - adapter.update_rank_offset(block_table) - return adapter.build_weight_map() - @logging_timer('SchedulePrefilling', logger) def _schedule_prefill(self): """Schedule for prefilling.""" @@ -148,9 +125,6 @@ def _schedule_prefill(self): swap_in_map: Dict[int, int] = dict() copy_map: Dict[int, int] = dict() running: SeqList = [] - required_adapters = set(seq.adapter_name for seq in current_running) - max_adapters = self.scheduler_config.max_active_adapters - len( - required_adapters) token_count = 0 def _to_running(seq: SchedulerSequence): @@ -172,26 +146,6 @@ def _reorder_waiting(): """reorder waiting.""" return sorted(self.waiting, key=lambda seq: seq.arrive_time) - def _active_adapter(adapter_name): - """active adapter of a seq.""" - if adapter_name not in required_adapters: - adapter = self.adapter_manager.get_adapter(adapter_name) - if not adapter.is_actived(): - _, tmp_map = self.block_manager.try_swap_in(adapter) - swap_in_map.update(tmp_map) - block_table = self.block_manager.get_block_table(adapter) - adapter.update_rank_offset(block_table) - adapter.active(True) - required_adapters.add(adapter_name) - - def _deactive_adapter(adapter_name): - """deactive_adapter.""" - adapter = self.adapter_manager.get_adapter(adapter_name) - if adapter.is_actived(): - _, tmp_map = self.block_manager.try_swap_out(adapter) - swap_out_map.update(tmp_map) - adapter.active(False) - num_waiting = self.seq_manager.num_sequences(MessageStatus.WAITING) if (len(running) >= max_batches or num_waiting == 0): return running, swap_in_map, swap_out_map, copy_map @@ -204,11 +158,6 @@ def _deactive_adapter(adapter_name): self.cache_config.max_prefill_token_num): break - # limit number of adapters - if len(required_adapters) >= max_adapters: - if seq.adapter_name not in required_adapters: - break - self.block_trie.match(seq) if not __evict_for_seq(seq, waiting): @@ -216,15 +165,8 @@ def _deactive_adapter(adapter_name): # allocate session memory self.block_manager.allocate(seq) - _active_adapter(seq.adapter_name) _to_running(seq) - deactive_adapters = self.actived_adapters.difference(required_adapters) - for adapter_name in deactive_adapters: - _deactive_adapter(adapter_name) - - self.actived_adapters = required_adapters - return running, swap_in_map, swap_out_map, copy_map @logging_timer('ScheduleDecoding', logger) @@ -270,12 +212,6 @@ def __evict_for_seq(seq: SchedulerSequence): return self.running, swap_in_map, swap_out_map, copy_map - def _get_adapter_list(self, adapter_names: List[str]): - adapters = [ - self.adapter_manager.get_adapter(name) for name in adapter_names - ] - return adapters - def schedule(self, is_prefill: bool, prealloc_size: int = 0): """Schedule inputs for next steps.""" if is_prefill: @@ -284,13 +220,10 @@ def schedule(self, is_prefill: bool, prealloc_size: int = 0): output = self._schedule_decoding(prealloc_size) running, swap_in_map, swap_out_map, copy_map = output - adapters = self._get_adapter_list(self.actived_adapters) - return SchedulerOutput(running=running, swap_in_map=swap_in_map, swap_out_map=swap_out_map, - copy_map=copy_map, - adapters=adapters) + copy_map=copy_map) def _set_session_status(self, session_id: int, status: MessageStatus): """Setup the status of session. @@ -345,6 +278,6 @@ def has_running(self): def has_waiting(self): return self.seq_manager.num_sequences(MessageStatus.WAITING) > 0 - def get_block_tables(self, seqs: Union[SeqList, AdapterList]): + def get_block_tables(self, seqs: SeqList): """get block table of the sequences.""" return [self.block_manager.get_block_table(seq) for seq in seqs] diff --git a/tests/pytorch/kernel/test_fused_lora.py b/tests/pytorch/kernel/test_fused_lora.py new file mode 100644 index 0000000000..78ea302937 --- /dev/null +++ b/tests/pytorch/kernel/test_fused_lora.py @@ -0,0 +1,108 @@ +import pytest +import torch +from torch.nn.utils.rnn import pad_sequence + +from lmdeploy.pytorch.kernels.cuda.fused_lora import fused_lora + + +class TestFusedLoRA: + + @pytest.fixture + def dtype(self): + yield torch.float16 + + @pytest.fixture + def head_size(self): + yield 32 + + @pytest.fixture + def out_head_size(self): + yield 16 + + @pytest.fixture + def seq_lens(self): + yield torch.tensor([2, 4, 6, 8]).cuda() + + @pytest.fixture + def ranks(self): + yield torch.tensor([2, 4]).cuda() + + @pytest.fixture + def start_loc(self, seq_lens): + yield seq_lens.cumsum(0) - seq_lens + + @pytest.fixture + def input(self, seq_lens, head_size, dtype): + total_len = seq_lens.sum() + yield torch.rand(total_len, head_size, dtype=dtype).cuda() + + @pytest.fixture + def adapter_ids(self, seq_lens, ranks): + num_ranks = len(ranks) + num_seqs = len(seq_lens) + ret = torch.arange(0, num_seqs) % num_ranks + ret = ret.cuda() + yield ret + + @pytest.fixture + def scaling(self, ranks): + yield torch.arange(ranks.size(0)).cuda() + 1 + + @pytest.fixture + def lora_a(self, ranks, head_size, dtype): + out = [] + for rank in ranks: + w = torch.rand(head_size, rank, dtype=dtype).cuda() + out.append(w) + yield out + + @pytest.fixture + def lora_b(self, ranks, out_head_size, dtype): + out = [] + for rank in ranks: + w = torch.rand(rank, out_head_size, dtype=dtype).cuda() + out.append(w) + yield out + + @pytest.fixture + def fused_lora_a(self, lora_a): + yield torch.cat(lora_a, dim=1).t().contiguous() + + @pytest.fixture + def fused_lora_b(self, lora_b): + yield torch.cat(lora_b, dim=0).contiguous() + + @pytest.fixture + def gt(self, input, start_loc, seq_lens, adapter_ids, lora_a, lora_b, + scaling): + out = [] + for loc, s_len, r_id in zip(start_loc, seq_lens, adapter_ids): + inp = input[loc:loc + s_len] + l_a = lora_a[r_id] + l_b = lora_b[r_id] + s = scaling[r_id] + out.append(inp @ l_a @ l_b * s) + + yield torch.cat(out) + + def test_fused_lora(self, input, fused_lora_a, fused_lora_b, + start_loc, seq_lens, adapter_ids, scaling, + ranks, gt): + max_seq_len = max(seq_lens).item() + max_rank = max(ranks).item() + rank_offset = ranks.cumsum(0) - ranks + + + output = fused_lora( + input, fused_lora_a, fused_lora_b, + scaling=scaling, + rank_start=rank_offset, + ranks=ranks, + seq_start=start_loc, + seq_lens=seq_lens, + adapter_ids=adapter_ids, + max_rank=max_rank, + max_seqlen=max_seq_len, + ) + + torch.testing.assert_close(gt, output) diff --git a/tests/pytorch/kernel/test_mbgmm.py b/tests/pytorch/kernel/test_mbgmm.py deleted file mode 100644 index e68ef18576..0000000000 --- a/tests/pytorch/kernel/test_mbgmm.py +++ /dev/null @@ -1,134 +0,0 @@ -import pytest -import torch -from torch.nn.utils.rnn import pad_sequence - -from lmdeploy.pytorch.kernels.mbgmm import mbgmm_a, mbgmm_b - - -class TestMBGMM: - - @pytest.fixture - def dtype(self): - yield torch.float16 - - @pytest.fixture - def head_size(self): - yield 32 - - @pytest.fixture - def out_head_size(self): - yield 16 - - @pytest.fixture - def seq_lens(self): - yield torch.tensor([2, 4, 6, 8]).cuda() - - @pytest.fixture - def ranks(self): - yield torch.tensor([2, 4]).cuda() - - @pytest.fixture - def start_loc(self, seq_lens): - yield seq_lens.cumsum(0) - seq_lens - - @pytest.fixture - def input(self, seq_lens, head_size, dtype): - total_len = seq_lens.sum() - yield torch.rand(total_len, head_size, dtype=dtype).cuda() - - @pytest.fixture - def adapter_ids(self, seq_lens, ranks): - num_ranks = len(ranks) - num_seqs = len(seq_lens) - ret = torch.randint(0, num_ranks, (num_seqs, )).cuda() - yield ret - - @pytest.fixture - def scaling(self, ranks): - yield torch.arange(ranks.size(0)).cuda() + 1 - - @pytest.fixture - def lora_a(self, ranks, head_size, dtype): - out = [] - for rank in ranks: - w = torch.rand(head_size, rank, dtype=dtype).cuda() - out.append(w) - yield out - - @pytest.fixture - def lora_b(self, ranks, out_head_size, dtype): - out = [] - for rank in ranks: - w = torch.rand(rank, out_head_size, dtype=dtype).cuda() - out.append(w) - yield out - - @pytest.fixture - def page_table(self, ranks): - total_ranks = sum(ranks) - index = torch.randperm(total_ranks) - index = index.split(ranks.tolist()) - yield pad_sequence(index, batch_first=True).cuda() - - @pytest.fixture - def rank_offset(self, page_table, head_size): - yield page_table * head_size - - @pytest.fixture - def paged_lora_a(self, lora_a, ranks, page_table, head_size, dtype): - num_pages = sum(ranks) - cache = torch.empty(num_pages, head_size, dtype=dtype).cuda() - for index, r, w in zip(page_table, ranks, lora_a): - cache[index[:r]] = w.t() - yield cache - - @pytest.fixture - def paged_lora_b(self, lora_b, ranks, page_table, head_size, out_head_size, - dtype): - num_pages = sum(ranks) - cache = torch.empty(num_pages, head_size, dtype=dtype).cuda() - for index, r, w in zip(page_table, ranks, lora_b): - cache[index[:r], :out_head_size] = w - yield cache - - @pytest.fixture - def gt(self, input, start_loc, seq_lens, adapter_ids, lora_a, lora_b, - scaling): - out = [] - for loc, s_len, r_id in zip(start_loc, seq_lens, adapter_ids): - inp = input[loc:loc + s_len] - l_a = lora_a[r_id] - l_b = lora_b[r_id] - s = scaling[r_id] - out.append(inp @ l_a @ l_b * s) - - yield torch.cat(out) - - def test_mbgmm(self, input, paged_lora_a, paged_lora_b, out_head_size, - start_loc, seq_lens, adapter_ids, scaling, rank_offset, - ranks, gt): - max_seq_len = max(seq_lens).item() - max_rank = rank_offset.size(-1) - - xa = mbgmm_a(input, - paged_lora_a, - q_start_loc=start_loc, - q_seqlens=seq_lens, - adapter_ids=adapter_ids, - rank_offset=rank_offset, - ranks=ranks, - max_seq_len=max_seq_len, - max_rank=max_rank) - - output = mbgmm_b(xa, - paged_lora_b[..., :out_head_size], - q_start_loc=start_loc, - q_seqlens=seq_lens, - adapter_ids=adapter_ids, - scaling=scaling, - rank_offset=rank_offset, - ranks=ranks, - max_seq_len=max_seq_len, - max_rank=max_rank) - - torch.testing.assert_close(gt, output) diff --git a/tests/pytorch/kernel/test_mbgmv.py b/tests/pytorch/kernel/test_mbgmv.py deleted file mode 100644 index 8751b2a891..0000000000 --- a/tests/pytorch/kernel/test_mbgmv.py +++ /dev/null @@ -1,122 +0,0 @@ -import pytest -import torch -from torch.nn.utils.rnn import pad_sequence - -from lmdeploy.pytorch.kernels.mbgmv import mbgmv_a, mbgmv_b - - -class TestMBGMV: - - @pytest.fixture - def dtype(self): - yield torch.float16 - - @pytest.fixture - def head_size(self): - yield 64 - - @pytest.fixture - def out_head_size(self): - yield 32 - - @pytest.fixture - def batch_size(self): - yield 8 - - @pytest.fixture - def ranks(self): - yield torch.tensor([2, 4]).cuda() - - @pytest.fixture - def input(self, batch_size, head_size, dtype): - x = torch.rand(batch_size, head_size, dtype=dtype).cuda() - x -= 0.5 - yield x - - @pytest.fixture - def adapter_ids(self, batch_size, ranks): - num_ranks = len(ranks) - ret = torch.randint(0, num_ranks, (batch_size, )).cuda() - yield ret - - @pytest.fixture - def scaling(self, ranks): - yield torch.arange(ranks.size(0)).cuda() + 1 - - @pytest.fixture - def lora_a(self, ranks, head_size, dtype): - out = [] - for rank in ranks: - w = torch.rand(head_size, rank, dtype=dtype).cuda() - w -= 0.5 - out.append(w) - yield out - - @pytest.fixture - def lora_b(self, ranks, out_head_size, dtype): - out = [] - for rank in ranks: - w = torch.rand(rank, out_head_size, dtype=dtype).cuda() - w -= 0.5 - out.append(w) - yield out - - @pytest.fixture - def page_table(self, ranks): - total_ranks = sum(ranks) - index = torch.randperm(total_ranks) - index = index.split(ranks.tolist()) - yield pad_sequence(index, batch_first=True).cuda() - - @pytest.fixture - def rank_offset(self, page_table, head_size): - yield page_table * head_size - - @pytest.fixture - def paged_lora_a(self, lora_a, ranks, page_table, head_size, dtype): - num_pages = sum(ranks) - cache = torch.empty(num_pages, head_size, dtype=dtype).cuda() - for index, r, w in zip(page_table, ranks, lora_a): - cache[index[:r]] = w.t() - yield cache - - @pytest.fixture - def paged_lora_b(self, lora_b, ranks, page_table, head_size, out_head_size, - dtype): - num_pages = sum(ranks) - cache = torch.empty(num_pages, head_size, dtype=dtype).cuda() - for index, r, w in zip(page_table, ranks, lora_b): - cache[index[:r], :out_head_size] = w - yield cache - - @pytest.fixture - def gt(self, input, adapter_ids, lora_a, lora_b, scaling): - out = [] - for inp, r_id in zip(input, adapter_ids): - inp = inp.unsqueeze(0) - l_a = lora_a[r_id] - l_b = lora_b[r_id] - s = scaling[r_id] - out.append(inp @ l_a @ l_b * s) - - yield torch.cat(out) - - def test_mbgmv(self, input, paged_lora_a, paged_lora_b, out_head_size, - adapter_ids, scaling, rank_offset, ranks, gt): - max_rank = rank_offset.size(-1) - - xa = mbgmv_a(input, - paged_lora_a, - adapter_ids=adapter_ids, - rank_offset=rank_offset, - ranks=ranks, - max_rank=max_rank) - - output = mbgmv_b(xa, - paged_lora_b[..., :out_head_size], - adapter_ids=adapter_ids, - scaling=scaling, - rank_offset=rank_offset, - ranks=ranks, - max_rank=max_rank) - torch.testing.assert_close(gt, output, atol=4e-3, rtol=1e-5) diff --git a/tests/pytorch/kernel/test_rearange_all_gather.py b/tests/pytorch/kernel/test_rearange_all_gather.py deleted file mode 100644 index 643c425fca..0000000000 --- a/tests/pytorch/kernel/test_rearange_all_gather.py +++ /dev/null @@ -1,83 +0,0 @@ -import pytest -import torch - -from lmdeploy.pytorch.kernels.rearange_all_gather import rearange_all_gather - - -class TestRearangeAllGather: - - @pytest.fixture - def seq_lens(self, request): - yield torch.tensor(request.param, device='cuda') - - @pytest.fixture - def start_loc(self, seq_lens): - yield seq_lens.cumsum(0) - seq_lens - - @pytest.fixture - def ranks(self): - yield torch.tensor([4, 8]).cuda() - - @pytest.fixture - def adapter_ids(self, seq_lens, ranks): - num_ranks = len(ranks) - num_seqs = len(seq_lens) - ret = torch.randint(0, num_ranks, (num_seqs, )).cuda() - yield ret - - @pytest.fixture - def world_size(self): - yield 2 - - @pytest.fixture - def input(self, seq_lens, ranks): - max_rank = max(ranks) - total_len = seq_lens.sum() - yield torch.rand(total_len, max_rank).cuda() - - @pytest.fixture - def rank_per_input(self, seq_lens, ranks, adapter_ids): - token_adapter_ids = [ - torch.full((slen, ), ada_id) - for slen, ada_id in zip(seq_lens, adapter_ids) - ] - token_adapter_ids = torch.cat(token_adapter_ids).cuda() - yield ranks[token_adapter_ids] - - @pytest.fixture - def valid_mask(self, rank_per_input, seq_lens, ranks): - max_rank = max(ranks) - total_len = seq_lens.sum() - mask = torch.zeros(total_len, max_rank).to(bool) - for r, m in zip(rank_per_input, mask): - m[:r] = True - yield mask.cuda() - - @pytest.fixture - def gt(self, input, rank_per_input, ranks, world_size): - max_rank = max(ranks) - pranks = rank_per_input // world_size - pmax_rank = max_rank // world_size - output = torch.empty_like(input) - for pr, inp, out in zip(pranks, input, output): - pindex = torch.arange(pr).cuda() - index = [pindex + ws * pmax_rank for ws in range(world_size)] - index = torch.cat(index) - out[:index.size(0)] = inp[index] - yield output - - @pytest.mark.parametrize('seq_lens', [[30, 50, 70, 90], [1, 1, 1, 1]], - indirect=True) - def test_gather(self, input, start_loc, seq_lens, adapter_ids, ranks, - world_size, gt, valid_mask): - max_seq_len = max(seq_lens) - output = rearange_all_gather(input, - start_loc, - seq_lens, - adapter_ids, - ranks, - world_size, - max_seq_len=max_seq_len) - output = output.where(valid_mask, output.new_tensor(0)) - gt = gt.where(valid_mask, gt.new_tensor(0)) - torch.testing.assert_close(output, gt)