From cbafdca085380710bb4d8e60d19959c60fdf1bc0 Mon Sep 17 00:00:00 2001 From: Thejas-bhat <35959007+Thejas-bhat@users.noreply.github.com> Date: Wed, 15 Jan 2025 14:38:22 +0530 Subject: [PATCH] MB-58901: Introduce support for BM25 scoring (#2113) Introducing support for BM25 scoring Key stats necessary for the scoring - fieldLength - the number of terms in a field within a doc. - avgDocLength - the average of terms in a field across all the docs in the index. - totalDocs - total number of docs in an index. Introduces a mechanism to maintain consistent scoring in a situation where the index is partitioned as a `bleve.IndexAlias`. This is achieved using the existing preSearch mechanism where the first phase of the entire search involves fetching the above mentioned stats, aggregating them and redistributing back to the bleve indexes which would use them while calculating the score for a hit. In order to enable this global scoring mechanism, the user needs to set the `context` argument of the SearchInContext with: `ctx = context.WithValue(ctx, search.SearchTypeKey, search.GlobalScoring)` Implementation wise, the user needs to explicitly mention BM25 as the scoring mechanism at `indexMapping.ScoringModel` level to actually use this scoring mechanism. This parameter is a global setting, i.e. when performing a search on multiple fields, all the fields are scored with the same scoring model. The storage layer exposes an API which returns the number of terms in a field's term dictionary which is used to compute the `avgDocLength`. At the indexing layer, we check if the queried field supports BM25 scoring and if consistent scoring is availed. This is followed by fetching the stats either from the local bleve index or from a context (in the case where we're availing the consistent scoring) to compute the actual score. Note: The scoring is highly dependent on the size of an individual bleve index's termDictionary (specific to a field) so there can be some discrepancies especially given that each index is further composed of multiple 'segments'. However in large scale use cases these discrepancies can be quite small and don't affect the order of the doc hits - in which case the user may choose to avoid this altogether. --------- Co-authored-by: Aditi Ahuja Co-authored-by: Abhinav Dangeti --- go.mod | 16 +- go.sum | 32 ++-- index/scorch/snapshot_index.go | 9 +- index/scorch/snapshot_index_dict.go | 14 +- index/upsidedown/field_dict.go | 4 + index_alias_impl.go | 64 +++++++- index_impl.go | 53 ++++++ index_test.go | 221 +++++++++++++++++++++++++- mapping/field.go | 4 +- mapping/index.go | 12 ++ mapping/mapping_vectors.go | 6 +- pre_search.go | 33 ++++ search.go | 3 + search/query/knn.go | 2 +- search/query/query.go | 13 ++ search/scorer/scorer_term.go | 108 +++++++++++-- search/scorer/scorer_term_test.go | 12 +- search/searcher/search_disjunction.go | 2 +- search/searcher/search_term.go | 90 +++++++++-- search/util.go | 53 +++++- 20 files changed, 668 insertions(+), 83 deletions(-) diff --git a/go.mod b/go.mod index cfee95607..49b71da6e 100644 --- a/go.mod +++ b/go.mod @@ -5,26 +5,26 @@ go 1.21 require ( github.com/RoaringBitmap/roaring v1.9.3 github.com/bits-and-blooms/bitset v1.12.0 - github.com/blevesearch/bleve_index_api v1.2.0 + github.com/blevesearch/bleve_index_api v1.2.1 github.com/blevesearch/geo v0.1.20 github.com/blevesearch/go-faiss v1.0.24 github.com/blevesearch/go-metrics v0.0.0-20201227073835-cf1acfcdf475 github.com/blevesearch/go-porterstemmer v1.0.3 github.com/blevesearch/goleveldb v1.0.1 github.com/blevesearch/gtreap v0.1.1 - github.com/blevesearch/scorch_segment_api/v2 v2.3.0 + github.com/blevesearch/scorch_segment_api/v2 v2.3.1 github.com/blevesearch/segment v0.9.1 github.com/blevesearch/snowball v0.6.1 github.com/blevesearch/snowballstem v0.9.0 github.com/blevesearch/stempel v0.2.0 github.com/blevesearch/upsidedown_store_api v1.0.2 github.com/blevesearch/vellum v1.1.0 - github.com/blevesearch/zapx/v11 v11.3.10 - github.com/blevesearch/zapx/v12 v12.3.10 - github.com/blevesearch/zapx/v13 v13.3.10 - github.com/blevesearch/zapx/v14 v14.3.10 - github.com/blevesearch/zapx/v15 v15.3.17 - github.com/blevesearch/zapx/v16 v16.1.11-0.20241219160422-82553cdd4b38 + github.com/blevesearch/zapx/v11 v11.3.11 + github.com/blevesearch/zapx/v12 v12.3.11 + github.com/blevesearch/zapx/v13 v13.3.11 + github.com/blevesearch/zapx/v14 v14.3.11 + github.com/blevesearch/zapx/v15 v15.3.18 + github.com/blevesearch/zapx/v16 v16.1.11-0.20250107152255-021e66397612 github.com/couchbase/moss v0.2.0 github.com/golang/protobuf v1.3.2 github.com/spf13/cobra v1.7.0 diff --git a/go.sum b/go.sum index f21c89611..1914f7919 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,8 @@ github.com/RoaringBitmap/roaring v1.9.3 h1:t4EbC5qQwnisr5PrP9nt0IRhRTb9gMUgQF4t4 github.com/RoaringBitmap/roaring v1.9.3/go.mod h1:6AXUsoIEzDTFFQCe1RbGA6uFONMhvejWj5rqITANK90= github.com/bits-and-blooms/bitset v1.12.0 h1:U/q1fAF7xXRhFCrhROzIfffYnu+dlS38vCZtmFVPHmA= github.com/bits-and-blooms/bitset v1.12.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= -github.com/blevesearch/bleve_index_api v1.2.0 h1:/DXMMWBwx/UmGKM1xDhTwDoJI5yQrG6rqRWPFcOgUVo= -github.com/blevesearch/bleve_index_api v1.2.0/go.mod h1:PbcwjIcRmjhGbkS/lJCpfgVSMROV6TRubGGAODaK1W8= +github.com/blevesearch/bleve_index_api v1.2.1 h1:IuXwLvmyp7I7+e0FOA68gcHHLfzSQ4AqQ8wVab5uxk0= +github.com/blevesearch/bleve_index_api v1.2.1/go.mod h1:rKQDl4u51uwafZxFrPD1R7xFOwKnzZW7s/LSeK4lgo0= github.com/blevesearch/geo v0.1.20 h1:paaSpu2Ewh/tn5DKn/FB5SzvH0EWupxHEIwbCk/QPqM= github.com/blevesearch/geo v0.1.20/go.mod h1:DVG2QjwHNMFmjo+ZgzrIq2sfCh6rIHzy9d9d0B59I6w= github.com/blevesearch/go-faiss v1.0.24 h1:K79IvKjoKHdi7FdiXEsAhxpMuns0x4fM0BO93bW5jLI= @@ -19,8 +19,8 @@ github.com/blevesearch/gtreap v0.1.1/go.mod h1:QaQyDRAT51sotthUWAH4Sj08awFSSWzgY github.com/blevesearch/mmap-go v1.0.2/go.mod h1:ol2qBqYaOUsGdm7aRMRrYGgPvnwLe6Y+7LMvAB5IbSA= github.com/blevesearch/mmap-go v1.0.4 h1:OVhDhT5B/M1HNPpYPBKIEJaD0F3Si+CrEKULGCDPWmc= github.com/blevesearch/mmap-go v1.0.4/go.mod h1:EWmEAOmdAS9z/pi/+Toxu99DnsbhG1TIxUoRmJw/pSs= -github.com/blevesearch/scorch_segment_api/v2 v2.3.0 h1:vxCjbXAkkEBSb4AB3Iqgr/EJcPyYRsiGxpcvsS8E1Dw= -github.com/blevesearch/scorch_segment_api/v2 v2.3.0/go.mod h1:5y+TgXYSx+xJGaCwSlvy9G/UJBIY5wzvIkhvhBm2ATc= +github.com/blevesearch/scorch_segment_api/v2 v2.3.1 h1:jjexIzwOdBtC9MlUceNErYHepLvoKxTdA5atbeZSRWE= +github.com/blevesearch/scorch_segment_api/v2 v2.3.1/go.mod h1:Np3Y03rsemM5TsyFxQ3wy+tG97EcviLTbp2S5W0tpRY= github.com/blevesearch/segment v0.9.1 h1:+dThDy+Lvgj5JMxhmOVlgFfkUtZV2kw49xax4+jTfSU= github.com/blevesearch/segment v0.9.1/go.mod h1:zN21iLm7+GnBHWTao9I+Au/7MBiL8pPFtJBJTsk6kQw= github.com/blevesearch/snowball v0.6.1 h1:cDYjn/NCH+wwt2UdehaLpr2e4BwLIjN4V/TdLsL+B5A= @@ -33,18 +33,18 @@ github.com/blevesearch/upsidedown_store_api v1.0.2 h1:U53Q6YoWEARVLd1OYNc9kvhBMG github.com/blevesearch/upsidedown_store_api v1.0.2/go.mod h1:M01mh3Gpfy56Ps/UXHjEO/knbqyQ1Oamg8If49gRwrQ= github.com/blevesearch/vellum v1.1.0 h1:CinkGyIsgVlYf8Y2LUQHvdelgXr6PYuvoDIajq6yR9w= github.com/blevesearch/vellum v1.1.0/go.mod h1:QgwWryE8ThtNPxtgWJof5ndPfx0/YMBh+W2weHKPw8Y= -github.com/blevesearch/zapx/v11 v11.3.10 h1:hvjgj9tZ9DeIqBCxKhi70TtSZYMdcFn7gDb71Xo/fvk= -github.com/blevesearch/zapx/v11 v11.3.10/go.mod h1:0+gW+FaE48fNxoVtMY5ugtNHHof/PxCqh7CnhYdnMzQ= -github.com/blevesearch/zapx/v12 v12.3.10 h1:yHfj3vXLSYmmsBleJFROXuO08mS3L1qDCdDK81jDl8s= -github.com/blevesearch/zapx/v12 v12.3.10/go.mod h1:0yeZg6JhaGxITlsS5co73aqPtM04+ycnI6D1v0mhbCs= -github.com/blevesearch/zapx/v13 v13.3.10 h1:0KY9tuxg06rXxOZHg3DwPJBjniSlqEgVpxIqMGahDE8= -github.com/blevesearch/zapx/v13 v13.3.10/go.mod h1:w2wjSDQ/WBVeEIvP0fvMJZAzDwqwIEzVPnCPrz93yAk= -github.com/blevesearch/zapx/v14 v14.3.10 h1:SG6xlsL+W6YjhX5N3aEiL/2tcWh3DO75Bnz77pSwwKU= -github.com/blevesearch/zapx/v14 v14.3.10/go.mod h1:qqyuR0u230jN1yMmE4FIAuCxmahRQEOehF78m6oTgns= -github.com/blevesearch/zapx/v15 v15.3.17 h1:NkkMI98pYLq/uHnB6YWcITrrLpCVyvZ9iP+AyfpW1Ys= -github.com/blevesearch/zapx/v15 v15.3.17/go.mod h1:vXRQzJJvlGVCdmOD5hg7t7JdjUT5DmDPhsAfjvtzIq8= -github.com/blevesearch/zapx/v16 v16.1.11-0.20241219160422-82553cdd4b38 h1:iJ3Q3sbyo2d0bjfb720RmGjj7cqzh/EdP3528ggDIMY= -github.com/blevesearch/zapx/v16 v16.1.11-0.20241219160422-82553cdd4b38/go.mod h1:JTZseJiEpogtkepKSubIKAmfgbQiOReJXfmjxB1qta4= +github.com/blevesearch/zapx/v11 v11.3.11 h1:r6/wFHFAKWvXJb82f5aO53l6p+gRH6eiX7S1tb3VGc0= +github.com/blevesearch/zapx/v11 v11.3.11/go.mod h1:0+gW+FaE48fNxoVtMY5ugtNHHof/PxCqh7CnhYdnMzQ= +github.com/blevesearch/zapx/v12 v12.3.11 h1:GBBAmXesxXLV5UZ+FZ0qILb7HPssT+kxEkbPPfp5HPM= +github.com/blevesearch/zapx/v12 v12.3.11/go.mod h1:0yeZg6JhaGxITlsS5co73aqPtM04+ycnI6D1v0mhbCs= +github.com/blevesearch/zapx/v13 v13.3.11 h1:H5ZvgS1qM1XKzsAuwp3kvDfh5sJFu9bLH/B8U6Im5e8= +github.com/blevesearch/zapx/v13 v13.3.11/go.mod h1:w2wjSDQ/WBVeEIvP0fvMJZAzDwqwIEzVPnCPrz93yAk= +github.com/blevesearch/zapx/v14 v14.3.11 h1:pg+c/YFzMJ32GkOwLzH/HAQ/GBr6y1Ar7/K5ZQpxTNo= +github.com/blevesearch/zapx/v14 v14.3.11/go.mod h1:qqyuR0u230jN1yMmE4FIAuCxmahRQEOehF78m6oTgns= +github.com/blevesearch/zapx/v15 v15.3.18 h1:yJcQnQyHGNF6rAiwq85OHn3HaXo26t7vgd83RclEw7U= +github.com/blevesearch/zapx/v15 v15.3.18/go.mod h1:vXRQzJJvlGVCdmOD5hg7t7JdjUT5DmDPhsAfjvtzIq8= +github.com/blevesearch/zapx/v16 v16.1.11-0.20250107152255-021e66397612 h1:LhORiqEVyUPUrVETzmmVuT0Yudsz2R3qGLFJWUpMsQo= +github.com/blevesearch/zapx/v16 v16.1.11-0.20250107152255-021e66397612/go.mod h1:+FIylxb+5Z/sFVmNaGpppGLHKBMUEnPSbkKoi+izER8= github.com/couchbase/ghistogram v0.1.0 h1:b95QcQTCzjTUocDXp/uMgSNQi8oj1tGwnJ4bODWZnps= github.com/couchbase/ghistogram v0.1.0/go.mod h1:s1Jhy76zqfEecpNWJfWUiKZookAFaiGOEoyzgHt9i7k= github.com/couchbase/moss v0.2.0 h1:VCYrMzFwEryyhRSeI+/b3tRBSeTpi/8gn5Kf6dxqn+o= diff --git a/index/scorch/snapshot_index.go b/index/scorch/snapshot_index.go index 6d0a0b60e..ece32eee6 100644 --- a/index/scorch/snapshot_index.go +++ b/index/scorch/snapshot_index.go @@ -42,8 +42,9 @@ type asynchSegmentResult struct { dict segment.TermDictionary dictItr segment.DictionaryIterator - index int - docs *roaring.Bitmap + cardinality int + index int + docs *roaring.Bitmap thesItr segment.ThesaurusIterator @@ -137,6 +138,7 @@ func (is *IndexSnapshot) newIndexSnapshotFieldDict(field string, results := make(chan *asynchSegmentResult) var totalBytesRead uint64 + var fieldCardinality int64 for _, s := range is.segment { go func(s *SegmentSnapshot) { dict, err := s.segment.Dictionary(field) @@ -146,6 +148,7 @@ func (is *IndexSnapshot) newIndexSnapshotFieldDict(field string, if dictStats, ok := dict.(segment.DiskStatsReporter); ok { atomic.AddUint64(&totalBytesRead, dictStats.BytesRead()) } + atomic.AddInt64(&fieldCardinality, int64(dict.Cardinality())) if randomLookup { results <- &asynchSegmentResult{dict: dict} } else { @@ -160,6 +163,7 @@ func (is *IndexSnapshot) newIndexSnapshotFieldDict(field string, snapshot: is, cursors: make([]*segmentDictCursor, 0, len(is.segment)), } + for count := 0; count < len(is.segment); count++ { asr := <-results if asr.err != nil && err == nil { @@ -183,6 +187,7 @@ func (is *IndexSnapshot) newIndexSnapshotFieldDict(field string, } } } + rv.cardinality = int(fieldCardinality) rv.bytesRead = totalBytesRead // after ensuring we've read all items on channel if err != nil { diff --git a/index/scorch/snapshot_index_dict.go b/index/scorch/snapshot_index_dict.go index 658aa8148..2ae789c6b 100644 --- a/index/scorch/snapshot_index_dict.go +++ b/index/scorch/snapshot_index_dict.go @@ -28,10 +28,12 @@ type segmentDictCursor struct { } type IndexSnapshotFieldDict struct { - snapshot *IndexSnapshot - cursors []*segmentDictCursor - entry index.DictEntry - bytesRead uint64 + cardinality int + bytesRead uint64 + + snapshot *IndexSnapshot + cursors []*segmentDictCursor + entry index.DictEntry } func (i *IndexSnapshotFieldDict) BytesRead() uint64 { @@ -94,6 +96,10 @@ func (i *IndexSnapshotFieldDict) Next() (*index.DictEntry, error) { return &i.entry, nil } +func (i *IndexSnapshotFieldDict) Cardinality() int { + return i.cardinality +} + func (i *IndexSnapshotFieldDict) Close() error { return nil } diff --git a/index/upsidedown/field_dict.go b/index/upsidedown/field_dict.go index 4875680c9..c990fd47b 100644 --- a/index/upsidedown/field_dict.go +++ b/index/upsidedown/field_dict.go @@ -77,6 +77,10 @@ func (r *UpsideDownCouchFieldDict) Next() (*index.DictEntry, error) { } +func (r *UpsideDownCouchFieldDict) Cardinality() int { + return 0 +} + func (r *UpsideDownCouchFieldDict) Close() error { return r.iterator.Close() } diff --git a/index_alias_impl.go b/index_alias_impl.go index 766240b4a..a4f724e34 100644 --- a/index_alias_impl.go +++ b/index_alias_impl.go @@ -192,9 +192,11 @@ func (i *indexAliasImpl) SearchInContext(ctx context.Context, req *SearchRequest // indicates that this index alias is set as an Index // in another alias, so we need to do a preSearch search // and NOT a real search + bm25PreSearch := isBM25Enabled(i.mapping) flags := &preSearchFlags{ knn: requestHasKNN(req), synonyms: !isMatchNoneQuery(req.Query), + bm25: bm25PreSearch, } return preSearchDataSearch(ctx, req, flags, i.indexes...) } @@ -234,7 +236,7 @@ func (i *indexAliasImpl) SearchInContext(ctx context.Context, req *SearchRequest // - the request requires preSearch var preSearchDuration time.Duration var sr *SearchResult - flags, err := preSearchRequired(req, i.mapping) + flags, err := preSearchRequired(ctx, req, i.mapping) if err != nil { return nil, err } @@ -244,6 +246,7 @@ func (i *indexAliasImpl) SearchInContext(ctx context.Context, req *SearchRequest if err != nil { return nil, err } + // check if the preSearch result has any errors and if so // return the search result as is without executing the query // so that the errors are not lost @@ -573,11 +576,20 @@ type asyncSearchResult struct { type preSearchFlags struct { knn bool synonyms bool + bm25 bool // needs presearch for this too } -// preSearchRequired checks if preSearch is required and returns a boolean flag -// It only allocates the preSearchFlags struct if necessary -func preSearchRequired(req *SearchRequest, m mapping.IndexMapping) (*preSearchFlags, error) { +func isBM25Enabled(m mapping.IndexMapping) bool { + var rv bool + if m, ok := m.(*mapping.IndexMappingImpl); ok { + rv = m.ScoringModel == index.BM25Scoring + } + return rv +} + +// preSearchRequired checks if preSearch is required and returns the presearch flags struct +// indicating which preSearch is required +func preSearchRequired(ctx context.Context, req *SearchRequest, m mapping.IndexMapping) (*preSearchFlags, error) { // Check for KNN query knn := requestHasKNN(req) var synonyms bool @@ -598,18 +610,32 @@ func preSearchRequired(req *SearchRequest, m mapping.IndexMapping) (*preSearchFl } } } - if knn || synonyms { + var bm25 bool + if !isMatchNoneQuery(req.Query) { + if ctx != nil { + if searchType := ctx.Value(search.SearchTypeKey); searchType != nil { + if searchType.(string) == search.GlobalScoring { + bm25 = isBM25Enabled(m) + } + } + } + } + + if knn || synonyms || bm25 { return &preSearchFlags{ knn: knn, synonyms: synonyms, + bm25: bm25, }, nil } return nil, nil } func preSearch(ctx context.Context, req *SearchRequest, flags *preSearchFlags, indexes ...Index) (*SearchResult, error) { + // create a dummy request with a match none query + // since we only care about the preSearchData in PreSearch var dummyQuery = req.Query - if !flags.synonyms { + if !flags.bm25 && !flags.synonyms { // create a dummy request with a match none query // since we only care about the preSearchData in PreSearch dummyQuery = query.NewMatchNoneQuery() @@ -694,6 +720,19 @@ func constructSynonymPreSearchData(rv map[string]map[string]interface{}, sr *Sea return rv } +func constructBM25PreSearchData(rv map[string]map[string]interface{}, sr *SearchResult, indexes []Index) map[string]map[string]interface{} { + bmStats := sr.BM25Stats + if bmStats != nil { + for _, index := range indexes { + rv[index.Name()][search.BM25PreSearchDataKey] = &search.BM25Stats{ + DocCount: bmStats.DocCount, + FieldCardinality: bmStats.FieldCardinality, + } + } + } + return rv +} + func constructPreSearchData(req *SearchRequest, flags *preSearchFlags, preSearchResult *SearchResult, indexes []Index) (map[string]map[string]interface{}, error) { if flags == nil || preSearchResult == nil { @@ -713,6 +752,9 @@ func constructPreSearchData(req *SearchRequest, flags *preSearchFlags, if flags.synonyms { mergedOut = constructSynonymPreSearchData(mergedOut, preSearchResult, indexes) } + if flags.bm25 { + mergedOut = constructBM25PreSearchData(mergedOut, preSearchResult, indexes) + } return mergedOut, nil } @@ -822,6 +864,12 @@ func redistributePreSearchData(req *SearchRequest, indexes []Index) (map[string] rv[index.Name()][search.SynonymPreSearchDataKey] = fts } } + + if bm25Data, ok := req.PreSearchData[search.BM25PreSearchDataKey].(*search.BM25Stats); ok { + for _, index := range indexes { + rv[index.Name()][search.BM25PreSearchDataKey] = bm25Data + } + } return rv, nil } @@ -1009,3 +1057,7 @@ func (f *indexAliasImplFieldDict) Close() error { defer f.index.mutex.RUnlock() return f.fieldDict.Close() } + +func (f *indexAliasImplFieldDict) Cardinality() int { + return f.fieldDict.Cardinality() +} diff --git a/index_impl.go b/index_impl.go index 289014f6c..d59dfb9a1 100644 --- a/index_impl.go +++ b/index_impl.go @@ -485,6 +485,8 @@ func (i *indexImpl) preSearch(ctx context.Context, req *SearchRequest, reader in } var fts search.FieldTermSynonymMap + var count uint64 + var fieldCardinality map[string]int if !isMatchNoneQuery(req.Query) { if synMap, ok := i.m.(mapping.SynonymMapping); ok { if synReader, ok := reader.(index.ThesaurusReader); ok { @@ -494,6 +496,26 @@ func (i *indexImpl) preSearch(ctx context.Context, req *SearchRequest, reader in } } } + if ok := isBM25Enabled(i.m); ok { + fieldCardinality = make(map[string]int) + count, err = reader.DocCount() + if err != nil { + return nil, err + } + + fs := make(query.FieldSet) + fs, err := query.ExtractFields(req.Query, i.m, fs) + if err != nil { + return nil, err + } + for field := range fs { + dict, err := reader.FieldDict(field) + if err != nil { + return nil, err + } + fieldCardinality[field] = dict.Cardinality() + } + } } return &SearchResult{ @@ -503,6 +525,10 @@ func (i *indexImpl) preSearch(ctx context.Context, req *SearchRequest, reader in }, Hits: knnHits, SynonymResult: fts, + BM25Stats: &search.BM25Stats{ + DocCount: float64(count), + FieldCardinality: fieldCardinality, + }, }, nil } @@ -558,6 +584,7 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr var fts search.FieldTermSynonymMap var skipSynonymCollector bool + var bm25Data *search.BM25Stats var ok bool if req.PreSearchData != nil { for k, v := range req.PreSearchData { @@ -578,6 +605,13 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr } skipSynonymCollector = true } + case search.BM25PreSearchDataKey: + if v != nil { + bm25Data, ok = v.(*search.BM25Stats) + if !ok { + return nil, fmt.Errorf("bm25 preSearchData must be of type map[string]interface{}") + } + } } } } @@ -605,6 +639,21 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr ctx = context.WithValue(ctx, search.FieldTermSynonymMapKey, fts) } + scoringModelCallback := func() string { + if isBM25Enabled(i.m) { + return index.BM25Scoring + } + return index.DefaultScoringModel + } + ctx = context.WithValue(ctx, search.GetScoringModelCallbackKey, + search.GetScoringModelCallbackFn(scoringModelCallback)) + + // set the bm25 presearch data (stats important for consistent scoring) in + // the context object + if bm25Data != nil { + ctx = context.WithValue(ctx, search.BM25PreSearchDataKey, bm25Data) + } + // This callback and variable handles the tracking of bytes read // 1. as part of creation of tfr and its Next() calls which is // accounted by invoking this callback when the TFR is closed. @@ -1107,6 +1156,10 @@ func (f *indexImplFieldDict) Close() error { return f.indexReader.Close() } +func (f *indexImplFieldDict) Cardinality() int { + return f.fieldDict.Cardinality() +} + // helper function to remove duplicate entries from slice of strings func deDuplicate(fields []string) []string { entries := make(map[string]struct{}) diff --git a/index_test.go b/index_test.go index 82be0d947..c2844584a 100644 --- a/index_test.go +++ b/index_test.go @@ -350,6 +350,216 @@ func TestBytesWritten(t *testing.T) { cleanupTmpIndexPath(t, tmpIndexPath4) } +func createIndexMappingOnSampleData() *mapping.IndexMappingImpl { + indexMapping := NewIndexMapping() + indexMapping.TypeField = "type" + indexMapping.DefaultAnalyzer = "en" + indexMapping.ScoringModel = index.DefaultScoringModel + documentMapping := NewDocumentMapping() + indexMapping.AddDocumentMapping("hotel", documentMapping) + indexMapping.StoreDynamic = false + indexMapping.DocValuesDynamic = false + contentFieldMapping := NewTextFieldMapping() + contentFieldMapping.Store = false + + reviewsMapping := NewDocumentMapping() + reviewsMapping.AddFieldMappingsAt("content", contentFieldMapping) + documentMapping.AddSubDocumentMapping("reviews", reviewsMapping) + + typeFieldMapping := NewTextFieldMapping() + typeFieldMapping.Store = false + documentMapping.AddFieldMappingsAt("type", typeFieldMapping) + + return indexMapping +} + +func TestBM25TFIDFScoring(t *testing.T) { + tmpIndexPath1 := createTmpIndexPath(t) + defer cleanupTmpIndexPath(t, tmpIndexPath1) + tmpIndexPath2 := createTmpIndexPath(t) + defer cleanupTmpIndexPath(t, tmpIndexPath2) + + indexMapping := createIndexMappingOnSampleData() + indexMapping.ScoringModel = index.BM25Scoring + indexBM25, err := NewUsing(tmpIndexPath1, indexMapping, Config.DefaultIndexType, Config.DefaultMemKVStore, nil) + if err != nil { + t.Fatal(err) + } + + indexMapping1 := createIndexMappingOnSampleData() + indexTFIDF, err := NewUsing(tmpIndexPath2, indexMapping1, Config.DefaultIndexType, Config.DefaultMemKVStore, nil) + if err != nil { + t.Fatal(err) + } + + defer func() { + err := indexBM25.Close() + if err != nil { + t.Fatal(err) + } + + err = indexTFIDF.Close() + if err != nil { + t.Fatal(err) + } + }() + + batch, err := getBatchFromData(indexBM25, "sample-data.json") + if err != nil { + t.Fatalf("failed to form a batch") + } + err = indexBM25.Batch(batch) + if err != nil { + t.Fatalf("failed to index batch %v\n", err) + } + query := NewMatchQuery("Hotel") + query.FieldVal = "name" + searchRequest := NewSearchRequestOptions(query, int(10), 0, true) + + resBM25, err := indexBM25.Search(searchRequest) + if err != nil { + t.Error(err) + } + + batch, err = getBatchFromData(indexTFIDF, "sample-data.json") + if err != nil { + t.Fatalf("failed to form a batch") + } + err = indexTFIDF.Batch(batch) + if err != nil { + t.Fatalf("failed to index batch %v\n", err) + } + + resTFIDF, err := indexTFIDF.Search(searchRequest) + if err != nil { + t.Error(err) + } + + for i, hit := range resTFIDF.Hits { + if hit.Score < resBM25.Hits[i].Score { + t.Fatalf("expected the score to be higher for BM25, got %v and %v", + resBM25.Hits[i].Score, hit.Score) + } + } +} + +func TestBM25GlobalScoring(t *testing.T) { + tmpIndexPath := createTmpIndexPath(t) + defer cleanupTmpIndexPath(t, tmpIndexPath) + + indexMapping := createIndexMappingOnSampleData() + indexMapping.ScoringModel = index.BM25Scoring + idxSinglePartition, err := NewUsing(tmpIndexPath, indexMapping, Config.DefaultIndexType, Config.DefaultMemKVStore, nil) + if err != nil { + t.Fatal(err) + } + + defer func() { + err := idxSinglePartition.Close() + if err != nil { + t.Fatal(err) + } + }() + + batch, err := getBatchFromData(idxSinglePartition, "sample-data.json") + if err != nil { + t.Fatalf("failed to form a batch") + } + err = idxSinglePartition.Batch(batch) + if err != nil { + t.Fatalf("failed to index batch %v\n", err) + } + query := NewMatchQuery("Hotel") + query.FieldVal = "name" + searchRequest := NewSearchRequestOptions(query, int(10), 0, true) + + res, err := idxSinglePartition.Search(searchRequest) + if err != nil { + t.Error(err) + } + + singlePartHits := res.Hits + + dataset, _ := readDataFromFile("sample-data.json") + tmpIndexPath1 := createTmpIndexPath(t) + defer cleanupTmpIndexPath(t, tmpIndexPath1) + + idxPart1, err := NewUsing(tmpIndexPath1, indexMapping, Config.DefaultIndexType, Config.DefaultMemKVStore, nil) + if err != nil { + t.Fatal(err) + } + + defer func() { + err := idxPart1.Close() + if err != nil { + t.Fatal(err) + } + }() + + batch1 := idxPart1.NewBatch() + for _, doc := range dataset[:len(dataset)/2] { + err = batch1.Index(fmt.Sprintf("%d", doc["id"]), doc) + if err != nil { + t.Fatal(err) + } + } + err = idxPart1.Batch(batch1) + if err != nil { + t.Fatal(err) + } + + tmpIndexPath2 := createTmpIndexPath(t) + defer cleanupTmpIndexPath(t, tmpIndexPath2) + + idxPart2, err := NewUsing(tmpIndexPath2, indexMapping, Config.DefaultIndexType, Config.DefaultMemKVStore, nil) + if err != nil { + t.Fatal(err) + } + + defer func() { + err := idxPart2.Close() + if err != nil { + t.Fatal(err) + } + }() + + batch2 := idxPart2.NewBatch() + for _, doc := range dataset[len(dataset)/2:] { + err = batch2.Index(fmt.Sprintf("%d", doc["id"]), doc) + if err != nil { + t.Fatal(err) + } + } + err = idxPart2.Batch(batch2) + if err != nil { + t.Fatal(err) + } + + multiPartIndex := NewIndexAlias(idxPart1, idxPart2) + err = multiPartIndex.SetIndexMapping(indexMapping) + if err != nil { + t.Fatal(err) + } + + ctx := context.Background() + // this key is set to ensure that we have a consistent scoring at the index alias + // level (it forces a pre search phase which can have a small overhead) + ctx = context.WithValue(ctx, search.SearchTypeKey, search.GlobalScoring) + + res, err = multiPartIndex.SearchInContext(ctx, searchRequest) + if err != nil { + t.Error(err) + } + + for i, hit := range res.Hits { + if hit.Score != singlePartHits[i].Score { + t.Fatalf("expected the scores to be the same, got %v and %v", + hit.Score, singlePartHits[i].Score) + } + } + +} + func TestBytesRead(t *testing.T) { tmpIndexPath := createTmpIndexPath(t) defer cleanupTmpIndexPath(t, tmpIndexPath) @@ -671,23 +881,30 @@ func TestBytesReadStored(t *testing.T) { } } -func getBatchFromData(idx Index, fileName string) (*Batch, error) { +func readDataFromFile(fileName string) ([]map[string]interface{}, error) { pwd, err := os.Getwd() if err != nil { return nil, err } path := filepath.Join(pwd, "data", "test", fileName) - batch := idx.NewBatch() + var dataset []map[string]interface{} fileContent, err := os.ReadFile(path) if err != nil { return nil, err } + err = json.Unmarshal(fileContent, &dataset) if err != nil { return nil, err } + return dataset, nil +} + +func getBatchFromData(idx Index, fileName string) (*Batch, error) { + dataset, err := readDataFromFile(fileName) + batch := idx.NewBatch() for _, doc := range dataset { err = batch.Index(fmt.Sprintf("%d", doc["id"]), doc) if err != nil { diff --git a/mapping/field.go b/mapping/field.go index ce2878b18..cfb390b40 100644 --- a/mapping/field.go +++ b/mapping/field.go @@ -74,8 +74,8 @@ type FieldMapping struct { Dims int `json:"dims,omitempty"` // Similarity is the similarity algorithm used for scoring - // vector fields. - // See: index.DefaultSimilarityMetric & index.SupportedSimilarityMetrics + // field's content while performing search on it. + // See: index.SimilarityModels Similarity string `json:"similarity,omitempty"` // Applicable to vector fields only - optimization string diff --git a/mapping/index.go b/mapping/index.go index 8a0d5e34a..6150f2a38 100644 --- a/mapping/index.go +++ b/mapping/index.go @@ -50,6 +50,7 @@ type IndexMappingImpl struct { DefaultAnalyzer string `json:"default_analyzer"` DefaultDateTimeParser string `json:"default_datetime_parser"` DefaultSynonymSource string `json:"default_synonym_source,omitempty"` + ScoringModel string `json:"scoring_model,omitempty"` DefaultField string `json:"default_field"` StoreDynamic bool `json:"store_dynamic"` IndexDynamic bool `json:"index_dynamic"` @@ -201,6 +202,11 @@ func (im *IndexMappingImpl) Validate() error { return err } } + + if _, ok := index.SupportedScoringModels[im.ScoringModel]; !ok && im.ScoringModel != "" { + return fmt.Errorf("unsupported scoring model: %s", im.ScoringModel) + } + return nil } @@ -303,6 +309,12 @@ func (im *IndexMappingImpl) UnmarshalJSON(data []byte) error { if err != nil { return err } + case "scoring_model": + err := util.UnmarshalJSON(v, &im.ScoringModel) + if err != nil { + return err + } + default: invalidKeys = append(invalidKeys, k) } diff --git a/mapping/mapping_vectors.go b/mapping/mapping_vectors.go index dbfde1fb0..20cbac6a8 100644 --- a/mapping/mapping_vectors.go +++ b/mapping/mapping_vectors.go @@ -204,7 +204,7 @@ func validateVectorFieldAlias(field *FieldMapping, parentName string, } if field.Similarity == "" { - field.Similarity = index.DefaultSimilarityMetric + field.Similarity = index.DefaultVectorSimilarityMetric } if field.VectorIndexOptimizedFor == "" { @@ -249,10 +249,10 @@ func validateVectorFieldAlias(field *FieldMapping, parentName string, MinVectorDims, MaxVectorDims) } - if _, ok := index.SupportedSimilarityMetrics[field.Similarity]; !ok { + if _, ok := index.SupportedVectorSimilarityMetrics[field.Similarity]; !ok { return fmt.Errorf("field: '%s', invalid similarity "+ "metric: '%s', valid metrics are: %+v", field.Name, field.Similarity, - reflect.ValueOf(index.SupportedSimilarityMetrics).MapKeys()) + reflect.ValueOf(index.SupportedVectorSimilarityMetrics).MapKeys()) } if fieldAliasCtx != nil { // writing to a nil map is unsafe diff --git a/pre_search.go b/pre_search.go index 5fd710d68..3dd7e0fe3 100644 --- a/pre_search.go +++ b/pre_search.go @@ -82,6 +82,34 @@ func (s *synonymPreSearchResultProcessor) finalize(sr *SearchResult) { } } +type bm25PreSearchResultProcessor struct { + docCount float64 // bm25 specific stats + fieldCardinality map[string]int +} + +func newBM25PreSearchResultProcessor() *bm25PreSearchResultProcessor { + return &bm25PreSearchResultProcessor{ + fieldCardinality: make(map[string]int), + } +} + +// TODO How will this work for queries other than term queries? +func (b *bm25PreSearchResultProcessor) add(sr *SearchResult, indexName string) { + if sr.BM25Stats != nil { + b.docCount += sr.BM25Stats.DocCount + for field, cardinality := range sr.BM25Stats.FieldCardinality { + b.fieldCardinality[field] += cardinality + } + } +} + +func (b *bm25PreSearchResultProcessor) finalize(sr *SearchResult) { + sr.BM25Stats = &search.BM25Stats{ + DocCount: b.docCount, + FieldCardinality: b.fieldCardinality, + } +} + // ----------------------------------------------------------------------------- // Master struct that can hold any number of presearch result processors type compositePreSearchResultProcessor struct { @@ -122,6 +150,11 @@ func createPreSearchResultProcessor(req *SearchRequest, flags *preSearchFlags) p processors = append(processors, synonymProcessor) } } + if flags.bm25 { + if bm25Processtor := newBM25PreSearchResultProcessor(); bm25Processtor != nil { + processors = append(processors, bm25Processtor) + } + } // Return based on the number of processors, optimizing for the common case of 1 processor // If there are no processors, return nil switch len(processors) { diff --git a/search.go b/search.go index 72bfca5e2..e13a93703 100644 --- a/search.go +++ b/search.go @@ -447,6 +447,9 @@ type SearchResult struct { // special fields that are applicable only for search // results that are obtained from a presearch SynonymResult search.FieldTermSynonymMap `json:"synonym_result,omitempty"` + + // The following fields are applicable to BM25 preSearch + BM25Stats *search.BM25Stats `json:"bm25_stats,omitempty"` } func (sr *SearchResult) Size() int { diff --git a/search/query/knn.go b/search/query/knn.go index 4d105d943..8221fbcea 100644 --- a/search/query/knn.go +++ b/search/query/knn.go @@ -82,7 +82,7 @@ func (q *KNNQuery) Searcher(ctx context.Context, i index.IndexReader, fieldMapping := m.FieldMappingForPath(q.VectorField) similarityMetric := fieldMapping.Similarity if similarityMetric == "" { - similarityMetric = index.DefaultSimilarityMetric + similarityMetric = index.DefaultVectorSimilarityMetric } if q.K <= 0 || len(q.Vector) == 0 { return nil, fmt.Errorf("k must be greater than 0 and vector must be non-empty") diff --git a/search/query/query.go b/search/query/query.go index 86859ae5b..a1f7b3404 100644 --- a/search/query/query.go +++ b/search/query/query.go @@ -105,6 +105,19 @@ func ParsePreSearchData(input []byte) (map[string]interface{}, error) { rv = make(map[string]interface{}) } rv[search.SynonymPreSearchDataKey] = value + case search.BM25PreSearchDataKey: + var value *search.BM25Stats + if v != nil { + err := util.UnmarshalJSON(v, &value) + if err != nil { + return nil, err + } + } + if rv == nil { + rv = make(map[string]interface{}) + } + rv[search.BM25PreSearchDataKey] = value + } } return rv, nil diff --git a/search/scorer/scorer_term.go b/search/scorer/scorer_term.go index 7b60eda4e..f5f8ec935 100644 --- a/search/scorer/scorer_term.go +++ b/search/scorer/scorer_term.go @@ -35,8 +35,9 @@ type TermQueryScorer struct { queryTerm string queryField string queryBoost float64 - docTerm uint64 - docTotal uint64 + docTerm uint64 // number of documents containing the term + docTotal uint64 // total number of documents in the index + avgDocLength float64 idf float64 options search.SearcherOptions idfExplanation *search.Explanation @@ -61,19 +62,43 @@ func (s *TermQueryScorer) Size() int { return sizeInBytes } -func NewTermQueryScorer(queryTerm []byte, queryField string, queryBoost float64, docTotal, docTerm uint64, options search.SearcherOptions) *TermQueryScorer { +func (s *TermQueryScorer) computeIDF(avgDocLength float64, docTotal, docTerm uint64) float64 { + var rv float64 + if avgDocLength > 0 { + // avgDocLength is set only for bm25 scoring + rv = math.Log(1 + (float64(docTotal)-float64(docTerm)+0.5)/ + (float64(docTerm)+0.5)) + } else { + rv = 1.0 + math.Log(float64(docTotal)/ + float64(docTerm+1.0)) + } + + return rv +} + +// queryTerm - the specific term being scored by this scorer object +// queryField - the field in which the term is being searched +// queryBoost - the boost value for the query term +// docTotal - total number of documents in the index +// docTerm - number of documents containing the term +// avgDocLength - average document length in the index +// options - search options such as explain scoring, include the location of the term etc. +func NewTermQueryScorer(queryTerm []byte, queryField string, queryBoost float64, docTotal, + docTerm uint64, avgDocLength float64, options search.SearcherOptions) *TermQueryScorer { + rv := TermQueryScorer{ queryTerm: string(queryTerm), queryField: queryField, queryBoost: queryBoost, docTerm: docTerm, docTotal: docTotal, - idf: 1.0 + math.Log(float64(docTotal)/float64(docTerm+1.0)), + avgDocLength: avgDocLength, options: options, queryWeight: 1.0, includeScore: options.Score != "none", } + rv.idf = rv.computeIDF(avgDocLength, docTotal, docTerm) if options.Explain { rv.idfExplanation = &search.Explanation{ Value: rv.idf, @@ -114,6 +139,63 @@ func (s *TermQueryScorer) SetQueryNorm(qnorm float64) { } } +func (s *TermQueryScorer) docScore(tf, norm float64) (score float64, model string) { + if s.avgDocLength > 0 { + // bm25 scoring + // using the posting's norm value to recompute the field length for the doc num + fieldLength := 1 / (norm * norm) + + score = s.idf * (tf * search.BM25_k1) / + (tf + search.BM25_k1*(1-search.BM25_b+(search.BM25_b*fieldLength/s.avgDocLength))) + model = index.BM25Scoring + } else { + // tf-idf scoring by default + score = tf * norm * s.idf + model = index.DefaultScoringModel + } + return score, model +} + +func (s *TermQueryScorer) scoreExplanation(tf float64, termMatch *index.TermFieldDoc) []*search.Explanation { + var rv []*search.Explanation + if s.avgDocLength > 0 { + fieldLength := 1 / (termMatch.Norm * termMatch.Norm) + fieldNormVal := 1 - search.BM25_b + (search.BM25_b * fieldLength / s.avgDocLength) + fieldNormalizeExplanation := &search.Explanation{ + Value: fieldNormVal, + Message: fmt.Sprintf("fieldNorm(field=%s), b=%f, fieldLength=%f, avgFieldLength=%f)", + s.queryField, search.BM25_b, fieldLength, s.avgDocLength), + } + + saturationExplanation := &search.Explanation{ + Value: search.BM25_k1 / (tf + search.BM25_k1*fieldNormVal), + Message: fmt.Sprintf("saturation(term:%s), k1=%f/(tf=%f + k1*fieldNorm=%f))", + termMatch.Term, search.BM25_k1, tf, fieldNormVal), + Children: []*search.Explanation{fieldNormalizeExplanation}, + } + + rv = make([]*search.Explanation, 3) + rv[0] = &search.Explanation{ + Value: tf, + Message: fmt.Sprintf("tf(termFreq(%s:%s)=%d", s.queryField, s.queryTerm, termMatch.Freq), + } + rv[1] = saturationExplanation + rv[2] = s.idfExplanation + } else { + rv = make([]*search.Explanation, 3) + rv[0] = &search.Explanation{ + Value: tf, + Message: fmt.Sprintf("tf(termFreq(%s:%s)=%d", s.queryField, s.queryTerm, termMatch.Freq), + } + rv[1] = &search.Explanation{ + Value: termMatch.Norm, + Message: fmt.Sprintf("fieldNorm(field=%s, doc=%s)", s.queryField, termMatch.ID), + } + rv[2] = s.idfExplanation + } + return rv +} + func (s *TermQueryScorer) Score(ctx *search.SearchContext, termMatch *index.TermFieldDoc) *search.DocumentMatch { rv := ctx.DocumentMatchPool.Get() // perform any score computations only when needed @@ -125,22 +207,14 @@ func (s *TermQueryScorer) Score(ctx *search.SearchContext, termMatch *index.Term } else { tf = math.Sqrt(float64(termMatch.Freq)) } - score := tf * termMatch.Norm * s.idf + score, scoringModel := s.docScore(tf, termMatch.Norm) if s.options.Explain { - childrenExplanations := make([]*search.Explanation, 3) - childrenExplanations[0] = &search.Explanation{ - Value: tf, - Message: fmt.Sprintf("tf(termFreq(%s:%s)=%d", s.queryField, s.queryTerm, termMatch.Freq), - } - childrenExplanations[1] = &search.Explanation{ - Value: termMatch.Norm, - Message: fmt.Sprintf("fieldNorm(field=%s, doc=%s)", s.queryField, termMatch.ID), - } - childrenExplanations[2] = s.idfExplanation + childrenExplanations := s.scoreExplanation(tf, termMatch) scoreExplanation = &search.Explanation{ - Value: score, - Message: fmt.Sprintf("fieldWeight(%s:%s in %s), product of:", s.queryField, s.queryTerm, termMatch.ID), + Value: score, + Message: fmt.Sprintf("fieldWeight(%s:%s in %s), as per %s model, "+ + "product of:", s.queryField, s.queryTerm, termMatch.ID, scoringModel), Children: childrenExplanations, } } diff --git a/search/scorer/scorer_term_test.go b/search/scorer/scorer_term_test.go index ffe535183..097dbe243 100644 --- a/search/scorer/scorer_term_test.go +++ b/search/scorer/scorer_term_test.go @@ -30,7 +30,7 @@ func TestTermScorer(t *testing.T) { var queryTerm = []byte("beer") var queryField = "desc" var queryBoost = 1.0 - scorer := NewTermQueryScorer(queryTerm, queryField, queryBoost, docTotal, docTerm, search.SearcherOptions{Explain: true}) + scorer := NewTermQueryScorer(queryTerm, queryField, queryBoost, docTotal, docTerm, 0, search.SearcherOptions{Explain: true}) idf := 1.0 + math.Log(float64(docTotal)/float64(docTerm+1.0)) tests := []struct { @@ -58,7 +58,7 @@ func TestTermScorer(t *testing.T) { Sort: []string{}, Expl: &search.Explanation{ Value: math.Sqrt(1.0) * idf, - Message: "fieldWeight(desc:beer in one), product of:", + Message: "fieldWeight(desc:beer in one), as per tfidf model, product of:", Children: []*search.Explanation{ { Value: 1, @@ -100,7 +100,7 @@ func TestTermScorer(t *testing.T) { Sort: []string{}, Expl: &search.Explanation{ Value: math.Sqrt(1.0) * idf, - Message: "fieldWeight(desc:beer in one), product of:", + Message: "fieldWeight(desc:beer in one), as per tfidf model, product of:", Children: []*search.Explanation{ { Value: 1, @@ -131,7 +131,7 @@ func TestTermScorer(t *testing.T) { Sort: []string{}, Expl: &search.Explanation{ Value: math.Sqrt(65) * idf, - Message: "fieldWeight(desc:beer in one), product of:", + Message: "fieldWeight(desc:beer in one), as per tfidf model, product of:", Children: []*search.Explanation{ { Value: math.Sqrt(65), @@ -175,7 +175,7 @@ func TestTermScorerWithQueryNorm(t *testing.T) { var queryTerm = []byte("beer") var queryField = "desc" var queryBoost = 3.0 - scorer := NewTermQueryScorer(queryTerm, queryField, queryBoost, docTotal, docTerm, search.SearcherOptions{Explain: true}) + scorer := NewTermQueryScorer(queryTerm, queryField, queryBoost, docTotal, docTerm, 0, search.SearcherOptions{Explain: true}) idf := 1.0 + math.Log(float64(docTotal)/float64(docTerm+1.0)) scorer.SetQueryNorm(2.0) @@ -224,7 +224,7 @@ func TestTermScorerWithQueryNorm(t *testing.T) { }, { Value: math.Sqrt(1.0) * idf, - Message: "fieldWeight(desc:beer in one), product of:", + Message: "fieldWeight(desc:beer in one), as per tfidf model, product of:", Children: []*search.Explanation{ { Value: 1, diff --git a/search/searcher/search_disjunction.go b/search/searcher/search_disjunction.go index d165ec027..434c705e7 100644 --- a/search/searcher/search_disjunction.go +++ b/search/searcher/search_disjunction.go @@ -114,7 +114,7 @@ func optimizeCompositeSearcher(ctx context.Context, optimizationKind string, return nil, nil } - return newTermSearcherFromReader(indexReader, tfr, + return newTermSearcherFromReader(ctx, indexReader, tfr, []byte(optimizationKind), "*", 1.0, options) } diff --git a/search/searcher/search_term.go b/search/searcher/search_term.go index c519d8d51..1c33c6a41 100644 --- a/search/searcher/search_term.go +++ b/search/searcher/search_term.go @@ -16,6 +16,8 @@ package searcher import ( "context" + "fmt" + "math" "reflect" "github.com/blevesearch/bleve/v2/search" @@ -38,14 +40,16 @@ type TermSearcher struct { tfd index.TermFieldDoc } -func NewTermSearcher(ctx context.Context, indexReader index.IndexReader, term string, field string, boost float64, options search.SearcherOptions) (search.Searcher, error) { +func NewTermSearcher(ctx context.Context, indexReader index.IndexReader, + term string, field string, boost float64, options search.SearcherOptions) (search.Searcher, error) { if isTermQuery(ctx) { ctx = context.WithValue(ctx, search.QueryTypeKey, search.Term) } return NewTermSearcherBytes(ctx, indexReader, []byte(term), field, boost, options) } -func NewTermSearcherBytes(ctx context.Context, indexReader index.IndexReader, term []byte, field string, boost float64, options search.SearcherOptions) (search.Searcher, error) { +func NewTermSearcherBytes(ctx context.Context, indexReader index.IndexReader, + term []byte, field string, boost float64, options search.SearcherOptions) (search.Searcher, error) { if ctx != nil { if fts, ok := ctx.Value(search.FieldTermSynonymMapKey).(search.FieldTermSynonymMap); ok { if ts, exists := fts[field]; exists { @@ -60,17 +64,85 @@ func NewTermSearcherBytes(ctx context.Context, indexReader index.IndexReader, te if err != nil { return nil, err } - return newTermSearcherFromReader(indexReader, reader, term, field, boost, options) + return newTermSearcherFromReader(ctx, indexReader, reader, term, field, boost, options) } -func newTermSearcherFromReader(indexReader index.IndexReader, reader index.TermFieldReader, - term []byte, field string, boost float64, options search.SearcherOptions) (*TermSearcher, error) { +func tfIDFScoreMetrics(indexReader index.IndexReader) (uint64, error) { + // default tf-idf stats count, err := indexReader.DocCount() if err != nil { - _ = reader.Close() - return nil, err + return 0, err + } + + if count == 0 { + return 0, nil + } + return count, nil +} + +func bm25ScoreMetrics(ctx context.Context, field string, + indexReader index.IndexReader) (uint64, float64, error) { + var count uint64 + var fieldCardinality int + var err error + + bm25Stats, ok := ctx.Value(search.BM25PreSearchDataKey).(*search.BM25Stats) + if !ok { + count, err = indexReader.DocCount() + if err != nil { + return 0, 0, err + } + dict, err := indexReader.FieldDict(field) + if err != nil { + return 0, 0, err + } + fieldCardinality = dict.Cardinality() + } else { + count = uint64(bm25Stats.DocCount) + fieldCardinality, ok = bm25Stats.FieldCardinality[field] + if !ok { + return 0, 0, fmt.Errorf("field stat for bm25 not present %s", field) + } + } + + if count == 0 && fieldCardinality == 0 { + return 0, 0, nil + } + return count, math.Ceil(float64(fieldCardinality) / float64(count)), nil +} + +func newTermSearcherFromReader(ctx context.Context, indexReader index.IndexReader, + reader index.TermFieldReader, term []byte, field string, boost float64, + options search.SearcherOptions) (*TermSearcher, error) { + var count uint64 + var avgDocLength float64 + var err error + var similarityModel string + + // as a fallback case we track certain stats for tf-idf scoring + if ctx != nil { + if similaritModelCallback, ok := ctx.Value(search. + GetScoringModelCallbackKey).(search.GetScoringModelCallbackFn); ok { + similarityModel = similaritModelCallback() + } + } + switch similarityModel { + case index.BM25Scoring: + count, avgDocLength, err = bm25ScoreMetrics(ctx, field, indexReader) + if err != nil { + _ = reader.Close() + return nil, err + } + case index.TFIDFScoring: + fallthrough + default: + count, err = tfIDFScoreMetrics(indexReader) + if err != nil { + _ = reader.Close() + return nil, err + } } - scorer := scorer.NewTermQueryScorer(term, field, boost, count, reader.Count(), options) + scorer := scorer.NewTermQueryScorer(term, field, boost, count, reader.Count(), avgDocLength, options) return &TermSearcher{ indexReader: indexReader, reader: reader, @@ -85,7 +157,7 @@ func NewSynonymSearcher(ctx context.Context, indexReader index.IndexReader, term if err != nil { return nil, err } - return newTermSearcherFromReader(indexReader, reader, term, field, boostVal, options) + return newTermSearcherFromReader(ctx, indexReader, reader, term, field, boostVal, options) } // create a searcher for the term itself termSearcher, err := createTermSearcher(term, boost) diff --git a/search/util.go b/search/util.go index 2e95f1180..0530c6732 100644 --- a/search/util.go +++ b/search/util.go @@ -135,16 +135,47 @@ const MinGeoBufPoolSize = 24 type GeoBufferPoolCallbackFunc func() *s2.GeoBufferPool +// PreSearchKey indicates whether to perform a preliminary search to gather necessary +// information which would be used in the actual search down the line. +const PreSearchKey = "_presearch_key" + +// *PreSearchDataKey are used to store the data gathered during the presearch phase +// which would be use in the actual search phase. const KnnPreSearchDataKey = "_knn_pre_search_data_key" const SynonymPreSearchDataKey = "_synonym_pre_search_data_key" +const BM25PreSearchDataKey = "_bm25_pre_search_data_key" -const PreSearchKey = "_presearch_key" +// SearchTypeKey is used to identify type of the search being performed. +// +// for consistent scoring in cases an index is partitioned/sharded (using an +// index alias), GlobalScoring helps in aggregating the necessary stats across +// all the child bleve indexes (shards/partitions) first before the actual search +// is performed, such that the scoring involved using these stats would be at a +// global level. +const SearchTypeKey = "_search_type_key" + +// The following keys are used to invoke the callbacks at the start and end stages +// of optimizing the disjunction/conjunction searcher creation. +const SearcherStartCallbackKey = "_searcher_start_callback_key" +const SearcherEndCallbackKey = "_searcher_end_callback_key" -type ScoreExplCorrectionCallbackFunc func(queryMatch *DocumentMatch, knnMatch *DocumentMatch) (float64, *Explanation) +// FieldTermSynonymMapKey is used to store and transport the synonym definitions data +// to the actual search phase which would use the synonyms to perform the search. +const FieldTermSynonymMapKey = "_field_term_synonym_map_key" + +const GlobalScoring = "_global_scoring" + +// GetScoringModelCallbackKey is used to help the underlying searcher identify +// which scoring mechanism to use based on index mapping. +const GetScoringModelCallbackKey = "_get_scoring_model" type SearcherStartCallbackFn func(size uint64) error type SearcherEndCallbackFn func(size uint64) error +type GetScoringModelCallbackFn func() string + +type ScoreExplCorrectionCallbackFunc func(queryMatch *DocumentMatch, knnMatch *DocumentMatch) (float64, *Explanation) + // field -> term -> synonyms type FieldTermSynonymMap map[string]map[string][]string @@ -161,7 +192,17 @@ func (f FieldTermSynonymMap) MergeWith(fts FieldTermSynonymMap) { } } -const FieldTermSynonymMapKey = "_field_term_synonym_map_key" - -const SearcherStartCallbackKey = "_searcher_start_callback_key" -const SearcherEndCallbackKey = "_searcher_end_callback_key" +// BM25 specific multipliers which control the scoring of a document. +// +// BM25_b - controls the extent to which doc's field length normalize term frequency part of score +// BM25_k1 - controls the saturation of the score due to term frequency +// the default values are as per elastic search's implementation +// - https://www.elastic.co/guide/en/elasticsearch/reference/current/index-modules-similarity.html#bm25 +// - https://www.elastic.co/blog/practical-bm25-part-3-considerations-for-picking-b-and-k1-in-elasticsearch +var BM25_k1 float64 = 1.2 +var BM25_b float64 = 0.75 + +type BM25Stats struct { + DocCount float64 `json:"doc_count"` + FieldCardinality map[string]int `json:"field_cardinality"` +}