diff --git a/docs/advanced-functionality.rst b/docs/advanced-functionality.rst index 438bfcc9..7730ffbd 100644 --- a/docs/advanced-functionality.rst +++ b/docs/advanced-functionality.rst @@ -280,6 +280,8 @@ Also you can limit the information to a single node in the cluster:: TermStat Query ============================= +**Experimental** - This query is currently in an experimental stage and the DSL may change as the code advances. For stable term statistic access please see the `ExplorerQuery`. + The :code:`TermStatQuery` is a re-imagination of the legacy :code:`ExplorerQuery` which offers clearer specification of terms and more freedom to experiment. This query surfaces the same data as the `ExplorerQuery` but it allows the user to specify a custom Lucene expression for the type of data they would like to retrieve. For example:: POST tmdb/_search @@ -309,8 +311,13 @@ Supported aggregation types are: - :code:`min` -- the minimum - :code:`max` -- the maximum - :code:`avg` -- the mean +- :code:`sum` -- the sum - :code:`stddev` -- the standard deviation +Additionally the following counts are available: +- :code:`matches` -- The number of terms that matched in the current document +- :code:`unique` -- The unique number of terms that were passed in + The :code:`terms` parameter is array of terms to gather statistics for. Currently only single terms are supported, there is not support for phrases or span queries. Note: If your field is tokenized you can pass multiple terms in one string in the array. The :code:`fields` parameter specifies which fields to check for the specified :code:`terms`. Note if no :code:`analyzer` is specified then we use the analyzer specified for the field. @@ -324,4 +331,96 @@ Optional Parameters Script Injection ---------------- -Finally, one last addition that this functionality provides is the ability to inject term statistics into a scripting context. When working with :code:`ScriptFeatures` if you pass a :code:`term_stat` object in with the :code:`terms`, :code:`fields` and :code:`analyzer` parameters you can access the raw values directly in a custom script via an injected variable named :code:`terms`. This provides for advanced feature engineering when you need to look at all the data to make decisions. +Finally, one last addition that this functionality provides is the ability to inject term statistics into a scripting context. When working with :code:`ScriptFeatures` if you pass a :code:`term_stat` object in with the :code:`terms`, :code:`fields` and :code:`analyzer` parameters you can access the raw values directly in a custom script via an injected variable named :code:`termStats`. This provides for advanced feature engineering when you need to look at all the data to make decisions. + +Scripts access matching and unique counts slightly differently than inside the TermStatQuery: + +To access the count of matched tokens: `params.matchCount.get()` +To access the count of unique tokens: `params.uniqueTerms` + +You have the following options for sending in parameters to scripts. If you always want to find stats about the same terms (i.e. stopwords or other common terms in your index) you can hardcode the parameters along with your script:: + + POST _ltr/_featureset/test + { + "featureset": { + "features": [ + { + "name": "injection", + "template_language": "script_feature", + "template": { + "lang": "painless", + "source": "params.termStats['df'].size()", + "params": { + "term_stat": { + "analyzer": "!standard" + "terms": ["rambo rocky"], + "fields": ["overview"] + } + } + } + } + ] + } + } + + Note: Analyzer names must be prefixed with a bang(!) if specifying locally, otherwise it will treat the value as a parameter lookup. + +To set parameter lookups simply pass the name of the parameter to pull the value from like so:: + + POST _ltr/_featureset/test + { + "featureset": { + "features": [ + { + "name": "injection", + "template_language": "script_feature", + "template": { + "lang": "painless", + "source": "params.termStats['df'].size()", + "params": { + "term_stat": { + "analyzer": "analyzerParam" + "terms": "termsParam", + "fields": "fieldsParam" + } + } + } + } + ] + } + } + +The following example shows how to set the parameters at query time:: + + POST tmdb/_search + { + "query": { + "bool": { + "filter": [ + { + "terms": { + "_id": ["7555", "1370", "1369"] + } + }, + { + "sltr": { + "_name": "logged_featureset", + "featureset": "test", + "params": { + "analyzerParam": "standard", + "termsParam": ["troutman"], + "fieldsParam": ["overview"] + } + }} + ] + } + }, + "ext": { + "ltr_log": { + "log_specs": { + "name": "log_entry1", + "named_query": "logged_featureset" + } + } + } + } diff --git a/src/main/java/com/o19s/es/explore/StatisticsHelper.java b/src/main/java/com/o19s/es/explore/StatisticsHelper.java index cbfd38ca..f70ad04b 100644 --- a/src/main/java/com/o19s/es/explore/StatisticsHelper.java +++ b/src/main/java/com/o19s/es/explore/StatisticsHelper.java @@ -27,7 +27,9 @@ public enum AggrType { AVG("avg"), MAX("max"), MIN("min"), - STDDEV("stddev"); + SUM("sum"), + STDDEV("stddev"), + MATCHES("matches"); private String type; @@ -118,6 +120,8 @@ public float getAggr(AggrType type) { return getMax(); case MIN: return getMin(); + case SUM: + return getSum(); case STDDEV: return getStdDev(); default: diff --git a/src/main/java/com/o19s/es/ltr/feature/store/ScriptFeature.java b/src/main/java/com/o19s/es/ltr/feature/store/ScriptFeature.java index dd7f8b36..cce3a138 100644 --- a/src/main/java/com/o19s/es/ltr/feature/store/ScriptFeature.java +++ b/src/main/java/com/o19s/es/ltr/feature/store/ScriptFeature.java @@ -44,7 +44,9 @@ public class ScriptFeature implements Feature { public static final String TEMPLATE_LANGUAGE = "script_feature"; public static final String FEATURE_VECTOR = "feature_vector"; - public static final String TERM_STAT = "terms"; + public static final String TERM_STAT = "termStats"; + public static final String MATCH_COUNT = "matchCount"; + public static final String UNIQUE_TERMS = "uniqueTerms"; public static final String EXTRA_LOGGING = "extra_logging"; public static final String EXTRA_SCRIPT_PARAMS = "extra_script_params"; @@ -76,6 +78,7 @@ public static ScriptFeature compile(StoredFeature feature) { try { XContentParser xContentParser = XContentType.JSON.xContent().createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, feature.template()); + return new ScriptFeature(feature.name(), Script.parse(xContentParser, "native"), feature.queryParams()); } catch (IOException e) { throw new RuntimeException(e); @@ -94,6 +97,7 @@ public String name() { * Transform this feature into a lucene query */ @Override + @SuppressWarnings("unchecked") public Query doToQuery(LtrQueryContext context, FeatureSet featureSet, Map params) { List missingParams = queryParams.stream() .filter((x) -> !params.containsKey(x)) @@ -116,7 +120,6 @@ public Query doToQuery(LtrQueryContext context, FeatureSet featureSet, Map terms = new HashSet<>(); if (baseScriptParams.containsKey("term_stat")) { - @SuppressWarnings("unchecked") HashMap termspec = (HashMap) baseScriptParams.get("term_stat"); - @SuppressWarnings("unchecked") - ArrayList fields = (ArrayList) termspec.get("fields"); + String analyzerName = null; + ArrayList fields = null; + ArrayList termList = null; + + final Object analyzerNameObj = termspec.get("analyzer"); + final Object fieldsObj = termspec.get("fields"); + final Object termListObj = termspec.get("terms"); + + // Support lookup via params or direct assignment + if (analyzerNameObj != null) { + if (analyzerNameObj instanceof String) { + // Support direct assignment by prefixing analyzer with a bang + if (((String)analyzerNameObj).startsWith("!")) { + analyzerName = ((String) analyzerNameObj).substring(1); + } else { + analyzerName = (String) params.get(analyzerNameObj); + } + } + } - @SuppressWarnings("unchecked") - ArrayList termList = (ArrayList) termspec.get("terms"); + if (fieldsObj != null) { + if (fieldsObj instanceof String) { + fields = (ArrayList) params.get(fieldsObj); + } else if (fieldsObj instanceof ArrayList) { + fields = (ArrayList) fieldsObj; + } + } - String analyzerName = (String) termspec.get("analyzer"); + if (termListObj != null) { + if (termListObj instanceof String) { + termList = (ArrayList) params.get(termListObj); + } else if (termListObj instanceof ArrayList) { + termList = (ArrayList) termListObj; + } + } if (fields == null || termList == null) { throw new IllegalArgumentException("Term Stats injection requires fields and terms"); @@ -143,9 +173,9 @@ public Query doToQuery(LtrQueryContext context, FeatureSet featureSet, Map> { private final List ACCEPTED_KEYS = Arrays.asList(new String[]{"df", "idf", "tf", "ttf", "tp"}); private AggrType posAggrType = AggrType.AVG; - private ClassicSimilarity sim; - private StatisticsHelper df_stats, idf_stats, tf_stats, ttf_stats, tp_stats; + private final ClassicSimilarity sim; + private final StatisticsHelper df_stats, idf_stats, tf_stats, ttf_stats, tp_stats; + private final Suppliers.MutableSupplier matchedCountSupplier; + + private int matchedTermCount = 0; public TermStatSupplier() { + this.matchedCountSupplier = new Suppliers.MutableSupplier<>(); this.sim = new ClassicSimilarity(); this.df_stats = new StatisticsHelper(); this.idf_stats = new StatisticsHelper(); @@ -48,6 +53,7 @@ public void bump (IndexSearcher searcher, LeafReaderContext context, tf_stats.getData().clear(); ttf_stats.getData().clear(); tp_stats.getData().clear(); + matchedTermCount = 0; PostingsEnum postingsEnum = null; for (Term term : terms) { @@ -63,6 +69,7 @@ public void bump (IndexSearcher searcher, LeafReaderContext context, TermState state = termStates.get(context); if (state == null) { + insertZeroes(); // Zero out stats for terms we don't know about in the index continue; } @@ -78,6 +85,8 @@ public void bump (IndexSearcher searcher, LeafReaderContext context, // Verify document is in postings if (postingsEnum.advance(docID) == docID){ + matchedTermCount++; + tf_stats.add(postingsEnum.freq()); if(postingsEnum.freq() > 0) { @@ -96,6 +105,8 @@ public void bump (IndexSearcher searcher, LeafReaderContext context, tp_stats.add(0.0f); } } + + matchedCountSupplier.set(matchedTermCount); } /** @@ -195,8 +206,24 @@ public int size() { }; } + public int getMatchedTermCount() { + return matchedTermCount; + } + + public Suppliers.MutableSupplier getMatchedTermCountSupplier() { + return matchedCountSupplier; + } + public void setPosAggr(AggrType type) { this.posAggrType = type; } + + private void insertZeroes() { + df_stats.add(0.0f); + idf_stats.add(0.0f); + tf_stats.add(0.0f); + ttf_stats.add(0.0f); + tp_stats.add(0.0f); + } } diff --git a/src/test/java/com/o19s/es/ltr/action/BaseIntegrationTest.java b/src/test/java/com/o19s/es/ltr/action/BaseIntegrationTest.java index fa35d1ce..2c349dc1 100644 --- a/src/test/java/com/o19s/es/ltr/action/BaseIntegrationTest.java +++ b/src/test/java/com/o19s/es/ltr/action/BaseIntegrationTest.java @@ -176,9 +176,9 @@ public ScoreScript newInstance(LeafReaderContext ctx) throws IOException { @Override public double execute(ExplanationHolder explainationHolder) { // For testing purposes just look for the "terms" key and see if stats were injected - if(p.containsKey("terms")) { + if(p.containsKey("termStats")) { AbstractMap> termStats = (AbstractMap>) p.get("terms"); + ArrayList>) p.get("termStats"); ArrayList dfStats = termStats.get("df"); return dfStats.size() > 0 ? dfStats.get(0) : 0.0; } else { diff --git a/src/test/java/com/o19s/es/ltr/logging/LoggingIT.java b/src/test/java/com/o19s/es/ltr/logging/LoggingIT.java index d37f1c7f..e213a997 100644 --- a/src/test/java/com/o19s/es/ltr/logging/LoggingIT.java +++ b/src/test/java/com/o19s/es/ltr/logging/LoggingIT.java @@ -101,12 +101,25 @@ public void prepareModelsExtraLogging() throws Exception { LinearRankerParserTests.generateRandomModelString(set), true)); addElement(model); } - public void prepareScriptFeatures() throws Exception { + public void prepareExternalScriptFeatures() throws Exception { + List features = new ArrayList<>(3); + features.add(new StoredFeature("test_inject", Arrays.asList(), ScriptFeature.TEMPLATE_LANGUAGE, + "{\"lang\": \"inject\", \"source\": \"df\", \"params\": {\"term_stat\": { " + + "\"analyzer\": \"analyzerParam\", " + + "\"terms\": \"termsParam\", " + + "\"fields\": \"fieldsParam\" } } }")); + + StoredFeatureSet set = new StoredFeatureSet("my_set", features); + addElement(set); + } + + public void prepareInternalScriptFeatures() throws Exception { List features = new ArrayList<>(3); features.add(new StoredFeature("test_inject", Arrays.asList("query"), ScriptFeature.TEMPLATE_LANGUAGE, "{\"lang\": \"inject\", \"source\": \"df\", \"params\": {\"term_stat\": { " + - "\"terms\": [\"found\"], " + - "\"fields\": [\"field1\"] } } }")); + "\"analyzer\": \"!standard\", " + + "\"terms\": [\"found\"], " + + "\"fields\": [\"field1\"] } } }")); StoredFeatureSet set = new StoredFeatureSet("my_set", features); addElement(set); @@ -326,8 +339,8 @@ public void testLogExtraLogging() throws Exception { assertSearchHitsExtraLogging(docs, resp3); } - public void testScriptLog() throws Exception { - prepareScriptFeatures(); + public void testScriptLogInternalParams() throws Exception { + prepareInternalScriptFeatures(); Map docs = buildIndex(); Map params = new HashMap<>(); @@ -364,6 +377,83 @@ public void testScriptLog() throws Exception { assertTrue((Float)log.get(0).get("value") > 0.0F); } + public void testScriptLogExternalParams() throws Exception { + prepareExternalScriptFeatures(); + Map docs = buildIndex(); + + Map params = new HashMap<>(); + ArrayList terms = new ArrayList<>(); + terms.add("found"); + params.put("termsParam", terms); + + ArrayList fields = new ArrayList<>(); + fields.add("field1"); + params.put("fieldsParam", fields); + + params.put("analyzerParam", "standard"); + + List idsColl = new ArrayList<>(docs.keySet()); + Collections.shuffle(idsColl, random()); + String[] ids = idsColl.subList(0, TestUtil.nextInt(random(), 5, 15)).toArray(new String[0]); + StoredLtrQueryBuilder sbuilder = new StoredLtrQueryBuilder(LtrTestUtils.nullLoader()) + .featureSetName("my_set") + .params(params) + .queryName("test") + .boost(random().nextInt(3)); + + QueryBuilder query = QueryBuilders.boolQuery().must(new WrapperQueryBuilder(sbuilder.toString())) + .filter(QueryBuilders.idsQuery("test").addIds(ids)); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().query(query) + .fetchSource(false) + .size(10) + .ext(Collections.singletonList( + new LoggingSearchExtBuilder() + .addQueryLogging("first_log", "test", false))); + + SearchResponse resp = client().prepareSearch("test_index").setTypes("test").setSource(sourceBuilder).get(); + + SearchHits hits = resp.getHits(); + SearchHit testHit = hits.getAt(0); + Map>> logs = testHit.getFields().get("_ltrlog").getValue(); + + assertTrue(logs.containsKey("first_log")); + List> log = logs.get("first_log"); + + assertEquals(log.get(0).get("name"), "test_inject"); + assertTrue((Float)log.get(0).get("value") > 0.0F); + } + + public void testScriptLogInvalidExternalParams() throws Exception { + prepareExternalScriptFeatures(); + Map docs = buildIndex(); + + Map params = new HashMap<>(); + params.put("query", "found"); + + List idsColl = new ArrayList<>(docs.keySet()); + Collections.shuffle(idsColl, random()); + String[] ids = idsColl.subList(0, TestUtil.nextInt(random(), 5, 15)).toArray(new String[0]); + StoredLtrQueryBuilder sbuilder = new StoredLtrQueryBuilder(LtrTestUtils.nullLoader()) + .featureSetName("my_set") + .params(params) + .queryName("test") + .boost(random().nextInt(3)); + + QueryBuilder query = QueryBuilders.boolQuery().must(new WrapperQueryBuilder(sbuilder.toString())) + .filter(QueryBuilders.idsQuery("test").addIds(ids)); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().query(query) + .fetchSource(false) + .size(10) + .ext(Collections.singletonList( + new LoggingSearchExtBuilder() + .addQueryLogging("first_log", "test", false))); + + assertExcWithMessage(() -> client().prepareSearch("test_index") + .setTypes("test") + .setSource(sourceBuilder).get(), + IllegalArgumentException.class, "Term Stats injection requires fields and terms"); + } + protected void assertSearchHits(Map docs, SearchResponse resp) { for (SearchHit hit: resp.getHits()) { assertTrue(hit.getFields().containsKey("_ltrlog")); diff --git a/src/test/java/com/o19s/es/termstat/TermStatQueryTests.java b/src/test/java/com/o19s/es/termstat/TermStatQueryTests.java index d669cfdd..50ad1b60 100644 --- a/src/test/java/com/o19s/es/termstat/TermStatQueryTests.java +++ b/src/test/java/com/o19s/es/termstat/TermStatQueryTests.java @@ -119,4 +119,42 @@ public void testBasicFormula() throws Exception { Explanation explanation = searcher.explain(tsq, docs.scoreDocs[0].doc); assertThat(explanation.toString().trim(), equalTo("1.8472979 = weight(" + expr + " in doc 0)")); } + + public void testMatchCount() throws Exception { + String expr = "matches"; + AggrType aggr = AggrType.AVG; + AggrType pos_aggr = AggrType.AVG; + + Set terms = new HashSet<>(); + terms.add(new Term("text", "brown")); + terms.add(new Term("text", "cow")); + terms.add(new Term("text", "horse")); + + Expression compiledExpression = (Expression) Scripting.compile(expr); + TermStatQuery tsq = new TermStatQuery(compiledExpression, aggr, pos_aggr, terms); + + // Verify explain + TopDocs docs = searcher.search(tsq, 4); + Explanation explanation = searcher.explain(tsq, docs.scoreDocs[0].doc); + assertThat(explanation.toString().trim(), equalTo("2.0 = weight(" + expr + " in doc 0)")); + } + + public void testUniqueCount() throws Exception { + String expr = "unique"; + AggrType aggr = AggrType.AVG; + AggrType pos_aggr = AggrType.AVG; + + Set terms = new HashSet<>(); + terms.add(new Term("text", "brown")); + terms.add(new Term("text", "cow")); + terms.add(new Term("text", "horse")); + + Expression compiledExpression = (Expression) Scripting.compile(expr); + TermStatQuery tsq = new TermStatQuery(compiledExpression, aggr, pos_aggr, terms); + + // Verify explain + TopDocs docs = searcher.search(tsq, 4); + Explanation explanation = searcher.explain(tsq, docs.scoreDocs[0].doc); + assertThat(explanation.toString().trim(), equalTo("3.0 = weight(" + expr + " in doc 0)")); + } }