forked from eosphoros-ai/DB-GPT
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
docs: add dbgpt_hub usage documents (eosphoros-ai#955)
- Loading branch information
Showing
13 changed files
with
339 additions
and
67 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
""" | ||
Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_duckdb.py | ||
""" | ||
|
||
import pytest | ||
import tempfile | ||
|
||
from dbgpt.datasource.rdbms.conn_duckdb import DuckDbConnect | ||
|
||
|
||
@pytest.fixture | ||
def db(): | ||
temp_db_file = tempfile.NamedTemporaryFile(delete=False) | ||
temp_db_file.close() | ||
conn = DuckDbConnect.from_file_path(temp_db_file.name + "duckdb.db") | ||
yield conn | ||
|
||
|
||
def test_get_users(db): | ||
assert db.get_users() == [] | ||
|
||
|
||
def test_get_table_names(db): | ||
assert list(db.get_table_names()) == [] | ||
|
||
|
||
def test_get_users(db): | ||
assert db.get_users() == [] | ||
|
||
|
||
def test_get_charset(db): | ||
assert db.get_charset() == "UTF-8" | ||
|
||
|
||
def test_get_table_comments(db): | ||
assert db.get_table_comments("test") == [] | ||
|
||
|
||
def test_table_simple_info(db): | ||
assert db.table_simple_info() == [] | ||
|
||
|
||
def test_execute(db): | ||
assert list(db.run("SELECT 42")[0]) == ["42"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
# Fine-Tuning use dbgpt_hub | ||
|
||
The DB-GPT-Hub project has released a pip package to lower the threshold for Text2SQL training. In addition to fine-tuning through the scripts provided in the warehouse, you can alse use the Python package we provide | ||
for fine-tuning. | ||
|
||
## Install | ||
``` | ||
pip install dbgpt_hub | ||
``` | ||
|
||
## Show Baseline | ||
```python | ||
from dbgpt_hub.baseline import show_scores | ||
show_scores() | ||
``` | ||
<p align="left"> | ||
<img src={'/img/ft/baseline.png'} width="720px" /> | ||
</p> | ||
|
||
## Fine-tuning | ||
|
||
```python | ||
from dbgpt_hub.data_process import preprocess_sft_data | ||
from dbgpt_hub.train import train_sft | ||
from dbgpt_hub.predict import start_predict | ||
from dbgpt_hub.eval import start_evaluate | ||
``` | ||
|
||
|
||
Preprocessing data into fine-tuned data format. | ||
``` | ||
data_folder = "dbgpt_hub/data" | ||
data_info = [ | ||
{ | ||
"data_source": "spider", | ||
"train_file": ["train_spider.json", "train_others.json"], | ||
"dev_file": ["dev.json"], | ||
"tables_file": "tables.json", | ||
"db_id_name": "db_id", | ||
"is_multiple_turn": False, | ||
"train_output": "spider_train.json", | ||
"dev_output": "spider_dev.json", | ||
} | ||
] | ||
preprocess_sft_data( | ||
data_folder = data_folder, | ||
data_info = data_info | ||
) | ||
``` | ||
|
||
Fine-tune the basic model and generate model weights | ||
``` | ||
train_args = { | ||
"model_name_or_path": "codellama/CodeLlama-13b-Instruct-hf", | ||
"do_train": True, | ||
"dataset": "example_text2sql_train", | ||
"max_source_length": 2048, | ||
"max_target_length": 512, | ||
"finetuning_type": "lora", | ||
"lora_target": "q_proj,v_proj", | ||
"template": "llama2", | ||
"lora_rank": 64, | ||
"lora_alpha": 32, | ||
"output_dir": "dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora", | ||
"overwrite_cache": True, | ||
"overwrite_output_dir": True, | ||
"per_device_train_batch_size": 1, | ||
"gradient_accumulation_steps": 16, | ||
"lr_scheduler_type": "cosine_with_restarts", | ||
"logging_steps": 50, | ||
"save_steps": 2000, | ||
"learning_rate": 2e-4, | ||
"num_train_epochs": 8, | ||
"plot_loss": True, | ||
"bf16": True, | ||
} | ||
start_sft(train_args) | ||
``` | ||
|
||
Predictive model output results | ||
``` | ||
predict_args = { | ||
"model_name_or_path": "codellama/CodeLlama-13b-Instruct-hf", | ||
"template": "llama2", | ||
"finetuning_type": "lora", | ||
"checkpoint_dir": "dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora", | ||
"predict_file_path": "dbgpt_hub/data/eval_data/dev_sql.json", | ||
"predict_out_dir": "dbgpt_hub/output/", | ||
"predicted_out_filename": "pred_sql.sql", | ||
} | ||
start_predict(predict_args) | ||
``` | ||
|
||
Evaluate the accuracy of the output results on the test datasets | ||
|
||
``` | ||
evaluate_args = { | ||
"input": "./dbgpt_hub/output/pred/pred_sql_dev_skeleton.sql", | ||
"gold": "./dbgpt_hub/data/eval_data/gold.txt", | ||
"gold_natsql": "./dbgpt_hub/data/eval_data/gold_natsql2sql.txt", | ||
"db": "./dbgpt_hub/data/spider/database", | ||
"table": "./dbgpt_hub/data/eval_data/tables.json", | ||
"table_natsql": "./dbgpt_hub/data/eval_data/tables_for_natsql2sql.json", | ||
"etype": "exec", | ||
"plug_value": True, | ||
"keep_distict": False, | ||
"progress_bar_for_each_datapoint": False, | ||
"natsql": False, | ||
} | ||
start_evaluate(evaluate_args) | ||
``` | ||
|
||
|
||
|
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import pytest | ||
|
||
from dbgpt.datasource.rdbms.conn_clickhouse import ClickhouseConnect | ||
|
||
|
||
@pytest.fixture | ||
def db(): | ||
conn = ClickhouseConnect.from_uri_db("localhost", 8123, "default", "", "default") | ||
yield conn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
""" | ||
Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_doris.py | ||
""" | ||
|
||
import pytest | ||
from dbgpt.datasource.rdbms.conn_doris import DorisConnect | ||
|
||
|
||
@pytest.fixture | ||
def db(): | ||
conn = DorisConnect.from_uri_db("localhost", 9030, "root", "", "test") | ||
yield conn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
""" | ||
Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_mysql.py | ||
docker run -itd --name mysql-test -p 3307:3306 -e MYSQL_ROOT_PASSWORD=12345678 mysql:5.7 | ||
mysql -h 127.0.0.1 -uroot -p -P3307 | ||
Enter password: | ||
Welcome to the MySQL monitor. Commands end with ; or \g. | ||
Your MySQL connection id is 2 | ||
Server version: 5.7.41 MySQL Community Server (GPL) | ||
Copyright (c) 2000, 2023, Oracle and/or its affiliates. | ||
Oracle is a registered trademark of Oracle Corporation and/or its | ||
affiliates. Other names may be trademarks of their respective | ||
owners. | ||
Type 'help;' or '\h' for help. Type '\c' to clear the current input statement. | ||
> create database test; | ||
""" | ||
|
||
import pytest | ||
from dbgpt.datasource.rdbms.conn_mysql import MySQLConnect | ||
|
||
_create_table_sql = """ | ||
CREATE TABLE IF NOT EXISTS `test` ( | ||
`id` int(11) DEFAULT NULL | ||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; | ||
""" | ||
|
||
|
||
@pytest.fixture | ||
def db(): | ||
conn = MySQLConnect.from_uri_db( | ||
"localhost", | ||
3307, | ||
"root", | ||
"******", | ||
"test", | ||
engine_args={"connect_args": {"charset": "utf8mb4"}}, | ||
) | ||
yield conn | ||
|
||
|
||
def test_get_usable_table_names(db): | ||
db.run(_create_table_sql) | ||
print(db._sync_tables_from_db()) | ||
assert list(db.get_usable_table_names()) == [] | ||
|
||
|
||
def test_get_table_info(db): | ||
assert "CREATE TABLE test" in db.get_table_info() | ||
|
||
|
||
def test_get_table_info_with_table(db): | ||
db.run(_create_table_sql) | ||
print(db._sync_tables_from_db()) | ||
table_info = db.get_table_info() | ||
assert "CREATE TABLE test" in table_info | ||
|
||
|
||
def test_run_no_throw(db): | ||
assert db.run_no_throw("this is a error sql").startswith("Error:") | ||
|
||
|
||
def test_get_index_empty(db): | ||
db.run(_create_table_sql) | ||
assert db.get_indexes("test") == [] | ||
|
||
|
||
def test_get_fields(db): | ||
db.run(_create_table_sql) | ||
assert list(db.get_fields("test")[0])[0] == "id" | ||
|
||
|
||
def test_get_charset(db): | ||
assert db.get_charset() == "utf8mb4" or db.get_charset() == "latin1" | ||
|
||
|
||
def test_get_collation(db): | ||
assert ( | ||
db.get_collation() == "utf8mb4_general_ci" | ||
or db.get_collation() == "latin1_swedish_ci" | ||
) | ||
|
||
|
||
def test_get_users(db): | ||
assert ("root", "%") in db.get_users() | ||
|
||
|
||
def test_get_database_lists(db): | ||
assert db.get_database_list() == ["test"] |
Oops, something went wrong.