Skip to content

Commit

Permalink
S-LoRA support (#894)
Browse files Browse the repository at this point in the history
* WIP

* cache engine wip

* finish cache engine

* fix cache and scheduler

* add paged attention

* step and stop

* add infer

* add request process

* fix end

* request without schedulersession

* add logits processor

* better context

* update patch

* [Improve] Use 4d input in pytorch poc (#371)

* 4D input, model.eval and llama config

* use auto dtype

* tp wip

* almost

* update logger

* run_check=false

* little optimize

current best

redist w/o dtensor

host mem in que

less rewrite

less code

update model weight

* share attention forward

* fix end

* Support Baichuan (#382)

* add baichuan WIP

* support baichuan

* support baichuan-13b

* fix

* add chat template

* lint

* comments

* fix

* Move `q_seq_info` into `context` (#398)

* move q seq info into context

* remove debugs

* remove debugs

* alibi wip

* add alibi

* reduce logic block (#435)

* add docstring

* add baichuan lint (#445)

* add fill cache back

* support internlm

* fix path of weight index

* Support chatglm2 in pytorch_poc (#360)

* draft support for chatglm2

* debug llama

* gitignore

* update input_id

* better patching

* patch chatglm2 model

* fix after merge

* remove inits

* q_seq_info & remove some debug & orig_self

* remove old unqeuzze inputid

* update patch and model config

* remove debugs and clean codes

* clean codes

* add credit

* add update id / fix dependency

* rename modules (#504)

Co-authored-by: grimoire <[email protected]>

* optimize fill kv cache (#523)

* optimize fill kv cache

* update internlm

* faster embedding

* fix bias tp

* fix baichuan2

* fix fill kv cache

* fix lint

---------

* Make trust_remote_code as cli argument (#434)

* trust_remote_code_argument

* format

* update tokenizer

* optimize rotary

* wtf

* Support Falcon models (#406)

* move q seq info into context

* falcon aligned

* trust_remote_code_argument

* fix for falcon

* comment out debugs

* comment out debugs

* use position id in context

* remove codes in falcon model

* Revert "comment out debugs"

This reverts commit ee26a25.

* 7b correct

* 1b aligned

* remove debugs

* patch to ignore position ids

* remove debug in alibi, avoid empty inputs

* fix

* rename dir to replace to "models"

* use position_id and new fill kernel

* remove useless get_prompt func

* fix batch>2

* Refactor scheduler (#551)

* optimize block manager

* scheduler wip

* finish scheduler

* update engine

* profile pytorch poc (#455)

* profile pytorch poc

* update doc and import if need

* arg

* support profile_throughput.py

* reuse pytorch session

* end session

* Support Tensor parallel on Falcon models (#582)

* tp falcon 1b and 7b works

* remove debugs

* update copyright

* add some comments

* remove a debug

* support new hub models

* support 40b

* support 40b model config

* try

* recover

* fix remain len

* Apply rotary kernel (#572)

* apply rotary kernel

* format

* update rmsnorm

* update rms norm

* better unittest

* add docstring

---------

Co-authored-by: grimoire <[email protected]>

* fix(pytorch_poc): memory cal (#606)

* fix(pytorch_poc): memory cal

* Optimize attention (#597)

* add unittest

* add split k

* add docstring

* fast split k

* optimize load

* manually setup device and stream

* lint

---------

Co-authored-by: grimoire <[email protected]>

* feat(pytorch_poc): implement ReRoPE (#625)

* fix(pytorch_poc): memory cal

* style(pytorch_poc): lint

* style(.pre-commit-config.yaml): update

* style(pytorch_poc): remove useless

* feat(pytorch_poc): llama2 support rerope

* feat(pytorch_poc): fix long input generate

* feat(lmdeploy): add kernel

* feat(lmdeploy): update

* feat(lmdeploy): add rerope implementation

* fix(lmdeploy/pytorch_poc): apply rotary_emb

* fix(lmdeploy): update

* style(pytorch_poc): format

* style(lmdeploy): fix lint

* style(lmdeploy): typo

* style(pytorch_poc): format

* style(pytorch_poc): format

* fix(pytorch_poc): rms_norm add mask

* style(pytorch_poc/kernels): format rerope

* style(pytorch_poc): format rerope attn function description

* style(lmdeploy/pytorch_poc): format

* style(pytorch_poc): add code ref

* style(pytorch_poc): format rerope attn

* Refactor engine (#623)

* add agent

* optimize postprocess

* optimize decoding fill cache

* add docstring

* logit to cuda

* blocksize 128

* optimize pre/post process

* fix postprocess

* cpu pre/post process

* manually setup stream and device

* remove context

* update model agent

* update max session len

* remove tqdm

* update pre/post process

* inplace kernel

* avoid kv_len computation

* flash decoding with one cache

* remove comment

* add warning when no enough resources

* step if has unfinish

* add request manager

* better fill kv cache

* fix fill kv cache

* optimize prefill attention

* refractor

* refactoring...

* add custom output

* use cache

---------

Co-authored-by: grimoire <[email protected]>

* [Feature] w8a8 based on pytorch poc (#595)

* refactor smoothquant and support load w8a8 model by from_pretrained

* add w8a8 docs

* add w8a8 en docs

* add convert_to_qmodules function

---------

Co-authored-by: grimoire <[email protected]>

* feat(lmdeploy): add rerope quantization (#718)

* feat(lmdeploy): add rerope quantization

* feat(lmdeploy): fix review

* [Refactor & Doc] Improve w8a8 and add docstring (#768)

* WIP

* improve w8a8 and add doc string

* add docstring

* add docstring

* fix lint

* rename pytorch poc (#764)

* rename pytorch poc

* fix lint

* add docstring

* add docstring

* refactor patch

* add recompute eviction support

* recovery modeling

* add docstring

* Unified paging (#860)

* change 'model_format' to 'qwen' when 'model_name' starts with 'qwen' (#575)

* avoid split chinese characters during decoding (#566)

* add solar chat template (#576)

* robust incremental decode for leading space (#581)

* robust incremental decode for leading space

* speed up lookup as prefix_space_tokens is shorter than no_prefix_space_tokens

* add UT and fix qwen stuff

* update solar chat template (#587)

* Revert "[Docs] Simplify `build.md` (#370)" (#586)

This reverts commit 4b5c2bd.

* Fix crash and remove `sys_instruct` from `chat.py` and `client.py`(#591)

* fix crash

* update profile_generation.py

* format

* use self.bos_id

* remove sys_instruct

* bump version to v0.0.12 (#604)

* Add "build from docker" section (#602)

* add build from docker section

* update

* install python package

* update

* update

* update

* Add more user-friendly CLI  (#541)

* add

* import fire in main

* wrap to speed up fire cli

* update

* update docs

* update docs

* fix

* resolve commennts

* resolve confict and add test for cli

* support inference a batch of prompts (#467)

* support inference a batch of prompts

* docstring and assert

* bump version to v0.0.13 (#620)

* Improve api_server and webui usage (#544)

* make IPv6 compatible, safe run for coroutine interrupting

* instance_id -> session_id and fix api_client.py

* update doc

* remove useless faq

* safe ip mapping

* update app.py

* WIP completion

* completion

* update doc

* disable interactive mode for /v1/chat/completions

* docstring

* docstring

* refactor gradio

* update gradio

* udpate

* update doc

* rename

* session_id default -1

* missed two files

* add a APIClient

* add chat func for APIClient

* refine

* add concurrent function

* sequence_start, sequence_end --> interactive_mode

* update doc

* comments

* doc

* better text completion

* remove /v1/embeddings

* comments

* deprecate generate and use /v1/interactive/completions

* /v1/interactive/completion -> /v1/chat/interactive

* embeddings

* rename

* remove wrong arg description

* docstring

* fix

* update cli

* update doc

* strict session_len limit condition

* pass model args to api_server

* fix: gradio gr.Button.update deprecated after 4.0.0 (#637)

* add cli to list the supported model names (#639)

* update

* resolve comment

* Refactor model conversion (#296)

* split deploy.py

* fix get_cuda_tensor

* deploy qwen_awq

* fix lint

* add docstring

* fix

* support baichuan/baichuan-awq

* parameterizing size_per_head

* remove try/except

* limit input model_format

* add quant_path param

* remove old deploy.py

* fix path

* fix transformer layer range when load bins

* fix qwen init

* split & save log

* relative import

* update get_config

* WeightFileMgr -> Reader

* rename

* update

* fix init_layer_id

* rename llama.py -> meta_llama.py, hf.py -> llama.py

* reduce code

* update arg description

* fix meta llama

* manually cleanup meta model params

* [Enchance] internlm message to prompt (#499)

* update turbomind session_len with model.session_len (#634)

* [Fix] Qwen's quantization results are abnormal & Baichuan cannot be quantized (#605)

* fix awq

* adapt new qwen code

* adapt qwen 14b and baichuan2 7b

* add docstring

* add runtime error for qwen

* FIX: fix stop_session func bug (#578)

* FIX: fix stop_session func bug

* keep sequence_end = False

---------

Co-authored-by: honglei.yan <[email protected]>
Co-authored-by: AllentDan <[email protected]>

* Manage session id using random int for gradio local mode (#553)

* Use session id from gradio state

* use a new session id after reset

* rename session id like a state

* update comments

* reformat files

* init session id on block loaded

* use auto increased session id

* remove session id textbox

* apply to api_server and tritonserver

* update docstring

* add lock for safety

---------

Co-authored-by: AllentDan <[email protected]>

* fix benchmark serving computation mistake (#630)

* fix benchmark serving computation mistake

* fix timestamps computations

* remove speed up

* no mp

* mp seems faster?

* remove

* update

* remove

* fix

* update

* update print log

* typo

* print fist token latency only stream==True

* remove renew_session

* update AsyncEngine

* fix tokenizer_info when convert the model (#661)

* Add check env sub command (#654)

* add check env

* update issue template'

* remove some reqs from check env

* resolve comment

* fix Tokenizer load error when the path of the being-converted  model is not writable (#669)

* Add UltraCM and WizardLM chat templates (#599)

* add ultracm eval chat template

* add WizardLM chat template

* use ultrachat template instead of ultracm usecase

* bump version to v0.0.14 (#663)

* Add extra_requires to reduce dependencies (#580)

* update reqs

* update docs

* resolve comments

* upgrade pydantic

* fix rebase

* update doc

* update

* update

* update readme

* update

* add flash-attn

* TurboMind 2 (#590)

* refresh decoder attention kernel

* block-level kv cache

* `BlockManager` & `SequenceManager`

* update

* update

* update

* update

* rename

* GQA support

* fix context length

* GQA dispatch

* kv8

* tune

* async stream cb

* nvtx

* config parsing

* debug

* optimize output cost

* split-k decoding

* minor

* truncate `session_len` by available blocks

* minor

* license

* fix

* dispatch `cp.async`

* fix linking

* fix

* fix deadlock

* guard input length

* correct start offset

* fix prefill chunking

* fix `cache_block_seq_len` param passing

* fix `block_size` fmtstr

* fix output tokens

* fix batch resizing

* fix masking of finished sequences

* add debug util

* free unused block early

* add ntk scaling and logn scaling

* cmake flags

* fix typo

* w4a16 for sm75

* fix msvc build

* fix msvc build

* fix block verification

* fix msvc build

* use `std::shuffle`

* fix lint

* fix lint

* fix lint

* clear incoming buffer

* clear finished requests

* fix batch initialization

* fix typo

* fix typo

* fix comparison

* [Docs] Update Supported Matrix (#679)

* update supported matrix

* change the default shard size when saving quantized weights

* baichuan2 kv8

* update kv8 docs (#681)

* Fix init of batch state (#682)

* fix init of finished buf

* fix `finished_count`

* fix turbomind stream canceling (#686)

* fix

* instance for each forward

* [Fix] Fix load_checkpoint_in_model bug (#690)

* fix load_checkpoint_in_model bug

* fix comments

* fix comments

* fix bugs

* [Doc] Update restful api doc (#662)

* update restful_api.md

* add a hint

* repeat 3 time

* Fix Tokenizer encode (#645)

* same encode with HF

* sequence_start -> add_bos

* complement

* Fix wrong eos_id and bos_id obtained through grpc api (#644)

* Fix wrong eos_id and bos_id obtained through grpc api

* fix according to review comments

* update

* Optimize for throughput (#701)

* tmp

* update

* update

* optimize for throughput

* update

* fix eos

* clean up

* fix serving

* fix indexed copy

* minor

* minor

---------

Co-authored-by: lvhan028 <[email protected]>

* Check-in user guide about turbomind config (#680)

* update

* update config guide

* update guide

* upate user guide according to review comments

* Replace mmengine with mmengine-lite (#715)

* Support loading hf model directly (#685)

* turbomind support export model params

* fix overflow

* support turbomind.from_pretrained

* fix tp

* support AutoModel

* support load kv qparams

* update auto_awq

* udpate docstring

* export lmdeploy version

* update doc

* remove download_hf_repo

* LmdeployForCausalLM -> LmdeployForCausalLM

* refactor turbomind.py

* update comment

* add bfloat16 convert back

* support gradio run_locl load hf

* support resuful api server load hf

* add docs

* support loading previous quantized model

* adapt pr 690

* udpate docs

* not export turbomind config when quantize a model

* check model_name when can not get it from config.json

* update readme

* remove model_name in auto_awq

* update

* update

* udpate

* fix build

* absolute import

* Fix cache/output length calculation (#738)

* bump version to v0.1.0a0 (#709)

* [Fix] Skip empty batch (#747)

* [Fix] build docker image failed since `packaging` is missing (#753)

* [Fix] Rollback the data type of input_ids to TYPE_UINT32 in preprocessor's proto (#758)

* Set the default value of `max_context_token_num` 1 (#761)

* rename pytorch poc

* fix lint

* add docstring

* add docstring

* refactor patch

* add recompute eviction support

* fix typo (#769)

* add triton server test and workflow yml (#760)

* add triton server test and workflow yml

* update

* revert changes in dockerfile

* update prompts

* recovery modeling

* fix turbomind build on sm<80 (#754)

* fix

* fix lint

* improvement(build): enable ninja and gold linker (#767)

* feat(build): enable ninja and lld

* fix(.github): add ninja installation

* fix(CI): remove dimsize=256

* fix(CI): add option for generate.sh

* fix(docs): update

* Report first-token-latency and token-latency percentiles (#736)

* update profile scripts

* add top_p, top_k and temperature as input arguments

* fix input_ids

* update profile_throughput

* update profile_restful_api

* update profile_serving

* update

* update

* add progress bar

* remove TODO comments

* update

* remove useless profile_* argument

* remove log level

* change concurrency default value to 64

* update restful_api.md

* update according to review comments

* fix docstring

* convert model with hf repo_id (#774)

* bump version to 0.1.0a1 (#776)

* Update benchmark user guide (#763)

* user guide of benchmark generation

* update benchmark generation guide

* update profiling throughput guide

* update profiling api_server guide

* rename file names

* update profile tis user guide

* update

* fix according to review comments

* update

* update according to review comments

* updaste

* add an example

* update

* add docstring

* add unified paging attention support

* refactor block manager

* do not alloc zero

* Fix early exit condition in attention kernel (#788)

* add chat template for Yi (#779)

* Fix missed arguments when benchmark static inference performance (#787)

* minor fix in the profile scripts and docs

* miss arguments

* typo

* fix lint

* update

* Unify prefill & decode passes (#775)

* Unify prefill and decode passes

* dynamic split-fuse

* refactor

* correct input count calculation

* remove unused

* lint

* lint

* fix msvc build

* fix msvc build

* fix msvc build

* fix msvc build

* fix msvc build

* fix msvc build

* fix msvc build

* fix msvc build

* fix msvc build

* add cuda12.1 build check ci (#782)

* update cuda12.1 build check ci

* use matrix

* auto upload cuda12.1 python pkg to release when create new tag (#784)

* add cuda12-whl-release ci

* enable environment

* test py310-311 windows wheel

* fix py310, py311 setup.py error on windows

* fix lint

* fix extra colon in InternLMChat7B (#796)

* fix local kv head num (#806)

* Report the inference benchmark of models with different size (#794)

* update test scripts for models with different sizes

* update

* only test after tunning gemm

* chmod +x

* fix typo

* benchmark on a100

* fix typo

* fix typo

* per-token latency percentile in profile_throughput

* fix

* fix

* rename

* make the script accept parameters

* minor fix

* indent

* reformat table

* change to 3000

* minor fix

* bump version to v0.1.0a2 (#807)

* fix out of bounds access (#809)

* update scheduler

* optimize request

* Simplify block manager (#812)

* simplify block manager

* fix lint

* set smem size for repetition penalty kernel (#818)

* add mbgemm&mbgemv

* fix recompute, fix mbgmm

---------

Co-authored-by: Lyu Han <[email protected]>
Co-authored-by: AllentDan <[email protected]>
Co-authored-by: pppppM <[email protected]>
Co-authored-by: Chen Xin <[email protected]>
Co-authored-by: RunningLeon <[email protected]>
Co-authored-by: Yam(长琴) <[email protected]>
Co-authored-by: liukuikun <[email protected]>
Co-authored-by: yunzhongyan0 <[email protected]>
Co-authored-by: honglei.yan <[email protected]>
Co-authored-by: AllentDan <[email protected]>
Co-authored-by: aisensiy <[email protected]>
Co-authored-by: Li Zhang <[email protected]>
Co-authored-by: whcao <[email protected]>
Co-authored-by: Zaida Zhou <[email protected]>
Co-authored-by: tpoisonooo <[email protected]>
Co-authored-by: Qian Zhao <[email protected]>

* [Fix] Adapt to the pyTorch poc branch (#863)

* Adapt to the pyTorch poc branch

* Adapt to the pyTorch poc branch

* fix comments

* update model

* wip

* wrong implementation

* s-lora single gpu

* refactor tp patch

* add tp support

* add tp gather

* recover profile generation

* daemon process

* inplace gather

* hf style

* add assert when input nothing

* find available port

---------

Co-authored-by: grimoire <[email protected]>
Co-authored-by: WRH <[email protected]>
Co-authored-by: AllentDan <[email protected]>
Co-authored-by: AllentDan <[email protected]>
Co-authored-by: tpoisonooo <[email protected]>
Co-authored-by: whcao <[email protected]>
Co-authored-by: Lyu Han <[email protected]>
Co-authored-by: pppppM <[email protected]>
Co-authored-by: Chen Xin <[email protected]>
Co-authored-by: RunningLeon <[email protected]>
Co-authored-by: Yam(长琴) <[email protected]>
Co-authored-by: liukuikun <[email protected]>
Co-authored-by: yunzhongyan0 <[email protected]>
Co-authored-by: honglei.yan <[email protected]>
Co-authored-by: aisensiy <[email protected]>
Co-authored-by: Li Zhang <[email protected]>
Co-authored-by: Zaida Zhou <[email protected]>
Co-authored-by: Qian Zhao <[email protected]>
  • Loading branch information
19 people authored Jan 9, 2024
1 parent 1a76191 commit 18cf952
Show file tree
Hide file tree
Showing 30 changed files with 2,124 additions and 650 deletions.
7 changes: 4 additions & 3 deletions benchmark/profile_torch_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _infer(model, session_id):
_start = time.perf_counter()
procs = []
for i in range(concurrency):
proc = Thread(target=_infer, args=(model, i + 1))
proc = Thread(target=_infer, args=(model, i + 1), daemon=True)
procs.append(proc)
proc.start()

Expand Down Expand Up @@ -139,7 +139,8 @@ def profile_throughput(model_path: str, concurrency: int, input_seqlen: int,
for i in range(concurrency):
proc = Thread(target=infer,
args=(tm_model, i + 1, input_ids, output_seqlen, top_k,
top_p, temperature, test_round, que))
top_p, temperature, test_round, que),
daemon=True)
procs.append(proc)
proc.start()

Expand Down Expand Up @@ -256,7 +257,7 @@ def mem_monitor(cls):
def start(cls):
cls._running = True
from multiprocessing import Process
cls.proc = Process(target=cls.mem_monitor)
cls.proc = Process(target=cls.mem_monitor, daemon=True)
cls.proc.start()

@classmethod
Expand Down
3 changes: 2 additions & 1 deletion benchmark/profile_torch_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ def process_request(self,
# start threads
for i in range(concurrency):
t = Thread(target=self._inference,
args=(req_queue, res_queue, i, stream_output))
args=(req_queue, res_queue, i, stream_output),
daemon=True)
t.start()
threads.append(t)

Expand Down
346 changes: 346 additions & 0 deletions lmdeploy/pytorch/adapter/adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,346 @@
# Copyright (c) OpenMMLab. All rights reserved.

import re
from dataclasses import dataclass
from typing import Any, Dict, List

import torch
from torch import Tensor

from ..block import LogicalTokenBlocks


def _cache_weight(cache: Tensor, weight: Tensor, block_table: Tensor):
"""cache weight."""
assert cache.dim() == 2
assert weight.dim() == 2
assert block_table.dim() == 1

rank, feat_size = weight.size()
assert cache.size(-1) >= feat_size, ('cache.size(-1) >= feat_size failed.')
assert rank <= block_table.size(0), ('rank <= block_table.size(0) failed.')
block_table = block_table[:rank]
cache[block_table, :feat_size] = weight.to(device=cache.device,
dtype=cache.dtype)


def _get_named_loralinears(model: torch.nn.Module):
"""get all named loralinear."""
from peft.tuners.lora import Linear as LoRALinear
named_loralinear: Dict[str, torch.nn.Module] = dict()
for name, module in model.named_modules():
if isinstance(module, LoRALinear):
named_loralinear[name] = module
return named_loralinear


def _get_layer_index(key: str, config: Any):
"""get layer index of the lora linear."""
from peft.utils.other import COMMON_LAYERS_PATTERN
layer_indexing_pattern = getattr(config, 'layers_pattern', None)
layers_pattern = layer_indexing_pattern or COMMON_LAYERS_PATTERN
if isinstance(layers_pattern, str):
layers_pattern = [layers_pattern]
for pattern in layers_pattern:
layer_index = re.match(f'.*.{pattern}\\.(\\d+)\\.*', key)

if layer_index is not None:
return int(layer_index[1])


def get_indexed_lora_linears(model: torch.nn.Module):
"""get indexed lora linear."""
named_linears = _get_named_loralinears(model)

config = None
peft_config = getattr(model, 'peft_config', dict)
if len(peft_config) > 0:
config = next(iter(peft_config.values()))

indexed_linears = dict()
for name, layer in named_linears.items():
index = _get_layer_index(name, config)
target = name.split('.')[-1]
indexed_linears.setdefault(index, dict())
indexed_linears[index][target] = layer
return indexed_linears


def update_lora_linears(lora_linears: Dict,
weight_maps: List['AdapterWeightMap'],
device: str = 'cuda'):
"""update lora linears."""

def __get_targets():
"""get targets."""
all_targets = set()
for weight_map in weight_maps:
targets = weight_map.target_modules.keys()
all_targets.update(targets)
return all_targets

def __get_rank_and_start(target_names):
"""get rank and start."""
rank_map = dict()
start_map = dict()
for target in target_names:
ranks = [0] + [
weight_map.target_modules[target].rank
for weight_map in weight_maps
]
block_starts = [0] + [
weight_map.target_modules[target].block_start
for weight_map in weight_maps
]
rank_map[target] = torch.tensor(ranks)
start_map[target] = torch.tensor(block_starts)
return rank_map, start_map

def __update_linear(linear, idx, rank_map, start_map, adapter_names):
"""update linear."""
linear.layer_idx = idx
linear.ranks = rank_map[target].to(device)
linear.block_starts = start_map[target].to(device)
for name in adapter_names:
if name in linear.lora_A:
linear.lora_A.pop(name)
linear.lora_B.pop(name)

adapter_names = [weight_map.adapter_name for weight_map in weight_maps]

all_targets = __get_targets()

for weight_map in weight_maps:
weight_map.expand_targets(all_targets)

rank_map, start_map = __get_rank_and_start(all_targets)

for idx, lora_linear in lora_linears.items():
for target, linear in lora_linear.items():
__update_linear(linear,
idx,
rank_map=rank_map,
start_map=start_map,
adapter_names=adapter_names)


@dataclass
class TargetMeta:
rank: int
block_start: int


@dataclass
class AdapterWeightMap:
adapter_name: str
block_table: Tensor
target_modules: Dict[str, TargetMeta]

@classmethod
def new(cls, adapter_name: str, rank: int, target_names: List[str],
block_table: Tensor):
"""create new weightmap."""
block_start = 0
target_modules: Dict[str, TargetMeta] = dict()
for name in target_names:
target_modules[name] = TargetMeta(rank, block_start)
block_start += rank

return AdapterWeightMap(adapter_name,
block_table=block_table,
target_modules=target_modules)

def expand_targets(self,
target_names: List[str],
ignore_exists: bool = True):
for name in target_names:
if name in self.target_modules:
if ignore_exists:
continue
else:
raise RuntimeError(f'target {name} exists.')
self.target_modules[name] = TargetMeta(0, 0)

@classmethod
def cache_lora_a(cls, cache: Tensor, weight: Tensor, block_table: Tensor):
"""cache lora a weight."""
return _cache_weight(cache, weight, block_table)

@classmethod
def cache_lora_b(cls, cache: Tensor, weight: Tensor, block_table: Tensor):
"""cache lora b weight."""
return _cache_weight(cache, weight.t(), block_table)

def cache_lora_linear(self, lora_linear: torch.nn.Module, cache_a: Tensor,
cache_b: Tensor):
"""cache lora linear."""
name = self.adapter_name
target_modules = self.target_modules
block_table = self.block_table
block_start = 0
for target, target_meta in target_modules.items():
linear = lora_linear[target]
if not (name in linear.lora_A and name in linear.lora_B):
continue
linear_a = linear.lora_A[name]
linear_b = linear.lora_B[name]
weight_a = linear_a.weight
weight_b = linear_b.weight
assert weight_a is not None
assert weight_b is not None
rank = target_meta.rank
block_offset = block_table[block_start:block_start + rank]
block_start += rank
self.cache_lora_a(cache_a, weight_a, block_offset)
self.cache_lora_b(cache_b, weight_b, block_offset)

def cache_adapter(self, lora_linears: Dict, caches: List[List[Tensor]]):
"""cache all linear."""
assert len(lora_linears) == len(caches), (
'len(lora_linears) == len(caches)')

for idx, lora_linear in lora_linears.items():
assert idx < len(caches), 'idx < len(caches)'
cache_a, cache_b = caches[idx]
self.cache_lora_linear(lora_linear, cache_a, cache_b)


@dataclass
class SchedulerAdapter:
"""lora adapter."""

idx: int
adapter_path: str
adapter_name: str
config: Any
target_modules: List[str]
logical_blocks: LogicalTokenBlocks
adapter_manager: 'AdapterManager'
_active: bool = False

@classmethod
def from_pretrained(cls, adapter_path: str, adapter_name: str, idx: int,
manager: 'AdapterManager'):
"""from_pretrained."""
from peft import PeftConfig
config = PeftConfig.from_pretrained(adapter_path)

return cls.from_config(config,
adapter_name=adapter_name,
idx=idx,
manager=manager)

@classmethod
def from_config(cls, config: Any, adapter_name: str, idx: int,
manager: 'AdapterManager'):
"""from config."""
new_adapter = SchedulerAdapter(
idx,
adapter_path=config.base_model_name_or_path,
adapter_name=adapter_name,
config=config,
target_modules=list(config.target_modules),
logical_blocks=LogicalTokenBlocks(1),
adapter_manager=manager)
new_adapter._active = False
return new_adapter

@property
def name(self):
"""get adapter name."""
return self.adapter_name

@property
def rank(self):
"""get rank."""
return self.config.r

def is_actived(self):
"""check if adapter is active."""
return self._active

def active(self, flag: bool = True):
"""active adapter."""
self.adapter_manager._on_active(self, flag)
self._active = flag

def num_blocks(self):
"""get num blocks."""
# ranks * (lora_a + lora_b) * num_targets
return self.rank * len(self.target_modules)

def num_required_blocks(self):
"""get num required blocks."""
if self.is_actived():
return 0
else:
return self.num_blocks()

def build_weight_map(self, block_table: Tensor):
return AdapterWeightMap.new(self.name,
rank=self.rank,
target_names=self.target_modules,
block_table=block_table)


class AdapterManager:
"""Adapter manager."""

def __init__(self) -> None:
self._adapters: Dict[str, SchedulerAdapter] = dict()
self._adapter_count = 0
self._active_count = 0

self._add_non_adapter()

def _add_non_adapter(self):
"""add non adapter."""
from peft import LoraConfig
adapter_name = None
config = LoraConfig(r=0, target_modules=[])
adapter = self.add_adapter_from_config(config,
adapter_name=adapter_name)
adapter.active()

def _on_active(self, adapter: SchedulerAdapter, flag: bool):
"""on active."""
if adapter._active != flag:
if flag:
self._active_count += 1
else:
self._active_count -= 1

def _add_adapter(self, adapter: SchedulerAdapter):
"""add adapter."""
assert adapter.adapter_name not in self._adapters
self._adapters[adapter.adapter_name] = adapter
self._adapter_count += 1
return adapter

def add_adapter_from_config(self, config: Any, adapter_name: str):
"""add adapter from config."""
adapter = SchedulerAdapter.from_config(config,
adapter_name=adapter_name,
idx=self._adapter_count,
manager=self)
return self._add_adapter(adapter)

def add_adapter_from_pretrained(self, adapter_path: str,
adapter_name: str):
"""add adapter by path and name."""
adapter = SchedulerAdapter.from_pretrained(adapter_path,
adapter_name=adapter_name,
idx=self._adapter_count,
manager=self)
return self._add_adapter(adapter)

def get_adapter(self, name: str, default=None):
"""get adapter."""
return self._adapters.get(name, default)

def num_adapters(self):
"""get num adapters."""
return len(self._adapters)


ADAPTER_MANAGER = AdapterManager()
Loading

0 comments on commit 18cf952

Please sign in to comment.