diff --git a/dbgpt_hub/configs/data_args.py b/dbgpt_hub/configs/data_args.py index c861723..9281e44 100644 --- a/dbgpt_hub/configs/data_args.py +++ b/dbgpt_hub/configs/data_args.py @@ -156,6 +156,13 @@ class DataArguments: "help": "Size of the development set, should be an integer or a float in range `[0,1)`." }, ) + predicted_out_filename: Optional[str] = field( + default="pred_sql.sql", + metadata={ + "help": "Filename to save predicted outcomes" + }, + ) + def init_for_training(self): # support mixing multiple datasets dataset_names = [ds.strip() for ds in self.dataset.split(",")] diff --git a/dbgpt_hub/llm_base/chat_model.py b/dbgpt_hub/llm_base/chat_model.py index b9fc15c..1d14384 100644 --- a/dbgpt_hub/llm_base/chat_model.py +++ b/dbgpt_hub/llm_base/chat_model.py @@ -16,12 +16,12 @@ class ChatModel: def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: - model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args) + model_args, self.data_args, finetuning_args, self.generating_args = get_infer_args(args) self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args) self.tokenizer.padding_side = "left" self.model = dispatch_model(self.model) - self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer) - self.system_prompt = data_args.system_prompt + self.template = get_template_and_fix_tokenizer(self.data_args.template, self.tokenizer) + self.system_prompt = self.data_args.system_prompt def process_args( self, diff --git a/dbgpt_hub/predict/predict.py b/dbgpt_hub/predict/predict.py index d7da933..4c89a30 100644 --- a/dbgpt_hub/predict/predict.py +++ b/dbgpt_hub/predict/predict.py @@ -18,8 +18,8 @@ def prepare_dataset() -> List[Dict]: def inference(model: ChatModel, predict_data: List[Dict], **input_kwargs): res = [] #test - for item in predict_data[:20]: - # for item in predict_data: + # for item in predict_data[:20]: + for item in predict_data: response, _ = model.chat( query=item["input"], history=[], @@ -36,7 +36,7 @@ def main(): if not os.path.exists(predict_out_dir): os.mkdir(predict_out_dir) - predict_output_dir_name = os.path.join(predict_out_dir, PREDICTED_OUT_FILENAME) + predict_output_dir_name = os.path.join(predict_out_dir, model.data_args.predicted_out_filename) print(f"predict_output_dir_name \t{predict_output_dir_name}") with open(predict_output_dir_name, "w") as f: diff --git a/dbgpt_hub/scripts/predict_sft.sh b/dbgpt_hub/scripts/predict_sft.sh index 953e747..5f2c4c9 100644 --- a/dbgpt_hub/scripts/predict_sft.sh +++ b/dbgpt_hub/scripts/predict_sft.sh @@ -1,17 +1,18 @@ ## shijian llama2 test -# CUDA_VISIBLE_DEVICES=0,1 python dbgpt_hub/predict/predict.py \ -# --model_name_or_path Llama-2-13b-chat-hf \ -# --template llama2 \ -# --finetuning_type lora \ -# --checkpoint_dir dbgpt_hub/output/adapter +CUDA_VISIBLE_DEVICES=0,1 python dbgpt_hub/predict/predict.py \ + --model_name_or_path Llama-2-13b-chat-hf \ + --template llama2 \ + --finetuning_type lora \ + --checkpoint_dir dbgpt_hub/output/adapter/llama2-13b-qlora \ + --predicted_out_filename pred_sql.sql # # wangzai baichua2_eval test -CUDA_VISIBLE_DEVICES=0 python dbgpt_hub/predict/predict.py \ - --model_name_or_path /home/model/Baichuan2-13B-Chat \ - --template baichuan2_eval \ - --quantization_bit 4 \ - --finetuning_type lora \ - --checkpoint_dir dbgpt_hub/output/adapter/baichuan2-13b-qlora +# CUDA_VISIBLE_DEVICES=0 python dbgpt_hub/predict/predict.py \ +# --model_name_or_path /home/model/Baichuan2-13B-Chat \ +# --template baichuan2_eval \ +# --quantization_bit 4 \ +# --finetuning_type lora \ +# --checkpoint_dir dbgpt_hub/output/adapter/baichuan2-13b-qlora ## wangzai codellama2_pred test a100