From f3e40e7dcc624339009f8dbef78c49ccef70065f Mon Sep 17 00:00:00 2001 From: ilija Date: Wed, 14 Aug 2024 14:40:34 +0200 Subject: [PATCH] Add interface tests for QueryKey value comparators --- .../chainreader/chain_reader_test.go | 32 ++++++++-- .../pluginprovider/chainreader/helper_test.go | 64 +++++++++++++++---- .../chain_reader_interface_tests.go | 64 +++++++++++++++++++ pkg/types/interfacetests/utils.go | 34 +++++++--- pkg/types/query/primitives/primitives.go | 19 ++++++ 5 files changed, 187 insertions(+), 26 deletions(-) diff --git a/pkg/loop/internal/relayer/pluginprovider/chainreader/chain_reader_test.go b/pkg/loop/internal/relayer/pluginprovider/chainreader/chain_reader_test.go index af93c6889..fb4c17b72 100644 --- a/pkg/loop/internal/relayer/pluginprovider/chainreader/chain_reader_test.go +++ b/pkg/loop/internal/relayer/pluginprovider/chainreader/chain_reader_test.go @@ -340,11 +340,11 @@ type eventConfidencePair struct { type fakeChainReader struct { fakeTypeProvider - vals []valConfidencePair - triggers []eventConfidencePair - stored []TestStruct + vals []valConfidencePair + triggers []eventConfidencePair + stored []TestStruct batchStored BatchCallEntry - lock sync.Mutex + lock sync.Mutex } func (f *fakeChainReader) Start(_ context.Context) error { return nil } @@ -512,7 +512,28 @@ func (f *fakeChainReader) QueryKey(_ context.Context, _ string, filter query.Key var sequences []types.Sequence for _, trigger := range f.triggers { - sequences = append(sequences, types.Sequence{Data: trigger.testStruct}) + doAppend := true + for _, expr := range filter.Expressions { + if primitive, ok := expr.Primitive.(*primitives.Comparator); ok { + if len(primitive.ValueComparators) == 0 { + return nil, fmt.Errorf("value comparator for %s should not be empty", primitive.Name) + } + if primitive.Name == "Field" { + for _, valComp := range primitive.ValueComparators { + doAppend = doAppend && Compare(*trigger.testStruct.Field, *valComp.Value.(*int32), valComp.Operator) + } + } else if primitive.Name == "NestedStruct.FixedBytes" { + // in practice, we won't throw error if there are multiple value comparators for un-comparable types, but such query wouldn't ever return results + if len(primitive.ValueComparators) > 1 || primitive.ValueComparators[0].Operator != primitives.Eq { + return nil, fmt.Errorf("value comparator for FixedBytes should only be filtered by equality and does not support %s operator", primitive.ValueComparators[0].Operator) + } + doAppend = reflect.DeepEqual(*primitive.ValueComparators[0].Value.(*[2]byte), trigger.testStruct.NestedStruct.FixedBytes) + } + } + } + if len(filter.Expressions) == 0 || doAppend { + sequences = append(sequences, types.Sequence{Data: trigger.testStruct}) + } } if !limitAndSort.HasSequenceSort() { @@ -579,6 +600,7 @@ func (e *errChainReader) QueryKey(_ context.Context, _ string, _ query.KeyFilter } type protoConversionTestChainReader struct { + testProtoConversionTypeProvider expectedBindings types.BoundContract expectedQueryFilter query.KeyFilter expectedLimitAndSort query.LimitAndSort diff --git a/pkg/loop/internal/relayer/pluginprovider/chainreader/helper_test.go b/pkg/loop/internal/relayer/pluginprovider/chainreader/helper_test.go index a38b87eb0..0d2ab71df 100644 --- a/pkg/loop/internal/relayer/pluginprovider/chainreader/helper_test.go +++ b/pkg/loop/internal/relayer/pluginprovider/chainreader/helper_test.go @@ -93,8 +93,47 @@ func (fakeTypeProvider) CreateContractType(_, itemType string, isEncode bool) (a return &FilterEventParams{}, nil } return &TestStruct{}, nil + case EventNameField: + if isEncode { + var typ int32 + return &typ, nil + } + return 0, errors.New("comparator types should only be encoded") + case EventNameNestedField: + if isEncode { + var typ [2]byte + return &typ, nil + } + return 0, errors.New("comparator types should only be encoded") } + return nil, types.ErrInvalidType +} + +type testProtoConversionTypeProvider struct{} + +func (f testProtoConversionTypeProvider) CreateType(itemType string, isEncode bool) (any, error) { + return f.CreateContractType("", itemType, isEncode) +} + +var _ types.ContractTypeProvider = (*testProtoConversionTypeProvider)(nil) +func (testProtoConversionTypeProvider) CreateContractType(_, itemType string, isEncode bool) (any, error) { + switch itemType { + case ProtoTest: + return &map[string]any{}, nil + case ProtoTestIntComparator: + if isEncode { + var typ int + return &typ, nil + } + return 0, errors.New("comparator types should only be encoded") + case ProtoTestStringComparator: + if isEncode { + var typ string + return &typ, nil + } + return 0, errors.New("comparator types should only be encoded") + } return nil, types.ErrInvalidType } @@ -102,28 +141,24 @@ func generateQueryFilterTestCases(t *testing.T) []query.KeyFilter { var queryFilters []query.KeyFilter confirmationsValues := []primitives.ConfidenceLevel{primitives.Finalized, primitives.Unconfirmed} operatorValues := []primitives.ComparisonOperator{primitives.Eq, primitives.Neq, primitives.Gt, primitives.Lt, primitives.Gte, primitives.Lte} - comparableValues := []string{"", " ", "number", "123"} primitiveExpressions := []query.Expression{query.TxHash("txHash")} for _, op := range operatorValues { primitiveExpressions = append(primitiveExpressions, query.Block("123", op)) primitiveExpressions = append(primitiveExpressions, query.Timestamp(123, op)) - var valueComparators []primitives.ValueComparator - for _, comparableValue := range comparableValues { - valueComparators = append(valueComparators, primitives.ValueComparator{ - Value: comparableValue, - Operator: op, - }) - } - primitiveExpressions = append(primitiveExpressions, query.Comparator("someName", valueComparators...)) + a, b, c, d := 1, 2, "123", "321" + valueComparatorsInt := []primitives.ValueComparator{{Value: &a, Operator: op}, {Value: &b, Operator: op}} + valueComparatorsString := []primitives.ValueComparator{{Value: &c, Operator: op}, {Value: &d, Operator: op}} + primitiveExpressions = append(primitiveExpressions, query.Comparator("IntComparator", valueComparatorsInt...)) + primitiveExpressions = append(primitiveExpressions, query.Comparator("StringComparator", valueComparatorsString...)) } for _, conf := range confirmationsValues { primitiveExpressions = append(primitiveExpressions, query.Confidence(conf)) } - qf, err := query.Where("primitives", primitiveExpressions...) + qf, err := query.Where(ProtoTest, primitiveExpressions...) require.NoError(t, err) queryFilters = append(queryFilters, qf) @@ -138,15 +173,18 @@ func generateQueryFilterTestCases(t *testing.T) []query.KeyFilter { ) require.NoError(t, err) - qf, err = query.Where("andOverPrimitivesBoolExpr", andOverPrimitivesBoolExpr) + // andOverPrimitivesBoolExpr + qf, err = query.Where(ProtoTest, andOverPrimitivesBoolExpr) require.NoError(t, err) queryFilters = append(queryFilters, qf) - qf, err = query.Where("orOverPrimitivesBoolExpr", orOverPrimitivesBoolExpr) + // orOverPrimitivesBoolExpr + qf, err = query.Where(ProtoTest, orOverPrimitivesBoolExpr) require.NoError(t, err) queryFilters = append(queryFilters, qf) - qf, err = query.Where("nestedBoolExpr", nestedBoolExpr) + // nestedBoolExpr + qf, err = query.Where(ProtoTest, nestedBoolExpr) require.NoError(t, err) queryFilters = append(queryFilters, qf) diff --git a/pkg/types/interfacetests/chain_reader_interface_tests.go b/pkg/types/interfacetests/chain_reader_interface_tests.go index 14cecf7db..c35fac0ec 100644 --- a/pkg/types/interfacetests/chain_reader_interface_tests.go +++ b/pkg/types/interfacetests/chain_reader_interface_tests.go @@ -43,6 +43,11 @@ const ( MethodReturningUint64Slice = "GetSliceValue" MethodReturningSeenStruct = "GetSeenStruct" EventName = "SomeEvent" + EventNameField = EventName + ".Field" + EventNameNestedField = EventName + ".NestedStruct.FixedBytes" + ProtoTest = "ProtoTest" + ProtoTestIntComparator = ProtoTest + ".IntComparator" + ProtoTestStringComparator = ProtoTest + ".StringComparator" EventWithFilterName = "SomeEventToFilter" AnyContractName = "TestContract" AnySecondContractName = "Not" + AnyContractName @@ -543,6 +548,65 @@ func runQueryKeyInterfaceTests[T TestingT[T]](t T, tester ChainReaderInterfaceTe }, tester.MaxWaitTimeForEvents(), time.Millisecond*10) }, }, + { + name: "QueryKey can filter data with value comparator", + test: func(t T) { + ctx := tests.Context(t) + cr := tester.GetChainReader(t) + require.NoError(t, cr.Bind(ctx, tester.GetBindings(t))) + ts1 := CreateTestStruct[T](0, tester) + tester.TriggerEvent(t, &ts1) + ts2 := CreateTestStruct[T](15, tester) + tester.TriggerEvent(t, &ts2) + ts3 := CreateTestStruct[T](35, tester) + tester.TriggerEvent(t, &ts3) + + ts := &TestStruct{} + assert.Eventually(t, func() bool { + // sequences from queryKey without limit and sort should be in descending order + sequences, err := cr.QueryKey(ctx, AnyContractName, query.KeyFilter{Key: EventName, Expressions: []query.Expression{ + query.Comparator("Field", + primitives.ValueComparator{ + Value: int32(15), + Operator: primitives.Gte, + }, + primitives.ValueComparator{ + Value: int32(35), + Operator: primitives.Lte, + }), + }, + }, query.LimitAndSort{}, ts) + return err == nil && len(sequences) == 2 && reflect.DeepEqual(&ts2, sequences[1].Data) && reflect.DeepEqual(&ts3, sequences[0].Data) + }, tester.MaxWaitTimeForEvents(), time.Millisecond*10) + }, + }, + { + name: "QueryKey can filter on nested non dynamic data with value comparator", + test: func(t T) { + ctx := tests.Context(t) + cr := tester.GetChainReader(t) + require.NoError(t, cr.Bind(ctx, tester.GetBindings(t))) + ts1 := CreateTestStruct[T](0, tester) + tester.TriggerEvent(t, &ts1) + ts2 := CreateTestStruct[T](15, tester) + tester.TriggerEvent(t, &ts2) + ts3 := CreateTestStruct[T](35, tester) + tester.TriggerEvent(t, &ts3) + + ts := &TestStruct{} + assert.Eventually(t, func() bool { + sequences, err := cr.QueryKey(ctx, AnyContractName, query.KeyFilter{Key: EventName, Expressions: []query.Expression{ + query.Comparator("NestedStruct.FixedBytes", + primitives.ValueComparator{ + Value: [2]byte{15, 16}, + Operator: primitives.Eq, + }), + }, + }, query.LimitAndSort{}, ts) + return err == nil && len(sequences) == 1 && reflect.DeepEqual(&ts2, sequences[0].Data) + }, tester.MaxWaitTimeForEvents(), time.Millisecond*10) + }, + }, } runTests(t, tester, tests) diff --git a/pkg/types/interfacetests/utils.go b/pkg/types/interfacetests/utils.go index cf0983104..a45116ed9 100644 --- a/pkg/types/interfacetests/utils.go +++ b/pkg/types/interfacetests/utils.go @@ -59,13 +59,13 @@ type MidLevelTestStruct struct { type TestStruct struct { Field *int32 + NestedStruct MidLevelTestStruct DifferentField string OracleID commontypes.OracleID OracleIDs [32]commontypes.OracleID Account []byte Accounts [][]byte BigField *big.Int - NestedStruct MidLevelTestStruct } type TestStructWithExtraField struct { @@ -115,13 +115,7 @@ func CreateTestStruct[T any](i int, tester BasicTester[T]) TestStruct { s := fmt.Sprintf("field%v", i) fv := int32(i) return TestStruct{ - Field: &fv, - DifferentField: s, - OracleID: commontypes.OracleID(i + 1), - OracleIDs: [32]commontypes.OracleID{commontypes.OracleID(i + 2), commontypes.OracleID(i + 3)}, - Account: tester.GetAccountBytes(i + 3), - Accounts: [][]byte{tester.GetAccountBytes(i + 4), tester.GetAccountBytes(i + 5)}, - BigField: big.NewInt(int64((i + 1) * (i + 2))), + Field: &fv, NestedStruct: MidLevelTestStruct{ FixedBytes: [2]byte{uint8(i), uint8(i + 1)}, Inner: InnerTestStruct{ @@ -129,5 +123,29 @@ func CreateTestStruct[T any](i int, tester BasicTester[T]) TestStruct { S: s, }, }, + DifferentField: s, + OracleID: commontypes.OracleID(i + 1), + OracleIDs: [32]commontypes.OracleID{commontypes.OracleID(i + 2), commontypes.OracleID(i + 3)}, + Account: tester.GetAccountBytes(i + 3), + Accounts: [][]byte{tester.GetAccountBytes(i + 4), tester.GetAccountBytes(i + 5)}, + BigField: big.NewInt(int64((i + 1) * (i + 2))), + } +} + +func Compare[T int32](a, b T, op primitives.ComparisonOperator) bool { + switch op { + case primitives.Eq: + return a == b + case primitives.Neq: + return a != b + case primitives.Gt: + return a > b + case primitives.Lt: + return a < b + case primitives.Gte: + return a >= b + case primitives.Lte: + return a <= b } + return false } diff --git a/pkg/types/query/primitives/primitives.go b/pkg/types/query/primitives/primitives.go index 2285237ae..b0fdf3b98 100644 --- a/pkg/types/query/primitives/primitives.go +++ b/pkg/types/query/primitives/primitives.go @@ -25,6 +25,25 @@ const ( Lte ) +func (cmpOp ComparisonOperator) String() string { + switch cmpOp { + case Eq: + return "==" + case Neq: + return "!=" + case Gt: + return ">" + case Lt: + return "<" + case Gte: + return ">=" + case Lte: + return "<=" + default: + return "Unknown" + } +} + type ValueComparator struct { Value any Operator ComparisonOperator