Skip to content

Commit

Permalink
merge the main branch (#186)
Browse files Browse the repository at this point in the history
Co-authored-by: qidanrui <[email protected]>
Co-authored-by: junewgl <[email protected]>
Co-authored-by: wangzaistone <[email protected]>
  • Loading branch information
4 people authored Dec 14, 2023
1 parent f79ade1 commit f31c00e
Show file tree
Hide file tree
Showing 9 changed files with 856 additions and 163 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
run: echo "DBGPT_HUB_VERSION=`toml get --toml-path pyproject.toml tool.poetry.version`" >> $GITHUB_ENV

- name: Create release note
run: poetry run python scripts/release-note.py $(git rev-parse --short HEAD) > RELEASE.md
run: poetry run python release_scripts/release-note.py $(git rev-parse --short HEAD) > RELEASE.md

- uses: ncipollo/release-action@v1
with:
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
- [2.2. Model](#22-model)
- [3. Usage](#3-usage)
- [3.1. Environment preparation](#31-environment-preparation)
- [3.2. Quick Start](#33-quick-start)
- [3.2. Quick Start](#32-quick-start)
- [3.3. Data preparation](#33-data-preparation)
- [3.4. Model fine-tuning](#34-model-fine-tuning)
- [3.5. Model Predict](#35-model-predict)
Expand Down
103 changes: 97 additions & 6 deletions README.zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,13 @@
- [2.2、基座模型](#22基座模型)
- [三、使用方法](#三使用方法)
- [3.1、环境准备](#31环境准备)
- [3.2、数据准备](#32数据准备)
- [3.3、模型微调](#33模型微调)
- [3.4、模型预测](#34模型预测)
- [3.5、模型权重](#35模型权重)
- [3.5.1 模型和微调权重合并](#351-模型和微调权重合并)
- [3.6、模型评估](#36模型评估)
- [3.2、快速开始](#32快速开始)
- [3.3、数据准备](#33数据准备)
- [3.4、模型微调](#34模型微调)
- [3.5、模型预测](#35模型预测)
- [3.6、模型权重](#36模型权重)
- [3.6.1 模型和微调权重合并](#361-模型和微调权重合并)
- [3.7、模型评估](#37模型评估)
- [四、发展路线](#四发展路线)
- [五、贡献](#五贡献)
- [六、感谢](#六感谢)
Expand Down Expand Up @@ -136,6 +137,96 @@ poetry run sh dbgpt_hub/scripts/gen_train_eval_data.sh
```
项目的数据处理代码中已经嵌套了`chase``cosql``sparc`的数据处理,可以根据上面链接将数据集下载到data路径后,在`dbgpt_hub/configs/config.py`中将 `SQL_DATA_INFO`中对应的代码注释松开即可。

### 3.2 快速开始

首先,用如下命令安装`dbgpt-hub`

`pip install dbgpt-hub`

然后,指定参数并用几行代码完成整个Text2SQL fine-tune流程:
```python
from dbgpt_hub.data_process import preprocess_sft_data
from dbgpt_hub.train import start_sft
from dbgpt_hub.predict import start_predict
from dbgpt_hub.eval import start_evaluate

# 配置训练和验证集路径和参数
data_folder = "dbgpt_hub/data"
data_info = [
{
"data_source": "spider",
"train_file": ["train_spider.json", "train_others.json"],
"dev_file": ["dev.json"],
"tables_file": "tables.json",
"db_id_name": "db_id",
"is_multiple_turn": False,
"train_output": "spider_train.json",
"dev_output": "spider_dev.json",
}
]

# 配置fine-tune参数
train_args = {
"model_name_or_path": "codellama/CodeLlama-13b-Instruct-hf",
"do_train": True,
"dataset": "example_text2sql_train",
"max_source_length": 2048,
"max_target_length": 512,
"finetuning_type": "lora",
"lora_target": "q_proj,v_proj",
"template": "llama2",
"lora_rank": 64,
"lora_alpha": 32,
"output_dir": "dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora",
"overwrite_cache": True,
"overwrite_output_dir": True,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 16,
"lr_scheduler_type": "cosine_with_restarts",
"logging_steps": 50,
"save_steps": 2000,
"learning_rate": 2e-4,
"num_train_epochs": 8,
"plot_loss": True,
"bf16": True,
}

# 配置预测参数
predict_args = {
"model_name_or_path": "codellama/CodeLlama-13b-Instruct-hf",
"template": "llama2",
"finetuning_type": "lora",
"checkpoint_dir": "dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora",
"predict_file_path": "dbgpt_hub/data/eval_data/dev_sql.json",
"predict_out_dir": "dbgpt_hub/output/",
"predicted_out_filename": "pred_sql.sql",
}

# 配置评估参数
evaluate_args = {
"input": "./dbgpt_hub/output/pred/pred_sql_dev_skeleton.sql",
"gold": "./dbgpt_hub/data/eval_data/gold.txt",
"gold_natsql": "./dbgpt_hub/data/eval_data/gold_natsql2sql.txt",
"db": "./dbgpt_hub/data/spider/database",
"table": "./dbgpt_hub/data/eval_data/tables.json",
"table_natsql": "./dbgpt_hub/data/eval_data/tables_for_natsql2sql.json",
"etype": "exec",
"plug_value": True,
"keep_distict": False,
"progress_bar_for_each_datapoint": False,
"natsql": False,
}

# 执行整个Fine-tune流程
preprocess_sft_data(
data_folder = data_folder,
data_info = data_info
)

start_sft(train_args)
start_predict(predict_args)
start_evaluate(evaluate_args)
```

### 3.3、模型微调

Expand Down
Loading

0 comments on commit f31c00e

Please sign in to comment.