Skip to content

Commit

Permalink
Merge branch '2.13' into 2.15
Browse files Browse the repository at this point in the history
  • Loading branch information
jhinch-at-atlassian-com committed Oct 31, 2024
2 parents f57e6af + dbb4e7d commit dcb8cd4
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,9 @@ public ScoreScript newInstance(LeafReaderContext ctx) throws IOException {
public double execute(ExplanationHolder explainationHolder) {
// For testing purposes just look for the "terms" key and see if stats were injected
if(p.containsKey("termStats")) {
AbstractMap<String, ArrayList<Float>> termStats = (AbstractMap<String,
ArrayList<Float>>) p.get("termStats");
ArrayList<Float> dfStats = termStats.get("df");
Supplier<AbstractMap<String, ArrayList<Float>>> termStats = (Supplier<AbstractMap<String,
ArrayList<Float>>>) p.get("termStats");
ArrayList<Float> dfStats = termStats.get().get("df");
return dfStats.size() > 0 ? dfStats.get(0) : 0.0;
} else {
return 0.0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import org.opensearch.action.search.SearchResponse;
import org.opensearch.common.lucene.search.function.FieldValueFactorFunction;
import org.opensearch.common.lucene.search.function.FunctionScoreQuery;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.index.query.InnerHitBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
Expand Down
30 changes: 18 additions & 12 deletions src/main/java/com/o19s/es/ltr/feature/store/ScriptFeature.java
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;

public class ScriptFeature implements Feature {
Expand All @@ -71,6 +72,14 @@ public class ScriptFeature implements Feature {
public static final String EXTRA_LOGGING = "extra_logging";
public static final String EXTRA_SCRIPT_PARAMS = "extra_script_params";

/**
* A thread local allowing for term stats to made available for the script score feature.
* This is needed as the parameters for the script score are created up-front when creating the
* lucene query with their values being swapped out for each document using a Supplier. A thread
* local is used to allow for different documents to have their scores computed concurrently.
*/
private static final ThreadLocal<TermStatSupplier> CURRENT_TERM_STATS = new ThreadLocal<>();

private final String name;
private final Script script;
private final Collection<String> queryParams;
Expand Down Expand Up @@ -143,7 +152,6 @@ public Query doToQuery(LtrQueryContext context, FeatureSet featureSet, Map<Strin

FeatureSupplier supplier = new FeatureSupplier(featureSet);
ExtraLoggingSupplier extraLoggingSupplier = new ExtraLoggingSupplier();
TermStatSupplier termstatSupplier = new TermStatSupplier();
Map<String, Object> nparams = new HashMap<>();

// Parse terms if set
Expand Down Expand Up @@ -220,8 +228,8 @@ public Query doToQuery(LtrQueryContext context, FeatureSet featureSet, Map<Strin
}
}

nparams.put(TERM_STAT, termstatSupplier);
nparams.put(MATCH_COUNT, termstatSupplier.getMatchedTermCountSupplier());
nparams.put(TERM_STAT, (Supplier<TermStatSupplier>) CURRENT_TERM_STATS::get);
nparams.put(MATCH_COUNT, (Supplier<Integer>) () -> CURRENT_TERM_STATS.get().getMatchedTermCount());
nparams.put(UNIQUE_TERMS, terms.size());
}

Expand All @@ -240,25 +248,22 @@ public Query doToQuery(LtrQueryContext context, FeatureSet featureSet, Map<Strin
context.getQueryShardContext().indexVersionCreated(),
null //TODO: this is different from ES LTR
);
return new LtrScript(function, supplier, extraLoggingSupplier, termstatSupplier, terms);
return new LtrScript(function, supplier, extraLoggingSupplier, terms);
}

static class LtrScript extends Query implements LtrRewritableQuery {
private final ScriptScoreFunction function;
private final FeatureSupplier supplier;
private final ExtraLoggingSupplier extraLoggingSupplier;
private final TermStatSupplier termStatSupplier;
private final Set<Term> terms;

LtrScript(ScriptScoreFunction function,
FeatureSupplier supplier,
ExtraLoggingSupplier extraLoggingSupplier,
TermStatSupplier termStatSupplier,
Set<Term> terms) {
this.function = function;
this.supplier = supplier;
this.extraLoggingSupplier = extraLoggingSupplier;
this.termStatSupplier = termStatSupplier;
this.terms = terms;
}

Expand All @@ -285,7 +290,7 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo
if (!scoreMode.needsScores()) {
return new MatchAllDocsQuery().createWeight(searcher, scoreMode, 1F);
}
return new LtrScriptWeight(this, this.function, termStatSupplier, terms, searcher, scoreMode);
return new LtrScriptWeight(this, this.function, terms, searcher, scoreMode);
}

@Override
Expand Down Expand Up @@ -317,18 +322,15 @@ static class LtrScriptWeight extends Weight {
private final IndexSearcher searcher;
private final ScoreMode scoreMode;
private final ScriptScoreFunction function;
private final TermStatSupplier termStatSupplier;
private final Set<Term> terms;
private final HashMap<Term, TermStates> termContexts;

LtrScriptWeight(Query query, ScriptScoreFunction function,
TermStatSupplier termStatSupplier,
Set<Term> terms,
IndexSearcher searcher,
ScoreMode scoreMode) throws IOException {
super(query);
this.function = function;
this.termStatSupplier = termStatSupplier;
this.terms = terms;
this.searcher = searcher;
this.scoreMode = scoreMode;
Expand All @@ -355,6 +357,7 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio
public Scorer scorer(LeafReaderContext context) throws IOException {
LeafScoreFunction leafScoreFunction = function.getLeafScoreFunction(context);
DocIdSetIterator iterator = DocIdSetIterator.all(context.reader().maxDoc());
TermStatSupplier termStatSupplier = new TermStatSupplier();
return new Scorer(this) {
@Override
public int docID() {
Expand All @@ -363,12 +366,15 @@ public int docID() {

@Override
public float score() throws IOException {
CURRENT_TERM_STATS.set(termStatSupplier);
// Do the terms magic if the user asked for it
if (terms.size() > 0) {
termStatSupplier.bump(searcher, context, docID(), terms, scoreMode, termContexts);
}

return (float) leafScoreFunction.score(iterator.docID(), 0F);
float score = (float) leafScoreFunction.score(iterator.docID(), 0F);
CURRENT_TERM_STATS.remove();
return score;
}

@Override
Expand Down
42 changes: 29 additions & 13 deletions src/main/java/com/o19s/es/ltr/query/RankerQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
import com.o19s.es.ltr.ranker.LogLtrRanker;
import com.o19s.es.ltr.ranker.LtrRanker;
import com.o19s.es.ltr.ranker.NullRanker;
import com.o19s.es.ltr.utils.Suppliers;
import com.o19s.es.ltr.utils.Suppliers.MutableSupplier;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
Expand Down Expand Up @@ -61,6 +59,26 @@
* or within a BooleanQuery and an appropriate filter clause.
*/
public class RankerQuery extends Query {
/**
* A thread local to allow for sharing the current feature vector between features. This
* is used primarily for derived expression and script features which derive one feature
* score from another. It relies on the following invariants to work:
* <ul>
* <li>
* Any call to {@link LtrRanker#newFeatureVector(LtrRanker.FeatureVector)} is
* followed by a subsequent call to {@link LtrRanker#score(LtrRanker.FeatureVector)}
* </li>
* <li>
* All feature scorers are invoked only between the creation of the feature vector and
* the final score being computed (the calls outlined above)
* </li>
* <li>
* All calls described above happen on the same thread for a single document
* </li>
* </ul>
*/
private static final ThreadLocal<LtrRanker.FeatureVector> CURRENT_VECTOR = new ThreadLocal<>();

private final List<Query> queries;
private final FeatureSet features;
private final LtrRanker ranker;
Expand Down Expand Up @@ -200,9 +218,8 @@ public boolean isCacheable(LeafReaderContext ctx) {
}

List<Weight> weights = new ArrayList<>(queries.size());
MutableSupplier<LtrRanker.FeatureVector> vectorSupplier = new Suppliers.MutableSupplier<>();
FVLtrRankerWrapper ltrRankerWrapper = new FVLtrRankerWrapper(ranker, vectorSupplier);
LtrRewriteContext context = new LtrRewriteContext(ranker, vectorSupplier);
FVLtrRankerWrapper ltrRankerWrapper = new FVLtrRankerWrapper(ranker);
LtrRewriteContext context = new LtrRewriteContext(ranker, CURRENT_VECTOR::get);
for (Query q : queries) {
if (q instanceof LtrRewritableQuery) {
q = ((LtrRewritableQuery) q).ltrRewrite(context);
Expand Down Expand Up @@ -439,11 +456,9 @@ public long cost() {

static class FVLtrRankerWrapper implements LtrRanker {
private final LtrRanker wrapped;
private final MutableSupplier<FeatureVector> vectorSupplier;

FVLtrRankerWrapper(LtrRanker wrapped, MutableSupplier<FeatureVector> vectorSupplier) {
FVLtrRankerWrapper(LtrRanker wrapped) {
this.wrapped = Objects.requireNonNull(wrapped);
this.vectorSupplier = Objects.requireNonNull(vectorSupplier);
}

@Override
Expand All @@ -454,27 +469,28 @@ public String name() {
@Override
public FeatureVector newFeatureVector(FeatureVector reuse) {
FeatureVector fv = wrapped.newFeatureVector(reuse);
vectorSupplier.set(fv);
CURRENT_VECTOR.set(fv);
return fv;
}

@Override
public float score(FeatureVector point) {
return wrapped.score(point);
float score = wrapped.score(point);
CURRENT_VECTOR.remove();
return score;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
FVLtrRankerWrapper that = (FVLtrRankerWrapper) o;
return Objects.equals(wrapped, that.wrapped) &&
Objects.equals(vectorSupplier, that.vectorSupplier);
return Objects.equals(wrapped, that.wrapped);
}

@Override
public int hashCode() {
return Objects.hash(wrapped, vectorSupplier);
return Objects.hash(wrapped);
}
}

Expand Down
20 changes: 0 additions & 20 deletions src/main/java/com/o19s/es/ltr/utils/Suppliers.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@

package com.o19s.es.ltr.utils;

import com.o19s.es.ltr.ranker.LtrRanker;
import org.opensearch.core.Assertions;

import java.util.concurrent.atomic.AtomicReference;
import java.util.Objects;
import java.util.function.Supplier;

Expand Down Expand Up @@ -63,20 +59,4 @@ public E get() {
return value;
}
}

/**
* A mutable supplier
*/
public static class MutableSupplier<T> implements Supplier<T> {
private final AtomicReference<T> ref = new AtomicReference<>();

@Override
public T get() {
return ref.get();
}

public void set(T obj) {
this.ref.set(obj);
}
}
}
9 changes: 0 additions & 9 deletions src/main/java/com/o19s/es/termstat/TermStatSupplier.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import com.o19s.es.explore.StatisticsHelper;
import com.o19s.es.explore.StatisticsHelper.AggrType;
import com.o19s.es.ltr.utils.Suppliers;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.ReaderUtil;
Expand Down Expand Up @@ -48,12 +47,10 @@ public class TermStatSupplier extends AbstractMap<String, ArrayList<Float>> {

private final ClassicSimilarity sim;
private final StatisticsHelper df_stats, idf_stats, tf_stats, ttf_stats, tp_stats;
private final Suppliers.MutableSupplier<Integer> matchedCountSupplier;

private int matchedTermCount = 0;

public TermStatSupplier() {
this.matchedCountSupplier = new Suppliers.MutableSupplier<>();
this.sim = new ClassicSimilarity();
this.df_stats = new StatisticsHelper();
this.idf_stats = new StatisticsHelper();
Expand Down Expand Up @@ -124,8 +121,6 @@ public void bump (IndexSearcher searcher, LeafReaderContext context,
tp_stats.add(0.0f);
}
}

matchedCountSupplier.set(matchedTermCount);
}

/**
Expand Down Expand Up @@ -229,10 +224,6 @@ public int getMatchedTermCount() {
return matchedTermCount;
}

public Suppliers.MutableSupplier<Integer> getMatchedTermCountSupplier() {
return matchedCountSupplier;
}

public void setPosAggr(AggrType type) {
this.posAggrType = type;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import com.o19s.es.ltr.feature.FeatureSet;
import com.o19s.es.ltr.ranker.DenseFeatureVector;
import com.o19s.es.ltr.ranker.LtrRanker;
import com.o19s.es.ltr.utils.Suppliers;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.opensearch.index.query.QueryBuilders;

Expand All @@ -45,10 +44,8 @@ public void testGetWhenFeatureVectorNotSet() {

public void testGetWhenFeatureVectorSet() {
FeatureSupplier featureSupplier = new FeatureSupplier(getFeatureSet());
Suppliers.MutableSupplier<LtrRanker.FeatureVector> vectorSupplier = new Suppliers.MutableSupplier<>();
LtrRanker.FeatureVector featureVector = new DenseFeatureVector(1);
vectorSupplier.set(featureVector);
featureSupplier.set(vectorSupplier);
featureSupplier.set(() -> featureVector);
assertEquals(featureVector, featureSupplier.get());
}

Expand All @@ -60,11 +57,9 @@ public void testContainsKey() {

public void testGetFeatureScore() {
FeatureSupplier featureSupplier = new FeatureSupplier(getFeatureSet());
Suppliers.MutableSupplier<LtrRanker.FeatureVector> vectorSupplier = new Suppliers.MutableSupplier<>();
LtrRanker.FeatureVector featureVector = new DenseFeatureVector(1);
featureVector.setFeatureScore(0, 10.0f);
vectorSupplier.set(featureVector);
featureSupplier.set(vectorSupplier);
featureSupplier.set(() -> featureVector);
assertEquals(10.0f, featureSupplier.get("test"), 0.0f);
assertNull(featureSupplier.get("bad_test"));
}
Expand All @@ -81,11 +76,9 @@ public void testEntrySetWhenFeatureVectorNotSet(){

public void testEntrySetWhenFeatureVectorIsSet(){
FeatureSupplier featureSupplier = new FeatureSupplier(getFeatureSet());
Suppliers.MutableSupplier<LtrRanker.FeatureVector> vectorSupplier = new Suppliers.MutableSupplier<>();
LtrRanker.FeatureVector featureVector = new DenseFeatureVector(1);
featureVector.setFeatureScore(0, 10.0f);
vectorSupplier.set(featureVector);
featureSupplier.set(vectorSupplier);
featureSupplier.set(() -> featureVector);

Set<Map.Entry<String, Float>> entrySet = featureSupplier.entrySet();
assertFalse(entrySet.isEmpty());
Expand Down

0 comments on commit dcb8cd4

Please sign in to comment.