Skip to content

Commit

Permalink
hide local scaling
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire committed Jan 26, 2024
1 parent 5ae19a1 commit 77a7b1e
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 23 deletions.
25 changes: 18 additions & 7 deletions lmdeploy/pytorch/adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,11 @@ def __get_targets():
all_targets.update(targets)
return all_targets

def __get_rank_and_start(target_names):
def __get_linear_meta(target_names):
"""get rank and start."""
rank_map = dict()
start_map = dict()
scaling_map = dict()
for target in target_names:
ranks = [0] + [
weight_map.target_modules[target].rank
Expand All @@ -92,15 +93,22 @@ def __get_rank_and_start(target_names):
weight_map.target_modules[target].block_start
for weight_map in weight_maps
]
scaling = [0] + [
weight_map.target_modules[target].scaling
for weight_map in weight_maps
]
rank_map[target] = torch.tensor(ranks)
start_map[target] = torch.tensor(block_starts)
return rank_map, start_map
scaling_map[target] = torch.tensor(scaling)
return rank_map, start_map, scaling_map

def __update_linear(linear, idx, rank_map, start_map, adapter_names):
def __update_linear(linear, idx, rank_map, start_map, scaling_map,
adapter_names):
"""update linear."""
linear.layer_idx = idx
linear.ranks = rank_map[target].to(device)
linear.block_starts = start_map[target].to(device)
linear.scaling = scaling_map[target].to(device)
for name in adapter_names:
if name in linear.lora_A:
linear.lora_A.pop(name)
Expand All @@ -113,14 +121,15 @@ def __update_linear(linear, idx, rank_map, start_map, adapter_names):
for weight_map in weight_maps:
weight_map.expand_targets(all_targets)

rank_map, start_map = __get_rank_and_start(all_targets)
rank_map, start_map, scaling_map = __get_linear_meta(all_targets)

for idx, lora_linear in lora_linears.items():
for target, linear in lora_linear.items():
__update_linear(linear,
idx,
rank_map=rank_map,
start_map=start_map,
scaling_map=scaling_map,
adapter_names=adapter_names)


Expand All @@ -139,6 +148,7 @@ def get_max_lora_weight_size(model: torch.nn.Module):
class TargetMeta:
rank: int
block_start: int
scaling: float


@dataclass
Expand All @@ -149,12 +159,12 @@ class AdapterWeightMap:

@classmethod
def new(cls, adapter_name: str, rank: int, target_names: List[str],
block_table: Tensor):
block_table: Tensor, scaling: float):
"""create new weightmap."""
block_start = 0
target_modules: Dict[str, TargetMeta] = dict()
for name in target_names:
target_modules[name] = TargetMeta(rank, block_start)
target_modules[name] = TargetMeta(rank, block_start, scaling)
block_start += rank

return AdapterWeightMap(adapter_name,
Expand Down Expand Up @@ -296,7 +306,8 @@ def build_weight_map(self, block_table: Tensor):
return AdapterWeightMap.new(self.name,
rank=self.rank,
target_names=self.target_modules,
block_table=block_table)
block_table=block_table,
scaling=self.scaling)


class AdapterManager:
Expand Down
6 changes: 0 additions & 6 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,6 @@ def create_model_inputs(self, messages: SeqList, adapters: AdapterList):
local_adapter_ids = None
global_adapter_ids = None
adapter_offsets = None
local_adapter_scalings = None
max_rank = 0
if ADAPTER_MANAGER.num_adapters() > 1:
local_adapter_ids = _get_adapter_ids(messages, adapters)
Expand All @@ -381,10 +380,6 @@ def create_model_inputs(self, messages: SeqList, adapters: AdapterList):
global_adapter_ids = seq_length.new_tensor(global_adapter_ids)
ranks = [ada.rank for ada in adapters]
max_rank = max(ranks)
local_adapter_scalings = [
adapters[ada_ids].scaling for ada_ids in local_adapter_ids
]
local_adapter_scalings = torch.tensor(local_adapter_scalings)

# add batch dim [bs=1, seq_len]
if input_ids.ndim == 1:
Expand All @@ -401,7 +396,6 @@ def create_model_inputs(self, messages: SeqList, adapters: AdapterList):
local_adapter_ids=local_adapter_ids,
global_adapter_ids=global_adapter_ids,
adapter_offsets=adapter_offsets,
local_adapter_scalings=local_adapter_scalings,
max_rank=max_rank,
meta=meta)

Expand Down
9 changes: 0 additions & 9 deletions lmdeploy/pytorch/engine/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ class ModelInputs:
local_adapter_ids: torch.LongTensor = None
global_adapter_ids: torch.LongTensor = None
adapter_offsets: torch.LongTensor = None
local_adapter_scalings: torch.Tensor = None
max_rank: int = 0
meta: Any = None

Expand All @@ -99,10 +98,8 @@ def slice(self, start: int, end: int):
history_lengths = self.history_lengths[sli]

local_adapter_ids = self.local_adapter_ids
local_adapter_scalings = self.local_adapter_scalings
if local_adapter_ids is not None:
local_adapter_ids = local_adapter_ids[sli]
local_adapter_scalings = local_adapter_scalings[sli]

return ModelInputs(input_ids=input_ids,
seq_length=seq_length,
Expand All @@ -115,7 +112,6 @@ def slice(self, start: int, end: int):
local_adapter_ids=local_adapter_ids,
global_adapter_ids=self.global_adapter_ids,
adapter_offsets=self.adapter_offsets,
local_adapter_scalings=local_adapter_scalings,
max_rank=self.max_rank,
meta=self.meta)

Expand Down Expand Up @@ -144,10 +140,8 @@ def split(self, split_size: int, block_size: int):
block_end += 1

local_adapter_ids = self.local_adapter_ids
local_adapter_scalings = self.local_adapter_scalings
if local_adapter_ids is not None:
local_adapter_ids = local_adapter_ids[:, start:end]
local_adapter_scalings = local_adapter_scalings[:, start:end]

inp = ModelInputs(
input_ids=self.input_ids[:, start:end],
Expand All @@ -161,7 +155,6 @@ def split(self, split_size: int, block_size: int):
local_adapter_ids=local_adapter_ids,
global_adapter_ids=self.global_adapter_ids,
adapter_offsets=self.adapter_offsets,
local_adapter_scalings=local_adapter_scalings,
max_rank=self.max_rank,
meta=self.meta,
)
Expand Down Expand Up @@ -205,7 +198,6 @@ class StepContext:
local_adapter_ids: torch.LongTensor = None
global_adapter_ids: torch.LongTensor = None
adapter_offsets: torch.LongTensor = None
local_adapter_scalings: torch.Tensor = None
max_rank: int = 0

_outputs: Dict = field(default_factory=dict)
Expand Down Expand Up @@ -254,7 +246,6 @@ def new(
local_adapter_ids=inputs.local_adapter_ids,
global_adapter_ids=inputs.global_adapter_ids,
adapter_offsets=inputs.adapter_offsets,
local_adapter_scalings=inputs.local_adapter_scalings,
max_rank=inputs.max_rank)
return ret

Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/models/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def _make_packed_lora_input(self, x):
layer_idx = self.layer_idx
ranks = self.ranks[global_adapter_ids]
block_starts = self.block_starts[global_adapter_ids]
scaling = self.scaling[global_adapter_ids]
k_cache, v_cache = context.kv_caches[layer_idx]
cache_len = k_cache.size(0)
a_cache = k_cache.view(cache_len, -1)
Expand All @@ -47,7 +48,7 @@ def _make_packed_lora_input(self, x):
b_start_loc=context.q_start_loc,
b_seq_lens=context.seq_length,
b_adapter_ids=context.local_adapter_ids,
b_scaling=context.local_adapter_scalings,
b_scaling=scaling,
rank_page_table=context.adapter_offsets,
rank_page_start=block_starts,
ranks=ranks,
Expand Down

0 comments on commit 77a7b1e

Please sign in to comment.