Skip to content

Commit 02ce845

Browse files
committed
docstore/memdocstore: #3508 allow nested slices query
1 parent 5695484 commit 02ce845

File tree

4 files changed

+95
-22
lines changed

4 files changed

+95
-22
lines changed

docstore/memdocstore/codec.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ func decodeDoc(m storedDoc, ddoc driver.Document, fps [][]string) error {
104104
// (We don't need the key field because ddoc must already have it.)
105105
m2 = map[string]interface{}{}
106106
for _, fp := range fps {
107-
val, err := getAtFieldPath(m, fp)
107+
val, err := getAtFieldPath(m, fp, false)
108108
if err != nil {
109109
if gcerrors.Code(err) == gcerrors.NotFound {
110110
continue

docstore/memdocstore/mem.go

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ type Options struct {
6868
// When the collection is closed, its contents are saved to the file.
6969
Filename string
7070

71+
// AllowNestedSlicesQuery allows querying with nested slices.
72+
// This makes the memdocstore more compatible with MongoDB,
73+
// but other providers may not support this feature.
74+
// See https://github.com/google/go-cloud/pull/3511 for more details.
75+
AllowNestedSlicesQuery bool
76+
7177
// Call this function when the collection is closed.
7278
// For internal use only.
7379
onClose func()
@@ -399,16 +405,34 @@ func (c *collection) checkRevision(arg driver.Document, current storedDoc) error
399405

400406
// getAtFieldPath gets the value of m at fp. It returns an error if fp is invalid
401407
// (see getParentMap).
402-
func getAtFieldPath(m map[string]interface{}, fp []string) (interface{}, error) {
403-
m2, err := getParentMap(m, fp, false)
404-
if err != nil {
405-
return nil, err
408+
func getAtFieldPath(m map[string]interface{}, fp []string, nested bool) (result interface{}, err error) {
409+
410+
var get func(m interface{}, name string) interface{}
411+
get = func(m interface{}, name string) interface{} {
412+
switch concrete := m.(type) {
413+
case map[string]interface{}:
414+
return concrete[name]
415+
case []interface{}:
416+
if !nested {
417+
return nil
418+
}
419+
result := []interface{}{}
420+
for _, e := range concrete {
421+
result = append(result, get(e, name))
422+
}
423+
return result
424+
}
425+
return nil
406426
}
407-
v, ok := m2[fp[len(fp)-1]]
408-
if ok {
409-
return v, nil
427+
result = m
428+
for _, k := range fp {
429+
next := get(result, k)
430+
if next == nil {
431+
return nil, gcerr.Newf(gcerr.NotFound, nil, "field %s not found", strings.Join(fp, "."))
432+
}
433+
result = next
410434
}
411-
return nil, gcerr.Newf(gcerr.NotFound, nil, "field %s not found", fp)
435+
return result, nil
412436
}
413437

414438
// setAtFieldPath sets m's value at fp to val. It creates intermediate maps as
@@ -422,14 +446,6 @@ func setAtFieldPath(m map[string]interface{}, fp []string, val interface{}) erro
422446
return nil
423447
}
424448

425-
// Delete the value from m at the given field path, if it exists.
426-
func deleteAtFieldPath(m map[string]interface{}, fp []string) {
427-
m2, _ := getParentMap(m, fp, false) // ignore error
428-
if m2 != nil {
429-
delete(m2, fp[len(fp)-1])
430-
}
431-
}
432-
433449
// getParentMap returns the map that directly contains the given field path;
434450
// that is, the value of m at the field path that excludes the last component
435451
// of fp. If a non-map is encountered along the way, an InvalidArgument error is

docstore/memdocstore/mem_test.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package memdocstore
1616

1717
import (
1818
"context"
19+
"io"
1920
"os"
2021
"path/filepath"
2122
"testing"
@@ -129,6 +130,49 @@ func TestUpdateAtomic(t *testing.T) {
129130
}
130131
}
131132

133+
func TestQueryNested(t *testing.T) {
134+
ctx := context.Background()
135+
136+
count := func(iter *docstore.DocumentIterator) (c int) {
137+
doc := docmap{}
138+
for {
139+
if err := iter.Next(ctx, doc); err != nil {
140+
if err == io.EOF {
141+
break
142+
}
143+
t.Fatal(err)
144+
}
145+
c++
146+
}
147+
return c
148+
}
149+
150+
dc, err := newCollection(drivertest.KeyField, nil, &Options{AllowNestedSlicesQuery: true})
151+
if err != nil {
152+
t.Fatal(err)
153+
}
154+
coll := docstore.NewCollection(dc)
155+
defer coll.Close()
156+
157+
doc := docmap{drivertest.KeyField: "TestQueryNested",
158+
"list": []any{docmap{"a": "A"}},
159+
"map": docmap{"b": "B"},
160+
dc.RevisionField(): nil,
161+
}
162+
if err := coll.Put(ctx, doc); err != nil {
163+
t.Fatal(err)
164+
}
165+
166+
got := count(coll.Query().Where("list.a", "=", "A").Get(ctx))
167+
if got != 1 {
168+
t.Errorf("got %v docs when filtering by list.a, want 1", got)
169+
}
170+
got = count(coll.Query().Where("map.b", "=", "B").Get(ctx))
171+
if got != 1 {
172+
t.Errorf("got %v docs when filtering by map.b, want 1", got)
173+
}
174+
}
175+
132176
func TestSortDocs(t *testing.T) {
133177
newDocs := func() []storedDoc {
134178
return []storedDoc{

docstore/memdocstore/query.go

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func (c *collection) RunGetQuery(_ context.Context, q *driver.Query) (driver.Doc
3737

3838
var resultDocs []storedDoc
3939
for _, doc := range c.docs {
40-
if filtersMatch(q.Filters, doc) {
40+
if filtersMatch(q.Filters, doc, c.opts.AllowNestedSlicesQuery) {
4141
resultDocs = append(resultDocs, doc)
4242
}
4343
}
@@ -74,17 +74,17 @@ func (c *collection) RunGetQuery(_ context.Context, q *driver.Query) (driver.Doc
7474
}, nil
7575
}
7676

77-
func filtersMatch(fs []driver.Filter, doc storedDoc) bool {
77+
func filtersMatch(fs []driver.Filter, doc storedDoc, nested bool) bool {
7878
for _, f := range fs {
79-
if !filterMatches(f, doc) {
79+
if !filterMatches(f, doc, nested) {
8080
return false
8181
}
8282
}
8383
return true
8484
}
8585

86-
func filterMatches(f driver.Filter, doc storedDoc) bool {
87-
docval, err := getAtFieldPath(doc, f.FieldPath)
86+
func filterMatches(f driver.Filter, doc storedDoc, nested bool) bool {
87+
docval, err := getAtFieldPath(doc, f.FieldPath, nested)
8888
// missing or bad field path => no match
8989
if err != nil {
9090
return false
@@ -138,6 +138,19 @@ func compare(x1, x2 interface{}) (int, bool) {
138138
}
139139
return -1, true
140140
}
141+
if v1.Kind() == reflect.Slice {
142+
for i := 0; i < v1.Len(); i++ {
143+
if c, ok := compare(x2, v1.Index(i).Interface()); ok {
144+
if !ok {
145+
return 0, false
146+
}
147+
if c == 0 {
148+
return 0, true
149+
}
150+
}
151+
}
152+
return -1, true
153+
}
141154
if v1.Kind() == reflect.String && v2.Kind() == reflect.String {
142155
return strings.Compare(v1.String(), v2.String()), true
143156
}

0 commit comments

Comments
 (0)