Skip to content

Commit

Permalink
Add support for query time parameters for ScriptFeature term statisti…
Browse files Browse the repository at this point in the history
…cs (#330)

* Add support for query time parameters

* Rename response object so it doesn't collide with the passed in
parameters

* Beef up docs around term stat injection.
Allow override of parameters

* Add sum stat type

* Making sure things blow up when they should

* Rework termstat parameter passing

* Insert zeroes for unknown terms

* Add experimental note to termstat section

* Add match/unique term counts
  • Loading branch information
worleydl authored Sep 15, 2020
1 parent b647875 commit e265bf6
Show file tree
Hide file tree
Showing 8 changed files with 315 additions and 23 deletions.
101 changes: 100 additions & 1 deletion docs/advanced-functionality.rst
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ Also you can limit the information to a single node in the cluster::
TermStat Query
=============================

**Experimental** - This query is currently in an experimental stage and the DSL may change as the code advances. For stable term statistic access please see the `ExplorerQuery`.

The :code:`TermStatQuery` is a re-imagination of the legacy :code:`ExplorerQuery` which offers clearer specification of terms and more freedom to experiment. This query surfaces the same data as the `ExplorerQuery` but it allows the user to specify a custom Lucene expression for the type of data they would like to retrieve. For example::

POST tmdb/_search
Expand Down Expand Up @@ -309,8 +311,13 @@ Supported aggregation types are:
- :code:`min` -- the minimum
- :code:`max` -- the maximum
- :code:`avg` -- the mean
- :code:`sum` -- the sum
- :code:`stddev` -- the standard deviation

Additionally the following counts are available:
- :code:`matches` -- The number of terms that matched in the current document
- :code:`unique` -- The unique number of terms that were passed in

The :code:`terms` parameter is array of terms to gather statistics for. Currently only single terms are supported, there is not support for phrases or span queries. Note: If your field is tokenized you can pass multiple terms in one string in the array.

The :code:`fields` parameter specifies which fields to check for the specified :code:`terms`. Note if no :code:`analyzer` is specified then we use the analyzer specified for the field.
Expand All @@ -324,4 +331,96 @@ Optional Parameters
Script Injection
----------------

Finally, one last addition that this functionality provides is the ability to inject term statistics into a scripting context. When working with :code:`ScriptFeatures` if you pass a :code:`term_stat` object in with the :code:`terms`, :code:`fields` and :code:`analyzer` parameters you can access the raw values directly in a custom script via an injected variable named :code:`terms`. This provides for advanced feature engineering when you need to look at all the data to make decisions.
Finally, one last addition that this functionality provides is the ability to inject term statistics into a scripting context. When working with :code:`ScriptFeatures` if you pass a :code:`term_stat` object in with the :code:`terms`, :code:`fields` and :code:`analyzer` parameters you can access the raw values directly in a custom script via an injected variable named :code:`termStats`. This provides for advanced feature engineering when you need to look at all the data to make decisions.

Scripts access matching and unique counts slightly differently than inside the TermStatQuery:

To access the count of matched tokens: `params.matchCount.get()`
To access the count of unique tokens: `params.uniqueTerms`

You have the following options for sending in parameters to scripts. If you always want to find stats about the same terms (i.e. stopwords or other common terms in your index) you can hardcode the parameters along with your script::

POST _ltr/_featureset/test
{
"featureset": {
"features": [
{
"name": "injection",
"template_language": "script_feature",
"template": {
"lang": "painless",
"source": "params.termStats['df'].size()",
"params": {
"term_stat": {
"analyzer": "!standard"
"terms": ["rambo rocky"],
"fields": ["overview"]
}
}
}
}
]
}
}

Note: Analyzer names must be prefixed with a bang(!) if specifying locally, otherwise it will treat the value as a parameter lookup.

To set parameter lookups simply pass the name of the parameter to pull the value from like so::

POST _ltr/_featureset/test
{
"featureset": {
"features": [
{
"name": "injection",
"template_language": "script_feature",
"template": {
"lang": "painless",
"source": "params.termStats['df'].size()",
"params": {
"term_stat": {
"analyzer": "analyzerParam"
"terms": "termsParam",
"fields": "fieldsParam"
}
}
}
}
]
}
}

The following example shows how to set the parameters at query time::

POST tmdb/_search
{
"query": {
"bool": {
"filter": [
{
"terms": {
"_id": ["7555", "1370", "1369"]
}
},
{
"sltr": {
"_name": "logged_featureset",
"featureset": "test",
"params": {
"analyzerParam": "standard",
"termsParam": ["troutman"],
"fieldsParam": ["overview"]
}
}}
]
}
},
"ext": {
"ltr_log": {
"log_specs": {
"name": "log_entry1",
"named_query": "logged_featureset"
}
}
}
}
6 changes: 5 additions & 1 deletion src/main/java/com/o19s/es/explore/StatisticsHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ public enum AggrType {
AVG("avg"),
MAX("max"),
MIN("min"),
STDDEV("stddev");
SUM("sum"),
STDDEV("stddev"),
MATCHES("matches");

private String type;

Expand Down Expand Up @@ -118,6 +120,8 @@ public float getAggr(AggrType type) {
return getMax();
case MIN:
return getMin();
case SUM:
return getSum();
case STDDEV:
return getStdDev();
default:
Expand Down
56 changes: 44 additions & 12 deletions src/main/java/com/o19s/es/ltr/feature/store/ScriptFeature.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@
public class ScriptFeature implements Feature {
public static final String TEMPLATE_LANGUAGE = "script_feature";
public static final String FEATURE_VECTOR = "feature_vector";
public static final String TERM_STAT = "terms";
public static final String TERM_STAT = "termStats";
public static final String MATCH_COUNT = "matchCount";
public static final String UNIQUE_TERMS = "uniqueTerms";
public static final String EXTRA_LOGGING = "extra_logging";
public static final String EXTRA_SCRIPT_PARAMS = "extra_script_params";

Expand Down Expand Up @@ -76,6 +78,7 @@ public static ScriptFeature compile(StoredFeature feature) {
try {
XContentParser xContentParser = XContentType.JSON.xContent().createParser(NamedXContentRegistry.EMPTY,
LoggingDeprecationHandler.INSTANCE, feature.template());

return new ScriptFeature(feature.name(), Script.parse(xContentParser, "native"), feature.queryParams());
} catch (IOException e) {
throw new RuntimeException(e);
Expand All @@ -94,6 +97,7 @@ public String name() {
* Transform this feature into a lucene query
*/
@Override
@SuppressWarnings("unchecked")
public Query doToQuery(LtrQueryContext context, FeatureSet featureSet, Map<String, Object> params) {
List<String> missingParams = queryParams.stream()
.filter((x) -> !params.containsKey(x))
Expand All @@ -116,7 +120,6 @@ public Query doToQuery(LtrQueryContext context, FeatureSet featureSet, Map<Strin
}
}


FeatureSupplier supplier = new FeatureSupplier(featureSet);
ExtraLoggingSupplier extraLoggingSupplier = new ExtraLoggingSupplier();
TermStatSupplier termstatSupplier = new TermStatSupplier();
Expand All @@ -125,16 +128,43 @@ public Query doToQuery(LtrQueryContext context, FeatureSet featureSet, Map<Strin
// Parse terms if set
Set<Term> terms = new HashSet<>();
if (baseScriptParams.containsKey("term_stat")) {
@SuppressWarnings("unchecked")
HashMap<String, Object> termspec = (HashMap<String, Object>) baseScriptParams.get("term_stat");

@SuppressWarnings("unchecked")
ArrayList<String> fields = (ArrayList<String>) termspec.get("fields");
String analyzerName = null;
ArrayList<String> fields = null;
ArrayList<String> termList = null;

final Object analyzerNameObj = termspec.get("analyzer");
final Object fieldsObj = termspec.get("fields");
final Object termListObj = termspec.get("terms");

// Support lookup via params or direct assignment
if (analyzerNameObj != null) {
if (analyzerNameObj instanceof String) {
// Support direct assignment by prefixing analyzer with a bang
if (((String)analyzerNameObj).startsWith("!")) {
analyzerName = ((String) analyzerNameObj).substring(1);
} else {
analyzerName = (String) params.get(analyzerNameObj);
}
}
}

@SuppressWarnings("unchecked")
ArrayList<String> termList = (ArrayList<String>) termspec.get("terms");
if (fieldsObj != null) {
if (fieldsObj instanceof String) {
fields = (ArrayList<String>) params.get(fieldsObj);
} else if (fieldsObj instanceof ArrayList) {
fields = (ArrayList<String>) fieldsObj;
}
}

String analyzerName = (String) termspec.get("analyzer");
if (termListObj != null) {
if (termListObj instanceof String) {
termList = (ArrayList<String>) params.get(termListObj);
} else if (termListObj instanceof ArrayList) {
termList = (ArrayList<String>) termListObj;
}
}

if (fields == null || termList == null) {
throw new IllegalArgumentException("Term Stats injection requires fields and terms");
Expand All @@ -143,9 +173,9 @@ public Query doToQuery(LtrQueryContext context, FeatureSet featureSet, Map<Strin
Analyzer analyzer = null;
for(String field : fields) {
if (analyzerName == null) {
MappedFieldType fieldType = context.getQueryShardContext().getMapperService().fieldType(field);
final MappedFieldType fieldType = context.getQueryShardContext().getMapperService().fieldType(field);
analyzer = context.getQueryShardContext().getSearchAnalyzer(fieldType);
} else if (analyzer == null) {
} else {
analyzer = context.getQueryShardContext().getIndexAnalyzers().get(analyzerName);
}

Expand All @@ -154,8 +184,8 @@ public Query doToQuery(LtrQueryContext context, FeatureSet featureSet, Map<Strin
}

for (String termString : termList) {
TokenStream ts = analyzer.tokenStream(field, termString);
TermToBytesRefAttribute termAtt = ts.getAttribute(TermToBytesRefAttribute.class);
final TokenStream ts = analyzer.tokenStream(field, termString);
final TermToBytesRefAttribute termAtt = ts.getAttribute(TermToBytesRefAttribute.class);

try {
ts.reset();
Expand All @@ -170,6 +200,8 @@ public Query doToQuery(LtrQueryContext context, FeatureSet featureSet, Map<Strin
}

nparams.put(TERM_STAT, termstatSupplier);
nparams.put(MATCH_COUNT, termstatSupplier.getMatchedTermCountSupplier());
nparams.put(UNIQUE_TERMS, terms.size());
}

nparams.putAll(baseScriptParams);
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/com/o19s/es/termstat/TermStatScorer.java
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ public DoubleValuesSource getDoubleValuesSource(String name) {
termStatDict.put("tf", tsq.get("tf").get(i));
termStatDict.put("tp", tsq.get("tp").get(i));
termStatDict.put("ttf", tsq.get("ttf").get(i));
termStatDict.put("matches", (float) tsq.getMatchedTermCount());
termStatDict.put("unique", (float) terms.size());

// Run the expression and store the result in computed
DoubleValuesSource dvSrc = compiledExpression.getDoubleValuesSource(bindings);
Expand Down
31 changes: 29 additions & 2 deletions src/main/java/com/o19s/es/termstat/TermStatSupplier.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.o19s.es.explore.StatisticsHelper;
import com.o19s.es.explore.StatisticsHelper.AggrType;
import com.o19s.es.ltr.utils.Suppliers;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.ReaderUtil;
Expand All @@ -28,10 +29,14 @@ public class TermStatSupplier extends AbstractMap<String, ArrayList<Float>> {
private final List<String> ACCEPTED_KEYS = Arrays.asList(new String[]{"df", "idf", "tf", "ttf", "tp"});
private AggrType posAggrType = AggrType.AVG;

private ClassicSimilarity sim;
private StatisticsHelper df_stats, idf_stats, tf_stats, ttf_stats, tp_stats;
private final ClassicSimilarity sim;
private final StatisticsHelper df_stats, idf_stats, tf_stats, ttf_stats, tp_stats;
private final Suppliers.MutableSupplier<Integer> matchedCountSupplier;

private int matchedTermCount = 0;

public TermStatSupplier() {
this.matchedCountSupplier = new Suppliers.MutableSupplier<>();
this.sim = new ClassicSimilarity();
this.df_stats = new StatisticsHelper();
this.idf_stats = new StatisticsHelper();
Expand All @@ -48,6 +53,7 @@ public void bump (IndexSearcher searcher, LeafReaderContext context,
tf_stats.getData().clear();
ttf_stats.getData().clear();
tp_stats.getData().clear();
matchedTermCount = 0;

PostingsEnum postingsEnum = null;
for (Term term : terms) {
Expand All @@ -63,6 +69,7 @@ public void bump (IndexSearcher searcher, LeafReaderContext context,
TermState state = termStates.get(context);

if (state == null) {
insertZeroes(); // Zero out stats for terms we don't know about in the index
continue;
}

Expand All @@ -78,6 +85,8 @@ public void bump (IndexSearcher searcher, LeafReaderContext context,

// Verify document is in postings
if (postingsEnum.advance(docID) == docID){
matchedTermCount++;

tf_stats.add(postingsEnum.freq());

if(postingsEnum.freq() > 0) {
Expand All @@ -96,6 +105,8 @@ public void bump (IndexSearcher searcher, LeafReaderContext context,
tp_stats.add(0.0f);
}
}

matchedCountSupplier.set(matchedTermCount);
}

/**
Expand Down Expand Up @@ -195,8 +206,24 @@ public int size() {
};
}

public int getMatchedTermCount() {
return matchedTermCount;
}

public Suppliers.MutableSupplier<Integer> getMatchedTermCountSupplier() {
return matchedCountSupplier;
}

public void setPosAggr(AggrType type) {
this.posAggrType = type;
}

private void insertZeroes() {
df_stats.add(0.0f);
idf_stats.add(0.0f);
tf_stats.add(0.0f);
ttf_stats.add(0.0f);
tp_stats.add(0.0f);
}
}

4 changes: 2 additions & 2 deletions src/test/java/com/o19s/es/ltr/action/BaseIntegrationTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,9 @@ public ScoreScript newInstance(LeafReaderContext ctx) throws IOException {
@Override
public double execute(ExplanationHolder explainationHolder) {
// For testing purposes just look for the "terms" key and see if stats were injected
if(p.containsKey("terms")) {
if(p.containsKey("termStats")) {
AbstractMap<String, ArrayList<Float>> termStats = (AbstractMap<String,
ArrayList<Float>>) p.get("terms");
ArrayList<Float>>) p.get("termStats");
ArrayList<Float> dfStats = termStats.get("df");
return dfStats.size() > 0 ? dfStats.get(0) : 0.0;
} else {
Expand Down
Loading

0 comments on commit e265bf6

Please sign in to comment.