Skip to content

Commit

Permalink
docs: add dbgpt_hub usage documents (eosphoros-ai#955)
Browse files Browse the repository at this point in the history
  • Loading branch information
csunny authored and Hopshine committed Sep 10, 2024
1 parent 589ac4c commit d67a6fa
Show file tree
Hide file tree
Showing 13 changed files with 339 additions and 67 deletions.
Binary file modified assets/wechat.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
44 changes: 44 additions & 0 deletions dbgpt/datasource/rdbms/tests/test_conn_duckdb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""
Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_duckdb.py
"""

import pytest
import tempfile

from dbgpt.datasource.rdbms.conn_duckdb import DuckDbConnect


@pytest.fixture
def db():
temp_db_file = tempfile.NamedTemporaryFile(delete=False)
temp_db_file.close()
conn = DuckDbConnect.from_file_path(temp_db_file.name + "duckdb.db")
yield conn


def test_get_users(db):
assert db.get_users() == []


def test_get_table_names(db):
assert list(db.get_table_names()) == []


def test_get_users(db):
assert db.get_users() == []


def test_get_charset(db):
assert db.get_charset() == "UTF-8"


def test_get_table_comments(db):
assert db.get_table_comments("test") == []


def test_table_simple_info(db):
assert db.table_simple_info() == []


def test_execute(db):
assert list(db.run("SELECT 42")[0]) == ["42"]
8 changes: 1 addition & 7 deletions docs/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# DB-GPT Website
# DB-GPT documentation

## Quick Start

Expand All @@ -17,9 +17,3 @@ yarn start

The default service starts on port `3000`, visit `localhost:3000`

## Docker development

```commandline
docker build -t dbgptweb .
docker run --restart=unless-stopped -d -p 3000:3000 dbgptweb
```
118 changes: 118 additions & 0 deletions docs/docs/application/fine_tuning_manual/dbgpt_hub.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Fine-Tuning use dbgpt_hub

The DB-GPT-Hub project has released a pip package to lower the threshold for Text2SQL training. In addition to fine-tuning through the scripts provided in the warehouse, you can alse use the Python package we provide
for fine-tuning.

## Install
```
pip install dbgpt_hub
```

## Show Baseline
```python
from dbgpt_hub.baseline import show_scores
show_scores()
```
<p align="left">
<img src={'/img/ft/baseline.png'} width="720px" />
</p>

## Fine-tuning

```python
from dbgpt_hub.data_process import preprocess_sft_data
from dbgpt_hub.train import train_sft
from dbgpt_hub.predict import start_predict
from dbgpt_hub.eval import start_evaluate
```


Preprocessing data into fine-tuned data format.
```
data_folder = "dbgpt_hub/data"
data_info = [
{
"data_source": "spider",
"train_file": ["train_spider.json", "train_others.json"],
"dev_file": ["dev.json"],
"tables_file": "tables.json",
"db_id_name": "db_id",
"is_multiple_turn": False,
"train_output": "spider_train.json",
"dev_output": "spider_dev.json",
}
]
preprocess_sft_data(
data_folder = data_folder,
data_info = data_info
)
```

Fine-tune the basic model and generate model weights
```
train_args = {
"model_name_or_path": "codellama/CodeLlama-13b-Instruct-hf",
"do_train": True,
"dataset": "example_text2sql_train",
"max_source_length": 2048,
"max_target_length": 512,
"finetuning_type": "lora",
"lora_target": "q_proj,v_proj",
"template": "llama2",
"lora_rank": 64,
"lora_alpha": 32,
"output_dir": "dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora",
"overwrite_cache": True,
"overwrite_output_dir": True,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 16,
"lr_scheduler_type": "cosine_with_restarts",
"logging_steps": 50,
"save_steps": 2000,
"learning_rate": 2e-4,
"num_train_epochs": 8,
"plot_loss": True,
"bf16": True,
}
start_sft(train_args)
```

Predictive model output results
```
predict_args = {
"model_name_or_path": "codellama/CodeLlama-13b-Instruct-hf",
"template": "llama2",
"finetuning_type": "lora",
"checkpoint_dir": "dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora",
"predict_file_path": "dbgpt_hub/data/eval_data/dev_sql.json",
"predict_out_dir": "dbgpt_hub/output/",
"predicted_out_filename": "pred_sql.sql",
}
start_predict(predict_args)
```

Evaluate the accuracy of the output results on the test datasets

```
evaluate_args = {
"input": "./dbgpt_hub/output/pred/pred_sql_dev_skeleton.sql",
"gold": "./dbgpt_hub/data/eval_data/gold.txt",
"gold_natsql": "./dbgpt_hub/data/eval_data/gold_natsql2sql.txt",
"db": "./dbgpt_hub/data/spider/database",
"table": "./dbgpt_hub/data/eval_data/tables.json",
"table_natsql": "./dbgpt_hub/data/eval_data/tables_for_natsql2sql.json",
"etype": "exec",
"plug_value": True,
"keep_distict": False,
"progress_bar_for_each_datapoint": False,
"natsql": False,
}
start_evaluate(evaluate_args)
```



55 changes: 0 additions & 55 deletions docs/docs/faq/chatdata.md

This file was deleted.

8 changes: 7 additions & 1 deletion docs/docs/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,10 @@ Connect various data sources
Observing & monitoring

- [Evaluation](/docs/modules/eval)
Evaluate framework performance and accuracy
Evaluate framework performance and accuracy

## Community
If you encounter any problems during the process, you can submit an [issue](https://github.com/eosphoros-ai/DB-GPT/issues) and communicate with us.

We welcome [discussions](https://github.com/orgs/eosphoros-ai/discussions) in the community

8 changes: 4 additions & 4 deletions docs/sidebars.js
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ const sidebars = {
type: 'doc',
id: 'application/fine_tuning_manual/text_to_sql',
},
{
type: 'doc',
id: 'application/fine_tuning_manual/dbgpt_hub',
},
],
},
],
Expand Down Expand Up @@ -224,10 +228,6 @@ const sidebars = {
type: 'doc',
id: 'faq/kbqa',
}
,{
type: 'doc',
id: 'faq/chatdata',
},
],
},

Expand Down
Binary file added docs/static/img/ft/baseline.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file.
9 changes: 9 additions & 0 deletions tests/intetration_tests/datasource/test_conn_clickhouse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import pytest

from dbgpt.datasource.rdbms.conn_clickhouse import ClickhouseConnect


@pytest.fixture
def db():
conn = ClickhouseConnect.from_uri_db("localhost", 8123, "default", "", "default")
yield conn
12 changes: 12 additions & 0 deletions tests/intetration_tests/datasource/test_conn_doris.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_doris.py
"""

import pytest
from dbgpt.datasource.rdbms.conn_doris import DorisConnect


@pytest.fixture
def db():
conn = DorisConnect.from_uri_db("localhost", 9030, "root", "", "test")
yield conn
91 changes: 91 additions & 0 deletions tests/intetration_tests/datasource/test_conn_mysql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""
Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_mysql.py
docker run -itd --name mysql-test -p 3307:3306 -e MYSQL_ROOT_PASSWORD=12345678 mysql:5.7
mysql -h 127.0.0.1 -uroot -p -P3307
Enter password:
Welcome to the MySQL monitor. Commands end with ; or \g.
Your MySQL connection id is 2
Server version: 5.7.41 MySQL Community Server (GPL)
Copyright (c) 2000, 2023, Oracle and/or its affiliates.
Oracle is a registered trademark of Oracle Corporation and/or its
affiliates. Other names may be trademarks of their respective
owners.
Type 'help;' or '\h' for help. Type '\c' to clear the current input statement.
> create database test;
"""

import pytest
from dbgpt.datasource.rdbms.conn_mysql import MySQLConnect

_create_table_sql = """
CREATE TABLE IF NOT EXISTS `test` (
`id` int(11) DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
"""


@pytest.fixture
def db():
conn = MySQLConnect.from_uri_db(
"localhost",
3307,
"root",
"******",
"test",
engine_args={"connect_args": {"charset": "utf8mb4"}},
)
yield conn


def test_get_usable_table_names(db):
db.run(_create_table_sql)
print(db._sync_tables_from_db())
assert list(db.get_usable_table_names()) == []


def test_get_table_info(db):
assert "CREATE TABLE test" in db.get_table_info()


def test_get_table_info_with_table(db):
db.run(_create_table_sql)
print(db._sync_tables_from_db())
table_info = db.get_table_info()
assert "CREATE TABLE test" in table_info


def test_run_no_throw(db):
assert db.run_no_throw("this is a error sql").startswith("Error:")


def test_get_index_empty(db):
db.run(_create_table_sql)
assert db.get_indexes("test") == []


def test_get_fields(db):
db.run(_create_table_sql)
assert list(db.get_fields("test")[0])[0] == "id"


def test_get_charset(db):
assert db.get_charset() == "utf8mb4" or db.get_charset() == "latin1"


def test_get_collation(db):
assert (
db.get_collation() == "utf8mb4_general_ci"
or db.get_collation() == "latin1_swedish_ci"
)


def test_get_users(db):
assert ("root", "%") in db.get_users()


def test_get_database_lists(db):
assert db.get_database_list() == ["test"]
Loading

0 comments on commit d67a6fa

Please sign in to comment.