Skip to content

Commit

Permalink
fix path of weight index
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire committed Sep 22, 2023
1 parent 6252b99 commit cdbea77
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
7 changes: 3 additions & 4 deletions lmdeploy/pytorch_poc/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import itertools
import json
import os
import os.path as osp
import time
from dataclasses import dataclass
from queue import Queue
Expand All @@ -19,6 +18,7 @@
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper)
from transformers.utils import WEIGHTS_INDEX_NAME, cached_file

from lmdeploy.pytorch.accel import LoadNoInit
from lmdeploy.pytorch_poc.config import (CacheConfig, ModelConfig,
Expand Down Expand Up @@ -257,15 +257,14 @@ def _tp_model_loop(
torch_dtype=torch_dtype,
trust_remote_code=True)

torch_model_json_path = osp.join(model_path,
'pytorch_model.bin.index.json')
torch_model_json_path = cached_file(model_path, WEIGHTS_INDEX_NAME)
with open(torch_model_json_path, mode='r') as f:
torch_model_json = json.load(f)

weight_map = torch_model_json['weight_map']

checkpoints = list(set(weight_map.values()))
checkpoints = [osp.join(model_path, ckpt) for ckpt in checkpoints]
checkpoints = [cached_file(model_path, ckpt) for ckpt in checkpoints]
patched_model = patch(
model,
extra_args=extra_args,
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/pytorch_poc/patch/internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,4 @@ def forward(
past_key_value,
output_attentions,
world_size=world_size,
)
)

0 comments on commit cdbea77

Please sign in to comment.