Skip to content

Commit

Permalink
Add interface tests for QueryKey value comparators
Browse files Browse the repository at this point in the history
  • Loading branch information
ilija42 committed Aug 14, 2024
1 parent 9e3fa6b commit f3e40e7
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -340,11 +340,11 @@ type eventConfidencePair struct {

type fakeChainReader struct {
fakeTypeProvider
vals []valConfidencePair
triggers []eventConfidencePair
stored []TestStruct
vals []valConfidencePair
triggers []eventConfidencePair
stored []TestStruct
batchStored BatchCallEntry
lock sync.Mutex
lock sync.Mutex
}

func (f *fakeChainReader) Start(_ context.Context) error { return nil }
Expand Down Expand Up @@ -512,7 +512,28 @@ func (f *fakeChainReader) QueryKey(_ context.Context, _ string, filter query.Key

var sequences []types.Sequence
for _, trigger := range f.triggers {
sequences = append(sequences, types.Sequence{Data: trigger.testStruct})
doAppend := true
for _, expr := range filter.Expressions {
if primitive, ok := expr.Primitive.(*primitives.Comparator); ok {
if len(primitive.ValueComparators) == 0 {
return nil, fmt.Errorf("value comparator for %s should not be empty", primitive.Name)
}
if primitive.Name == "Field" {
for _, valComp := range primitive.ValueComparators {
doAppend = doAppend && Compare(*trigger.testStruct.Field, *valComp.Value.(*int32), valComp.Operator)
}
} else if primitive.Name == "NestedStruct.FixedBytes" {
// in practice, we won't throw error if there are multiple value comparators for un-comparable types, but such query wouldn't ever return results
if len(primitive.ValueComparators) > 1 || primitive.ValueComparators[0].Operator != primitives.Eq {
return nil, fmt.Errorf("value comparator for FixedBytes should only be filtered by equality and does not support %s operator", primitive.ValueComparators[0].Operator)
}
doAppend = reflect.DeepEqual(*primitive.ValueComparators[0].Value.(*[2]byte), trigger.testStruct.NestedStruct.FixedBytes)
}
}
}
if len(filter.Expressions) == 0 || doAppend {
sequences = append(sequences, types.Sequence{Data: trigger.testStruct})
}
}

if !limitAndSort.HasSequenceSort() {
Expand Down Expand Up @@ -579,6 +600,7 @@ func (e *errChainReader) QueryKey(_ context.Context, _ string, _ query.KeyFilter
}

type protoConversionTestChainReader struct {
testProtoConversionTypeProvider
expectedBindings types.BoundContract
expectedQueryFilter query.KeyFilter
expectedLimitAndSort query.LimitAndSort
Expand Down
64 changes: 51 additions & 13 deletions pkg/loop/internal/relayer/pluginprovider/chainreader/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,37 +93,72 @@ func (fakeTypeProvider) CreateContractType(_, itemType string, isEncode bool) (a
return &FilterEventParams{}, nil
}
return &TestStruct{}, nil
case EventNameField:
if isEncode {
var typ int32
return &typ, nil
}
return 0, errors.New("comparator types should only be encoded")
case EventNameNestedField:
if isEncode {
var typ [2]byte
return &typ, nil
}
return 0, errors.New("comparator types should only be encoded")
}
return nil, types.ErrInvalidType
}

type testProtoConversionTypeProvider struct{}

func (f testProtoConversionTypeProvider) CreateType(itemType string, isEncode bool) (any, error) {
return f.CreateContractType("", itemType, isEncode)
}

var _ types.ContractTypeProvider = (*testProtoConversionTypeProvider)(nil)

func (testProtoConversionTypeProvider) CreateContractType(_, itemType string, isEncode bool) (any, error) {
switch itemType {
case ProtoTest:
return &map[string]any{}, nil
case ProtoTestIntComparator:
if isEncode {
var typ int
return &typ, nil
}
return 0, errors.New("comparator types should only be encoded")
case ProtoTestStringComparator:
if isEncode {
var typ string
return &typ, nil
}
return 0, errors.New("comparator types should only be encoded")
}
return nil, types.ErrInvalidType
}

func generateQueryFilterTestCases(t *testing.T) []query.KeyFilter {
var queryFilters []query.KeyFilter
confirmationsValues := []primitives.ConfidenceLevel{primitives.Finalized, primitives.Unconfirmed}
operatorValues := []primitives.ComparisonOperator{primitives.Eq, primitives.Neq, primitives.Gt, primitives.Lt, primitives.Gte, primitives.Lte}
comparableValues := []string{"", " ", "number", "123"}

primitiveExpressions := []query.Expression{query.TxHash("txHash")}
for _, op := range operatorValues {
primitiveExpressions = append(primitiveExpressions, query.Block("123", op))
primitiveExpressions = append(primitiveExpressions, query.Timestamp(123, op))

var valueComparators []primitives.ValueComparator
for _, comparableValue := range comparableValues {
valueComparators = append(valueComparators, primitives.ValueComparator{
Value: comparableValue,
Operator: op,
})
}
primitiveExpressions = append(primitiveExpressions, query.Comparator("someName", valueComparators...))
a, b, c, d := 1, 2, "123", "321"
valueComparatorsInt := []primitives.ValueComparator{{Value: &a, Operator: op}, {Value: &b, Operator: op}}
valueComparatorsString := []primitives.ValueComparator{{Value: &c, Operator: op}, {Value: &d, Operator: op}}
primitiveExpressions = append(primitiveExpressions, query.Comparator("IntComparator", valueComparatorsInt...))
primitiveExpressions = append(primitiveExpressions, query.Comparator("StringComparator", valueComparatorsString...))
}

for _, conf := range confirmationsValues {
primitiveExpressions = append(primitiveExpressions, query.Confidence(conf))
}

qf, err := query.Where("primitives", primitiveExpressions...)
qf, err := query.Where(ProtoTest, primitiveExpressions...)
require.NoError(t, err)
queryFilters = append(queryFilters, qf)

Expand All @@ -138,15 +173,18 @@ func generateQueryFilterTestCases(t *testing.T) []query.KeyFilter {
)
require.NoError(t, err)

qf, err = query.Where("andOverPrimitivesBoolExpr", andOverPrimitivesBoolExpr)
// andOverPrimitivesBoolExpr
qf, err = query.Where(ProtoTest, andOverPrimitivesBoolExpr)
require.NoError(t, err)
queryFilters = append(queryFilters, qf)

qf, err = query.Where("orOverPrimitivesBoolExpr", orOverPrimitivesBoolExpr)
// orOverPrimitivesBoolExpr
qf, err = query.Where(ProtoTest, orOverPrimitivesBoolExpr)
require.NoError(t, err)
queryFilters = append(queryFilters, qf)

qf, err = query.Where("nestedBoolExpr", nestedBoolExpr)
// nestedBoolExpr
qf, err = query.Where(ProtoTest, nestedBoolExpr)
require.NoError(t, err)
queryFilters = append(queryFilters, qf)

Expand Down
64 changes: 64 additions & 0 deletions pkg/types/interfacetests/chain_reader_interface_tests.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ const (
MethodReturningUint64Slice = "GetSliceValue"
MethodReturningSeenStruct = "GetSeenStruct"
EventName = "SomeEvent"
EventNameField = EventName + ".Field"
EventNameNestedField = EventName + ".NestedStruct.FixedBytes"
ProtoTest = "ProtoTest"
ProtoTestIntComparator = ProtoTest + ".IntComparator"
ProtoTestStringComparator = ProtoTest + ".StringComparator"
EventWithFilterName = "SomeEventToFilter"
AnyContractName = "TestContract"
AnySecondContractName = "Not" + AnyContractName
Expand Down Expand Up @@ -543,6 +548,65 @@ func runQueryKeyInterfaceTests[T TestingT[T]](t T, tester ChainReaderInterfaceTe
}, tester.MaxWaitTimeForEvents(), time.Millisecond*10)
},
},
{
name: "QueryKey can filter data with value comparator",
test: func(t T) {
ctx := tests.Context(t)
cr := tester.GetChainReader(t)
require.NoError(t, cr.Bind(ctx, tester.GetBindings(t)))
ts1 := CreateTestStruct[T](0, tester)
tester.TriggerEvent(t, &ts1)
ts2 := CreateTestStruct[T](15, tester)
tester.TriggerEvent(t, &ts2)
ts3 := CreateTestStruct[T](35, tester)
tester.TriggerEvent(t, &ts3)

ts := &TestStruct{}
assert.Eventually(t, func() bool {
// sequences from queryKey without limit and sort should be in descending order
sequences, err := cr.QueryKey(ctx, AnyContractName, query.KeyFilter{Key: EventName, Expressions: []query.Expression{
query.Comparator("Field",
primitives.ValueComparator{
Value: int32(15),
Operator: primitives.Gte,
},
primitives.ValueComparator{
Value: int32(35),
Operator: primitives.Lte,
}),
},
}, query.LimitAndSort{}, ts)
return err == nil && len(sequences) == 2 && reflect.DeepEqual(&ts2, sequences[1].Data) && reflect.DeepEqual(&ts3, sequences[0].Data)
}, tester.MaxWaitTimeForEvents(), time.Millisecond*10)
},
},
{
name: "QueryKey can filter on nested non dynamic data with value comparator",
test: func(t T) {
ctx := tests.Context(t)
cr := tester.GetChainReader(t)
require.NoError(t, cr.Bind(ctx, tester.GetBindings(t)))
ts1 := CreateTestStruct[T](0, tester)
tester.TriggerEvent(t, &ts1)
ts2 := CreateTestStruct[T](15, tester)
tester.TriggerEvent(t, &ts2)
ts3 := CreateTestStruct[T](35, tester)
tester.TriggerEvent(t, &ts3)

ts := &TestStruct{}
assert.Eventually(t, func() bool {
sequences, err := cr.QueryKey(ctx, AnyContractName, query.KeyFilter{Key: EventName, Expressions: []query.Expression{
query.Comparator("NestedStruct.FixedBytes",
primitives.ValueComparator{
Value: [2]byte{15, 16},
Operator: primitives.Eq,
}),
},
}, query.LimitAndSort{}, ts)
return err == nil && len(sequences) == 1 && reflect.DeepEqual(&ts2, sequences[0].Data)
}, tester.MaxWaitTimeForEvents(), time.Millisecond*10)
},
},
}

runTests(t, tester, tests)
Expand Down
34 changes: 26 additions & 8 deletions pkg/types/interfacetests/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,13 @@ type MidLevelTestStruct struct {

type TestStruct struct {
Field *int32
NestedStruct MidLevelTestStruct
DifferentField string
OracleID commontypes.OracleID
OracleIDs [32]commontypes.OracleID
Account []byte
Accounts [][]byte
BigField *big.Int
NestedStruct MidLevelTestStruct
}

type TestStructWithExtraField struct {
Expand Down Expand Up @@ -115,19 +115,37 @@ func CreateTestStruct[T any](i int, tester BasicTester[T]) TestStruct {
s := fmt.Sprintf("field%v", i)
fv := int32(i)
return TestStruct{
Field: &fv,
DifferentField: s,
OracleID: commontypes.OracleID(i + 1),
OracleIDs: [32]commontypes.OracleID{commontypes.OracleID(i + 2), commontypes.OracleID(i + 3)},
Account: tester.GetAccountBytes(i + 3),
Accounts: [][]byte{tester.GetAccountBytes(i + 4), tester.GetAccountBytes(i + 5)},
BigField: big.NewInt(int64((i + 1) * (i + 2))),
Field: &fv,
NestedStruct: MidLevelTestStruct{
FixedBytes: [2]byte{uint8(i), uint8(i + 1)},
Inner: InnerTestStruct{
I: i,
S: s,
},
},
DifferentField: s,
OracleID: commontypes.OracleID(i + 1),
OracleIDs: [32]commontypes.OracleID{commontypes.OracleID(i + 2), commontypes.OracleID(i + 3)},
Account: tester.GetAccountBytes(i + 3),
Accounts: [][]byte{tester.GetAccountBytes(i + 4), tester.GetAccountBytes(i + 5)},
BigField: big.NewInt(int64((i + 1) * (i + 2))),
}
}

func Compare[T int32](a, b T, op primitives.ComparisonOperator) bool {
switch op {
case primitives.Eq:
return a == b
case primitives.Neq:
return a != b
case primitives.Gt:
return a > b
case primitives.Lt:
return a < b
case primitives.Gte:
return a >= b
case primitives.Lte:
return a <= b
}
return false
}
19 changes: 19 additions & 0 deletions pkg/types/query/primitives/primitives.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,25 @@ const (
Lte
)

func (cmpOp ComparisonOperator) String() string {
switch cmpOp {
case Eq:
return "=="
case Neq:
return "!="
case Gt:
return ">"
case Lt:
return "<"
case Gte:
return ">="
case Lte:
return "<="
default:
return "Unknown"
}
}

type ValueComparator struct {
Value any
Operator ComparisonOperator
Expand Down

0 comments on commit f3e40e7

Please sign in to comment.