-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e2f251c
commit 5cd01f9
Showing
18 changed files
with
1,897 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
__version__ = "0.0.2" | ||
|
||
from .train import loss as loss | ||
from .train.trainer import ContrastTrainer | ||
from .train.arguments import ContrastArguments | ||
from .datasets import * | ||
from .modelling import * | ||
|
||
def is_torch_available(): | ||
try: | ||
import torch | ||
return True | ||
except ImportError: | ||
return False | ||
|
||
def is_flax_available(): | ||
try: | ||
import flax | ||
return True | ||
except ImportError: | ||
return False | ||
|
||
def seed_everything(seed=42): | ||
import random | ||
import numpy as np | ||
import torch | ||
|
||
random.seed(seed) | ||
np.random.seed(seed) | ||
if is_torch_available(): | ||
torch.manual_seed(seed) | ||
torch.cuda.manual_seed_all(seed) | ||
torch.backends.cudnn.deterministic = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
from collections import defaultdict | ||
import logging | ||
from typing import Optional | ||
import pandas as pd | ||
import pyterrier as pt | ||
import ir_datasets as irds | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
def _pivot(frame, negatives = None): | ||
new = [] | ||
for row in frame.itertuples(): | ||
new.append( | ||
{ | ||
"qid": row.query_id, | ||
"docno": row.doc_id_a, | ||
"pos": 1 | ||
}) | ||
if negatives: | ||
for doc in negatives[row.query_id]: | ||
new.append( | ||
{ | ||
"qid": row.query_id, | ||
"docno": doc | ||
}) | ||
else: | ||
new.append( | ||
{ | ||
"qid": row.query_id, | ||
"docno": row.doc_id_b | ||
}) | ||
return pd.DataFrame.from_records(new) | ||
|
||
def _qrel_pivot(frame): | ||
new = [] | ||
for row in frame.itertuples(): | ||
new.append( | ||
{ | ||
"qid": row.query_id, | ||
"docno": row.doc_id, | ||
"score": row.relevance | ||
}) | ||
return pd.DataFrame.from_records(new) | ||
|
||
def get_teacher_scores(model : pt.Transformer, | ||
corpus : Optional[pd.DataFrame] = None, | ||
ir_dataset : Optional[str] = None, | ||
subset : Optional[int] = None, | ||
negatives : Optional[dict] = None, | ||
seed : int = 42): | ||
assert corpus is not None or ir_dataset is not None, "Either corpus or ir_dataset must be provided" | ||
if corpus: | ||
for column in ["query", "text"]: assert column in corpus.columns, f"{column} not found in corpus" | ||
if ir_dataset: | ||
dataset = irds.load(ir_dataset) | ||
docs = pd.DataFrame(dataset.docs_iter()).set_index("doc_id")["text"].to_dict() | ||
queries = pd.DataFrame(dataset.queries_iter()).set_index("query_id")["text"].to_dict() | ||
corpus = pd.DataFrame(dataset.docpairs_iter()) | ||
if negatives: | ||
corpus = corpus[['query_id', 'doc_id_a']] | ||
corpus = _pivot(corpus, negatives) | ||
corpus['text'] = corpus['docno'].map(docs) | ||
corpus['query'] = corpus['qid'].map(queries) | ||
if subset: | ||
corpus = corpus.sample(n=subset, random_state=seed) | ||
|
||
logger.warning("Retrieving scores, this may take a while...") | ||
scores = model.transform(corpus) | ||
lookup = defaultdict(dict) | ||
for qid, group in scores.groupby('qid'): | ||
for docno, score in zip(group['docno'], group['score']): | ||
lookup[qid][docno] = score | ||
return lookup | ||
|
||
def initialise_irds_eval(dataset : irds.Dataset): | ||
qrels = pd.DataFrame(dataset.qrels_iter()) | ||
return _qrel_pivot(qrels) | ||
|
||
def load_json(file: str): | ||
import json | ||
import gzip | ||
""" | ||
Load a JSON or JSONL (optionally compressed with gzip) file. | ||
Parameters: | ||
file (str): The path to the file to load. | ||
Returns: | ||
dict or list: The loaded JSON content. Returns a list for JSONL files, | ||
and a dict for JSON files. | ||
Raises: | ||
ValueError: If the file extension is not recognized. | ||
""" | ||
if file.endswith(".json"): | ||
with open(file, 'r') as f: | ||
return json.load(f) | ||
elif file.endswith(".jsonl"): | ||
with open(file, 'r') as f: | ||
return [json.loads(line) for line in f] | ||
elif file.endswith(".json.gz"): | ||
with gzip.open(file, 'rt') as f: | ||
return json.load(f) | ||
elif file.endswith(".jsonl.gz"): | ||
with gzip.open(file, 'rt') as f: | ||
return [json.loads(line) for line in f] | ||
else: | ||
raise ValueError(f"Unknown file type for {file}") | ||
|
||
def save_json(data, file: str): | ||
import json | ||
import gzip | ||
""" | ||
Save data to a JSON or JSONL file (optionally compressed with gzip). | ||
Parameters: | ||
data (dict or list): The data to save. Must be a list for JSONL files. | ||
file (str): The path to the file to save. | ||
Raises: | ||
ValueError: If the file extension is not recognized. | ||
""" | ||
if file.endswith(".json"): | ||
with open(file, 'w') as f: | ||
json.dump(data, f) | ||
elif file.endswith(".jsonl"): | ||
with open(file, 'w') as f: | ||
for item in data: | ||
f.write(json.dumps(item) + '\n') | ||
elif file.endswith(".json.gz"): | ||
with gzip.open(file, 'wt') as f: | ||
json.dump(data, f) | ||
elif file.endswith(".jsonl.gz"): | ||
with gzip.open(file, 'wt') as f: | ||
for item in data: | ||
f.write(json.dumps(item) + '\n') | ||
else: | ||
raise ValueError(f"Unknown file type for {file}") | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .dataset import * | ||
from .loader import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import pandas as pd | ||
|
||
class Corpus: | ||
def __init__(self, | ||
documents : dict = None, | ||
queries : dict = None, | ||
qrels : pd.DataFrame = None | ||
) -> None: | ||
self.documents = documents | ||
self.queries = queries | ||
self.qrels = qrels | ||
|
||
self.__post_init__() | ||
|
||
def __post_init__(self): | ||
if self.qrels: | ||
for column in 'query_id', 'doc_id', 'relevance': | ||
if column not in self.qrels.columns: raise ValueError(f"Format not recognised, Column '{column}' not found in qrels dataframe") | ||
|
||
self.qrels = self.qrels[['query_id', 'doc_id', 'relevance']] | ||
|
||
def has_documents(self): | ||
return self.documents is not None | ||
|
||
def has_queries(self): | ||
return self.queries is not None | ||
|
||
def has_qrels(self): | ||
return self.qrels is not None | ||
|
||
def queries_iter(self): | ||
for queryid, text in self.queries.items(): | ||
yield {"query_id" : queryid, "text" : text} | ||
|
||
def docs_iter(self): | ||
for docid, text in self.documents.items(): | ||
yield {"doc_id" : docid, "text" : text} | ||
|
||
def qrels_iter(self): | ||
for queryid, docid, relevance in self.qrels.itertuples(index=False): | ||
yield {"query_id" : queryid, "doc_id" : docid, "relevance" : relevance} | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import random | ||
from torch.utils.data import Dataset | ||
import pandas as pd | ||
import torch | ||
from typing import Optional, Union | ||
import ir_datasets as irds | ||
from .._util import load_json | ||
from .corpus import Corpus | ||
|
||
from contrast._util import initialise_irds_eval | ||
|
||
class TrainingDataset(Dataset): | ||
def __init__(self, | ||
training_data : pd.DataFrame, | ||
corpus : Union[Corpus, irds.Dataset], | ||
teacher_file : Optional[str] = None, | ||
group_size : int = 2, | ||
listwise : bool = False, | ||
) -> None: | ||
super().__init__() | ||
self.training_data = training_data | ||
self.corpus = corpus | ||
self.teacher_file = teacher_file | ||
self.group_size = group_size | ||
self.listwise = listwise | ||
|
||
self.__post_init__() | ||
|
||
|
||
def __post_init__(self): | ||
|
||
for column in 'query_id', 'doc_id_a', 'doc_id_b': | ||
if column not in self.training_data.columns: raise ValueError(f"Format not recognised, Column '{column}' not found in triples dataframe") | ||
self.docs = pd.DataFrame(self.corpus.docs_iter()).set_index("doc_id")["text"].to_dict() | ||
self.queries = pd.DataFrame(self.corpus.queries_iter()).set_index("query_id")["text"].to_dict() | ||
|
||
if self.teacher_file: self.teacher = load_json(self.teacher_file) | ||
|
||
self.labels = True if self.teacher_file else False | ||
self.multi_negatives = True if (type(self.training_data['doc_id_b'].iloc[0]) == list) else False | ||
|
||
if not self.listwise: | ||
if self.group_size > 2 and self.multi_negatives: | ||
self.training_data['doc_id_b'] = self.training_data['doc_id_b'].map(lambda x: random.sample(x, self.group_size-1)) | ||
elif self.group_size == 2 and self.multi_negatives: | ||
self.training_data['doc_id_b'] = self.training_data['doc_id_b'].map(lambda x: random.choice(x) if len(x) > 1 else x[0]) | ||
self.multi_negatives = False | ||
elif self.group_size > 2 and not self.multi_negatives: | ||
raise ValueError("Group size > 2 not supported for single negative samples") | ||
|
||
@classmethod | ||
def from_irds(cls, | ||
ir_dataset : str, | ||
teacher_file : Optional[str] = None, | ||
group_size : int = 2, | ||
collate_fn : Optional[callable] = lambda x : pd.DataFrame(x.docpairs_iter()) | ||
) -> 'TrainingDataset': | ||
dataset = irds.load(ir_dataset) | ||
assert dataset.has_docpairs(), "Dataset does not have docpairs, check you are not using a test collection" | ||
training_data = collate_fn(dataset) | ||
return cls(training_data, dataset, teacher_file, group_size) | ||
|
||
def __len__(self): | ||
return len(self.training_data) | ||
|
||
def _teacher(self, qid, doc_id, positive=False): | ||
assert self.labels, "No teacher file provided" | ||
try: return self.teacher[str(qid)][str(doc_id)] | ||
except KeyError: return 0. | ||
|
||
def __getitem__(self, idx): | ||
item = self.training_data.iloc[idx] | ||
qid, doc_id_a, doc_id_b = item['query_id'], item['doc_id_a'], item['doc_id_b'] | ||
query = self.queries[str(qid)] | ||
texts = [self.docs[str(doc_id_a)]] if not self.listwise else [] | ||
|
||
if self.multi_negatives: texts.extend([self.docs[str(doc)] for doc in doc_id_b]) | ||
else: texts.append(self.docs[str(doc_id_b)]) | ||
|
||
if self.labels: | ||
scores = [self._teacher(str(qid), str(doc_id_a), positive=True)] if not self.listwise else [] | ||
if self.multi_negatives: scores.extend([self._teacher(qid, str(doc)) for doc in doc_id_b]) | ||
else: scores.append(self._teacher(str(qid), str(doc_id_b))) | ||
return (query, texts, scores) | ||
else: | ||
return (query, texts) | ||
|
||
class EvaluationDataset(Dataset): | ||
def __init__(self, | ||
evaluation_data : Union[pd.DataFrame, str], | ||
corpus : Union[Corpus, irds.Dataset] | ||
) -> None: | ||
super().__init__() | ||
self.evaluation_data = evaluation_data | ||
self.corpus = corpus | ||
|
||
self.__post_init__() | ||
|
||
def __post_init__(self): | ||
if type(self.evaluation_data) == str: | ||
import pyterrier as pt | ||
self.evaluation_data = pt.io.read_results(self.evaluation_data) | ||
else: | ||
for column in 'qid', 'docno', 'score': | ||
if column not in self.evaluation_data.columns: raise ValueError(f"Format not recognised, Column '{column}' not found in dataframe") | ||
self.docs = pd.DataFrame(self.corpus.docs_iter()).set_index("doc_id")["text"].to_dict() | ||
self.queries = pd.DataFrame(self.corpus.queries_iter()).set_index("query_id")["text"].to_dict() | ||
self.qrels = pd.DataFrame(self.corpus.qrels_iter()) | ||
|
||
self.evaluation_data['text'] = self.evaluation_data['docno'].map(self.docs) | ||
self.evaluation_data['query'] = self.evaluation_data['qid'].map(self.queries) | ||
|
||
@classmethod | ||
def from_irds(cls, | ||
ir_dataset : str, | ||
) -> 'EvaluationDataset': | ||
dataset = irds.load(ir_dataset) | ||
evaluation_data = initialise_irds_eval(dataset) | ||
return cls(evaluation_data, dataset) | ||
|
||
def __len__(self): | ||
return len(self.evaluation_data.qid.unique()) |
Oops, something went wrong.