diff --git a/bert4torch/models/base.py b/bert4torch/models/base.py index 745988d6..f11128c8 100644 --- a/bert4torch/models/base.py +++ b/bert4torch/models/base.py @@ -430,7 +430,7 @@ def set_outputs(self, outputs): else: self.output = outputs[0] - def quantize(self, quantization_method, **kwargs): + def quantize(self, quantization_method:Literal['cpm_kernels', 'load_in_8bit', 'load_in_4bit'], **kwargs): '''量化''' if self.quantized: print("Already quantized.") diff --git a/bert4torch/quantization.py b/bert4torch/quantization.py index 3ef1cbde..ef0cb776 100644 --- a/bert4torch/quantization.py +++ b/bert4torch/quantization.py @@ -3,17 +3,18 @@ from torch.nn import Linear, Embedding from torch.nn.parameter import Parameter +from torch import nn import torch.nn.functional as F import bz2 import torch import base64 import ctypes -from typing import List +from typing import List, Union, Dict import re from tqdm import tqdm from functools import partial import inspect -from bert4torch.snippets import log_error +from bert4torch.snippets import is_package_available try: from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up @@ -41,7 +42,6 @@ def __init__(self, code: bytes, function_names: List[str]): ) except Exception as exception: kernels = None - log_error("Failed to load cpm_kernels:" + str(exception)) class W8A16Linear(torch.autograd.Function): @@ -222,7 +222,8 @@ def forward(self, input): return output -def quantize_cpm_kernels(model, quantization_bit, use_quantization_cache=False, empty_init=False, target_modules=None, **kwargs): +def quantize_cpm_kernels(model:nn.Module, quantization_bit:int, use_quantization_cache:bool=False, + empty_init:bool=False, target_modules:Union[str, List]=None, **kwargs): """从chagglm-6b移植过来的的量化,方便以int8和int4进行推理 源链接:https://huggingface.co/THUDM/chatglm-6b/blob/main/quantization.py @@ -230,6 +231,9 @@ def quantize_cpm_kernels(model, quantization_bit, use_quantization_cache=False, 这里修改了hard code, 可以适配其他模型 target_modules: str/list, 指定对某些层做量化 """ + if not is_package_available('cpm_kernels'): + raise ModuleNotFoundError('Module `cpm_kernels` not found, you may use `pip install cpm_kernels`') + modules_trans = {} for name, module in model.named_modules(): # target_modules=None, 表示对所有Linear层替换 @@ -257,7 +261,7 @@ def quantize_cpm_kernels(model, quantization_bit, use_quantization_cache=False, cache = dict() for name, module in tqdm(modules_trans.items(), desc='Quantize linear layers'): - cache_name = re.sub('\.[0-9]+\.', '.', name) + cache_name = re.sub(r'\.[0-9]+\.', '.', name) if use_quantization_cache and (cache_name not in cache): n, m = module.weight.size(0), module.weight.size(1) cache[cache_name] = CacheTensor(n, m, dtype=dtype, device=current_device, requires_grad=False) @@ -273,7 +277,7 @@ def quantize_cpm_kernels(model, quantization_bit, use_quantization_cache=False, del module # 赋值 name_new = list(name) - for iter_ in re.finditer('\.[0-9]+\.', name): + for iter_ in re.finditer(r'\.[0-9]+\.', name): iter_str = name[iter_.start():iter_.end()] name_new[iter_.start():iter_.end()] = [''] * (iter_.end()-iter_.start()) name_new[iter_.start()] = '[' + iter_str[1:-1] + '].' @@ -281,9 +285,14 @@ def quantize_cpm_kernels(model, quantization_bit, use_quantization_cache=False, return model -def quantize_load_in_kbit(model, load_in_8bit=False, load_in_4bit=False, keep_in_fp32_modules=None, llm_int8_skip_modules=None, quantization_config=None, **kwargs): +def quantize_load_in_kbit(model:nn.Module, load_in_8bit:bool=False, load_in_4bit:bool=False, keep_in_fp32_modules:List=None, + llm_int8_skip_modules:List=None, quantization_config:Dict=None, **kwargs): '''transformer的load_in_8bit, 源自transformer源代码''' - from transformers.utils.bitsandbytes import replace_with_bnb_linear, set_module_quantized_tensor_to_device + # 兼容transformers新旧版本 + try: + from transformers.integrations import replace_with_bnb_linear, set_module_quantized_tensor_to_device + except: + from transformers.utils.bitsandbytes import replace_with_bnb_linear, set_module_quantized_tensor_to_device from transformers.utils.quantization_config import BitsAndBytesConfig if quantization_config is None: quantization_config, kwargs = BitsAndBytesConfig.from_dict( @@ -317,7 +326,7 @@ def quantize_load_in_kbit(model, load_in_8bit=False, load_in_4bit=False, keep_in for key, param in model.named_parameters(): if param.device == torch.device("meta"): - set_module_quantized_tensor_to_device(model, key, 'cpu', value=state_dict[key], fp16_statistics=None) + set_module_quantized_tensor_to_device(model, key, 'cpu', value=state_dict[key]) model.is_loaded_in_8bit = load_in_8bit model.is_loaded_in_4bit = load_in_4bit diff --git a/examples/README.md b/examples/README.md index 1c8c9b39..14d1882c 100644 --- a/examples/README.md +++ b/examples/README.md @@ -9,7 +9,6 @@ | | [basic_gibbs_sampling_via_mlm.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/bert/basic_gibbs_sampling_via_mlm.py):利用BERT+Gibbs采样进行文本随机生成,参考[这里](https://kexue.fm/archives/8119)。 | | [basic_language_model_bert-base-multilingual-cased.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/bert/basic_language_model_bert-base-multilingual-cased.py):[bert-base-multilingual-cased](https://huggingface.co/bert-base-multilingual-cased) | | [basic_language_model_bert.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/bert/basic_language_model_bert.py):测试BERT的MLM模型效果。 -| |[basic_language_model_guwenbert.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/roberta/basic_language_model_guwenbert.py): 测试[古文bert](https://huggingface.co/ethanyt/guwenbert-base)模型。 | | [basic_make_uncased_model_cased.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/bert/basic_make_uncased_model_cased.py):通过简单修改词表,使得不区分大小写的模型有区分大小写的能力。 | |[basic_language_model_macbert.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/macbert/basic_language_model_macbert.py):测试macbert的MLM模型效果。 |bloom | [basic_language_model_bloom.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/bloom/basic_language_model_bloom.py):测试[bloom](https://huggingface.co/bigscience)。 @@ -51,6 +50,7 @@ |roberta|[basic_language_model_chinese-roberta-wwm.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/roberta/basic_language_model_chinese-roberta-wwm.py):测试HIT的chinese-roberta-wwm的MLM模型效果。 | |[basic_language_model_roberta_small_tiny.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/roberta/basic_language_model_roberta_small_tiny.py):测试Roberta-small的MLM模型效果。 | |[basic_language_model_roberta-base-english.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/roberta/basic_language_model_roberta-base-english.py):测试英文版roberta-base的MLM模型效果。 +| |[basic_language_model_guwenbert.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/roberta/basic_language_model_guwenbert.py): 测试[古文bert](https://huggingface.co/ethanyt/guwenbert-base)模型。 |roformer|[basic_language_model_roformer.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/roformer/basic_language_model_roformer.py):测试roformer的MLM模型效果。 |simbert|[basic_language_model_simbert.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/simbert/basic_language_model_simbert.py):测试[simbert](https://github.com/ZhuiyiTechnology/simbert)和[roformer-sim](https://github.com/ZhuiyiTechnology/roformer-sim)的生成效果和句子相似度效果。 | t5 |[basic_language_model_chatyuan.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/t5/basic_language_model_chatyuan.py): 测试[ChatYuan](https://github.com/clue-ai/ChatYuan)模型。 diff --git a/examples/llm/README.md b/examples/llm/README.md index f2f3e9ee..18a43fc9 100644 --- a/examples/llm/README.md +++ b/examples/llm/README.md @@ -8,13 +8,10 @@ ## 其他 - [instruct_gpt](https://github.com/Tongjilibo/bert4torch/blob/master/examples/llm/instruct_gpt): 按照三个步骤复现rlhf的实现 - [task_chatglm_nbce.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/llm/task_chatglm_nbce.py): 测试[chatglm-6b](https://github.com/THUDM/ChatGLM-6B)模型, 使用朴素贝叶斯增加LLM的Context处理长度。 -- [eval](https://github.com/Tongjilibo/bert4torch/blob/master/examples/llm/eval): 大模型的评估eval ## 微调 -- [task_chatglm_ptuning_v2.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/llm/task_chatglm_ptuning_v2.py): [chatglm-6b](https://github.com/THUDM/ChatGLM-6B)的ptuning_v2微调。 -- [task_chatglm_lora.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/llm/task_chatglm_lora.py): [chatglm-6b](https://github.com/THUDM/ChatGLM-6B)的lora微调(基于peft)。 -- [task_chatglm2_ptuning_v2.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/llm/task_chatglm2_ptuning_v2.py): [chatglm2-6b](https://github.com/THUDM/ChatGLM2-6B)的ptuning_v2微调。 -- [task_chatglm2_lora.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/llm/task_chatglm2_lora.py): [chatglm2-6b](https://github.com/THUDM/ChatGLM2-6B)的lora微调(基于peft)。 +- [task_chatglm_ptuning_v2.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/llm/task_chatglm_ptuning_v2.py): [chatglm-6b](https://github.com/THUDM/ChatGLM-6B)/[chatglm2-6b](https://github.com/THUDM/ChatGLM2-6B)的ptuning_v2微调。 +- [task_chatglm_lora.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/llm/task_chatglm_lora.py): [chatglm-6b](https://github.com/THUDM/ChatGLM-6B)/[chatglm2-6b](https://github.com/THUDM/ChatGLM2-6B)的lora微调(基于peft)。 - [task_llama-2_lora.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/llm/task_llama-2_lora.py): [llama-2](https://github.com/facebookresearch/llama)的lora微调(基于peft)。 - [task_chatglm_deepspeed](https://github.com/Tongjilibo/bert4torch/blob/master/examples/llm/task_chatglm_deepspeed): [chatglm](https://github.com/THUDM/ChatGLM-6B)的lora微调(peft+deepspeed)。 - [task_llama_deepspeed](https://github.com/Tongjilibo/bert4torch/blob/master/examples/llm/task_llama_deepspeed): [llama-2](https://github.com/facebookresearch/llama)的lora微调(peft+deepspeed)。 diff --git a/examples/llm/task_chatglm2_lora.py b/examples/llm/task_chatglm2_lora.py deleted file mode 100644 index acecc366..00000000 --- a/examples/llm/task_chatglm2_lora.py +++ /dev/null @@ -1,230 +0,0 @@ -#! -*- coding: utf-8 -*- -# chatglm2的指令微调, 基于lora/qlora -# peft和transformer包是耦合的,因此这里用法和hf的略有不同 -# 参考项目:lora: https://github.com/mymusise/ChatGLM-Tuning -# qlora: https://github.com/shuxueslpi/chatGLM-6B-QLoRA - -# | chatglm | gpu | Time/epoch(s)| Rouge-L | Rouge-1 | Rouge-2 | BLEU | comment | -# | ---------------------- | --------- | ------------ | ------------- | ----------- | ----------- | --------- | ------- | - -from bert4torch.models import build_transformer_model -from bert4torch.snippets import sequence_padding, text_segmentate -import torch.nn as nn -import torch -import torch.optim as optim -from torch.utils.data import DataLoader -import torch -from bert4torch.models import build_transformer_model, BaseModel -from bert4torch.snippets import ListDataset -from bert4torch.generation import SeqGeneration -from bert4torch.callbacks import Callback, Logger -from bert4torch.optimizers import get_linear_schedule_with_warmup -from transformers import AutoTokenizer -import json -import jieba -from rouge_chinese import Rouge -from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction -import numpy as np -from tqdm import tqdm -from peft import LoraConfig, prepare_model_for_kbit_training # 需要pip install git+https://github.com/huggingface/peft.git -import os - - -# 基本参数 -mode = 'train' -max_source_length = 64 -max_target_length = 64 -lr = 5e-4 -batch_size = 16 # 根据显存大小调整 -eval_batch_size = 4 -grad_accumulation_steps = 1 # 根据显存大小调整 -max_seq_length = max_source_length + max_target_length -ignore_pad_token_for_loss = True -epochs = 1 -steps_per_epoch = 3000 -prefix = '' -prompt_column = 'content' -response_column = 'summary' -history_column = None - -# 模型配置 -dir_path = "E:\\pretrain_ckpt\\glm\\chatglm2-6B" -config_path = dir_path + '\\bert4torch_config.json' -checkpoint_path = [os.path.join(dir_path, i) for i in os.listdir(dir_path) if i.endswith('.bin')] -device = 'cuda' if torch.cuda.is_available() else 'cpu' - -tokenizer = AutoTokenizer.from_pretrained(dir_path, trust_remote_code=True) - -# 加载数据集 -class MyDataset(ListDataset): - @staticmethod - def load_data(filename): - """加载数据,并尽量分为不超过maxlen的句子 - """ - D = [] - with open(filename, encoding='utf-8') as f: - for l in f: - l = json.loads(l) - prompt, response = l[prompt_column], l[response_column] - history = l.get('history_column', None) - D.append((prompt, response, history)) - return D - -def build_prompt(query, history=None): - if history is None: - history = [] - prompt = "" - for i, (old_query, response) in enumerate(history): - prompt += "[Round {}]\n\n问:{}\n\n答:{}\n\n".format(i + 1, old_query, response) - prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query) - return prompt - -def collate_train_fn(batch): - batch_token_ids, batch_labels = [], [] - for query, answer, history in batch: - prompt = build_prompt(query, history) - prompt = prefix + prompt - a_ids = tokenizer.encode(text=prompt, add_special_tokens=True, truncation=True, max_length=max_source_length) - b_ids = tokenizer.encode(text=answer, add_special_tokens=False, truncation=True, max_length=max_target_length) - - context_length = len(a_ids) - input_ids = a_ids + b_ids + [tokenizer.eos_token_id] - labels = [tokenizer.pad_token_id] * context_length + b_ids + [tokenizer.eos_token_id] - batch_token_ids.append(input_ids) - batch_labels.append(labels) - - batch_token_ids = torch.tensor(sequence_padding(batch_token_ids, value=tokenizer.pad_token_id), dtype=torch.long, device=device) - batch_labels = torch.tensor(sequence_padding(batch_labels, value=tokenizer.pad_token_id), dtype=torch.long, device=device) - return [batch_token_ids], batch_labels - -def collate_dev_fn(batch): - batch_prompt, batch_labels = [], [] - for query, labels, history in batch: - batch_prompt.append(prefix + build_prompt(query, history)) - - label_ids = tokenizer(text_target=labels, max_length=max_target_length, truncation=True)['input_ids'] - batch_labels.append(tokenizer.decode(label_ids, skip_special_tokens=True)) - return batch_prompt, batch_labels - -train_dataloader = DataLoader(MyDataset('F:/data/corpus/sft/AdvertiseGen/train.json'), batch_size=batch_size, shuffle=True, collate_fn=collate_train_fn) -dev_dataloader = DataLoader(MyDataset('F:/data/corpus/sft/AdvertiseGen/dev.json'), batch_size=eval_batch_size, shuffle=False, collate_fn=collate_dev_fn) - -# 建立模型,加载权重 -model = build_transformer_model(config_path=config_path, checkpoint_path=checkpoint_path, add_trainer=True, - tie_emb_prj_weight=True, # 绑定embedding和dense/lm_head的权重,transformers中有绑定 - ).half() - -# 量化 -load_in_nbit = None # 设置为True在3060卡上loss能正常下降,在v100上loss就是nan -if load_in_nbit == 8: - model.gradient_checkpointing_enable() - model.enable_input_require_grads() - - class CastOutputToFloat(nn.Sequential): - def forward(self, x): - return super().forward(x).to(torch.float32) - model = model.quantize(quantization_method='load_in_8bit', llm_int8_skip_modules=['model.embeddings.word_embeddings', 'lm_head']) - model.lm_head = CastOutputToFloat(model.lm_head) - -elif load_in_nbit == 4: - from transformers import BitsAndBytesConfig - q_config = BitsAndBytesConfig(load_in_4bit=True, - bnb_4bit_quant_type='nf4', - bnb_4bit_use_double_quant=True, - bnb_4bit_compute_dtype=torch.float16, # 可选 torch.float32, torch.float16, torch.bfloat16 - llm_int8_skip_modules=['model.embeddings.word_embeddings', 'lm_head'] - ) - model = model.quantize(quantization_method='load_in_4bit', quantization_config=q_config) - model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True) - -# lora -peft_config = LoraConfig( - inference_mode=False, - r=8, - lora_alpha=32, - lora_dropout=0.1, - target_modules=['q', 'k', 'v'] - ) -model = model.get_peft_model(peft_config).to(device) - -class CrossEntropyLoss(nn.CrossEntropyLoss): - def __init__(self, **kwargs): - super().__init__(**kwargs) - def forward(self, logits, labels): - ''' - logits: [btz, seq_len, vocab_size] - labels: token_ids: [btz, seq_len] - ''' - raw_dtyps = logits.dtype - logits = logits.to(torch.float32) - logits = logits[:, :-1, :].contiguous() # 预测序列,错开一位 - labels = labels[:, 1:].contiguous() # 目标token_ids - - logits = logits.reshape(-1, logits.shape[-1]) - labels = labels.flatten() - loss = super().forward(logits, labels) - - return loss.to(raw_dtyps) - -optimizer = optim.AdamW(model.parameters(), lr) -scheduler = get_linear_schedule_with_warmup(optimizer, 0, steps_per_epoch*epochs) # torch4keras<0.0.8需要设置为(steps_per_epoch*epochs)//grad_accumulation_steps -model.compile(loss=CrossEntropyLoss(ignore_index=tokenizer.pad_token_id), optimizer=optimizer, scheduler=scheduler, grad_accumulation_steps=grad_accumulation_steps, clip_grad_norm=1.0) - -class Chat(SeqGeneration): - def pre_process(self, text): - return [tokenizer(text, max_length=max_source_length, truncation=True)['input_ids']] - def post_process(self, output_ids): - return [tokenizer.decode(output_id.cpu().numpy()) for output_id in output_ids] -generation = Chat(model, tokenizer, start_id=None, end_id=tokenizer.eos_token_id, pad_id=tokenizer.pad_token_id, - mode='random_sample', maxlen=512, default_rtype='logits', use_states=True) - -class Evaluator(Callback): - """评估与保存 - """ - def __init__(self): - self.best = 0 - - def on_epoch_end(self, steps, epoch, logs=None): - model.save_weights(f'./model.pt', trainable_only=True) - - def evaluate(self, data, epoch='final'): - preds, labels = [], [] - for prompt, label in tqdm(data, desc='Evaluating'): - pred = generation.generate(prompt, topk=50, topp=0.7, temperature=0.95) - preds.extend(pred) - labels.extend(label) - with open(f'./preds_{epoch}.txt', 'a+', encoding='utf-8') as f: - for pred_i, label_i in zip(pred, label): - f.write(json.dumps({'pred': pred_i, 'label': label_i}, ensure_ascii=False) + '\n') - - score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []} - for pred, label in zip(preds, labels): - hypothesis = list(jieba.cut(pred)) - reference = list(jieba.cut(label)) - rouge = Rouge() - scores = rouge.get_scores(' '.join(hypothesis) , ' '.join(reference)) - result = scores[0] - - for k, v in result.items(): - score_dict[k].append(round(v["f"] * 100, 4)) - bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3) - score_dict["bleu-4"].append(round(bleu_score * 100, 4)) - - for k, v in score_dict.items(): - score_dict[k] = float(np.mean(v)) - return score_dict - - -if __name__ == '__main__': - evaluator = Evaluator() - logger = Logger('./log.log', interval=100) - - if mode == 'train': - model.fit(train_dataloader, steps_per_epoch=steps_per_epoch, epochs=epochs, callbacks=[evaluator, logger]) - score_dict = evaluator.evaluate(dev_dataloader) - print(score_dict) - - else: - model.load_weights('./model.pt', strict=False) - score_dict = evaluator.evaluate(dev_dataloader) - print(score_dict) diff --git a/examples/llm/task_chatglm2_ptuning_v2.py b/examples/llm/task_chatglm2_ptuning_v2.py deleted file mode 100644 index 3a630e76..00000000 --- a/examples/llm/task_chatglm2_ptuning_v2.py +++ /dev/null @@ -1,214 +0,0 @@ -#! -*- coding: utf-8 -*- -# chatglm2的指令微调, 基于ptuning_v2,性能和官方项目给出的指标相当 -# | chatglm2 | gpu | Time/epoch(s)| Rouge-L | Rouge-1 | Rouge-2 | BLEU | comment | -# | ---------------------- | --------- | ------------ | ------------- | ----------- | ----------- | --------- | ------- | -# | b4t+pt2+v100+int4+bs1 | 7G | —— | 24.36 | 29.97 | 6.66 | 7.89 | | - -from bert4torch.models import build_transformer_model -from bert4torch.snippets import sequence_padding, text_segmentate -from bert4torch.callbacks import Callback -import torch.nn as nn -import torch -import torch.optim as optim -from torch.utils.data import DataLoader -import torch -from bert4torch.models import build_transformer_model, BaseModel -from transformers import AutoTokenizer -from bert4torch.snippets import ListDataset, seed_everything -from bert4torch.callbacks import Logger -from bert4torch.generation import SeqGeneration -from bert4torch.optimizers import get_linear_schedule_with_warmup -from bert4torch.trainer import PtuningV2Trainer -import json -import jieba -from rouge_chinese import Rouge -from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction -import numpy as np -from tqdm import tqdm -import os - - -# 基本参数 -mode = 'train' -max_source_length = 64 -max_target_length = 64 -lr = 2e-2 -batch_size = 1 -eval_batch_size = 16 -grad_accumulation_steps = 16 -steps_per_epoch = 3000 -epochs = 1 -max_seq_length = max_source_length + max_target_length -ignore_pad_token_for_loss = True -prefix = '' -prompt_column = 'content' -response_column = 'summary' -history_column = None -use_states = True - -seed_everything(42) - -# 模型配置 -choice = 'default' # chatglm2, int4, int8 -if choice == 'default': - dir_path = "E:/pretrain_ckpt/glm/chatglm2-6B" - config_path = dir_path + '/bert4torch_config.json' - checkpoint_path = [os.path.join(dir_path, i) for i in os.listdir(dir_path) if i.endswith('.bin')] -elif choice == 'int4': - dir_path = "E:/pretrain_ckpt/glm/chatglm2-6B-int4" - config_path = dir_path + '/bert4torch_config.json' - checkpoint_path = [os.path.join(dir_path, i) for i in os.listdir(dir_path) if i.endswith('.bin')] -elif choice == 'int8': - dir_path = "E:/pretrain_ckpt/glm/chatglm2-6B-int8" - config_path = dir_path + '/bert4torch_config.json' - checkpoint_path = [os.path.join(dir_path, i) for i in os.listdir(dir_path) if i.endswith('.bin')] - -device = 'cuda' if torch.cuda.is_available() else 'cpu' -tokenizer = AutoTokenizer.from_pretrained(dir_path.replace('/', '\\'), trust_remote_code=True) - -# 加载数据集 -class MyDataset(ListDataset): - @staticmethod - def load_data(filename): - """加载数据,并尽量分为不超过maxlen的句子 - """ - D = [] - with open(filename, encoding='utf-8') as f: - for l in f: - l = json.loads(l) - prompt, response = l[prompt_column], l[response_column] - history = l.get('history_column', None) - D.append((prompt, response, history)) - return D - -def build_prompt(query, history=None): - if history is None: - history = [] - prompt = "" - for i, (old_query, response) in enumerate(history): - prompt += "[Round {}]\n\n问:{}\n\n答:{}\n\n".format(i + 1, old_query, response) - prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query) - return prompt - -def collate_train_fn(batch): - batch_token_ids, batch_labels = [], [] - for query, answer, history in batch: - prompt = build_prompt(query, history) - prompt = prefix + prompt - a_ids = tokenizer.encode(text=prompt, add_special_tokens=True, truncation=True, max_length=max_source_length) - b_ids = tokenizer.encode(text=answer, add_special_tokens=False, truncation=True, max_length=max_target_length) - - context_length = len(a_ids) - input_ids = a_ids + b_ids + [tokenizer.eos_token_id] - labels = [tokenizer.pad_token_id] * context_length + b_ids + [tokenizer.eos_token_id] - batch_token_ids.append(input_ids) - batch_labels.append(labels) - - batch_token_ids = torch.tensor(sequence_padding(batch_token_ids, value=tokenizer.pad_token_id), dtype=torch.long, device=device) - batch_labels = torch.tensor(sequence_padding(batch_labels, value=tokenizer.pad_token_id), dtype=torch.long, device=device) - return [batch_token_ids], batch_labels - -def collate_dev_fn(batch): - batch_prompt, batch_labels = [], [] - for query, labels, history in batch: - batch_prompt.append(prefix + build_prompt(query, history)) - - label_ids = tokenizer(text_target=labels, max_length=max_target_length, truncation=True)['input_ids'] - batch_labels.append(tokenizer.decode(label_ids, skip_special_tokens=True)) - return batch_prompt, batch_labels - -train_dataloader = DataLoader(MyDataset('F:/data/corpus/sft/AdvertiseGen/train.json'), batch_size=batch_size, shuffle=True, collate_fn=collate_train_fn) -dev_dataloader = DataLoader(MyDataset('F:/data/corpus/sft/AdvertiseGen/dev.json'), batch_size=eval_batch_size, shuffle=False, collate_fn=collate_dev_fn) - -if choice == 'default': - encoder = build_transformer_model(config_path=config_path, checkpoint_path=checkpoint_path).half() - encoder = encoder.quantize(quantization_method='cpm_kernels', quantization_bit=4, - target_modules=['q', 'k', 'v', 'o', 'intermediateDense', 'outputDense']).to(device) -else: - # 在config中已经写入了量化的配置参数 - encoder = build_transformer_model(config_path=config_path, checkpoint_path=checkpoint_path).to(device) - -model = PtuningV2Trainer(encoder).to(device) -model.print_trainable_parameters() - -class CrossEntropyLoss(nn.CrossEntropyLoss): - def __init__(self, **kwargs): - super().__init__(**kwargs) - def forward(self, logits, labels): - ''' - logits: [btz, seq_len, vocab_size] - labels: token_ids: [btz, seq_len] - ''' - raw_dtyps = logits.dtype - logits = logits.to(torch.float32) - logits = logits[:, :-1, :].contiguous() # 预测序列,错开一位 - labels = labels[:, 1:].contiguous() # 目标token_ids - - logits = logits.reshape(-1, logits.shape[-1]) - labels = labels.flatten() - loss = super().forward(logits, labels) - - return loss.to(raw_dtyps) - -optimizer = optim.AdamW(model.parameters(), lr) -scheduler = get_linear_schedule_with_warmup(optimizer, 0, steps_per_epoch*epochs) # torch4keras<0.0.8需要设置为(steps_per_epoch*epochs)//grad_accumulation_steps -model.compile(loss=CrossEntropyLoss(ignore_index=tokenizer.pad_token_id), optimizer=optimizer, scheduler=scheduler, grad_accumulation_steps=grad_accumulation_steps, clip_grad_norm=1.0) - -class Chat(SeqGeneration): - def pre_process(self, text): - return [tokenizer(text, max_length=max_source_length, truncation=True)['input_ids']] - def post_process(self, output_ids): - return [tokenizer.decode(output_id.cpu().numpy()) for output_id in output_ids] -generation = Chat(model, tokenizer, start_id=None, end_id=tokenizer.eos_token_id, pad_id=tokenizer.pad_token_id, - mode='random_sample', maxlen=512, default_rtype='logits', use_states=use_states) - -class Evaluator(Callback): - """评估与保存 - """ - def __init__(self): - self.best = 0 - - def on_epoch_end(self, steps, epoch, logs=None): - model.save_weights(f'./model.pt', trainable_only=True) - - def evaluate(self, data, epoch='final'): - preds, labels = [], [] - for prompt, label in tqdm(data, desc='Evaluating'): - pred = generation.generate(prompt, topk=50, topp=0.7, temperature=0.95) - preds.extend(pred) - labels.extend(label) - with open(f'./preds_{epoch}.txt', 'a+', encoding='utf-8') as f: - for pred_i, label_i in zip(pred, label): - f.write(json.dumps({'pred': pred_i, 'label': label_i}, ensure_ascii=False) + '\n') - - score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []} - for pred, label in zip(preds, labels): - hypothesis = list(jieba.cut(pred)) - reference = list(jieba.cut(label)) - rouge = Rouge() - scores = rouge.get_scores(' '.join(hypothesis) , ' '.join(reference)) - result = scores[0] - - for k, v in result.items(): - score_dict[k].append(round(v["f"] * 100, 4)) - bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3) - score_dict["bleu-4"].append(round(bleu_score * 100, 4)) - - for k, v in score_dict.items(): - score_dict[k] = float(np.mean(v)) - return score_dict - - -if __name__ == '__main__': - evaluator = Evaluator() - logger = Logger('./log.log', interval=100) - - if mode == 'train': - model.fit(train_dataloader, steps_per_epoch=steps_per_epoch, epochs=epochs, callbacks=[evaluator, logger]) - score_dict = evaluator.evaluate(dev_dataloader) - print(score_dict) - - else: - model.load_weights('./model.pt', strict=False) - score_dict = evaluator.evaluate(dev_dataloader) - print(score_dict) diff --git a/examples/llm/task_chatglm_lora.py b/examples/llm/task_chatglm_lora.py index 2db0b335..c16ee1d3 100644 --- a/examples/llm/task_chatglm_lora.py +++ b/examples/llm/task_chatglm_lora.py @@ -1,13 +1,13 @@ #! -*- coding: utf-8 -*- -# chatglm的指令微调, 基于lora/qlora +# chatglm/chatglm2的指令微调, 基于lora/qlora # peft和transformer包是耦合的,因此这里用法和hf的略有不同 # 参考项目:lora: https://github.com/mymusise/ChatGLM-Tuning # qlora: https://github.com/shuxueslpi/chatGLM-6B-QLoRA -# | chatglm | gpu | Time/epoch(s)| Rouge-L | Rouge-1 | Rouge-2 | BLEU | comment | -# | ---------------------- | --------- | ------------ | ------------- | ----------- | ----------- | --------- | ------- | -# | b4t+lora+V100-fp16-bs16 | 28G | 2570 | 24.89 | 31.38 | 7.17 | 8.15 | | -# | b4t+qlora+V100-bs16 | 26G | 5381 | 23.99 | 29.52 | 6.47 | 7.74 | | +# | 模型 | gpu | Time/epoch(s)| Rouge-L | Rouge-1 | Rouge-2 | BLEU | comment | +# | ------------------------------ | --------- | ------------ | ------------- | ----------- | ----------- | --------- | ------- | +# | chatglm+b4t+lora+V100-fp16-bs16 | 28G | 2570 | 24.89 | 31.38 | 7.17 | 8.15 | | +# | chatglm+b4t+qlora+V100-bs16 | 26G | 5381 | 23.99 | 29.52 | 6.47 | 7.74 | | from bert4torch.models import build_transformer_model from bert4torch.snippets import sequence_padding, text_segmentate @@ -21,6 +21,7 @@ from bert4torch.generation import SeqGeneration from bert4torch.callbacks import Callback, Logger from bert4torch.optimizers import get_linear_schedule_with_warmup +from bert4torch.losses import CausalLMLoss from transformers import AutoTokenizer import json import jieba @@ -32,32 +33,34 @@ import os -# 基本参数 +# ====================================基本参数==================================== +model_name = 'chatglm2' # 可选chatglm, chatglm2 mode = 'train' +load_in_nbit = None # 量化, 可选None, 8, 4 max_source_length = 64 max_target_length = 64 +max_seq_length = max_source_length + max_target_length lr = 5e-4 -batch_size = 16 # 根据显存大小调整 +batch_size = 4 # 根据显存大小调整 eval_batch_size = 4 grad_accumulation_steps = 1 # 根据显存大小调整 -max_seq_length = max_source_length + max_target_length -ignore_pad_token_for_loss = True epochs = 1 steps_per_epoch = 3000 prefix = '' prompt_column = 'content' response_column = 'summary' history_column = None - -# 模型配置 -dir_path = "E:\\pretrain_ckpt\\glm\\chatglm-6B" -config_path = dir_path + '\\bert4torch_config.json' -checkpoint_path = [os.path.join(dir_path, i) for i in os.listdir(dir_path) if i.endswith('.bin')] +data_dir = '/data/corpus/sft/AdvertiseGen' # 数据路径 +if model_name == 'chatglm2': + model_dir = "/data/pretrain_ckpt/glm/chatglm2-6B" +elif model_name == 'chatglm': + model_dir = "/data/pretrain_ckpt/glm/chatglm-6B" device = 'cuda' if torch.cuda.is_available() else 'cpu' -tokenizer = AutoTokenizer.from_pretrained(dir_path, trust_remote_code=True) -# 加载数据集 +# ====================================加载数据集==================================== +tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) + class MyDataset(ListDataset): @staticmethod def load_data(filename): @@ -72,40 +75,69 @@ def load_data(filename): D.append((prompt, response, history)) return D -def build_prompt(query, history): - if history_column is None: - prompt = query - else: +if model_name == 'chatglm': + def build_prompt(query, history): + if history_column is None: + prompt = query + else: + prompt = "" + for i, (old_query, answer) in enumerate(history): + prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, answer) + prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) + return prompt + + def collate_train_fn(batch): + batch_token_ids, batch_labels = [], [] + for query, answer, history in batch: + prompt = build_prompt(query, history) + prompt = prefix + prompt + a_ids = tokenizer.encode(text=prompt, add_special_tokens=False) + b_ids = tokenizer.encode(text=answer, add_special_tokens=False) + + if len(a_ids) > max_source_length - 1: + a_ids = a_ids[:max_source_length - 1] + + if len(b_ids) > max_target_length - 2: + b_ids = b_ids[:max_target_length - 2] + + input_ids = tokenizer.build_inputs_with_special_tokens(a_ids, b_ids) + context_length = input_ids.index(tokenizer.bos_token_id) + mask_position = context_length - 1 + labels = [tokenizer.pad_token_id] * context_length + input_ids[mask_position+1:] + batch_token_ids.append(input_ids) + batch_labels.append(labels) + + batch_token_ids = torch.tensor(sequence_padding(batch_token_ids, value=tokenizer.pad_token_id), dtype=torch.long, device=device) + batch_labels = torch.tensor(sequence_padding(batch_labels, value=tokenizer.pad_token_id), dtype=torch.long, device=device) + return [batch_token_ids], batch_labels + +elif model_name == 'chatglm2': + def build_prompt(query, history=None): + if history is None: + history = [] prompt = "" - for i, (old_query, answer) in enumerate(history): - prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, answer) - prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) - return prompt - -def collate_train_fn(batch): - batch_token_ids, batch_labels = [], [] - for query, answer, history in batch: - prompt = build_prompt(query, history) - prompt = prefix + prompt - a_ids = tokenizer.encode(text=prompt, add_special_tokens=False) - b_ids = tokenizer.encode(text=answer, add_special_tokens=False) - - if len(a_ids) > max_source_length - 1: - a_ids = a_ids[:max_source_length - 1] - - if len(b_ids) > max_target_length - 2: - b_ids = b_ids[:max_target_length - 2] - - input_ids = tokenizer.build_inputs_with_special_tokens(a_ids, b_ids) - context_length = input_ids.index(tokenizer.bos_token_id) - mask_position = context_length - 1 - labels = [-100] * context_length + input_ids[mask_position+1:] - batch_token_ids.append(input_ids) - batch_labels.append(labels) - - batch_token_ids = torch.tensor(sequence_padding(batch_token_ids, value=tokenizer.pad_token_id), dtype=torch.long, device=device) - batch_labels = torch.tensor(sequence_padding(batch_labels, value=-100), dtype=torch.long, device=device) - return [batch_token_ids], batch_labels + for i, (old_query, response) in enumerate(history): + prompt += "[Round {}]\n\n问:{}\n\n答:{}\n\n".format(i + 1, old_query, response) + prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query) + return prompt + + def collate_train_fn(batch): + batch_token_ids, batch_labels = [], [] + for query, answer, history in batch: + prompt = build_prompt(query, history) + prompt = prefix + prompt + a_ids = tokenizer.encode(text=prompt, add_special_tokens=True, truncation=True, max_length=max_source_length) + b_ids = tokenizer.encode(text=answer, add_special_tokens=False, truncation=True, max_length=max_target_length) + + context_length = len(a_ids) + input_ids = a_ids + b_ids + [tokenizer.eos_token_id] + labels = [tokenizer.pad_token_id] * context_length + b_ids + [tokenizer.eos_token_id] + batch_token_ids.append(input_ids) + batch_labels.append(labels) + + batch_token_ids = torch.tensor(sequence_padding(batch_token_ids, value=tokenizer.pad_token_id), dtype=torch.long, device=device) + batch_labels = torch.tensor(sequence_padding(batch_labels, value=tokenizer.pad_token_id), dtype=torch.long, device=device) + return [batch_token_ids], batch_labels def collate_dev_fn(batch): batch_prompt, batch_labels = [], [] @@ -116,16 +148,16 @@ def collate_dev_fn(batch): batch_labels.append(tokenizer.decode(label_ids, skip_special_tokens=True)) return batch_prompt, batch_labels -train_dataloader = DataLoader(MyDataset('F:/data/corpus/sft/AdvertiseGen/train.json'), batch_size=batch_size, shuffle=True, collate_fn=collate_train_fn) -dev_dataloader = DataLoader(MyDataset('F:/data/corpus/sft/AdvertiseGen/dev.json'), batch_size=eval_batch_size, shuffle=False, collate_fn=collate_dev_fn) +train_dataloader = DataLoader(MyDataset(os.path.join(data_dir, 'train.json')), batch_size=batch_size, shuffle=True, collate_fn=collate_train_fn) +dev_dataloader = DataLoader(MyDataset(os.path.join(data_dir, 'dev.json')), batch_size=eval_batch_size, shuffle=False, collate_fn=collate_dev_fn) -# 建立模型,加载权重 -model = build_transformer_model(config_path=config_path, checkpoint_path=checkpoint_path, add_trainer=True, - tie_emb_prj_weight=True, # 绑定embedding和dense/lm_head的权重,transformers中有绑定 - ).half() + +# ====================================建立模型==================================== +# 原使用peft=0.5.0时候下面可.half(),高版本peft.half()发现loss为nan,排查发现是高版本会把lora_A转成和base_layer(Linear)的dtype=fp16 +# 把.half()去掉,使用原来的bf16训练,lora_A还是fp32 +model = build_transformer_model(config_path=model_dir, checkpoint_path=model_dir, add_trainer=True, tie_emb_prj_weight=True) # 量化 -load_in_nbit = None # 设置为True在3060卡上loss能正常下降,在v100上loss就是nan if load_in_nbit == 8: model.gradient_checkpointing_enable() model.enable_input_require_grads() @@ -157,28 +189,10 @@ def forward(self, x): ) model = model.get_peft_model(peft_config).to(device) -class CrossEntropyLoss(nn.CrossEntropyLoss): - def __init__(self, **kwargs): - super().__init__(**kwargs) - def forward(self, logits, labels): - ''' - logits: [btz, seq_len, vocab_size] - labels: token_ids: [btz, seq_len] - ''' - raw_dtyps = logits.dtype - logits = logits.to(torch.float32) - logits = logits[:, :-1, :].contiguous() # 预测序列,错开一位 - labels = labels[:, 1:].contiguous() # 目标token_ids - - logits = logits.reshape(-1, logits.shape[-1]) - labels = labels.flatten() - loss = super().forward(logits, labels) - - return loss.to(raw_dtyps) - optimizer = optim.AdamW(model.parameters(), lr) scheduler = get_linear_schedule_with_warmup(optimizer, 0, steps_per_epoch*epochs) # torch4keras<0.0.8需要设置为(steps_per_epoch*epochs)//grad_accumulation_steps -model.compile(loss=CrossEntropyLoss(ignore_index=-100), optimizer=optimizer, scheduler=scheduler, grad_accumulation_steps=grad_accumulation_steps, clip_grad_norm=1.0) +model.compile(loss=CausalLMLoss(offset=True, ignore_index=tokenizer.pad_token_id), optimizer=optimizer, scheduler=scheduler, + grad_accumulation_steps=grad_accumulation_steps, clip_grad_norm=1.0) class Chat(SeqGeneration): def pre_process(self, text): diff --git a/examples/llm/task_chatglm_ptuning_v2.py b/examples/llm/task_chatglm_ptuning_v2.py index d714a343..8c974efd 100644 --- a/examples/llm/task_chatglm_ptuning_v2.py +++ b/examples/llm/task_chatglm_ptuning_v2.py @@ -1,15 +1,16 @@ #! -*- coding: utf-8 -*- -# chatglm的指令微调, 基于ptuning_v2,性能和官方项目给出的指标相当 -# | chatglm | gpu | Time/epoch(s)| Rouge-L | Rouge-1 | Rouge-2 | BLEU | comment | -# | ---------------------- | --------- | ------------ | ------------- | ----------- | ----------- | --------- | ------- | -# | hf+pt2 official+v100-int4-bs1 | —— | —— | 24.97 | 31.12 | 7.11 | 8.10 | | -# | hf+pt2 reappear+v100-int4-bs1 | —— | —— | 24.80 | 30.97 | 6.98 | 7.85 | | -# | b4t+pt2+v100+int4+bs1 | —— | —— | 24.58 | 30.76 | 7.12 | 8.12 | | -# | b4t+pt2+T4-int8-bs1 | 10G | 1470 | 24.87 | 30.83 | 7.14 | 8.05 | | -# | b4t+pt2+A100(pcie 40G)-fp16-bs1 | 15G | 287 | 25.10 | 31.43 | 7.30 | 8.28 | | -# | b4t+pt2+A100(pcie 40G)-fp16-bs8 | 22G | 705 | 25.22 | 31.22 | 7.38 | 8.35 | | -# | b4t+pt2+A100(pcie 40G)-fp32-bs1 | 29G | 760 | 24.83 | 30.95 | 7.18 | 8.08 | | -# | b4t+pt2+A100(pcie 40G)-fp32-bs4 | 32G | 2600 | 25.12 | 31.55 | 7.21 | 8.02 | | +# chatglm/chatglm2的指令微调, 基于ptuning_v2,性能和官方项目给出的指标相当 +# | model | gpu | Time/epoch(s)| Rouge-L | Rouge-1 | Rouge-2 | BLEU | comment | +# | ------------------------------ | --------- | ------------ | ------------- | ----------- | ----------- | --------- | ------- | +# | chatglm+hf+pt2 official+v100-int4-bs1 | —— | —— | 24.97 | 31.12 | 7.11 | 8.10 | | +# | chatglm+hf+pt2 reappear+v100-int4-bs1 | —— | —— | 24.80 | 30.97 | 6.98 | 7.85 | | +# | chatglm+b4t+pt2+v100+int4+bs1 | —— | —— | 24.58 | 30.76 | 7.12 | 8.12 | | +# | chatglm+b4t+pt2+T4-int8-bs1 | 10G | 1470 | 24.87 | 30.83 | 7.14 | 8.05 | | +# | chatglm+b4t+pt2+A100(pcie 40G)-fp16-bs1 | 15G | 287 | 25.10 | 31.43 | 7.30 | 8.28 | | +# | chatglm+b4t+pt2+A100(pcie 40G)-fp16-bs8 | 22G | 705 | 25.22 | 31.22 | 7.38 | 8.35 | | +# | chatglm+b4t+pt2+A100(pcie 40G)-fp32-bs1 | 29G | 760 | 24.83 | 30.95 | 7.18 | 8.08 | | +# | chatglm+b4t+pt2+A100(pcie 40G)-fp32-bs4 | 32G | 2600 | 25.12 | 31.55 | 7.21 | 8.02 | | +# | chatglm2+b4t+pt2+v100+int4+bs1 | 7G | —— | 24.36 | 29.97 | 6.66 | 7.89 | | from bert4torch.snippets import sequence_padding from bert4torch.callbacks import Callback @@ -20,11 +21,12 @@ import torch from bert4torch.models import build_transformer_model, BaseModel from transformers import AutoTokenizer -from bert4torch.snippets import ListDataset +from bert4torch.snippets import ListDataset, seed_everything from bert4torch.callbacks import Logger from bert4torch.generation import SeqGeneration from bert4torch.optimizers import get_linear_schedule_with_warmup from bert4torch.trainer import PtuningV2Trainer +from bert4torch.losses import CausalLMLoss import json import jieba from rouge_chinese import Rouge @@ -34,43 +36,33 @@ import os -# 基本参数 +# ====================================基本参数==================================== +mode = 'train' # train evaluate inference +model_name = 'chatglm2-6B' # chatglm-6B, chatglm-6B-int4, chatglm-6B-int8, chatglm2-6B, chatglm2-6B-int4, chatglm2-6B-int8 max_source_length = 64 max_target_length = 64 +max_seq_length = max_source_length + max_target_length lr = 2e-2 batch_size = 1 eval_batch_size = 16 grad_accumulation_steps = 16 steps_per_epoch = 3000 epochs = 1 # torch4keras<0.0.8后需要设置为16,因为1个batch_step不包含grad_accumulation_steps -max_seq_length = max_source_length + max_target_length ignore_pad_token_for_loss = True prefix = '' prompt_column = 'content' response_column = 'summary' history_column = None use_states = True - -# 模型配置 -choice = 'int4' # default, int4, int8 -if choice == 'default': - dir_path = "E:/pretrain_ckpt/glm/chatglm-6B" - config_path = dir_path + '/bert4torch_config.json' - checkpoint_path = [os.path.join(dir_path, i) for i in os.listdir(dir_path) if i.endswith('.bin')] -elif choice == 'int4': - dir_path = "E:/pretrain_ckpt/glm/chatglm-6B-int4" - config_path = dir_path + '/bert4torch_config.json' - checkpoint_path = [os.path.join(dir_path, i) for i in os.listdir(dir_path) if i.endswith('.bin')] -elif choice == 'int8': - dir_path = "E:/pretrain_ckpt/glm/chatglm-6B-int8" - config_path = dir_path + '/bert4torch_config.json' - checkpoint_path = [os.path.join(dir_path, i) for i in os.listdir(dir_path) if i.endswith('.bin')] - +data_dir = '/data/corpus/sft/AdvertiseGen' # 数据路径 +model_dir = f"/data/pretrain_ckpt/glm/{model_name}" # 模型路径 device = 'cuda' if torch.cuda.is_available() else 'cpu' -tokenizer = AutoTokenizer.from_pretrained(dir_path.replace('/', '\\'), trust_remote_code=True) +seed_everything(42) + +# ====================================加载数据集==================================== +tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) -# 加载数据集 class MyDataset(ListDataset): @staticmethod def load_data(filename): @@ -85,40 +77,68 @@ def load_data(filename): D.append((prompt, response, history)) return D -def build_prompt(query, history): - if history_column is None: - prompt = query - else: +if model_name in {'chatglm-6B', 'chatglm-6B-int8', 'chatglm-6B-int4'}: + def build_prompt(query, history): + if history_column is None: + prompt = query + else: + prompt = "" + for i, (old_query, answer) in enumerate(history): + prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, answer) + prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) + return prompt + + def collate_train_fn(batch): + batch_token_ids, batch_labels = [], [] + for query, answer, history in batch: + prompt = build_prompt(query, history) + prompt = prefix + prompt + a_ids = tokenizer.encode(text=prompt, add_special_tokens=False) + b_ids = tokenizer.encode(text=answer, add_special_tokens=False) + + if len(a_ids) > max_source_length - 1: + a_ids = a_ids[:max_source_length - 1] + + if len(b_ids) > max_target_length - 2: + b_ids = b_ids[:max_target_length - 2] + + input_ids = tokenizer.build_inputs_with_special_tokens(a_ids, b_ids) + context_length = input_ids.index(tokenizer.bos_token_id) + mask_position = context_length - 1 + labels = [tokenizer.pad_token_id] * context_length + input_ids[mask_position+1:] + batch_token_ids.append(input_ids) + batch_labels.append(labels) + + batch_token_ids = torch.tensor(sequence_padding(batch_token_ids, value=tokenizer.pad_token_id), dtype=torch.long, device=device) + batch_labels = torch.tensor(sequence_padding(batch_labels, value=tokenizer.pad_token_id), dtype=torch.long, device=device) + return [batch_token_ids], batch_labels +else: + def build_prompt(query, history=None): + if history is None: + history = [] prompt = "" - for i, (old_query, answer) in enumerate(history): - prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, answer) - prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) - return prompt - -def collate_train_fn(batch): - batch_token_ids, batch_labels = [], [] - for query, answer, history in batch: - prompt = build_prompt(query, history) - prompt = prefix + prompt - a_ids = tokenizer.encode(text=prompt, add_special_tokens=False) - b_ids = tokenizer.encode(text=answer, add_special_tokens=False) - - if len(a_ids) > max_source_length - 1: - a_ids = a_ids[:max_source_length - 1] - - if len(b_ids) > max_target_length - 2: - b_ids = b_ids[:max_target_length - 2] - - input_ids = tokenizer.build_inputs_with_special_tokens(a_ids, b_ids) - context_length = input_ids.index(tokenizer.bos_token_id) - mask_position = context_length - 1 - labels = [-100] * context_length + input_ids[mask_position+1:] - batch_token_ids.append(input_ids) - batch_labels.append(labels) - - batch_token_ids = torch.tensor(sequence_padding(batch_token_ids, value=tokenizer.pad_token_id), dtype=torch.long, device=device) - batch_labels = torch.tensor(sequence_padding(batch_labels, value=-100), dtype=torch.long, device=device) - return [batch_token_ids], batch_labels + for i, (old_query, response) in enumerate(history): + prompt += "[Round {}]\n\n问:{}\n\n答:{}\n\n".format(i + 1, old_query, response) + prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query) + return prompt + + def collate_train_fn(batch): + batch_token_ids, batch_labels = [], [] + for query, answer, history in batch: + prompt = build_prompt(query, history) + prompt = prefix + prompt + a_ids = tokenizer.encode(text=prompt, add_special_tokens=True, truncation=True, max_length=max_source_length) + b_ids = tokenizer.encode(text=answer, add_special_tokens=False, truncation=True, max_length=max_target_length) + + context_length = len(a_ids) + input_ids = a_ids + b_ids + [tokenizer.eos_token_id] + labels = [tokenizer.pad_token_id] * context_length + b_ids + [tokenizer.eos_token_id] + batch_token_ids.append(input_ids) + batch_labels.append(labels) + + batch_token_ids = torch.tensor(sequence_padding(batch_token_ids, value=tokenizer.pad_token_id), dtype=torch.long, device=device) + batch_labels = torch.tensor(sequence_padding(batch_labels, value=tokenizer.pad_token_id), dtype=torch.long, device=device) + return [batch_token_ids], batch_labels def collate_dev_fn(batch): batch_prompt, batch_labels = [], [] @@ -129,42 +149,26 @@ def collate_dev_fn(batch): batch_labels.append(tokenizer.decode(label_ids, skip_special_tokens=True)) return batch_prompt, batch_labels -train_dataloader = DataLoader(MyDataset('F:/data/corpus/sft/AdvertiseGen/train.json'), batch_size=batch_size, shuffle=True, collate_fn=collate_train_fn) -dev_dataloader = DataLoader(MyDataset('F:/data/corpus/sft/AdvertiseGen/dev.json'), batch_size=eval_batch_size, shuffle=False, collate_fn=collate_dev_fn) +train_dataloader = DataLoader(MyDataset(os.path.join(data_dir, 'train.json')), batch_size=batch_size, shuffle=True, collate_fn=collate_train_fn) +dev_dataloader = DataLoader(MyDataset(os.path.join(data_dir, 'dev.json')), batch_size=eval_batch_size, shuffle=False, collate_fn=collate_dev_fn) + -if choice == 'default': - encoder = build_transformer_model(config_path=config_path, checkpoint_path=checkpoint_path).half() +# ====================================建立模型==================================== +if model_name in {'chatglm-6B', 'chatglm2-6B'}: + encoder = build_transformer_model(config_path=model_dir, checkpoint_path=model_dir).half() encoder = encoder.quantize(quantization_method='cpm_kernels', quantization_bit=4, - target_modules=['q', 'k', 'v', 'o', 'intermediateDense', 'outputDense']).to(device) + target_modules=['q', 'k', 'v', 'o', 'intermediateDense', 'outputDense']).to(device) else: # 在config中已经写入了量化的配置参数 - encoder = build_transformer_model(config_path=config_path, checkpoint_path=checkpoint_path).to(device) + encoder = build_transformer_model(config_path=model_dir, checkpoint_path=model_dir).to(device) model = PtuningV2Trainer(encoder).to(device) model.print_trainable_parameters() -class CrossEntropyLoss(nn.CrossEntropyLoss): - def __init__(self, **kwargs): - super().__init__(**kwargs) - def forward(self, logits, labels): - ''' - logits: [btz, seq_len, vocab_size] - labels: token_ids: [btz, seq_len] - ''' - raw_dtyps = logits.dtype - logits = logits.to(torch.float32) - logits = logits[:, :-1, :].contiguous() # 预测序列,错开一位 - labels = labels[:, 1:].contiguous() # 目标token_ids - - logits = logits.reshape(-1, logits.shape[-1]) - labels = labels.flatten() - loss = super().forward(logits, labels) - - return loss.to(raw_dtyps) - optimizer = optim.AdamW(model.parameters(), lr) scheduler = get_linear_schedule_with_warmup(optimizer, 0, steps_per_epoch*epochs) # torch4keras<0.0.8需要设置为(steps_per_epoch*epochs)//grad_accumulation_steps -model.compile(loss=CrossEntropyLoss(ignore_index=-100), optimizer=optimizer, scheduler=scheduler, grad_accumulation_steps=grad_accumulation_steps, clip_grad_norm=1.0) +model.compile(loss=CausalLMLoss(offset=True, ignore_index=tokenizer.pad_token_id), optimizer=optimizer, + scheduler=scheduler, grad_accumulation_steps=grad_accumulation_steps, clip_grad_norm=1.0) class Chat(SeqGeneration): def pre_process(self, text): @@ -222,7 +226,6 @@ def evaluate(self, data, epoch='final'): if __name__ == '__main__': evaluator = Evaluator() logger = Logger('./log.log', interval=100) - mode = 'train' # train evaluate inference if mode == 'train': model.fit(train_dataloader, steps_per_epoch=steps_per_epoch, epochs=epochs, callbacks=[evaluator, logger])