Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added batch equivalent of computeQueryDocumentScore #1882

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 154 additions & 1 deletion src/main/java/io/anserini/index/IndexReaderUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import io.anserini.search.SearchArgs;
import io.anserini.search.query.BagOfWordsQueryGenerator;
import io.anserini.search.query.PhraseQueryGenerator;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to @HAKSOAT: Remove these logger imports if not needed in the final implementation.

import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.DirectoryReader;
Expand Down Expand Up @@ -59,12 +61,17 @@
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

/**
* Class containing a bunch of static helper methods for accessing a Lucene inverted index.
* This class provides a lot of functionality that is exposed in Python via Pyserini.
*/
public class IndexReaderUtils {
private static final Logger LOG = LogManager.getLogger(IndexReaderUtils.class);

/**
* An individual posting in a postings list. Note that this class is used primarily for inspecting
Expand Down Expand Up @@ -726,7 +733,153 @@ public static float computeQueryDocumentScoreWithSimilarityAndAnalyzer(
return rs.scoreDocs.length == 0 ? 0 : rs.scoreDocs[0].score - 1;
}

// TODO: Write a variant of computeQueryDocumentScore that takes a set of documents.
/**
* Computes the scores of a batch of documents with respect to a query given a scoring function and an analyzer.
*
* @param reader index reader
* @param docids A list of docids of the documents to score
* @param q query
* @param threads number of threads
* @return a map of document ids to their scores with respect to the query
* @throws IOException if error encountered during query
*/
public static Map<String, Float> batchComputeQueryDocumentScore(
IndexReader reader, List<String> docids, String q, int threads)
throws IOException {

SearchArgs args = new SearchArgs();
return batchComputeQueryDocumentScoreWithSimilarityAndAnalyzer(reader, docids, q,
new BM25Similarity(Float.parseFloat(args.bm25_k1[0]), Float.parseFloat(args.bm25_b[0])),
IndexCollection.DEFAULT_ANALYZER, threads);
}


/**
* Computes the scores of a batch of documents with respect to a query given a scoring function and an analyzer.
*
* @param reader index reader
* @param docids A list of docids of the documents to score
* @param q query
* @param similarity scoring function
* @param threads number of threads
* @return a map of document ids to their scores with respect to the query
* @throws IOException if error encountered during query
*/
public static Map<String, Float> batchComputeQueryDocumentScore(
IndexReader reader, List<String> docids, String q, Similarity similarity, int threads)
throws IOException {

return batchComputeQueryDocumentScoreWithSimilarityAndAnalyzer(reader, docids, q, similarity,
IndexCollection.DEFAULT_ANALYZER, threads);
}


/**
* Computes the scores of a batch of documents with respect to a query given a scoring function and an analyzer.
*
* @param reader index reader
* @param docids A list of docids of the documents to score
* @param q query
* @param analyzer analyzer to use
* @param threads number of threads
* @return a map of document ids to their scores with respect to the query
* @throws IOException if error encountered during query
*/
public static Map<String, Float> batchComputeQueryDocumentScore(
IndexReader reader, List<String> docids, String q, Analyzer analyzer, int threads)
throws IOException {

SearchArgs args = new SearchArgs();
return batchComputeQueryDocumentScoreWithSimilarityAndAnalyzer(reader, docids, q,
new BM25Similarity(Float.parseFloat(args.bm25_k1[0]), Float.parseFloat(args.bm25_b[0])),
analyzer, threads);
}


/**
* Computes the scores of a batch of documents with respect to a query given a scoring function and an analyzer.
*
* @param reader index reader
* @param docids A list of docids of the documents to score
* @param q query
* @param similarity scoring function
* @param analyzer analyzer to use
* @param threads number of threads
* @return a map of document ids to their scores with respect to the query
* @throws IOException if error encountered during query
*/
public static Map<String, Float> batchComputeQueryDocumentScore(
IndexReader reader, List<String> docids, String q, Similarity similarity, Analyzer analyzer, int threads)
throws IOException {
return batchComputeQueryDocumentScoreWithSimilarityAndAnalyzer(reader, docids, q, similarity, analyzer, threads);
}


/**
* Computes the scores of a batch of documents with respect to a query given a scoring function and an analyzer.
*
* @param reader index reader
* @param docids A list of docids of the documents to score
* @param q query
* @param similarity scoring function
* @param analyzer analyzer to use
* @param threads number of threads
* @return a map of document ids to their scores with respect to the query
* @throws IOException if error encountered during query
*/
public static Map<String, Float> batchComputeQueryDocumentScoreWithSimilarityAndAnalyzer(
IndexReader reader, List<String> docids, String q, Similarity similarity, Analyzer analyzer, int threads)
throws IOException {
// We compute the query-document score by issuing the query with an additional filter clause that restricts
// consideration to only the docid in question, and then returning the retrieval score.
//
// This implementation is inefficient, but as the advantage of using the existing Lucene similarity, which means
// that we don't need to copy the scoring function and keep it in sync wrt code updates.

IndexSearcher searcher = new IndexSearcher(reader);
searcher.setSimilarity(similarity);

ThreadPoolExecutor executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(threads);
ConcurrentHashMap<String, Float> results = new ConcurrentHashMap<>();

for (String docid: docids) {
executor.execute(() -> {
try {
Query query = new BagOfWordsQueryGenerator().buildQuery(IndexArgs.CONTENTS, analyzer, q);

Query filterQuery = new ConstantScoreQuery(new TermQuery(new Term(IndexArgs.ID, docid)));
BooleanQuery.Builder builder = new BooleanQuery.Builder();
builder.add(filterQuery, BooleanClause.Occur.MUST);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What you want to do is to move the docids here: In the non-batch impl, the filter clause restricts to a single docid. Here, in the batch impl, you want to restrict to a set of docids - i.e., add multiple sub-clauses in the filter query.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @lintool I took a look at this and tried testing with a set of documents from robust04. However, I came across the error: org.apache.lucene.search.BooleanQuery$TooManyClauses: maxClauseCount is set to 1024.

This resulted from doing:

    for (String docid: docids){
      // Setting default result value for all docids.
      results.put(docid, 0.0f);
      Query filterQuery = new ConstantScoreQuery(new TermQuery(new Term(IndexArgs.ID, docid)));
      builder.add(filterQuery, BooleanClause.Occur.SHOULD);
    }

What are your thoughts on this?

Am I doing the right thing? I tried this with tests and it works when the clause count is less than 1024.

builder.add(query, BooleanClause.Occur.MUST);
Query finalQuery = builder.build();

TopDocs rs = searcher.search(finalQuery, 1);

// We want the score of the first (and only) hit, but remember to remove 1 for the ConstantScoreQuery.
// If we get zero results, indicates that term isn't found in the document.
float result = rs.scoreDocs.length == 0 ? 0 : rs.scoreDocs[0].score - 1;
results.put(docid, result);
} catch (Exception e){}
});
}

executor.shutdown();

try {
// Wait for existing tasks to terminate
while (!executor.awaitTermination(1, TimeUnit.MINUTES)) {
LOG.info(String.format("%.2f percent completed",
(double) executor.getCompletedTaskCount() / docids.size() * 100.0d));
}
} catch (InterruptedException ie) {
// (Re-)Cancel if current thread also interrupted
executor.shutdownNow();
// Preserve interrupt status
Thread.currentThread().interrupt();
}

return results;
}

/**
* Converts a collection docid to a Lucene internal docid.
Expand Down
47 changes: 47 additions & 0 deletions src/test/java/io/anserini/index/IndexReaderUtilsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.anserini.analysis.DefaultEnglishAnalyzer;
import io.anserini.search.SearchArgs;
import io.anserini.search.SimpleSearcher;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexReader;
Expand All @@ -40,6 +41,7 @@

import java.io.ByteArrayOutputStream;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
Expand Down Expand Up @@ -535,6 +537,51 @@ public void testComputeQueryDocumentScore() throws Exception {
dir.close();
}

@Test
public void testBatchComputeQueryDocumentScore() throws Exception {
SimpleSearcher searcher1 = new SimpleSearcher(tempDir1.toString());
// Using analyzer asides the default for second searcher.
Analyzer stemAnalyzer = DefaultEnglishAnalyzer.newStemmingInstance("krovertz");
SimpleSearcher searcher2 = new SimpleSearcher(tempDir1.toString(), stemAnalyzer);
Directory dir = FSDirectory.open(tempDir1);
IndexReader reader = DirectoryReader.open(dir);
Similarity similarity = new BM25Similarity(0.9f, 0.4f);

// A bunch of test queries...
String[] queries = {"text city", "text", "city"};

for (String query: queries) {
SimpleSearcher.Result[] results1 = searcher1.search(query);

List<String> docids = new ArrayList<String>();
for (SimpleSearcher.Result result: results1){
docids.add(result.docid);
}

Map<String, Float> batchScore1 = IndexReaderUtils.batchComputeQueryDocumentScore(reader, docids, query, similarity, 2);
for (SimpleSearcher.Result result: results1){
assertEquals(batchScore1.get(result.docid), result.score, 10e-5);
}

SimpleSearcher.Result[] results2 = searcher2.search(query);
Map<String, Float> batchScore2 = IndexReaderUtils.batchComputeQueryDocumentScore(reader, docids, query, similarity, stemAnalyzer, 2);
for (SimpleSearcher.Result result: results2){
assertEquals(batchScore2.get(result.docid), result.score, 10e-5);
}

// This is hard coded - doc3 isn't retrieved by any of the queries.
String fakeId = "doc3";
docids = List.of(fakeId);
Map<String, Float> batchScore = IndexReaderUtils.batchComputeQueryDocumentScore(
reader, docids, query, similarity, 2);
assertEquals(0.0f, batchScore.get(fakeId), 10e-6);
}

reader.close();
dir.close();
}


@Test
public void testGetIndexStats() throws Exception {
Directory dir = FSDirectory.open(tempDir1);
Expand Down