From a9ce2e2a4126768a7796686f7d517f0ece468ff3 Mon Sep 17 00:00:00 2001 From: RunningLeon Date: Tue, 10 Dec 2024 11:20:20 +0800 Subject: [PATCH] use weights iterator --- .../weight_loader/model_weight_loader.py | 59 ++++++++++++------- requirements/runtime.txt | 1 + 2 files changed, 40 insertions(+), 20 deletions(-) diff --git a/lmdeploy/pytorch/weight_loader/model_weight_loader.py b/lmdeploy/pytorch/weight_loader/model_weight_loader.py index cb548614c..0a8ac9bb4 100644 --- a/lmdeploy/pytorch/weight_loader/model_weight_loader.py +++ b/lmdeploy/pytorch/weight_loader/model_weight_loader.py @@ -1,9 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import json import os.path as osp +from typing import List import torch -from transformers.modeling_utils import load_state_dict +from safetensors.torch import safe_open +from tqdm.auto import tqdm from transformers.utils import (SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME) @@ -90,6 +92,28 @@ def _get_weight_path(model_path: str, weight_type: str): return weight_path, weight_name +def _get_safetensors_weights_iterator(hf_files: List[str], disable_tqdm: bool): + """get safeternsors weights iterator.""" + for file in tqdm(hf_files, + desc='Loading weights from safetensors', + disable=disable_tqdm): + with safe_open(file, framework='pt') as f: + for name in f.keys(): + param = f.get_tensor(name) + yield name, param + + +def _get_pt_weights_iterator(hf_files: List[str], disable_tqdm: bool): + """get pt weights iterator.""" + for file in tqdm(hf_files, + desc='Loading weights from pt ckpt', + disable=disable_tqdm): + state = torch.load(file, weights_only=True, map_location='cpu') + yield from state.items() + del state + torch.cuda.empty_cache() + + class ModelWeightLoader: """model weight loader for sharded weights.""" @@ -115,13 +139,17 @@ def _get_shard_paths(model_path: str, is_sharded: bool, weight_type: str): path, _ = _get_weight_path(model_path, weight_type) return (path, ) - def _load_shard(self, path: str): - """load shards.""" - state_dict = load_state_dict(path) + def _get_weights_iterator(self, paths: List[str], disable_tqdm: bool): + """get weights iterator.""" + if self._weight_type == 'safetensors': + weights_iterator = _get_safetensors_weights_iterator( + paths, disable_tqdm) + else: + weights_iterator = _get_pt_weights_iterator(paths, disable_tqdm) if self._prefix is not None: - state_dict = dict( - (f'{self._prefix}{k}', v) for k, v in state_dict.items()) - return state_dict + weights_iterator = ((self._prefix + name, tensor) + for name, tensor in weights_iterator) + return weights_iterator def load_model_weights( self, @@ -131,19 +159,10 @@ def load_model_weights( """load model weights implementation.""" assert hasattr(model, 'load_weights') paths = self._shard_paths - world_size, rank = get_world_rank() - for path in paths: - - # log - file_name = osp.split(path)[1] - msg = f'loading weights - "{file_name}"' - if world_size > 1: - msg = f'rank[{rank}] {msg}' - logger.info(msg) - - # process - state_dict = self._load_shard(path) - model.load_weights(state_dict.items()) + _, rank = get_world_rank() + disable_tqdm = rank != 0 + weights_iterator = self._get_weights_iterator(paths, disable_tqdm) + model.load_weights(weights_iterator) if device is not None: device = model.to(device) diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 400c492b0..fca862f7a 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -17,6 +17,7 @@ shortuuid tiktoken torch<=2.4.0,>=2.0.0 torchvision<=0.19.0,>=0.15.0 +tqdm transformers triton>=2.2.0,<=3.0.0; sys_platform == "linux" uvicorn