Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix load_checkpoint_in_model bug #690

Merged
merged 4 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 6 additions & 26 deletions lmdeploy/lite/apis/auto_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
from pathlib import Path

import torch
from accelerate import (infer_auto_device_map, init_empty_weights,
load_checkpoint_in_model)
from torch import nn
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers import AutoConfig, AutoTokenizer

from lmdeploy.lite.quantization.awq import (FC_FCS_MAP, NORM_FCS_MAP,
quant_weights, smooth_layers)
from lmdeploy.lite.utils import collect_target_modules
from lmdeploy.lite.utils import collect_target_modules, load_hf_from_pretrained

LAYER_TYPE_MAP = {
'InternLMForCausalLM': 'InternLMDecoderLayer',
Expand Down Expand Up @@ -40,37 +38,19 @@ def auto_awq(model: str,
use_fast=False,
trust_remote_code=True)
hf_config = AutoConfig.from_pretrained(model, trust_remote_code=True)
checkpoint = hf_config._name_or_path

# hard code for qwen, other configs do not have the `fp16` attribute.
hf_config.fp16 = True

with init_empty_weights():
# Load model
model = AutoModelForCausalLM.from_pretrained(model,
torch_dtype=torch.float16,
trust_remote_code=True)
model.config.use_cache = False
model = load_hf_from_pretrained(model,
config=hf_config,
torch_dtype=torch.float16,
trust_remote_code=True)

layer_type = LAYER_TYPE_MAP[type(model).__name__]
fc2fcs = FC_FCS_MAP[layer_type]
norm2fcs = NORM_FCS_MAP[layer_type]

decoder_layers = collect_target_modules(model, layer_type)

# Infer device map
device_map = infer_auto_device_map(model,
no_split_module_classes=[layer_type])
for name in device_map.keys():
if name in decoder_layers or 'lm_head' in name:
device_map[name] = 'cpu'
else:
device_map[name] = 0
load_checkpoint_in_model(model,
checkpoint,
device_map,
dtype=torch.float16)

work_dir = Path(work_dir)

act_scales = torch.load(work_dir / 'inputs_stats.pth')['absmax']
Expand Down
36 changes: 8 additions & 28 deletions lmdeploy/lite/apis/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
from typing import Union

import torch
from accelerate import (infer_auto_device_map, init_empty_weights,
load_checkpoint_in_model)
from torch import nn
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers import AutoConfig, AutoTokenizer

from lmdeploy.lite.quantization import CalibrationContext
from lmdeploy.lite.utils import collect_target_modules, get_calib_loaders
from lmdeploy.lite.utils import (collect_target_modules, get_calib_loaders,
load_hf_from_pretrained)

LAYER_TYPE_MAP = {
'InternLMForCausalLM': 'InternLMDecoderLayer',
Expand Down Expand Up @@ -109,7 +108,7 @@ def calibrate(model: str,
given dataset.

Args:
model (str): The model to be loaded.
model (str): The name or path of the model to be loaded.
calib_dataset (str, optional): The calibration dataset name.
Defaults to 'c4'.
calib_samples (int, optional): The number of samples for calibration.
Expand All @@ -132,18 +131,14 @@ def calibrate(model: str,
hf_config = AutoConfig.from_pretrained(model,
torch_dtype=torch.float16,
trust_remote_code=True)
checkpoint = hf_config._name_or_path

# hard code for qwen, other configs do not have the `fp16` attribute.
hf_config.fp16 = True

with init_empty_weights():
# Load model
model = AutoModelForCausalLM.from_pretrained(model,
config=hf_config,
torch_dtype=torch.float16,
trust_remote_code=True)
model.config.use_cache = False
model = load_hf_from_pretrained(model,
config=hf_config,
torch_dtype=torch.float16,
trust_remote_code=True)

model_type = type(model).__name__
if model_type not in LAYER_TYPE_MAP or model_type not in NORM_TYPE_MAP:
Expand All @@ -164,21 +159,6 @@ def calibrate(model: str,
layer_type = LAYER_TYPE_MAP[type(model).__name__]
norm_type = NORM_TYPE_MAP[type(model).__name__]

decoder_layers = collect_target_modules(model, layer_type)

# Infer device map
device_map = infer_auto_device_map(model,
no_split_module_classes=[layer_type])
for name in device_map.keys():
if name in decoder_layers or 'lm_head' in name:
device_map[name] = 'cpu'
else:
device_map[name] = 0
load_checkpoint_in_model(model,
checkpoint,
device_map,
dtype=torch.float16)

_prepare_for_calibrate(model, layer_type, 'lm_head', device)

print('Loading calibrate dataset ...')
Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/lite/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
from .collect import (bimap_name_mod, collect_target_modules,
collect_target_weights)
from .global_avail import GlobalAvailMixin
from .load import load_hf_from_pretrained

__all__ = [
'cal_qparams_per_channel_absmax', 'cal_qparams_per_channel_minmax',
'cal_qparams_per_group_absmax', 'cal_qparams_per_group_minmax',
'cal_qparams_per_tensor_absmax', 'cal_qparams_per_tensor_minmax',
'QParams', 'get_calib_loaders', 'collect_target_modules', 'precise_round',
'collect_target_weights', 'GlobalAvailMixin', 'split_decoder_layer_inputs',
'bimap_name_mod', 'concat_decoder_layer_outputs'
'bimap_name_mod', 'concat_decoder_layer_outputs', 'load_hf_from_pretrained'
]
40 changes: 40 additions & 0 deletions lmdeploy/lite/utils/load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) OpenMMLab. All rights reserved.

from accelerate import infer_auto_device_map, init_empty_weights
from transformers import AutoModelForCausalLM

from lmdeploy.lite.utils import collect_target_modules
from lmdeploy.pytorch.model import LoadWoInit

LAYER_TYPE_MAP = {
'InternLMForCausalLM': 'InternLMDecoderLayer',
'QWenLMHeadModel': 'QWenBlock',
'BaiChuanForCausalLM': 'DecoderLayer',
lvhan028 marked this conversation as resolved.
Show resolved Hide resolved
'LlamaForCausalLM': 'LlamaDecoderLayer',
}


def load_hf_from_pretrained(pretrained_model_name_or_path, **kwargs):
with init_empty_weights():
# Load model
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path, **kwargs)
model.config.use_cache = False
layer_type = LAYER_TYPE_MAP[type(model).__name__]
decoder_layers = collect_target_modules(model, layer_type)
# Infer device map
device_map = infer_auto_device_map(model,
no_split_module_classes=[layer_type])
for name in device_map.keys():
if name in decoder_layers or 'lm_head' in name:
device_map[name] = 'cpu'
else:
device_map[name] = 0
if 'device_map' in kwargs:
kwargs.pop('device_map')
with LoadWoInit():
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path, device_map=device_map, **kwargs)
model.config.use_cache = False

return model
Loading