diff --git a/lmdeploy/pytorch/adapter/adapter.py b/lmdeploy/pytorch/adapter/adapter.py index f8186123ae..f93ac5541a 100644 --- a/lmdeploy/pytorch/adapter/adapter.py +++ b/lmdeploy/pytorch/adapter/adapter.py @@ -79,10 +79,11 @@ def __get_targets(): all_targets.update(targets) return all_targets - def __get_rank_and_start(target_names): + def __get_linear_meta(target_names): """get rank and start.""" rank_map = dict() start_map = dict() + scaling_map = dict() for target in target_names: ranks = [0] + [ weight_map.target_modules[target].rank @@ -92,15 +93,22 @@ def __get_rank_and_start(target_names): weight_map.target_modules[target].block_start for weight_map in weight_maps ] + scaling = [0] + [ + weight_map.target_modules[target].scaling + for weight_map in weight_maps + ] rank_map[target] = torch.tensor(ranks) start_map[target] = torch.tensor(block_starts) - return rank_map, start_map + scaling_map[target] = torch.tensor(scaling) + return rank_map, start_map, scaling_map - def __update_linear(linear, idx, rank_map, start_map, adapter_names): + def __update_linear(linear, idx, rank_map, start_map, scaling_map, + adapter_names): """update linear.""" linear.layer_idx = idx linear.ranks = rank_map[target].to(device) linear.block_starts = start_map[target].to(device) + linear.scaling = scaling_map[target].to(device) for name in adapter_names: if name in linear.lora_A: linear.lora_A.pop(name) @@ -113,7 +121,7 @@ def __update_linear(linear, idx, rank_map, start_map, adapter_names): for weight_map in weight_maps: weight_map.expand_targets(all_targets) - rank_map, start_map = __get_rank_and_start(all_targets) + rank_map, start_map, scaling_map = __get_linear_meta(all_targets) for idx, lora_linear in lora_linears.items(): for target, linear in lora_linear.items(): @@ -121,6 +129,7 @@ def __update_linear(linear, idx, rank_map, start_map, adapter_names): idx, rank_map=rank_map, start_map=start_map, + scaling_map=scaling_map, adapter_names=adapter_names) @@ -139,6 +148,7 @@ def get_max_lora_weight_size(model: torch.nn.Module): class TargetMeta: rank: int block_start: int + scaling: float @dataclass @@ -149,12 +159,12 @@ class AdapterWeightMap: @classmethod def new(cls, adapter_name: str, rank: int, target_names: List[str], - block_table: Tensor): + block_table: Tensor, scaling: float): """create new weightmap.""" block_start = 0 target_modules: Dict[str, TargetMeta] = dict() for name in target_names: - target_modules[name] = TargetMeta(rank, block_start) + target_modules[name] = TargetMeta(rank, block_start, scaling) block_start += rank return AdapterWeightMap(adapter_name, @@ -296,7 +306,8 @@ def build_weight_map(self, block_table: Tensor): return AdapterWeightMap.new(self.name, rank=self.rank, target_names=self.target_modules, - block_table=block_table) + block_table=block_table, + scaling=self.scaling) class AdapterManager: diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 0ace19efe7..ebaec203b6 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -370,7 +370,6 @@ def create_model_inputs(self, messages: SeqList, adapters: AdapterList): local_adapter_ids = None global_adapter_ids = None adapter_offsets = None - local_adapter_scalings = None max_rank = 0 if ADAPTER_MANAGER.num_adapters() > 1: local_adapter_ids = _get_adapter_ids(messages, adapters) @@ -381,10 +380,6 @@ def create_model_inputs(self, messages: SeqList, adapters: AdapterList): global_adapter_ids = seq_length.new_tensor(global_adapter_ids) ranks = [ada.rank for ada in adapters] max_rank = max(ranks) - local_adapter_scalings = [ - adapters[ada_ids].scaling for ada_ids in local_adapter_ids - ] - local_adapter_scalings = torch.tensor(local_adapter_scalings) # add batch dim [bs=1, seq_len] if input_ids.ndim == 1: @@ -401,7 +396,6 @@ def create_model_inputs(self, messages: SeqList, adapters: AdapterList): local_adapter_ids=local_adapter_ids, global_adapter_ids=global_adapter_ids, adapter_offsets=adapter_offsets, - local_adapter_scalings=local_adapter_scalings, max_rank=max_rank, meta=meta) diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 756d51bdea..f92b2bfb5f 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -82,7 +82,6 @@ class ModelInputs: local_adapter_ids: torch.LongTensor = None global_adapter_ids: torch.LongTensor = None adapter_offsets: torch.LongTensor = None - local_adapter_scalings: torch.Tensor = None max_rank: int = 0 meta: Any = None @@ -99,10 +98,8 @@ def slice(self, start: int, end: int): history_lengths = self.history_lengths[sli] local_adapter_ids = self.local_adapter_ids - local_adapter_scalings = self.local_adapter_scalings if local_adapter_ids is not None: local_adapter_ids = local_adapter_ids[sli] - local_adapter_scalings = local_adapter_scalings[sli] return ModelInputs(input_ids=input_ids, seq_length=seq_length, @@ -115,7 +112,6 @@ def slice(self, start: int, end: int): local_adapter_ids=local_adapter_ids, global_adapter_ids=self.global_adapter_ids, adapter_offsets=self.adapter_offsets, - local_adapter_scalings=local_adapter_scalings, max_rank=self.max_rank, meta=self.meta) @@ -144,10 +140,8 @@ def split(self, split_size: int, block_size: int): block_end += 1 local_adapter_ids = self.local_adapter_ids - local_adapter_scalings = self.local_adapter_scalings if local_adapter_ids is not None: local_adapter_ids = local_adapter_ids[:, start:end] - local_adapter_scalings = local_adapter_scalings[:, start:end] inp = ModelInputs( input_ids=self.input_ids[:, start:end], @@ -161,7 +155,6 @@ def split(self, split_size: int, block_size: int): local_adapter_ids=local_adapter_ids, global_adapter_ids=self.global_adapter_ids, adapter_offsets=self.adapter_offsets, - local_adapter_scalings=local_adapter_scalings, max_rank=self.max_rank, meta=self.meta, ) @@ -205,7 +198,6 @@ class StepContext: local_adapter_ids: torch.LongTensor = None global_adapter_ids: torch.LongTensor = None adapter_offsets: torch.LongTensor = None - local_adapter_scalings: torch.Tensor = None max_rank: int = 0 _outputs: Dict = field(default_factory=dict) @@ -254,7 +246,6 @@ def new( local_adapter_ids=inputs.local_adapter_ids, global_adapter_ids=inputs.global_adapter_ids, adapter_offsets=inputs.adapter_offsets, - local_adapter_scalings=inputs.local_adapter_scalings, max_rank=inputs.max_rank) return ret diff --git a/lmdeploy/pytorch/models/peft.py b/lmdeploy/pytorch/models/peft.py index 74e5279d11..52fbc48df1 100644 --- a/lmdeploy/pytorch/models/peft.py +++ b/lmdeploy/pytorch/models/peft.py @@ -36,6 +36,7 @@ def _make_packed_lora_input(self, x): layer_idx = self.layer_idx ranks = self.ranks[global_adapter_ids] block_starts = self.block_starts[global_adapter_ids] + scaling = self.scaling[global_adapter_ids] k_cache, v_cache = context.kv_caches[layer_idx] cache_len = k_cache.size(0) a_cache = k_cache.view(cache_len, -1) @@ -47,7 +48,7 @@ def _make_packed_lora_input(self, x): b_start_loc=context.q_start_loc, b_seq_lens=context.seq_length, b_adapter_ids=context.local_adapter_ids, - b_scaling=context.local_adapter_scalings, + b_scaling=scaling, rank_page_table=context.adapter_offsets, rank_page_start=block_starts, ranks=ranks,