Documentation | Torch4keras | Examples
安装稳定版
pip install bert4torch
安装最新版
pip install git+https://github.com/Tongjilibo/bert4torch
- 注意事项:pip包的发布慢于git上的开发版本,git clone注意引用路径,注意权重是否需要转换
- 测试用例:
git clone https://github.com/Tongjilibo/bert4torch
,修改example中的预训练模型文件路径和数据路径即可启动脚本 - 自行训练:针对自己的数据,修改相应的数据处理代码块
- 开发环境:原使用
torch==1.10
版本进行开发,现已切换到torch2.0
开发,如其他版本遇到不适配,欢迎反馈
-
LLM模型: 加载chatglm、llama、 baichuan、ziya、bloom等开源大模型权重进行推理和微调
-
核心功能:加载bert、roberta、albert、xlnet、nezha、bart、RoFormer、RoFormer_V2、ELECTRA、GPT、GPT2、T5、GAU-alpha、ERNIE等预训练权重继续进行finetune、并支持在bert基础上灵活定义自己模型
-
丰富示例:包含llm、pretrain、sentence_classfication、sentence_embedding、sequence_labeling、relation_extraction、seq2seq、serving等多种解决方案
-
实验验证:已在公开数据集实验验证,使用如下examples数据集
-
易用trick:集成了常见的trick,即插即用
-
其他特性:加载transformers库模型一起使用;调用方式简洁高效;有训练进度条动态展示;配合torchinfo打印参数量;默认Logger和Tensorboard简便记录训练过程;自定义fit过程,满足高阶需求
-
训练过程:
2022-10-28 23:16:10 - Start Training 2022-10-28 23:16:10 - Epoch: 1/2 5000/5000 [==============================] - 13s 3ms/step - loss: 0.1351 - acc: 0.9601 Evaluate: 100%|██████████████████████████████████████████████████| 2500/2500 [00:03<00:00, 798.09it/s] test_acc: 0.98045. best_test_acc: 0.98045 2022-10-28 23:16:27 - Epoch: 2/2 5000/5000 [==============================] - 13s 3ms/step - loss: 0.0465 - acc: 0.9862 Evaluate: 100%|██████████████████████████████████████████████████| 2500/2500 [00:03<00:00, 635.78it/s] test_acc: 0.98280. best_test_acc: 0.98280 2022-10-28 23:16:44 - Finish Training
功能 | bert4torch | transformers | 备注 |
---|---|---|---|
训练进度条 | ✅ | ✅ | 进度条打印loss和定义的metrics |
分布式训练dp/ddp | ✅ | ✅ | torch自带dp/ddp |
各类callbacks | ✅ | ✅ | 日志/tensorboard/earlystop/wandb等 |
大模型推理,stream/batch输出 | ✅ | ✅ | 各个模型是通用的,无需单独维护脚本 |
大模型微调 | ✅ | ✅ | lora依赖peft库,pv2自带 |
丰富tricks | ✅ | ❌ | 对抗训练等tricks即插即用 |
代码简洁易懂,自定义空间大 | ✅ | ❌ | 代码复用度高, keras代码训练风格 |
仓库的维护能力/影响力/使用量/兼容性 | ❌ | ✅ | 目前仓库个人维护 |
更新日期 | bert4torch | torch4keras | 版本说明 |
---|---|---|---|
20231126 | 0.4.0 | 0.1.5 | 修复flash_attn的bug, stream_generate支持仅输出last_token |
20231119 | 0.3.9 | 0.1.5 | 修复random_sample采样n>1, 新增Yi-6B, 支持flash_attn |
20231112 | 0.3.8 | 0.1.5 | 支持chatglm 32k的rope_ratio,config中可以指定mapping, 增加m3e和bge |
20231106 | 0.3.7 | 0.1.5 | 大部分模型文件无需convert,修复multi_query_group_num在int4/int8下bug, 简化build_transformer_model 中配置到config 中 |
- 20231126:修复flash_attn的bug, stream_generate支持仅输出last_token
- 20231119:修复random_sample采样n>1, 新增Yi-6B, 支持flash_attn
- 20231112:支持chatglm 32k的rope_ratio,config中可以指定mapping, 增加m3e和bge
- 20231106:🔥大部分模型文件无需convert,修复multi_query_group_num在int4/int8下bug, 简化
build_transformer_model
中配置到config
中
- 若无说明则使用权重自带的
pytorch_model.bin
和config.json
模型分类 | 模型名称 | 权重来源 | 权重链接 | 备注(若有) |
---|---|---|---|---|
bert | bert-base-chinese | 谷歌bert的torch版 | torch | config |
chinese_L-12_H-768_A-12 | 谷歌 | github, tf | 转换命令, config | |
chinese-bert-wwm-ext | HFL | tf/torch,torch | ||
bert-base-multilingual-cased | huggingface | torch | config | |
macbert | HFL | tf/torch,torch | ||
wobert | 追一科技 | tf,torch_base,torch_plus_base | ||
guwenbert | ethanyt | torch | config | |
roberta | chinese-roberta-wwm-ext | HFL | tf/torch,torch | |
roberta-small/tiny | 追一科技 & UER | tf,torch | 转换脚本 | |
roberta-base-english | huggingface | torch | config | |
albert | albert | brightmart | tf,torch,torch | |
nezha | NEZHA | 华为 | tf,torch | |
xlnet | chinese-xlnet | HFL | tf/torch | config |
deberta | Erlangshen-DeBERTa-v2 | IDEA | torch | |
electra | Chinese-ELECTRA | HFL | tf,torch | |
ernie | ernie | 百度文心 | paddle,torch | |
roformer | roformer | 追一科技 | tf,torch | |
roformer_v2 | 追一科技 | tf,torch | ||
simbert | simbert | 追一科技 | tf,torch_base | 转换脚本 |
simbert_v2/roformer-sim | 追一科技 | tf,torch | ||
gau | GAU-alpha | 追一科技 | tf | 转换脚本 |
gpt | CDial-GPT | thu-coai | torch | config |
gpt2 | cmp_lm(26亿) | 清华 | torch | config |
gpt2-chinese-cluecorpussmall | UER | torch | config | |
gpt2-ml | imcaspar | tf,torch | config | |
bart | bart_base_chinese | 复旦fnlp | torch, v1.0, v2.0 | config |
t5 | t5 | UER | torch | config_base, config_small |
mt5 | 谷歌 | torch | config | |
t5_pegasus | 追一科技 | tf | config_base, config_small | |
chatyuan v1&v2 | clue-ai | torch | config | |
PromptCLUE | clue-ai | torch | config | |
chatglm | chatglm-6b | THUDM | github, v0.1.0, v1.1.0, int8, int4 | config |
chatglm2-6b | THUDM | github, v2, int4, 32k | config | |
chatglm3-6b | THUDM | github, v3, 32k | config | |
llama | llama | github | config | |
llama-2 | github, 7b, 7b-chat, 13b, 13b-chat | config | ||
chinese_llama_alpaca | HFL | github | config | |
Belle_llama | LianjiaTech | github, 7B-2M-enc | 合成说明、config | |
Ziya | IDEA-CCNL | v1, v1.1, pretrain-v1 | config | |
Baichuan | baichuan-inc | github, 7B, 13B-Base, 13B-Chat | config | |
Baichuan2 | baichuan-inc | github, 7B-Base, 7B-Chat, 13B-Base, 13B-Chat | config | |
vicuna | lmsys | 7b-v1.5 | config | |
Yi | 01-ai | github, 6B, 6B-200K | config | |
bloom | bloom | bigscience | bloom-560m, bloomz-560m | config |
Qwen | Qwen | 阿里云 | github, 7B, 7B-Chat | config |
InternLM | InternLM | 上海人工智能实验室 | github, 7B-Chat, 7B | config |
Falcon | Falcon | tiiuae | hf, RW-1B, 7B, 7B-Instruct | config |
embedding | text2vec-base-chinese | shibing624 | torch | |
m3e | moka-ai | torch | config | |
bge | BAAI | torch | config |
- 感谢苏神实现的bert4keras,本实现有不少地方参考了bert4keras的源码,在此衷心感谢大佬的无私奉献;
- 其次感谢项目bert4pytorch,也是在该项目的指引下给了我用pytorch来复现bert4keras的想法和思路。
@misc{bert4torch,
title={bert4torch},
author={Bo Li},
year={2022},
howpublished={\url{https://github.com/Tongjilibo/bert4torch}},
}
- Wechat & Star History Chart
微信号 |
微信群 |
Star History Chart |