Skip to content

Commit

Permalink
add style checking workflow and reformat some files
Browse files Browse the repository at this point in the history
  • Loading branch information
qidanrui committed Nov 25, 2023
1 parent 28dd6a9 commit 10aaf9d
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 107 deletions.
56 changes: 56 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
name: CI

on:
push:
branches:
- main
- release
pull_request:
branches:
- main

jobs:
build:
name: ${{ matrix.os }} x py${{ matrix.python }}
strategy:
fail-fast: false
matrix:
python: ["3.10", "3.11"]
os: [ubuntu-latest, windows-latest]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2

- uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python }}

- name: Cache venv
uses: actions/cache@v2
with:
path: ~/.cache/pypoetry/virtualenvs
key: ${{ runner.os }}-build-${{ matrix.python }}-${{ secrets.CACHE_VERSION }}-${{ hashFiles('poetry.lock') }}

- name: Install dependencies
run: |
echo "Cache Version ${{ secrets.CACHE_VERSION }}"
pip install poetry
poetry install
poetry run pip install ERAlchemy
poetry config --list
- name: Print tool versions
run: |
poetry run black --version
- name: Check if the code is formatted
run: poetry run black --check dataprep

# - name: Type check the project
# run: poetry run pyright dataprep || true

# - name: Style check the project
# run: poetry run pylint dataprep || true

- name: Build binary dependencies
run: poetry build
56 changes: 56 additions & 0 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
name: release

on:
push:
branches:
- release

jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
with:
fetch-depth: "0"

- uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python }}

- name: Cache venv
uses: actions/cache@v2
with:
path: ~/.cache/pypoetry/virtualenvs
key: ${{ runner.os }}-build-${{ matrix.python }}-${{ secrets.CACHE_VERSION }}-${{ hashFiles('poetry.lock') }}

- name: Install dependencies
run: |
echo "Cache Version ${{ secrets.CACHE_VERSION }}"
pip install poetry toml-cli
poetry install
poetry config --list
- name: Print tool versions
run: |
poetry run black --version
- name: Build wheels
run: poetry build

- name: Parse version from pyproject.toml
run: echo "DBGPT_HUB_VERSION=`toml get --toml-path pyproject.toml tool.poetry.version`" >> $GITHUB_ENV

- name: Create release note
run: poetry run python scripts/release-note.py $(git rev-parse --short HEAD) > RELEASE.md

- uses: ncipollo/release-action@v1
with:
artifacts: "dist/*.whl,dist/*.tar.gz"
bodyFile: "RELEASE.md"
token: ${{ secrets.GITHUB_TOKEN }}
draft: true
tag: v${{ env.DBGPT_HUB_VERSION }}
commit: ${{ env.GITHUB_SHA }}

- name: Upload wheels
run: poetry publish --username __token__ --password ${{ secrets.PYPI_TOKEN }}
4 changes: 1 addition & 3 deletions dbgpt_hub/data_process/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,4 @@

from .sql_data_process_api import preprocess_sft_data

__all__ = [
"preprocess_sft_data"
]
__all__ = ["preprocess_sft_data"]
16 changes: 7 additions & 9 deletions dbgpt_hub/data_process/connectors/anydb_connector.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Any
from .mysql_connector import MySQLConnector


class AnyDBConnector(BaseConnector):
def __init__(
self,
db_type: str = "mysql",
host: str ="127.0.0.1",
port: str =3306,
host: str = "127.0.0.1",
port: str = 3306,
user: Optional[str] = None,
passwd: Optional[str] = None,
db: Optional[str] = None,
Expand All @@ -15,15 +16,12 @@ def __init__(
**kwargs
) -> Any:
super().__init__(db_type, host, port, user, passwd, db, charset, args, kwargs)
if self.db_type == 'mysql':
if self.db_type == "mysql":
self.connector = MySQLConnector(
host=self.host,
port=self.port,
user=self.user,
passwd=self.passwd
host=self.host, port=self.port, user=self.user, passwd=self.passwd
)
"""TO DO: postgres, bigquery, etc."""

def __del__(self) -> Any:
super().__del__()

Expand Down Expand Up @@ -51,7 +49,7 @@ def get_version(self, args=None):

def get_all_table_metadata(self, args=None):
"""查询所有表的元数据信息"""
return self.connector.get_all_table_metadata(args)
return self.connector.get_all_table_metadata(args)

def get_table_metadata(self, db, table, args=None):
"""查询指定表的元数据信息"""
Expand Down
95 changes: 38 additions & 57 deletions dbgpt_hub/data_process/specialized_scenario_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,14 @@ class SpecialScenarioDataProcessor(object):
def __init__(
self,
db_type: str = "mysql",
host: str ="127.0.0.1",
port: str =3306,
host: str = "127.0.0.1",
port: str = 3306,
user: Optional[str] = None,
passwd: Optional[str] = None,
) -> Any:
self.db_type = db_type
self.connector = AnyDBConnector(
db_type = db_type,
host = host,
port = port,
user = user,
passwd = passwd
db_type=db_type, host=host, port=port, user=user, passwd=passwd
)

def generate_spider_nl2sql_metadata(
Expand All @@ -33,81 +29,66 @@ def generate_spider_nl2sql_metadata(
output_folder: str = "",
training_ratio: float = 0.7,
valid_ratio: float = 0.3,
test_ratio: float = 0.0
test_ratio: float = 0.0,
):
# Generate table.json file according to Database Info
"""
TO DO:
TO DO:
1. need to solve config data
"""
table_meta_data_processor = TableMetaDataProcessor(
conn = self.connector,
config_data = None
conn=self.connector, config_data=None
)
spider_table_metadata = \
spider_table_metadata = (
table_meta_data_processor.generate_spider_table_metadata()
)

with open(os.path.join(output_folder, 'tables.json'), 'w', encoding='utf-8') as json_file:
with open(
os.path.join(output_folder, "tables.json"), "w", encoding="utf-8"
) as json_file:
json.dump(spider_table_metadata, json_file, ensure_ascii=False, indent=4)

text2sql_pairs = self._parse_text2sql_pairs(
input_folder = input_folder
)
text2sql_pairs = self._parse_text2sql_pairs(input_folder=input_folder)

metadata_with_tokens = self._parse_tokens(
text2sql_pairs = text2sql_pairs
)
metadata_with_tokens = self._parse_tokens(text2sql_pairs=text2sql_pairs)

# split metadata into training valid and test data
train_data, remain_data = train_test_split(
metadata_with_tokens,
test_size = (1 - training_ratio),
random_state = 42
metadata_with_tokens, test_size=(1 - training_ratio), random_state=42
)

if test_ratio > 0:
valid_data, test_data = train_test_split(
remain_data,
test_size = test_ratio / (valid_ratio + test_ratio),
random_state = 42
remain_data,
test_size=test_ratio / (valid_ratio + test_ratio),
random_state=42,
)

with open(
os.path.join(output_folder, 'train.json'),
'w',
encoding='utf-8'
os.path.join(output_folder, "train.json"), "w", encoding="utf-8"
) as json_file:
json.dump(train_data, json_file, ensure_ascii=False, indent=4)

with open(
os.path.join(output_folder, 'valid.json'),
'w',
encoding='utf-8'
os.path.join(output_folder, "valid.json"), "w", encoding="utf-8"
) as json_file:
json.dump(valid_data, json_file, ensure_ascii=False, indent=4)

with open(
os.path.join(output_folder, 'test.json'),
'w',
encoding='utf-8'
os.path.join(output_folder, "test.json"), "w", encoding="utf-8"
) as json_file:
json.dump(test_data, json_file, ensure_ascii=False, indent=4)
else:
with open(
os.path.join(output_folder, 'train.json'),
'w',
encoding='utf-8'
os.path.join(output_folder, "train.json"), "w", encoding="utf-8"
) as json_file:
json.dump(train_data, json_file, ensure_ascii=False, indent=4)

with open(
os.path.join(output_folder, 'valid.json'),
'w',
encoding='utf-8'
os.path.join(output_folder, "valid.json"), "w", encoding="utf-8"
) as json_file:
json.dump(remain_data, json_file, ensure_ascii=False, indent=4)


def _parse_text2sql_pairs(
self,
input_folder: Optional[str] = None,
Expand All @@ -123,35 +104,35 @@ def _parse_text2sql_pairs(

# Iterate through all files in the folder
for filename in os.listdir(input_folder):
if filename.endswith('.csv'):
if filename.endswith(".csv"):
file_path = os.path.join(input_folder, filename)

# Open and read the CSV file
with open(file_path, 'r', encoding='utf-8') as csv_file:
with open(file_path, "r", encoding="utf-8") as csv_file:
csv_reader = csv.DictReader(csv_file)

# Iterate through each row in the CSV file
for row in csv_reader:
db_name = row['db_name']
nl_question = row['nl_question']
query = row['query']
db_name = row["db_name"]
nl_question = row["nl_question"]
query = row["query"]

# Create a dictionary for the current row and append it to the list
data_list.append({'db_id': db_name, 'question': nl_question, 'query': query})
data_list.append(
{"db_id": db_name, "question": nl_question, "query": query}
)

# Now, data_list contains all the rows from all CSV files as dictionaries
return data_list

def _parse_tokens(
self,
text2sql_pairs: Optional[Dict] = None
) -> Any:

def _parse_tokens(self, text2sql_pairs: Optional[Dict] = None) -> Any:
"""
TO DO:
need to parse SQL and NL into list of tokens
need to parse SQL and NL into list of tokens
"""
pass


if __name__ == "__main__":
data_processor = SpecialScenarioDataProcessor()
data_processor.generate_spider_nl2sql_metadata()
3 changes: 2 additions & 1 deletion dbgpt_hub/data_process/sql_data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
INSTRUCTION_PROMPT,
)


class ProcessSqlData:
def __init__(self, train_file=None, dev_file=None) -> None:
self.train_file = train_file
Expand Down Expand Up @@ -165,4 +166,4 @@ def create_sft_raw_data(self):
precess = ProcessSqlData(
train_file=all_in_one_train_file, dev_file=all_in_one_dev_file
)
precess.create_sft_raw_data()
precess.create_sft_raw_data()
Loading

0 comments on commit 10aaf9d

Please sign in to comment.