Skip to content

Commit

Permalink
rename
Browse files Browse the repository at this point in the history
  • Loading branch information
Parry-Parry committed Sep 13, 2024
1 parent e2f251c commit 5cd01f9
Show file tree
Hide file tree
Showing 18 changed files with 1,897 additions and 1 deletion.
33 changes: 33 additions & 0 deletions rankers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
__version__ = "0.0.2"

from .train import loss as loss
from .train.trainer import ContrastTrainer
from .train.arguments import ContrastArguments
from .datasets import *
from .modelling import *

def is_torch_available():
try:
import torch
return True
except ImportError:
return False

def is_flax_available():
try:
import flax
return True
except ImportError:
return False

def seed_everything(seed=42):
import random
import numpy as np
import torch

random.seed(seed)
np.random.seed(seed)
if is_torch_available():
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
140 changes: 140 additions & 0 deletions rankers/_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from collections import defaultdict
import logging
from typing import Optional
import pandas as pd
import pyterrier as pt
import ir_datasets as irds

logger = logging.getLogger(__name__)

def _pivot(frame, negatives = None):
new = []
for row in frame.itertuples():
new.append(
{
"qid": row.query_id,
"docno": row.doc_id_a,
"pos": 1
})
if negatives:
for doc in negatives[row.query_id]:
new.append(
{
"qid": row.query_id,
"docno": doc
})
else:
new.append(
{
"qid": row.query_id,
"docno": row.doc_id_b
})
return pd.DataFrame.from_records(new)

def _qrel_pivot(frame):
new = []
for row in frame.itertuples():
new.append(
{
"qid": row.query_id,
"docno": row.doc_id,
"score": row.relevance
})
return pd.DataFrame.from_records(new)

def get_teacher_scores(model : pt.Transformer,
corpus : Optional[pd.DataFrame] = None,
ir_dataset : Optional[str] = None,
subset : Optional[int] = None,
negatives : Optional[dict] = None,
seed : int = 42):
assert corpus is not None or ir_dataset is not None, "Either corpus or ir_dataset must be provided"
if corpus:
for column in ["query", "text"]: assert column in corpus.columns, f"{column} not found in corpus"
if ir_dataset:
dataset = irds.load(ir_dataset)
docs = pd.DataFrame(dataset.docs_iter()).set_index("doc_id")["text"].to_dict()
queries = pd.DataFrame(dataset.queries_iter()).set_index("query_id")["text"].to_dict()
corpus = pd.DataFrame(dataset.docpairs_iter())
if negatives:
corpus = corpus[['query_id', 'doc_id_a']]
corpus = _pivot(corpus, negatives)
corpus['text'] = corpus['docno'].map(docs)
corpus['query'] = corpus['qid'].map(queries)
if subset:
corpus = corpus.sample(n=subset, random_state=seed)

logger.warning("Retrieving scores, this may take a while...")
scores = model.transform(corpus)
lookup = defaultdict(dict)
for qid, group in scores.groupby('qid'):
for docno, score in zip(group['docno'], group['score']):
lookup[qid][docno] = score
return lookup

def initialise_irds_eval(dataset : irds.Dataset):
qrels = pd.DataFrame(dataset.qrels_iter())
return _qrel_pivot(qrels)

def load_json(file: str):
import json
import gzip
"""
Load a JSON or JSONL (optionally compressed with gzip) file.
Parameters:
file (str): The path to the file to load.
Returns:
dict or list: The loaded JSON content. Returns a list for JSONL files,
and a dict for JSON files.
Raises:
ValueError: If the file extension is not recognized.
"""
if file.endswith(".json"):
with open(file, 'r') as f:
return json.load(f)
elif file.endswith(".jsonl"):
with open(file, 'r') as f:
return [json.loads(line) for line in f]
elif file.endswith(".json.gz"):
with gzip.open(file, 'rt') as f:
return json.load(f)
elif file.endswith(".jsonl.gz"):
with gzip.open(file, 'rt') as f:
return [json.loads(line) for line in f]
else:
raise ValueError(f"Unknown file type for {file}")

def save_json(data, file: str):
import json
import gzip
"""
Save data to a JSON or JSONL file (optionally compressed with gzip).
Parameters:
data (dict or list): The data to save. Must be a list for JSONL files.
file (str): The path to the file to save.
Raises:
ValueError: If the file extension is not recognized.
"""
if file.endswith(".json"):
with open(file, 'w') as f:
json.dump(data, f)
elif file.endswith(".jsonl"):
with open(file, 'w') as f:
for item in data:
f.write(json.dumps(item) + '\n')
elif file.endswith(".json.gz"):
with gzip.open(file, 'wt') as f:
json.dump(data, f)
elif file.endswith(".jsonl.gz"):
with gzip.open(file, 'wt') as f:
for item in data:
f.write(json.dumps(item) + '\n')
else:
raise ValueError(f"Unknown file type for {file}")


2 changes: 2 additions & 0 deletions rankers/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .dataset import *
from .loader import *
43 changes: 43 additions & 0 deletions rankers/datasets/corpus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pandas as pd

class Corpus:
def __init__(self,
documents : dict = None,
queries : dict = None,
qrels : pd.DataFrame = None
) -> None:
self.documents = documents
self.queries = queries
self.qrels = qrels

self.__post_init__()

def __post_init__(self):
if self.qrels:
for column in 'query_id', 'doc_id', 'relevance':
if column not in self.qrels.columns: raise ValueError(f"Format not recognised, Column '{column}' not found in qrels dataframe")

self.qrels = self.qrels[['query_id', 'doc_id', 'relevance']]

def has_documents(self):
return self.documents is not None

def has_queries(self):
return self.queries is not None

def has_qrels(self):
return self.qrels is not None

def queries_iter(self):
for queryid, text in self.queries.items():
yield {"query_id" : queryid, "text" : text}

def docs_iter(self):
for docid, text in self.documents.items():
yield {"doc_id" : docid, "text" : text}

def qrels_iter(self):
for queryid, docid, relevance in self.qrels.itertuples(index=False):
yield {"query_id" : queryid, "doc_id" : docid, "relevance" : relevance}


122 changes: 122 additions & 0 deletions rankers/datasets/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import random
from torch.utils.data import Dataset
import pandas as pd
import torch
from typing import Optional, Union
import ir_datasets as irds
from .._util import load_json
from .corpus import Corpus

from contrast._util import initialise_irds_eval

class TrainingDataset(Dataset):
def __init__(self,
training_data : pd.DataFrame,
corpus : Union[Corpus, irds.Dataset],
teacher_file : Optional[str] = None,
group_size : int = 2,
listwise : bool = False,
) -> None:
super().__init__()
self.training_data = training_data
self.corpus = corpus
self.teacher_file = teacher_file
self.group_size = group_size
self.listwise = listwise

self.__post_init__()


def __post_init__(self):

for column in 'query_id', 'doc_id_a', 'doc_id_b':
if column not in self.training_data.columns: raise ValueError(f"Format not recognised, Column '{column}' not found in triples dataframe")
self.docs = pd.DataFrame(self.corpus.docs_iter()).set_index("doc_id")["text"].to_dict()
self.queries = pd.DataFrame(self.corpus.queries_iter()).set_index("query_id")["text"].to_dict()

if self.teacher_file: self.teacher = load_json(self.teacher_file)

self.labels = True if self.teacher_file else False
self.multi_negatives = True if (type(self.training_data['doc_id_b'].iloc[0]) == list) else False

if not self.listwise:
if self.group_size > 2 and self.multi_negatives:
self.training_data['doc_id_b'] = self.training_data['doc_id_b'].map(lambda x: random.sample(x, self.group_size-1))
elif self.group_size == 2 and self.multi_negatives:
self.training_data['doc_id_b'] = self.training_data['doc_id_b'].map(lambda x: random.choice(x) if len(x) > 1 else x[0])
self.multi_negatives = False
elif self.group_size > 2 and not self.multi_negatives:
raise ValueError("Group size > 2 not supported for single negative samples")

@classmethod
def from_irds(cls,
ir_dataset : str,
teacher_file : Optional[str] = None,
group_size : int = 2,
collate_fn : Optional[callable] = lambda x : pd.DataFrame(x.docpairs_iter())
) -> 'TrainingDataset':
dataset = irds.load(ir_dataset)
assert dataset.has_docpairs(), "Dataset does not have docpairs, check you are not using a test collection"
training_data = collate_fn(dataset)
return cls(training_data, dataset, teacher_file, group_size)

def __len__(self):
return len(self.training_data)

def _teacher(self, qid, doc_id, positive=False):
assert self.labels, "No teacher file provided"
try: return self.teacher[str(qid)][str(doc_id)]
except KeyError: return 0.

def __getitem__(self, idx):
item = self.training_data.iloc[idx]
qid, doc_id_a, doc_id_b = item['query_id'], item['doc_id_a'], item['doc_id_b']
query = self.queries[str(qid)]
texts = [self.docs[str(doc_id_a)]] if not self.listwise else []

if self.multi_negatives: texts.extend([self.docs[str(doc)] for doc in doc_id_b])
else: texts.append(self.docs[str(doc_id_b)])

if self.labels:
scores = [self._teacher(str(qid), str(doc_id_a), positive=True)] if not self.listwise else []
if self.multi_negatives: scores.extend([self._teacher(qid, str(doc)) for doc in doc_id_b])
else: scores.append(self._teacher(str(qid), str(doc_id_b)))
return (query, texts, scores)
else:
return (query, texts)

class EvaluationDataset(Dataset):
def __init__(self,
evaluation_data : Union[pd.DataFrame, str],
corpus : Union[Corpus, irds.Dataset]
) -> None:
super().__init__()
self.evaluation_data = evaluation_data
self.corpus = corpus

self.__post_init__()

def __post_init__(self):
if type(self.evaluation_data) == str:
import pyterrier as pt
self.evaluation_data = pt.io.read_results(self.evaluation_data)
else:
for column in 'qid', 'docno', 'score':
if column not in self.evaluation_data.columns: raise ValueError(f"Format not recognised, Column '{column}' not found in dataframe")
self.docs = pd.DataFrame(self.corpus.docs_iter()).set_index("doc_id")["text"].to_dict()
self.queries = pd.DataFrame(self.corpus.queries_iter()).set_index("query_id")["text"].to_dict()
self.qrels = pd.DataFrame(self.corpus.qrels_iter())

self.evaluation_data['text'] = self.evaluation_data['docno'].map(self.docs)
self.evaluation_data['query'] = self.evaluation_data['qid'].map(self.queries)

@classmethod
def from_irds(cls,
ir_dataset : str,
) -> 'EvaluationDataset':
dataset = irds.load(ir_dataset)
evaluation_data = initialise_irds_eval(dataset)
return cls(evaluation_data, dataset)

def __len__(self):
return len(self.evaluation_data.qid.unique())
Loading

0 comments on commit 5cd01f9

Please sign in to comment.