From 5cd01f9dac39e5b0393e8e621c6a5ec0145de446 Mon Sep 17 00:00:00 2001 From: Andrew Parry Date: Fri, 13 Sep 2024 14:35:40 +0100 Subject: [PATCH] rename --- rankers/__init__.py | 33 +++ rankers/_util.py | 140 ++++++++++++ rankers/datasets/__init__.py | 2 + rankers/datasets/corpus.py | 43 ++++ rankers/datasets/dataset.py | 122 ++++++++++ rankers/datasets/loader.py | 201 ++++++++++++++++ rankers/modelling/__init__.py | 2 + rankers/modelling/cat.py | 185 +++++++++++++++ rankers/modelling/dot.py | 394 ++++++++++++++++++++++++++++++++ rankers/modelling/seq2seq.py | 162 +++++++++++++ rankers/train/__init__.py | 0 rankers/train/arguments.py | 10 + rankers/train/loss/__init__.py | 166 ++++++++++++++ rankers/train/loss/listwise.py | 182 +++++++++++++++ rankers/train/loss/pairwise.py | 69 ++++++ rankers/train/loss/pointwise.py | 16 ++ rankers/train/trainer.py | 169 ++++++++++++++ setup.py | 2 +- 18 files changed, 1897 insertions(+), 1 deletion(-) create mode 100644 rankers/__init__.py create mode 100644 rankers/_util.py create mode 100644 rankers/datasets/__init__.py create mode 100644 rankers/datasets/corpus.py create mode 100644 rankers/datasets/dataset.py create mode 100644 rankers/datasets/loader.py create mode 100644 rankers/modelling/__init__.py create mode 100644 rankers/modelling/cat.py create mode 100644 rankers/modelling/dot.py create mode 100644 rankers/modelling/seq2seq.py create mode 100644 rankers/train/__init__.py create mode 100644 rankers/train/arguments.py create mode 100644 rankers/train/loss/__init__.py create mode 100644 rankers/train/loss/listwise.py create mode 100644 rankers/train/loss/pairwise.py create mode 100644 rankers/train/loss/pointwise.py create mode 100644 rankers/train/trainer.py diff --git a/rankers/__init__.py b/rankers/__init__.py new file mode 100644 index 0000000..9ccb6c7 --- /dev/null +++ b/rankers/__init__.py @@ -0,0 +1,33 @@ +__version__ = "0.0.2" + +from .train import loss as loss +from .train.trainer import ContrastTrainer +from .train.arguments import ContrastArguments +from .datasets import * +from .modelling import * + +def is_torch_available(): + try: + import torch + return True + except ImportError: + return False + +def is_flax_available(): + try: + import flax + return True + except ImportError: + return False + +def seed_everything(seed=42): + import random + import numpy as np + import torch + + random.seed(seed) + np.random.seed(seed) + if is_torch_available(): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True \ No newline at end of file diff --git a/rankers/_util.py b/rankers/_util.py new file mode 100644 index 0000000..0de1c3b --- /dev/null +++ b/rankers/_util.py @@ -0,0 +1,140 @@ +from collections import defaultdict +import logging +from typing import Optional +import pandas as pd +import pyterrier as pt +import ir_datasets as irds + +logger = logging.getLogger(__name__) + +def _pivot(frame, negatives = None): + new = [] + for row in frame.itertuples(): + new.append( + { + "qid": row.query_id, + "docno": row.doc_id_a, + "pos": 1 + }) + if negatives: + for doc in negatives[row.query_id]: + new.append( + { + "qid": row.query_id, + "docno": doc + }) + else: + new.append( + { + "qid": row.query_id, + "docno": row.doc_id_b + }) + return pd.DataFrame.from_records(new) + +def _qrel_pivot(frame): + new = [] + for row in frame.itertuples(): + new.append( + { + "qid": row.query_id, + "docno": row.doc_id, + "score": row.relevance + }) + return pd.DataFrame.from_records(new) + +def get_teacher_scores(model : pt.Transformer, + corpus : Optional[pd.DataFrame] = None, + ir_dataset : Optional[str] = None, + subset : Optional[int] = None, + negatives : Optional[dict] = None, + seed : int = 42): + assert corpus is not None or ir_dataset is not None, "Either corpus or ir_dataset must be provided" + if corpus: + for column in ["query", "text"]: assert column in corpus.columns, f"{column} not found in corpus" + if ir_dataset: + dataset = irds.load(ir_dataset) + docs = pd.DataFrame(dataset.docs_iter()).set_index("doc_id")["text"].to_dict() + queries = pd.DataFrame(dataset.queries_iter()).set_index("query_id")["text"].to_dict() + corpus = pd.DataFrame(dataset.docpairs_iter()) + if negatives: + corpus = corpus[['query_id', 'doc_id_a']] + corpus = _pivot(corpus, negatives) + corpus['text'] = corpus['docno'].map(docs) + corpus['query'] = corpus['qid'].map(queries) + if subset: + corpus = corpus.sample(n=subset, random_state=seed) + + logger.warning("Retrieving scores, this may take a while...") + scores = model.transform(corpus) + lookup = defaultdict(dict) + for qid, group in scores.groupby('qid'): + for docno, score in zip(group['docno'], group['score']): + lookup[qid][docno] = score + return lookup + +def initialise_irds_eval(dataset : irds.Dataset): + qrels = pd.DataFrame(dataset.qrels_iter()) + return _qrel_pivot(qrels) + +def load_json(file: str): + import json + import gzip + """ + Load a JSON or JSONL (optionally compressed with gzip) file. + + Parameters: + file (str): The path to the file to load. + + Returns: + dict or list: The loaded JSON content. Returns a list for JSONL files, + and a dict for JSON files. + + Raises: + ValueError: If the file extension is not recognized. + """ + if file.endswith(".json"): + with open(file, 'r') as f: + return json.load(f) + elif file.endswith(".jsonl"): + with open(file, 'r') as f: + return [json.loads(line) for line in f] + elif file.endswith(".json.gz"): + with gzip.open(file, 'rt') as f: + return json.load(f) + elif file.endswith(".jsonl.gz"): + with gzip.open(file, 'rt') as f: + return [json.loads(line) for line in f] + else: + raise ValueError(f"Unknown file type for {file}") + +def save_json(data, file: str): + import json + import gzip + """ + Save data to a JSON or JSONL file (optionally compressed with gzip). + + Parameters: + data (dict or list): The data to save. Must be a list for JSONL files. + file (str): The path to the file to save. + + Raises: + ValueError: If the file extension is not recognized. + """ + if file.endswith(".json"): + with open(file, 'w') as f: + json.dump(data, f) + elif file.endswith(".jsonl"): + with open(file, 'w') as f: + for item in data: + f.write(json.dumps(item) + '\n') + elif file.endswith(".json.gz"): + with gzip.open(file, 'wt') as f: + json.dump(data, f) + elif file.endswith(".jsonl.gz"): + with gzip.open(file, 'wt') as f: + for item in data: + f.write(json.dumps(item) + '\n') + else: + raise ValueError(f"Unknown file type for {file}") + + \ No newline at end of file diff --git a/rankers/datasets/__init__.py b/rankers/datasets/__init__.py new file mode 100644 index 0000000..6ca2d9c --- /dev/null +++ b/rankers/datasets/__init__.py @@ -0,0 +1,2 @@ +from .dataset import * +from .loader import * \ No newline at end of file diff --git a/rankers/datasets/corpus.py b/rankers/datasets/corpus.py new file mode 100644 index 0000000..f0e1b8a --- /dev/null +++ b/rankers/datasets/corpus.py @@ -0,0 +1,43 @@ +import pandas as pd + +class Corpus: + def __init__(self, + documents : dict = None, + queries : dict = None, + qrels : pd.DataFrame = None + ) -> None: + self.documents = documents + self.queries = queries + self.qrels = qrels + + self.__post_init__() + + def __post_init__(self): + if self.qrels: + for column in 'query_id', 'doc_id', 'relevance': + if column not in self.qrels.columns: raise ValueError(f"Format not recognised, Column '{column}' not found in qrels dataframe") + + self.qrels = self.qrels[['query_id', 'doc_id', 'relevance']] + + def has_documents(self): + return self.documents is not None + + def has_queries(self): + return self.queries is not None + + def has_qrels(self): + return self.qrels is not None + + def queries_iter(self): + for queryid, text in self.queries.items(): + yield {"query_id" : queryid, "text" : text} + + def docs_iter(self): + for docid, text in self.documents.items(): + yield {"doc_id" : docid, "text" : text} + + def qrels_iter(self): + for queryid, docid, relevance in self.qrels.itertuples(index=False): + yield {"query_id" : queryid, "doc_id" : docid, "relevance" : relevance} + + \ No newline at end of file diff --git a/rankers/datasets/dataset.py b/rankers/datasets/dataset.py new file mode 100644 index 0000000..3ca8e43 --- /dev/null +++ b/rankers/datasets/dataset.py @@ -0,0 +1,122 @@ +import random +from torch.utils.data import Dataset +import pandas as pd +import torch +from typing import Optional, Union +import ir_datasets as irds +from .._util import load_json +from .corpus import Corpus + +from contrast._util import initialise_irds_eval + +class TrainingDataset(Dataset): + def __init__(self, + training_data : pd.DataFrame, + corpus : Union[Corpus, irds.Dataset], + teacher_file : Optional[str] = None, + group_size : int = 2, + listwise : bool = False, + ) -> None: + super().__init__() + self.training_data = training_data + self.corpus = corpus + self.teacher_file = teacher_file + self.group_size = group_size + self.listwise = listwise + + self.__post_init__() + + + def __post_init__(self): + + for column in 'query_id', 'doc_id_a', 'doc_id_b': + if column not in self.training_data.columns: raise ValueError(f"Format not recognised, Column '{column}' not found in triples dataframe") + self.docs = pd.DataFrame(self.corpus.docs_iter()).set_index("doc_id")["text"].to_dict() + self.queries = pd.DataFrame(self.corpus.queries_iter()).set_index("query_id")["text"].to_dict() + + if self.teacher_file: self.teacher = load_json(self.teacher_file) + + self.labels = True if self.teacher_file else False + self.multi_negatives = True if (type(self.training_data['doc_id_b'].iloc[0]) == list) else False + + if not self.listwise: + if self.group_size > 2 and self.multi_negatives: + self.training_data['doc_id_b'] = self.training_data['doc_id_b'].map(lambda x: random.sample(x, self.group_size-1)) + elif self.group_size == 2 and self.multi_negatives: + self.training_data['doc_id_b'] = self.training_data['doc_id_b'].map(lambda x: random.choice(x) if len(x) > 1 else x[0]) + self.multi_negatives = False + elif self.group_size > 2 and not self.multi_negatives: + raise ValueError("Group size > 2 not supported for single negative samples") + + @classmethod + def from_irds(cls, + ir_dataset : str, + teacher_file : Optional[str] = None, + group_size : int = 2, + collate_fn : Optional[callable] = lambda x : pd.DataFrame(x.docpairs_iter()) + ) -> 'TrainingDataset': + dataset = irds.load(ir_dataset) + assert dataset.has_docpairs(), "Dataset does not have docpairs, check you are not using a test collection" + training_data = collate_fn(dataset) + return cls(training_data, dataset, teacher_file, group_size) + + def __len__(self): + return len(self.training_data) + + def _teacher(self, qid, doc_id, positive=False): + assert self.labels, "No teacher file provided" + try: return self.teacher[str(qid)][str(doc_id)] + except KeyError: return 0. + + def __getitem__(self, idx): + item = self.training_data.iloc[idx] + qid, doc_id_a, doc_id_b = item['query_id'], item['doc_id_a'], item['doc_id_b'] + query = self.queries[str(qid)] + texts = [self.docs[str(doc_id_a)]] if not self.listwise else [] + + if self.multi_negatives: texts.extend([self.docs[str(doc)] for doc in doc_id_b]) + else: texts.append(self.docs[str(doc_id_b)]) + + if self.labels: + scores = [self._teacher(str(qid), str(doc_id_a), positive=True)] if not self.listwise else [] + if self.multi_negatives: scores.extend([self._teacher(qid, str(doc)) for doc in doc_id_b]) + else: scores.append(self._teacher(str(qid), str(doc_id_b))) + return (query, texts, scores) + else: + return (query, texts) + +class EvaluationDataset(Dataset): + def __init__(self, + evaluation_data : Union[pd.DataFrame, str], + corpus : Union[Corpus, irds.Dataset] + ) -> None: + super().__init__() + self.evaluation_data = evaluation_data + self.corpus = corpus + + self.__post_init__() + + def __post_init__(self): + if type(self.evaluation_data) == str: + import pyterrier as pt + self.evaluation_data = pt.io.read_results(self.evaluation_data) + else: + for column in 'qid', 'docno', 'score': + if column not in self.evaluation_data.columns: raise ValueError(f"Format not recognised, Column '{column}' not found in dataframe") + self.docs = pd.DataFrame(self.corpus.docs_iter()).set_index("doc_id")["text"].to_dict() + self.queries = pd.DataFrame(self.corpus.queries_iter()).set_index("query_id")["text"].to_dict() + self.qrels = pd.DataFrame(self.corpus.qrels_iter()) + + self.evaluation_data['text'] = self.evaluation_data['docno'].map(self.docs) + self.evaluation_data['query'] = self.evaluation_data['qid'].map(self.queries) + + @classmethod + def from_irds(cls, + ir_dataset : str, + ) -> 'EvaluationDataset': + dataset = irds.load(ir_dataset) + evaluation_data = initialise_irds_eval(dataset) + return cls(evaluation_data, dataset) + + def __len__(self): + return len(self.evaluation_data.qid.unique()) \ No newline at end of file diff --git a/rankers/datasets/loader.py b/rankers/datasets/loader.py new file mode 100644 index 0000000..f0107c4 --- /dev/null +++ b/rankers/datasets/loader.py @@ -0,0 +1,201 @@ +from typing import Any +import torch + +class DotDataCollator: + def __init__(self, + tokenizer, + special_mask=False, + q_max_length=30, + d_max_length=200, + ) -> None: + self.tokenizer = tokenizer + self.q_max_length = q_max_length + self.d_max_length = d_max_length + self.special_mask = special_mask + + def __call__(self, batch) -> dict: + batch_queries = [] + batch_docs = [] + batch_scores = [] + for (q, dx, *args) in batch: + batch_queries.append(q) + batch_docs.extend(dx) + if len(args) == 0: + continue + batch_scores.extend(args[0]) + + tokenized_queries = self.tokenizer( + batch_queries, + padding=True, + truncation=True, + max_length=self.q_max_length, + return_tensors="pt", + return_special_tokens_mask=self.special_mask, + ) + tokenized_docs = self.tokenizer( + batch_docs, + padding=True, + truncation=True, + max_length=self.d_max_length, + return_tensors="pt", + return_special_tokens_mask=self.special_mask + ) + + return { + "queries": dict(tokenized_queries), + "docs_batch": dict(tokenized_docs), + "labels": torch.tensor(batch_scores) if len(batch_scores) > 0 else None, + } + +class CatDataCollator: + def __init__(self, + tokenizer, + q_max_length=30, + d_max_length=200, + ) -> None: + self.tokenizer = tokenizer + self.q_max_length = q_max_length + self.d_max_length = d_max_length + + def __call__(self, batch) -> dict: + batch_queries = [] + batch_docs = [] + batch_scores = [] + for (q, dx, *args) in batch: + batch_queries.extend([q]*len(dx)) + batch_docs.extend(dx) + if len(args) == 0: + continue + batch_scores.extend(args[0]) + + tokenized_sequences = self.tokenizer( + batch_queries, + batch_docs, + padding=True, + truncation='only_second', + max_length=self.q_max_length + self.d_max_length, + return_tensors="pt", + ) + return { + "sequences": dict(tokenized_sequences), + "labels": torch.tensor(batch_scores) if len(batch_scores) > 0 else None, + } + +def _make_pos_pairs(texts) -> list: + output = [] + pos = texts[0] + for i in range(1, len(texts)): + output.append([pos, texts[i]]) + return output + +class PairDataCollator: + def __init__(self, + tokenizer, + max_length=512) -> None: + self.tokenizer = tokenizer + self.max_length = max_length + + def __call__(self, batch) -> dict: + batch_queries = [] + batch_docs = [] + batch_scores = [] + for (q, dx, *args) in batch: + batch_queries.append(q) + batch_document_pairs = _make_pos_pairs(dx) + batch_docs.append(batch_document_pairs) + if len(args) == 0: + continue + batch_score_pairs = _make_pos_pairs(args[0]) + batch_scores.extend(batch_score_pairs) + + # tokenize each pair with each query + sequences = [f"[CLS] {query} [SEP] {pair[0]} [SEP] {pair[1]}" for query, pairs in zip(batch_queries, batch_docs) for pair in pairs] + + tokenized_sequences = self.tokenizer( + sequences, + padding=True, + truncation=True, + max_length=self.max_length, + return_tensors="pt", + add_special_tokens=True, + ) + + return { + "sequences": dict(tokenized_sequences), + "labels": torch.tensor(batch_scores).squeeze() if len(batch_scores) > 0 else None, + } + +class PromptDataCollator: + def __init__(self, + tokenizer, + prompt : Any, + max_length=512, + ) -> None: + self.tokenizer = tokenizer + self.prompt = prompt + self.max_length = max_length + + def __call__(self, batch) -> dict: + batch_queries = [] + batch_docs = [] + batch_scores = [] + for (q, dx, *args) in batch: + batch_queries.extend([q]*len(dx)) + batch_docs.extend(dx) + if len(args) == 0: + continue + batch_scores.extend(args[0]) + + sequences = [self.prompt(query=q, doc=d) for q, d in zip(batch_queries, batch_docs)] + + tokenized_sequences = self.tokenizer( + sequences, + padding=True, + truncation=True, + max_length=self.max_length, + return_tensors="pt", + add_special_tokens=True, + ) + return { + "sequences": dict(tokenized_sequences), + "labels": torch.tensor(batch_scores) if len(batch_scores) > 0 else None, + } + +class PairPromptDataCollator: + def __init__(self, + tokenizer, + prompt : Any, + max_length=512) -> None: + self.tokenizer = tokenizer + self.max_length = max_length + self.prompt = prompt + + def __call__(self, batch) -> dict: + batch_queries = [] + batch_docs = [] + batch_scores = [] + for (q, dx, *args) in batch: + batch_queries.append(q) + batch_document_pairs = _make_pos_pairs(dx) + batch_docs.append(batch_document_pairs) + if len(args) == 0: + continue + batch_score_pairs = _make_pos_pairs(args[0]) + batch_scores.extend(batch_score_pairs) + + # tokenize each pair with each query + sequences = [self.prompt(query=query, document_1=pair[0], document_2=pair[1]) for query, pairs in zip(batch_queries, batch_docs) for pair in pairs] + + tokenized_sequences = self.tokenizer( + sequences, + padding=True, + truncation=True, + max_length=self.max_length, + return_tensors="pt", + add_special_tokens=True, + ) + + return { + "sequences": dict(tokenized_sequences), + "labels": torch.tensor(batch_scores).squeeze() if len(batch_scores) > 0 else None, + } \ No newline at end of file diff --git a/rankers/modelling/__init__.py b/rankers/modelling/__init__.py new file mode 100644 index 0000000..e7fafed --- /dev/null +++ b/rankers/modelling/__init__.py @@ -0,0 +1,2 @@ +from .cat import CatTransformer, Cat +from .dot import Dot, DotTransformer \ No newline at end of file diff --git a/rankers/modelling/cat.py b/rankers/modelling/cat.py new file mode 100644 index 0000000..675969a --- /dev/null +++ b/rankers/modelling/cat.py @@ -0,0 +1,185 @@ +import pyterrier as pt +if not pt.started(): + pt.init() +from transformers import PreTrainedModel, PreTrainedTokenizer, AutoModelForSequenceClassification, AutoTokenizer, AutoConfig +from typing import Union +import torch +import pandas as pd +from more_itertools import chunked +import numpy as np +import torch.nn.functional as F + + +class Cat(PreTrainedModel): + """Wrapper for Cat Model + + Parameters + ---------- + classifier : PreTrainedModel + the classifier model + config : AutoConfig + the configuration for the model + """ + model_architecture = 'Cat' + def __init__( + self, + classifier: PreTrainedModel, + tokenizer: PreTrainedTokenizer, + config: AutoConfig, + ): + super().__init__(config) + self.classifier = classifier + self.tokenizer = tokenizer + + def prepare_outputs(self, logits): + """Prepare outputs""" + return F.log_softmax(logits.reshape(-1, self.config.group_size, 2), dim=-1)[:, :, 1] + + def forward(self, loss, sequences, labels=None): + """Compute the loss given (pairs, labels)""" + sequences = {k: v.to(self.classifier.device) for k, v in sequences.items()} + labels = labels.to(self.classifier.device) if labels is not None else None + logits = self.classifier(**sequences).logits + pred = self.prepare_outputs(logits) + loss_value = loss(pred) if labels is None else loss(pred, labels) + return (loss_value, pred) + + def save_pretrained(self, model_dir, **kwargs): + """Save classifier""" + self.config.save_pretrained(model_dir) + self.classifier.save_pretrained(model_dir) + self.tokenizer.save_pretrained(model_dir) + + + def load_state_dict(self, model_dir): + """Load state dict from a directory""" + return self.classifier.load_state_dict(AutoModelForSequenceClassification.from_pretrained(model_dir).state_dict()) + + + def to_pyterrier(self) -> "pt.Transformer": + return CatTransformer.from_model(self.classifier, self.tokenizer, text_field='text') + + @classmethod + def from_pretrained(cls, model_dir_or_name : str, num_labels=2): + """Load classifier from a directory""" + config = AutoConfig.from_pretrained(model_dir_or_name) + classifier = AutoModelForSequenceClassification.from_pretrained(model_dir_or_name, num_labels=num_labels) + tokenizer = AutoTokenizer.from_pretrained(model_dir_or_name) + return cls(classifier, tokenizer, config) + +class CatTransformer(pt.Transformer): + def __init__(self, + model : PreTrainedModel, + tokenizer : PreTrainedTokenizer, + config : AutoConfig, + batch_size : int, + text_field : str = 'text', + device : Union[str, torch.device] = None, + verbose : bool = False + ) -> None: + super().__init__() + self.model = model + self.tokenizer = tokenizer + self.config = config + self.batch_size = batch_size + self.text_field = text_field + self.device = device if device is not None else 'cuda' if torch.cuda.is_available() else 'cpu' + self.verbose = verbose + + @classmethod + def from_pretrained(cls, + model_name_or_path : str, + batch_size : int = 64, + text_field : str = 'text', + device : Union[str, torch.device] = None, + verbose : bool = False + ): + model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path).cuda().eval() + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + config = AutoConfig.from_pretrained(model_name_or_path) + return cls(model, tokenizer, config, batch_size, text_field, device, verbose) + + @classmethod + def from_model(cls, + model : PreTrainedModel, + tokenizer : PreTrainedTokenizer, + batch_size : int = 64, + text_field : str = 'text', + verbose : bool = False + ): + config = model.config + return cls(model, tokenizer, config, batch_size, text_field, model.device, verbose) + + def transform(self, inp : pd.DataFrame) -> pd.DataFrame: + scores = [] + it = inp[['query', self.text_field]].itertuples(index=False) + if self.verbose: + it = pt.tqdm(it, total=len(inp), unit='record', desc='Cat scoring') + with torch.no_grad(): + for chunk in chunked(it, self.batch_size): + queries, texts = map(list, zip(*chunk)) + inps = self.tokenizer(queries, texts, return_tensors='pt', padding=True, truncation=True) + inps = {k: v.to(self.model.device) for k, v in inps.items()} + scores.append(F.log_softmax(self.model(**inps).logits, dim=-1)[:, 1].cpu().detach().numpy()) + res = inp.assign(score=np.concatenate(scores)) + res = res.sort_values(['qid', 'score'], ascending=[True, False]) + return pt.model.add_ranks(res) + +class PairTransformer(pt.Transformer): + def __init__(self, + model : PreTrainedModel, + tokenizer : PreTrainedTokenizer, + config : AutoConfig, + batch_size : int, + text_field : str = 'text', + device : Union[str, torch.device] = None, + verbose : bool = False + ) -> None: + super().__init__() + self.model = model + self.tokenizer = tokenizer + self.config = config + self.batch_size = batch_size + self.text_field = text_field + self.device = device if device is not None else 'cuda' if torch.cuda.is_available() else 'cpu' + + @classmethod + def from_model(cls, + model : PreTrainedModel, + tokenizer : PreTrainedTokenizer, + batch_size : int = 64, + text_field : str = 'text', + verbose : bool = False + ): + config = model.config + return cls(model, tokenizer, config, batch_size, text_field, model.device, verbose) + + @classmethod + def from_pretrained(cls, + model_name_or_path : str, + batch_size : int = 64, + text_field : str = 'text', + device : Union[str, torch.device] = None, + verbose : bool = False + ): + model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path) + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + config = AutoConfig.from_pretrained(model_name_or_path) + return cls(model, tokenizer, config, batch_size, text_field, device, verbose) + + def transform(self, inp : pd.DataFrame) -> pd.DataFrame: + # TODO: Switch this to a pair-wise scoring + scores = [] + it = inp[['query', self.text_field]].itertuples(index=False) + if self.verbose: + it = pt.tqdm(it, total=len(inp), unit='record', desc='Duo scoring') + with torch.no_grad(): + for chunk in chunked(it, self.batch_size): + queries, texts = map(list, zip(*chunk)) + inps = self.tokenizer(queries, texts, return_tensors='pt', padding=True, truncation=True) + inps = {k: v.to(self.device) for k, v in inps.items()} + scores.append(self.model(**inps).logits.cpu().detach().numpy()) + res = inp.assign(score=np.concatenate(scores)) + res = inp.assign(score=np.concatenate(scores)) + res = res.sort_values(['qid', 'score'], ascending=[True, False]) + return pt.model.add_ranks(res) \ No newline at end of file diff --git a/rankers/modelling/dot.py b/rankers/modelling/dot.py new file mode 100644 index 0000000..cfe8f4f --- /dev/null +++ b/rankers/modelling/dot.py @@ -0,0 +1,394 @@ +from copy import deepcopy +import os +import torch +from torch import nn +import pyterrier as pt +if not pt.started(): + pt.init() +from transformers import PreTrainedModel, PreTrainedTokenizer, PretrainedConfig, AutoModel, AutoTokenizer +from typing import Union +import pandas as pd +import numpy as np +from more_itertools import chunked +from ..train.loss import batched_dot_product, cross_dot_product, LOSSES + +class DotConfig(PretrainedConfig): + """Configuration for Dot Model + + Parameters + ---------- + model_name_or_path : str + the model name or path + mode : str + the pooling mode for the model + encoder_tied : bool + whether the encoder is tied + use_pooler : bool + whether to use the pooler + pooler_dim_in : int + the input dimension for the pooler + pooler_dim_out : int + the output dimension for the pooler + pooler_tied : bool + whether the pooler is tied + """ + model_architecture = "Dot" + def __init__(self, + model_name_or_path : str='bert-base-uncased', + mode='cls', + inbatch_loss = None, + encoder_tied=True, + use_pooler=False, + pooler_dim_in=768, + pooler_dim_out=768, + pooler_tied=True, + **kwargs): + self.model_name_or_path = model_name_or_path + self.mode = mode + self.inbatch_loss = inbatch_loss + self.encoder_tied = encoder_tied + self.use_pooler = use_pooler + self.pooler_dim_in = pooler_dim_in + self.pooler_dim_out = pooler_dim_out + self.pooler_tied = pooler_tied + super().__init__(**kwargs) + + @classmethod + def from_pretrained(cls, + model_name_or_path : str='bert-base-uncased', + mode='cls', + inbatch_loss = None, + encoder_tied=True, + use_pooler=False, + pooler_dim_in=768, + pooler_dim_out=768, + pooler_tied=True, + ) -> 'DotConfig': + config = super().from_pretrained(model_name_or_path) + config.model_name_or_path = model_name_or_path + config.mode = mode + config.inbatch_loss = inbatch_loss + config.encoder_tied = encoder_tied + config.use_pooler = use_pooler + config.pooler_dim_in = pooler_dim_in + config.pooler_dim_out = pooler_dim_out + config.pooler_tied = pooler_tied + return config + +class Pooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense_q = nn.Linear(config.pooler_dim_in, config.pooler_dim_out) + self.dense_d = nn.Linear(config.pooler_dim_in, config.pooler_dim_out) if not config.pooler_tied else self.dense_q + + @classmethod + def from_pretrained(cls, model_name_or_path : str='bert-base-uncased') -> 'Pooler': + config = DotConfig.from_pretrained(model_name_or_path) + model = cls(config) + return model + + def forward(self, hidden_states, d=False): + return self.dense_d(hidden_states) if d else self.dense_q(hidden_states) + +class Dot(PreTrainedModel): + """ + Dot Model for Fine-Tuning + + Parameters + ---------- + encoder : PreTrainedModel + the encoder model + config : DotConfig + the configuration for the model + encoder_d : PreTrainedModel + the document encoder model + pooler : Pooler + the pooling layer + """ + def __init__( + self, + encoder : PreTrainedModel, + tokenizer : PreTrainedTokenizer, + config : DotConfig, + encoder_d : PreTrainedModel = None, + pooler : Pooler = None, + ): + super().__init__(config) + self.encoder = encoder + self.tokenizer = tokenizer + if encoder_d: self.encoder_d = encoder_d + else: self.encoder_d = self.encoder if config.encoder_tied else deepcopy(self.encoder) + self.pooling = { + 'mean': self._mean, + 'cls' : self._cls, + }[config.mode] + + if config.use_pooler: self.pooler = Pooler(config) if pooler is None else pooler + else: self.pooler = lambda x, y =True : x + + if config.inbatch_loss is not None: + if config.inbatch_loss not in LOSSES: + raise ValueError(f"Unknown loss: {config.inbatch_loss}") + self.inbatch_loss_fn = LOSSES[config.inbatch_loss]() + else: + self.inbatch_loss_fn = None + + def prepare_outputs(self, query_reps, docs_batch_reps, labels=None): + batch_size = query_reps.size(0) + emb_q = query_reps.reshape(batch_size, 1, -1) + emb_d = docs_batch_reps.reshape(batch_size, self.config.group_size, -1) + pred = batched_dot_product(emb_q, emb_d) + + if self.config.inbatch_loss is not None: + inbatch_d = emb_d[:, 0] + inbatch_pred = cross_dot_product(emb_q.view(batch_size, -1), inbatch_d) + else: + inbatch_pred = None + + if labels is not None: labels = labels.reshape(batch_size, self.config.group_size) + + return pred, labels, inbatch_pred + + def _cls(self, x : torch.Tensor) -> torch.Tensor: + return self.pooler(x[:, 0]) + + def _mean(self, x : torch.Tensor) -> torch.Tensor: + return self.pooler(x.mean(dim=1)) + + def _encode_d(self, **text): + return self.pooling(self.encoder_d(**text).last_hidden_state) + + def _encode_q(self, **text): + return self.pooling(self.encoder(**text).last_hidden_state) + + def forward(self, + loss = None, + queries = None, + docs_batch = None, + labels=None): + """Compute the loss given (queries, docs, labels)""" + queries = {k: v.to(self.encoder.device) for k, v in queries.items()} if queries is not None else None + docs_batch = {k: v.to(self.encoder_d.device) for k, v in docs_batch.items()} if docs_batch is not None else None + labels = labels.to(self.encoder_d.device) if labels is not None else None + + query_reps = self._encode_q(**queries) if queries is not None else None + docs_batch_reps = self._encode_d(**docs_batch) if docs_batch is not None else None + + pred, labels, inbatch_pred = self.prepare_outputs(query_reps, docs_batch_reps, labels) + inbatch_loss = self.inbatch_loss_fn(inbatch_pred, torch.eye(inbatch_pred.shape[0]).to(inbatch_pred.device)) if inbatch_pred is not None else 0. + + loss_value = loss(pred, labels) if labels is not None else loss(pred) + loss_value += inbatch_loss + return (loss_value, pred) + + def save_pretrained(self, model_dir, **kwargs): + """Save both query and document encoder""" + self.config.save_pretrained(model_dir) + self.encoder.save_pretrained(model_dir) + if not self.config.encoder_tied: self.encoder_d.save_pretrained(model_dir + "/encoder_d") + if self.config.use_pooler: self.pooler.save_pretrained(model_dir + "/pooler") + self.tokenizer.save_pretrained(model_dir) + + + def load_state_dict(self, model_dir): + """Load state dict from a directory""" + self.config = DotConfig.from_pretrained(model_dir) + self.encoder.load_state_dict(AutoModel.from_pretrained(model_dir).state_dict()) + if not self.config.encoder_tied: self.encoder_d.load_state_dict(AutoModel.from_pretrained(model_dir + "/encoder_d").state_dict()) + if self.config.use_pooler: self.pooler.load_state_dict(AutoModel.from_pretrained(model_dir + "/pooler").state_dict()) + + @classmethod + def from_pretrained(cls, model_dir_or_name, **kwargs): + """Load encoder""" + if os.path.isdir(model_dir_or_name): + config = DotConfig.from_pretrained(model_dir_or_name, **kwargs) + encoder = AutoModel.from_pretrained(model_dir_or_name) + encoder_d = None if config.encoder_tied else AutoModel.from_pretrained(model_dir_or_name + "/encoder_d") + pooler = None if not config.use_pooler else Pooler.from_pretrained(model_dir_or_name + "/pooler") + + return cls(encoder, config, encoder_d, pooler) + config = DotConfig(model_dir_or_name, **kwargs) + tokenizer = AutoTokenizer.from_pretrained(model_dir_or_name) + encoder = AutoModel.from_pretrained(model_dir_or_name) + return cls(encoder, tokenizer, config) + + def to_pyterrier(self) -> "DotTransformer": + return DotTransformer.from_model(self, self.tokenizer, text_field='text') + +class DotTransformer(pt.Transformer): + def __init__(self, + model : PreTrainedModel, + tokenizer : PreTrainedTokenizer, + config : DotConfig, + batch_size : int, + text_field : str = 'text', + device : Union[str, torch.device] = None, + verbose : bool = False + ) -> None: + super().__init__() + self.device = device if device is not None else 'cuda' if torch.cuda.is_available() else 'cpu' + self.model = model.eval().to(self.device) + self.tokenizer = tokenizer + self.config = config + self.batch_size = batch_size + self.text_field = text_field + self.pooling = { + 'mean': lambda x: x.mean(dim=1), + 'cls' : lambda x: x[:, 0], + 'none': lambda x: x, + }[config.mode] + self.verbose = verbose + + @classmethod + def from_pretrained(cls, + model_name_or_path : str, + batch_size : int = 64, + pooling : str = 'cls', + text_field : str = 'text', + device : Union[str, torch.device] = None, + verbose : bool = False + ): + config = DotConfig.from_pretrained(model_name_or_path) + config.mode = pooling + pooler = None if not config.use_pooler else Pooler.from_pretrained(model_name_or_path+"/pooler") + encoder_d = None if config.encoder_tied else AutoModel.from_pretrained(model_name_or_path + "/encoder_d") + encoder_q = AutoModel.from_pretrained(model_name_or_path) + model = Dot(encoder_q, config, encoder_d, pooler) + return cls(model, AutoTokenizer.from_pretrained(model_name_or_path), config, batch_size, text_field, device, verbose) + + @classmethod + def from_model(cls, + model : PreTrainedModel, + tokenizer : PreTrainedTokenizer, + batch_size : int = 64, + text_field : str = 'text', + verbose : bool = False + ): + config = model.config + return cls(model, tokenizer, config, batch_size, text_field, model.device, verbose) + + def encode_queries(self, texts, batch_size=None) -> np.ndarray: + results = [] + with torch.no_grad(): + for chunk in chunked(texts, batch_size or self.batch_size): + inps = self.tokenizer(list(chunk), return_tensors='pt', padding=True, truncation=True) + inps = {k: v.to(self.device) for k, v in inps.items()} + res = self.model._encode_q(**inps) + results.append(res.cpu().numpy()) + if not results: + return np.empty(shape=(0, 0)) + return np.concatenate(results, axis=0) + + def encode_docs(self, texts, batch_size=None) -> np.ndarray: + results = [] + with torch.no_grad(): + for chunk in chunked(texts, batch_size or self.batch_size): + inps = self.tokenizer(list(chunk), return_tensors='pt', padding=True, truncation=True) + inps = {k: v.to(self.device) for k, v in inps.items()} + res = self.model._encode_d(**inps) + results.append(res.cpu().numpy()) + if not results: + return np.empty(shape=(0, 0)) + return np.concatenate(results, axis=0) + + def transform(self, inp: pd.DataFrame) -> pd.DataFrame: + columns = set(inp.columns) + modes = [ + (['qid', 'query', self.text_field], self.scorer), + (['qid', 'query_vec', self.text_field], self.scorer), + (['qid', 'query', 'doc_vec'], self.scorer), + (['qid', 'query_vec', 'doc_vec'], self.scorer), + (['query'], self.query_encoder), + ([self.text_field], self.doc_encoder), + ] + for fields, fn in modes: + if all(f in columns for f in fields): + return fn()(inp) + message = f'Unexpected input with columns: {inp.columns}. Supports:' + for fields, fn in modes: + message += f'\n - {fn.__doc__.strip()}: {fields}' + raise RuntimeError(message) + + def query_encoder(self, verbose=None, batch_size=None) -> pt.Transformer: + """ + Query encoding + """ + return BiQueryEncoder(self, verbose=verbose, batch_size=batch_size) + + def doc_encoder(self, verbose=None, batch_size=None) -> pt.Transformer: + """ + Doc encoding + """ + return BiDocEncoder(self, verbose=verbose, batch_size=batch_size) + + def scorer(self, verbose=None, batch_size=None) -> pt.Transformer: + """ + Scoring (re-ranking) + """ + return BiScorer(self, verbose=verbose, batch_size=batch_size) + +class BiQueryEncoder(pt.Transformer): + def __init__(self, bi_encoder_model: DotTransformer, verbose=None, batch_size=None): + self.bi_encoder_model = bi_encoder_model + self.verbose = verbose if verbose is not None else bi_encoder_model.verbose + self.batch_size = batch_size if batch_size is not None else bi_encoder_model.batch_size + + def encode(self, texts, batch_size=None) -> np.array: + return self.bi_encoder_model.encode_queries(texts, batch_size=batch_size or self.batch_size) + + def transform(self, inp: pd.DataFrame) -> pd.DataFrame: + assert all(c in inp.columns for c in ['query']) + it = inp['query'].values + it, inv = np.unique(it, return_inverse=True) + if self.verbose: + it = pt.tqdm(it, desc='Encoding Queries', unit='query') + enc = self.encode(it) + return inp.assign(query_vec=[enc[i] for i in inv]) + + def __repr__(self): + return f'{repr(self.bi_encoder_model)}.query_encoder()' + +class BiDocEncoder(pt.Transformer): + def __init__(self, bi_encoder_model: DotTransformer, verbose=None, batch_size=None, text_field=None): + self.bi_encoder_model = bi_encoder_model + self.verbose = verbose if verbose is not None else bi_encoder_model.verbose + self.batch_size = batch_size if batch_size is not None else bi_encoder_model.batch_size + self.text_field = text_field if text_field is not None else bi_encoder_model.text_field + + def encode(self, texts, batch_size=None) -> np.array: + return self.bi_encoder_model.encode_docs(texts, batch_size=batch_size or self.batch_size) + + def transform(self, inp: pd.DataFrame) -> pd.DataFrame: + assert all(c in inp.columns for c in [self.text_field]) + it = inp[self.text_field] + if self.verbose: + it = pt.tqdm(it, desc='Encoding Docs', unit='doc') + return inp.assign(doc_vec=list(self.encode(it))) + + def __repr__(self): + return f'{repr(self.bi_encoder_model)}.doc_encoder()' + +class BiScorer(pt.Transformer): + def __init__(self, bi_encoder_model: DotTransformer, verbose=None, batch_size=None, text_field=None): + self.bi_encoder_model = bi_encoder_model + self.verbose = verbose if verbose is not None else bi_encoder_model.verbose + self.batch_size = batch_size if batch_size is not None else bi_encoder_model.batch_size + self.text_field = text_field if text_field is not None else bi_encoder_model.text_field + + def transform(self, inp: pd.DataFrame) -> pd.DataFrame: + assert 'query_vec' in inp.columns or 'query' in inp.columns + assert 'doc_vec' in inp.columns or self.text_field in inp.columns + if 'query_vec' in inp.columns: + query_vec = inp['query_vec'] + else: + query_vec = self.bi_encoder_model.query_encoder(batch_size=self.batch_size, verbose=self.verbose)(inp)['query_vec'] + if 'doc_vec' in inp.columns: + doc_vec = inp['doc_vec'] + else: + doc_vec = self.bi_encoder_model.doc_encoder(batch_size=self.batch_size, verbose=self.verbose)(inp)['doc_vec'] + scores = (query_vec * doc_vec).apply(np.sum) + outp = inp.assign(score=scores) + return pt.model.add_ranks(outp) + + def __repr__(self): + return f'{repr(self.bi_encoder_model)}.scorer()' \ No newline at end of file diff --git a/rankers/modelling/seq2seq.py b/rankers/modelling/seq2seq.py new file mode 100644 index 0000000..f52e64b --- /dev/null +++ b/rankers/modelling/seq2seq.py @@ -0,0 +1,162 @@ +import pyterrier as pt +if not pt.started(): + pt.init() +from transformers import PreTrainedModel, PreTrainedTokenizer, AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM +from typing import Union +import torch +import pandas as pd +from more_itertools import chunked +import numpy as np + + +DEFAULT_MONO_PROMPT = r'query: {query} document: {text} relevant:' +DEFAULT_DUO_PROMPT = r'query: {query} positive: {text} negative: {text} relevant:' + +class Seq2Seq(PreTrainedModel): + """Wrapper for ConditionalGenerationCat Model + + Parameters + ---------- + classifier : AutoModelForSeq2SeqLM + the classifier model + config : AutoConfig + the configuration for the model + """ + model_architecture = 'Seq2Seq' + def __init__( + self, + classifier: AutoModelForSeq2SeqLM, + tokenizer: PreTrainedTokenizer, + config: AutoConfig, + ): + super().__init__(config) + self.classifier = classifier + self.tokenizer = tokenizer + + def prepare_outputs(self, logits): + raise NotImplementedError + + def forward(self, loss, sequences, labels=None): + """Compute the loss given (pairs, labels)""" + sequences = {k: v.to(self.classifier.device) for k, v in sequences.items()} + labels = labels.to(self.classifier.device) if labels is not None else None + logits = self.classifier(**sequences).logits + pred = self.prepare_outputs(logits) + loss_value = loss(pred) if labels is None else loss(pred, labels) + return (loss_value, pred) + + + def save_pretrained(self, model_dir, **kwargs): + """Save classifier""" + self.config.save_pretrained(model_dir) + self.classifier.save_pretrained(model_dir) + self.tokenizer.save_pretrained(model_dir) + + def load_state_dict(self, model_dir): + """Load state dict from a directory""" + return self.classifier.load_state_dict(AutoModelForSeq2SeqLM.from_pretrained(model_dir).state_dict()) + + def to_pyterrier(self) -> "Seq2SeqTransformer": + return Seq2SeqTransformer.from_model(self.classifier, self.tokenizer, text_field='text') + + @classmethod + def from_pretrained(cls, model_dir_or_name : str, num_labels=2): + """Load classifier from a directory""" + config = AutoConfig.from_pretrained(model_dir_or_name) + classifier = AutoModelForSeq2SeqLM.from_pretrained(model_dir_or_name, num_labels=num_labels) + return cls(classifier, config) + +class Seq2SeqTransformer(pt.Transformer): + def __init__(self, + model : PreTrainedModel, + tokenizer : PreTrainedTokenizer, + config : AutoConfig, + batch_size : int, + text_field : str = 'text', + device : Union[str, torch.device] = None, + pos_token : str = 'true', + neg_token : str = 'false', + prompt : str = None + ) -> None: + super().__init__() + self.model = model + self.tokenizer = tokenizer + self.config = config + self.batch_size = batch_size + self.text_field = text_field + self.device = device if device is not None else 'cuda' if torch.cuda.is_available() else 'cpu' + self.pos_token = self.tokenizer.encode(pos_token)[0] + self.neg_token = self.tokenizer.encode(neg_token)[0] + self.prompt = prompt if prompt is not None else DEFAULT_MONO_PROMPT + + @classmethod + def from_pretrained(cls, + model_name_or_path : str, + batch_size : int = 64, + text_field : str = 'text', + device : Union[str, torch.device] = None + ): + model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path) + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + config = AutoConfig.from_pretrained(model_name_or_path) + return cls(model, tokenizer, config, batch_size, text_field, device) + + @classmethod + def from_model(cls, + model : PreTrainedModel, + tokenizer : PreTrainedTokenizer, + batch_size : int = 64, + text_field : str = 'text', + ): + config = model.config + return cls(model, tokenizer, config, batch_size, text_field, model.device) + + def transform(self, inp : pd.DataFrame) -> pd.DataFrame: + scores = [] + it = inp[['query', self.text_field]].itertuples(index=False) + if self.verbose: + it = pt.tqdm(it, total=len(inp), unit='record', desc='Cat scoring') + with torch.no_grad(): + for chunk in chunked(it, self.batch_size): + queries, texts = map(list, zip(*chunk)) + prompts = [self.prompt.format(query=q, text=t) for q, t in zip(queries, texts)] + inps = self.tokenizer(prompts, return_tensors='pt', padding=True, truncation=True) + inps = {k: v.to(self.device) for k, v in inps.items()} + scores.append(self.model(**inps).logits[:, (self.pos_token, self.neg_token)].softmax(dim=-1)[0].cpu().detach().numpy()) + res = inp.assign(score=np.concatenate(scores)) + pt.model.add_ranks(res) + res = res.sort_values(['qid', 'rank']) + return res + +class Seq2SeqDuoTransformer(Seq2SeqTransformer): + def __init__(self, + model : PreTrainedModel, + tokenizer : PreTrainedTokenizer, + config : AutoConfig, + batch_size : int, + text_field : str = 'text', + device : Union[str, torch.device] = None, + pos_token : str = 'true', + neg_token : str = 'false', + prompt : str = None + ) -> None: + super().__init__(model, tokenizer, config, batch_size, text_field, device, pos_token, neg_token, prompt) + self.prompt = prompt if prompt is not None else DEFAULT_DUO_PROMPT + + def transform(self, inp : pd.DataFrame) -> pd.DataFrame: + # TODO: Fix this mess + scores = [] + it = inp[['query', self.text_field]].itertuples(index=False) + if self.verbose: + it = pt.tqdm(it, total=len(inp), unit='record', desc='Cat scoring') + with torch.no_grad(): + for chunk in chunked(it, self.batch_size): + queries, texts = map(list, zip(*chunk)) + prompts = [self.prompt.format(query=q, text1=t1, text2=t2) for q, t1, t2 in zip(queries, texts, texts) if t1 != t2] + inps = self.tokenizer(prompts, return_tensors='pt', padding=True, truncation=True) + inps = {k: v.to(self.device) for k, v in inps.items()} + scores.append(self.model(**inps).logits[:, (self.pos_token, self.neg_token)].softmax(dim=-1)[0].cpu().detach().numpy()) + res = inp.assign(score=np.concatenate(scores)) + pt.model.add_ranks(res) + res = res.sort_values(['qid', 'rank']) + return res \ No newline at end of file diff --git a/rankers/train/__init__.py b/rankers/train/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/rankers/train/arguments.py b/rankers/train/arguments.py new file mode 100644 index 0000000..0be8495 --- /dev/null +++ b/rankers/train/arguments.py @@ -0,0 +1,10 @@ +from transformers import TrainingArguments + +class ContrastArguments(TrainingArguments): + def __init__(self, + mode : str = None, + group_size : int = 2, + **kwargs): + self.mode = mode + self.group_size = group_size + super().__init__(**kwargs) \ No newline at end of file diff --git a/rankers/train/loss/__init__.py b/rankers/train/loss/__init__.py new file mode 100644 index 0000000..18ecb67 --- /dev/null +++ b/rankers/train/loss/__init__.py @@ -0,0 +1,166 @@ +from collections import defaultdict +import torch.nn as nn +import torch.nn.functional as F +import torch +from torch import Tensor + +def reduce(a : torch.Tensor, reduction : str): + """ + Reducing a tensor along a given dimension. + Parameters + ---------- + a: torch.Tensor + the input tensor + reduction: str + the reduction type + Returns + ------- + torch.Tensor + the reduced tensor + """ + if reduction == 'none': + return a + if reduction == 'mean': + return a.mean() + if reduction == 'sum': + return a.sum() + if reduction == 'batchmean': + return a.mean(dim=0).sum() + raise ValueError(f"Unknown reduction type: {reduction}") + +class BaseLoss(nn.Module): + """ + Base class for Losses + + Parameters + ---------- + reduction: str + the reduction type + """ + def __init__(self, reduction : str = 'mean') -> None: + super(BaseLoss, self).__init__() + self.reduction = reduction + + def _reduce(self, a : torch.Tensor): + return reduce(a, self.reduction) + + def forward(self, *args, **kwargs): + raise NotImplementedError + +def normalize(a: Tensor, dim: int = -1): + """ + Normalizing a tensor along a given dimension. + Parameters + ---------- + a: torch.Tensor + the input tensor + dim: int + the dimension to normalize along + Returns + ------- + torch.Tensor + the normalized tensor + """ + min_values = a.min(dim=dim, keepdim=True)[0] + max_values = a.max(dim=dim, keepdim=True)[0] + return (a - min_values) / (max_values - min_values + 1e-10) + +def residual(a : Tensor): + """ + Calculating the residual between a positive sample and multiple negatives. + Parameters + ---------- + a: torch.Tensor + the input tensor + Returns + ------- + torch.Tensor + the residuals + """ + if a.size(1) == 1: return a + if len(a.size()) == 3: + assert a.size(2) == 1, "Expected scalar values for residuals." + a = a.squeeze(2) + + positive = a[:, 0] + negative = a[:, 1] + + return positive - negative + +def dot_product(a: Tensor, b: Tensor): + """ + Calculating row-wise dot product between two tensors a and b. + a and b must have the same dimensionality. + Parameters + ---------- + a: torch.Tensor + size: batch_size x vector_dim + b: torch.Tensor + size: batch_size x vector_dim + Returns + ------- + torch.Tensor: size of (batch_size x 1) + dot product for each pair of vectors + """ + return (a * b).sum(dim=-1) + + +def cross_dot_product(a: Tensor, b: Tensor): + """ + Calculating the cross doc product between each row in a with every row in b. a and b must have the same number of columns, but can have varied nuber of rows. + Parameters + ---------- + a: torch.Tensor + size: (batch_size_1, vector_dim) + b: torch.Tensor + size: (batch_size_2, vector_dim) + Returns + ------- + torch.Tensor: of size (batch_size_1, batch_size_2) where the value at (i,j) is dot product of a[i] and b[j]. + """ + return torch.mm(a, b.transpose(0, 1)) + +def batched_dot_product(a: Tensor, b: Tensor): + """ + Calculating the dot product between two tensors a and b. + + Parameters + ---------- + a: torch.Tensor + size: batch_size x vector_dim + b: torch.Tensor + size: batch_size x group_size x vector_dim + Returns + ------- + torch.Tensor: size of (batch_size x group_size) + dot product for each group of vectors + """ + if len(b.shape) == 2: + return torch.matmul(a, b.transpose(0, 1)) + + # Ensure `a` is of shape (batch_size, 1, vector_dim) + if len(a.shape) == 2: + a = a.unsqueeze(1) + + # Compute batched dot product, result shape: (batch_size, 1, group_size) + return torch.bmm(b, a.transpose(1, 2)).squeeze() + +def num_non_zero(a: Tensor): + """ + Calculating the average number of non-zero columns in each row. + Parameters + ---------- + a: torch.Tensor + the input tensor + """ + return (a > 0).float().sum(dim=1).mean() + +from .listwise import * +from .pointwise import * +from .pairwise import * + +LOSSES = { + **POINTWISE_LOSSES, + **PAIRWISE_LOSSES, + **LISTWISE_LOSSES, +} \ No newline at end of file diff --git a/rankers/train/loss/listwise.py b/rankers/train/loss/listwise.py new file mode 100644 index 0000000..e54ae3c --- /dev/null +++ b/rankers/train/loss/listwise.py @@ -0,0 +1,182 @@ +import torch +from torch import Tensor +from torch.nn import functional as F +from . import BaseLoss + +class KL_DivergenceLoss(BaseLoss): + """KL Divergence loss""" + + def __init__(self, reduction='batchmean', temperature=1.): + super().__init__(reduction) + self.temperature = temperature + self.kl_div = torch.nn.KLDivLoss(reduction=self.reduction) + + def forward(self, pred: Tensor, labels: Tensor) -> Tensor: + return self.kl_div(F.log_softmax(pred / self.temperature, dim=1), F.softmax(labels / self.temperature, dim=1)) + + +class RankNetLoss(BaseLoss): + """RankNet loss + https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/MSR-TR-2010-82.pdf + """ + + def __init__(self, reduction='mean', temperature=1.): + super().__init__(reduction) + self.temperature = temperature + self.bce = torch.nn.BCEWithLogitsLoss(reduction=reduction) + + def forward(self, pred: Tensor, labels: Tensor=None) -> Tensor: + _, g = pred.shape + i1, i2 = torch.triu_indices(g, g, offset=1) + pred_diff = pred[:, i1] - pred[:, i2] + if labels is None: + labels = torch.zeros_like(pred_diff) + labels[:, 0] = 1. + else: + label_diff = labels[:, i1] - labels[:, i2] + labels = (label_diff > 0).float() + + return self.bce(pred_diff, labels) + + +class DistillRankNetLoss(BaseLoss): + """DistillRankNet loss + Very much a WIP from https://arxiv.org/pdf/2402.10769 + DO NOT USE + """ + def __init__(self, reduction='mean', temperature=1., base_margin=300., increment_margin=100.): + super().__init__(reduction) + self.temperature = temperature + self.base_margin = base_margin + self.increment_margin = increment_margin + + def forward(self, pred: Tensor, labels: Tensor) -> Tensor: + _, g = pred.shape + i1, i2 = torch.triu_indices(g, g, offset=1) + + pred_diff = pred[:, i1] - pred[:, i2] + + label_diff = labels[:, i1] - labels[:, i2] + label_margin = (label_diff -1) * self.increment_margin + self.base_margin + + final_margin = pred_diff + label_margin + labels = (label_diff > 0).float() + + return self._reduce(final_margin[labels]) + +class ListNetLoss(BaseLoss): + """ListNet loss + """ + + def __init__(self, reduction='mean', temperature=1., epsilon=1e-8): + super().__init__(reduction) + self.temperature = temperature + self.epsilon = epsilon + + def forward(self, pred: Tensor, labels: Tensor) -> Tensor: + if not torch.all((labels >= 0) & (labels <= 1)): + labels = F.softmax(labels / self.temperature, dim=1) + return self._reduce(-torch.sum(labels * F.log_softmax(pred + self.epsilon / self.temperature, dim=1), dim=-1)) + +class Poly1SoftmaxLoss(BaseLoss): + """Poly1 softmax loss with automatic softmax handling and reduction.""" + + def __init__(self, reduction='mean', epsilon : float = 1., temperature=1.): + super().__init__(reduction) + self.epsilon = epsilon + self.temperature = temperature + self.ce = torch.nn.CrossEntropyLoss(reduction='none') + + def forward(self, pred: Tensor, labels: Tensor) -> Tensor: + labels_for_softmax = torch.divide(labels, labels.sum(dim=1)) + expansion = (labels_for_softmax * F.softmax(pred / self.temperature, dim=1)).sum(dim=-1) + ce = self.ce(pred / self.temperature, labels_for_softmax) + return self._reduce(ce + (1 - expansion) * self.epsilon) + +def get_approx_ranks(pred: torch.Tensor, temperature: float) -> torch.Tensor: + score_diff = pred[:, None] - pred[..., None] + normalized_score_diff = torch.sigmoid(score_diff / temperature) + # set diagonal to 0 + normalized_score_diff = normalized_score_diff * (1 - torch.eye(pred.shape[1], device=pred.device)) + approx_ranks = normalized_score_diff.sum(-1) + 1 + return approx_ranks + +# Taken from https://github.com/webis-de/lightning-ir/blob/main/lightning_ir/loss/loss.py + +def get_dcg( + ranks: torch.Tensor, + labels: torch.Tensor, + k: int | None = None, + scale_gains: bool = True, + ) -> torch.Tensor: + log_ranks = torch.log2(1 + ranks) + discounts = 1 / log_ranks + if scale_gains: + gains = 2**labels - 1 + else: + gains = labels + dcgs = gains * discounts + if k is not None: + dcgs = dcgs.masked_fill(ranks > k, 0) + return dcgs.sum(dim=-1) + +def get_ndcg( + ranks: torch.Tensor, + labels: torch.Tensor, + k: int | None = None, + scale_gains: bool = True, + optimal_labels: torch.Tensor | None = None, + ) -> torch.Tensor: + labels = labels.clamp(min=0) + if optimal_labels is None: + optimal_labels = labels + optimal_ranks = torch.argsort(torch.argsort(optimal_labels, descending=True)) + optimal_ranks = optimal_ranks + 1 + dcg = get_dcg(ranks, labels, k, scale_gains) + idcg = get_dcg(optimal_ranks, optimal_labels, k, scale_gains) + ndcg = dcg / (idcg.clamp(min=1e-12)) + return ndcg + +class ApproxNDCGLoss(BaseLoss): + def __init__(self, reduction: str = 'mean', temperature=1., scale_gains: bool = True) -> None: + super().__init__(reduction) + self.temperature = temperature + self.scale_gains = scale_gains + + def forward(self, pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + labels = self.process_labels(pred, labels) + approx_ranks = get_approx_ranks(pred, self.temperature) + ndcg = get_ndcg(approx_ranks, labels, k=None, scale_gains=self.scale_gains) + loss = 1 - ndcg + return self._reduce(loss) + +def get_mrr(ranks: torch.Tensor, labels: torch.Tensor, k: int | None = None) -> torch.Tensor: + labels = labels.clamp(None, 1) + reciprocal_ranks = 1 / ranks + mrr = reciprocal_ranks * labels + if k is not None: + mrr = mrr.masked_fill(ranks > k, 0) + mrr = mrr.max(dim=-1)[0] + return mrr + +class ApproxMRRLoss(BaseLoss): + def __init__(self, reduction: str = 'mean', temperature=1.) -> None: + super().__init__(reduction) + self.temperature = temperature + + def forward(self, pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + approx_ranks = get_approx_ranks(pred, self.temperature) + mrr = get_mrr(approx_ranks, labels, k=None) + loss = 1 - mrr + return self._reduce(loss) + + +LISTWISE_LOSSES = { + 'kl_div': KL_DivergenceLoss, + 'ranknet': RankNetLoss, + 'distill_ranknet': DistillRankNetLoss, + 'listnet': ListNetLoss, + 'poly1': Poly1SoftmaxLoss, + 'approx_ndcg': ApproxNDCGLoss, + 'approx_mrr': ApproxMRRLoss, +} \ No newline at end of file diff --git a/rankers/train/loss/pairwise.py b/rankers/train/loss/pairwise.py new file mode 100644 index 0000000..7b8d167 --- /dev/null +++ b/rankers/train/loss/pairwise.py @@ -0,0 +1,69 @@ +import torch +from torch import Tensor +import torch.nn.functional as F +from . import BaseLoss + +residual = lambda x : x[:, 0].unsqueeze(1) - x[:, 1:] + +class MarginMSELoss(BaseLoss): + """Margin MSE loss with residual calculation.""" + + def forward(self, pred: Tensor, labels: Tensor) -> Tensor: + residual_pred = pred[:, 0].unsqueeze(1) - pred[:, 1:] + residual_label = labels[:, 0].unsqueeze(1) - labels[:, 1:] + return F.mse_loss(residual_pred, residual_label, reduction=self.reduction) + + +class HingeLoss(BaseLoss): + """Hinge loss with sigmoid activation and residual calculation.""" + + def __init__(self, margin=1, reduction='mean'): + super().__init__(reduction) + self.margin = margin + + def forward(self, pred: Tensor, labels: Tensor) -> Tensor: + pred_residuals = F.relu(residual(F.sigmoid(pred))) + label_residuals = torch.sign(residual(F.sigmoid(labels))) + return self._reduce(F.relu(self.margin - (label_residuals * pred_residuals))) + + +class ClearLoss(BaseLoss): + """Clear loss with margin and residual calculation.""" + + def __init__(self, margin=1, reduction='mean'): + super().__init__(reduction) + self.margin = margin + + def forward(self, pred: Tensor, labels: Tensor) -> Tensor: + margin_b = self.margin - residual(labels) + return self._reduce(F.relu(margin_b - residual(pred))) + +class LCELoss(BaseLoss): + """LCE loss: Cross Entropy for NCE with localised examples.""" + def forward(self, pred: Tensor, labels: Tensor=None) -> Tensor: + if labels is not None: + labels = labels.argmax(dim=1) + else: + labels = torch.zeros(pred.size(0), dtype=torch.long, device=pred.device) + return F.cross_entropy(pred, labels, reduction=self.reduction) + + +class ContrastiveLoss(BaseLoss): + """Contrastive loss with log_softmax and negative log likelihood.""" + + def __init__(self, reduction='mean', temperature=1.): + super().__init__(reduction) + self.temperature = temperature + + def forward(self, pred: Tensor, labels : Tensor = None) -> Tensor: + softmax_scores = F.log_softmax(pred / self.temperature, dim=1) + labels = labels.argmax(dim=1) if labels is not None else torch.zeros(pred.size(0), dtype=torch.long, device=pred.device).view(-1, 1) + return F.nll_loss(softmax_scores, labels, reduction=self.reduction) + +PAIRWISE_LOSSES = { + 'margin_mse': MarginMSELoss, + 'hinge': HingeLoss, + 'clear': ClearLoss, + 'lce': LCELoss, + 'contrastive': ContrastiveLoss, +} \ No newline at end of file diff --git a/rankers/train/loss/pointwise.py b/rankers/train/loss/pointwise.py new file mode 100644 index 0000000..45526fd --- /dev/null +++ b/rankers/train/loss/pointwise.py @@ -0,0 +1,16 @@ +import torch +from torch import Tensor +from torch.nn import functional as F +from . import BaseLoss + +class PointwiseMSELoss(BaseLoss): + """Pointwise MSE loss""" + + def forward(self, pred: Tensor, labels: Tensor) -> Tensor: + flattened_pred = pred.view(-1) + flattened_labels = labels.view(-1) + return F.mse_loss(flattened_pred, flattened_labels, reduction=self.reduction) + +POINTWISE_LOSSES = { + 'mse': PointwiseMSELoss, +} \ No newline at end of file diff --git a/rankers/train/trainer.py b/rankers/train/trainer.py new file mode 100644 index 0000000..f0bd68c --- /dev/null +++ b/rankers/train/trainer.py @@ -0,0 +1,169 @@ +import torch +import os +import logging +from transformers import Trainer +import math +import time +import pandas as pd +from typing import Optional, Union, Dict, List +from datasets import Dataset +from transformers.trainer_utils import EvalLoopOutput, speed_metrics +from transformers.integrations.deepspeed import deepspeed_init +from .loss import LOSSES + +logger = logging.getLogger(__name__) + +LOSS_NAME = "loss.pt" + +class ContrastTrainer(Trainer): + """Customized Trainer from Huggingface's Trainer""" + + def __init__(self, *args, loss=None, **kwargs) -> None: + super(ContrastTrainer, self).__init__(*args, **kwargs) + if isinstance(loss, str): + if loss not in LOSSES: raise ValueError(f"Unknown loss: {loss}") + self.loss = LOSSES[loss]() + else: + self.loss = loss + self.tokenizer = self.data_collator.tokenizer + self.model.config.group_size = self.args.group_size + + def compute_loss(self, model, inputs, return_outputs=False): + outputs = model(self.loss, **inputs) + # Save past state if it exists + if self.args.past_index >= 0: + self._past = outputs[self.args.past_index] + + if isinstance(outputs, dict) and "loss" not in outputs: + raise ValueError( + "The model did not return a loss from the inputs, only the following keys: " + f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." + ) + # We don't use .loss here since the model may return tuples instead of ModelOutput. + loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] + + return (loss, outputs) if return_outputs else loss + + def compute_metrics(self, result_frame : pd.DataFrame): + from ir_measures import evaluator, RR + qrels = self.eval_dataset.qrels + metrics = self.args.eval_metrics if self.args.eval_metrics else [RR@10] + evaluator = evaluator(metrics, qrels) + + return evaluator.calc_aggregate(result_frame) + + def evaluation_loop( + self, + dataset: Dataset, + description: str, + metric_key_prefix: str = "eval", + ) -> EvalLoopOutput: + """ + Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. + + Works both with or without labels. + """ + args = self.args + + # if eval is called w/o train, handle model prep here + if self.is_deepspeed_enabled and self.deepspeed is None: + _, _ = deepspeed_init(self, num_training_steps=0, inference=True) + + if len(self.accelerator._models) == 0 and model is self.model: + start_time = time.time() + model = ( + self.accelerator.prepare(model) + if self.is_deepspeed_enabled + else self.accelerator.prepare_model(model, evaluation_mode=True) + ) + self.model_preparation_time = round(time.time() - start_time, 4) + + if self.is_fsdp_enabled: + self.model = model + + # for the rest of this function `model` is the outside model, whether it was wrapped or not + if model is not self.model: + self.model_wrapped = model + + # backward compatibility + if self.is_deepspeed_enabled: + self.deepspeed = self.model_wrapped + + # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called + # while ``train`` is running, cast it to the right dtype first and then put on device + if not self.is_in_train: + if args.fp16_full_eval: + model = model.to(dtype=torch.float16, device=args.device) + elif args.bf16_full_eval: + model = model.to(dtype=torch.bfloat16, device=args.device) + + batch_size = self.args.eval_batch_size + + logger.info(f"\n***** Running {description} *****") + logger.info(f" Num queries = {len(dataset)}") + logger.info(f" Batch size = {batch_size}") + + eval_model = model.to_pyterrier() + result_frame = eval_model.transform(dataset.evaluation_data) + metrics = self.compute_metrics(result_frame) + + num_samples = len(dataset) + metrics = {f"{metric_key_prefix}_{k}": v for k, v in metrics.items()} + + return EvalLoopOutput(predictions=result_frame, label_ids=None, metrics=metrics, num_samples=num_samples) + + def evaluate( + self, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "eval", + ) -> Dict[str, float]: + # handle multipe eval datasets + override = eval_dataset is not None + eval_dataset = eval_dataset if override else self.eval_dataset + + # memory metrics - must set up as early as possible + self._memory_tracker.start() + + start_time = time.time() + + eval_loop = self.evaluation_loop + output = eval_loop( + eval_dataset, + description="Evaluation", + # No point gathering the predictions if there are no metrics, otherwise we defer to + # self.args.prediction_loss_only + prediction_loss_only=None, + ignore_keys=ignore_keys, + metric_key_prefix=metric_key_prefix, + ) + + total_batch_size = self.args.eval_batch_size * self.args.world_size + output.metrics.update( + speed_metrics( + metric_key_prefix, + start_time, + num_samples=output.num_samples, + num_steps=math.ceil(output.num_samples / total_batch_size), + ) + ) + + self.log(output.metrics) + + self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) + + self._memory_tracker.stop_and_update_metrics(output.metrics) + + return output.metrics + + def _load_optimizer_and_scheduler(self, checkpoint): + super()._load_optimizer_and_scheduler(checkpoint) + if checkpoint is None: + return + if os.path.exists(os.path.join(checkpoint, LOSS_NAME)): + self.loss.load_state_dict(torch.load(os.path.join(checkpoint, LOSS_NAME))) + + def _load_from_checkpoint(self, resume_from_checkpoint, model=None): + logger.info("Loading model's weight from %s", resume_from_checkpoint) + if model: return model.load_state_dict(resume_from_checkpoint) + else: self.model.load_state_dict(resume_from_checkpoint) \ No newline at end of file diff --git a/setup.py b/setup.py index 0c3bed0..191cce6 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ long_description = fh.read() setuptools.setup( - name="contrast", + name="rankers", version="0.0.1", author="Andrew Parry", description="A framework for training and evaluating neural IR models.",