Skip to content

Commit

Permalink
合并chatglm的lora,pv2例子
Browse files Browse the repository at this point in the history
  • Loading branch information
Tongjilibo committed Apr 18, 2024
1 parent c36dc93 commit dbe0b95
Show file tree
Hide file tree
Showing 8 changed files with 207 additions and 628 deletions.
2 changes: 1 addition & 1 deletion bert4torch/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def set_outputs(self, outputs):
else:
self.output = outputs[0]

def quantize(self, quantization_method, **kwargs):
def quantize(self, quantization_method:Literal['cpm_kernels', 'load_in_8bit', 'load_in_4bit'], **kwargs):
'''量化'''
if self.quantized:
print("Already quantized.")
Expand Down
27 changes: 18 additions & 9 deletions bert4torch/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@

from torch.nn import Linear, Embedding
from torch.nn.parameter import Parameter
from torch import nn
import torch.nn.functional as F
import bz2
import torch
import base64
import ctypes
from typing import List
from typing import List, Union, Dict
import re
from tqdm import tqdm
from functools import partial
import inspect
from bert4torch.snippets import log_error
from bert4torch.snippets import is_package_available

try:
from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up
Expand Down Expand Up @@ -41,7 +42,6 @@ def __init__(self, code: bytes, function_names: List[str]):
)
except Exception as exception:
kernels = None
log_error("Failed to load cpm_kernels:" + str(exception))


class W8A16Linear(torch.autograd.Function):
Expand Down Expand Up @@ -222,14 +222,18 @@ def forward(self, input):
return output


def quantize_cpm_kernels(model, quantization_bit, use_quantization_cache=False, empty_init=False, target_modules=None, **kwargs):
def quantize_cpm_kernels(model:nn.Module, quantization_bit:int, use_quantization_cache:bool=False,
empty_init:bool=False, target_modules:Union[str, List]=None, **kwargs):
"""从chagglm-6b移植过来的的量化,方便以int8和int4进行推理
源链接:https://huggingface.co/THUDM/chatglm-6b/blob/main/quantization.py
Replace fp16 linear with quantized linear
这里修改了hard code, 可以适配其他模型
target_modules: str/list, 指定对某些层做量化
"""
if not is_package_available('cpm_kernels'):
raise ModuleNotFoundError('Module `cpm_kernels` not found, you may use `pip install cpm_kernels`')

modules_trans = {}
for name, module in model.named_modules():
# target_modules=None, 表示对所有Linear层替换
Expand Down Expand Up @@ -257,7 +261,7 @@ def quantize_cpm_kernels(model, quantization_bit, use_quantization_cache=False,

cache = dict()
for name, module in tqdm(modules_trans.items(), desc='Quantize linear layers'):
cache_name = re.sub('\.[0-9]+\.', '.', name)
cache_name = re.sub(r'\.[0-9]+\.', '.', name)
if use_quantization_cache and (cache_name not in cache):
n, m = module.weight.size(0), module.weight.size(1)
cache[cache_name] = CacheTensor(n, m, dtype=dtype, device=current_device, requires_grad=False)
Expand All @@ -273,17 +277,22 @@ def quantize_cpm_kernels(model, quantization_bit, use_quantization_cache=False,
del module
# 赋值
name_new = list(name)
for iter_ in re.finditer('\.[0-9]+\.', name):
for iter_ in re.finditer(r'\.[0-9]+\.', name):
iter_str = name[iter_.start():iter_.end()]
name_new[iter_.start():iter_.end()] = [''] * (iter_.end()-iter_.start())
name_new[iter_.start()] = '[' + iter_str[1:-1] + '].'
exec('model.' + ''.join(name_new) + ' = module_quant')
return model


def quantize_load_in_kbit(model, load_in_8bit=False, load_in_4bit=False, keep_in_fp32_modules=None, llm_int8_skip_modules=None, quantization_config=None, **kwargs):
def quantize_load_in_kbit(model:nn.Module, load_in_8bit:bool=False, load_in_4bit:bool=False, keep_in_fp32_modules:List=None,
llm_int8_skip_modules:List=None, quantization_config:Dict=None, **kwargs):
'''transformer的load_in_8bit, 源自transformer源代码'''
from transformers.utils.bitsandbytes import replace_with_bnb_linear, set_module_quantized_tensor_to_device
# 兼容transformers新旧版本
try:
from transformers.integrations import replace_with_bnb_linear, set_module_quantized_tensor_to_device
except:
from transformers.utils.bitsandbytes import replace_with_bnb_linear, set_module_quantized_tensor_to_device
from transformers.utils.quantization_config import BitsAndBytesConfig
if quantization_config is None:
quantization_config, kwargs = BitsAndBytesConfig.from_dict(
Expand Down Expand Up @@ -317,7 +326,7 @@ def quantize_load_in_kbit(model, load_in_8bit=False, load_in_4bit=False, keep_in

for key, param in model.named_parameters():
if param.device == torch.device("meta"):
set_module_quantized_tensor_to_device(model, key, 'cpu', value=state_dict[key], fp16_statistics=None)
set_module_quantized_tensor_to_device(model, key, 'cpu', value=state_dict[key])

model.is_loaded_in_8bit = load_in_8bit
model.is_loaded_in_4bit = load_in_4bit
Expand Down
2 changes: 1 addition & 1 deletion examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
| | [basic_gibbs_sampling_via_mlm.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/bert/basic_gibbs_sampling_via_mlm.py):利用BERT+Gibbs采样进行文本随机生成,参考[这里](https://kexue.fm/archives/8119)
| | [basic_language_model_bert-base-multilingual-cased.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/bert/basic_language_model_bert-base-multilingual-cased.py)[bert-base-multilingual-cased](https://huggingface.co/bert-base-multilingual-cased)
| | [basic_language_model_bert.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/bert/basic_language_model_bert.py):测试BERT的MLM模型效果。
| |[basic_language_model_guwenbert.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/roberta/basic_language_model_guwenbert.py): 测试[古文bert](https://huggingface.co/ethanyt/guwenbert-base)模型。
| | [basic_make_uncased_model_cased.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/bert/basic_make_uncased_model_cased.py):通过简单修改词表,使得不区分大小写的模型有区分大小写的能力。
| |[basic_language_model_macbert.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/macbert/basic_language_model_macbert.py):测试macbert的MLM模型效果。
|bloom | [basic_language_model_bloom.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/bloom/basic_language_model_bloom.py):测试[bloom](https://huggingface.co/bigscience)
Expand Down Expand Up @@ -51,6 +50,7 @@
|roberta|[basic_language_model_chinese-roberta-wwm.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/roberta/basic_language_model_chinese-roberta-wwm.py):测试HIT的chinese-roberta-wwm的MLM模型效果。
| |[basic_language_model_roberta_small_tiny.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/roberta/basic_language_model_roberta_small_tiny.py):测试Roberta-small的MLM模型效果。
| |[basic_language_model_roberta-base-english.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/roberta/basic_language_model_roberta-base-english.py):测试英文版roberta-base的MLM模型效果。
| |[basic_language_model_guwenbert.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/roberta/basic_language_model_guwenbert.py): 测试[古文bert](https://huggingface.co/ethanyt/guwenbert-base)模型。
|roformer|[basic_language_model_roformer.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/roformer/basic_language_model_roformer.py):测试roformer的MLM模型效果。
|simbert|[basic_language_model_simbert.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/simbert/basic_language_model_simbert.py):测试[simbert](https://github.com/ZhuiyiTechnology/simbert)[roformer-sim](https://github.com/ZhuiyiTechnology/roformer-sim)的生成效果和句子相似度效果。
| t5 |[basic_language_model_chatyuan.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/t5/basic_language_model_chatyuan.py): 测试[ChatYuan](https://github.com/clue-ai/ChatYuan)模型。
Expand Down
7 changes: 2 additions & 5 deletions examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,10 @@
## 其他
- [instruct_gpt](https://github.com/Tongjilibo/bert4torch/blob/master/examples/llm/instruct_gpt): 按照三个步骤复现rlhf的实现
- [task_chatglm_nbce.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/llm/task_chatglm_nbce.py): 测试[chatglm-6b](https://github.com/THUDM/ChatGLM-6B)模型, 使用朴素贝叶斯增加LLM的Context处理长度。
- [eval](https://github.com/Tongjilibo/bert4torch/blob/master/examples/llm/eval): 大模型的评估eval

## 微调
- [task_chatglm_ptuning_v2.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/llm/task_chatglm_ptuning_v2.py): [chatglm-6b](https://github.com/THUDM/ChatGLM-6B)的ptuning_v2微调。
- [task_chatglm_lora.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/llm/task_chatglm_lora.py): [chatglm-6b](https://github.com/THUDM/ChatGLM-6B)的lora微调(基于peft)。
- [task_chatglm2_ptuning_v2.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/llm/task_chatglm2_ptuning_v2.py): [chatglm2-6b](https://github.com/THUDM/ChatGLM2-6B)的ptuning_v2微调。
- [task_chatglm2_lora.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/llm/task_chatglm2_lora.py): [chatglm2-6b](https://github.com/THUDM/ChatGLM2-6B)的lora微调(基于peft)。
- [task_chatglm_ptuning_v2.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/llm/task_chatglm_ptuning_v2.py): [chatglm-6b](https://github.com/THUDM/ChatGLM-6B)/[chatglm2-6b](https://github.com/THUDM/ChatGLM2-6B)的ptuning_v2微调。
- [task_chatglm_lora.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/llm/task_chatglm_lora.py): [chatglm-6b](https://github.com/THUDM/ChatGLM-6B)/[chatglm2-6b](https://github.com/THUDM/ChatGLM2-6B)的lora微调(基于peft)。
- [task_llama-2_lora.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/llm/task_llama-2_lora.py): [llama-2](https://github.com/facebookresearch/llama)的lora微调(基于peft)。
- [task_chatglm_deepspeed](https://github.com/Tongjilibo/bert4torch/blob/master/examples/llm/task_chatglm_deepspeed): [chatglm](https://github.com/THUDM/ChatGLM-6B)的lora微调(peft+deepspeed)。
- [task_llama_deepspeed](https://github.com/Tongjilibo/bert4torch/blob/master/examples/llm/task_llama_deepspeed): [llama-2](https://github.com/facebookresearch/llama)的lora微调(peft+deepspeed)。
Expand Down
Loading

0 comments on commit dbe0b95

Please sign in to comment.