Skip to content

Commit

Permalink
use weights iterator
Browse files Browse the repository at this point in the history
  • Loading branch information
RunningLeon committed Dec 12, 2024
1 parent af7157a commit a9ce2e2
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 20 deletions.
59 changes: 39 additions & 20 deletions lmdeploy/pytorch/weight_loader/model_weight_loader.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import json
import os.path as osp
from typing import List

import torch
from transformers.modeling_utils import load_state_dict
from safetensors.torch import safe_open
from tqdm.auto import tqdm
from transformers.utils import (SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME, WEIGHTS_NAME)

Expand Down Expand Up @@ -90,6 +92,28 @@ def _get_weight_path(model_path: str, weight_type: str):
return weight_path, weight_name


def _get_safetensors_weights_iterator(hf_files: List[str], disable_tqdm: bool):
"""get safeternsors weights iterator."""
for file in tqdm(hf_files,
desc='Loading weights from safetensors',
disable=disable_tqdm):
with safe_open(file, framework='pt') as f:
for name in f.keys():
param = f.get_tensor(name)
yield name, param


def _get_pt_weights_iterator(hf_files: List[str], disable_tqdm: bool):
"""get pt weights iterator."""
for file in tqdm(hf_files,
desc='Loading weights from pt ckpt',
disable=disable_tqdm):
state = torch.load(file, weights_only=True, map_location='cpu')
yield from state.items()
del state
torch.cuda.empty_cache()


class ModelWeightLoader:
"""model weight loader for sharded weights."""

Expand All @@ -115,13 +139,17 @@ def _get_shard_paths(model_path: str, is_sharded: bool, weight_type: str):
path, _ = _get_weight_path(model_path, weight_type)
return (path, )

def _load_shard(self, path: str):
"""load shards."""
state_dict = load_state_dict(path)
def _get_weights_iterator(self, paths: List[str], disable_tqdm: bool):
"""get weights iterator."""
if self._weight_type == 'safetensors':
weights_iterator = _get_safetensors_weights_iterator(
paths, disable_tqdm)
else:
weights_iterator = _get_pt_weights_iterator(paths, disable_tqdm)
if self._prefix is not None:
state_dict = dict(
(f'{self._prefix}{k}', v) for k, v in state_dict.items())
return state_dict
weights_iterator = ((self._prefix + name, tensor)
for name, tensor in weights_iterator)
return weights_iterator

def load_model_weights(
self,
Expand All @@ -131,19 +159,10 @@ def load_model_weights(
"""load model weights implementation."""
assert hasattr(model, 'load_weights')
paths = self._shard_paths
world_size, rank = get_world_rank()
for path in paths:

# log
file_name = osp.split(path)[1]
msg = f'loading weights - "{file_name}"'
if world_size > 1:
msg = f'rank[{rank}] {msg}'
logger.info(msg)

# process
state_dict = self._load_shard(path)
model.load_weights(state_dict.items())
_, rank = get_world_rank()
disable_tqdm = rank != 0
weights_iterator = self._get_weights_iterator(paths, disable_tqdm)
model.load_weights(weights_iterator)
if device is not None:
device = model.to(device)

Expand Down
1 change: 1 addition & 0 deletions requirements/runtime.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ shortuuid
tiktoken
torch<=2.4.0,>=2.0.0
torchvision<=0.19.0,>=0.15.0
tqdm
transformers
triton>=2.2.0,<=3.0.0; sys_platform == "linux"
uvicorn

0 comments on commit a9ce2e2

Please sign in to comment.