Skip to content

Commit

Permalink
fix baichuan2
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire committed Oct 8, 2023
1 parent 3271d31 commit c27aa34
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 9 deletions.
20 changes: 13 additions & 7 deletions lmdeploy/pytorch_poc/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper)
from transformers.utils import WEIGHTS_INDEX_NAME, cached_file
from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME, cached_file

from lmdeploy.pytorch.accel import LoadNoInit
from lmdeploy.pytorch_poc.config import (CacheConfig, ModelConfig,
Expand Down Expand Up @@ -261,14 +261,20 @@ def _tp_model_loop(
torch_dtype=torch_dtype,
trust_remote_code=True)

torch_model_json_path = cached_file(model_path, WEIGHTS_INDEX_NAME)
with open(torch_model_json_path, mode='r') as f:
torch_model_json = json.load(f)
try:
torch_model_json_path = cached_file(model_path, WEIGHTS_INDEX_NAME)
with open(torch_model_json_path, mode='r') as f:
torch_model_json = json.load(f)

weight_map = torch_model_json['weight_map']
weight_map = torch_model_json['weight_map']

checkpoints = list(set(weight_map.values()))
checkpoints = [cached_file(model_path, ckpt) for ckpt in checkpoints]
checkpoints = list(set(weight_map.values()))
checkpoints = [
cached_file(model_path, ckpt) for ckpt in checkpoints
]
except Exception:
logger.warning(f'load failed, try load from {WEIGHTS_NAME}.')
checkpoints = [cached_file(model_path, WEIGHTS_NAME)]
patched_model = patch(
model,
extra_args=extra_args,
Expand Down
4 changes: 2 additions & 2 deletions lmdeploy/pytorch_poc/patch/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def apply_rotary_pos_emb(q: Tensor, k: Tensor, cos: Tensor, sin: Tensor,
"""
# The first two dimensions of cos and sin are always 1,
# so we can `squeeze` them.
cos = cos.to(q.device)
sin = sin.to(q.device)
cos = cos.to(device=q.device, dtype=q.dtype)
sin = sin.to(device=q.device, dtype=q.dtype)
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
seq_length = position_ids[..., -1] + 1
Expand Down
4 changes: 4 additions & 0 deletions lmdeploy/pytorch_poc/patch/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.distributed as dist
from addict import Addict
from torch.distributed._tensor import DeviceMesh
from transformers.utils import TRANSFORMERS_DYNAMIC_MODULE_NAME

from lmdeploy.pytorch_poc.dist_utils import partition_module, replicate_module
from lmdeploy.utils import get_logger
Expand All @@ -27,6 +28,9 @@
MODULE_MAP.update({
'modeling_baichuan.Model':
'lmdeploy.pytorch_poc.patch.llama.LlamaModel', # noqa
(f'{TRANSFORMERS_DYNAMIC_MODULE_NAME}.Baichuan2-7B-Chat'
'.modeling_baichuan.BaichuanModel'):
'lmdeploy.pytorch_poc.patch.llama.LlamaModel', # noqa
'modeling_baichuan.BaichuanModel':
'lmdeploy.pytorch_poc.patch.baichuan.BaichuanModel', # noqa
'modeling_baichuan.Attention':
Expand Down

0 comments on commit c27aa34

Please sign in to comment.