diff --git a/dbgpt_hub/predict/predict_lora.py b/dbgpt_hub/predict/predict_lora.py deleted file mode 100644 index 9f3679a..0000000 --- a/dbgpt_hub/predict/predict_lora.py +++ /dev/null @@ -1,183 +0,0 @@ -# import re -# import os -# import torch -# import argparse -# import transformers -# from transformers import AutoTokenizer -# from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer -# from dbgpt_hub.configs import GenerationArguments, ModelInferenceArguments -# from datasets import load_dataset -# from dbgpt_hub.utils.model_utils import get_logits_processor -# from dbgpt_hub.utils.model_utils import smart_tokenizer_and_embedding_resize -# from peft import PeftModel - -# from dbgpt_hub.configs.config import OUT_DIR, MODEL_PATH, DEFAULT_FT_MODEL_NAME -# from dbgpt_hub.configs.data_args import DEFAULT_PROMPT_DICT,ALPACA_PROMPT_DICT,SQL_PROMPT_DICT - - -# model_path = os.path.join(MODEL_PATH, DEFAULT_FT_MODEL_NAME) - - -# def get_args(): -# parser = argparse.ArgumentParser() -# parser.add_argument("--base_model_name_or_path", type=str, default=model_path) -# parser.add_argument("--peft_ckpt_path", type=str, default="Your lora ckpt path") -# parser.add_argument("--input_data_json", type=str, default="dev_sql.json") -# parser.add_argument( -# "--output_name", -# type=str, -# default=OUT_DIR + "/pre_lora_8_lr_2e4_drop1e1.sql", -# ) -# return parser.parse_args() - - -# local_parser = get_args() -# # print(f"loca {local_parser.base_model_name_or_path}") - - -# def extract_sql_dataset(example): -# if example.get("input", "") != "": -# prompt_format = SQL_PROMPT_DICT["prompt_input"] -# else: -# prompt_format = SQL_PROMPT_DICT["prompt_no_input"] -# return {"input": prompt_format.format(**example)} - - -# def predict(): -# # parameters -# parser = transformers.HfArgumentParser( -# (ModelInferenceArguments, GenerationArguments) -# ) -# model_server_args, generation_args = parser.parse_args_into_dataclasses() - -# device = "cuda" if torch.cuda.is_available() else "cpu" -# print(f"Loading base model: {model_server_args.model_name_or_path}") - -# base_model = AutoModelForCausalLM.from_pretrained( -# local_parser.base_model_name_or_path, -# trust_remote_code=True, -# low_cpu_mem_usage=True, -# torch_dtype=torch.float16, -# device_map={"": 0}, -# ) - -# print(f"Loading PEFT LoRA: {local_parser.peft_ckpt_path}") -# model = PeftModel.from_pretrained(base_model, local_parser.peft_ckpt_path) - -# # args = get_args() - -# # print(f"Loading base model: {args.base_model_name_or_path}") -# # base_model = AutoModelForCausalLM.from_pretrained( -# # args.base_model_name_or_path, -# # return_dict=True, -# # torch_dtype=torch.float16, -# # trust_remote_code=True -# # ) - -# # print(f"Loading PEFT: {args.peft_model_path}") -# # model = PeftModel.from_pretrained(base_model, checkpoint_dir) -# # model.to(args.device) - -# tokenizer = AutoTokenizer.from_pretrained( -# local_parser.base_model_name_or_path, -# trust_remote_code=True, -# use_fast=False, -# ) -# if tokenizer._pad_token is None: -# smart_tokenizer_and_embedding_resize( -# special_tokens_dict=dict(pad_token="[PAD]"), -# tokenizer=tokenizer, -# model=model, -# ) -# if "llama" in model_server_args.model_name_or_path or isinstance( -# tokenizer, LlamaTokenizer -# ): -# # LLaMA tokenizer may not have correct special tokens set. -# # Check and add them if missing to prevent them from being parsed into different tokens. -# # Note that these are present in the vocabulary. -# # Note also that `model.config.pad_token_id` is 0 which corresponds to `` token. -# print("Adding special tokens.") -# tokenizer.add_special_tokens( -# { -# "eos_token": tokenizer.convert_ids_to_tokens(model.config.eos_token_id), -# "bos_token": tokenizer.convert_ids_to_tokens(model.config.bos_token_id), -# "unk_token": tokenizer.convert_ids_to_tokens( -# model.config.pad_token_id -# if model.config.pad_token_id != -1 -# else tokenizer.pad_token_id -# ), -# } -# ) -# model.config.use_cache = False -# # model.to(device) - -# # Load dataset. -# dataset = load_dataset("json", data_files=local_parser.input_data_json) -# dataset = dataset.map(extract_sql_dataset, remove_columns=["instruction"]) -# # dataset_labels = dataset["train"]["output"] -# dataset = dataset["train"]["input"] - -# result = [] -# predict_batchsize = 1 -# idx = 0 -# nums_examples = len(dataset) -# while idx < nums_examples: -# if idx + predict_batchsize < nums_examples: -# inputs = dataset[idx : idx + predict_batchsize] -# idx += predict_batchsize -# else: -# inputs = dataset[idx:nums_examples] -# idx = nums_examples -# encoded_inputs = tokenizer.batch_encode_plus( -# inputs, return_tensors="pt", padding=True, truncation=True, max_length=512 -# ) -# encoded_inputs = { -# name: tensor.to(device) for name, tensor in encoded_inputs.items() -# } -# outputs = model.generate( -# **encoded_inputs, -# **generation_args.to_dict(), -# logits_processor=get_logits_processor(), -# ) -# # ## support different type LLM -# # if re.search(r'(?i)falcon', model_path): -# # generate_kwargs = { -# # "input_ids": encoded_inputs["input_ids"], -# # "attention_mask": encoded_inputs["attention_mask"] -# # } -# # outputs = model.generate(**generate_kwargs, max_length=512) -# # elif re.search(r'(?i)llama', model_path): -# # outputs = model.generate( -# # **encoded_inputs, -# # max_new_tokens=512, -# # generation_config = training_args.generation_config, -# # logits_processor=get_logits_processor() -# # ) -# # else: -# # print("right now,not support well") - -# # support the compared format directly ,like origin inputs: \n orgin outputs labels \n predict; -# for output in outputs: -# prediction = tokenizer.decode(output, skip_special_tokens=True) -# response = re.split(r"Response:\s*", prediction)[-1] -# result.append(response) -# print(response) -# print(idx) -# # origin only predict format -# # for output in outputs: -# # prediction = tokenizer.decode(output, skip_special_tokens=True) -# # response = re.split(r"Response:\s*", prediction)[-1] -# # result.append(response.replace("\n", "")) -# return result - - -# if __name__ == "__main__": -# result = predict() - -# # Judge path exists, if not need create -# if not os.path.exists(OUT_DIR): -# os.mkdir(OUT_DIR) - -# with open(local_parser.output_name, "w") as f: -# for p in result: -# f.write(p + "\n") diff --git a/dbgpt_hub/predict/predict_no_peft_llama2_13b_hf.py b/dbgpt_hub/predict/predict_no_peft_llama2_13b_hf.py deleted file mode 100644 index 378889e..0000000 --- a/dbgpt_hub/predict/predict_no_peft_llama2_13b_hf.py +++ /dev/null @@ -1,150 +0,0 @@ -# """predict only base model ,no peft sft""" -# import re -# import os -# import torch -# import argparse -# import transformers -# from datasets import load_dataset -# from transformers import AutoTokenizer -# from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer -# from dbgpt_hub.configs import GenerationArguments, ModelInferenceArguments -# from dbgpt_hub.configs.config import MODEL_PATH, OUT_DIR, DEFAULT_FT_MODEL_NAME -# from dbgpt_hub.utils.model_utils import get_logits_processor -# from dbgpt_hub.utils.model_utils import smart_tokenizer_and_embedding_resize -# from dbgpt_hub.configs.data_args import DEFAULT_PROMPT_DICT,ALPACA_PROMPT_DICT,SQL_PROMPT_DICT - - -# def get_args(): -# parser = argparse.ArgumentParser() -# parser.add_argument( -# "--base_model_name_or_path", -# type=str, -# default=os.path.join(MODEL_PATH, DEFAULT_FT_MODEL_NAME), -# ) -# parser.add_argument("--input_data_json", type=str, default="dev_sql.json") -# parser.add_argument( -# "--output_name", -# type=str, -# default=OUT_DIR + "/predict_no_peft_llama2_13b_hf_new.sql", -# ) - -# return parser.parse_args() - - -# local_parser = get_args() - - - - -# def extract_sql_dataset(example): -# if example.get("input", "") != "": -# prompt_format = SQL_PROMPT_DICT["prompt_input"] -# else: -# prompt_format = SQL_PROMPT_DICT["prompt_no_input"] -# return {"input": prompt_format.format(**example)} - - -# def predict(): -# # parameters -# parser = transformers.HfArgumentParser( -# (ModelInferenceArguments, GenerationArguments) -# ) -# model_server_args, generation_args = parser.parse_args_into_dataclasses() - -# device = "cuda" if torch.cuda.is_available() else "cpu" -# print(f"Loading base model: {model_server_args.model_name_or_path}") - -# base_model = AutoModelForCausalLM.from_pretrained( -# local_parser.base_model_name_or_path, -# trust_remote_code=True, -# low_cpu_mem_usage=True, -# torch_dtype=torch.float16, -# device_map={"": 0}, -# ) - -# model = base_model - -# tokenizer = AutoTokenizer.from_pretrained( -# local_parser.base_model_name_or_path, -# trust_remote_code=True, -# use_fast=False, -# ) -# if tokenizer._pad_token is None: -# smart_tokenizer_and_embedding_resize( -# special_tokens_dict=dict(pad_token="[PAD]"), -# tokenizer=tokenizer, -# model=model, -# ) -# if "llama" in model_server_args.model_name_or_path or isinstance( -# tokenizer, LlamaTokenizer -# ): -# # LLaMA tokenizer may not have correct special tokens set. -# # Check and add them if missing to prevent them from being parsed into different tokens. -# # Note that these are present in the vocabulary. -# # Note also that `model.config.pad_token_id` is 0 which corresponds to `` token. -# print("Adding special tokens.") -# tokenizer.add_special_tokens( -# { -# "eos_token": tokenizer.convert_ids_to_tokens(model.config.eos_token_id), -# "bos_token": tokenizer.convert_ids_to_tokens(model.config.bos_token_id), -# "unk_token": tokenizer.convert_ids_to_tokens( -# model.config.pad_token_id -# if model.config.pad_token_id != -1 -# else tokenizer.pad_token_id -# ), -# } -# ) -# model.config.use_cache = False -# # model.to(device) - -# # Load dataset. -# dataset = load_dataset("json", data_files=local_parser.input_data_json) -# dataset = dataset.map(extract_sql_dataset, remove_columns=["instruction"]) -# dataset = dataset["train"]["input"] - -# result = [] -# predict_batchsize = 1 -# idx = 0 -# nums_examples = len(dataset) -# # if nums_examples > 6: -# # nums_examples = 6 -# print(f"just test {nums_examples} examples\n") -# while idx < nums_examples: -# if idx + predict_batchsize < nums_examples: -# inputs = dataset[idx : idx + predict_batchsize] -# idx += predict_batchsize -# else: -# inputs = dataset[idx:nums_examples] -# idx = nums_examples -# encoded_inputs = tokenizer.batch_encode_plus( -# inputs, return_tensors="pt", padding=True, truncation=True, max_length=512 -# ) -# encoded_inputs = { -# name: tensor.to(device) for name, tensor in encoded_inputs.items() -# } -# outputs = model.generate( -# **encoded_inputs, -# **generation_args.to_dict(), -# logits_processor=get_logits_processor(), -# ) - -# # support the compared format directly ,like origin inputs: \n orgin outputs labels \n predict; -# for output in outputs: -# prediction = tokenizer.decode(output, skip_special_tokens=True) -# response = re.split(r"Response:\s*", prediction)[-1] -# print("response replace \n", response.replace("\n", "")) -# result.append(response.replace("\n", "")) - -# return result - - -# if __name__ == "__main__": -# result = predict() - -# # Judge path exists, if not need create -# if not os.path.exists(OUT_DIR): -# os.mkdir(OUT_DIR) - -# with open(local_parser.output_name, "w") as f: -# for p in result: -# f.write(p + "\n") diff --git a/dbgpt_hub/predict/predict_qlora.py b/dbgpt_hub/predict/predict_qlora.py deleted file mode 100644 index 90f3e2a..0000000 --- a/dbgpt_hub/predict/predict_qlora.py +++ /dev/null @@ -1,213 +0,0 @@ -# import re -# import os -# import argparse -# import torch -# import transformers -# from transformers import AutoTokenizer -# from transformers import set_seed, Seq2SeqTrainer, GenerationConfig -# from datasets import load_dataset - -# from dbgpt_hub.configs import ( -# DataArguments, -# GenerationArguments, -# LoraArguments, -# ModelArguments, -# QuantArguments, -# TrainingArguments, -# ) - -# from dbgpt_hub.llms import get_accelerate_model -# from dbgpt_hub.configs.config import MODEL_PATH, DEFAULT_FT_MODEL_NAME, OUT_DIR -# from dbgpt_hub.configs.data_args import DEFAULT_PROMPT_DICT,ALPACA_PROMPT_DICT,SQL_PROMPT_DICT - - -# def get_args(): -# parser = argparse.ArgumentParser() -# parser.add_argument( -# "--base_model_name_or_path", -# type=str, -# default=os.path.join(MODEL_PATH, DEFAULT_FT_MODEL_NAME), -# ) -# parser.add_argument( -# "--peft_ckpt_path", type=str, default="Your peft qlora ckpt path" -# ) -# parser.add_argument("--input_data_json", type=str, default="dev_sql.json") -# parser.add_argument( -# "--output_name", type=str, default=OUT_DIR + "/qlora_8_lr_2e4_drop1e1.sql" -# ) - -# return parser.parse_args() - - -# local_parser = get_args() - - - -# def extract_alpaca_dataset(example): -# if example.get("input", "") != "": -# prompt_format = ALPACA_PROMPT_DICT["prompt_input"] -# else: -# prompt_format = ALPACA_PROMPT_DICT["prompt_no_input"] -# return {"input": prompt_format.format(**example)} - - -# def extract_sql_dataset(example): -# if example.get("input", "") != "": -# prompt_format = SQL_PROMPT_DICT["prompt_input"] -# else: -# prompt_format = SQL_PROMPT_DICT["prompt_no_input"] -# return {"input": prompt_format.format(**example)} - - -# def predict(): -# # parameters -# parser = transformers.HfArgumentParser( -# ( -# ModelArguments, -# DataArguments, -# TrainingArguments, -# LoraArguments, -# QuantArguments, -# GenerationArguments, -# ) -# ) -# ( -# model_args, -# data_args, -# training_args, -# lora_args, -# quant_args, -# generation_args, -# ) = parser.parse_args_into_dataclasses() -# # Check arguments (do not check finetuning_args since it may be loaded from checkpoints) -# # data_args.init_for_training() -# training_args.generation_config = GenerationConfig(**vars(generation_args)) -# import argparse - -# args = argparse.Namespace( -# **vars(model_args), -# **vars(data_args), -# **vars(training_args), -# **vars(lora_args), -# **vars(quant_args), -# ) - -# device = "cuda" if torch.cuda.is_available() else "cpu" -# model, tokenizer = get_accelerate_model(args, local_parser.peft_ckpt_path) -# model.config.use_cache = False - -# def format_dataset(dataset, dataset_format): -# if ( -# dataset_format == "alpaca" -# or dataset_format == "alpaca-clean" -# or (dataset_format is None and args.dataset in ["alpaca", "alpaca-clean"]) -# ): -# dataset = dataset.map( -# extract_alpaca_dataset, remove_columns=["instruction"] -# ) -# elif dataset_format == "spider": -# dataset = dataset.map(extract_sql_dataset, remove_columns=["instruction"]) -# elif dataset_format == "chip2" or ( -# dataset_format is None and args.dataset == "chip2" -# ): -# dataset = dataset.map( -# lambda x: { -# "input": x["text"].split("\n: ")[0].replace(": ", ""), -# "output": x["text"].split("\n: ")[1], -# } -# ) -# elif dataset_format == "self-instruct" or ( -# dataset_format is None and args.dataset == "self-instruct" -# ): -# for old, new in [["prompt", "input"], ["completion", "output"]]: -# dataset = dataset.rename_column(old, new) -# elif dataset_format == "hh-rlhf" or ( -# dataset_format is None and args.dataset == "hh-rlhf" -# ): -# dataset = dataset.map(lambda x: {"input": "", "output": x["chosen"]}) -# elif dataset_format == "oasst1" or ( -# dataset_format is None and args.dataset == "oasst1" -# ): -# dataset = dataset.map( -# lambda x: { -# "input": "", -# "output": x["text"], -# } -# ) -# elif dataset_format == "input-output": -# pass -# dataset = dataset.remove_columns( -# [ -# col -# for col in dataset.column_names["train"] -# if col not in ["input", "output"] -# ] -# ) -# return dataset - -# # Load dataset. -# dataset = load_dataset("json", data_files=local_parser.input_data_json) -# dataset = format_dataset(dataset, args.dataset_format) -# dataset_labels = dataset["train"]["output"] - -# dataset = dataset["train"]["input"] - -# result = [] -# idx = 0 -# predict_batchsize = 2 -# nums_examples = len(dataset) -# # if nums_examples > 6: -# # nums_examples = 6 -# print(f"just test {nums_examples} examples\n") -# while idx < nums_examples: -# if idx + predict_batchsize < nums_examples: -# inputs = dataset[idx : idx + predict_batchsize] -# idx += predict_batchsize -# else: -# inputs = dataset[idx:nums_examples] -# idx = nums_examples -# encoded_inputs = tokenizer.batch_encode_plus( -# inputs, return_tensors="pt", padding=True, truncation=True, max_length=512 -# ) -# encoded_inputs = { -# name: tensor.to(device) for name, tensor in encoded_inputs.items() -# } - -# # support different type LLM -# if re.search(r"(?i)falcon", local_parser.base_model_name_or_path): -# generate_kwargs = { -# "input_ids": encoded_inputs["input_ids"], -# "attention_mask": encoded_inputs["attention_mask"], -# } -# outputs = model.generate(**generate_kwargs, max_length=512) -# elif re.search(r"(?i)llama", local_parser.base_model_name_or_path): -# outputs = model.generate(**encoded_inputs, max_length=512) -# else: -# print("right now,not support well") - -# # support the compared format directly ,like origin inputs: \n orgin outputs labels \n predict; -# for i, output in enumerate(outputs): -# input_idx = idx - predict_batchsize + i -# prediction = tokenizer.decode(output, skip_special_tokens=True) -# response = re.split(r"Response:\s*", prediction)[-1] -# # compose_i = "origin inputs:\t" + dataset[input_idx].replace("\n", "") + "\n"+"orgin outputs labels:\t" + dataset_labels[input_idx].replace( -# # "\n", "") + "\n"+"predict outputs labels:\t" + response.replace("\n", "") -# # test -# compose_i = response.replace("\n", "") -# print(f"compos_i \t {compose_i}") -# result.append(compose_i) -# print(result) -# print(idx) -# return args.dataset, result - - -# if __name__ == "__main__": -# dataset_name, result = predict() - -# # Judge path exists, if not need create -# if not os.path.exists(OUT_DIR): -# os.mkdir(OUT_DIR) - -# with open(local_parser.output_name, "w") as f: -# for p in result: -# f.write(p + "\n") diff --git a/dbgpt_hub/predict/predict_qlora_nf4_bit4.py b/dbgpt_hub/predict/predict_qlora_nf4_bit4.py deleted file mode 100644 index 479f49b..0000000 --- a/dbgpt_hub/predict/predict_qlora_nf4_bit4.py +++ /dev/null @@ -1,218 +0,0 @@ -# import re -# import os -# import argparse -# import torch -# import transformers -# from transformers import AutoTokenizer -# from transformers import set_seed, Seq2SeqTrainer, GenerationConfig -# from datasets import load_dataset - -# from dbgpt_hub.configs import ( -# DataArguments, -# GenerationArguments, -# LoraArguments, -# ModelArguments, -# QuantArguments, -# TrainingArguments, -# ) - -# from dbgpt_hub.llm_base import get_accelerate_model -# from dbgpt_hub.configs.config import MODEL_PATH, DEFAULT_FT_MODEL_NAME, OUT_DIR -# from dbgpt_hub.configs.data_args import DEFAULT_PROMPT_DICT,ALPACA_PROMPT_DICT,SQL_PROMPT_DICT - - -# def get_args(): -# parser = argparse.ArgumentParser() -# parser.add_argument( -# "--base_model_name_or_path", -# type=str, -# # default=os.path.join("MODEL_PATH, DEFAULT_FT_MODEL_NAME") -# default="/home/model_files/codellama/CodeLlama-7b-Instruct-hf", # 这个默认值没有用上。 -# ) -# parser.add_argument( -# "--peft_ckpt_path", type=str, default="output_pred/qlora_bit4/checkpoint-500" -# ) -# parser.add_argument("--input_data_json", type=str, default="dev_sql.json") -# parser.add_argument( -# "--output_name", -# type=str, -# # default=OUT_DIR + "/qlora_64_nf4_bit4.sql" -# default="output_pred/qlora_bit4/pred/qlora_64_nf4_bit4.sql", -# ) - -# return parser.parse_args() - - -# local_parser = get_args() - - - - -# def extract_alpaca_dataset(example): -# if example.get("input", "") != "": -# prompt_format = ALPACA_PROMPT_DICT["prompt_input"] -# else: -# prompt_format = ALPACA_PROMPT_DICT["prompt_no_input"] -# return {"input": prompt_format.format(**example)} - - -# def extract_sql_dataset(example): -# if example.get("input", "") != "": -# prompt_format = SQL_PROMPT_DICT["prompt_input"] -# else: -# prompt_format = SQL_PROMPT_DICT["prompt_no_input"] -# return {"input": prompt_format.format(**example)} - - -# def predict(): -# # parameters -# parser = transformers.HfArgumentParser( -# ( -# ModelArguments, -# DataArguments, -# TrainingArguments, -# LoraArguments, -# QuantArguments, -# GenerationArguments, -# ) -# ) -# ( -# model_args, -# data_args, -# training_args, -# lora_args, -# quant_args, -# generation_args, -# ) = parser.parse_args_into_dataclasses() -# # Check arguments (do not check finetuning_args since it may be loaded from checkpoints) -# # data_args.init_for_training() -# training_args.generation_config = GenerationConfig(**vars(generation_args)) -# import argparse - -# args = argparse.Namespace( -# **vars(model_args), -# **vars(data_args), -# **vars(training_args), -# **vars(lora_args), -# **vars(quant_args), -# ) - -# device = "cuda" if torch.cuda.is_available() else "cpu" -# model, tokenizer = get_accelerate_model(args, local_parser.peft_ckpt_path) -# model.config.use_cache = False - -# def format_dataset(dataset, dataset_format): -# if ( -# dataset_format == "alpaca" -# or dataset_format == "alpaca-clean" -# or (dataset_format is None and args.dataset in ["alpaca", "alpaca-clean"]) -# ): -# dataset = dataset.map( -# extract_alpaca_dataset, remove_columns=["instruction"] -# ) -# elif dataset_format == "spider": -# dataset = dataset.map(extract_sql_dataset, remove_columns=["instruction"]) -# elif dataset_format == "chip2" or ( -# dataset_format is None and args.dataset == "chip2" -# ): -# dataset = dataset.map( -# lambda x: { -# "input": x["text"].split("\n: ")[0].replace(": ", ""), -# "output": x["text"].split("\n: ")[1], -# } -# ) -# elif dataset_format == "self-instruct" or ( -# dataset_format is None and args.dataset == "self-instruct" -# ): -# for old, new in [["prompt", "input"], ["completion", "output"]]: -# dataset = dataset.rename_column(old, new) -# elif dataset_format == "hh-rlhf" or ( -# dataset_format is None and args.dataset == "hh-rlhf" -# ): -# dataset = dataset.map(lambda x: {"input": "", "output": x["chosen"]}) -# elif dataset_format == "oasst1" or ( -# dataset_format is None and args.dataset == "oasst1" -# ): -# dataset = dataset.map( -# lambda x: { -# "input": "", -# "output": x["text"], -# } -# ) -# elif dataset_format == "input-output": -# pass -# dataset = dataset.remove_columns( -# [ -# col -# for col in dataset.column_names["train"] -# if col not in ["input", "output"] -# ] -# ) -# return dataset - -# # Load dataset. -# dataset = load_dataset("json", data_files=local_parser.input_data_json) -# dataset = format_dataset(dataset, args.dataset_format) -# dataset_labels = dataset["train"]["output"] - -# dataset = dataset["train"]["input"] - -# result = [] -# idx = 0 -# predict_batchsize = 2 -# nums_examples = len(dataset) -# # if nums_examples > 6: -# # nums_examples = 6 -# print(f"just test {nums_examples} examples\n") -# while idx < nums_examples: -# if idx + predict_batchsize < nums_examples: -# inputs = dataset[idx : idx + predict_batchsize] -# idx += predict_batchsize -# else: -# inputs = dataset[idx:nums_examples] -# idx = nums_examples -# encoded_inputs = tokenizer.batch_encode_plus( -# inputs, return_tensors="pt", padding=True, truncation=True, max_length=512 -# ) -# encoded_inputs = { -# name: tensor.to(device) for name, tensor in encoded_inputs.items() -# } - -# # support different type LLM -# if re.search(r"(?i)falcon", local_parser.base_model_name_or_path): -# generate_kwargs = { -# "input_ids": encoded_inputs["input_ids"], -# "attention_mask": encoded_inputs["attention_mask"], -# } -# outputs = model.generate(**generate_kwargs, max_length=512) -# elif re.search(r"(?i)llama", local_parser.base_model_name_or_path): -# outputs = model.generate(**encoded_inputs, max_length=512) -# else: -# print("right now,not support well") - -# # support the compared format directly ,like origin inputs: \n orgin outputs labels \n predict; -# for i, output in enumerate(outputs): -# input_idx = idx - predict_batchsize + i -# prediction = tokenizer.decode(output, skip_special_tokens=True) -# response = re.split(r"Response:\s*", prediction)[-1] -# # compose_i = "origin inputs:\t" + dataset[input_idx].replace("\n", "") + "\n"+"orgin outputs labels:\t" + dataset_labels[input_idx].replace( -# # "\n", "") + "\n"+"predict outputs labels:\t" + response.replace("\n", "") -# # test -# compose_i = response.replace("\n", "") -# print(f"compos_i \t {compose_i}") -# result.append(compose_i) -# print(result) -# print(idx) -# return args.dataset, result - - -# if __name__ == "__main__": -# dataset_name, result = predict() - -# # Judge path exists, if not need create -# if not os.path.exists(OUT_DIR): -# os.mkdir(OUT_DIR) - -# with open(local_parser.output_name, "w") as f: -# for p in result: -# f.write(p + "\n")