From 3de7b47b46c08e0af5051bac095bec6082f15149 Mon Sep 17 00:00:00 2001 From: murphy Date: Thu, 9 Dec 2021 16:49:00 -0800 Subject: [PATCH 1/2] added debugTrain and debugCode args, add some docstrings --- cli.py | 14 ++-- data3.py | 176 ++++++++++++++++++++++++-------------------------- run.py | 8 +-- span_utils.py | 115 ++++++++++++++++++++++++--------- 4 files changed, 181 insertions(+), 132 deletions(-) diff --git a/cli.py b/cli.py index cb7da4c..a691a21 100755 --- a/cli.py +++ b/cli.py @@ -98,7 +98,9 @@ def main(): help="Evaluate & save model") parser.add_argument('--prefix', type=str, default='', help="Prefix for saving predictions") - parser.add_argument('--debug', action='store_true', + parser.add_argument('--debugCode', action='store_true', + help="Use a subset of data for debugging") + parser.add_argument('--debugTrain', action='store_true', help="Use a subset of data for debugging") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") @@ -111,7 +113,7 @@ def main(): help="some checkpoint pdb debug") # parameters for SpanSeqGen - parser.add_argument("--top_k_passages", default=10, type=int) + parser.add_argument("--num_top_passages", default=10, type=int) # "data/reranking_results/ambigqa" parser.add_argument("--ranking_folder_path", default=None) # "data/reranking_results/ambigqa" @@ -130,7 +132,7 @@ def main(): parser.add_argument("--passage_clustering", default=False, action="store_true") parser.add_argument("--k_cluster", default = 10, type=int) - parser.add_argument("--rank_threshold", default=60, type=int) + parser.add_argument("--rank_threshold", default=100, type=int) parser.add_argument("--is_contrastive", default=False, action="store_true") @@ -199,7 +201,11 @@ def main(): if args.model.lower() == "t5" and args.prepend_question_token == False: logger.warning("t5 model needs prepending, it's adjusted now") args.prepend_question_token = True - + if args.debugCode and args.debugTrain: + raise ValueError("debug code and debug train mode are both turned" + " on. You need to either remove --debugCode or " + "--deubgTrain to make the script running.") + logger.info("Using {} gpus".format(args.n_gpu)) if args.device == "cuda": assert args.n_gpu > 1, "if there is only one gpu, set args.device=0" diff --git a/data3.py b/data3.py index ec565a5..f3394a1 100644 --- a/data3.py +++ b/data3.py @@ -9,6 +9,7 @@ import torch from torch.utils.data import Dataset, TensorDataset, DataLoader, RandomSampler, SequentialSampler +from typing import List import argparse import numpy as np @@ -49,23 +50,29 @@ def __init__(self, logger, args, data_path, dataset_type): # determine is_training status now as dataset_type might be modfied later for file accessing self.is_training = dataset_type == "train" self.dataset_type = dataset_type + # add args.debugTrainCode which use small sample data + # add args.debugTrainTrain which uses dev data - if args.debug: + if args.debugTrain: self.data_path = data_path.replace("train", "dev") # under debug # we don't want to save train file as dev # we want to load dev file as train (we simply don't save) dataset_type_for_file_accessing = "dev" else: - if args.fine_tune: - logger.info( - "Not AmbigQA test dataset available, using dev dataset") - if not self.is_training: - dataset_type_for_file_accessing = "dev" # fine tuning stage + if args.debugCode: + dataset_type_for_file_accessing = "debug" + else: + + if args.fine_tune: + logger.info( + "Not AmbigQA test dataset available, using dev dataset") + if not self.is_training: + dataset_type_for_file_accessing = "dev" # fine tuning stage + else: + dataset_type_for_file_accessing = dataset_type else: dataset_type_for_file_accessing = dataset_type - else: - dataset_type_for_file_accessing = dataset_type # NOTE: self.data is the original data. Not tokenized nor encoded. with open(self.data_path, "r") as f: # format example: [ {'id': '-8178292525996414464', 'question': 'big little lies season 2 how many episodes', 'answer': ['seven']}, ..... ] @@ -73,7 +80,7 @@ def __init__(self, logger, args, data_path, dataset_type): if type(self.data) == dict: self.data = self.data["data"] self.processed_data = None - if args.debug : + if args.debugTrain : if self.is_training == False: logger.warn("[DEBUG MODE] Load all dev data") self.data = self.data[:100] @@ -93,7 +100,7 @@ def __init__(self, logger, args, data_path, dataset_type): # TODO: correct it back self.load = True # debug mode also needs load - # self.load = not args.debug # do not load the large tokenized dataset + # self.load = not args.debugTrain # do not load the large tokenized dataset self.logger = logger self.args = args if "test" in self.data_path: @@ -110,18 +117,19 @@ def __init__(self, logger, args, data_path, dataset_type): self.dataset = None self.dataloader = None self.cache = None - self.debug = args.debug - self.answer_type = "span" if "extraction" in args.predict_type.lower() else "seq" + self.debugTrain = args.debugTrain + self.debugCode = args.debugCode + self.answer_type = "span" if "extraction" in args.predict_type.lower() else "seq" self.dataset_name = None # ambig or nq self.passages = None if self.args.passage_clustering: # only need to load when using passage clustering self.clustered_passages_path = "data/clustering_results/AmbigQA_" - postfix = ["top", self.args.top_k_passages, "passages", + postfix = ["top", self.args.num_top_passages, "passages", self.data_type, "is_training", self.is_training, "is_contrastive", self.args.is_contrastive, "rank_threshold", self.args.rank_threshold] postfix = [str(x) for x in postfix] postfix = "_".join(postfix) - if self.args.debug: + if self.args.debugTrain: postfix += "_debug" # it might affect the number of data self.clustered_passages_path += postfix @@ -156,10 +164,10 @@ def __init__(self, logger, args, data_path, dataset_type): args.ranking_folder_path, f"{ranking_file_name}{dataset_type_for_file_accessing}.json") self.data_path = os.path.join( args.data_folder_path, f"{data_file_n}{dataset_type_for_file_accessing}.json") - self.top_k_passages = args.top_k_passages + self.top_k_passages = args.num_top_passages self.metric = "EM" if self.dataset_name == "nq" else "F1" - self.sep_token = "" - self.spaced_sep_token = " " + self.sep_token + " " + self.sep_token = self.tokenizer.sep_token + # self.sep_token = " " + self.sep_token + " " self.logging_prefix = None @@ -234,6 +242,21 @@ def init_top_k_passages(self): self.passages = topKPassasages(self.args.k_cluster, self.wiki_passage_path, self.ranking_path, self.data_path) def load_dataset(self, tokenizer, do_return=False): + """ + Loads encoded data into dataset class. + The pipeline is as follows, and it starts loading/processing + from the first data. + text data -> tokens data -> encoded data + + + Args: + tokenizer: Pre-trained tokenizer + do_return (bool, optional): True for returning dataset, + False otherwise. Defaults to False. + + Returns: + tokenized dataset. + """ self.logging_prefix = f"[{self.dataset_type} data]\t".upper() self.tokenizer = tokenizer @@ -255,8 +278,10 @@ def load_dataset(self, tokenizer, do_return=False): self.top_k_passages, "rank_threshold", self.args.rank_threshold ,self.answer_type, "answers", self.args.augment_k_times, "augmentation", "is_training", self.is_training] postfix = [str(x) for x in postfix] postfix = "_".join(postfix) - if self.debug: - postfix += "_debug" + if self.debugTrain: + postfix += "_debugTrain" + if self.debugCode: + postfix += "_debugCode" if self.args.passage_clustering: postfix += "_clustered" @@ -284,7 +309,7 @@ def safe_remove(file_path): def remove_confirmation_prompt(file_name): prompt = input( f"Confirm to remove {file_name}? (y/n) ").lower() - return prompt == "yes" or prompt == "y" + return prompt.lower() == "yes" or prompt.lower() == "y" if self.args.retokenize == True: if remove_confirmation_prompt("tokenization file"): safe_remove(tokenized_path) @@ -308,8 +333,9 @@ def remove_confirmation_prompt(file_name): # General procedure: # 1. check if pickle cache exists # 2. if not, check if tokenized data exists - # 3. if not, preprocess(load passages and encode) from scratch + # 3. if not, preprocess(load passages and encode) from raw text data if self.load and self.cache: + # found encoded data self.logger.info( self.logging_prefix + f"Found pickle cache, start loading {encoded_input_path}") if self.answer_type == "seq": @@ -328,7 +354,6 @@ def remove_confirmation_prompt(file_name): if self.dataset_name == "ambig": for (idx, joined_answers) in enumerate(joined_answers_l): self.data[idx]["answers"] = joined_answers - # inputs are lists of integers elif self.answer_type == "span": d = preprocess_span_input( @@ -349,12 +374,14 @@ def remove_confirmation_prompt(file_name): else: self.logger.warn("wrong answer type") exit() - else: # not found pickle cache + else: + # not found pickle cache self.logger.info(self.logging_prefix + "Not found pickle cache, start preprocessing...") - if self.load and os.path.exists(tokenized_path): # found tokenized path + if self.load and os.path.exists(tokenized_path): + # not found pickle cache -> found tokenized path self.logger.info( self.logging_prefix + "Loading pre-tokenized data from {}".format(tokenized_path)) with open(tokenized_path, "r") as f: @@ -372,7 +399,8 @@ def remove_confirmation_prompt(file_name): exit() self.logger.info( self.logging_prefix + f"Passage kept rate(after truncation): {passage_coverage_rate * 100} %") - else: # not found tokenized data + else: + # not found pickle cache -> not found tokenized data self.logger.info( self.logging_prefix + "Not found tokenized data, start tokenizing...") @@ -391,7 +419,6 @@ def remove_confirmation_prompt(file_name): for (idx, data_entry) in enumerate(self.data): cur_answer = [] - # Q: does data_entry has more than one annotations? Or each answer is categorized for qa_d in data_entry["annotations"]: # import pdb @@ -437,10 +464,10 @@ def remove_confirmation_prompt(file_name): self.logger.info(self.logging_prefix + "Start concatenating question and passages ") - + # tokenize questions, passages and answers if self.answer_type == "seq": if self.dataset_name == "nq": # nq seq answer - qp = [" " + q for q in questions] + qp = [self.tokenizer.bos_token + q for q in questions] # TODO: add them to arguments # note that after this questions are actually a concatenation of questions and passages self.logger.info(self.logging_prefix + f"Start concatenating question and passages for top {self.top_k_passages} passages") @@ -448,14 +475,13 @@ def remove_confirmation_prompt(file_name): self.top_k_passages, self.wiki_passage_path, self.ranking_path, self.data_path) for i in tqdm(range(len(qp))): # mark the begining of passages - qp[i] += " " # add passage one by one - for p in self.passages.get_passages(i, self.args.top_k_passages): + for p in self.passages.get_passages(i, self.args.num_top_passages): # format: [CLS] question [SEP] title 1 [SEP] passages - qp[i] += self.spaced_sep_token + \ - p["title"] + self.spaced_sep_token + p["text"] + qp[i] += self.sep_token + \ + p["title"] + self.sep_token + p["text"] # mark the begining of passages - qp[i] += " " + qp[i] += self.tokenizer.eos_token question_metadata = None answer_metadata = None # NOTE: no need to rename @@ -466,15 +492,13 @@ def remove_confirmation_prompt(file_name): elif self.dataset_name == "ambig": # ambig seq answer # TODO: add function pre_process in utils.py if prepend_question_token: # T5 - qp = [" question: " + + qp = [tokenizer.bos_token + "question: " + question for question in questions] # t5 tokenizer doesn't have else: - qp = [" " + q for q in questions] # Bart - qp = [q + " " for q in qp] - questions_with_clustered_passages = [] - # TODO: add them to arguments - # note that after this questions are actually a concatenation of questions and passages - all_qp_concatenation_list = [] + qp = [tokenizer.bos_token + q for q in questions] + # Bart + qp = [q + tokenizer.eos_token for q in qp] + self.logger.info( self.logging_prefix + f"Start concatenating question and passages for top {self.top_k_passages} passages") # import pdb; pdb.set_trace() @@ -564,7 +588,7 @@ def remove_confirmation_prompt(file_name): cur_titles = [] cur_passages = [] - for p in self.passages.get_passages(i, self.args.top_k_passages): + for p in self.passages.get_passages(i, self.args.num_top_passages): cur_titles.append(p["title"]) cur_passages.append(p["text"]) all_titles.append(cur_titles) @@ -942,17 +966,27 @@ def __init__(self, k_cluster, passages_path, rank_path, data_path, passage_embed def set_passage_embeddings(self, passage_embeddings): self.passage_embeddings = passage_embeddings - def get_clustered_passages(self, i, rank_threshold): - """Indexed on quesiton id and return clusters of passages + def get_clustered_passages(self, q_id: int, + clustering_method: str="kmeans", + rank_threshold=100): + """ + Get a list of clustered passages given a quesiton id. Args: - i ([type]): [description] + q_id (int): zero-based question id. + clustering_method: kmeans, spectral, x-means + rank_threshold (int): hard rank threshold ranges from 0-100. + I don't think it's useful now since we are adding reranker + and let it be 100 to keep all passages now. Returns: - [type]: [description] + List[List[numpy.array]]: a list of lists of passages. + the size is (num_cluster, num_passages_in_the_cluster). + cluster is ordered by the closeness to question. + cluster_passages under one cluster is ordered by + the closeness to question. """ - passage_embeddings = self.get_passage_embeddings( - i) + q_id) kmeans_1 = KMeans(n_clusters=self.k_cluster, random_state=0).fit(passage_embeddings) # compute stat of clusters @@ -961,65 +995,25 @@ def get_clustered_passages(self, i, rank_threshold): cluster_pts_count[j] = sum( kmeans_1.labels_ == j) - cluster_ranks = dict() - - - # add up ranks - for j in range(len(kmeans_1.labels_)): # count up to the number of points - cluster_label = kmeans_1.labels_[j] - # TODO: defaultdict - if cluster_label in cluster_ranks.keys(): - cluster_ranks[cluster_label] += j - else: - cluster_ranks[cluster_label] = j - # average ranks - for j in range(self.k_cluster): - cluster_ranks[j] /= cluster_pts_count[j] - - sorted_cluster_ranks= sorted(cluster_ranks.items(), - key=lambda item: item[1]) - # we want the smallest few. (ranked higher) - print(sorted_cluster_ranks) - - - - # add top-k cluster - filtered_clusters = [] - for (cluster_label, avg_rank) in sorted_cluster_ranks: - if avg_rank < rank_threshold: - filtered_clusters.append( - cluster_label) - else: - break - if len(filtered_clusters) == 0: - filtered_clusters.append(sorted_cluster_ranks[0][0]) # append the first cluster label - - passages = [] - passage_ids = self.ranks[i] + passage_ids = self.ranks[q_id] num_cluster_passages_l = [] # add top 5 passages' ids from each cluster for cluster_label in filtered_clusters: - cluster_passages = [] for j in range(len(kmeans_1.labels_)): # iterate all cluster labels and keep the order if kmeans_1.labels_[j] == cluster_label: passage_id = passage_ids[j] - cluster_passages.append( self.passages[passage_id]) - if len(cluster_passages) == 5: # TODO: change 5 to top_k argument - break num_cluster_passages_l.append(len(cluster_passages) ) - assert len(cluster_passages) > 0 and len(cluster_passages) <= 5, "each cluster should have more than one passages and less than five passages" + assert len(cluster_passages) > 0, "each cluster should have more than one passages and less than five passages" passages.append(cluster_passages) assert len(passages) >= 1, "There should be more than one cluster" - num_cluster_for_question_i = len(num_cluster_passages_l) - num_selected_cluster_passages_for_question_i = sum(num_cluster_passages_l) - - return passages, num_cluster_for_question_i, num_selected_cluster_passages_for_question_i + + return passages def get_passages(self, i, k): diff --git a/run.py b/run.py index 46d0224..9ec7160 100644 --- a/run.py +++ b/run.py @@ -97,20 +97,18 @@ def run(args, logger): logger.info("Add token into tokenizer") # add extra token for BART tokenizer.add_tokens([""], special_tokens=True) + tokenizer.sep_token = "" + # for toeknizer doesn't have bos tokens if tokenizer.bos_token_id == None: tokenizer.add_tokens([""], special_tokens=True) - - - - if args.do_tokenize: # during the process train_data will be overwritten, so memory will be collected for k in range(5, 15): for l in range(600, 800, 50): print("Evaluate passage coverage for top ", k, "passages for max input sequence length ", l) - args.top_k_passages = k + args.num_top_passages = k args.max_input_length = l train_data = QAData(logger, args, args.train_file, "train") if args.do_predict: diff --git a/span_utils.py b/span_utils.py index af63573..9c386c5 100644 --- a/span_utils.py +++ b/span_utils.py @@ -200,17 +200,60 @@ def is_answer_a_date_or_infreq_digit(answer_str): def preprocess_qpa(questions, question_ids, passages, answers, metadata, data, - top_k_passages, tokenizer, + num_top_passages, tokenizer, answer_type, is_training, is_ambig, args, logging_prefix, logger, rank_threshold=None, clustered_passages_path=None) -> dict(): - """ Process question, passages and answers. + """ + Concatenate question and passages, answers. + First tokenize and then concatenate them with using tokenizer sep id. + Args: + questions (List[str]): a list of question text data. + list length is the number of question. + question_ids (List[str]): a list of question ids. + list length is the number of question. + answers (List[str]): a list of flattened answers. + list length is the number of flattened answers. + metadata (List[List[tuple]]): a list of lists of tuples. + (num_questions, num answer semantics, num_answer of one semantic) + data (Dict[str, str]): + annotation: a list of answer semantics in dictionary format + with annotated answer type. For example, + 'annotations': [{'type': 'singleAnswer', 'answer': ['usually continues uninterrupted until death']}, {'type': 'singleAnswer', 'answer': ['constant', 'usually continues uninterrupted until death']}] + id: question id. + question: question text. + num_top_passages (int): + number of top passages to select. + tokenizer: Pre-trained tokenizer + answer_type(str): 'seq' for seq2seq model output and 'span' for + SpanExtraction model output. It has effect on answer preprocess. + is_training (bool): True for training mode and only answer presented in + passages will be kept. False for evaluation/test mode and all + answers will be kept. + is_ambig (bool): True for preprocess NQ dataset, and False for + preprocess AMBIG dataset. + args: parsed argument. + logging_prefix(str): prefix logging dataset types. + for example, '[TEST DATA]\t'. + logger: logger. + rank_threshold: might be deleted later as we are adding reranker. + clustered_passages_path: + clustered passage tokens pickle data. Returns: - [type]: [description] + Dict[str, Data]: Dictionary stores token data for later encoding. + qpa_dict["qp"]: concatenation of the question and passages + qpa_dict["question_ids"] = question_ids + qpa_dict["answers"] = answers + qpa_dict["question_metadata"] = question_metadata + qpa_dict["answer_metadata"] = answer_metadata + qpa_dict["joined_answers_l"] = joined_answers_l + qpa_dict["data"] = data """ + import pdb; pdb.set_trace() + print('check preprocess_qpa args') @@ -223,12 +266,6 @@ def preprocess_qpa(questions, question_ids, passages, answers, metadata, data, # TODO: check where the 5gb GPU memory comes from by setting pdb # something in training mode (as in prediciton mode there is no such a memory) - # TODO: contrastive - # encode file name (add contrastive, if not contrastive, don't add contrastive (keep it the same)) - # tokenize file name - # each cluster will have at most one positive example and at most one negative example - # - # # TODO: provide more clustering analytics. Given a question, how is the clustering? # top-k passages contain the answer and top-k passages doesn't contain the answer. @@ -248,8 +285,10 @@ def preprocess_qpa(questions, question_ids, passages, answers, metadata, data, assert rank_threshold is not None, "PC mode: there should be a PC rank threhold." assert clustered_passages_path is not None, "PC mode: there should be a clustered_passages_path" - sep_token = "" - spaced_sep_token = " " + sep_token + " " + sep_token = tokenizer.sep_token + eos_token = tokenizer.eos_token + + sep_token = " " + sep_token + " " question_metadata = [] joined_answers_l = [] @@ -262,14 +301,27 @@ def preprocess_qpa(questions, question_ids, passages, answers, metadata, data, # end of question and start of passages questions[i] += " " # add passage one by one - for p in passages.get_passages(i, top_k_passages): + for p in passages.get_passages(i, num_top_passages): # format: [CLS] question [SEP] title 1 [SEP] passages - questions[i] += spaced_sep_token + \ - p["title"] + spaced_sep_token + p["text"] + questions[i] += sep_token + \ + p["title"] + sep_token + p["text"] # mark the begining of passages questions[i] += " " else: if args.passage_clustering: + + clustered_raw_data_path = "" + if not os.path.exists(clustered_raw_data_path): + # clustering + # store all clustering data in tokens to a json file + for (i, cur_md) in enumerate(metadata): + clusters_passages = passages.get_clustered_passages( + i, rank_threshold=100) # 2-d list + # load json file. + + + + logger.info( logging_prefix + "Concatenating clustering results...") assert len(question_ids) == len( @@ -303,8 +355,8 @@ def preprocess_qpa(questions, question_ids, passages, answers, metadata, data, qp_d = questions_n_passages[-1] qp_d["pos"] = [] qp_d["neg"] = [] - # 1. needs truncation here? probably not, we can directly check. - # 2. check presence of answer. + # 1. needs truncation here? probably not, we can directly check. + # 2. check presence of answer. for p_cluster in clusters_passages: # it's ordered # reset qp concatenation cluster_qp_concatenation = questions[i] @@ -317,24 +369,24 @@ def preprocess_qpa(questions, question_ids, passages, answers, metadata, data, if is_answer_set_in_passsages(cur_md, p["text"], answers): if pos_start: pos_cluster_qp_concatenation += p["title"] + \ - spaced_sep_token + \ + sep_token + \ p["text"] pos_start = False else: - pos_cluster_qp_concatenation += spaced_sep_token + \ + pos_cluster_qp_concatenation += sep_token + \ p["title"] + \ - spaced_sep_token + \ + sep_token + \ p["text"] else: if neg_start: neg_cluster_qp_concatenation += p["title"] + \ - spaced_sep_token + \ + sep_token + \ p["text"] neg_start = False else: - neg_cluster_qp_concatenation += spaced_sep_token + \ + neg_cluster_qp_concatenation += sep_token + \ p["title"] + \ - spaced_sep_token + \ + sep_token + \ p["text"] pos_cluster_qp_concatenation += " " neg_cluster_qp_concatenation += " " @@ -375,12 +427,12 @@ def preprocess_qpa(questions, question_ids, passages, answers, metadata, data, # format: [CLS] question [SEP] title 1 [SEP] passages if start: cluster_qp_concatenation += p["title"] + \ - spaced_sep_token + \ + sep_token + \ p["text"] else: - cluster_qp_concatenation += spaced_sep_token + \ + cluster_qp_concatenation += sep_token + \ p["title"] + \ - spaced_sep_token + \ + sep_token + \ p["text"] start = False cluster_qp_concatenation += " " @@ -411,15 +463,15 @@ def preprocess_qpa(questions, question_ids, passages, answers, metadata, data, # add passage one by one start = True # NOTE: get passage clustering - for p in passages.get_passages(i, args.top_k_passages): + for p in passages.get_passages(i, num_top_passages): # format: [CLS] question [SEP] title 1 [SEP] passages if start: questions_n_passages[i] += p["title"] + \ - spaced_sep_token + p["text"] + sep_token + p["text"] else: - questions_n_passages[i] += spaced_sep_token + \ - p["title"] + spaced_sep_token + p["text"] + questions_n_passages[i] += sep_token + \ + p["title"] + sep_token + p["text"] start = False questions_n_passages[i] += " " @@ -440,7 +492,6 @@ def preprocess_qpa(questions, question_ids, passages, answers, metadata, data, metadata), (len(questions_n_passages), len(metadata)) # format QP and A for idx, (cur_qp, cur_md) in enumerate(zip(questions_n_passages, metadata)): - # import pdb; pdb.set_trace() found_answers_for_one_question = [] # check existance of answers for latter joining (for evaluation) for cur_md_for_qa_pair in cur_md: @@ -457,7 +508,7 @@ def preprocess_qpa(questions, question_ids, passages, answers, metadata, data, p_str = get_p_str(cur_qp_str, tokenizer, args.max_input_length) - if is_training and not args.debug: + if is_training and not args.debugTrain: if is_answer_in_passages(cur_a_str, p_str): found_answer_for_qa_pair.append( cur_a_str) @@ -469,7 +520,7 @@ def preprocess_qpa(questions, question_ids, passages, answers, metadata, data, p_str = get_p_str(cur_qp, tokenizer, args.max_input_length) - if is_training and not args.debug: + if is_training and not args.debugTrain: if is_answer_in_passages(cur_a_str, p_str): found_answer_for_qa_pair.append( cur_a_str) From bab98ea023eb897e175600c4fd203a8ba65c203c Mon Sep 17 00:00:00 2001 From: murphy Date: Fri, 10 Dec 2021 10:56:18 -0800 Subject: [PATCH 2/2] implemented clustering text results saving --- cli.py | 1 - data3.py | 148 ++++++++++++++++++--------------- span_utils.py | 224 ++++++++++++++++++++++---------------------------- 3 files changed, 179 insertions(+), 194 deletions(-) diff --git a/cli.py b/cli.py index a691a21..6be40b5 100755 --- a/cli.py +++ b/cli.py @@ -132,7 +132,6 @@ def main(): parser.add_argument("--passage_clustering", default=False, action="store_true") parser.add_argument("--k_cluster", default = 10, type=int) - parser.add_argument("--rank_threshold", default=100, type=int) parser.add_argument("--is_contrastive", default=False, action="store_true") diff --git a/data3.py b/data3.py index f3394a1..f04f2ac 100644 --- a/data3.py +++ b/data3.py @@ -22,11 +22,9 @@ from span_utils import preprocess_span_input, eval, preprocess_qpa, dump_pickle, load_pickle from numpy import random -from sklearn.cluster import KMeans +from sklearn.cluster import KMeans, SpectralClustering +from pyclustering.cluster.xmeans import xmeans import multiprocessing as mp - - - import csv @@ -52,27 +50,28 @@ def __init__(self, logger, args, data_path, dataset_type): self.dataset_type = dataset_type # add args.debugTrainCode which use small sample data # add args.debugTrainTrain which uses dev data - - if args.debugTrain: - self.data_path = data_path.replace("train", "dev") - # under debug - # we don't want to save train file as dev - # we want to load dev file as train (we simply don't save) - dataset_type_for_file_accessing = "dev" - else: + + if args.debugTrain or args.debugCode: + if args.debugTrain: + self.data_path = data_path.replace("train", "dev") + # under debug + # we don't want to save train file as dev + # we want to load dev file as train (we simply don't save) + dataset_type_for_file_accessing = "dev" if args.debugCode: dataset_type_for_file_accessing = "debug" - else: - - if args.fine_tune: - logger.info( - "Not AmbigQA test dataset available, using dev dataset") - if not self.is_training: - dataset_type_for_file_accessing = "dev" # fine tuning stage - else: - dataset_type_for_file_accessing = dataset_type + self.data_path = data_path.replace("dev", "debug") + else: + # neither debugTrain nor debugCode + if args.fine_tune: + logger.info( + "Not AmbigQA test dataset available, using dev dataset") + if not self.is_training: + dataset_type_for_file_accessing = "dev" # fine tuning stage else: dataset_type_for_file_accessing = dataset_type + else: + dataset_type_for_file_accessing = dataset_type # NOTE: self.data is the original data. Not tokenized nor encoded. with open(self.data_path, "r") as f: # format example: [ {'id': '-8178292525996414464', 'question': 'big little lies season 2 how many episodes', 'answer': ['seven']}, ..... ] @@ -103,12 +102,15 @@ def __init__(self, logger, args, data_path, dataset_type): # self.load = not args.debugTrain # do not load the large tokenized dataset self.logger = logger self.args = args + # set self.data_type for logging if "test" in self.data_path: self.data_type = "test" elif "dev" in self.data_path: self.data_type = "dev" elif "train" in self.data_path: self.data_type = "train" + elif "debug" in self.data_path: + self.data_type = "debug" else: raise NotImplementedError() @@ -126,7 +128,7 @@ def __init__(self, logger, args, data_path, dataset_type): if self.args.passage_clustering: # only need to load when using passage clustering self.clustered_passages_path = "data/clustering_results/AmbigQA_" postfix = ["top", self.args.num_top_passages, "passages", - self.data_type, "is_training", self.is_training, "is_contrastive", self.args.is_contrastive, "rank_threshold", self.args.rank_threshold] + self.data_type, "is_training", self.is_training, "is_contrastive", self.args.is_contrastive] postfix = [str(x) for x in postfix] postfix = "_".join(postfix) if self.args.debugTrain: @@ -166,10 +168,9 @@ def __init__(self, logger, args, data_path, dataset_type): args.data_folder_path, f"{data_file_n}{dataset_type_for_file_accessing}.json") self.top_k_passages = args.num_top_passages self.metric = "EM" if self.dataset_name == "nq" else "F1" - self.sep_token = self.tokenizer.sep_token + self.sep_token = "" # self.sep_token = " " + self.sep_token + " " - self.logging_prefix = None def __len__(self): @@ -272,10 +273,10 @@ def load_dataset(self, tokenizer, do_return=False): prepend_question_token = True if self.args.augment_k_times == 1: postfix = [postfix, "max_input_length", self.max_input_length, "top", - self.top_k_passages, "rank_threshold", self.args.rank_threshold, self.answer_type, "is_training", self.is_training] # TODO: can be written more elegantly by using dictionary + self.top_k_passages, self.answer_type, "is_training", self.is_training] # TODO: can be written more elegantly by using dictionary else: postfix = [postfix, "max_input_length", self.max_input_length, "top", - self.top_k_passages, "rank_threshold", self.args.rank_threshold ,self.answer_type, "answers", self.args.augment_k_times, "augmentation", "is_training", self.is_training] + self.top_k_passages, self.answer_type, "answers", self.args.augment_k_times, "augmentation", "is_training", self.is_training] postfix = [str(x) for x in postfix] postfix = "_".join(postfix) if self.debugTrain: @@ -455,7 +456,7 @@ def remove_confirmation_prompt(file_name): # flatten answer list answers, metadata = self.flatten( answers, self.dataset_name == "ambig") - + if self.args.do_lowercase: questions = [question.lower() for question in questions] answers = [answer.lower() for answer in answers] @@ -528,7 +529,7 @@ def remove_confirmation_prompt(file_name): self.top_k_passages, self.tokenizer, self.answer_type, self.is_training, True, self.args, self.logging_prefix, self.logger, - self.args.rank_threshold, clustered_passages_path) + clustered_passages_path) qp = qpa_dict["qp"] self.question_ids = qpa_dict["question_ids"] answers = qpa_dict["answers"] @@ -930,7 +931,23 @@ class topKPassasages(): This class serves as a modular way of retrieving top k passages of a question for reader """ - def __init__(self, k_cluster, passages_path, rank_path, data_path, passage_embedding = None, evaluate=False): + def __init__(self, k_cluster, passages_path, rank_path, data_path, + passage_embedding = None, evaluate=False): + """ + Loading passages and their embeddings. + + self.passages 'title', 'text' + + Args: + k_cluster (int): the number of clusters for clustering. + passages_path (str): the path to load raw wiki passage data. + rank_path (str): the path to load (DPR) ranking results. + data_path (str): the path to load QA data. + passage_embedding (str, optional): the path to load (DPR) + passage embeddings. + evaluate (bool, optional): True for evaluating top k passage MACRO + recall. False otherwise. Defaults to False. + """ # load wiki passages and store in dictionary # a list of lists of passsages ids [ [ 3,5, ], ... ] @@ -967,61 +984,63 @@ def set_passage_embeddings(self, passage_embeddings): self.passage_embeddings = passage_embeddings def get_clustered_passages(self, q_id: int, - clustering_method: str="kmeans", - rank_threshold=100): + clustering_method: str="kmeans"): """ Get a list of clustered passages given a quesiton id. + It first retrieves top k passage embeddings given question id, + then perform clustering method on passage embeddings, + we then add passages based on passage indexing on self.passages + Args: q_id (int): zero-based question id. clustering_method: kmeans, spectral, x-means - rank_threshold (int): hard rank threshold ranges from 0-100. - I don't think it's useful now since we are adding reranker - and let it be 100 to keep all passages now. + Returns: - List[List[numpy.array]]: a list of lists of passages. - the size is (num_cluster, num_passages_in_the_cluster). + Dict[str, List[numpy.array]]: key is cluster label in str. + value is a list of clustered passages. cluster is ordered by the closeness to question. cluster_passages under one cluster is ordered by the closeness to question. """ passage_embeddings = self.get_passage_embeddings( q_id) - kmeans_1 = KMeans(n_clusters=self.k_cluster, - random_state=0).fit(passage_embeddings) - # compute stat of clusters - cluster_pts_count = dict() - for j in range(self.k_cluster): - cluster_pts_count[j] = sum( - kmeans_1.labels_ == j) - - passages = [] + if clustering_method == "kmeans": + clustering_obj = KMeans(n_clusters=self.k_cluster, + random_state=0).fit(passage_embeddings) + elif clustering_method == "spectral": + clustering_obj = SpectralClustering(n_clusters=self.k_cluster).fit(passage_embeddings) + elif clustering_method == "xmeans": + # Create instance of X-Means algorithm. The algorithm will start analysis from 2 clusters, the maximum. + # number of clusters that can be allocated is 20. + clustering_obj = xmeans(passage_embeddings, 10).process() + else: + raise NotImplementedError(f"clustering_method {clustering_method}" + "is not implemented. Please input" + " implemented ones such as kmeans, " + "spectral or xmeans.") passage_ids = self.ranks[q_id] - - num_cluster_passages_l = [] - # add top 5 passages' ids from each cluster - for cluster_label in filtered_clusters: - cluster_passages = [] - for j in range(len(kmeans_1.labels_)): # iterate all cluster labels and keep the order - - if kmeans_1.labels_[j] == cluster_label: - passage_id = passage_ids[j] - cluster_passages.append( - self.passages[passage_id]) - num_cluster_passages_l.append(len(cluster_passages) ) - assert len(cluster_passages) > 0, "each cluster should have more than one passages and less than five passages" - passages.append(cluster_passages) - assert len(passages) >= 1, "There should be more than one cluster" - - return passages + # clustering results might have more or less clusters + # we need to be adaptive to the various number of clusters + passage_clusters = defaultdict(list) + for c_label, p_id in zip(clustering_obj.labels_, passage_ids): + passage = self.passages[p_id] + # cluster ranking is inherantly in c_label + # passage ranking (w.r.t question) is also ordered in list + passage_clusters[str(c_label)].append(passage) + return passage_clusters def get_passages(self, i, k): """ 0-indexed based retrieval to get top k passages. Note that rank, answers and passages are lists with the same length - :param i: index - :return: a list of passage dictionary {title:str, text:str} + Args: + i (int): index of question id. + k (int): the number of passages to be retrieved. + + Returns: + List: a list of passage dictionary {title:str, text:str} """ top_k_ranked_passage_ids = self.ranks[i][:k] # get rank prediction @@ -1030,7 +1049,6 @@ def get_passages(self, i, k): def get_passage_embeddings(self, i): assert self.passage_embeddings is not None, "passage embedding is not loaded" passage_embeddings = [] - for passage_id in self.ranks[i]: try: passage_embeddings.append(self.passage_embeddings[passage_id]) diff --git a/span_utils.py b/span_utils.py index 9c386c5..004d26a 100644 --- a/span_utils.py +++ b/span_utils.py @@ -1,4 +1,5 @@ +from json.encoder import py_encode_basestring_ascii import pdb from re import A from IPython import embed @@ -101,7 +102,6 @@ def eval(predictions, data, eval_fn, normaliza_fn, for (prediction, dp) in zip(predictions, data): # there are many concatenation of answers and they are all correct # we append the one with the highest score - eval_scores.append(eval_fn( prediction, dp["answer"])) else: @@ -113,41 +113,42 @@ def eval(predictions, data, eval_fn, normaliza_fn, -def is_answer_set_in_passsages(answer_md, p_str, answers, remove_answer = False): - """check if a passage contain any answer in the answer set +def check_answer_presence(answer_md, p_str, answers): + """ + Check if a passage contains any answer in the answer set. Args: - answer_md ([type]): [description] - p_str ([type]): [description] - answers ([type]): [description] - remove_answer (bool): remove answer from matadata so as to + answer_md (Tuple[int, int]): answer metadata. (start_idx, end_idx) + p_str (str): passage string which is where the answer is searched. + answers (List[str]): a list of all flattened answers + remove_answer (bool): True for removing answer from metadata so as to + not checking repetitive answers. Defaults to False. Returns: - [type]: [description] + Union[str, None]: answer text found in the input passage. + If not found, it will return None. """ for cur_md_for_qa_pair in answer_md: for start, end in cur_md_for_qa_pair: answer_for_qa_pair = answers[start:end] for cur_a_str in answer_for_qa_pair: if is_answer_in_passages(cur_a_str, p_str): - if remove_answer: - answer_md.remove(cur_md_for_qa_pair) - return True, answer_md - else: - return True - if remove_answer: - return False, answer_md - else: - return False + return cur_a_str + return None def is_answer_in_passages(answer_str, p_str): - """check the existance of answer in passages by comparing string + """ + check the existence of an answer in passages by comparing string Args: - idx ([type]): [description] + answer_str(str): answering string. + p_str(str): passage string. + + Returns: + bool: True for answer found in passage string. False otherwise. """ return answer_str.lower() in p_str.lower() @@ -202,8 +203,7 @@ def is_answer_a_date_or_infreq_digit(answer_str): def preprocess_qpa(questions, question_ids, passages, answers, metadata, data, num_top_passages, tokenizer, answer_type, is_training, is_ambig, args, - logging_prefix, logger, - rank_threshold=None, clustered_passages_path=None) -> dict(): + logging_prefix, logger,clustered_passages_path=None) -> dict(): """ Concatenate question and passages, answers. First tokenize and then concatenate them with using tokenizer sep id. @@ -213,6 +213,7 @@ def preprocess_qpa(questions, question_ids, passages, answers, metadata, data, list length is the number of question. question_ids (List[str]): a list of question ids. list length is the number of question. + passages: topKpassage instance. answers (List[str]): a list of flattened answers. list length is the number of flattened answers. metadata (List[List[tuple]]): a list of lists of tuples. @@ -237,7 +238,6 @@ def preprocess_qpa(questions, question_ids, passages, answers, metadata, data, logging_prefix(str): prefix logging dataset types. for example, '[TEST DATA]\t'. logger: logger. - rank_threshold: might be deleted later as we are adding reranker. clustered_passages_path: clustered passage tokens pickle data. @@ -251,11 +251,6 @@ def preprocess_qpa(questions, question_ids, passages, answers, metadata, data, qpa_dict["joined_answers_l"] = joined_answers_l qpa_dict["data"] = data """ - - import pdb; pdb.set_trace() - print('check preprocess_qpa args') - - # TODO: test ambig bart first # TODO: dump dictionary @@ -282,20 +277,19 @@ def preprocess_qpa(questions, question_ids, passages, answers, metadata, data, questions_n_passages = [] if args.passage_clustering: assert is_ambig == True, "PC mode: must be for ambig or multi-answer datasets" - assert rank_threshold is not None, "PC mode: there should be a PC rank threhold." assert clustered_passages_path is not None, "PC mode: there should be a clustered_passages_path" sep_token = tokenizer.sep_token eos_token = tokenizer.eos_token - + sep_token = " " + sep_token + " " question_metadata = [] joined_answers_l = [] empty_answer_str = " " - - if not is_ambig: # nq dataset + if not is_ambig: + # nq dataset (no pc mode) for i in tqdm(range(len(questions))): # mark the begining of passages # end of question and start of passages @@ -308,19 +302,39 @@ def preprocess_qpa(questions, question_ids, passages, answers, metadata, data, # mark the begining of passages questions[i] += " " else: + # ambig -> pc if args.passage_clustering: - clustered_raw_data_path = "" - if not os.path.exists(clustered_raw_data_path): + clustered_raw_data_path = "/home/murphy/Downloads/2021Winter/bart-closed-book-qa-4.3.3/data/clustering_text_results/test.json" + qpca = {} # ordered question, passage clusters and answers + if not os.path.exists(clustered_raw_data_path) or True: # clustering # store all clustering data in tokens to a json file - for (i, cur_md) in enumerate(metadata): - clusters_passages = passages.get_clustered_passages( - i, rank_threshold=100) # 2-d list - # load json file. - - + for (q_id, cur_md) in enumerate(metadata): + passage_clusters = passages.get_clustered_passages( + q_id) # 2-d list + qpca[str(q_id)] = {} + qpca[str(q_id)]["question"] = questions[q_id] + qpca[str(q_id)]["answers"] = answers[cur_md[0]:cur_md[1]] + qpca[str(q_id)]["passage_clusters"] = [] + # add answers to passage clusters + for c_label in sorted(passage_clusters.keys()): + p_clusters = passage_clusters[c_label] + # p_clusters is a list of passage dictionary + # dict format: [str, str] + # keys are title, text + # we want to add answer data field + for (j, p) in enumerate(p_clusters): + found_answer = check_answer_presence(cur_md, p["text"], answers) + if found_answer: + p_clusters[j]['answer'] = found_answer + qpca[str(q_id)]["passage_clusters"].append(passage_clusters) + with open(clustered_raw_data_path, "w") as f: + json.dump(qpca, f) + + import pdb; pdb.set_trace() + print('current checkpoint for the current work: save qpca') logger.info( logging_prefix + "Concatenating clustering results...") @@ -340,104 +354,58 @@ def preprocess_qpa(questions, question_ids, passages, answers, metadata, data, # iterate answer metadata for (i, cur_md) in enumerate(metadata): - clusters_passages, num_cluster_for_question_i, num_passages_for_question_i = passages.get_clustered_passages( - i, rank_threshold) # 2-d list - num_clusters += num_cluster_for_question_i - num_passages += num_passages_for_question_i + passage_clusters = passages.get_clustered_passages( + i) # 2-d list # make questions[i] a list, put index 0 a concatenation of all passsages # we want all passages because we want a joined_answer list for evaluation # Problem: they are not constrained by max_input_length correctly # and are not the actual input - # add - if args.is_contrastive: - questions_n_passages.append(dict()) - qp_d = questions_n_passages[-1] - qp_d["pos"] = [] - qp_d["neg"] = [] - # 1. needs truncation here? probably not, we can directly check. - # 2. check presence of answer. - for p_cluster in clusters_passages: # it's ordered - # reset qp concatenation - cluster_qp_concatenation = questions[i] - pos_cluster_qp_concatenation = cluster_qp_concatenation + " " - neg_cluster_qp_concatenation = cluster_qp_concatenation + " " - pos_start = True - neg_start = True - for p in p_cluster: - # format: [CLS] question [SEP] title 1 [SEP] passages - if is_answer_set_in_passsages(cur_md, p["text"], answers): - if pos_start: - pos_cluster_qp_concatenation += p["title"] + \ - sep_token + \ - p["text"] - pos_start = False - else: - pos_cluster_qp_concatenation += sep_token + \ - p["title"] + \ - sep_token + \ - p["text"] - else: - if neg_start: - neg_cluster_qp_concatenation += p["title"] + \ - sep_token + \ - p["text"] - neg_start = False - else: - neg_cluster_qp_concatenation += sep_token + \ - p["title"] + \ - sep_token + \ - p["text"] - pos_cluster_qp_concatenation += " " - neg_cluster_qp_concatenation += " " - qp_d["pos"].append( - pos_cluster_qp_concatenation) - qp_d["neg"].append( - neg_cluster_qp_concatenation) - else: - questions_n_passages.append([]) - qp_l = questions_n_passages[-1] - - - - updated_md = copy.deepcopy(cur_md) - - # i is question index and j is cluster index - for (j, p_cluster) in enumerate(clusters_passages): # it's ordered - # reset qp concatenation - cluster_qp_concatenation = questions[i] - cluster_qp_concatenation += " " - title_distribution_d[j] += len(set([p["title"] for p in p_cluster])) - start = True - for p in p_cluster: - - # updated md - found_answer, updated_md = is_answer_set_in_passsages( - updated_md, p["text"], answers,True) - if found_answer: - exclusive_answer_distribution_d[j] += 1 - - found_answer = is_answer_set_in_passsages( - cur_md, p["text"], answers) - if found_answer: - answer_distribution_d[j] += 1 + questions_n_passages.append([]) + qp_l = questions_n_passages[-1] + updated_md = copy.deepcopy(cur_md) + import pdb; pdb.set_trace() + print('add every chunk of 5 passages and answers') + + # TODO: add every chunk of 5 passages and answers + + # i is question index and j is cluster index + for (j, p_cluster) in enumerate(passage_clusters): # it's ordered + # reset qp concatenation + cluster_qp_concatenation = questions[i] + cluster_qp_concatenation += sep_token + title_distribution_d[j] += len(set([p["title"] for p in p_cluster])) + start = True + for p in p_cluster: + found_answer = check_answer_presence( + cur_md, p["text"], answers) + # updated md + # found_answer, updated_md = check_answer_presence( + # updated_md, p["text"], answers,True) + # if found_answer: + # exclusive_answer_distribution_d[j] += 1 + + # found_answer = check_answer_presence( + # cur_md, p["text"], answers) + # if found_answer: + # answer_distribution_d[j] += 1 + - - # format: [CLS] question [SEP] title 1 [SEP] passages - if start: - cluster_qp_concatenation += p["title"] + \ - sep_token + \ - p["text"] - else: - cluster_qp_concatenation += sep_token + \ - p["title"] + \ + # format: [CLS] question [SEP] title 1 [SEP] passages + if start: + cluster_qp_concatenation += p["title"] + \ sep_token + \ p["text"] - start = False - cluster_qp_concatenation += " " - qp_l.append( - cluster_qp_concatenation) + else: + cluster_qp_concatenation += sep_token + \ + p["title"] + \ + sep_token + \ + p["text"] + start = False + cluster_qp_concatenation += eos_token + qp_l.append( + cluster_qp_concatenation) for j in range(num_clusters): title_distribution_d[j] /= len(metadata)