Skip to content

Commit

Permalink
Merge pull request #31 from UKPLab/development
Browse files Browse the repository at this point in the history
Merging latest development branch to main branch
  • Loading branch information
thakur-nandan authored Jul 19, 2021
2 parents 8ab9bbf + 17a5a63 commit aacc574
Show file tree
Hide file tree
Showing 10 changed files with 116 additions and 43 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ To easily understand and get your hands dirty with BEIR, we invite you to try ou
| ------------------------------------------- | ---------- |
| Hybrid sparse retrieval using SPARTA | [evaluate_sparta.py](https://github.com/UKPLab/beir/blob/main/examples/retrieval/evaluation/sparse/evaluate_sparta.py) |
| Sparse retrieval using docT5query and Pyserini | [evaluate_anserini_docT5query.py](https://github.com/UKPLab/beir/blob/main/examples/retrieval/evaluation/sparse/evaluate_anserini_docT5query.py) |
| Sparse retrieval using docT5query (MultiGPU) and Pyserini | [evaluate_anserini_docT5query_parallel.py](https://github.com/UKPLab/beir/blob/main/examples/retrieval/evaluation/sparse/evaluate_anserini_docT5query_parallel.py) :new: |
| Sparse retrieval using DeepCT and Pyserini :new: | [evaluate_deepct.py](https://github.com/UKPLab/beir/blob/main/examples/retrieval/evaluation/sparse/evaluate_deepct.py) |

### :beers: Reranking (Evaluation)
Expand Down Expand Up @@ -212,6 +213,8 @@ For other datasets, just use one of the datasets names, mention below.

## :beers: Available Datasets

Command to generate md5hash using Terminal: ``md5hash filename.zip``.

| Dataset | Website| BEIR-Name | Queries | Corpus | Rel D/Q | Down-load | md5 |
| -------- | -----| ---------| ----------- | ---------| ---------| :----------: | :------:|
| MSMARCO | [Homepage](https://microsoft.github.io/msmarco/)| ``msmarco`` | 6,980 | 8.84M | 1.1 | [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/msmarco.zip) | ``444067daf65d982533ea17ebd59501e4`` |
Expand All @@ -229,7 +232,7 @@ For other datasets, just use one of the datasets names, mention below.
| Quora| [Homepage](https://www.quora.com/q/quoradata/First-Quora-Dataset-Release-Question-Pairs) | ``quora``| 10,000 | 523K | 1.6 | [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/quora.zip) | ``18fb154900ba42a600f84b839c173167`` |
| DBPedia | [Homepage](https://github.com/iai-group/DBpedia-Entity/) | ``dbpedia-entity``| 400 | 4.63M | 38.2 | [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/dbpedia-entity.zip) | ``c2a39eb420a3164af735795df012ac2c`` |
| SCIDOCS| [Homepage](https://allenai.org/data/scidocs) | ``scidocs``| 1,000 | 25K | 4.9 | [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/scidocs.zip) | ``38121350fc3a4d2f48850f6aff52e4a9`` |
| FEVER| [Homepage](http://fever.ai) | ``fever``| 6,666 | 5.42M | 1.2| [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/fever.zip) | ``88591ef8eb2913126d0c93ecbde6285f`` |
| FEVER | [Homepage](http://fever.ai) | ``fever``| 6,666 | 5.42M | 1.2| [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/fever.zip) | ``5a818580227bfb4b35bb6fa46d9b6c03`` |
| Climate-FEVER| [Homepage](http://climatefever.ai) | ``climate-fever``| 1,535 | 5.42M | 3.0 | [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/climate-fever.zip) | ``8b66f0a9126c521bae2bde127b4dc99d`` |
| SciFact| [Homepage](https://github.com/allenai/scifact) | ``scifact``| 300 | 5K | 1.1 | [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/scifact.zip) | ``5f7d1de60b170fc8027bb7898e2efca1`` |
| Robust04 | [Homepage](https://trec.nist.gov/data/robust/04.guidelines.html) | ``robust04``| 249 | 528K | 69.9 | No | [How to Reproduce?](https://github.com/UKPLab/beir/blob/main/examples/dataset#3-robust04) |
Expand Down Expand Up @@ -429,6 +432,7 @@ We also include custom-metrics now which can be used for evaluation, please refe
- MRR (``MRR@k``)
- Capped Recall (``R_cap@k``)
- Hole (``Hole@k``): % of top-k docs retrieved unseen by annotators
- Top-K Accuracy (``Accuracy@k``): % of relevant docs present in top-k results

## :beers: Citing & Authors

Expand Down
32 changes: 31 additions & 1 deletion beir/retrieval/custom_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,34 @@ def hole(qrels: Dict[str, Dict[str, int]],
Hole[f"Hole@{k}"] = round(Hole[f"Hole@{k}"]/len(results), 5)
logging.info("Hole@{}: {:.4f}".format(k, Hole[f"Hole@{k}"]))

return Hole
return Hole

def top_k_accuracy(
qrels: Dict[str, Dict[str, int]],
results: Dict[str, Dict[str, float]],
k_values: List[int]) -> Tuple[Dict[str, float]]:

top_k_acc = {}

for k in k_values:
top_k_acc[f"Accuracy@{k}"] = 0.0

k_max, top_hits = max(k_values), {}
logging.info("\n")

for query_id, doc_scores in results.items():
top_hits[query_id] = sorted(doc_scores.keys(), key=lambda item: item[1], reverse=True)[0:k_max]

for query_id in qrels:
query_relevant_docs = set([doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0])
for k in k_values:
for relevant_doc_id in query_relevant_docs:
if relevant_doc_id in top_hits[query_id][0:k]:
top_k_acc[f"Accuracy@{k}"] += 1.0
break

for k in k_values:
top_k_acc[f"Accuracy@{k}"] = round(top_k_acc[f"Accuracy@{k}"]/len(qrels), 5)
logging.info("Accuracy@{}: {:.4f}".format(k, top_k_acc[f"Accuracy@{k}"]))

return top_k_acc
9 changes: 6 additions & 3 deletions beir/retrieval/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .search.dense import DenseRetrievalFaissSearch as DRFS
from .search.lexical import BM25Search as BM25
from .search.sparse import SparseSearch as SS
from .custom_metrics import mrr, recall_cap, hole
from .custom_metrics import mrr, recall_cap, hole, top_k_accuracy

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -86,7 +86,7 @@ def evaluate(qrels: Dict[str, Dict[str, int]],
@staticmethod
def evaluate_custom(qrels: Dict[str, Dict[str, int]],
results: Dict[str, Dict[str, float]],
k_values: List[int], metric: str in ["mrr", "r_cap", "hole"]) -> Tuple[Dict[str, float]]:
k_values: List[int], metric: str) -> Tuple[Dict[str, float]]:

if metric.lower() in ["mrr", "mrr@k", "mrr_cut"]:
return mrr(qrels, results, k_values)
Expand All @@ -95,4 +95,7 @@ def evaluate_custom(qrels: Dict[str, Dict[str, int]],
return recall_cap(qrels, results, k_values)

elif metric.lower() in ["hole", "hole@k"]:
return hole(qrels, results, k_values)
return hole(qrels, results, k_values)

elif metric.lower() in ["acc", "top_k_acc", "accuracy", "accuracy@k", "top_k_accuracy"]:
return top_k_accuracy(qrels, results, k_values)
2 changes: 1 addition & 1 deletion beir/retrieval/search/dense/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .exact_search import DenseRetrievalExactSearch
from .faiss_search import DenseRetrievalFaissSearch, BinaryFaissSearch, PQFaissSearch, HNSWFaissSearch, FlatIPFaissSearch, PCAFaissSearch
from .faiss_search import DenseRetrievalFaissSearch, BinaryFaissSearch, PQFaissSearch, HNSWFaissSearch, FlatIPFaissSearch, PCAFaissSearch, SQFaissSearch
9 changes: 7 additions & 2 deletions beir/retrieval/search/dense/exact_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def __init__(self, model, batch_size: int = 128, corpus_chunk_size: int = 50000,
self.model = model
self.batch_size = batch_size
self.score_functions = {'cos_sim': cos_sim, 'dot': dot_score}
self.score_function_desc = {'cos_sim': "Cosine Similarity", 'dot': "Dot Product"}
self.corpus_chunk_size = corpus_chunk_size
self.show_progress_bar = True #TODO: implement no progress bar if false
self.convert_to_tensor = True
Expand All @@ -38,10 +39,14 @@ def search(self,
query_embeddings = self.model.encode_queries(
queries, batch_size=self.batch_size, show_progress_bar=self.show_progress_bar, convert_to_tensor=self.convert_to_tensor)

logger.info("Encoding Corpus in batches... Warning: This might take a while!")
corpus_ids = list(corpus.keys())
logger.info("Sorting Corpus by document length (Longest first)...")

corpus_ids = sorted(corpus, key=lambda k: len(corpus[k].get("title", "") + corpus[k].get("text", "")), reverse=True)
corpus = [corpus[cid] for cid in corpus_ids]

logger.info("Encoding Corpus in batches... Warning: This might take a while!")
logger.info("Scoring Function: {} ({})".format(self.score_function_desc[score_function], score_function))

itr = range(0, len(corpus), self.corpus_chunk_size)

for batch_num, corpus_start_idx in enumerate(itr):
Expand Down
21 changes: 1 addition & 20 deletions beir/retrieval/search/dense/faiss_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def build(
passage_embeddings = np.hstack((passage_embeddings, aux_dims.reshape(-1, 1)))
return super().build(passage_ids, passage_embeddings, index, buffer_size)

class FaissPQIndex(FaissIndex):
class FaissTrainIndex(FaissIndex):
def search(self, query_embeddings: np.ndarray, k: int, **kwargs) -> Tuple[np.ndarray, np.ndarray]:
return super().search(query_embeddings, k)

Expand All @@ -95,25 +95,6 @@ def build(
index.train(passage_embeddings)
return super().build(passage_ids, passage_embeddings, index, buffer_size)

class FaissPCAIndex(FaissIndex):
def search(self, query_embeddings: np.ndarray, k: int, **kwargs) -> Tuple[np.ndarray, np.ndarray]:
return super().search(query_embeddings, k)

def save(self, output_path: str):
super().save(output_path)

@classmethod
def build(
cls,
passage_ids: List[int],
passage_embeddings: np.ndarray,
index: Optional[faiss.Index] = None,
buffer_size: int = 50000,
):
index.train(passage_embeddings)
return super().build(passage_ids, passage_embeddings, index, buffer_size)


class FaissBinaryIndex(FaissIndex):
def __init__(self, index: faiss.Index, passage_ids: List[int] = None, passage_embeddings: np.ndarray = None):
self.index = index
Expand Down
67 changes: 56 additions & 11 deletions beir/retrieval/search/dense/faiss_search.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .util import cos_sim, dot_score, normalize, save_dict_to_tsv, load_tsv_to_dict
from .faiss_index import FaissBinaryIndex, FaissPQIndex, FaissHNSWIndex, FaissPCAIndex, FaissIndex
from .faiss_index import FaissBinaryIndex, FaissTrainIndex, FaissHNSWIndex, FaissIndex
import logging
import sys
import torch
Expand Down Expand Up @@ -61,12 +61,14 @@ def save(self, output_dir: str, prefix: str, ext: str):

def _index(self, corpus: Dict[str, Dict[str, str]], score_function: str = None):

logger.info("Encoding Corpus in batches... Warning: This might take a while!")
corpus_ids = list(corpus.keys())
logger.info("Sorting Corpus by document length (Longest first)...")
corpus_ids = sorted(corpus, key=lambda k: len(corpus[k].get("title", "") + corpus[k].get("text", "")), reverse=True)
self._create_mapping_ids(corpus_ids)
corpus = [corpus[cid] for cid in corpus_ids]
normalize_embeddings = True if score_function == "cos_sim" else False

logger.info("Encoding Corpus in batches... Warning: This might take a while!")

itr = range(0, len(corpus), self.corpus_chunk_size)

for batch_num, corpus_start_idx in enumerate(itr):
Expand Down Expand Up @@ -137,6 +139,7 @@ def load(self, input_dir: str, prefix: str = "my-index", ext: str = "bin"):
def index(self, corpus: Dict[str, Dict[str, str]], score_function: str = None):
faiss_ids, corpus_embeddings = super()._index(corpus, score_function)
logger.info("Using Binary Hashing in Flat Mode!")
logger.info("Output Dimension: {}".format(self.dim_size))
base_index = faiss.IndexBinaryFlat(self.dim_size * 8)
self.faiss_index = FaissBinaryIndex.build(faiss_ids, corpus_embeddings, base_index)

Expand All @@ -154,25 +157,34 @@ def search(self,

class PQFaissSearch(DenseRetrievalFaissSearch):
def __init__(self, model, batch_size: int = 128, corpus_chunk_size: int = 50000, num_of_centroids: int = 96,
code_size: int = 8, similarity_metric=faiss.METRIC_INNER_PRODUCT, **kwargs):
code_size: int = 8, similarity_metric=faiss.METRIC_INNER_PRODUCT, use_rotation: bool = False, **kwargs):
super(PQFaissSearch, self).__init__(model, batch_size, corpus_chunk_size, **kwargs)
self.num_of_centroids = num_of_centroids
self.code_size = code_size
self.similarity_metric = similarity_metric
self.use_rotation = use_rotation

def load(self, input_dir: str, prefix: str = "my-index", ext: str = "pq"):
input_faiss_path, passage_ids = super()._load(input_dir, prefix, ext)
base_index = faiss.read_index(input_faiss_path)
self.faiss_index = FaissPQIndex(base_index, passage_ids)
self.faiss_index = FaissTrainIndex(base_index, passage_ids)

def index(self, corpus: Dict[str, Dict[str, str]], score_function: str = None, **kwargs):
faiss_ids, corpus_embeddings = super()._index(corpus, score_function, **kwargs)
faiss_ids, corpus_embeddings = super()._index(corpus, score_function, **kwargs)

logger.info("Using Product Quantization (PQ) in Flat mode!")
logger.info("Parameters Used: num_of_centroids: {} ".format(self.num_of_centroids))
logger.info("Parameters Used: code_size: {}".format(self.code_size))

logger.info("Parameters Used: code_size: {}".format(self.code_size))
base_index = faiss.IndexPQ(self.dim_size, self.num_of_centroids, self.code_size, self.similarity_metric)
self.faiss_index = FaissPQIndex.build(faiss_ids, corpus_embeddings, base_index)

if self.use_rotation:
logger.info("Rotating data before encoding it with a product quantizer...")
logger.info("Creating OPQ Matrix...")
opq_matrix = faiss.OPQMatrix(self.dim_size, self.code_size)
base_index = faiss.IndexPreTransform(opq_matrix, base_index)

self.faiss_index = FaissTrainIndex.build(faiss_ids, corpus_embeddings, base_index)

def save(self, output_dir: str, prefix: str = "my-index", ext: str = "pq"):
super().save(output_dir, prefix, ext)
Expand Down Expand Up @@ -256,19 +268,52 @@ def __init__(self, model, base_index: faiss.Index, output_dimension: int, batch_
def load(self, input_dir: str, prefix: str = "my-index", ext: str = "pca"):
input_faiss_path, passage_ids = super()._load(input_dir, prefix, ext)
base_index = faiss.read_index(input_faiss_path)
self.faiss_index = FaissPCAIndex(base_index, passage_ids)
self.faiss_index = FaissTrainIndex(base_index, passage_ids)

def index(self, corpus: Dict[str, Dict[str, str]], score_function: str = None, **kwargs):
faiss_ids, corpus_embeddings = super()._index(corpus, score_function, **kwargs)
logger.info("Creating PCA Matrix...")
logger.info("Input Dimension: {}, Output Dimension: {}".format(self.dim_size, self.output_dim))
pca_matrix = faiss.PCAMatrix(self.dim_size, self.output_dim, 0, True)
final_index = faiss.IndexPreTransform(pca_matrix, self.base_index)
self.faiss_index = FaissPCAIndex.build(faiss_ids, corpus_embeddings, final_index)
self.faiss_index = FaissTrainIndex.build(faiss_ids, corpus_embeddings, final_index)

def save(self, output_dir: str, prefix: str = "my-index", ext: str = "pca"):
super().save(output_dir, prefix, ext)

def search(self,
corpus: Dict[str, Dict[str, str]],
queries: Dict[str, str],
top_k: int,
score_function = str, **kwargs) -> Dict[str, Dict[str, float]]:

return super().search(corpus, queries, top_k, score_function, **kwargs)

class SQFaissSearch(DenseRetrievalFaissSearch):
def __init__(self, model, batch_size: int = 128, corpus_chunk_size: int = 50000,
similarity_metric=faiss.METRIC_INNER_PRODUCT, quantizer_type: str = "QT_fp16", **kwargs):
super(SQFaissSearch, self).__init__(model, batch_size, corpus_chunk_size, **kwargs)
self.similarity_metric = similarity_metric
self.qname = quantizer_type

def load(self, input_dir: str, prefix: str = "my-index", ext: str = "sq"):
input_faiss_path, passage_ids = super()._load(input_dir, prefix, ext)
base_index = faiss.read_index(input_faiss_path)
self.faiss_index = FaissTrainIndex(base_index, passage_ids)

def index(self, corpus: Dict[str, Dict[str, str]], score_function: str = None, **kwargs):
faiss_ids, corpus_embeddings = super()._index(corpus, score_function, **kwargs)

logger.info("Using Scalar Quantizer in Flat Mode!")
logger.info("Parameters Used: quantizer_type: {}".format(self.qname))

qtype = getattr(faiss.ScalarQuantizer, self.qname)
base_index = faiss.IndexScalarQuantizer(self.dim_size, qtype, self.similarity_metric)
self.faiss_index = FaissTrainIndex.build(faiss_ids, corpus_embeddings, base_index)

def save(self, output_dir: str, prefix: str = "my-index", ext: str = "sq"):
super().save(output_dir, prefix, ext)

def search(self,
corpus: Dict[str, Dict[str, str]],
queries: Dict[str, str],
Expand Down
10 changes: 7 additions & 3 deletions beir/retrieval/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Dict, Type, List, Callable, Iterable, Tuple
import logging
import time
import difflib

logger = logging.getLogger(__name__)

Expand All @@ -29,9 +30,12 @@ def load_train(self, corpus: Dict[str, Dict[str, str]], queries: Dict[str, str],
for query_id in query_ids_batch:
for corpus_id, score in qrels[query_id].items():
if score >= 1: # if score = 0, we don't consider for training
s1 = queries[query_id]
s2 = corpus[corpus_id].get("title") + " " + corpus[corpus_id].get("text")
train_samples.append(InputExample(guid=idx, texts=[s1, s2], label=1))
try:
s1 = queries[query_id]
s2 = corpus[corpus_id].get("title") + " " + corpus[corpus_id].get("text")
train_samples.append(InputExample(guid=idx, texts=[s1, s2], label=1))
except KeyError:
logging.error("Error: Key {} not present in corpus!".format(corpus_id))

logger.info("Loaded {} training pairs.".format(len(train_samples)))
return train_samples
Expand Down
2 changes: 1 addition & 1 deletion examples/dataset/md5.csv
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ cqadupstack.zip,4e41456d7df8ee7760a7f866133bda78
quora.zip,18fb154900ba42a600f84b839c173167
dbpedia-entity.zip,c2a39eb420a3164af735795df012ac2c
scidocs.zip,38121350fc3a4d2f48850f6aff52e4a9
fever.zip,88591ef8eb2913126d0c93ecbde6285f
fever.zip,5f7d1de60b170fc8027bb7898e2efca1
climate-fever.zip,8b66f0a9126c521bae2bde127b4dc99d
scifact.zip,5f7d1de60b170fc8027bb7898e2efca1
germanquad.zip,95a581c3162d10915a418609bcce851b
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
mrr = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="mrr")
recall_cap = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="recall_cap")
hole = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="hole")
top_k_accuracy = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="top_k_accuracy")

#### Print top-k documents retrieved ####
top_k = 10
Expand Down

0 comments on commit aacc574

Please sign in to comment.