Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat (hub): add style checking workflow and reformat some files #146

Merged
merged 10 commits into from
Nov 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
name: CI

on:
push:
branches:
- main
- release
pull_request:
branches:
- main

jobs:
build:
name: ${{ matrix.os }} x py${{ matrix.python }}
strategy:
fail-fast: false
matrix:
python: ["3.10", "3.11"]
os: [ubuntu-latest, windows-latest]
include:
- os: ubuntu-latest
install_graphviz:
sudo apt-get install graphviz libgraphviz-dev
- os: windows-latest
install_graphviz:
choco install graphviz --version=2.48.0;
poetry run pip install --global-option=build_ext --global-option="-IC:\Program Files\Graphviz\include" --global-option="-LC:\Program Files\Graphviz\lib" pygraphviz;

runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2

- uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python }}

- name: Cache venv
uses: actions/cache@v2
with:
path: ~/.cache/pypoetry/virtualenvs
key: ${{ runner.os }}-build-${{ matrix.python }}-${{ secrets.CACHE_VERSION }}-${{ hashFiles('poetry.lock') }}

- name: Install dependencies
run: |
echo "Cache Version ${{ secrets.CACHE_VERSION }}"
pip install poetry
${{ matrix.install_graphviz }}
poetry install
poetry run pip install ERAlchemy
poetry config --list

- name: Print tool versions
run: |
poetry run black --version

- name: Check if the code is formatted
run: poetry run black --check dbgpt_hub

# - name: Type check the project
# run: poetry run pyright dbgpt_hub || true

# - name: Style check the project
# run: poetry run pylint dbgpt_hub || true

- name: Build binary dependencies
run: poetry build
56 changes: 56 additions & 0 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
name: release

on:
push:
branches:
- release

jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
with:
fetch-depth: "0"

- uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python }}

- name: Cache venv
uses: actions/cache@v2
with:
path: ~/.cache/pypoetry/virtualenvs
key: ${{ runner.os }}-build-${{ matrix.python }}-${{ secrets.CACHE_VERSION }}-${{ hashFiles('poetry.lock') }}

- name: Install dependencies
run: |
echo "Cache Version ${{ secrets.CACHE_VERSION }}"
pip install poetry toml-cli
poetry install
poetry config --list

- name: Print tool versions
run: |
poetry run black --version

- name: Build wheels
run: poetry build

- name: Parse version from pyproject.toml
run: echo "DBGPT_HUB_VERSION=`toml get --toml-path pyproject.toml tool.poetry.version`" >> $GITHUB_ENV

- name: Create release note
run: poetry run python scripts/release-note.py $(git rev-parse --short HEAD) > RELEASE.md

- uses: ncipollo/release-action@v1
with:
artifacts: "dist/*.whl,dist/*.tar.gz"
bodyFile: "RELEASE.md"
token: ${{ secrets.GITHUB_TOKEN }}
draft: true
tag: v${{ env.DBGPT_HUB_VERSION }}
commit: ${{ env.GITHUB_SHA }}

- name: Upload wheels
run: poetry publish --username __token__ --password ${{ secrets.PYPI_TOKEN }}
4 changes: 1 addition & 3 deletions dbgpt_hub/data_process/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,4 @@

from .sql_data_process_api import preprocess_sft_data

__all__ = [
"preprocess_sft_data"
]
__all__ = ["preprocess_sft_data"]
16 changes: 7 additions & 9 deletions dbgpt_hub/data_process/connectors/anydb_connector.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Any
from .mysql_connector import MySQLConnector


class AnyDBConnector(BaseConnector):
def __init__(
self,
db_type: str = "mysql",
host: str ="127.0.0.1",
port: str =3306,
host: str = "127.0.0.1",
port: str = 3306,
user: Optional[str] = None,
passwd: Optional[str] = None,
db: Optional[str] = None,
Expand All @@ -15,15 +16,12 @@ def __init__(
**kwargs
) -> Any:
super().__init__(db_type, host, port, user, passwd, db, charset, args, kwargs)
if self.db_type == 'mysql':
if self.db_type == "mysql":
self.connector = MySQLConnector(
host=self.host,
port=self.port,
user=self.user,
passwd=self.passwd
host=self.host, port=self.port, user=self.user, passwd=self.passwd
)
"""TO DO: postgres, bigquery, etc."""

def __del__(self) -> Any:
super().__del__()

Expand Down Expand Up @@ -51,7 +49,7 @@ def get_version(self, args=None):

def get_all_table_metadata(self, args=None):
"""查询所有表的元数据信息"""
return self.connector.get_all_table_metadata(args)
return self.connector.get_all_table_metadata(args)

def get_table_metadata(self, db, table, args=None):
"""查询指定表的元数据信息"""
Expand Down
95 changes: 38 additions & 57 deletions dbgpt_hub/data_process/specialized_scenario_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,14 @@ class SpecialScenarioDataProcessor(object):
def __init__(
self,
db_type: str = "mysql",
host: str ="127.0.0.1",
port: str =3306,
host: str = "127.0.0.1",
port: str = 3306,
user: Optional[str] = None,
passwd: Optional[str] = None,
) -> Any:
self.db_type = db_type
self.connector = AnyDBConnector(
db_type = db_type,
host = host,
port = port,
user = user,
passwd = passwd
db_type=db_type, host=host, port=port, user=user, passwd=passwd
)

def generate_spider_nl2sql_metadata(
Expand All @@ -33,81 +29,66 @@ def generate_spider_nl2sql_metadata(
output_folder: str = "",
training_ratio: float = 0.7,
valid_ratio: float = 0.3,
test_ratio: float = 0.0
test_ratio: float = 0.0,
):
# Generate table.json file according to Database Info
"""
TO DO:
TO DO:
1. need to solve config data
"""
table_meta_data_processor = TableMetaDataProcessor(
conn = self.connector,
config_data = None
conn=self.connector, config_data=None
)
spider_table_metadata = \
spider_table_metadata = (
table_meta_data_processor.generate_spider_table_metadata()
)

with open(os.path.join(output_folder, 'tables.json'), 'w', encoding='utf-8') as json_file:
with open(
os.path.join(output_folder, "tables.json"), "w", encoding="utf-8"
) as json_file:
json.dump(spider_table_metadata, json_file, ensure_ascii=False, indent=4)

text2sql_pairs = self._parse_text2sql_pairs(
input_folder = input_folder
)
text2sql_pairs = self._parse_text2sql_pairs(input_folder=input_folder)

metadata_with_tokens = self._parse_tokens(
text2sql_pairs = text2sql_pairs
)
metadata_with_tokens = self._parse_tokens(text2sql_pairs=text2sql_pairs)

# split metadata into training valid and test data
train_data, remain_data = train_test_split(
metadata_with_tokens,
test_size = (1 - training_ratio),
random_state = 42
metadata_with_tokens, test_size=(1 - training_ratio), random_state=42
)

if test_ratio > 0:
valid_data, test_data = train_test_split(
remain_data,
test_size = test_ratio / (valid_ratio + test_ratio),
random_state = 42
remain_data,
test_size=test_ratio / (valid_ratio + test_ratio),
random_state=42,
)

with open(
os.path.join(output_folder, 'train.json'),
'w',
encoding='utf-8'
os.path.join(output_folder, "train.json"), "w", encoding="utf-8"
) as json_file:
json.dump(train_data, json_file, ensure_ascii=False, indent=4)

with open(
os.path.join(output_folder, 'valid.json'),
'w',
encoding='utf-8'
os.path.join(output_folder, "valid.json"), "w", encoding="utf-8"
) as json_file:
json.dump(valid_data, json_file, ensure_ascii=False, indent=4)

with open(
os.path.join(output_folder, 'test.json'),
'w',
encoding='utf-8'
os.path.join(output_folder, "test.json"), "w", encoding="utf-8"
) as json_file:
json.dump(test_data, json_file, ensure_ascii=False, indent=4)
else:
with open(
os.path.join(output_folder, 'train.json'),
'w',
encoding='utf-8'
os.path.join(output_folder, "train.json"), "w", encoding="utf-8"
) as json_file:
json.dump(train_data, json_file, ensure_ascii=False, indent=4)

with open(
os.path.join(output_folder, 'valid.json'),
'w',
encoding='utf-8'
os.path.join(output_folder, "valid.json"), "w", encoding="utf-8"
) as json_file:
json.dump(remain_data, json_file, ensure_ascii=False, indent=4)


def _parse_text2sql_pairs(
self,
input_folder: Optional[str] = None,
Expand All @@ -123,35 +104,35 @@ def _parse_text2sql_pairs(

# Iterate through all files in the folder
for filename in os.listdir(input_folder):
if filename.endswith('.csv'):
if filename.endswith(".csv"):
file_path = os.path.join(input_folder, filename)

# Open and read the CSV file
with open(file_path, 'r', encoding='utf-8') as csv_file:
with open(file_path, "r", encoding="utf-8") as csv_file:
csv_reader = csv.DictReader(csv_file)

# Iterate through each row in the CSV file
for row in csv_reader:
db_name = row['db_name']
nl_question = row['nl_question']
query = row['query']
db_name = row["db_name"]
nl_question = row["nl_question"]
query = row["query"]

# Create a dictionary for the current row and append it to the list
data_list.append({'db_id': db_name, 'question': nl_question, 'query': query})
data_list.append(
{"db_id": db_name, "question": nl_question, "query": query}
)

# Now, data_list contains all the rows from all CSV files as dictionaries
return data_list

def _parse_tokens(
self,
text2sql_pairs: Optional[Dict] = None
) -> Any:

def _parse_tokens(self, text2sql_pairs: Optional[Dict] = None) -> Any:
"""
TO DO:
need to parse SQL and NL into list of tokens
need to parse SQL and NL into list of tokens
"""
pass


if __name__ == "__main__":
data_processor = SpecialScenarioDataProcessor()
data_processor.generate_spider_nl2sql_metadata()
3 changes: 2 additions & 1 deletion dbgpt_hub/data_process/sql_data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
INSTRUCTION_PROMPT,
)


class ProcessSqlData:
def __init__(self, train_file=None, dev_file=None) -> None:
self.train_file = train_file
Expand Down Expand Up @@ -165,4 +166,4 @@ def create_sft_raw_data(self):
precess = ProcessSqlData(
train_file=all_in_one_train_file, dev_file=all_in_one_dev_file
)
precess.create_sft_raw_data()
precess.create_sft_raw_data()
Loading