Skip to content

Commit

Permalink
[Fix] Fix load_checkpoint_in_model bug (#690)
Browse files Browse the repository at this point in the history
* fix load_checkpoint_in_model bug

* fix comments

* fix comments

* fix bugs
  • Loading branch information
HIT-cwh authored Nov 16, 2023
1 parent 7d40d19 commit 0fcc303
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 66 deletions.
35 changes: 5 additions & 30 deletions lmdeploy/lite/apis/auto_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
from pathlib import Path

import torch
from accelerate import (infer_auto_device_map, init_empty_weights,
load_checkpoint_in_model)
from torch import nn
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers import AutoTokenizer

from lmdeploy.lite.quantization.awq import (FC_FCS_MAP, NORM_FCS_MAP,
quant_weights, smooth_layers)
from lmdeploy.lite.utils import collect_target_modules
from lmdeploy.lite.utils import collect_target_modules, load_hf_from_pretrained

LAYER_TYPE_MAP = {
'InternLMForCausalLM': 'InternLMDecoderLayer',
Expand Down Expand Up @@ -39,38 +37,15 @@ def auto_awq(model: str,
tokenizer = AutoTokenizer.from_pretrained(model,
use_fast=False,
trust_remote_code=True)
hf_config = AutoConfig.from_pretrained(model, trust_remote_code=True)
checkpoint = hf_config._name_or_path

# hard code for qwen, other configs do not have the `fp16` attribute.
hf_config.fp16 = True

with init_empty_weights():
# Load model
model = AutoModelForCausalLM.from_pretrained(model,
torch_dtype=torch.float16,
trust_remote_code=True)
model.config.use_cache = False
model = load_hf_from_pretrained(model,
torch_dtype=torch.float16,
trust_remote_code=True)

layer_type = LAYER_TYPE_MAP[type(model).__name__]
fc2fcs = FC_FCS_MAP[layer_type]
norm2fcs = NORM_FCS_MAP[layer_type]

decoder_layers = collect_target_modules(model, layer_type)

# Infer device map
device_map = infer_auto_device_map(model,
no_split_module_classes=[layer_type])
for name in device_map.keys():
if name in decoder_layers or 'lm_head' in name:
device_map[name] = 'cpu'
else:
device_map[name] = 0
load_checkpoint_in_model(model,
checkpoint,
device_map,
dtype=torch.float16)

work_dir = Path(work_dir)

act_scales = torch.load(work_dir / 'inputs_stats.pth')['absmax']
Expand Down
43 changes: 8 additions & 35 deletions lmdeploy/lite/apis/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
from typing import Union

import torch
from accelerate import (infer_auto_device_map, init_empty_weights,
load_checkpoint_in_model)
from torch import nn
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers import AutoTokenizer

from lmdeploy.lite.quantization import CalibrationContext
from lmdeploy.lite.utils import collect_target_modules, get_calib_loaders
from lmdeploy.lite.utils import (collect_target_modules, get_calib_loaders,
load_hf_from_pretrained)

LAYER_TYPE_MAP = {
'InternLMForCausalLM': 'InternLMDecoderLayer',
Expand Down Expand Up @@ -109,7 +108,7 @@ def calibrate(model: str,
given dataset.
Args:
model (str): The model to be loaded.
model (str): The name or path of the model to be loaded.
calib_dataset (str, optional): The calibration dataset name.
Defaults to 'c4'.
calib_samples (int, optional): The number of samples for calibration.
Expand All @@ -129,21 +128,10 @@ def calibrate(model: str,
tokenizer = AutoTokenizer.from_pretrained(model,
use_fast=False,
trust_remote_code=True)
hf_config = AutoConfig.from_pretrained(model,
torch_dtype=torch.float16,
trust_remote_code=True)
checkpoint = hf_config._name_or_path

# hard code for qwen, other configs do not have the `fp16` attribute.
hf_config.fp16 = True

with init_empty_weights():
# Load model
model = AutoModelForCausalLM.from_pretrained(model,
config=hf_config,
torch_dtype=torch.float16,
trust_remote_code=True)
model.config.use_cache = False

model = load_hf_from_pretrained(model,
torch_dtype=torch.float16,
trust_remote_code=True)

model_type = type(model).__name__
if model_type not in LAYER_TYPE_MAP or model_type not in NORM_TYPE_MAP:
Expand All @@ -164,21 +152,6 @@ def calibrate(model: str,
layer_type = LAYER_TYPE_MAP[type(model).__name__]
norm_type = NORM_TYPE_MAP[type(model).__name__]

decoder_layers = collect_target_modules(model, layer_type)

# Infer device map
device_map = infer_auto_device_map(model,
no_split_module_classes=[layer_type])
for name in device_map.keys():
if name in decoder_layers or 'lm_head' in name:
device_map[name] = 'cpu'
else:
device_map[name] = 0
load_checkpoint_in_model(model,
checkpoint,
device_map,
dtype=torch.float16)

_prepare_for_calibrate(model, layer_type, 'lm_head', device)

print('Loading calibrate dataset ...')
Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/lite/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
from .collect import (bimap_name_mod, collect_target_modules,
collect_target_weights)
from .global_avail import GlobalAvailMixin
from .load import load_hf_from_pretrained

__all__ = [
'cal_qparams_per_channel_absmax', 'cal_qparams_per_channel_minmax',
'cal_qparams_per_group_absmax', 'cal_qparams_per_group_minmax',
'cal_qparams_per_tensor_absmax', 'cal_qparams_per_tensor_minmax',
'QParams', 'get_calib_loaders', 'collect_target_modules', 'precise_round',
'collect_target_weights', 'GlobalAvailMixin', 'split_decoder_layer_inputs',
'bimap_name_mod', 'concat_decoder_layer_outputs'
'bimap_name_mod', 'concat_decoder_layer_outputs', 'load_hf_from_pretrained'
]
55 changes: 55 additions & 0 deletions lmdeploy/lite/utils/load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) OpenMMLab. All rights reserved.

import torch
from accelerate import infer_auto_device_map, init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM

from lmdeploy.lite.utils import collect_target_modules
from lmdeploy.pytorch.model import LoadWoInit

LAYER_TYPE_MAP = {
'InternLMForCausalLM': 'InternLMDecoderLayer',
'QWenLMHeadModel': 'QWenBlock',
'BaiChuanForCausalLM': 'DecoderLayer', # Baichuan 7B
'BaichuanForCausalLM': 'DecoderLayer', # Baichuan2 7B
'LlamaForCausalLM': 'LlamaDecoderLayer',
}


def load_hf_from_pretrained(pretrained_model_name_or_path, **kwargs):

kwargs.pop('config', None)

hf_config = AutoConfig.from_pretrained(pretrained_model_name_or_path,
torch_dtype=torch.float16,
trust_remote_code=True)

# hard code for qwen, other configs do not have the `fp16` attribute.
hf_config.fp16 = True

with init_empty_weights():
# Load model
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path, config=hf_config, **kwargs)
model.config.use_cache = False
layer_type = LAYER_TYPE_MAP[type(model).__name__]
decoder_layers = collect_target_modules(model, layer_type)
# Infer device map
device_map = infer_auto_device_map(model,
no_split_module_classes=[layer_type])
for name in device_map.keys():
if name in decoder_layers or 'lm_head' in name:
device_map[name] = 'cpu'
else:
device_map[name] = 0
if 'device_map' in kwargs:
kwargs.pop('device_map')
with LoadWoInit():
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
device_map=device_map,
config=hf_config,
**kwargs)
model.config.use_cache = False

return model

0 comments on commit 0fcc303

Please sign in to comment.