diff --git a/.gitignore b/.gitignore index b293180..aa4f887 100644 --- a/.gitignore +++ b/.gitignore @@ -184,4 +184,4 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +.idea/ diff --git a/dbgpt_hub/eval/__init__.py b/dbgpt_hub/eval/__init__.py index e69de29..b38ecaf 100644 --- a/dbgpt_hub/eval/__init__.py +++ b/dbgpt_hub/eval/__init__.py @@ -0,0 +1,8 @@ +""" +dbgpt_hub.eval +============== +""" + +from .evaluation_api import start_evaluate + +__all__ = ["start_evaluate"] diff --git a/dbgpt_hub/eval/evaluation.py b/dbgpt_hub/eval/evaluation.py index ecf374f..a96c7f9 100644 --- a/dbgpt_hub/eval/evaluation.py +++ b/dbgpt_hub/eval/evaluation.py @@ -12,6 +12,7 @@ import subprocess import json +from typing import Optional, Dict, Any from process_sql import get_schema, Schema, get_sql from exec_eval import eval_exec_match from func_timeout import func_timeout, FunctionTimedOut @@ -1152,6 +1153,42 @@ def build_foreign_key_map_from_json(table): return tables +def evaluate_api(args: Optional[Dict[str, Any]] = None): + # Prepare output file path by appending "2sql" before ".txt" if --natsql is true + if args["natsql"]: + pred_file_path = ( + args["input"].rsplit(".", 1)[0] + "2sql." + args["input"].rsplit(".", 1)[1] + ) + gold_file_path = args["gold_natsql"] + table_info_path = args["table_natsql"] + else: + pred_file_path = args["input"] + gold_file_path = args["gold"] + table_info_path = args["table"] + + # only evaluating exact match needs this argument + kmaps = None + if args["etype"] in ["all", "match"]: + assert ( + args.table is not None + ), "table argument must be non-None if exact set match is evaluated" + kmaps = build_foreign_key_map_from_json(args["table"]) + + # Print args + print(f"params as fllows \n {args}") + + evaluate( + gold_file_path, + pred_file_path, + args["db"], + args["etype"], + kmaps, + args["plug_value"], + args["keep_distinct"], + args["progress_bar_for_each_datapoint"], + ) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( diff --git a/dbgpt_hub/eval/evaluation_api.py b/dbgpt_hub/eval/evaluation_api.py new file mode 100644 index 0000000..77ab00f --- /dev/null +++ b/dbgpt_hub/eval/evaluation_api.py @@ -0,0 +1,32 @@ +from typing import Optional, Dict, Any + +from dbgpt_hub.eval import evaluation + + +def start_evaluate( + args: Optional[Dict[str, Any]] = None, +): + # Arguments for evaluation + if args is None: + args = { + "input": "./dbgpt_hub/output/pred/pred_sql_dev_skeleton.sql", + "gold": "./dbgpt_hub/data/eval_data/gold.txt", + "gold_natsql": "./dbgpt_hub/data/eval_data/gold_natsql2sql.txt", + "db": "./dbgpt_hub/data/spider/database", + "table": "./dbgpt_hub/data/eval_data/tables.json", + "table_natsql": "./dbgpt_hub/data/eval_data/tables_for_natsql2sql.json", + "etype": "exec", + "plug_value": True, + "keep_distict": False, + "progress_bar_for_each_datapoint": False, + "natsql": False, + } + else: + args = args + + # Execute evaluation + evaluation.evaluate_api(args) + + +if __name__ == "__main__": + start_evaluate() diff --git a/dbgpt_hub/predict/__init__.py b/dbgpt_hub/predict/__init__.py index e69de29..d9cb30e 100644 --- a/dbgpt_hub/predict/__init__.py +++ b/dbgpt_hub/predict/__init__.py @@ -0,0 +1,8 @@ +""" +dbgpt_hub.predict +============== +""" + +from .predict_api import start_predict + +__all__ = ["start_predict"] diff --git a/dbgpt_hub/predict/predict.py b/dbgpt_hub/predict/predict.py index 6d67722..e21a120 100644 --- a/dbgpt_hub/predict/predict.py +++ b/dbgpt_hub/predict/predict.py @@ -1,23 +1,21 @@ import os import json import sys -from tqdm import tqdm ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(ROOT_PATH) -from typing import List, Dict + +from tqdm import tqdm +from typing import List, Dict, Optional, Any from dbgpt_hub.data_process.data_utils import extract_sql_prompt_dataset from dbgpt_hub.llm_base.chat_model import ChatModel -from dbgpt_hub.configs.config import ( - PREDICTED_DATA_PATH, - OUT_DIR, - PREDICTED_OUT_FILENAME, -) -def prepare_dataset() -> List[Dict]: - with open(PREDICTED_DATA_PATH, "r") as fp: +def prepare_dataset( + predict_file_path: Optional[str] = None, +) -> List[Dict]: + with open(predict_file_path, "r") as fp: data = json.load(fp) predict_data = [extract_sql_prompt_dataset(item) for item in data] return predict_data @@ -33,21 +31,34 @@ def inference(model: ChatModel, predict_data: List[Dict], **input_kwargs): return res -def main(): - predict_data = prepare_dataset() +def predict(args: Optional[Dict[str, Any]] = None): + predict_file_path = "" + if args is None: + predict_file_path = os.path.join( + ROOT_PATH, "dbgpt_hub/data/eval_data/dev_sql.json" + ) + predict_out_dir = os.path.join( + os.path.join(ROOT_PATH, "dbgpt_hub/output/"), "pred" + ) + if not os.path.exists(predict_out_dir): + os.mkdir(predict_out_dir) + predict_output_filename = os.path.join(predict_out_dir, "pred_sql.sql") + print(f"predict_output_filename \t{predict_output_filename}") + else: + predict_file_path = os.path.join(ROOT_PATH, args["predict_file_path"]) + predict_out_dir = os.path.join( + os.path.join(ROOT_PATH, args["predict_out_dir"]), "pred" + ) + if not os.path.exists(predict_out_dir): + os.mkdir(predict_out_dir) + predict_output_filename = os.path.join(predict_out_dir, args["pred_sql.sql"]) + print(f"predict_output_filename \t{predict_output_filename}") + + predict_data = prepare_dataset(predict_file_path=predict_file_path) model = ChatModel() result = inference(model, predict_data) - predict_out_dir = os.path.join(OUT_DIR, "pred") - if not os.path.exists(predict_out_dir): - os.mkdir(predict_out_dir) - - predict_output_dir_name = os.path.join( - predict_out_dir, model.data_args.predicted_out_filename - ) - print(f"predict_output_dir_name \t{predict_output_dir_name}") - - with open(predict_output_dir_name, "w") as f: + with open(predict_output_filename, "w") as f: for p in result: try: f.write(p.replace("\n", " ") + "\n") @@ -56,4 +67,4 @@ def main(): if __name__ == "__main__": - main() + predict() diff --git a/dbgpt_hub/predict/predict_api.py b/dbgpt_hub/predict/predict_api.py new file mode 100644 index 0000000..11d916d --- /dev/null +++ b/dbgpt_hub/predict/predict_api.py @@ -0,0 +1,31 @@ +import os +from dbgpt_hub.predict import predict +from typing import Optional, Dict, Any + + +def start_predict( + args: Optional[Dict[str, Any]] = None, cuda_visible_devices: Optional[str] = "0" +): + # Setting CUDA Device + os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices + + # Default Arguments + if args is None: + args = { + "model_name_or_path": "codellama/CodeLlama-13b-Instruct-hf", + "template": "llama2", + "finetuning_type": "lora", + "checkpoint_dir": "dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora", + "predict_file_path": "dbgpt_hub/data/eval_data/dev_sql.json", + "predict_out_dir": "dbgpt_hub/output/", + "predicted_out_filename": "pred_sql.sql", + } + else: + args = args + + # Execute prediction + predict.predict(args) + + +if __name__ == "__main__": + start_predict() diff --git a/dbgpt_hub/train/__init__.py b/dbgpt_hub/train/__init__.py index e69de29..4b57815 100644 --- a/dbgpt_hub/train/__init__.py +++ b/dbgpt_hub/train/__init__.py @@ -0,0 +1,8 @@ +""" +dbgpt_hub.train +============== +""" + +from .sft_train_api import start_sft + +__all__ = ["start_sft"] diff --git a/dbgpt_hub/train/sft_train_api.py b/dbgpt_hub/train/sft_train_api.py new file mode 100644 index 0000000..39e207b --- /dev/null +++ b/dbgpt_hub/train/sft_train_api.py @@ -0,0 +1,47 @@ +import os + +from typing import Optional, Dict, Any +from dbgpt_hub.train import sft_train + + +def start_sft( + args: Optional[Dict[str, Any]] = None, cuda_visible_devices: Optional[str] = "0" +): + # Setting CUDA Device + os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices + + # Default Arguments + if args is None: + args = { + "model_name_or_path": "codellama/CodeLlama-13b-Instruct-hf", + "do_train": True, + "dataset": "example_text2sql_train", + "max_source_length": 2048, + "max_target_length": 512, + "finetuning_type": "lora", + "lora_target": "q_proj,v_proj", + "template": "llama2", + "lora_rank": 64, + "lora_alpha": 32, + "output_dir": "dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora", + "overwrite_cache": True, + "overwrite_output_dir": True, + "per_device_train_batch_size": 1, + "gradient_accumulation_steps": 16, + "lr_scheduler_type": "cosine_with_restarts", + "logging_steps": 50, + "save_steps": 2000, + "learning_rate": 2e-4, + "num_train_epochs": 8, + "plot_loss": True, + "bf16": True, + } + else: + args = args + + # Run SFT + sft_train.train(args) + + +if __name__ == "__main__": + start_sft()