diff --git a/example/drag.py b/example/drag.py new file mode 100644 index 000000000000..f82b7654a174 --- /dev/null +++ b/example/drag.py @@ -0,0 +1,588 @@ + +#!/usr/bin/python +# -*- coding: UTF-8 -*- + +import os +import paddle +# os.environ['CUDA_VISIBLE_DEVICES'] = '4' +device = "gpu:1" + +paddle.set_device(device) + +# import jsonlines +# import paddlenlp.transformers +from paddlenlp.transformers import AutoTokenizer, AutoModelForCausalLM + +import random +# import torch +import numpy as np +from tqdm import tqdm +import json +import argparse +import re +from tqdm import tqdm +import time +from paddlenlp.drag.utils import PROMPT_DICT, TASK_INST, load_jsonlines, control_tokens, load_special_tokens +from paddlenlp.drag.metrics import match, accuracy + +# +seed = 633 +# os.environ['TORCH_NCCL_AVOID_RECORD_STREAMS'] = '0' + +# torch.backends.cudnn.deterministic = True +random.seed(seed) +np.random.seed(seed) +paddle.seed(seed) + + + + +model_path = '/root/zzg/self-rag-main/model/llama3-8b-instruct-paddle' +model_path = 'meta-llama/Meta-Llama-3-8B-Instruct' +# model_path = '/root/zzg/llama3-8b-instruct' +# data_path = '/data/home/scv6140/run/hs/self-rag/retrieval_lm/eval_data/popqa/popqa_part_1.jsonl' +# data_path = '/data/home/scv6140/run/hs/self-rag/retrieval_lm/eval_data/popqa/popqa_part_2.jsonl' +# data_path = '/data/home/scv6140/run/hs/self-rag/retrieval_lm/eval_data/popqa/popqa_part_3.jsonl' +# data_path = '/data/home/scv6140/run/hs/self-rag/retrieval_lm/eval_data/popqa/popqa_part_4.jsonl' +# data_path = '/root/zzg/self-rag-main/datasets/eval_data/popqa_longtail.jsonl' +# data_path = '/data/home/scv6140/run/hs/self-rag/retrieval_lm/eval_data/popqa_longtail_retrieval_20.jsonl' +# data_path = '/root/zzg/self-rag-main/datasets/eval_data/popqa_longtail.jsonl' +# data_path = '/data/home/scv6140/run/hs/self-rag/retrieval_lm/eval_data/triviaqa_retrieval.jsonl' +# data_path = '/data/home/scv6140/run/hs/self-rag/retrieval_lm/eval_data/triviaqa_test.jsonl' +# data_path = '/data/home/scv6140/run/hs/self-rag/retrieval_lm/eval_data/test.jsonl' +out_path = '/root/zzg/self-rag-main/retrieval_lm/output/tmp.json' +metric = 'match' +ndocs = 10 +show_prev = False +# Qwen = "/root/zzg/self-rag-main/model/Qwen2-0.5B-instruct-paddle" +Qwen = "/root/zzg/self-rag-main/model/Qwen2-0.5B-instruct-paddle-v1" +# Qwen = "/root/zzg/self-rag-main/model/llama3-8b-instruct-paddle" +llama3_Tokenizer = AutoTokenizer.from_pretrained(model_path) +# ---------------------------------------------------------------------------------------don't modify content before this line +wo_decoding = False +wo_anayzer = False +w_irr_fix1 = False +use_conf = False + +qwen_model = AutoModelForCausalLM.from_pretrained(Qwen) +qwen_tokenizer = AutoTokenizer.from_pretrained(Qwen) + + +def postprocess_answer_option_conditioned(answer): + for token in control_tokens: + answer = answer.replace(token, "") + + if "" in answer: + answer = answer.replace("", "") + if "\n" in answer: + answer = answer.replace("\n", "") + + if "<|endoftext|>" in answer: + answer = answer.replace("<|endoftext|>", "") + + return answer + +def extract_elements(question, max_new_tokens=125): + sys_instruction = "You are an assistant in extracting key elements from a given question." + user_instruction = "Question: " + messages = [ + {"role": "system", "content": sys_instruction}, + {"role": "user", "content": user_instruction + '\n' + question} + ] + + text = qwen_tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + model_inputs = qwen_tokenizer([text], return_tensors="pd") + + # Generate response + generated_ids = qwen_model.generate( + **model_inputs, + max_new_tokens=max_new_tokens + ) + # print(tokenizer.batch_decode(outputs[0], skip_special_tokens=True)) + + # Process the generated output + response = qwen_tokenizer.batch_decode(generated_ids[0], skip_special_tokens=True)[0] + print("ELEMENTS: " + response, flush=True) + return response + +alpha_beta = None +def process_explanation(text): + essential_scores = list(map(float, re.findall(r'\(essential\): (\d+\.\d+)', text, re.IGNORECASE))) + initial_scores = list(map(float, re.findall(r'\(initial\): (\d+\.\d+)', text, re.IGNORECASE))) + supplementary_scores = list(map(float, re.findall(r'\(supplementary\): (\d+\.\d+)', text, re.IGNORECASE))) + + if alpha_beta is not None: + alpha, beta = alpha_beta + score_max = len(essential_scores) + alpha * len(initial_scores) + beta * len(supplementary_scores) + std_max = len(essential_scores) + 0.6 * len(initial_scores) + 0.3 * len(supplementary_scores) + score = sum(essential_scores) + alpha * sum(initial_scores) + beta * sum(supplementary_scores) + # print(len(essential_scores), len(initial_scores), len(supplementary_scores), score_max) + try: + final_score = score * std_max / score_max + except: + final_score = score + else: + final_score = sum(essential_scores) + 0.6 * sum(initial_scores) + 0.3 * sum(supplementary_scores) + lex_diver = len(essential_scores) + 1.2 * len(initial_scores) + len(supplementary_scores) + return final_score, lex_diver + +def score_paragraph(question, paragraph, elements, max_new_tokens=125): + sys_instruction = "You are an assistant in scoring paragraphs based on a given question and its associated elements." + user_instruction = "Question:\n Elements;\n Paragraphs:" + + paragraph_text = f"{paragraph['title']}\n{paragraph['text']}" + + my_input = ( + f"### Question: {question}\n" + f"### Element: {elements}\n" + f"### Paragraphs: {paragraph_text}\n" + ) + + messages = [ + {"role": "system", "content": sys_instruction}, + {"role": "user", "content": user_instruction + '\n' + my_input} + ] + + text = qwen_tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + # print(text) + model_inputs = qwen_tokenizer([text], return_tensors="pd") + + # Generate response + generated_ids = qwen_model.generate( + **model_inputs, + max_new_tokens=max_new_tokens + ) + + # Process the generated output + response = qwen_tokenizer.batch_decode(generated_ids[0], skip_special_tokens=True)[0] + print("SCORES: " + response, flush=True) + score, lex_diver = process_explanation(response) + return (score, lex_diver) + +def sort_para(question, paragraphs, max_new_tokens=125): + # paragraphs = paragraphs[:1] + elements = extract_elements(question, max_new_tokens) + scored_texts = [ + (para, *score_paragraph(question, para, elements, max_new_tokens)) + for para in paragraphs + ] + lex_diver = scored_texts[0][2] if len(scored_texts) != 0 else None + scored_texts = [(para, score) for para, score, ld in scored_texts if score is not False] + sorted_texts = sorted(scored_texts, key=lambda x: x[1], reverse=True) + # return [text for text, score in sorted_texts[:5]], [text for text, score in sorted_texts[-1:]] + return sorted_texts[:5], sorted_texts[-1:], lex_diver + + +def get_score_unsort(question, paragraphs, max_new_tokens=125): + elements = extract_elements(question, max_new_tokens) + scored_texts = [ + (para, *score_paragraph(question, para, elements, max_new_tokens)) + for para in paragraphs[:5] + ] + lex_diver = scored_texts[0][2] if len(scored_texts) != 0 else None + scored_texts = scored_texts[:5] + scores = [] + for para, score, ld in scored_texts: + scores.append(score if score is not False else 0.) + return scores, lex_diver + +def my_model_rank(my_input, evidences, model, tokenizer, max_new_tokens=50): + # subquery_num, subquery = Generate_Subquery(my_input, model, max_new_tokens) + if 1: + # extracted_info = generate_Triplet(my_input, max_new_tokens) + # subject = extracted_info['subject'] + # relationship = extracted_info['relationship'] + + paddle.device.cuda.synchronize() + part1_time = time.perf_counter() + relevant_para = [] + irrelevant_para = [] + results = {} + decoding_flag = not wo_decoding + + + if not wo_anayzer: + # evidences, irr_evidences=sort_para(my_input,evidences) + top5, bottom1, lex_diver = sort_para(my_input,evidences) + evidences = [text for text, score in top5] + scores = [score for text, score in top5] + # if top5[0][1] < 1.6 and bottom1[0][1] < 1: + if 1: + irr_evidences = [text for text, score in bottom1] + else: + if decoding_flag: + # TODO + scores, lex_diver = get_score_unsort(my_input, evidences) + pass + evidences = evidences[:5] + irr_evidences = evidences[-1:] + if decoding_flag: + model.config.decode_strategy = "sampling_irr" + model.generation_config.decode_strategy = "sampling_irr" + set_para_ano( + lex_diver if not use_conf else -1, + scores + ) + + + if w_irr_fix1: + # irr_evidences=[ + # {"title": "Ethnic groups in Rwanda", "text": "divert the emphasis from ethnicity to a division of the population into categories of victim, victors, survivors, and perpetrators. However, in identifying victims and survivors, some Rwandans are left to be identified as perpetrators. This becomes increasingly problematic as all Hutus are deemed perpetrators—where their survival of the genocide seems to imply some form of complicity with the former government. Thus, in this process of rebuilding and bringing guilty parties to justice, the current government is providing dangling linkages back to the very ethnicities they wish to abolish and is risking further entrenching supposed “past” ethnic divisions. Furthermore, government policy"} + # ] + irr_evidences =[ + {"title": "Rebirth (Buddhism)", "text": "Rebirth (Buddhism) Rebirth in Buddhism refers to its teaching that the actions of a person lead to a new existence after death, in endless cycles called \"saṃsāra\". This cycle is considered to be \"dukkha\", unsatisfactory and painful. The cycle stops only if liberation is achieved by insight and the extinguishing of desire. Rebirth is one of the foundational doctrines of Buddhism, along with Karma, nirvana and moksha. The rebirth doctrine in Buddhism, sometimes referred to as reincarnation or metempsychosis, asserts that rebirth does not necessarily take place as another human being, but as an existence in one of the six"} + ] + for evidence in evidences: + relevant_para.append("[Retrieval]{0}\n{1}".format(evidence["title"], evidence["text"])) + + + paddle.device.cuda.synchronize + part1_time = time.perf_counter() - part1_time + # relevant_para = relevant_para[:5] + + paddle.device.cuda.synchronize + part2_time = time.perf_counter() + if not relevant_para: + sys_msg=PROMPT_DICT["prompt_for_combine_no_retrieval"][0]['content'] + user_msg=PROMPT_DICT["prompt_for_combine_no_retrieval"][1]['content'].format(instruction=my_input) + msgs = [ + {"role":"system","content":sys_msg}, + {"role":"user","content":user_msg} + ] + prompt = llama3_Tokenizer.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False) + else: + sys_msg=PROMPT_DICT["prompt_for_combine"][0]['content'] + user_msg=PROMPT_DICT["prompt_for_combine"][1]['content'].format(instruction=my_input,paragraphs='\n'.join(relevant_para)) + msgs = [ + {"role":"system","content":sys_msg}, + {"role":"user","content":user_msg} + ] + prompt = llama3_Tokenizer.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False) + + token_inputs = tokenizer([prompt], return_tensors="pd") + + inputs_ids_irr = None + + + if decoding_flag and lex_diver is not None: + for irr_evidence in irr_evidences: + irrelevant_para.append("[Retrieval]{0}\n{1}".format(irr_evidence["title"], irr_evidence["text"])) + sys_msg_irr = PROMPT_DICT["prompt_for_combine"][0]['content'] + user_msg_irr = PROMPT_DICT["prompt_for_combine"][1]['content'].format(instruction=my_input,paragraphs='\n'.join(irrelevant_para)) + # user_msg_irr = PROMPT_DICT["prompt_for_combine"][1]['content'].format(instruction=my_input,subject=subject,relationship=relationship,paragraphs=irrelevant_para[0]) + msgs_irr = [ + {"role":"system","content":sys_msg_irr}, + {"role":"user","content":user_msg_irr} + ] + prompt_irr = llama3_Tokenizer.apply_chat_template(msgs_irr, add_generation_prompt=True, tokenize=False) + # prompt_irr = PROMPT_DICT['prompt_no_input_retrieval'].format(paragraph='\n'.join(irrelevant_para),instruction=my_input) + token_inputs_irr = tokenizer([prompt_irr], return_tensors="pd") + inputs_ids_irr = token_inputs_irr.input_ids + + # 计算question(含)之前的token长度 + question_only_msg = PROMPT_DICT["prompt_for_combine"][1]['content'].format(instruction=my_input,paragraphs='\n') + question_msgs = [ + {"role":"system","content":sys_msg}, + {"role":"user","content":question_only_msg} + ] + question_only_prompt = llama3_Tokenizer.apply_chat_template(question_msgs, add_generation_prompt=True, tokenize=False) + question_token_inputs = tokenizer([question_only_prompt], return_tensors="pd") + question_token_len = question_token_inputs.input_ids.shape[1] + + + question_and_doc_len = [] + for i in range(len(relevant_para)): + _sys_msg = PROMPT_DICT["prompt_for_combine"][0]['content'] + _user_msg = PROMPT_DICT["prompt_for_combine"][1]['content'].format(instruction=my_input,paragraphs='\n'.join(relevant_para[:(i + 1)])) + _msgs = [ + {"role":"system","content":_sys_msg}, + {"role":"user","content":_user_msg} + ] + _prompt = llama3_Tokenizer.apply_chat_template(_msgs, add_generation_prompt=True, tokenize=False) + + _token_inputs = tokenizer([_prompt], return_tensors="pd") + + question_and_doc_len.append(_token_inputs.input_ids.shape[1]) + + print(question_token_len, question_and_doc_len, flush=True) + generated_ids = model.generate( + **token_inputs, + max_new_tokens=512, + use_irr = decoding_flag, + inputs_irr=inputs_ids_irr, + alpha=3, + question_token_len = question_token_len, + question_and_doc_len = question_and_doc_len, + ) + # generated_ids = [ + # output_ids[len(input_ids):] for input_ids, output_ids in zip(token_inputs.input_ids, generated_ids) + # ] + output_text = tokenizer.batch_decode(generated_ids[0], skip_special_tokens=True)[0] + else: + generated_ids_prev = model.generate( + **token_inputs, + max_new_tokens=512, + ) + # generated_ids_prev = [ + # output_ids[len(input_ids):] for input_ids, output_ids in zip(token_inputs.input_ids, generated_ids_prev) + # ] + output_text_prev = tokenizer.batch_decode(generated_ids_prev[0], skip_special_tokens=True)[0] + output_text = output_text_prev + + paddle.device.cuda.synchronize + part2_time = time.perf_counter() - part2_time + + print('-------------------------------PROMPT-------------------------------', flush=True) + print(prompt, flush=True) + if 'prompt_irr' in locals().keys(): + print('-----------------------------IRR_PROMPT-----------------------------', flush=True) + print(prompt_irr, flush=True) + + print(">>>>>>>>>>>>>>>OUTPUT:", output_text, flush=True) + print(f"this_time: {part1_time} | {part2_time}", flush=True) + # print("OUTPUT_TOKEN:", generated_ids or generated_ids_prev, flush=True) + + return output_text, results, part1_time, part2_time + +def process_data_evidences(demonstration, top_n): + ctx_key = "ctxs" if "ctxs" in demonstration else "top_contexts" + # prompt = PROMPT_DICT["prompt_no_input"].format_map(demonstration) + evidences = demonstration[ctx_key][:top_n] + return ctx_key, evidences + + +def preprocess_input_data(dataset, task=None): + new_data = [] + if task in TASK_INST: + instruction = TASK_INST[task] + else: + instruction = None + for item in dataset: + if task == "arc_c": + choices = item["choices"] + answer_labels = {} + for i in range(len(choices["label"])): + answer_key = choices["label"][i] + text = choices["text"][i] + if answer_key == "1": + answer_labels["A"] = text + if answer_key == "2": + answer_labels["B"] = text + if answer_key == "3": + answer_labels["C"] = text + if answer_key == "4": + answer_labels["D"] = text + if answer_key in ["A", "B", "C", "D"]: + answer_labels[answer_key] = text + + if "D" not in answer_labels: + answer_labels["D"] = "" + choices = "\nA: {0}\nB: {1}\nC: {2}\nD: {3}".format( + answer_labels["A"], answer_labels["B"], answer_labels["C"], answer_labels["D"]) + if "E" in answer_labels: + choices += "\nE: {}".format(answer_labels["E"]) + item["instruction"] = instruction + \ + "\n\n### Input:\n" + item["question"] + choices + item["answers"] = [item["answerKey"]] + else: + prompt = instruction + "\n\n## Input:\n\n" + \ + item["question"] if instruction is not None else item["question"] + item["instruction"] = prompt + new_data.append(item) + + return new_data + + +from paddlenlp.drag.drag_generate import set_para, set_para_ano, Drag +# def set_para(*args): + # pass +# def set_para_ano(*args): + # pass + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--model_name', type=str) + parser.add_argument('--input_file', type=str, default='/data/home/scv6140/run/hs/self-rag/retrieval_lm/eval_data/triviaqa/triviaqa_part_1.jsonl') + parser.add_argument('--output_file', type=str) + parser.add_argument('--task', type=str) + parser.add_argument('--device', type=str, default="cuda") + parser.add_argument('--max_new_tokens', type=int, default=15) + parser.add_argument('--tokenizer_path', type=str) + parser.add_argument('--download_dir', type=str, help="specify vllm model download dir", + default=".cache") + parser.add_argument("--ndocs", type=int, default=10, + help="Number of documents to retrieve per questions") + parser.add_argument("--world_size", type=int, default=1, + help="world size to use multiple GPUs.") + parser.add_argument("--dtype", type=str, default="half", + help="We use bfloat16 for training. If you run inference on GPUs that do not support BF16, please set this to be `half`.") + # Decoding hyperparams + parser.add_argument('--threshold', type=float, + default=None, help="Adaptive threshold.") + parser.add_argument("--use_seqscore", action="store_true") + parser.add_argument("--use_groundness", action="store_true", + help="use ground score") + parser.add_argument( + "--use_utility", action="store_true", help="tree search") + parser.add_argument("--beam_width", type=int, + default=2, help="beam search width") + parser.add_argument("--max_depth", type=int, + default=2, help="tree depth width") + parser.add_argument("--w_rel", type=float, default=1.0, + help="reward weight for document relevance") + parser.add_argument("--w_sup", type=float, default=1.0, + help="reward weight for generation support (attribution)") + parser.add_argument("--w_use", type=float, default=1.0, + help="reward weight for overall completeness / utility.") + parser.add_argument('--mode', type=str, help="mode to control retrieval.", + default="default", choices=['adaptive_retrieval', 'no_retrieval', 'always_retrieve'],) + parser.add_argument('--metric', type=str, help="metric to be used during evaluation") + parser.add_argument("--thresh", type=float, default=0.8) + parser.add_argument("--use_conf", action="store_true") + parser.add_argument('--extra', type=str, default='') + parser.add_argument('--alpha_beta', type=str, default=None) + parser.add_argument("--wo_dec", action="store_true") + parser.add_argument("--wo_ana", action="store_true") + parser.add_argument("--start_round", type=int, default=0) + args = parser.parse_args() + + ## debug + args.model_name = model_path + # args.input_file = data_path + if args.alpha_beta is not None: + alpha, beta = args.alpha_beta.split('_', 1) + args.extra += f"alpha{alpha}_beta{beta}" + try: + global alpha_beta + alpha_beta = float(alpha), float(beta) + except: + print("alpha_beta format wrong, which is", args.alpha_beta) + exit() + + if args.wo_dec: + args.extra += f"_nodec" + global wo_decoding + wo_decoding = True + if args.wo_ana: + args.extra += f"_noana" + global wo_anayzer + wo_anayzer = True + if args.start_round != 0: + args.extra += f"_st_round_{args.start_round}" + + + args.max_new_tokens = 100 + args.output_file = out_path + args.metric = 'match' + args.ndocs = ndocs + args.dtype = 'half' + print(args, flush=True) + set_para(args.thresh, "full", 0, 0, 0, False, args.extra) + if args.use_conf: + global use_conf + use_conf = True + + gpt = args.model_name + input_path = args.input_file + if input_path.endswith(".json"): + input_data = json.load(open(input_path)) + else: + input_data = load_jsonlines(input_path) + + input_data = preprocess_input_data( + input_data, task=args.task) + + tokenizer = AutoTokenizer.from_pretrained(gpt, padding_side="left") + model = AutoModelForCausalLM.from_pretrained(gpt, low_cpu_mem_usage=True) + print(model.__class__) + model.__class__ = Drag + model.check() + + + def generate(prompt, evidences, max_new_tokens): + return my_model_rank(prompt, evidences, model=model, tokenizer=tokenizer, max_new_tokens=max_new_tokens) + + preds = [] + prompts = [] + golds = [] + metric_results = [] + scores = [] + all_results = [] + count = 0 + sum_time = [0, 0] + run_round = 0 + missing_round = [] + for i, row in tqdm(enumerate(input_data)): + if i < args.start_round: + continue + results = {} + my_input = row['instruction'] + _, evidences = process_data_evidences(row, top_n=args.ndocs) + + # try: + # pred, results, t1, t2 = generate( + # my_input, evidences, max_new_tokens=args.max_new_tokens,) + # except: + # missing_round.append(i) + # print("MISSING ROUND:", missing_round, flush=True) + # continue + + + pred, results, t1, t2 = generate( + my_input, evidences, max_new_tokens=args.max_new_tokens,) + + sum_time[0] += t1 + sum_time[1] += t2 + run_round += 1 + print(f"avg_time: {sum_time[0] / run_round} | {sum_time[1] / run_round}", flush=True) + if type(pred) is str and len(pred)>0 and (pred[0] == "#" or pred[0] == ":"): + pred = pred[1:] + prompts.append(my_input) + preds.append(pred) + all_results.append(results) + # if do_retrieve is True: + # count += 1 + if "answers" not in row and "answer" in row: + row["answers"] = [row["answer"]] if type( + row["answer"]) is str else row["answer"] + if args.metric == "accuracy": + metric_result = accuracy(pred, row["output"]) + + elif args.metric == "match": + if "SUPPORTS" in pred: + pred = "true" + elif "REFUTES" in pred: + pred = "false" + metric_result = match(pred, row["answers"]) + else: + raise NotImplementedError + + metric_results.append(metric_result) + if i % 10 == 0: + print("average: {}".format(np.mean(metric_results)), flush=True) + final_results = {"preds": preds, "prompts": prompts, "metric_results": metric_results, "all_results": all_results, + "golds": golds, "metric": args.metric, "metric_mean": np.mean(metric_results), "scores": scores} + with open(args.output_file + "_tmp", "w") as outfile: + json.dump(final_results, outfile) + + final_results = {"preds": preds, "prompts": prompts, "metric_results": metric_results, "all_results": all_results, + "golds": golds, "metric": args.metric, "metric_mean": np.mean(metric_results), "scores": scores} + with open(args.output_file, "w") as outfile: + json.dump(final_results, outfile) + + print("Final result: {0}".format(np.mean(metric_results)), flush=True) + print("MISSING ROUND:", missing_round, flush=True) + #print("Retrieval Frequencies: {0}".format(count / len(final_results))) + print(metric_results, flush=True) + +if __name__ == "__main__": + main() diff --git a/paddlenlp/drag/__init__.py b/paddlenlp/drag/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/paddlenlp/drag/drag.py b/paddlenlp/drag/drag.py new file mode 100644 index 000000000000..6cd6c79462f3 --- /dev/null +++ b/paddlenlp/drag/drag.py @@ -0,0 +1,584 @@ + +#!/usr/bin/python +# -*- coding: UTF-8 -*- + +import os +import paddle +# os.environ['CUDA_VISIBLE_DEVICES'] = '4' +device = "gpu:0" + +paddle.set_device(device) + +# import jsonlines +# import paddlenlp.transformers +from paddlenlp.transformers import AutoTokenizer, AutoModelForCausalLM + +import random +# import torch +import numpy as np +from tqdm import tqdm +import json +import argparse +import re +from tqdm import tqdm +import time +from paddlenlp.drag.utils import PROMPT_DICT, TASK_INST, load_jsonlines, control_tokens, load_special_tokens +from paddlenlp.drag.metrics import match, accuracy + +# +seed = 633 +# os.environ['TORCH_NCCL_AVOID_RECORD_STREAMS'] = '0' + +# torch.backends.cudnn.deterministic = True +random.seed(seed) +np.random.seed(seed) +paddle.seed(seed) + + + + +model_path = '/root/zzg/self-rag-main/model/llama3-8b-instruct-paddle' +model_path = 'meta-llama/Meta-Llama-3-8B-Instruct' +# model_path = '/root/zzg/llama3-8b-instruct' +# data_path = '/data/home/scv6140/run/hs/self-rag/retrieval_lm/eval_data/popqa/popqa_part_1.jsonl' +# data_path = '/data/home/scv6140/run/hs/self-rag/retrieval_lm/eval_data/popqa/popqa_part_2.jsonl' +# data_path = '/data/home/scv6140/run/hs/self-rag/retrieval_lm/eval_data/popqa/popqa_part_3.jsonl' +# data_path = '/data/home/scv6140/run/hs/self-rag/retrieval_lm/eval_data/popqa/popqa_part_4.jsonl' +# data_path = '/root/zzg/self-rag-main/datasets/eval_data/popqa_longtail.jsonl' +# data_path = '/data/home/scv6140/run/hs/self-rag/retrieval_lm/eval_data/popqa_longtail_retrieval_20.jsonl' +# data_path = '/root/zzg/self-rag-main/datasets/eval_data/popqa_longtail.jsonl' +# data_path = '/data/home/scv6140/run/hs/self-rag/retrieval_lm/eval_data/triviaqa_retrieval.jsonl' +# data_path = '/data/home/scv6140/run/hs/self-rag/retrieval_lm/eval_data/triviaqa_test.jsonl' +# data_path = '/data/home/scv6140/run/hs/self-rag/retrieval_lm/eval_data/test.jsonl' +out_path = '/root/zzg/self-rag-main/retrieval_lm/output/tmp.json' +metric = 'match' +ndocs = 10 +show_prev = False +# Qwen = "/root/zzg/self-rag-main/model/Qwen2-0.5B-instruct-paddle" +Qwen = "/root/zzg/self-rag-main/model/Qwen2-0.5B-instruct-paddle-v1" +# Qwen = "/root/zzg/self-rag-main/model/llama3-8b-instruct-paddle" +llama3_Tokenizer = AutoTokenizer.from_pretrained(model_path) +# ---------------------------------------------------------------------------------------don't modify content before this line +wo_decoding = False +wo_anayzer = False +w_irr_fix1 = False +use_conf = False + +qwen_model = AutoModelForCausalLM.from_pretrained(Qwen) +qwen_tokenizer = AutoTokenizer.from_pretrained(Qwen) + + +def postprocess_answer_option_conditioned(answer): + for token in control_tokens: + answer = answer.replace(token, "") + + if "" in answer: + answer = answer.replace("", "") + if "\n" in answer: + answer = answer.replace("\n", "") + + if "<|endoftext|>" in answer: + answer = answer.replace("<|endoftext|>", "") + + return answer + +def extract_elements(question, max_new_tokens=125): + sys_instruction = "You are an assistant in extracting key elements from a given question." + user_instruction = "Question: " + messages = [ + {"role": "system", "content": sys_instruction}, + {"role": "user", "content": user_instruction + '\n' + question} + ] + + text = qwen_tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + model_inputs = qwen_tokenizer([text], return_tensors="pd") + + # Generate response + generated_ids = qwen_model.generate( + **model_inputs, + max_new_tokens=max_new_tokens + ) + # print(tokenizer.batch_decode(outputs[0], skip_special_tokens=True)) + + # Process the generated output + response = qwen_tokenizer.batch_decode(generated_ids[0], skip_special_tokens=True)[0] + print("ELEMENTS: " + response, flush=True) + return response + +alpha_beta = None +def process_explanation(text): + essential_scores = list(map(float, re.findall(r'\(essential\): (\d+\.\d+)', text, re.IGNORECASE))) + initial_scores = list(map(float, re.findall(r'\(initial\): (\d+\.\d+)', text, re.IGNORECASE))) + supplementary_scores = list(map(float, re.findall(r'\(supplementary\): (\d+\.\d+)', text, re.IGNORECASE))) + + if alpha_beta is not None: + alpha, beta = alpha_beta + score_max = len(essential_scores) + alpha * len(initial_scores) + beta * len(supplementary_scores) + std_max = len(essential_scores) + 0.6 * len(initial_scores) + 0.3 * len(supplementary_scores) + score = sum(essential_scores) + alpha * sum(initial_scores) + beta * sum(supplementary_scores) + # print(len(essential_scores), len(initial_scores), len(supplementary_scores), score_max) + try: + final_score = score * std_max / score_max + except: + final_score = score + else: + final_score = sum(essential_scores) + 0.6 * sum(initial_scores) + 0.3 * sum(supplementary_scores) + lex_diver = len(essential_scores) + 1.2 * len(initial_scores) + len(supplementary_scores) + return final_score, lex_diver + +def score_paragraph(question, paragraph, elements, max_new_tokens=125): + sys_instruction = "You are an assistant in scoring paragraphs based on a given question and its associated elements." + user_instruction = "Question:\n Elements;\n Paragraphs:" + + paragraph_text = f"{paragraph['title']}\n{paragraph['text']}" + + my_input = ( + f"### Question: {question}\n" + f"### Element: {elements}\n" + f"### Paragraphs: {paragraph_text}\n" + ) + + messages = [ + {"role": "system", "content": sys_instruction}, + {"role": "user", "content": user_instruction + '\n' + my_input} + ] + + text = qwen_tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + # print(text) + model_inputs = qwen_tokenizer([text], return_tensors="pd") + + # Generate response + generated_ids = qwen_model.generate( + **model_inputs, + max_new_tokens=max_new_tokens + ) + + # Process the generated output + response = qwen_tokenizer.batch_decode(generated_ids[0], skip_special_tokens=True)[0] + print("SCORES: " + response, flush=True) + score, lex_diver = process_explanation(response) + return (score, lex_diver) + +def sort_para(question, paragraphs, max_new_tokens=125): + # paragraphs = paragraphs[:1] + elements = extract_elements(question, max_new_tokens) + scored_texts = [ + (para, *score_paragraph(question, para, elements, max_new_tokens)) + for para in paragraphs + ] + lex_diver = scored_texts[0][2] if len(scored_texts) != 0 else None + scored_texts = [(para, score) for para, score, ld in scored_texts if score is not False] + sorted_texts = sorted(scored_texts, key=lambda x: x[1], reverse=True) + # return [text for text, score in sorted_texts[:5]], [text for text, score in sorted_texts[-1:]] + return sorted_texts[:5], sorted_texts[-1:], lex_diver + + +def get_score_unsort(question, paragraphs, max_new_tokens=125): + elements = extract_elements(question, max_new_tokens) + scored_texts = [ + (para, *score_paragraph(question, para, elements, max_new_tokens)) + for para in paragraphs[:5] + ] + lex_diver = scored_texts[0][2] if len(scored_texts) != 0 else None + scored_texts = scored_texts[:5] + scores = [] + for para, score, ld in scored_texts: + scores.append(score if score is not False else 0.) + return scores, lex_diver + +def my_model_rank(my_input, evidences, model, tokenizer, max_new_tokens=50): + # subquery_num, subquery = Generate_Subquery(my_input, model, max_new_tokens) + if 1: + # extracted_info = generate_Triplet(my_input, max_new_tokens) + # subject = extracted_info['subject'] + # relationship = extracted_info['relationship'] + + paddle.device.cuda.synchronize() + part1_time = time.perf_counter() + relevant_para = [] + irrelevant_para = [] + results = {} + decoding_flag = not wo_decoding + + + if not wo_anayzer: + # evidences, irr_evidences=sort_para(my_input,evidences) + top5, bottom1, lex_diver = sort_para(my_input,evidences) + evidences = [text for text, score in top5] + scores = [score for text, score in top5] + # if top5[0][1] < 1.6 and bottom1[0][1] < 1: + if 1: + irr_evidences = [text for text, score in bottom1] + else: + if decoding_flag: + # TODO + scores, lex_diver = get_score_unsort(my_input, evidences) + pass + evidences = evidences[:5] + irr_evidences = evidences[-1:] + if decoding_flag: + model.config.decode_strategy = "sampling_irr" + model.generation_config.decode_strategy = "sampling_irr" + set_para_ano( + lex_diver if not use_conf else -1, + scores + ) + + + if w_irr_fix1: + # irr_evidences=[ + # {"title": "Ethnic groups in Rwanda", "text": "divert the emphasis from ethnicity to a division of the population into categories of victim, victors, survivors, and perpetrators. However, in identifying victims and survivors, some Rwandans are left to be identified as perpetrators. This becomes increasingly problematic as all Hutus are deemed perpetrators—where their survival of the genocide seems to imply some form of complicity with the former government. Thus, in this process of rebuilding and bringing guilty parties to justice, the current government is providing dangling linkages back to the very ethnicities they wish to abolish and is risking further entrenching supposed “past” ethnic divisions. Furthermore, government policy"} + # ] + irr_evidences =[ + {"title": "Rebirth (Buddhism)", "text": "Rebirth (Buddhism) Rebirth in Buddhism refers to its teaching that the actions of a person lead to a new existence after death, in endless cycles called \"saṃsāra\". This cycle is considered to be \"dukkha\", unsatisfactory and painful. The cycle stops only if liberation is achieved by insight and the extinguishing of desire. Rebirth is one of the foundational doctrines of Buddhism, along with Karma, nirvana and moksha. The rebirth doctrine in Buddhism, sometimes referred to as reincarnation or metempsychosis, asserts that rebirth does not necessarily take place as another human being, but as an existence in one of the six"} + ] + for evidence in evidences: + relevant_para.append("[Retrieval]{0}\n{1}".format(evidence["title"], evidence["text"])) + + + paddle.device.cuda.synchronize + part1_time = time.perf_counter() - part1_time + # relevant_para = relevant_para[:5] + + paddle.device.cuda.synchronize + part2_time = time.perf_counter() + if not relevant_para: + sys_msg=PROMPT_DICT["prompt_for_combine_no_retrieval"][0]['content'] + user_msg=PROMPT_DICT["prompt_for_combine_no_retrieval"][1]['content'].format(instruction=my_input) + msgs = [ + {"role":"system","content":sys_msg}, + {"role":"user","content":user_msg} + ] + prompt = llama3_Tokenizer.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False) + else: + sys_msg=PROMPT_DICT["prompt_for_combine"][0]['content'] + user_msg=PROMPT_DICT["prompt_for_combine"][1]['content'].format(instruction=my_input,paragraphs='\n'.join(relevant_para)) + msgs = [ + {"role":"system","content":sys_msg}, + {"role":"user","content":user_msg} + ] + prompt = llama3_Tokenizer.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False) + + token_inputs = tokenizer([prompt], return_tensors="pd") + + inputs_ids_irr = None + + + if decoding_flag and lex_diver is not None: + for irr_evidence in irr_evidences: + irrelevant_para.append("[Retrieval]{0}\n{1}".format(irr_evidence["title"], irr_evidence["text"])) + sys_msg_irr = PROMPT_DICT["prompt_for_combine"][0]['content'] + user_msg_irr = PROMPT_DICT["prompt_for_combine"][1]['content'].format(instruction=my_input,paragraphs='\n'.join(irrelevant_para)) + # user_msg_irr = PROMPT_DICT["prompt_for_combine"][1]['content'].format(instruction=my_input,subject=subject,relationship=relationship,paragraphs=irrelevant_para[0]) + msgs_irr = [ + {"role":"system","content":sys_msg_irr}, + {"role":"user","content":user_msg_irr} + ] + prompt_irr = llama3_Tokenizer.apply_chat_template(msgs_irr, add_generation_prompt=True, tokenize=False) + # prompt_irr = PROMPT_DICT['prompt_no_input_retrieval'].format(paragraph='\n'.join(irrelevant_para),instruction=my_input) + token_inputs_irr = tokenizer([prompt_irr], return_tensors="pd") + inputs_ids_irr = token_inputs_irr.input_ids + + # 计算question(含)之前的token长度 + question_only_msg = PROMPT_DICT["prompt_for_combine"][1]['content'].format(instruction=my_input,paragraphs='\n') + question_msgs = [ + {"role":"system","content":sys_msg}, + {"role":"user","content":question_only_msg} + ] + question_only_prompt = llama3_Tokenizer.apply_chat_template(question_msgs, add_generation_prompt=True, tokenize=False) + question_token_inputs = tokenizer([question_only_prompt], return_tensors="pd") + question_token_len = question_token_inputs.input_ids.shape[1] + + + question_and_doc_len = [] + for i in range(len(relevant_para)): + _sys_msg = PROMPT_DICT["prompt_for_combine"][0]['content'] + _user_msg = PROMPT_DICT["prompt_for_combine"][1]['content'].format(instruction=my_input,paragraphs='\n'.join(relevant_para[:(i + 1)])) + _msgs = [ + {"role":"system","content":_sys_msg}, + {"role":"user","content":_user_msg} + ] + _prompt = llama3_Tokenizer.apply_chat_template(_msgs, add_generation_prompt=True, tokenize=False) + + _token_inputs = tokenizer([_prompt], return_tensors="pd") + + question_and_doc_len.append(_token_inputs.input_ids.shape[1]) + + print(question_token_len, question_and_doc_len, flush=True) + generated_ids = model.generate( + **token_inputs, + max_new_tokens=512, + use_irr = decoding_flag, + inputs_irr=inputs_ids_irr, + alpha=3, + question_token_len = question_token_len, + question_and_doc_len = question_and_doc_len, + ) + # generated_ids = [ + # output_ids[len(input_ids):] for input_ids, output_ids in zip(token_inputs.input_ids, generated_ids) + # ] + output_text = tokenizer.batch_decode(generated_ids[0], skip_special_tokens=True)[0] + else: + generated_ids_prev = model.generate( + **token_inputs, + max_new_tokens=512, + ) + # generated_ids_prev = [ + # output_ids[len(input_ids):] for input_ids, output_ids in zip(token_inputs.input_ids, generated_ids_prev) + # ] + output_text_prev = tokenizer.batch_decode(generated_ids_prev[0], skip_special_tokens=True)[0] + output_text = output_text_prev + + paddle.device.cuda.synchronize + part2_time = time.perf_counter() - part2_time + + print('-------------------------------PROMPT-------------------------------', flush=True) + print(prompt, flush=True) + if 'prompt_irr' in locals().keys(): + print('-----------------------------IRR_PROMPT-----------------------------', flush=True) + print(prompt_irr, flush=True) + + print(">>>>>>>>>>>>>>>OUTPUT:", output_text, flush=True) + print(f"this_time: {part1_time} | {part2_time}", flush=True) + # print("OUTPUT_TOKEN:", generated_ids or generated_ids_prev, flush=True) + + return output_text, results, part1_time, part2_time + +def process_data_evidences(demonstration, top_n): + ctx_key = "ctxs" if "ctxs" in demonstration else "top_contexts" + # prompt = PROMPT_DICT["prompt_no_input"].format_map(demonstration) + evidences = demonstration[ctx_key][:top_n] + return ctx_key, evidences + + +def preprocess_input_data(dataset, task=None): + new_data = [] + if task in TASK_INST: + instruction = TASK_INST[task] + else: + instruction = None + for item in dataset: + if task == "arc_c": + choices = item["choices"] + answer_labels = {} + for i in range(len(choices["label"])): + answer_key = choices["label"][i] + text = choices["text"][i] + if answer_key == "1": + answer_labels["A"] = text + if answer_key == "2": + answer_labels["B"] = text + if answer_key == "3": + answer_labels["C"] = text + if answer_key == "4": + answer_labels["D"] = text + if answer_key in ["A", "B", "C", "D"]: + answer_labels[answer_key] = text + + if "D" not in answer_labels: + answer_labels["D"] = "" + choices = "\nA: {0}\nB: {1}\nC: {2}\nD: {3}".format( + answer_labels["A"], answer_labels["B"], answer_labels["C"], answer_labels["D"]) + if "E" in answer_labels: + choices += "\nE: {}".format(answer_labels["E"]) + item["instruction"] = instruction + \ + "\n\n### Input:\n" + item["question"] + choices + item["answers"] = [item["answerKey"]] + else: + prompt = instruction + "\n\n## Input:\n\n" + \ + item["question"] if instruction is not None else item["question"] + item["instruction"] = prompt + new_data.append(item) + + return new_data + + +from paddlenlp.generation.utils import set_para, set_para_ano +# def set_para(*args): + # pass +# def set_para_ano(*args): + # pass + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--model_name', type=str) + parser.add_argument('--input_file', type=str, default='/data/home/scv6140/run/hs/self-rag/retrieval_lm/eval_data/triviaqa/triviaqa_part_1.jsonl') + parser.add_argument('--output_file', type=str) + parser.add_argument('--task', type=str) + parser.add_argument('--device', type=str, default="cuda") + parser.add_argument('--max_new_tokens', type=int, default=15) + parser.add_argument('--tokenizer_path', type=str) + parser.add_argument('--download_dir', type=str, help="specify vllm model download dir", + default=".cache") + parser.add_argument("--ndocs", type=int, default=10, + help="Number of documents to retrieve per questions") + parser.add_argument("--world_size", type=int, default=1, + help="world size to use multiple GPUs.") + parser.add_argument("--dtype", type=str, default="half", + help="We use bfloat16 for training. If you run inference on GPUs that do not support BF16, please set this to be `half`.") + # Decoding hyperparams + parser.add_argument('--threshold', type=float, + default=None, help="Adaptive threshold.") + parser.add_argument("--use_seqscore", action="store_true") + parser.add_argument("--use_groundness", action="store_true", + help="use ground score") + parser.add_argument( + "--use_utility", action="store_true", help="tree search") + parser.add_argument("--beam_width", type=int, + default=2, help="beam search width") + parser.add_argument("--max_depth", type=int, + default=2, help="tree depth width") + parser.add_argument("--w_rel", type=float, default=1.0, + help="reward weight for document relevance") + parser.add_argument("--w_sup", type=float, default=1.0, + help="reward weight for generation support (attribution)") + parser.add_argument("--w_use", type=float, default=1.0, + help="reward weight for overall completeness / utility.") + parser.add_argument('--mode', type=str, help="mode to control retrieval.", + default="default", choices=['adaptive_retrieval', 'no_retrieval', 'always_retrieve'],) + parser.add_argument('--metric', type=str, help="metric to be used during evaluation") + parser.add_argument("--thresh", type=float, default=0.8) + parser.add_argument("--use_conf", action="store_true") + parser.add_argument('--extra', type=str, default='') + parser.add_argument('--alpha_beta', type=str, default=None) + parser.add_argument("--wo_dec", action="store_true") + parser.add_argument("--wo_ana", action="store_true") + parser.add_argument("--start_round", type=int, default=0) + args = parser.parse_args() + + ## debug + args.model_name = model_path + # args.input_file = data_path + if args.alpha_beta is not None: + alpha, beta = args.alpha_beta.split('_', 1) + args.extra += f"alpha{alpha}_beta{beta}" + try: + global alpha_beta + alpha_beta = float(alpha), float(beta) + except: + print("alpha_beta format wrong, which is", args.alpha_beta) + exit() + + if args.wo_dec: + args.extra += f"_nodec" + global wo_decoding + wo_decoding = True + if args.wo_ana: + args.extra += f"_noana" + global wo_anayzer + wo_anayzer = True + if args.start_round != 0: + args.extra += f"_st_round_{args.start_round}" + + + args.max_new_tokens = 100 + args.output_file = out_path + args.metric = 'match' + args.ndocs = ndocs + args.dtype = 'half' + print(args, flush=True) + set_para(args.thresh, "full", 0, 0, 0, False, args.extra) + if args.use_conf: + global use_conf + use_conf = True + + gpt = args.model_name + input_path = args.input_file + if input_path.endswith(".json"): + input_data = json.load(open(input_path)) + else: + input_data = load_jsonlines(input_path) + + input_data = preprocess_input_data( + input_data, task=args.task) + + tokenizer = AutoTokenizer.from_pretrained(gpt, padding_side="left") + model = AutoModelForCausalLM.from_pretrained(gpt, low_cpu_mem_usage=True) + + def generate(prompt, evidences, max_new_tokens): + return my_model_rank(prompt, evidences, model=model, tokenizer=tokenizer, max_new_tokens=max_new_tokens) + + preds = [] + prompts = [] + golds = [] + metric_results = [] + scores = [] + all_results = [] + count = 0 + sum_time = [0, 0] + run_round = 0 + missing_round = [] + for i, row in tqdm(enumerate(input_data)): + if i < args.start_round: + continue + results = {} + my_input = row['instruction'] + _, evidences = process_data_evidences(row, top_n=args.ndocs) + + # try: + # pred, results, t1, t2 = generate( + # my_input, evidences, max_new_tokens=args.max_new_tokens,) + # except: + # missing_round.append(i) + # print("MISSING ROUND:", missing_round, flush=True) + # continue + + + pred, results, t1, t2 = generate( + my_input, evidences, max_new_tokens=args.max_new_tokens,) + + sum_time[0] += t1 + sum_time[1] += t2 + run_round += 1 + print(f"avg_time: {sum_time[0] / run_round} | {sum_time[1] / run_round}", flush=True) + if type(pred) is str and len(pred)>0 and (pred[0] == "#" or pred[0] == ":"): + pred = pred[1:] + prompts.append(my_input) + preds.append(pred) + all_results.append(results) + # if do_retrieve is True: + # count += 1 + if "answers" not in row and "answer" in row: + row["answers"] = [row["answer"]] if type( + row["answer"]) is str else row["answer"] + if args.metric == "accuracy": + metric_result = accuracy(pred, row["output"]) + + elif args.metric == "match": + if "SUPPORTS" in pred: + pred = "true" + elif "REFUTES" in pred: + pred = "false" + metric_result = match(pred, row["answers"]) + else: + raise NotImplementedError + + metric_results.append(metric_result) + if i % 10 == 0: + print("average: {}".format(np.mean(metric_results)), flush=True) + final_results = {"preds": preds, "prompts": prompts, "metric_results": metric_results, "all_results": all_results, + "golds": golds, "metric": args.metric, "metric_mean": np.mean(metric_results), "scores": scores} + with open(args.output_file + "_tmp", "w") as outfile: + json.dump(final_results, outfile) + + final_results = {"preds": preds, "prompts": prompts, "metric_results": metric_results, "all_results": all_results, + "golds": golds, "metric": args.metric, "metric_mean": np.mean(metric_results), "scores": scores} + with open(args.output_file, "w") as outfile: + json.dump(final_results, outfile) + + print("Final result: {0}".format(np.mean(metric_results)), flush=True) + print("MISSING ROUND:", missing_round, flush=True) + #print("Retrieval Frequencies: {0}".format(count / len(final_results))) + print(metric_results, flush=True) + +if __name__ == "__main__": + main() diff --git a/paddlenlp/drag/drag_generate.py b/paddlenlp/drag/drag_generate.py new file mode 100644 index 000000000000..66c95f87fcc9 --- /dev/null +++ b/paddlenlp/drag/drag_generate.py @@ -0,0 +1,645 @@ +import paddle + + +import copy +import random +import paddle.distributed as dist +import paddle.nn.functional as F +import numpy as np +from paddlenlp.transformers import LlamaForCausalLM +from paddlenlp.transformers.model_outputs import ModelOutput +from paddlenlp.transformers.utils import get_scale_by_dtype + + +from paddlenlp.utils.log import logger + +from paddlenlp.generation.configuration_utils import DEFAULT_MAX_NEW_TOKENS, GenerationConfig +from paddlenlp.generation.streamers import BaseStreamer +from paddlenlp.generation.utils import get_unfinished_flag + + + +from paddlenlp.generation.logits_process import ( + LogitsProcessorList, + TopKProcess, + TopPProcess, +) +from paddlenlp.generation.stopping_criteria import ( + StoppingCriteria, + StoppingCriteriaList, + validate_stopping_criteria, +) +from typing import Optional, Union + +threshold = None +small_model = None +uncer_w1 = None +uncer_w2 = None +uncer_w3 = None +randm = None +extra = None + +def set_para(thres, small_mod, w1, w2, w3, _randm: bool, ext): + global threshold + global small_model + global uncer_w1 + global uncer_w2 + global uncer_w3 + global randm + global extra + threshold = thres + small_model = small_mod + uncer_w1 = w1 + uncer_w2 = w2 + uncer_w3 = w3 + randm = _randm + extra = ext + +lex_diver = None +rel_scores = None +def set_para_ano(_lex_diver, _scores): + global lex_diver + lex_diver = _lex_diver + global rel_scores + rel_scores = _scores + + + +class Drag(LlamaForCausalLM): + @staticmethod + def prepare_attention_mask_for_generation(input_ids, pad_token_id, eos_token_id): + is_pad_token_in_inputs_ids = (pad_token_id is not None) and paddle.any(input_ids == pad_token_id).item() + is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ( + (eos_token_id is not None) and (pad_token_id != eos_token_id) + ) + if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id: + attention_mask = (input_ids == pad_token_id).astype(paddle.get_default_dtype()) * get_scale_by_dtype( + return_positive=False + ) + else: + attention_mask = paddle.zeros_like(input_ids, dtype=paddle.get_default_dtype()) + return attention_mask + + def relative_top_filter(self, scores: paddle.Tensor, relative_top: float = 0.1, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1) -> paddle.Tensor: + scores_normalized = F.log_softmax(scores, axis=-1) + sorted_logits = paddle.sort(scores_normalized, descending=True) + min_thresh = sorted_logits[..., min_tokens_to_keep-1] + probs_max = paddle.max(scores_normalized, axis=-1) + probs_thresh = probs_max + np.log(relative_top) + probs_thresh = paddle.minimum(min_thresh, probs_thresh) + probs_thresh = probs_thresh.unsqueeze(-1) + scores_normalized[scores_normalized < probs_thresh] = filter_value + return scores_normalized + + def sample_irr( + self, + input_ids, + logits_processors, + max_length, + pad_token_id, + eos_token_id, + top_k=None, + top_p=None, + temperature=None, + min_tokens_to_keep=1, + stopping_criteria=None, + streamer=None, + fast_ptq_sampling=False, + trunc_input=True, + synced_gpus=False, + input_ids_irr=None, + attention_mask_irr=None, + alpha = 10, + relative_top = 0.1, + question_token_len = None, + question_and_doc_len = None, + **model_kwargs + ): + # output_attentions = self.config.output_attentions + # output_hidden_states = self.config.output_hidden_states + output_attentions = True + output_hidden_states = True + model_kwargs["use_cache"] = model_kwargs.get("use_cache", True) + + model_kwargs_irr = copy.deepcopy(model_kwargs) + model_kwargs_irr["attention_mask"] = attention_mask_irr + model_kwargs_irr["use_cache"] = model_kwargs_irr.get("use_cache", True) + + logits_processors = logits_processors if logits_processors is not None else LogitsProcessorList() + + # max_length will be convert to MaxLengthCriteria + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + if max_length is not None: + # logger.warning( + # "`max_length` is deprecated in this function, use" + # " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead." + # ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + + batch_size, cur_len = input_ids.shape + origin_len = cur_len + unfinished_flag = paddle.full([batch_size, 1], True, dtype="bool") + scores = paddle.full([batch_size, 1], 0.0, dtype=paddle.get_default_dtype()) + + generate_end = False + + num_generation = 0 + num_decoding = 0 + decoding_list = [] + risk_scores = [] + + def calculate_attention_uncertainty(next_new_token_attn, logits): + """ + 计算基于多头注意力和logits的token不确定性 + + :param next_new_token_attn: Tensor of shape (1, 32, n), attention output + :param logits: Tensor of shape (1, n, vocab_size), logits output for the token prediction + :return: 不确定性度量的张量 (1, n) + """ + # 1. 计算多头注意力的权重差异 + # 我们可以计算32个头的标准差或方差 + # 计算标准差, 为每个token在不同头之间的注意力权重计算差异 + # attention_variance = paddle.var(next_new_token_attn, dim=1) # Variance across the 32 attention heads + # uncertainty_1 = paddle.sqrt(attention_variance).sum() # 可以选择使用标准差作为不确定性的度量 + + # 2. 计算单个头的注意力权重分布 + # 我们通过计算每个token在每个头的权重分布的熵来衡量分散度 + # 熵较高表示该token的权重分布较为均匀,从而不确定性较高 + + if randm: + return random.random() + + def compute_entropy(attn): + # 使用softmax来计算概率分布,然后计算熵 + prob_dist = F.softmax(attn, axis=-1) + entropy = -(prob_dist * prob_dist.log()).sum(axis=-1) + return entropy + if uncer_w1 != 0: + entropy_values = compute_entropy(next_new_token_attn.sum(1).squeeze(0)) # 计算每个token在不同头的熵 + uncertainty_2 = entropy_values # 熵越大,表示不确定性越高 + else: + uncertainty_2 = 0 + + # 3. 输出token的概率(基于logits) + # 使用softmax对logits进行归一化,计算每个token的生成概率 + # logits的形状为 (1, n, vocab_size),我们需要对每个token的logits进行softmax + if uncer_w2 != 0: + softmax_probs = F.softmax(logits.squeeze(0), axis=-1) # (n, vocab_size) + + # # 对每个token,选出最大概率对应的词汇的概率 + max_probs = softmax_probs.max(axis=-1) # 获取每个token的最大生成概率 + uncertainty_3 = 1 - max_probs # 概率越大,不确定性越小,取反即为不确定性 + else: + uncertainty_3 = 0 + + if uncer_w3 != 0: + uncertainty_4 = compute_entropy(logits) + else: + uncertainty_4 = 0 + # 综合不确定性:可以选择加权平均或者简单地加总每部分的不确定性 + # total_uncertainty = uncertainty_2#*0.3 + uncertainty_3 + total_uncertainty = uncertainty_2 * uncer_w1 + uncertainty_3 * uncer_w2 + uncertainty_4 * uncer_w3 + + return total_uncertainty + + def calc_risk(attn, logits): + if randm: + return random.random() + ret = 0 + attn = attn.sum(1).squeeze(0) + st = question_token_len + for j in range(len(question_and_doc_len)): + en = question_and_doc_len[j] + assert en <= attn.shape[0], f"{j}--- en:{en} attn_len:{attn.shape[0]}" + ret += attn[st:en].sum().item() / (1 + rel_scores[j]) + st = en + + softmax_probs = F.softmax(logits.squeeze(0), axis=-1) + max_probs = softmax_probs.max(axis=-1) + ret *= (1- max_probs.item()) + + if lex_diver > 0: + ret *= lex_diver + # else: + + return ret + + while True: + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = paddle.to_tensor(0.0 if generate_end else 1.0) + # send 0.0 if we finished, 1.0 otherwise + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + break + # prepare model inputs & get model output + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + # NOTE: to decrease ref-count and clear outdate cache in-time + model_kwargs["cache"] = None + model_kwargs["past_key_values"] = None + #outputs = self(**model_inputs) + + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + model_kwargs_irr["cache"] = None + model_kwargs_irr["past_key_values"] = None + model_inputs_irr = self.prepare_inputs_for_generation(input_ids_irr, **model_kwargs_irr) + # NOTE: to decrease ref-count and clear outdate cache in-time + model_kwargs_irr["cache"] = None + model_kwargs_irr["past_key_values"] = None + # outputs = self(**model_inputs) + + if synced_gpus and generate_end: + continue # don't waste resources running the code we don't need + + if isinstance(outputs, tuple): + ori_logits = outputs[0] + elif isinstance(outputs, ModelOutput): + ori_logits = outputs.logits + else: + ori_logits = outputs + + # [batch_size, vocab_size] + ori_logits = ori_logits[:, -1, :] + + if lex_diver is None: + risk_score = calculate_attention_uncertainty(outputs['attentions'][-1][:,:,-1,:], logits) # TODO + else: + risk_score = calc_risk(outputs['attentions'][-1][:,:,-1,:], ori_logits) + + risk_scores.append(risk_score) + if risk_score > threshold: + irr_flag= True + num_decoding += 1 + decoding_list.append(num_generation) + else: + irr_flag = False + num_generation += 1 + + if irr_flag: + #outputs_irr = self(**model_inputs_irr) + outputs_irr = self( + **model_inputs_irr, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + if isinstance(outputs_irr, tuple): + logits_irr = outputs_irr[0] + elif isinstance(outputs_irr, ModelOutput): + logits_irr = outputs_irr.logits + else: + logits_irr = outputs_irr + + logits_irr = logits_irr[:, -1, :] + + if relative_top > 0.0: + ori_logits = self.relative_top_filter(ori_logits, relative_top) + logits_irr = F.log_softmax(logits_irr, axis=-1) + mask = ori_logits[0] < -1e3 + logits_irr = paddle.where(mask, paddle.to_tensor(-1e3), logits_irr) + # logits_irr[0][mask] = -1e3 + else: + ori_logits = F.log_softmax(ori_logits, axis=-1) + logits_irr = F.log_softmax(ori_logits, axis=-1) + + logits = ori_logits + alpha * (ori_logits - logits_irr) + + else: + logits = ori_logits + + + # pre-process distribution + logits = self.adjust_logits_during_generation(logits) + logits = logits_processors(input_ids, logits) + + # sample + origin_probs = F.softmax(logits) + origin_probs = paddle.log(origin_probs) + if temperature is not None and temperature != 1.0: + logits = logits / temperature + probs = F.softmax(logits) + if top_k is not None and top_k != 0: + probs = TopKProcess(probs, top_k, min_tokens_to_keep) + if top_p is not None and top_p < 1.0: + probs = TopPProcess(probs, top_p, min_tokens_to_keep) + if paddle.device.is_compiled_with_custom_device("gcu"): + probs = paddle.cast(probs, "float32") + if paddle.device.is_compiled_with_xpu(): + probs = paddle.cast(probs, "float32") + + # multinomial already support fp16 and bf16 currently, fix issue: https://github.com/PaddlePaddle/Paddle/issues/51852 + next_tokens = paddle.multinomial(probs) + + if self.config.tensor_parallel_degree > 1: + # Maybe no need to broadcast if seed is set correclty. + from paddle.distributed import fleet + + try: + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_model_parallel_group() + src = hcg.get_model_parallel_group_src_rank() + except: + group, src = None, 0 + paddle.distributed.broadcast(next_tokens, src=src, group=group) + # config does not include pipeline_parallel_degree, and pipeline parallel + # uses trainer.model_wrapped to run in both train and predict mode + # which has pp_group as a attribute + # TODO(guosheng): only let the last stage of pipeline to do softmax + # and sampling, and then broadcast to avoid broadcast logits. + if getattr(self, "pp_group", None) is not None: + paddle.distributed.broadcast( + next_tokens, src=self.pp_group.ranks[0], group=self.pp_group # use rank 0 for same seed to check + ) + + next_scores = paddle.index_sample(origin_probs, next_tokens) + if eos_token_id is not None: + next_tokens = paddle.where(unfinished_flag, next_tokens, paddle.full_like(next_tokens, pad_token_id)) + + scores = self.update_scores_for_generation(scores, next_scores, cur_len - origin_len, unfinished_flag) + + cur_len += 1 + input_ids = paddle.concat([input_ids, next_tokens], axis=1) + input_ids_irr = paddle.concat([input_ids_irr, next_tokens], axis=1) + + if streamer is not None: + if self.config.tensor_parallel_rank == 0: + streamer.put(next_tokens.cpu()) + + if stopping_criteria(input_ids, scores): + generate_end = True + + if eos_token_id is not None: + unfinished_flag = get_unfinished_flag(input_ids, unfinished_flag, eos_token_id) + if not paddle.any(unfinished_flag): + generate_end = True + + # Stop when there is a in all sentences + if generate_end and not synced_gpus: + break + + model_kwargs = self.update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.is_encoder_decoder + ) + outputs_copy = copy.deepcopy(outputs) + model_kwargs_irr = self.update_model_kwargs_for_generation( + outputs_copy, model_kwargs_irr, is_encoder_decoder=self.is_encoder_decoder + ) + if fast_ptq_sampling: + break + + if irr_flag: + del outputs_irr + + if streamer is not None: + streamer.end() + + return input_ids[:, origin_len:] if trunc_input else input_ids, scores + + @paddle.no_grad() + def generate( + self, + input_ids: paddle.Tensor = None, + generation_config: GenerationConfig = None, + stopping_criteria: StoppingCriteria = None, + streamer: BaseStreamer = None, + synced_gpus: Optional[bool] = None, + use_irr: Optional[str] = None, + inputs_irr: Optional[paddle.Tensor] = None, + alpha: Optional[float] = 10, + question_token_len: Optional[int] = None, + question_and_doc_len: list = None, + **kwargs, + ): + if generation_config is None: + if self.generation_config is None or self.generation_config._from_model_config: + new_generation_config = GenerationConfig.from_model_config(self.config) + if new_generation_config != self.generation_config: + logger.warning( + "model.generation_config is in conflict with model.config, " "model.config is used." + ) + self.generation_config = new_generation_config + generation_config = self.generation_config + + # without update model.generation_config + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) + + assert generation_config.decode_strategy in [ + "greedy_search", + "sampling", + "beam_search", + "sampling_irr", + ], "`decode_strategy` must be one of 'greedy_search', 'sampling' or 'beam_search' but received {}.".format( + generation_config.decode_strategy + ) + + if getattr(self, "deprecated_warnings", None) is None: + self.deprecated_warnings = {} + + use_fast = False + if "use_faster" in model_kwargs: + raise ValueError("`use_faster` is deprecated now.") + + if "use_fast" in model_kwargs: + raise ValueError("`use_fast` is deprecated now.") + + bos_token_id = ( + generation_config.bos_token_id if generation_config.bos_token_id is not None else self.config.bos_token_id + ) + eos_token_id = ( + generation_config.eos_token_id if generation_config.eos_token_id is not None else self.config.eos_token_id + ) + pad_token_id = ( + generation_config.pad_token_id if generation_config.pad_token_id is not None else self.config.pad_token_id + ) + forced_bos_token_id = ( + generation_config.forced_bos_token_id + if generation_config.forced_bos_token_id is not None + else self.config.forced_bos_token_id + ) + forced_eos_token_id = ( + generation_config.forced_eos_token_id + if generation_config.forced_eos_token_id is not None + else self.config.forced_eos_token_id + ) + decoder_start_token_id = ( + generation_config.decoder_start_token_id + if generation_config.decoder_start_token_id is not None + else self.config.decoder_start_token_id + ) + no_repeat_ngram_size = ( + generation_config.no_repeat_ngram_size + if generation_config.no_repeat_ngram_size is not None + else self.config.no_repeat_ngram_size + ) + + if getattr(self, "_fast_entry", None) is not False and use_fast: + fg_args = locals() + fg_args.pop("self") + fg_args.pop("__class__", None) + model_kwargs = fg_args.pop("model_kwargs") + fg_args.update(model_kwargs) + try: + if getattr(self, "_fast_entry", None) is None: + self._build_fast(fg_args) + if self._fast_entry: + output = self._fast_entry(**fg_args) + if isinstance(output, tuple): + output_ids, dummy_srore = output + else: + output_ids = output + # make result and fast result oneconsistent + dummy_srore = None + if generation_config.decode_strategy == "beam_search": + output_ids = output_ids.transpose([1, 2, 0]) + output_ids = output_ids[:, : generation_config.num_return_sequences, :].reshape( + [-1, output_ids.shape[-1]] + ) + if dummy_srore is not None: + dummy_srore = dummy_srore[:, : generation_config.num_return_sequences].flatten() + else: + output_ids = output_ids.transpose([1, 0]) + return output_ids, dummy_srore + + except Exception as e: + fg_args["model_kwargs"] = model_kwargs + # TODO + # Prevent self._convert_to_fast to throw Exception + self._convert_to_fast(fg_args) + logger.warning(e) + logger.warning("FastGeneration is not available, " "and the original version would be used instead.") + + # input_ids in model_kwargs is supported + if "input_ids" in model_kwargs: + _input_ids = model_kwargs.pop("input_ids") + if input_ids is None: + input_ids = _input_ids + + # params check + if input_ids is None and "inputs_embeds" not in model_kwargs: + # Init `input_ids` with bos_token_id + input_ids = self.prepare_input_ids_for_generation(bos_token_id) + elif "inputs_embeds" in model_kwargs: + # Add input embeds support + input_ids = self.prepare_input_ids_for_generation( + bos_token_id, encoder_output=model_kwargs["inputs_embeds"] + ) + + if model_kwargs.get("attention_mask", None) is None: + # TODO + # Init `attention_mask` depending on `pad_token_id` + model_kwargs["attention_mask"] = self.prepare_attention_mask_for_generation( + input_ids, pad_token_id, eos_token_id + ) + self.is_encoder_decoder = self.config.is_encoder_decoder + + if self.is_encoder_decoder: + model_kwargs = self.prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs) + # set input_ids as decoder_input_ids + if "decoder_input_ids" in model_kwargs: + input_ids = model_kwargs.pop("decoder_input_ids") + else: + input_ids = self.prepare_decoder_input_ids_for_generation( + input_ids, decoder_start_token_id, bos_token_id + ) + # streamer + if streamer is not None: + # streamer couldn't support beam_search strategy + if generation_config.decode_strategy == "beam_search" or generation_config.num_beams > 1: + raise ValueError( + "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." + ) + + pad_token_id = self.set_pad_token_id(pad_token_id, eos_token_id) + + if generation_config.max_length != 0 and generation_config.max_new_tokens == DEFAULT_MAX_NEW_TOKENS: + logger.warning("`max_length` will be deprecated in future releases, use `max_new_tokens` instead.") + generation_config.max_new_tokens = generation_config.max_length + + if generation_config.min_length != 0 and generation_config.min_new_tokens == 0: + logger.warning("`min_length` will be deprecated in future releases, use `min_new_tokens` instead.") + generation_config.min_new_tokens = generation_config.min_length + + max_length = generation_config.max_new_tokens + min_length = generation_config.min_new_tokens + + input_len = input_ids.shape[-1] + min_len = input_len + min_length + max_len = input_len + max_length + + logits_processors = self.get_logits_processor( + min_length=min_len if min_length > 0 else None, + max_length=max_len, + eos_token_id=eos_token_id, + forced_bos_token_id=forced_bos_token_id, + forced_eos_token_id=forced_eos_token_id, + num_beams=generation_config.num_beams, + num_beam_groups=generation_config.num_beam_groups, + diversity_rate=generation_config.diversity_rate, + repetition_penalty=generation_config.repetition_penalty, + no_repeat_ngram_size=generation_config.no_repeat_ngram_size, + logits_processors=model_kwargs["logits_processors"] + if "logits_processors" in model_kwargs + and isinstance(model_kwargs["logits_processors"], LogitsProcessorList) + else None, + ) + if "logits_processors" in model_kwargs: + model_kwargs.pop("logits_processors") + + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + if use_irr is not None: + if self.config.is_encoder_decoder: + logger.warning( + "Using irrelevant doc not implemented for encoder-decoder architecture yet." + ) + + input_ids_irr = inputs_irr + attention_mask_irr = self.prepare_attention_mask_for_generation( + input_ids_irr, generation_config.pad_token_id, generation_config.eos_token_id + ) + + + if generation_config.decode_strategy == "sampling_irr": + if generation_config.num_return_sequences > 1: + input_ids, model_kwargs = self.expand_inputs_for_generation( + input_ids, expand_size=generation_config.num_return_sequences, **model_kwargs + ) + + if generation_config.num_return_sequences > 1: + input_ids_irr, model_kwargs_irr = self.expand_inputs_for_generation( + input_ids, expand_size=generation_config.num_return_sequences, **model_kwargs + ) + + return self.sample_irr( + input_ids, + logits_processors, + max_len, + pad_token_id, + eos_token_id, + generation_config.top_k, + generation_config.top_p, + generation_config.temperature, + stopping_criteria=stopping_criteria, + streamer=streamer, + fast_ptq_sampling=generation_config.fast_ptq_sampling, + trunc_input=generation_config.trunc_input, + synced_gpus=synced_gpus, + input_ids_irr=input_ids_irr, + attention_mask_irr=attention_mask_irr, + alpha=alpha, + question_token_len=question_token_len, + question_and_doc_len=question_and_doc_len, + **model_kwargs, + ) + def check(self): + print("Make sure the model is a Drag class") diff --git a/paddlenlp/drag/metrics.py b/paddlenlp/drag/metrics.py new file mode 100644 index 000000000000..e989a1687818 --- /dev/null +++ b/paddlenlp/drag/metrics.py @@ -0,0 +1,91 @@ +import numpy as np +import string +import re +from collections import Counter +import re + + +def exact_match_score(prediction, ground_truth): + return (normalize_answer(prediction) == normalize_answer(ground_truth)) + +def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): + scores_for_ground_truths = [] + for ground_truth in ground_truths: + score = metric_fn(prediction, ground_truth) + scores_for_ground_truths.append(score) + return max(scores_for_ground_truths) + +def accuracy(preds, labels): + match_count = 0 + for pred, label in zip(preds, labels): + target = label[0] + if pred == target: + match_count += 1 + + return 100 * (match_count / len(preds)) + + +def f1(decoded_preds, decoded_labels): + f1_all = [] + for prediction, answers in zip(decoded_preds, decoded_labels): + if type(answers) == list: + if len(answers) == 0: + return 0 + f1_all.append(np.max([qa_f1_score(prediction, gt) + for gt in answers])) + else: + f1_all.append(qa_f1_score(prediction, answers)) + return 100 * np.mean(f1_all) + + +def qa_f1_score(prediction, ground_truth): + prediction_tokens = normalize_answer(prediction).split() + ground_truth_tokens = normalize_answer(ground_truth).split() + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + return 0 + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1 = (2 * precision * recall) / (precision + recall) + return f1 + + +def normalize_answer(s): + def remove_articles(text): + return re.sub(r'\b(a|an|the)\b', ' ', text) + + def white_space_fix(text): + return ' '.join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return ''.join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + return white_space_fix(remove_articles(remove_punc(lower(s)))) + +def find_entity_tags(sentence): + entity_regex = r'(.+?)(?=\s<|$)' + tag_regex = r'<(.+?)>' + entity_names = re.findall(entity_regex, sentence) + tags = re.findall(tag_regex, sentence) + + results = {} + for entity, tag in zip(entity_names, tags): + if "<" in entity: + results[entity.split("> ")[1]] = tag + else: + results[entity] = tag + return results + +def match(prediction, ground_truth): + + prediction = prediction.lower() + + for gt in ground_truth: + gt = gt.lower() + if gt in prediction: + return 1 + return 0 \ No newline at end of file diff --git a/paddlenlp/drag/utils.py b/paddlenlp/drag/utils.py new file mode 100644 index 000000000000..25467a728d5c --- /dev/null +++ b/paddlenlp/drag/utils.py @@ -0,0 +1,296 @@ +import jsonlines +import json +import copy +import re + +PROMPT_DICT = { + "prompt_input": ( + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" + ), + "prompt_no_input": [{"role":"system", "content":"Write a response that appropriately completes the request."}, + {"role":"user", "content":"### Instruction:\n{instruction}\n\n### Response:\n"} + ], + "prompt_no_input_retrieval": [{"role":"system", "content": + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\n\n"}, + {"role":"user", "content":"### Paragraph:\n{paragraph}\n\n### Instruction:\n{instruction}\n\n### Response:"} + ], + "prompt_open_instruct": ( + "\n{instruction}\n" + "\n" + ), + "prompt_open_instruct_retrieval": ( + "\nReference:{paragraph}\n{instruction}\n" + "\n" + ), + "llama_chat_prompt": ( + "[INST]{instruction}[/INST]" + ), + "llama_chat_prompt_retrieval": ( + "[INST]{paragraph}\n{instruction}[/INST]" + ), + "prompt_for_triplet": [{"role":"system", "content": + "You will receive an instruction ." + "Analyze the instruction and extracted two specific types of information:" + "1.Subject: The main entity or instance mentioned in the question. " + "This could be a person, object, organization, event, or activity, etc. It may consist of one or more words." + "2.Relationship: The summarization of the relative relationship to the subject and the expected answer" #and also reflect the nature of the expected answer. + "Summarize explicit relationship information from the input instruction. " + "Do not add additional guesses about the nature of the subject, but do not omit information explicitly mentioned in the original instruction. It may consist of one or more words.\n\n" + "Provide the output in the following format:\n" + "Subject: XXX\n" + "Relationship: XXX\n" + "Explanation: XXX\n" + "END\n\n" + }, + {"role":"user", "content": + "### Instruction: \n{instruction}\n"} + ], + "prompt_for_same": [{"role":"system", "content": "You will receive a paragraph and some information (including a subject and a relationship) extracted from a query." + "Firstly, determine if the subject extracted from the query is exactly the same as an entity mentioned in the paragraph." + "Only consider the entity as the same if the full name exactly match. Partial matches, such as different middle names or additional qualifiers, should be considered as not the same." + "Secondly, if the entities are exactly the same, determine if the paragraph contains any description that match the given relationship." + "If there is a description that match the given relationship, extract the corresponding sentence(s) from the paragraph." + "If there is no matched description, simply state 'No relevant description found'.\n" + "Finally, explain the two matching judgments you made.\n\n" + "Provide the output in the following format:\n" + "Subject Match: true/false\n" + "Relationship Match: true/false\n" + "Description: XXX\n" + "Explanation: XXX\n\n"}, + {"role": "user", "content": "### Subject: {subject}\n" + "Relationship: {relationship}\n" + "Paragraph: {paragraph}\n\n### Output: \n"} + ], + "prompt_for_combine":[{"role": "system", "content": + "You are provided with an instruction and some retrieved texts (wrapped in tags). " + "Your task is to select credible texts to answer the instruction. " + "You should categorize the texts and reason out if they refer to the same entity based on the content. " + "Integrate the texts referring to the same entity to provide a comprehensive answer." + "If the texts refer to entities that are the same name but different, answer in separate paragraphs." + "Provide the detailed answer based on the paragraphs in the following format:\n" + "Answer: XXX\n\n"}, + {"role": "user", "content": + "### Instruction: {instruction}\n" + "### Retrieved Texts:\n{paragraphs}\n\n" + "### Answer:"} + ], + "prompt_for_combine_no_retrieval":[{"role": "system", "content": + "You are provided with an instruction from the instruction. " + "Your task is to answer the instruction. " + "Provide the detailed answer in the following format:\n" + "Answer: XXX\n\n"}, + {"role": "user", "content": + "### Instruction: {instruction}\n" + "### Answer:"} + ], + "prompt_for_splitting": [ + { + "role": "system", + "content": + "You will receive an instruction.The instruction may contain one or more entities, and additional knowledge about those entities needs to be retrieved in order to answer the query." + "Analyze the instruction to determine if it requires multiple retrievals based on the number of entities or pieces of knowledge needed." + "If the instruction requires multiple retrievals, split the instruction into sub-questions." + "Each sub-question should target a specific entity or piece of knowledge needed to answer the main question." + "If the instruction contains only one entity to be retrieved, then the sub-question count is 1 and the content of the sub-question should be exactly the same as the original instruction." + "Note that count is the same as the number of entities to be retrieved in instruction and is a positive integer with a minimum value of 1." + "Please give an explanation of your entire reasoning process." + "Output the explanation of the instruction, the number of sub-questions and their content in the following format:\n" + "Explanation: {XXX}" + "Sub-question Count: {count}\n" + "Sub-questions:\n" + "1. {sub-question-1}\n" + "2. {sub-question-2}\n" + "... (and so on)" + + }, + { + "role": "user", + "content": "### Instruction: \n{instruction}\n" + } + ] + + +} + +TASK_INST = {"wow": "Given a chat history separated by new lines, generates an informative, knowledgeable and engaging response. ", + "fever": "Is the following statement correct or not? Say true if it's correct; otherwise say false.", + "eli5": "Provide a paragraph-length response using simple words to answer the following question.", + "obqa": "Given four answer candidates, A, B, C and D, choose the best answer choice.", + "arc_easy": "Given four answer candidates, A, B, C and D, choose the best answer choice.", + "arc_c": "Given four answer candidates, A, B, C and D, choose the best answer choice.", + "trex": "Given the input format 'Subject Entity [SEP] Relationship Type,' predict the target entity.", + "asqa": "Answer the following question. The question may be ambiguous and have multiple correct answers, and in that case, you have to provide a long-form answer including all correct answers.", + "med": "Given four answer candidates, A, B, C and D, choose the best answer choice."} + +rel_tokens_names = ["[Irrelevant]", "[Relevant]"] +retrieval_tokens_names = ["[No Retrieval]", + "[Retrieval]", "[Continue to Use Evidence]"] +utility_tokens_names = ["[Utility:1]", "[Utility:2]", + "[Utility:3]", "[Utility:4]", "[Utility:5]"] +ground_tokens_names = ["[Fully supported]", + "[Partially supported]", "[No support / Contradictory]"] +other_special_tokens = ["", "", "[PAD]", + "", "", ""] +control_tokens = ["[Fully supported]", "[Partially supported]", "[No support / Contradictory]", "[No Retrieval]", "[Retrieval]", + "[Irrelevant]", "[Relevant]", "", "", "[Utility:1]", "[Utility:2]", "[Utility:3]", "[Utility:4]", "[Utility:5]"] + + +def load_special_tokens(tokenizer, use_grounding=False, use_utility=False): + ret_tokens = {token: tokenizer.convert_tokens_to_ids( + token) for token in retrieval_tokens_names} + rel_tokens = {} + for token in ["[Irrelevant]", "[Relevant]"]: + rel_tokens[token] = tokenizer.convert_tokens_to_ids(token) + + grd_tokens = None + if use_grounding is True: + grd_tokens = {} + for token in ground_tokens_names: + grd_tokens[token] = tokenizer.convert_tokens_to_ids(token) + + ut_tokens = None + if use_utility is True: + ut_tokens = {} + for token in utility_tokens_names: + ut_tokens[token] = tokenizer.convert_tokens_to_ids(token) + + return ret_tokens, rel_tokens, grd_tokens, ut_tokens + + +def fix_spacing(input_text): + # Add a space after periods that lack whitespace + output_text = re.sub(r'(?<=\w)([.!?])(?=\w)', r'\1 ', input_text) + return output_text + + +def postprocess(pred): + special_tokens = ["[Fully supported]", "[Partially supported]", "[No support / Contradictory]", "[No Retrieval]", "[Retrieval]", + "[Irrelevant]", "[Relevant]", "", "", "[Utility:1]", "[Utility:2]", "[Utility:3]", "[Utility:4]", "[Utility:5]"] + for item in special_tokens: + pred = pred.replace(item, "") + pred = pred.replace("", "") + + if len(pred) == 0: + return "" + if pred[0] == " ": + pred = pred[1:] + return pred + + +def load_jsonlines(file): + with jsonlines.open(file, 'r') as jsonl_f: + lst = [obj for obj in jsonl_f] + return lst + + +def load_file(input_fp): + if input_fp.endswith(".json"): + input_data = json.load(open(input_fp)) + else: + input_data = load_jsonlines(input_fp) + return input_data + + +def save_file_jsonl(data, fp): + with jsonlines.open(fp, mode='w') as writer: + writer.write_all(data) + + +def preprocess_input(input_data, task): + if task == "factscore": + for item in input_data: + item["instruction"] = item["input"] + item["output"] = [item["output"] + ] if "output" in item else [item["topic"]] + return input_data + + elif task == "qa": + for item in input_data: + if "instruction" not in item: + item["instruction"] = item["question"] + if "answers" not in item and "output" in item: + item["answers"] = "output" + return input_data + + elif task in ["asqa", "eli5"]: + processed_input_data = [] + for instance_idx, item in enumerate(input_data["data"]): + prompt = item["question"] + instructions = TASK_INST[task] + prompt = instructions + "## Input:\n\n" + prompt + entry = copy.deepcopy(item) + entry["instruction"] = prompt + processed_input_data.append(entry) + return processed_input_data + + +def postprocess_output(input_instance, prediction, task, intermediate_results=None): + if task == "factscore": + return {"input": input_instance["input"], "output": prediction, "topic": input_instance["topic"], "cat": input_instance["cat"]} + + elif task == "qa": + input_instance["pred"] = prediction + return input_instance + + elif task in ["asqa", "eli5"]: + # ALCE datasets require additional postprocessing to compute citation accuracy. + final_output = "" + docs = [] + if "splitted_sentences" not in intermediate_results: + input_instance["output"] = postprocess(prediction) + + else: + for idx, (sent, doc) in enumerate(zip(intermediate_results["splitted_sentences"][0], intermediate_results["ctxs"][0])): + if len(sent) == 0: + continue + postprocessed_result = postprocess(sent) + final_output += postprocessed_result[:- + 1] + " [{}]".format(idx) + ". " + docs.append(doc) + if final_output[-1] == " ": + final_output = final_output[:-1] + input_instance["output"] = final_output + input_instance["docs"] = docs + return input_instance + +def process_arc_instruction(item, instruction): + choices = item["choices"] + answer_labels = {} + for i in range(len(choices["label"])): + answer_key = choices["label"][i] + text = choices["text"][i] + if answer_key == "1": + answer_labels["A"] = text + if answer_key == "2": + answer_labels["B"] = text + if answer_key == "3": + answer_labels["C"] = text + if answer_key == "4": + answer_labels["D"] = text + if answer_key in ["A", "B", "C", "D"]: + answer_labels[answer_key] = text + + if "D" not in answer_labels: + answer_labels["D"] = "" + choices = "\nA: {0}\nB: {1}\nC: {2}\nD: {3}".format(answer_labels["A"], answer_labels["B"], answer_labels["C"], answer_labels["D"]) + if "E" in answer_labels: + choices += "\nE: {}".format(answer_labels["E"]) + processed_instruction = instruction + "\n\n### Input:\n" + item["instruction"] + choices + return processed_instruction + + +def postprocess_answers_closed(output, task, choices=None): + final_output = None + if choices is not None: + for c in choices.split(" "): + if c in output: + final_output = c + if task == "fever" and output in ["REFUTES", "SUPPORTS"]: + final_output = "true" if output == "SUPPORTS" else "REFUTES" + if task == "fever" and output.lower() in ["true", "false"]: + final_output = output.lower() + if final_output is None: + return output + else: + return final_output