diff --git a/docstore/awsdynamodb/query.go b/docstore/awsdynamodb/query.go index b89b2317ea..a53871c7ac 100644 --- a/docstore/awsdynamodb/query.go +++ b/docstore/awsdynamodb/query.go @@ -20,6 +20,7 @@ import ( "errors" "fmt" "io" + "reflect" "sort" "strings" "time" @@ -436,11 +437,26 @@ func toFilter(f driver.Filter) expression.ConditionBuilder { return expression.GreaterThanEqual(name, val) case ">": return expression.GreaterThan(name, val) + case "in": + return toInCondition(f) + case "not-in": + return expression.Not(toInCondition(f)) default: panic(fmt.Sprint("invalid filter operation:", f.Op)) } } +func toInCondition(f driver.Filter) expression.ConditionBuilder { + name := expression.Name(strings.Join(f.FieldPath, ".")) + vslice := reflect.ValueOf(f.Value) + right := expression.Value(vslice.Index(0).Interface()) + other := make([]expression.OperandBuilder, vslice.Len()-1) + for i := 1; i < vslice.Len(); i++ { + other[i-1] = expression.Value(vslice.Index(i).Interface()) + } + return expression.In(name, right, other...) +} + type documentIterator struct { qr *queryRunner items []map[string]*dyn.AttributeValue diff --git a/docstore/drivertest/drivertest.go b/docstore/drivertest/drivertest.go index 44051e3b32..39aaaed5cd 100644 --- a/docstore/drivertest/drivertest.go +++ b/docstore/drivertest/drivertest.go @@ -1423,6 +1423,16 @@ func testGetQuery(t *testing.T, _ Harness, coll *docstore.Collection) { q: coll.Query().Where("Score", ">=", 50).Where("Time", ">", date(4, 1)), want: func(h *HighScore) bool { return h.Score >= 50 && h.Time.After(date(4, 1)) }, }, + { + name: "PlayerIn", + q: coll.Query().Where("Player", "in", []string{"pat", "billie"}), + want: func(h *HighScore) bool { return h.Player == "pat" || h.Player == "billie" }, + }, + { + name: "PlayerNotIn", + q: coll.Query().Where("Player", "not-in", []string{"pat", "billie"}), + want: func(h *HighScore) bool { return h.Player != "pat" && h.Player != "billie" }, + }, { name: "AllByPlayerAsc", q: coll.Query().OrderBy("Player", docstore.Ascending), diff --git a/docstore/gcpfirestore/query.go b/docstore/gcpfirestore/query.go index f4c4eaad00..2c826ae328 100644 --- a/docstore/gcpfirestore/query.go +++ b/docstore/gcpfirestore/query.go @@ -293,8 +293,7 @@ func (c *collection) filterToProto(f driver.Filter) (*pb.StructuredQuery_Filter, FilterType: &pb.StructuredQuery_Filter_UnaryFilter{ UnaryFilter: &pb.StructuredQuery_UnaryFilter{ OperandType: &pb.StructuredQuery_UnaryFilter_Field{ - Field: fieldRef(f.FieldPath), - }, + Field: fieldRef(f.FieldPath)}, Op: uop, }, }, @@ -346,6 +345,10 @@ func newFieldFilter(fp []string, op string, val *pb.Value) (*pb.StructuredQuery_ fop = pb.StructuredQuery_FieldFilter_GREATER_THAN_OR_EQUAL case driver.EqualOp: fop = pb.StructuredQuery_FieldFilter_EQUAL + case "in": + fop = pb.StructuredQuery_FieldFilter_IN + case "not-in": + fop = pb.StructuredQuery_FieldFilter_NOT_IN // TODO(jba): can we support array-contains portably? // case "array-contains": // fop = pb.StructuredQuery_FieldFilter_ARRAY_CONTAINS diff --git a/docstore/memdocstore/query.go b/docstore/memdocstore/query.go index 0ea0bb129d..1086ef484e 100644 --- a/docstore/memdocstore/query.go +++ b/docstore/memdocstore/query.go @@ -86,7 +86,7 @@ func filterMatches(f driver.Filter, doc storedDoc) bool { return applyComparison(f.Op, c) } -// op is one of the five permitted docstore operators ("=", "<", etc.) +// op is one of the permitted docstore operators ("=", "<", etc.) // c is the result of strings.Compare or the like. // TODO(jba): dedup from gcpfirestore/query? func applyComparison(op string, c int) bool { @@ -101,6 +101,10 @@ func applyComparison(op string, c int) bool { return c >= 0 case "<=": return c <= 0 + case "in": + return c == 0 + case "not-in": + return c != 0 default: panic("bad op") } @@ -109,6 +113,21 @@ func applyComparison(op string, c int) bool { func compare(x1, x2 interface{}) (int, bool) { v1 := reflect.ValueOf(x1) v2 := reflect.ValueOf(x2) + // this is for in/not-in queries. + // return 0 if x1 is in slice x2, -1 if not. + if v2.Kind() == reflect.Slice { + for i := 0; i < v2.Len(); i++ { + if c, ok := compare(x1, v2.Index(i).Interface()); ok { + if !ok { + return 0, false + } + if c == 0 { + return 0, true + } + } + } + return -1, true + } if v1.Kind() == reflect.String && v2.Kind() == reflect.String { return strings.Compare(v1.String(), v2.String()), true } diff --git a/docstore/mongodocstore/mongo_test.go b/docstore/mongodocstore/mongo_test.go index 9073d7a1c4..2e29e1d21e 100644 --- a/docstore/mongodocstore/mongo_test.go +++ b/docstore/mongodocstore/mongo_test.go @@ -301,4 +301,23 @@ func TestLowercaseFields(t *testing.T) { var got6 S must(coll.Query().OrderBy("G", docstore.Descending).Get(ctx).Next(ctx, &got6)) check(got6, *sdoc2) + + // List queries + // select F from coll WHERE G IN (50, 51) ORDER BY G DESC + // test that F is 99 + sdoc3 := &S{ID: 3, F: 99, G: 50} + sdoc4 := &S{ID: 4, F: 99, G: 51} + must(coll.Put(ctx, sdoc3)) + must(coll.Put(ctx, sdoc4)) + var got7, got8 S + iter := coll.Query().Where("G", "in", []int{50, 51}).OrderBy("G", docstore.Descending).Get(ctx) + must(iter.Next(ctx, &got7)) + must(iter.Next(ctx, &got8)) + check(got7, *sdoc4) + check(got8, *sdoc3) + + // same query with not-in, expect to get sdoc2 back even though G is higher for sdoc3 and sdoc4 + var got9 S + must(coll.Query().Where("G", "not-in", []int{50, 51}).OrderBy("G", docstore.Descending).Get(ctx).Next(ctx, &got9)) + check(got9, *sdoc2) } diff --git a/docstore/mongodocstore/query.go b/docstore/mongodocstore/query.go index d37ea2f814..e90c625b22 100644 --- a/docstore/mongodocstore/query.go +++ b/docstore/mongodocstore/query.go @@ -75,6 +75,8 @@ var mongoQueryOps = map[string]string{ ">=": "$gte", "<": "$lt", "<=": "$lte", + "in": "$in", + "not-in": "$nin", } // filtersToBSON converts a []driver.Filter to the MongoDB equivalent, expressed diff --git a/docstore/query.go b/docstore/query.go index fabfef03d4..44b2bb995e 100644 --- a/docstore/query.go +++ b/docstore/query.go @@ -37,7 +37,7 @@ func (c *Collection) Query() *Query { } // Where expresses a condition on the query. -// Valid ops are: "=", ">", "<", ">=", "<=". +// Valid ops are: "=", ">", "<", ">=", "<=, "in", "not-in". // Valid values are strings, integers, floating-point numbers, and time.Time values. func (q *Query) Where(fp FieldPath, op string, value interface{}) *Query { if q.err != nil { @@ -48,10 +48,11 @@ func (q *Query) Where(fp FieldPath, op string, value interface{}) *Query { q.err = err return q } - if !validOp[op] { - return q.invalidf("invalid filter operator: %q. Use one of: =, >, <, >=, <=", op) + validator, ok := validOp[op] + if !ok { + return q.invalidf("invalid filter operator: %q. Use one of: =, >, <, >=, <=, in, not-in", op) } - if !validFilterValue(value) { + if !validator(value) { return q.invalidf("invalid filter value: %v", value) } q.dq.Filters = append(q.dq.Filters, driver.Filter{ @@ -62,12 +63,16 @@ func (q *Query) Where(fp FieldPath, op string, value interface{}) *Query { return q } -var validOp = map[string]bool{ - "=": true, - ">": true, - "<": true, - ">=": true, - "<=": true, +type valueValidator func(interface{}) bool + +var validOp = map[string]valueValidator{ + "=": validFilterValue, + ">": validFilterValue, + "<": validFilterValue, + ">=": validFilterValue, + "<=": validFilterValue, + "in": validFilterSlice, + "not-in": validFilterSlice, } func validFilterValue(v interface{}) bool { @@ -91,6 +96,19 @@ func validFilterValue(v interface{}) bool { } } +func validFilterSlice(v interface{}) bool { + if v == nil || reflect.TypeOf(v).Kind() != reflect.Slice { + return false + } + vv := reflect.ValueOf(v) + for i := 0; i < vv.Len(); i++ { + if !validFilterValue(vv.Index(i).Interface()) { + return false + } + } + return true +} + // Limit will limit the results to at most n documents. // n must be positive. // It is an error to specify Limit more than once in a Get query, or