From 7d7da1852873e366c82cf629324a38881934a6ba Mon Sep 17 00:00:00 2001 From: johnabass Date: Wed, 10 Jul 2024 21:55:35 -0700 Subject: [PATCH 1/4] incorporate the generic source type into validation --- basculehttp/middleware.go | 12 +- validator.go | 224 +++++++++++++++++++++++++++++++++----- validator_test.go | 66 ----------- 3 files changed, 204 insertions(+), 98 deletions(-) diff --git a/basculehttp/middleware.go b/basculehttp/middleware.go index 0c9f4b5..a94c392 100644 --- a/basculehttp/middleware.go +++ b/basculehttp/middleware.go @@ -51,9 +51,9 @@ func WithTokenParser(scheme bascule.Scheme, tp bascule.TokenParser[*http.Request // WithAuthentication adds validators used for authentication to this Middleware. Each // invocation of this option is cumulative. Authentication validators are run in the order // supplied by this option. -func WithAuthentication(v ...bascule.Validator) MiddlewareOption { +func WithAuthentication(v ...bascule.Validator[*http.Request]) MiddlewareOption { return middlewareOptionFunc(func(m *Middleware) error { - m.authentication.Add(v...) + m.authentication = m.authentication.Append(v...) return nil }) } @@ -114,7 +114,7 @@ func WithErrorMarshaler(em ErrorMarshaler) MiddlewareOption { type Middleware struct { credentialsParser bascule.CredentialsParser[*http.Request] tokenParsers bascule.TokenParsers[*http.Request] - authentication bascule.Validators + authentication bascule.Validators[*http.Request] authorization bascule.Authorizers[*http.Request] challenges Challenges @@ -194,8 +194,8 @@ func (m *Middleware) getCredentialsAndToken(ctx context.Context, request *http.R return } -func (m *Middleware) authenticate(ctx context.Context, token bascule.Token) error { - return m.authentication.Validate(ctx, token) +func (m *Middleware) authenticate(ctx context.Context, request *http.Request, token bascule.Token) (bascule.Token, error) { + return m.authentication.Validate(ctx, request, token) } func (m *Middleware) authorize(ctx context.Context, token bascule.Token, request *http.Request) error { @@ -220,7 +220,7 @@ func (fd *frontDoor) ServeHTTP(response http.ResponseWriter, request *http.Reque } ctx = bascule.WithCredentials(ctx, creds) - err = fd.middleware.authenticate(ctx, token) + token, err = fd.middleware.authenticate(ctx, request, token) if err != nil { // at this point in the workflow, the request has valid credentials. we use // StatusForbidden as the default because any failure to authenticate isn't a diff --git a/validator.go b/validator.go index 487f16b..5aaa32b 100644 --- a/validator.go +++ b/validator.go @@ -5,50 +5,222 @@ package bascule import ( "context" + "reflect" ) // Validator represents a general strategy for validating tokens. Token validation -// typically happens during authentication, but it can also happen during parsing -// if a caller uses NewValidatingTokenParser. -type Validator interface { +// typically happens during authentication. +type Validator[S any] interface { // Validate validates a token. If this validator needs to interact // with external systems, the supplied context can be passed to honor - // cancelation semantics. + // cancelation semantics. Additionally, the source object from which the + // token was taken is made available. // // This method may be passed a token that it doesn't support, e.g. a Basic // validator can be passed a JWT token. In that case, this method should - // simply return nil. - Validate(context.Context, Token) error + // simply return a nil error. + // + // If this method returns a nil token, then the supplied token should be used + // as is. If this method returns a non-nil token, that new new token should be + // used instead. This allows a validator to augment a token with additional + // data, possibly from an external system or database. + Validate(ctx context.Context, source S, t Token) (Token, error) } -// ValidatorFunc is a closure type that implements Validator. -type ValidatorFunc func(context.Context, Token) error +// Validate applies several validators to the given token. Although each individual +// validator may return a nil Token to indicate that there is no change in the token, +// this function will always return a non-nil Token. +// +// This function returns the validated Token and a nil error to indicate success. +// If any validator fails, this function halts further validation and returns +// the error. +func Validate[S any](ctx context.Context, source S, original Token, v ...Validator[S]) (validated Token, err error) { + next := original + for i, prev := 0, next; err == nil && i < len(v); i, prev = i+1, next { + next, err = v[i].Validate(ctx, source, prev) + if next == nil { + // no change in the token + next = prev + } + } + + if err == nil { + validated = next + } -func (vf ValidatorFunc) Validate(ctx context.Context, token Token) error { - return vf(ctx, token) + return } -// Validators is an aggregate Validator. -type Validators []Validator +// Validators is an aggregate Validator that returns validity if and only if +// all of its contained validators return validity. +type Validators[S any] []Validator[S] + +// Append tacks on more validators to this aggregate, returning the possibly new +// instance. The semantics of this method are the same as the built-in append. +func (vs Validators[S]) Append(more ...Validator[S]) Validators[S] { + return append(vs, more...) +} + +// Validate executes each contained validator in order, returning validity only +// if all validators pass. Any validation failure prevents subsequent validators +// from running. +func (vs Validators[S]) Validate(ctx context.Context, source S, t Token) (Token, error) { + return Validate(ctx, source, t, vs...) +} + +// ValidatorFunc defines the closure signatures that are allowed as Validator instances. +type ValidatorFunc[S any] interface { + ~func(Token) error | + ~func(S, Token) error | + ~func(Token) (Token, error) | + ~func(S, Token) (Token, error) | + ~func(context.Context, Token) error | + ~func(context.Context, S, Token) error | + ~func(context.Context, Token) (Token, error) | + ~func(context.Context, S, Token) (Token, error) +} + +// validatorFunc is an internal type that implements Validator. Used to normalize +// and uncurry a closure. +type validatorFunc[S any] func(context.Context, S, Token) (Token, error) + +func (vf validatorFunc[S]) Validate(ctx context.Context, source S, t Token) (Token, error) { + return vf(ctx, source, t) +} + +var ( + tokenReturnsError = reflect.TypeOf((func(Token) error)(nil)) + tokenReturnsTokenError = reflect.TypeOf((func(Token) (Token, error))(nil)) + contextTokenReturnsError = reflect.TypeOf((func(context.Context, Token) error)(nil)) + contextTokenReturnsTokenError = reflect.TypeOf((func(context.Context, Token) (Token, error))(nil)) +) + +// asValidatorSimple tries simple conversions on f. This function will not catch +// user-defined types. +func asValidatorSimple[S any, F ValidatorFunc[S]](f F) (v Validator[S]) { + switch vf := any(f).(type) { + case func(Token) error: + v = validatorFunc[S]( + func(ctx context.Context, source S, t Token) (Token, error) { + return nil, vf(t) + }, + ) -// Add appends validators to this aggregate Validators. -func (vs *Validators) Add(v ...Validator) { - if *vs == nil { - *vs = make(Validators, 0, len(v)) + case func(S, Token) error: + v = validatorFunc[S]( + func(ctx context.Context, source S, t Token) (Token, error) { + return nil, vf(source, t) + }, + ) + + case func(Token) (Token, error): + v = validatorFunc[S]( + func(ctx context.Context, source S, t Token) (next Token, err error) { + next, err = vf(t) + if next == nil { + next = t + } + + return + }, + ) + + case func(S, Token) (Token, error): + v = validatorFunc[S]( + func(ctx context.Context, source S, t Token) (next Token, err error) { + next, err = vf(source, t) + if next == nil { + next = t + } + + return + }, + ) + + case func(context.Context, Token) error: + v = validatorFunc[S]( + func(ctx context.Context, source S, t Token) (Token, error) { + return nil, vf(ctx, t) + }, + ) + + case func(context.Context, S, Token) error: + v = validatorFunc[S]( + func(ctx context.Context, source S, t Token) (Token, error) { + return nil, vf(ctx, source, t) + }, + ) + + case func(context.Context, Token) (Token, error): + v = validatorFunc[S]( + func(ctx context.Context, source S, t Token) (next Token, err error) { + next, err = vf(ctx, t) + if next == nil { + next = t + } + + return + }, + ) + + case func(context.Context, S, Token) (Token, error): + v = validatorFunc[S](vf) } - *vs = append(*vs, v...) + return } -// Validate applies each validator in sequence. Execution stops at the first validator -// that returns an error, and that error is returned. If all validators return nil, -// this method returns nil, indicating the Token is valid. -func (vs Validators) Validate(ctx context.Context, token Token) error { - for _, v := range vs { - if err := v.Validate(ctx, token); err != nil { - return err - } +// AsValidator takes a ValidatorFunc closure and returns a Validator instance that +// executes that closure. +func AsValidator[S any, F ValidatorFunc[S]](f F) Validator[S] { + // first, try the simple way: + if v := asValidatorSimple[S](f); v != nil { + return v + } + + // next, support user-defined types that are closures that do not + // require the source type. + fVal := reflect.ValueOf(f) + switch { + case fVal.CanConvert(tokenReturnsError): + return asValidatorSimple[S]( + fVal.Convert(tokenReturnsError).Interface().(func(Token) error), + ) + + case fVal.CanConvert(tokenReturnsTokenError): + return asValidatorSimple[S]( + fVal.Convert(tokenReturnsError).Interface().(func(Token) (Token, error)), + ) + + case fVal.CanConvert(contextTokenReturnsError): + return asValidatorSimple[S]( + fVal.Convert(contextTokenReturnsError).Interface().(func(context.Context, Token) error), + ) + + case fVal.CanConvert(contextTokenReturnsTokenError): + return asValidatorSimple[S]( + fVal.Convert(contextTokenReturnsError).Interface().(func(context.Context, Token) (Token, error)), + ) } - return nil + // finally: user-defined types that are closures involving the source type S. + // we have to look these up here, due to the way generics in golang work. + if ft := reflect.TypeOf((func(S, Token) error)(nil)); fVal.CanConvert(ft) { + return asValidatorSimple[S]( + fVal.Convert(ft).Interface().(func(S, Token) error), + ) + } else if ft := reflect.TypeOf((func(S, Token) (Token, error))(nil)); fVal.CanConvert(ft) { + return asValidatorSimple[S]( + fVal.Convert(ft).Interface().(func(S, Token) (Token, error)), + ) + } else if ft := reflect.TypeOf((func(context.Context, S, Token) error)(nil)); fVal.CanConvert(ft) { + return asValidatorSimple[S]( + fVal.Convert(ft).Interface().(func(context.Context, S, Token) error), + ) + } else { + // we know this can be converted to this final type + return asValidatorSimple[S]( + fVal.Convert(ft).Interface().(func(context.Context, S, Token) (Token, error)), + ) + } } diff --git a/validator_test.go b/validator_test.go index 4ea2601..94fd939 100644 --- a/validator_test.go +++ b/validator_test.go @@ -4,8 +4,6 @@ package bascule import ( - "context" - "errors" "testing" "github.com/stretchr/testify/suite" @@ -15,70 +13,6 @@ type ValidatorsTestSuite struct { TestSuite } -func (suite *ValidatorsTestSuite) TestValidate() { - validateErr := errors.New("expected Validate error") - - testCases := []struct { - name string - results []error - expectedErr error - }{ - { - name: "EmptyValidators", - results: nil, - }, - { - name: "OneSuccess", - results: []error{nil}, - }, - { - name: "OneFailure", - results: []error{validateErr}, - expectedErr: validateErr, - }, - { - name: "FirstFailure", - results: []error{validateErr, errors.New("should not be called")}, - expectedErr: validateErr, - }, - { - name: "MiddleFailure", - results: []error{nil, validateErr, errors.New("should not be called")}, - expectedErr: validateErr, - }, - { - name: "AllSuccess", - results: []error{nil, nil, nil}, - }, - } - - for _, testCase := range testCases { - suite.Run(testCase.name, func() { - var ( - testCtx = suite.testContext() - testToken = suite.testToken() - vs Validators - ) - - for _, err := range testCase.results { - err := err - vs.Add( - ValidatorFunc(func(ctx context.Context, token Token) error { - suite.Same(testCtx, ctx) - suite.Same(testToken, token) - return err - }), - ) - } - - suite.Equal( - testCase.expectedErr, - vs.Validate(testCtx, testToken), - ) - }) - } -} - func TestValidators(t *testing.T) { suite.Run(t, new(ValidatorsTestSuite)) } From 539a004f65e23308fdb2b1e6a2e5dfbe44f8e746 Mon Sep 17 00:00:00 2001 From: johnabass Date: Thu, 11 Jul 2024 12:50:06 -0700 Subject: [PATCH 2/4] refactored the Authorizer API to be consistent with Validator --- authorizer.go | 28 +++++++++++++--------------- authorizer_test.go | 16 ++++++++-------- basculehttp/middleware.go | 4 ++-- 3 files changed, 23 insertions(+), 25 deletions(-) diff --git a/authorizer.go b/authorizer.go index 9818416..79e03a2 100644 --- a/authorizer.go +++ b/authorizer.go @@ -18,26 +18,24 @@ type Authorizer[R any] interface { // cancelation semantics. // // If this method doesn't support the given token, it should return nil. - Authorize(ctx context.Context, token Token, resource R) error + Authorize(ctx context.Context, resource R, token Token) error } // AuthorizerFunc is a closure type that implements Authorizer. -type AuthorizerFunc[R any] func(context.Context, Token, R) error +type AuthorizerFunc[R any] func(context.Context, R, Token) error -func (af AuthorizerFunc[R]) Authorize(ctx context.Context, token Token, resource R) error { - return af(ctx, token, resource) +func (af AuthorizerFunc[R]) Authorize(ctx context.Context, resource R, token Token) error { + return af(ctx, resource, token) } // Authorizers is a collection of Authorizers. type Authorizers[R any] []Authorizer[R] -// Add appends authorizers to this aggregate Authorizers. -func (as *Authorizers[R]) Add(a ...Authorizer[R]) { - if *as == nil { - *as = make(Authorizers[R], 0, len(a)) - } - - *as = append(*as, a...) +// Append tacks on one or more authorizers to this collection. The possibly +// new Authorizers instance is returned. The semantics of this method are +// the same as the built-in append. +func (as Authorizers[R]) Append(a ...Authorizer[R]) Authorizers[R] { + return append(as, a...) } // Authorize requires all authorizers in this sequence to allow access. This @@ -45,9 +43,9 @@ func (as *Authorizers[R]) Add(a ...Authorizer[R]) { // // Because authorization can be arbitrarily expensive, execution halts at the first failed // authorization attempt. -func (as Authorizers[R]) Authorize(ctx context.Context, token Token, resource R) error { +func (as Authorizers[R]) Authorize(ctx context.Context, resource R, token Token) error { for _, a := range as { - if err := a.Authorize(ctx, token, resource); err != nil { + if err := a.Authorize(ctx, resource, token); err != nil { return err } } @@ -61,10 +59,10 @@ type requireAny[R any] struct { // Authorize returns nil at the first authorizer that returns nil, i.e. accepts the access. // Otherwise, this method returns an aggregate error of all the authorization errors. -func (ra requireAny[R]) Authorize(ctx context.Context, token Token, resource R) error { +func (ra requireAny[R]) Authorize(ctx context.Context, resource R, token Token) error { var err error for _, a := range ra.a { - authErr := a.Authorize(ctx, token, resource) + authErr := a.Authorize(ctx, resource, token) if authErr == nil { return nil } diff --git a/authorizer_test.go b/authorizer_test.go index a77ba48..3f5c564 100644 --- a/authorizer_test.go +++ b/authorizer_test.go @@ -63,8 +63,8 @@ func (suite *AuthorizersTestSuite) TestAuthorize() { for _, err := range testCase.results { err := err - as.Add( - AuthorizerFunc[string](func(ctx context.Context, token Token, resource string) error { + as = as.Append( + AuthorizerFunc[string](func(ctx context.Context, resource string, token Token) error { suite.Same(testCtx, ctx) suite.Same(testToken, token) suite.Equal(placeholderResource, resource) @@ -75,7 +75,7 @@ func (suite *AuthorizersTestSuite) TestAuthorize() { suite.Equal( testCase.expectedErr, - as.Authorize(testCtx, testToken, placeholderResource), + as.Authorize(testCtx, placeholderResource, testToken), ) }) } @@ -123,8 +123,8 @@ func (suite *AuthorizersTestSuite) TestAny() { for _, err := range testCase.results { err := err - as.Add( - AuthorizerFunc[string](func(ctx context.Context, token Token, resource string) error { + as = as.Append( + AuthorizerFunc[string](func(ctx context.Context, resource string, token Token) error { suite.Same(testCtx, ctx) suite.Same(testToken, token) suite.Equal(placeholderResource, resource) @@ -136,13 +136,13 @@ func (suite *AuthorizersTestSuite) TestAny() { anyAs := as.Any() suite.Equal( testCase.expectedErr, - anyAs.Authorize(testCtx, testToken, placeholderResource), + anyAs.Authorize(testCtx, placeholderResource, testToken), ) if len(as) > 0 { // the any instance should be distinct as[0] = AuthorizerFunc[string]( - func(context.Context, Token, string) error { + func(context.Context, string, Token) error { suite.Fail("should not have been called") return nil }, @@ -150,7 +150,7 @@ func (suite *AuthorizersTestSuite) TestAny() { suite.Equal( testCase.expectedErr, - anyAs.Authorize(testCtx, testToken, placeholderResource), + anyAs.Authorize(testCtx, placeholderResource, testToken), ) } }) diff --git a/basculehttp/middleware.go b/basculehttp/middleware.go index a94c392..9868f6e 100644 --- a/basculehttp/middleware.go +++ b/basculehttp/middleware.go @@ -77,7 +77,7 @@ func WithChallenges(ch ...Challenge) MiddlewareOption { // This is useful for use cases like admin access or alternate capabilities. func WithAuthorization(a ...bascule.Authorizer[*http.Request]) MiddlewareOption { return middlewareOptionFunc(func(m *Middleware) error { - m.authorization.Add(a...) + m.authorization = m.authorization.Append(a...) return nil }) } @@ -199,7 +199,7 @@ func (m *Middleware) authenticate(ctx context.Context, request *http.Request, to } func (m *Middleware) authorize(ctx context.Context, token bascule.Token, request *http.Request) error { - return m.authorization.Authorize(ctx, token, request) + return m.authorization.Authorize(ctx, request, token) } // frontDoor is the internal handler implementation that protects a handler From 7b342c473eacec34109a70c945d3fdc5a0f527c0 Mon Sep 17 00:00:00 2001 From: johnabass Date: Mon, 29 Jul 2024 11:46:48 -0700 Subject: [PATCH 3/4] refactored validator code to allow access to the source during validation --- authorizer_test.go | 4 +- mocks_test.go | 36 +++++ testSuite_test.go | 12 +- validator.go | 45 +++--- validator_test.go | 354 +++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 419 insertions(+), 32 deletions(-) create mode 100644 mocks_test.go diff --git a/authorizer_test.go b/authorizer_test.go index 3f5c564..a878a55 100644 --- a/authorizer_test.go +++ b/authorizer_test.go @@ -66,7 +66,7 @@ func (suite *AuthorizersTestSuite) TestAuthorize() { as = as.Append( AuthorizerFunc[string](func(ctx context.Context, resource string, token Token) error { suite.Same(testCtx, ctx) - suite.Same(testToken, token) + suite.Equal(testToken, token) suite.Equal(placeholderResource, resource) return err }), @@ -126,7 +126,7 @@ func (suite *AuthorizersTestSuite) TestAny() { as = as.Append( AuthorizerFunc[string](func(ctx context.Context, resource string, token Token) error { suite.Same(testCtx, ctx) - suite.Same(testToken, token) + suite.Equal(testToken, token) suite.Equal(placeholderResource, resource) return err }), diff --git a/mocks_test.go b/mocks_test.go new file mode 100644 index 0000000..00d0e3b --- /dev/null +++ b/mocks_test.go @@ -0,0 +1,36 @@ +// SPDX-FileCopyrightText: 2020 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package bascule + +import ( + "context" + + "github.com/stretchr/testify/mock" +) + +type testToken string + +func (tt testToken) Principal() string { return string(tt) } + +type mockValidator[S any] struct { + mock.Mock +} + +func (m *mockValidator[S]) Validate(ctx context.Context, source S, token Token) (Token, error) { + args := m.Called(ctx, source, token) + t, _ := args.Get(0).(Token) + return t, args.Error(1) +} + +func (m *mockValidator[S]) ExpectValidate(ctx context.Context, source S, token Token) *mock.Call { + return m.On("Validate", ctx, source, token) +} + +func assertValidators[S any](t mock.TestingT, vs ...Validator[S]) (passed bool) { + for _, v := range vs { + passed = v.(*mockValidator[S]).AssertExpectations(t) && passed + } + + return +} diff --git a/testSuite_test.go b/testSuite_test.go index 554ef02..987cf63 100644 --- a/testSuite_test.go +++ b/testSuite_test.go @@ -14,14 +14,6 @@ const ( testScheme Scheme = "Test" ) -type testToken struct { - principal string -} - -func (tt *testToken) Principal() string { - return tt.principal -} - // TestSuite holds generally useful functionality for testing bascule. type TestSuite struct { suite.Suite @@ -43,9 +35,7 @@ func (suite *TestSuite) testCredentials() Credentials { } func (suite *TestSuite) testToken() Token { - return &testToken{ - principal: "test", - } + return testToken("test") } func (suite *TestSuite) contexter(ctx context.Context) Contexter { diff --git a/validator.go b/validator.go index 5aaa32b..c0bea98 100644 --- a/validator.go +++ b/validator.go @@ -84,15 +84,20 @@ type ValidatorFunc[S any] interface { // and uncurry a closure. type validatorFunc[S any] func(context.Context, S, Token) (Token, error) -func (vf validatorFunc[S]) Validate(ctx context.Context, source S, t Token) (Token, error) { - return vf(ctx, source, t) +func (vf validatorFunc[S]) Validate(ctx context.Context, source S, t Token) (next Token, err error) { + next, err = vf(ctx, source, t) + if next == nil { + next = t + } + + return } var ( - tokenReturnsError = reflect.TypeOf((func(Token) error)(nil)) - tokenReturnsTokenError = reflect.TypeOf((func(Token) (Token, error))(nil)) - contextTokenReturnsError = reflect.TypeOf((func(context.Context, Token) error)(nil)) - contextTokenReturnsTokenError = reflect.TypeOf((func(context.Context, Token) (Token, error))(nil)) + tokenReturnError = reflect.TypeOf((func(Token) error)(nil)) + tokenReturnTokenAndError = reflect.TypeOf((func(Token) (Token, error))(nil)) + contextTokenReturnError = reflect.TypeOf((func(context.Context, Token) error)(nil)) + contextTokenReturnTokenError = reflect.TypeOf((func(context.Context, Token) (Token, error))(nil)) ) // asValidatorSimple tries simple conversions on f. This function will not catch @@ -102,14 +107,14 @@ func asValidatorSimple[S any, F ValidatorFunc[S]](f F) (v Validator[S]) { case func(Token) error: v = validatorFunc[S]( func(ctx context.Context, source S, t Token) (Token, error) { - return nil, vf(t) + return t, vf(t) }, ) case func(S, Token) error: v = validatorFunc[S]( func(ctx context.Context, source S, t Token) (Token, error) { - return nil, vf(source, t) + return t, vf(source, t) }, ) @@ -140,14 +145,14 @@ func asValidatorSimple[S any, F ValidatorFunc[S]](f F) (v Validator[S]) { case func(context.Context, Token) error: v = validatorFunc[S]( func(ctx context.Context, source S, t Token) (Token, error) { - return nil, vf(ctx, t) + return t, vf(ctx, t) }, ) case func(context.Context, S, Token) error: v = validatorFunc[S]( func(ctx context.Context, source S, t Token) (Token, error) { - return nil, vf(ctx, source, t) + return t, vf(ctx, source, t) }, ) @@ -171,7 +176,8 @@ func asValidatorSimple[S any, F ValidatorFunc[S]](f F) (v Validator[S]) { } // AsValidator takes a ValidatorFunc closure and returns a Validator instance that -// executes that closure. +// executes that closure. This function can also convert custom types which can +// be converted to any of the closure signatures. func AsValidator[S any, F ValidatorFunc[S]](f F) Validator[S] { // first, try the simple way: if v := asValidatorSimple[S](f); v != nil { @@ -182,24 +188,24 @@ func AsValidator[S any, F ValidatorFunc[S]](f F) Validator[S] { // require the source type. fVal := reflect.ValueOf(f) switch { - case fVal.CanConvert(tokenReturnsError): + case fVal.CanConvert(tokenReturnError): return asValidatorSimple[S]( - fVal.Convert(tokenReturnsError).Interface().(func(Token) error), + fVal.Convert(tokenReturnError).Interface().(func(Token) error), ) - case fVal.CanConvert(tokenReturnsTokenError): + case fVal.CanConvert(tokenReturnTokenAndError): return asValidatorSimple[S]( - fVal.Convert(tokenReturnsError).Interface().(func(Token) (Token, error)), + fVal.Convert(tokenReturnTokenAndError).Interface().(func(Token) (Token, error)), ) - case fVal.CanConvert(contextTokenReturnsError): + case fVal.CanConvert(contextTokenReturnError): return asValidatorSimple[S]( - fVal.Convert(contextTokenReturnsError).Interface().(func(context.Context, Token) error), + fVal.Convert(contextTokenReturnError).Interface().(func(context.Context, Token) error), ) - case fVal.CanConvert(contextTokenReturnsTokenError): + case fVal.CanConvert(contextTokenReturnTokenError): return asValidatorSimple[S]( - fVal.Convert(contextTokenReturnsError).Interface().(func(context.Context, Token) (Token, error)), + fVal.Convert(contextTokenReturnTokenError).Interface().(func(context.Context, Token) (Token, error)), ) } @@ -219,6 +225,7 @@ func AsValidator[S any, F ValidatorFunc[S]](f F) Validator[S] { ) } else { // we know this can be converted to this final type + ft := reflect.TypeOf((func(context.Context, S, Token) (Token, error))(nil)) return asValidatorSimple[S]( fVal.Convert(ft).Interface().(func(context.Context, S, Token) (Token, error)), ) diff --git a/validator_test.go b/validator_test.go index 94fd939..af45572 100644 --- a/validator_test.go +++ b/validator_test.go @@ -4,6 +4,9 @@ package bascule import ( + "context" + "errors" + "fmt" "testing" "github.com/stretchr/testify/suite" @@ -11,6 +14,357 @@ import ( type ValidatorsTestSuite struct { TestSuite + + expectedCtx context.Context + expectedSource int + inputToken Token + outputToken Token + expectedErr error +} + +func (suite *ValidatorsTestSuite) SetupSuite() { + type contextKey struct{} + suite.expectedCtx = context.WithValue( + context.Background(), + contextKey{}, + "value", + ) + + suite.expectedSource = 123 + suite.inputToken = testToken("input token") + suite.outputToken = testToken("output token") + suite.expectedErr = errors.New("expected validator error") +} + +// assertNoTransform verifies that the validator returns the same token as the input token. +func (suite *ValidatorsTestSuite) assertNoTransform(v Validator[int]) { + suite.Require().NotNil(v) + actualToken, actualErr := v.Validate(suite.expectedCtx, suite.expectedSource, suite.inputToken) + suite.Equal(suite.inputToken, actualToken) + suite.ErrorIs(suite.expectedErr, actualErr) +} + +// assertTransform verifies a validator that returns a different token than the input token. +func (suite *ValidatorsTestSuite) assertTransform(v Validator[int]) { + suite.Require().NotNil(v) + actualToken, actualErr := v.Validate(suite.expectedCtx, suite.expectedSource, suite.inputToken) + suite.Equal(suite.outputToken, actualToken) + suite.ErrorIs(suite.expectedErr, actualErr) +} + +// validateToken is a ValidatorFunc of the signature func(Token) error +func (suite *ValidatorsTestSuite) validateToken(actualToken Token) error { + suite.Equal(suite.inputToken, actualToken) + return suite.expectedErr +} + +// validateSourceToken is a ValidatorFunc of the signature func(, Token) error +func (suite *ValidatorsTestSuite) validateSourceToken(actualSource int, actualToken Token) error { + suite.Equal(suite.expectedSource, actualSource) + suite.Equal(suite.inputToken, actualToken) + return suite.expectedErr +} + +// validateContextToken is a ValidatorFunc of the signature func(context.Context, Token) error +func (suite *ValidatorsTestSuite) validateContextToken(actualCtx context.Context, actualToken Token) error { + suite.Equal(suite.expectedCtx, actualCtx) + suite.Equal(suite.inputToken, actualToken) + return suite.expectedErr +} + +// validateContextSourceToken is a ValidatorFunc of the signature func(context.Context, , Token) error +func (suite *ValidatorsTestSuite) validateContextSourceToken(actualCtx context.Context, actualSource int, actualToken Token) error { + suite.Equal(suite.expectedCtx, actualCtx) + suite.Equal(suite.expectedSource, actualSource) + suite.Equal(suite.inputToken, actualToken) + return suite.expectedErr +} + +// transformToken is a ValidatorFunc of the signature func(Token) (Token, error). +// This variant returns suite.outputToken. +func (suite *ValidatorsTestSuite) transformToken(actualToken Token) (Token, error) { + suite.Equal(suite.inputToken, actualToken) + return suite.outputToken, suite.expectedErr +} + +// transformTokenToNil is a ValidatorFunc of the signature func(Token) (Token, error). +// This variant returns a nil Token, indicating that the original token is unchanged. +func (suite *ValidatorsTestSuite) transformTokenToNil(actualToken Token) (Token, error) { + suite.Equal(suite.inputToken, actualToken) + return nil, suite.expectedErr +} + +// transformSourceToken is a ValidatorFunc of the signature func(, Token) (Token, error) +// This variant returns suite.outputToken. +func (suite *ValidatorsTestSuite) transformSourceToken(actualSource int, actualToken Token) (Token, error) { + suite.Equal(suite.expectedSource, actualSource) + suite.Equal(suite.inputToken, actualToken) + return suite.outputToken, suite.expectedErr +} + +// transformSourceTokenToNil is a ValidatorFunc of the signature func(, Token) (Token, error) +// This variant returns a nil Token, indicating that the original token is unchanged. +func (suite *ValidatorsTestSuite) transformSourceTokenToNil(actualSource int, actualToken Token) (Token, error) { + suite.Equal(suite.expectedSource, actualSource) + suite.Equal(suite.inputToken, actualToken) + return nil, suite.expectedErr +} + +// transformContextToken is a ValidatorFunc of the signature func(context.context, Token) (Token, error) +// This variant returns suite.outputToken. +func (suite *ValidatorsTestSuite) transformContextToken(actualCtx context.Context, actualToken Token) (Token, error) { + suite.Equal(suite.expectedCtx, actualCtx) + suite.Equal(suite.inputToken, actualToken) + return suite.outputToken, suite.expectedErr +} + +// transformContextTokenToNil is a ValidatorFunc of the signature func(context.context, Token) (Token, error) +// This variant returns a nil Token, indicating that the original token is unchanged. +func (suite *ValidatorsTestSuite) transformContextTokenToNil(actualCtx context.Context, actualToken Token) (Token, error) { + suite.Equal(suite.expectedCtx, actualCtx) + suite.Equal(suite.inputToken, actualToken) + return nil, suite.expectedErr +} + +// transformContextSourceToken is a ValidatorFunc of the signature func(context.Context, , Token) (Token, error) +// This variant returns suite.outputToken. +func (suite *ValidatorsTestSuite) transformContextSourceToken(actualCtx context.Context, actualSource int, actualToken Token) (Token, error) { + suite.Equal(suite.expectedCtx, actualCtx) + suite.Equal(suite.expectedSource, actualSource) + suite.Equal(suite.inputToken, actualToken) + return suite.outputToken, suite.expectedErr +} + +// transformContextSourceToken is a ValidatorFunc of the signature func(context.Context, , Token) (Token, error) +// This variant returns a nil Token, indicating that the original token is unchanged. +func (suite *ValidatorsTestSuite) transformContextSourceTokenToNil(actualCtx context.Context, actualSource int, actualToken Token) (Token, error) { + suite.Equal(suite.expectedCtx, actualCtx) + suite.Equal(suite.expectedSource, actualSource) + suite.Equal(suite.inputToken, actualToken) + return nil, suite.expectedErr +} + +func (suite *ValidatorsTestSuite) testAsValidatorToken() { + suite.Run("ReturnError", func() { + suite.Run("Simple", func() { + v := AsValidator[int](suite.validateToken) + suite.assertNoTransform(v) + }) + + suite.Run("CustomType", func() { + type Custom func(Token) error + f := Custom(suite.validateToken) + v := AsValidator[int](f) + suite.assertNoTransform(v) + }) + }) + + suite.Run("ReturnTokenError", func() { + suite.Run("Simple", func() { + v := AsValidator[int](suite.transformToken) + suite.assertTransform(v) + }) + + suite.Run("NilOutputToken", func() { + v := AsValidator[int](suite.transformTokenToNil) + suite.assertNoTransform(v) + }) + + suite.Run("CustomType", func() { + type Custom func(Token) (Token, error) + f := Custom(suite.transformToken) + v := AsValidator[int](f) + suite.assertTransform(v) + }) + }) +} + +func (suite *ValidatorsTestSuite) testAsValidatorSourceToken() { + suite.Run("ReturnError", func() { + suite.Run("Simple", func() { + v := AsValidator[int](suite.validateSourceToken) + suite.assertNoTransform(v) + }) + + suite.Run("CustomType", func() { + type Custom func(int, Token) error + f := Custom(suite.validateSourceToken) + v := AsValidator[int](f) + suite.assertNoTransform(v) + }) + }) + + suite.Run("ReturnTokenError", func() { + suite.Run("Simple", func() { + v := AsValidator[int](suite.transformSourceToken) + suite.assertTransform(v) + }) + + suite.Run("NilOutputToken", func() { + v := AsValidator[int](suite.transformSourceTokenToNil) + suite.assertNoTransform(v) + }) + + suite.Run("CustomType", func() { + type Custom func(int, Token) (Token, error) + f := Custom(suite.transformSourceToken) + v := AsValidator[int](f) + suite.assertTransform(v) + }) + }) +} + +func (suite *ValidatorsTestSuite) testAsValidatorContextToken() { + suite.Run("ReturnError", func() { + suite.Run("Simple", func() { + v := AsValidator[int](suite.validateContextToken) + suite.assertNoTransform(v) + }) + + suite.Run("CustomType", func() { + type Custom func(context.Context, Token) error + f := Custom(suite.validateContextToken) + v := AsValidator[int](f) + suite.assertNoTransform(v) + }) + }) + + suite.Run("ReturnTokenError", func() { + suite.Run("Simple", func() { + v := AsValidator[int](suite.transformContextToken) + suite.assertTransform(v) + }) + + suite.Run("NilOutputToken", func() { + v := AsValidator[int](suite.transformContextTokenToNil) + suite.assertNoTransform(v) + }) + + suite.Run("CustomType", func() { + type Custom func(context.Context, Token) (Token, error) + f := Custom(suite.transformContextToken) + v := AsValidator[int](f) + suite.assertTransform(v) + }) + }) +} + +func (suite *ValidatorsTestSuite) testAsValidatorContextSourceToken() { + suite.Run("ReturnError", func() { + suite.Run("Simple", func() { + v := AsValidator[int](suite.validateContextSourceToken) + suite.assertNoTransform(v) + }) + + suite.Run("CustomType", func() { + type Custom func(context.Context, int, Token) error + f := Custom(suite.validateContextSourceToken) + v := AsValidator[int](f) + suite.assertNoTransform(v) + }) + }) + + suite.Run("ReturnTokenError", func() { + suite.Run("Simple", func() { + v := AsValidator[int](suite.transformContextSourceToken) + suite.assertTransform(v) + }) + + suite.Run("NilOutputToken", func() { + v := AsValidator[int](suite.transformContextSourceTokenToNil) + suite.assertNoTransform(v) + }) + + suite.Run("CustomType", func() { + type Custom func(context.Context, int, Token) (Token, error) + f := Custom(suite.transformContextSourceToken) + v := AsValidator[int](f) + suite.assertTransform(v) + }) + }) +} + +func (suite *ValidatorsTestSuite) TestAsValidator() { + suite.Run("Token", suite.testAsValidatorToken) + suite.Run("SourceToken", suite.testAsValidatorSourceToken) + suite.Run("ContextToken", suite.testAsValidatorContextToken) + suite.Run("ContextSourceToken", suite.testAsValidatorContextSourceToken) +} + +// newValidators constructs an array of validators that can only be called once +// and which successfully validate the suite's input token. +func (suite *ValidatorsTestSuite) newValidators(count int) (vs []Validator[int]) { + vs = make([]Validator[int], 0, count) + for len(vs) < cap(vs) { + v := new(mockValidator[int]) + v.ExpectValidate(suite.expectedCtx, suite.expectedSource, suite.inputToken). + Return(nil, nil).Once() + + vs = append(vs, v) + } + + return +} + +func (suite *ValidatorsTestSuite) TestValidate() { + suite.Run("NoValidators", func() { + outputToken, err := Validate[int](suite.expectedCtx, suite.expectedSource, suite.inputToken) + suite.Equal(suite.inputToken, outputToken) + suite.NoError(err) + }) + + suite.Run("NilOutputToken", func() { + for _, count := range []int{1, 2, 5} { + suite.Run(fmt.Sprintf("count=%d", count), func() { + vs := suite.newValidators(count) + actualToken, actualErr := Validate(suite.expectedCtx, suite.expectedSource, suite.inputToken, vs...) + suite.Equal(suite.inputToken, actualToken) + suite.NoError(actualErr) + assertValidators(suite.T(), vs...) + }) + } + }) +} + +func (suite *ValidatorsTestSuite) TestCompositeValidators() { + suite.Run("Empty", func() { + var vs Validators[int] + outputToken, err := vs.Validate(suite.expectedCtx, suite.expectedSource, suite.inputToken) + suite.Equal(suite.inputToken, outputToken) + suite.NoError(err) + }) + + suite.Run("NotEmpty", func() { + suite.Run("len=1", func() { + v := new(mockValidator[int]) + v.ExpectValidate(suite.expectedCtx, suite.expectedSource, suite.inputToken). + Return(suite.outputToken, nil).Once() + + var vs Validators[int] + vs = vs.Append(v) + actualToken, actualErr := vs.Validate(suite.expectedCtx, suite.expectedSource, suite.inputToken) + suite.Equal(suite.outputToken, actualToken) + suite.NoError(actualErr) + assertValidators(suite.T(), v) + }) + + suite.Run("len=2", func() { + v1 := new(mockValidator[int]) + v1.ExpectValidate(suite.expectedCtx, suite.expectedSource, suite.inputToken). + Return(nil, nil).Once() + + v2 := new(mockValidator[int]) + v2.ExpectValidate(suite.expectedCtx, suite.expectedSource, suite.inputToken). + Return(nil, nil).Once() + + var vs Validators[int] + vs = vs.Append(v1, v2) + actualToken, actualErr := vs.Validate(suite.expectedCtx, suite.expectedSource, suite.inputToken) + suite.Equal(suite.inputToken, actualToken) // the token should be unchanged + suite.NoError(actualErr) + assertValidators(suite.T(), v1, v2) + }) + }) } func TestValidators(t *testing.T) { From 04766591944957bc2f818f3382fc776da7e2b51d Mon Sep 17 00:00:00 2001 From: johnabass Date: Mon, 29 Jul 2024 12:07:22 -0700 Subject: [PATCH 4/4] go mod tidy --- go.mod | 1 + go.sum | 2 ++ 2 files changed, 3 insertions(+) diff --git a/go.mod b/go.mod index 8085c91..ad283f5 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,7 @@ require ( github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/rogpeppe/go-internal v1.10.0 // indirect github.com/segmentio/asm v1.2.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect golang.org/x/crypto v0.24.0 // indirect golang.org/x/sys v0.21.0 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect diff --git a/go.sum b/go.sum index ce32581..10874eb 100644 --- a/go.sum +++ b/go.sum @@ -35,6 +35,8 @@ github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncj github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys= github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=