Skip to content

Commit

Permalink
Support for loading lora adapter weights in safetensors format (#2860)
Browse files Browse the repository at this point in the history
Co-authored-by: Ping <[email protected]>
  • Loading branch information
Galaxy-Husky and Ping authored Dec 9, 2024
1 parent 14b64c7 commit 47fa7cf
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion lmdeploy/pytorch/models/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import load_state_dict

from lmdeploy.utils import get_logger

Expand Down Expand Up @@ -295,7 +296,9 @@ def add_adapters(model: torch.nn.Module,
for name, path in adapters.items():
adapter_id = adapter_id_map[name]
checkpoint_path = f'{path}/adapter_model.bin'
state_dict = torch.load(checkpoint_path, map_location=device)
if not osp.exists(checkpoint_path):
checkpoint_path = f'{path}/adapter_model.safetensors'
state_dict = load_state_dict(checkpoint_path, map_location=device)

if hasattr(model, 'load_lora_weights'):
model.load_lora_weights(state_dict.items(), adapter_id=adapter_id)
Expand Down

0 comments on commit 47fa7cf

Please sign in to comment.