diff --git a/docstore/memdocstore/codec.go b/docstore/memdocstore/codec.go index 68c4d1db85..c2c5baca59 100644 --- a/docstore/memdocstore/codec.go +++ b/docstore/memdocstore/codec.go @@ -104,7 +104,7 @@ func decodeDoc(m storedDoc, ddoc driver.Document, fps [][]string) error { // (We don't need the key field because ddoc must already have it.) m2 = map[string]interface{}{} for _, fp := range fps { - val, err := getAtFieldPath(m, fp) + val, err := getAtFieldPath(m, fp, false) if err != nil { if gcerrors.Code(err) == gcerrors.NotFound { continue diff --git a/docstore/memdocstore/mem.go b/docstore/memdocstore/mem.go index c4ecf83178..0077e76fcc 100644 --- a/docstore/memdocstore/mem.go +++ b/docstore/memdocstore/mem.go @@ -68,6 +68,13 @@ type Options struct { // When the collection is closed, its contents are saved to the file. Filename string + // AllowNestedSliceQueries allows querying into nested slices. + // If true queries for a field path which points to a slice will return + // true if any element of the slice has a value that validates with the operator. + // This makes the memdocstore more compatible with MongoDB, + // but other providers may not support this feature. + AllowNestedSliceQueries bool + // Call this function when the collection is closed. // For internal use only. onClose func() @@ -397,18 +404,44 @@ func (c *collection) checkRevision(arg driver.Document, current storedDoc) error return nil } -// getAtFieldPath gets the value of m at fp. It returns an error if fp is invalid +// getAtFieldPath gets the value of m at fp. It returns an error if fp is invalid. +// If nested is true compare against all elements of a slice, see AllowNestedSliceQueries // (see getParentMap). -func getAtFieldPath(m map[string]interface{}, fp []string) (interface{}, error) { - m2, err := getParentMap(m, fp, false) - if err != nil { - return nil, err +func getAtFieldPath(m map[string]any, fp []string, nested bool) (result any, err error) { + var get func(m any, name string) any + get = func(m any, name string) any { + switch m := m.(type) { + case map[string]any: + return m[name] + case []any: + if !nested { + return nil + } + var result []any + for _, e := range m { + next := get(e, name) + // If we have slices within slices the compare function does not see the nested slices. + // Changing the compare function to be recursive would be more effort than flattening the slices here. + sliced, ok := next.([]any) + if ok { + result = append(result, sliced...) + } else { + result = append(result, next) + } + } + return result + } + return nil } - v, ok := m2[fp[len(fp)-1]] - if ok { - return v, nil + result = m + for _, k := range fp { + next := get(result, k) + if next == nil { + return nil, gcerr.Newf(gcerr.NotFound, nil, "field %s not found", strings.Join(fp, ".")) + } + result = next } - return nil, gcerr.Newf(gcerr.NotFound, nil, "field %s not found", fp) + return result, nil } // setAtFieldPath sets m's value at fp to val. It creates intermediate maps as @@ -422,14 +455,6 @@ func setAtFieldPath(m map[string]interface{}, fp []string, val interface{}) erro return nil } -// Delete the value from m at the given field path, if it exists. -func deleteAtFieldPath(m map[string]interface{}, fp []string) { - m2, _ := getParentMap(m, fp, false) // ignore error - if m2 != nil { - delete(m2, fp[len(fp)-1]) - } -} - // getParentMap returns the map that directly contains the given field path; // that is, the value of m at the field path that excludes the last component // of fp. If a non-map is encountered along the way, an InvalidArgument error is diff --git a/docstore/memdocstore/mem_test.go b/docstore/memdocstore/mem_test.go index 9a843202d6..3d8d617c94 100644 --- a/docstore/memdocstore/mem_test.go +++ b/docstore/memdocstore/mem_test.go @@ -16,6 +16,7 @@ package memdocstore import ( "context" + "io" "os" "path/filepath" "testing" @@ -131,6 +132,141 @@ func TestUpdateAtomic(t *testing.T) { } } +func TestQueryNested(t *testing.T) { + ctx := context.Background() + + dc, err := newCollection(drivertest.KeyField, nil, &Options{AllowNestedSliceQueries: true}) + if err != nil { + t.Fatal(err) + } + coll := docstore.NewCollection(dc) + defer coll.Close() + + // Set up test documents + testDocs := []docmap{{ + drivertest.KeyField: "TestQueryNested", + "list": []any{docmap{"a": "A"}}, + "map": docmap{"b": "B"}, + "listOfMaps": []any{docmap{"id": "1"}, docmap{"id": "2"}, docmap{"id": "3"}}, + "mapOfLists": docmap{"ids": []any{"1", "2", "3"}}, + "deep": []any{docmap{"nesting": []any{docmap{"of": docmap{"elements": "yes"}}}}}, + "listOfLists": []any{docmap{"items": []any{docmap{"price": 10}, docmap{"price": 20}}}}, + dc.RevisionField(): nil, + }, { + drivertest.KeyField: "CheapItems", + "items": []any{docmap{"price": 10}, docmap{"price": 1}}, + dc.RevisionField(): nil, + }, { + drivertest.KeyField: "ExpensiveItems", + "items": []any{docmap{"price": 50}, docmap{"price": 100}}, + dc.RevisionField(): nil, + }} + + for _, testDoc := range testDocs { + err = coll.Put(ctx, testDoc) + if err != nil { + t.Fatal(err) + } + } + + tests := []struct { + name string + where []any + wantKeys []string + }{ + { + name: "list field match", + where: []any{"list.a", "=", "A"}, + wantKeys: []string{"TestQueryNested"}, + }, { + name: "list field no match", + where: []any{"list.a", "=", "missing"}, + }, { + name: "map field match", + where: []any{"map.b", "=", "B"}, + wantKeys: []string{"TestQueryNested"}, + }, { + name: "list of maps field match", + where: []any{"listOfMaps.id", "=", "2"}, + wantKeys: []string{"TestQueryNested"}, + }, { + name: "map of lists field match", + where: []any{"mapOfLists.ids", "=", "1"}, + wantKeys: []string{"TestQueryNested"}, + }, { + name: "deep nested field match", + where: []any{"deep.nesting.of.elements", "=", "yes"}, + wantKeys: []string{"TestQueryNested"}, + }, { + name: "list of lists exact price 10", + where: []any{"listOfLists.items.price", "=", 10}, + wantKeys: []string{"TestQueryNested"}, + }, { + name: "list of lists exact price 20", + where: []any{"listOfLists.items.price", "=", 20}, + wantKeys: []string{"TestQueryNested"}, + }, { + name: "list of lists price less than or equal to 20", + where: []any{"listOfLists.items.price", "<=", 20}, + wantKeys: []string{"TestQueryNested"}, + }, { + name: "items price equals 1", + where: []any{"items.price", "=", 1}, + wantKeys: []string{"CheapItems"}, + }, { + name: "items price equals 5 (no match)", + where: []any{"items.price", "=", 5}, + }, { + name: "items price greater than or equal to 1", + where: []any{"items.price", ">=", 1}, + wantKeys: []string{"CheapItems", "ExpensiveItems"}, + }, { + name: "items price greater than or equal to 5", + where: []any{"items.price", ">=", 5}, + wantKeys: []string{"CheapItems", "ExpensiveItems"}, + }, { + name: "items price greater than or equal to 10", + where: []any{"items.price", ">=", 10}, + wantKeys: []string{"CheapItems", "ExpensiveItems"}, + }, { + name: "items price less than or equal to 50", + where: []any{"items.price", "<=", 50}, + wantKeys: []string{"CheapItems", "ExpensiveItems"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + iter := coll.Query().Where(docstore.FieldPath(tc.where[0].(string)), tc.where[1].(string), tc.where[2]).Get(ctx) + var got []docmap + for { + doc := docmap{} + err := iter.Next(ctx, doc) + if err != nil { + if err == io.EOF { + break + } + t.Fatal(err) + } + got = append(got, doc) + } + + // Extract keys from results + var gotKeys []string + for _, d := range got { + if key, ok := d[drivertest.KeyField].(string); ok { + gotKeys = append(gotKeys, key) + } + } + + diff := cmp.Diff(gotKeys, tc.wantKeys) + if diff != "" { + t.Errorf("query results mismatch (-got +want):\n%s", diff) + } + }) + } +} + func TestSortDocs(t *testing.T) { newDocs := func() []storedDoc { return []storedDoc{ diff --git a/docstore/memdocstore/query.go b/docstore/memdocstore/query.go index 419017b993..24830e6ed2 100644 --- a/docstore/memdocstore/query.go +++ b/docstore/memdocstore/query.go @@ -37,7 +37,7 @@ func (c *collection) RunGetQuery(_ context.Context, q *driver.Query) (driver.Doc var resultDocs []storedDoc for _, doc := range c.docs { - if filtersMatch(q.Filters, doc) { + if filtersMatch(q.Filters, doc, c.opts.AllowNestedSliceQueries) { resultDocs = append(resultDocs, doc) } } @@ -74,22 +74,22 @@ func (c *collection) RunGetQuery(_ context.Context, q *driver.Query) (driver.Doc }, nil } -func filtersMatch(fs []driver.Filter, doc storedDoc) bool { +func filtersMatch(fs []driver.Filter, doc storedDoc, nested bool) bool { for _, f := range fs { - if !filterMatches(f, doc) { + if !filterMatches(f, doc, nested) { return false } } return true } -func filterMatches(f driver.Filter, doc storedDoc) bool { - docval, err := getAtFieldPath(doc, f.FieldPath) +func filterMatches(f driver.Filter, doc storedDoc, nested bool) bool { + docval, err := getAtFieldPath(doc, f.FieldPath, nested) // missing or bad field path => no match if err != nil { return false } - c, ok := compare(docval, f.Value) + c, ok := compare(docval, f.Value, f.Op) if !ok { return false } @@ -120,24 +120,46 @@ func applyComparison(op string, c int) bool { } } -func compare(x1, x2 interface{}) (int, bool) { +func compare(x1, x2 any, op string) (int, bool) { v1 := reflect.ValueOf(x1) v2 := reflect.ValueOf(x2) - // this is for in/not-in queries. - // return 0 if x1 is in slice x2, -1 if not. + // For in/not-in queries. Otherwise this should only be reached with AllowNestedSliceQueries set. + // Return 0 if x1 is in slice x2, -1 if not. if v2.Kind() == reflect.Slice { - for i := 0; i < v2.Len(); i++ { - if c, ok := compare(x1, v2.Index(i).Interface()); ok { - if !ok { - return 0, false - } + for i := range v2.Len() { + if c, ok := compare(x1, v2.Index(i).Interface(), op); ok { if c == 0 { return 0, true } + if op != "in" && op != "not-in" { + return c, true + } } } return -1, true } + // See Options.AllowNestedSliceQueries + // When querying for x2 in the document and x1 is a list of values we only need one value to match + // the comparison value depends on the operator. + if v1.Kind() == reflect.Slice { + v2Greater := false + v2Less := false + for i := range v1.Len() { + if c, ok := compare(x2, v1.Index(i).Interface(), op); ok { + if c == 0 { + return 0, true + } + v2Greater = v2Greater || c > 0 + v2Less = v2Less || c < 0 + } + } + if op[0] == '>' && v2Less { + return 1, true + } else if op[0] == '<' && v2Greater { + return -1, true + } + return 0, false + } if v1.Kind() == reflect.String && v2.Kind() == reflect.String { return strings.Compare(v1.String(), v2.String()), true } @@ -160,7 +182,7 @@ func compare(x1, x2 interface{}) (int, bool) { func sortDocs(docs []storedDoc, field string, asc bool) { sort.Slice(docs, func(i, j int) bool { - c, ok := compare(docs[i][field], docs[j][field]) + c, ok := compare(docs[i][field], docs[j][field], ">") if !ok { return false } diff --git a/docstore/memdocstore/urls.go b/docstore/memdocstore/urls.go index 84d3258aeb..5f99a00157 100644 --- a/docstore/memdocstore/urls.go +++ b/docstore/memdocstore/urls.go @@ -65,8 +65,9 @@ func (o *URLOpener) OpenCollectionURL(ctx context.Context, u *url.URL) (*docstor } options := &Options{ - RevisionField: q.Get("revision_field"), - Filename: q.Get("filename"), + RevisionField: q.Get("revision_field"), + Filename: q.Get("filename"), + AllowNestedSliceQueries: q.Get("allow_nested_slice_queries") == "true", onClose: func() { o.mu.Lock() delete(o.collections, collName) @@ -75,6 +76,7 @@ func (o *URLOpener) OpenCollectionURL(ctx context.Context, u *url.URL) (*docstor } q.Del("revision_field") q.Del("filename") + q.Del("allow_nested_slice_queries") for param := range q { return nil, fmt.Errorf("open collection %v: invalid query parameter %q", u, param) }