Skip to content

Commit

Permalink
Fix data race which can occur when using script and derived expressio…
Browse files Browse the repository at this point in the history
…n features with concurrent segment search

Signed-off-by: Jason Hinch <[email protected]>
  • Loading branch information
jhinch-at-atlassian-com committed Oct 28, 2024
1 parent a3ad0a2 commit dbb4e7d
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,9 @@ public ScoreScript newInstance(LeafReaderContext ctx) throws IOException {
public double execute(ExplanationHolder explainationHolder) {
// For testing purposes just look for the "terms" key and see if stats were injected
if(p.containsKey("termStats")) {
AbstractMap<String, ArrayList<Float>> termStats = (AbstractMap<String,
ArrayList<Float>>) p.get("termStats");
ArrayList<Float> dfStats = termStats.get("df");
Supplier<AbstractMap<String, ArrayList<Float>>> termStats = (Supplier<AbstractMap<String,
ArrayList<Float>>>) p.get("termStats");
ArrayList<Float> dfStats = termStats.get().get("df");
return dfStats.size() > 0 ? dfStats.get(0) : 0.0;
} else {
return 0.0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import org.opensearch.action.search.SearchResponse;
import org.opensearch.common.lucene.search.function.FieldValueFactorFunction;
import org.opensearch.common.lucene.search.function.FunctionScoreQuery;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.index.query.InnerHitBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
Expand Down
30 changes: 18 additions & 12 deletions src/main/java/com/o19s/es/ltr/feature/store/ScriptFeature.java
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;

public class ScriptFeature implements Feature {
Expand All @@ -71,6 +72,14 @@ public class ScriptFeature implements Feature {
public static final String EXTRA_LOGGING = "extra_logging";
public static final String EXTRA_SCRIPT_PARAMS = "extra_script_params";

/**
* A thread local allowing for term stats to made available for the script score feature.
* This is needed as the parameters for the script score are created up-front when creating the
* lucene query with their values being swapped out for each document using a Supplier. A thread
* local is used to allow for different documents to have their scores computed concurrently.
*/
private static final ThreadLocal<TermStatSupplier> CURRENT_TERM_STATS = new ThreadLocal<>();

private final String name;
private final Script script;
private final Collection<String> queryParams;
Expand Down Expand Up @@ -143,7 +152,6 @@ public Query doToQuery(LtrQueryContext context, FeatureSet featureSet, Map<Strin

FeatureSupplier supplier = new FeatureSupplier(featureSet);
ExtraLoggingSupplier extraLoggingSupplier = new ExtraLoggingSupplier();
TermStatSupplier termstatSupplier = new TermStatSupplier();
Map<String, Object> nparams = new HashMap<>();

// Parse terms if set
Expand Down Expand Up @@ -220,8 +228,8 @@ public Query doToQuery(LtrQueryContext context, FeatureSet featureSet, Map<Strin
}
}

nparams.put(TERM_STAT, termstatSupplier);
nparams.put(MATCH_COUNT, termstatSupplier.getMatchedTermCountSupplier());
nparams.put(TERM_STAT, (Supplier<TermStatSupplier>) CURRENT_TERM_STATS::get);
nparams.put(MATCH_COUNT, (Supplier<Integer>) () -> CURRENT_TERM_STATS.get().getMatchedTermCount());
nparams.put(UNIQUE_TERMS, terms.size());
}

Expand All @@ -240,25 +248,22 @@ public Query doToQuery(LtrQueryContext context, FeatureSet featureSet, Map<Strin
context.getQueryShardContext().indexVersionCreated(),
null //TODO: this is different from ES LTR
);
return new LtrScript(function, supplier, extraLoggingSupplier, termstatSupplier, terms);
return new LtrScript(function, supplier, extraLoggingSupplier, terms);
}

static class LtrScript extends Query implements LtrRewritableQuery {
private final ScriptScoreFunction function;
private final FeatureSupplier supplier;
private final ExtraLoggingSupplier extraLoggingSupplier;
private final TermStatSupplier termStatSupplier;
private final Set<Term> terms;

LtrScript(ScriptScoreFunction function,
FeatureSupplier supplier,
ExtraLoggingSupplier extraLoggingSupplier,
TermStatSupplier termStatSupplier,
Set<Term> terms) {
this.function = function;
this.supplier = supplier;
this.extraLoggingSupplier = extraLoggingSupplier;
this.termStatSupplier = termStatSupplier;
this.terms = terms;
}

Expand All @@ -285,7 +290,7 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo
if (!scoreMode.needsScores()) {
return new MatchAllDocsQuery().createWeight(searcher, scoreMode, 1F);
}
return new LtrScriptWeight(this, this.function, termStatSupplier, terms, searcher, scoreMode);
return new LtrScriptWeight(this, this.function, terms, searcher, scoreMode);
}

@Override
Expand Down Expand Up @@ -317,18 +322,15 @@ static class LtrScriptWeight extends Weight {
private final IndexSearcher searcher;
private final ScoreMode scoreMode;
private final ScriptScoreFunction function;
private final TermStatSupplier termStatSupplier;
private final Set<Term> terms;
private final HashMap<Term, TermStates> termContexts;

LtrScriptWeight(Query query, ScriptScoreFunction function,
TermStatSupplier termStatSupplier,
Set<Term> terms,
IndexSearcher searcher,
ScoreMode scoreMode) throws IOException {
super(query);
this.function = function;
this.termStatSupplier = termStatSupplier;
this.terms = terms;
this.searcher = searcher;
this.scoreMode = scoreMode;
Expand All @@ -355,6 +357,7 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio
public Scorer scorer(LeafReaderContext context) throws IOException {
LeafScoreFunction leafScoreFunction = function.getLeafScoreFunction(context);
DocIdSetIterator iterator = DocIdSetIterator.all(context.reader().maxDoc());
TermStatSupplier termStatSupplier = new TermStatSupplier();
return new Scorer(this) {
@Override
public int docID() {
Expand All @@ -363,12 +366,15 @@ public int docID() {

@Override
public float score() throws IOException {
CURRENT_TERM_STATS.set(termStatSupplier);
// Do the terms magic if the user asked for it
if (terms.size() > 0) {
termStatSupplier.bump(searcher, context, docID(), terms, scoreMode, termContexts);
}

return (float) leafScoreFunction.score(iterator.docID(), 0F);
float score = (float) leafScoreFunction.score(iterator.docID(), 0F);
CURRENT_TERM_STATS.remove();
return score;
}

@Override
Expand Down
42 changes: 29 additions & 13 deletions src/main/java/com/o19s/es/ltr/query/RankerQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
import com.o19s.es.ltr.ranker.LogLtrRanker;
import com.o19s.es.ltr.ranker.LtrRanker;
import com.o19s.es.ltr.ranker.NullRanker;
import com.o19s.es.ltr.utils.Suppliers;
import com.o19s.es.ltr.utils.Suppliers.MutableSupplier;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
Expand Down Expand Up @@ -61,6 +59,26 @@
* or within a BooleanQuery and an appropriate filter clause.
*/
public class RankerQuery extends Query {
/**
* A thread local to allow for sharing the current feature vector between features. This
* is used primarily for derived expression and script features which derive one feature
* score from another. It relies on the following invariants to work:
* <ul>
* <li>
* Any call to {@link LtrRanker#newFeatureVector(LtrRanker.FeatureVector)} is
* followed by a subsequent call to {@link LtrRanker#score(LtrRanker.FeatureVector)}
* </li>
* <li>
* All feature scorers are invoked only between the creation of the feature vector and
* the final score being computed (the calls outlined above)
* </li>
* <li>
* All calls described above happen on the same thread for a single document
* </li>
* </ul>
*/
private static final ThreadLocal<LtrRanker.FeatureVector> CURRENT_VECTOR = new ThreadLocal<>();

private final List<Query> queries;
private final FeatureSet features;
private final LtrRanker ranker;
Expand Down Expand Up @@ -200,9 +218,8 @@ public boolean isCacheable(LeafReaderContext ctx) {
}

List<Weight> weights = new ArrayList<>(queries.size());
MutableSupplier<LtrRanker.FeatureVector> vectorSupplier = new Suppliers.MutableSupplier<>();
FVLtrRankerWrapper ltrRankerWrapper = new FVLtrRankerWrapper(ranker, vectorSupplier);
LtrRewriteContext context = new LtrRewriteContext(ranker, vectorSupplier);
FVLtrRankerWrapper ltrRankerWrapper = new FVLtrRankerWrapper(ranker);
LtrRewriteContext context = new LtrRewriteContext(ranker, CURRENT_VECTOR::get);
for (Query q : queries) {
if (q instanceof LtrRewritableQuery) {
q = ((LtrRewritableQuery) q).ltrRewrite(context);
Expand Down Expand Up @@ -439,11 +456,9 @@ public long cost() {

static class FVLtrRankerWrapper implements LtrRanker {
private final LtrRanker wrapped;
private final MutableSupplier<FeatureVector> vectorSupplier;

FVLtrRankerWrapper(LtrRanker wrapped, MutableSupplier<FeatureVector> vectorSupplier) {
FVLtrRankerWrapper(LtrRanker wrapped) {
this.wrapped = Objects.requireNonNull(wrapped);
this.vectorSupplier = Objects.requireNonNull(vectorSupplier);
}

@Override
Expand All @@ -454,27 +469,28 @@ public String name() {
@Override
public FeatureVector newFeatureVector(FeatureVector reuse) {
FeatureVector fv = wrapped.newFeatureVector(reuse);
vectorSupplier.set(fv);
CURRENT_VECTOR.set(fv);
return fv;
}

@Override
public float score(FeatureVector point) {
return wrapped.score(point);
float score = wrapped.score(point);
CURRENT_VECTOR.remove();
return score;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
FVLtrRankerWrapper that = (FVLtrRankerWrapper) o;
return Objects.equals(wrapped, that.wrapped) &&
Objects.equals(vectorSupplier, that.vectorSupplier);
return Objects.equals(wrapped, that.wrapped);
}

@Override
public int hashCode() {
return Objects.hash(wrapped, vectorSupplier);
return Objects.hash(wrapped);
}
}

Expand Down
20 changes: 0 additions & 20 deletions src/main/java/com/o19s/es/ltr/utils/Suppliers.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@

package com.o19s.es.ltr.utils;

import com.o19s.es.ltr.ranker.LtrRanker;
import org.opensearch.core.Assertions;

import java.util.concurrent.atomic.AtomicReference;
import java.util.Objects;
import java.util.function.Supplier;

Expand Down Expand Up @@ -63,20 +59,4 @@ public E get() {
return value;
}
}

/**
* A mutable supplier
*/
public static class MutableSupplier<T> implements Supplier<T> {
private final AtomicReference<T> ref = new AtomicReference<>();

@Override
public T get() {
return ref.get();
}

public void set(T obj) {
this.ref.set(obj);
}
}
}
9 changes: 0 additions & 9 deletions src/main/java/com/o19s/es/termstat/TermStatSupplier.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import com.o19s.es.explore.StatisticsHelper;
import com.o19s.es.explore.StatisticsHelper.AggrType;
import com.o19s.es.ltr.utils.Suppliers;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.ReaderUtil;
Expand Down Expand Up @@ -48,12 +47,10 @@ public class TermStatSupplier extends AbstractMap<String, ArrayList<Float>> {

private final ClassicSimilarity sim;
private final StatisticsHelper df_stats, idf_stats, tf_stats, ttf_stats, tp_stats;
private final Suppliers.MutableSupplier<Integer> matchedCountSupplier;

private int matchedTermCount = 0;

public TermStatSupplier() {
this.matchedCountSupplier = new Suppliers.MutableSupplier<>();
this.sim = new ClassicSimilarity();
this.df_stats = new StatisticsHelper();
this.idf_stats = new StatisticsHelper();
Expand Down Expand Up @@ -124,8 +121,6 @@ public void bump (IndexSearcher searcher, LeafReaderContext context,
tp_stats.add(0.0f);
}
}

matchedCountSupplier.set(matchedTermCount);
}

/**
Expand Down Expand Up @@ -229,10 +224,6 @@ public int getMatchedTermCount() {
return matchedTermCount;
}

public Suppliers.MutableSupplier<Integer> getMatchedTermCountSupplier() {
return matchedCountSupplier;
}

public void setPosAggr(AggrType type) {
this.posAggrType = type;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import com.o19s.es.ltr.feature.FeatureSet;
import com.o19s.es.ltr.ranker.DenseFeatureVector;
import com.o19s.es.ltr.ranker.LtrRanker;
import com.o19s.es.ltr.utils.Suppliers;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.opensearch.index.query.QueryBuilders;

Expand All @@ -45,10 +44,8 @@ public void testGetWhenFeatureVectorNotSet() {

public void testGetWhenFeatureVectorSet() {
FeatureSupplier featureSupplier = new FeatureSupplier(getFeatureSet());
Suppliers.MutableSupplier<LtrRanker.FeatureVector> vectorSupplier = new Suppliers.MutableSupplier<>();
LtrRanker.FeatureVector featureVector = new DenseFeatureVector(1);
vectorSupplier.set(featureVector);
featureSupplier.set(vectorSupplier);
featureSupplier.set(() -> featureVector);
assertEquals(featureVector, featureSupplier.get());
}

Expand All @@ -60,11 +57,9 @@ public void testContainsKey() {

public void testGetFeatureScore() {
FeatureSupplier featureSupplier = new FeatureSupplier(getFeatureSet());
Suppliers.MutableSupplier<LtrRanker.FeatureVector> vectorSupplier = new Suppliers.MutableSupplier<>();
LtrRanker.FeatureVector featureVector = new DenseFeatureVector(1);
featureVector.setFeatureScore(0, 10.0f);
vectorSupplier.set(featureVector);
featureSupplier.set(vectorSupplier);
featureSupplier.set(() -> featureVector);
assertEquals(10.0f, featureSupplier.get("test"), 0.0f);
assertNull(featureSupplier.get("bad_test"));
}
Expand All @@ -81,11 +76,9 @@ public void testEntrySetWhenFeatureVectorNotSet(){

public void testEntrySetWhenFeatureVectorIsSet(){
FeatureSupplier featureSupplier = new FeatureSupplier(getFeatureSet());
Suppliers.MutableSupplier<LtrRanker.FeatureVector> vectorSupplier = new Suppliers.MutableSupplier<>();
LtrRanker.FeatureVector featureVector = new DenseFeatureVector(1);
featureVector.setFeatureScore(0, 10.0f);
vectorSupplier.set(featureVector);
featureSupplier.set(vectorSupplier);
featureSupplier.set(() -> featureVector);

Set<Map.Entry<String, Float>> entrySet = featureSupplier.entrySet();
assertFalse(entrySet.isEmpty());
Expand Down

0 comments on commit dbb4e7d

Please sign in to comment.