Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix baichuan2 lora #1042

Merged
merged 7 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions lmdeploy/pytorch/accel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def __init__(self):
self.normal_ = torch.nn.init.normal_
self.kaiming_uniform_ = torch.nn.init.kaiming_uniform_
self.kaiming_normal_ = torch.nn.init.kaiming_normal_
self.tensor_normal_ = torch.Tensor.normal_

def __enter__(self, *args, **kwargs):
"""Replace initializers with no-op."""
Expand All @@ -24,6 +25,7 @@ def __enter__(self, *args, **kwargs):
torch.nn.init.normal_ = lambda *args, **kwargs: None
torch.nn.init.kaiming_uniform_ = lambda *args, **kwargs: None
torch.nn.init.kaiming_normal_ = lambda *args, **kwargs: None
torch.Tensor.normal_ = lambda *args, **kwargs: None

def __exit__(self, *args, **kwargs):
"""Recover."""
Expand All @@ -35,3 +37,4 @@ def __exit__(self, *args, **kwargs):
torch.nn.init.normal_ = self.normal_
torch.nn.init.kaiming_uniform_ = self.kaiming_uniform_
torch.nn.init.kaiming_normal_ = self.kaiming_normal_
torch.Tensor.normal_ = self.tensor_normal_
32 changes: 24 additions & 8 deletions lmdeploy/pytorch/adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,11 @@ def __get_targets():
all_targets.update(targets)
return all_targets

def __get_rank_and_start(target_names):
def __get_linear_meta(target_names):
"""get rank and start."""
rank_map = dict()
start_map = dict()
scaling_map = dict()
for target in target_names:
ranks = [0] + [
weight_map.target_modules[target].rank
Expand All @@ -92,15 +93,22 @@ def __get_rank_and_start(target_names):
weight_map.target_modules[target].block_start
for weight_map in weight_maps
]
scaling = [0] + [
weight_map.target_modules[target].scaling
for weight_map in weight_maps
]
rank_map[target] = torch.tensor(ranks)
start_map[target] = torch.tensor(block_starts)
return rank_map, start_map
scaling_map[target] = torch.tensor(scaling)
return rank_map, start_map, scaling_map

def __update_linear(linear, idx, rank_map, start_map, adapter_names):
def __update_linear(linear, idx, rank_map, start_map, scaling_map,
adapter_names):
"""update linear."""
linear.layer_idx = idx
linear.ranks = rank_map[target].to(device)
linear.block_starts = start_map[target].to(device)
linear.scaling = scaling_map[target].to(device)
for name in adapter_names:
if name in linear.lora_A:
linear.lora_A.pop(name)
Expand All @@ -113,14 +121,15 @@ def __update_linear(linear, idx, rank_map, start_map, adapter_names):
for weight_map in weight_maps:
weight_map.expand_targets(all_targets)

rank_map, start_map = __get_rank_and_start(all_targets)
rank_map, start_map, scaling_map = __get_linear_meta(all_targets)

for idx, lora_linear in lora_linears.items():
for target, linear in lora_linear.items():
__update_linear(linear,
idx,
rank_map=rank_map,
start_map=start_map,
scaling_map=scaling_map,
adapter_names=adapter_names)


Expand All @@ -139,6 +148,7 @@ def get_max_lora_weight_size(model: torch.nn.Module):
class TargetMeta:
rank: int
block_start: int
scaling: float


@dataclass
Expand All @@ -149,12 +159,12 @@ class AdapterWeightMap:

@classmethod
def new(cls, adapter_name: str, rank: int, target_names: List[str],
block_table: Tensor):
block_table: Tensor, scaling: float):
"""create new weightmap."""
block_start = 0
target_modules: Dict[str, TargetMeta] = dict()
for name in target_names:
target_modules[name] = TargetMeta(rank, block_start)
target_modules[name] = TargetMeta(rank, block_start, scaling)
block_start += rank

return AdapterWeightMap(adapter_name,
Expand All @@ -170,7 +180,7 @@ def expand_targets(self,
continue
else:
raise RuntimeError(f'target {name} exists.')
self.target_modules[name] = TargetMeta(0, 0)
self.target_modules[name] = TargetMeta(0, 0, 0.0)

@classmethod
def cache_lora_a(cls, cache: Tensor, weight: Tensor, block_table: Tensor):
Expand Down Expand Up @@ -266,6 +276,11 @@ def rank(self):
"""get rank."""
return self.config.r

@property
def scaling(self):
"""get scaling."""
return self.config.lora_alpha / self.rank

def is_actived(self):
"""check if adapter is active."""
return self._active
Expand All @@ -291,7 +306,8 @@ def build_weight_map(self, block_table: Tensor):
return AdapterWeightMap.new(self.name,
rank=self.rank,
target_names=self.target_modules,
block_table=block_table)
block_table=block_table,
scaling=self.scaling)


class AdapterManager:
Expand Down
Loading
Loading