diff --git a/README.md b/README.md index 5887181..fa49d43 100644 --- a/README.md +++ b/README.md @@ -36,12 +36,13 @@ - [2.2. Model](#22-model) - [3. Usage](#3-usage) - [3.1. Environment preparation](#31-environment-preparation) - - [3.2. Data preparation](#32-data-preparation) - - [3.3. Model fine-tuning](#33-model-fine-tuning) - - [3.4. Model Predict](#34-model-predict) - - [3.5 Model Weights](#35-model-weights) - - [3.5.1 Model and fine-tuned weight merging](#351-model-and-fine-tuned-weight-merging) - - [3.6 Model Evaluation](#36-model-evaluation) + - [3.2. Quick Start](#33-quick-start) + - [3.3. Data preparation](#33-data-preparation) + - [3.4. Model fine-tuning](#34-model-fine-tuning) + - [3.5. Model Predict](#35-model-predict) + - [3.6 Model Weights](#36-model-weights) + - [3.6.1 Model and fine-tuned weight merging](#361-model-and-fine-tuned-weight-merging) + - [3.7 Model Evaluation](#37-model-evaluation) - [4. RoadMap](#4-roadmap) - [5. Contributions](#5-contributions) - [6. Acknowledgements](#6-acknowledgements) @@ -118,8 +119,98 @@ conda activate dbgpt_hub pip install poetry poetry install ``` +### 3.2 Quick Start + +Firstly, install `dbgpt_hub` with the following command + +`pip install dbgpt_hub` + +Then, set up the arguments and run the whole process. +```python +from dbgpt_hub.data_process import preprocess_sft_data +from dbgpt_hub.train import start_sft +from dbgpt_hub.predict import start_predict +from dbgpt_hub.eval import start_evaluate + +# Config the input datasets +data_folder = "dbgpt_hub/data" +data_info = [ + { + "data_source": "spider", + "train_file": ["train_spider.json", "train_others.json"], + "dev_file": ["dev.json"], + "tables_file": "tables.json", + "db_id_name": "db_id", + "is_multiple_turn": False, + "train_output": "spider_train.json", + "dev_output": "spider_dev.json", + } +] + +# Config training parameters +train_args = { + "model_name_or_path": "codellama/CodeLlama-13b-Instruct-hf", + "do_train": True, + "dataset": "example_text2sql_train", + "max_source_length": 2048, + "max_target_length": 512, + "finetuning_type": "lora", + "lora_target": "q_proj,v_proj", + "template": "llama2", + "lora_rank": 64, + "lora_alpha": 32, + "output_dir": "dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora", + "overwrite_cache": True, + "overwrite_output_dir": True, + "per_device_train_batch_size": 1, + "gradient_accumulation_steps": 16, + "lr_scheduler_type": "cosine_with_restarts", + "logging_steps": 50, + "save_steps": 2000, + "learning_rate": 2e-4, + "num_train_epochs": 8, + "plot_loss": True, + "bf16": True, +} + +# Config predict parameters +predict_args = { + "model_name_or_path": "codellama/CodeLlama-13b-Instruct-hf", + "template": "llama2", + "finetuning_type": "lora", + "checkpoint_dir": "dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora", + "predict_file_path": "dbgpt_hub/data/eval_data/dev_sql.json", + "predict_out_dir": "dbgpt_hub/output/", + "predicted_out_filename": "pred_sql.sql", +} + +# Config evaluation parameters +evaluate_args = { + "input": "./dbgpt_hub/output/pred/pred_sql_dev_skeleton.sql", + "gold": "./dbgpt_hub/data/eval_data/gold.txt", + "gold_natsql": "./dbgpt_hub/data/eval_data/gold_natsql2sql.txt", + "db": "./dbgpt_hub/data/spider/database", + "table": "./dbgpt_hub/data/eval_data/tables.json", + "table_natsql": "./dbgpt_hub/data/eval_data/tables_for_natsql2sql.json", + "etype": "exec", + "plug_value": True, + "keep_distict": False, + "progress_bar_for_each_datapoint": False, + "natsql": False, +} + +# Run the whole fine-tuning workflow +preprocess_sft_data( + data_folder = data_folder, + data_info = data_info +) + +start_sft(train_args) +start_predict(predict_args) +start_evaluate(evaluate_args) +``` -### 3.2. Data preparation +### 3.3. Data preparation DB-GPT-Hub uses the information matching generation method for data preparation, i.e. the SQL + Repository generation method that combines table information. This method combines data table information to better understand the structure and relationships of the data table, and is suitable for generating SQL statements that meet the requirements. @@ -146,7 +237,7 @@ The data in the generated JSON looks something like this: ``` The data processing code of `chase`, `cosql` and `sparc` has been embedded in the data processing code of the project. After downloading the data set according to the above link, you only need to add ` in `dbgpt_hub/configs/config.py` Just loosen the corresponding code comment in SQL_DATA_INFO`. -### 3.3. Model fine-tuning +### 3.4. Model fine-tuning The model fine-tuning supports both LoRA and QLoRA methods. We can run the following command to fine-tune the model. By default, with the parameter --quantization_bit, it uses the QLoRA fine-tuning method. To switch to LoRAs, simply remove the related parameter from the script. Run the command: @@ -206,7 +297,7 @@ In the script, during fine-tuning, different models correspond to key parameters > num_train_epochs: The number of epochs for training the dataset. -### 3.4. Model Predict +### 3.5. Model Predict Under the project directory ./dbgpt_hub/output/pred/, this folder is the default output location for model predictions(if not exist, just mkdir). @@ -217,11 +308,11 @@ poetry run sh ./dbgpt_hub/scripts/predict_sft.sh In the script, by default with the parameter `--quantization_bit`, it predicts using QLoRA. Removing it switches to the LoRA prediction method. The value of the parameter `predicted_input_filename` is your predict test dataset file. `--predicted_out_filename` is the file name of the model's predicted results. -### 3.5 Model Weights +### 3.6 Model Weights You can find the second corresponding model weights from Huggingface [hg-eosphoros-ai ](https://huggingface.co/Wangzaistone123/CodeLlama-13b-sql-lora) ,we uploaded the LoRA weights in October,which execution accuracy on the Spider evaluation set reached 0.789. -#### 3.5.1 Model and fine-tuned weight merging +#### 3.6.1 Model and fine-tuned weight merging If you need to merge the weights of the trained base model and the fine-tuned Peft module to export a complete model, execute the following model export script: @@ -231,7 +322,7 @@ poetry run sh ./dbgpt_hub/scripts/export_merge.sh Be sure to replace the parameter path values in the script with the paths corresponding to your project. -### 3.6 Model Evaluation +### 3.7 Model Evaluation To evaluate model performance on the dataset, default is spider dev dataset. Run the following command: ```bash