diff --git a/authorizer.go b/authorizer.go index 9818416..79e03a2 100644 --- a/authorizer.go +++ b/authorizer.go @@ -18,26 +18,24 @@ type Authorizer[R any] interface { // cancelation semantics. // // If this method doesn't support the given token, it should return nil. - Authorize(ctx context.Context, token Token, resource R) error + Authorize(ctx context.Context, resource R, token Token) error } // AuthorizerFunc is a closure type that implements Authorizer. -type AuthorizerFunc[R any] func(context.Context, Token, R) error +type AuthorizerFunc[R any] func(context.Context, R, Token) error -func (af AuthorizerFunc[R]) Authorize(ctx context.Context, token Token, resource R) error { - return af(ctx, token, resource) +func (af AuthorizerFunc[R]) Authorize(ctx context.Context, resource R, token Token) error { + return af(ctx, resource, token) } // Authorizers is a collection of Authorizers. type Authorizers[R any] []Authorizer[R] -// Add appends authorizers to this aggregate Authorizers. -func (as *Authorizers[R]) Add(a ...Authorizer[R]) { - if *as == nil { - *as = make(Authorizers[R], 0, len(a)) - } - - *as = append(*as, a...) +// Append tacks on one or more authorizers to this collection. The possibly +// new Authorizers instance is returned. The semantics of this method are +// the same as the built-in append. +func (as Authorizers[R]) Append(a ...Authorizer[R]) Authorizers[R] { + return append(as, a...) } // Authorize requires all authorizers in this sequence to allow access. This @@ -45,9 +43,9 @@ func (as *Authorizers[R]) Add(a ...Authorizer[R]) { // // Because authorization can be arbitrarily expensive, execution halts at the first failed // authorization attempt. -func (as Authorizers[R]) Authorize(ctx context.Context, token Token, resource R) error { +func (as Authorizers[R]) Authorize(ctx context.Context, resource R, token Token) error { for _, a := range as { - if err := a.Authorize(ctx, token, resource); err != nil { + if err := a.Authorize(ctx, resource, token); err != nil { return err } } @@ -61,10 +59,10 @@ type requireAny[R any] struct { // Authorize returns nil at the first authorizer that returns nil, i.e. accepts the access. // Otherwise, this method returns an aggregate error of all the authorization errors. -func (ra requireAny[R]) Authorize(ctx context.Context, token Token, resource R) error { +func (ra requireAny[R]) Authorize(ctx context.Context, resource R, token Token) error { var err error for _, a := range ra.a { - authErr := a.Authorize(ctx, token, resource) + authErr := a.Authorize(ctx, resource, token) if authErr == nil { return nil } diff --git a/authorizer_test.go b/authorizer_test.go index a77ba48..a878a55 100644 --- a/authorizer_test.go +++ b/authorizer_test.go @@ -63,10 +63,10 @@ func (suite *AuthorizersTestSuite) TestAuthorize() { for _, err := range testCase.results { err := err - as.Add( - AuthorizerFunc[string](func(ctx context.Context, token Token, resource string) error { + as = as.Append( + AuthorizerFunc[string](func(ctx context.Context, resource string, token Token) error { suite.Same(testCtx, ctx) - suite.Same(testToken, token) + suite.Equal(testToken, token) suite.Equal(placeholderResource, resource) return err }), @@ -75,7 +75,7 @@ func (suite *AuthorizersTestSuite) TestAuthorize() { suite.Equal( testCase.expectedErr, - as.Authorize(testCtx, testToken, placeholderResource), + as.Authorize(testCtx, placeholderResource, testToken), ) }) } @@ -123,10 +123,10 @@ func (suite *AuthorizersTestSuite) TestAny() { for _, err := range testCase.results { err := err - as.Add( - AuthorizerFunc[string](func(ctx context.Context, token Token, resource string) error { + as = as.Append( + AuthorizerFunc[string](func(ctx context.Context, resource string, token Token) error { suite.Same(testCtx, ctx) - suite.Same(testToken, token) + suite.Equal(testToken, token) suite.Equal(placeholderResource, resource) return err }), @@ -136,13 +136,13 @@ func (suite *AuthorizersTestSuite) TestAny() { anyAs := as.Any() suite.Equal( testCase.expectedErr, - anyAs.Authorize(testCtx, testToken, placeholderResource), + anyAs.Authorize(testCtx, placeholderResource, testToken), ) if len(as) > 0 { // the any instance should be distinct as[0] = AuthorizerFunc[string]( - func(context.Context, Token, string) error { + func(context.Context, string, Token) error { suite.Fail("should not have been called") return nil }, @@ -150,7 +150,7 @@ func (suite *AuthorizersTestSuite) TestAny() { suite.Equal( testCase.expectedErr, - anyAs.Authorize(testCtx, testToken, placeholderResource), + anyAs.Authorize(testCtx, placeholderResource, testToken), ) } }) diff --git a/basculehttp/middleware.go b/basculehttp/middleware.go index 0c9f4b5..9868f6e 100644 --- a/basculehttp/middleware.go +++ b/basculehttp/middleware.go @@ -51,9 +51,9 @@ func WithTokenParser(scheme bascule.Scheme, tp bascule.TokenParser[*http.Request // WithAuthentication adds validators used for authentication to this Middleware. Each // invocation of this option is cumulative. Authentication validators are run in the order // supplied by this option. -func WithAuthentication(v ...bascule.Validator) MiddlewareOption { +func WithAuthentication(v ...bascule.Validator[*http.Request]) MiddlewareOption { return middlewareOptionFunc(func(m *Middleware) error { - m.authentication.Add(v...) + m.authentication = m.authentication.Append(v...) return nil }) } @@ -77,7 +77,7 @@ func WithChallenges(ch ...Challenge) MiddlewareOption { // This is useful for use cases like admin access or alternate capabilities. func WithAuthorization(a ...bascule.Authorizer[*http.Request]) MiddlewareOption { return middlewareOptionFunc(func(m *Middleware) error { - m.authorization.Add(a...) + m.authorization = m.authorization.Append(a...) return nil }) } @@ -114,7 +114,7 @@ func WithErrorMarshaler(em ErrorMarshaler) MiddlewareOption { type Middleware struct { credentialsParser bascule.CredentialsParser[*http.Request] tokenParsers bascule.TokenParsers[*http.Request] - authentication bascule.Validators + authentication bascule.Validators[*http.Request] authorization bascule.Authorizers[*http.Request] challenges Challenges @@ -194,12 +194,12 @@ func (m *Middleware) getCredentialsAndToken(ctx context.Context, request *http.R return } -func (m *Middleware) authenticate(ctx context.Context, token bascule.Token) error { - return m.authentication.Validate(ctx, token) +func (m *Middleware) authenticate(ctx context.Context, request *http.Request, token bascule.Token) (bascule.Token, error) { + return m.authentication.Validate(ctx, request, token) } func (m *Middleware) authorize(ctx context.Context, token bascule.Token, request *http.Request) error { - return m.authorization.Authorize(ctx, token, request) + return m.authorization.Authorize(ctx, request, token) } // frontDoor is the internal handler implementation that protects a handler @@ -220,7 +220,7 @@ func (fd *frontDoor) ServeHTTP(response http.ResponseWriter, request *http.Reque } ctx = bascule.WithCredentials(ctx, creds) - err = fd.middleware.authenticate(ctx, token) + token, err = fd.middleware.authenticate(ctx, request, token) if err != nil { // at this point in the workflow, the request has valid credentials. we use // StatusForbidden as the default because any failure to authenticate isn't a diff --git a/go.mod b/go.mod index b3d4c5b..25b14e9 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,7 @@ require ( github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/rogpeppe/go-internal v1.10.0 // indirect github.com/segmentio/asm v1.2.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect golang.org/x/crypto v0.25.0 // indirect golang.org/x/sys v0.22.0 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect diff --git a/go.sum b/go.sum index d01eb9f..143d6b7 100644 --- a/go.sum +++ b/go.sum @@ -35,6 +35,8 @@ github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncj github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys= github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= diff --git a/mocks_test.go b/mocks_test.go new file mode 100644 index 0000000..00d0e3b --- /dev/null +++ b/mocks_test.go @@ -0,0 +1,36 @@ +// SPDX-FileCopyrightText: 2020 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package bascule + +import ( + "context" + + "github.com/stretchr/testify/mock" +) + +type testToken string + +func (tt testToken) Principal() string { return string(tt) } + +type mockValidator[S any] struct { + mock.Mock +} + +func (m *mockValidator[S]) Validate(ctx context.Context, source S, token Token) (Token, error) { + args := m.Called(ctx, source, token) + t, _ := args.Get(0).(Token) + return t, args.Error(1) +} + +func (m *mockValidator[S]) ExpectValidate(ctx context.Context, source S, token Token) *mock.Call { + return m.On("Validate", ctx, source, token) +} + +func assertValidators[S any](t mock.TestingT, vs ...Validator[S]) (passed bool) { + for _, v := range vs { + passed = v.(*mockValidator[S]).AssertExpectations(t) && passed + } + + return +} diff --git a/testSuite_test.go b/testSuite_test.go index 554ef02..987cf63 100644 --- a/testSuite_test.go +++ b/testSuite_test.go @@ -14,14 +14,6 @@ const ( testScheme Scheme = "Test" ) -type testToken struct { - principal string -} - -func (tt *testToken) Principal() string { - return tt.principal -} - // TestSuite holds generally useful functionality for testing bascule. type TestSuite struct { suite.Suite @@ -43,9 +35,7 @@ func (suite *TestSuite) testCredentials() Credentials { } func (suite *TestSuite) testToken() Token { - return &testToken{ - principal: "test", - } + return testToken("test") } func (suite *TestSuite) contexter(ctx context.Context) Contexter { diff --git a/validator.go b/validator.go index 487f16b..c0bea98 100644 --- a/validator.go +++ b/validator.go @@ -5,50 +5,229 @@ package bascule import ( "context" + "reflect" ) // Validator represents a general strategy for validating tokens. Token validation -// typically happens during authentication, but it can also happen during parsing -// if a caller uses NewValidatingTokenParser. -type Validator interface { +// typically happens during authentication. +type Validator[S any] interface { // Validate validates a token. If this validator needs to interact // with external systems, the supplied context can be passed to honor - // cancelation semantics. + // cancelation semantics. Additionally, the source object from which the + // token was taken is made available. // // This method may be passed a token that it doesn't support, e.g. a Basic // validator can be passed a JWT token. In that case, this method should - // simply return nil. - Validate(context.Context, Token) error + // simply return a nil error. + // + // If this method returns a nil token, then the supplied token should be used + // as is. If this method returns a non-nil token, that new new token should be + // used instead. This allows a validator to augment a token with additional + // data, possibly from an external system or database. + Validate(ctx context.Context, source S, t Token) (Token, error) +} + +// Validate applies several validators to the given token. Although each individual +// validator may return a nil Token to indicate that there is no change in the token, +// this function will always return a non-nil Token. +// +// This function returns the validated Token and a nil error to indicate success. +// If any validator fails, this function halts further validation and returns +// the error. +func Validate[S any](ctx context.Context, source S, original Token, v ...Validator[S]) (validated Token, err error) { + next := original + for i, prev := 0, next; err == nil && i < len(v); i, prev = i+1, next { + next, err = v[i].Validate(ctx, source, prev) + if next == nil { + // no change in the token + next = prev + } + } + + if err == nil { + validated = next + } + + return +} + +// Validators is an aggregate Validator that returns validity if and only if +// all of its contained validators return validity. +type Validators[S any] []Validator[S] + +// Append tacks on more validators to this aggregate, returning the possibly new +// instance. The semantics of this method are the same as the built-in append. +func (vs Validators[S]) Append(more ...Validator[S]) Validators[S] { + return append(vs, more...) } -// ValidatorFunc is a closure type that implements Validator. -type ValidatorFunc func(context.Context, Token) error +// Validate executes each contained validator in order, returning validity only +// if all validators pass. Any validation failure prevents subsequent validators +// from running. +func (vs Validators[S]) Validate(ctx context.Context, source S, t Token) (Token, error) { + return Validate(ctx, source, t, vs...) +} -func (vf ValidatorFunc) Validate(ctx context.Context, token Token) error { - return vf(ctx, token) +// ValidatorFunc defines the closure signatures that are allowed as Validator instances. +type ValidatorFunc[S any] interface { + ~func(Token) error | + ~func(S, Token) error | + ~func(Token) (Token, error) | + ~func(S, Token) (Token, error) | + ~func(context.Context, Token) error | + ~func(context.Context, S, Token) error | + ~func(context.Context, Token) (Token, error) | + ~func(context.Context, S, Token) (Token, error) } -// Validators is an aggregate Validator. -type Validators []Validator +// validatorFunc is an internal type that implements Validator. Used to normalize +// and uncurry a closure. +type validatorFunc[S any] func(context.Context, S, Token) (Token, error) -// Add appends validators to this aggregate Validators. -func (vs *Validators) Add(v ...Validator) { - if *vs == nil { - *vs = make(Validators, 0, len(v)) +func (vf validatorFunc[S]) Validate(ctx context.Context, source S, t Token) (next Token, err error) { + next, err = vf(ctx, source, t) + if next == nil { + next = t } - *vs = append(*vs, v...) + return } -// Validate applies each validator in sequence. Execution stops at the first validator -// that returns an error, and that error is returned. If all validators return nil, -// this method returns nil, indicating the Token is valid. -func (vs Validators) Validate(ctx context.Context, token Token) error { - for _, v := range vs { - if err := v.Validate(ctx, token); err != nil { - return err - } +var ( + tokenReturnError = reflect.TypeOf((func(Token) error)(nil)) + tokenReturnTokenAndError = reflect.TypeOf((func(Token) (Token, error))(nil)) + contextTokenReturnError = reflect.TypeOf((func(context.Context, Token) error)(nil)) + contextTokenReturnTokenError = reflect.TypeOf((func(context.Context, Token) (Token, error))(nil)) +) + +// asValidatorSimple tries simple conversions on f. This function will not catch +// user-defined types. +func asValidatorSimple[S any, F ValidatorFunc[S]](f F) (v Validator[S]) { + switch vf := any(f).(type) { + case func(Token) error: + v = validatorFunc[S]( + func(ctx context.Context, source S, t Token) (Token, error) { + return t, vf(t) + }, + ) + + case func(S, Token) error: + v = validatorFunc[S]( + func(ctx context.Context, source S, t Token) (Token, error) { + return t, vf(source, t) + }, + ) + + case func(Token) (Token, error): + v = validatorFunc[S]( + func(ctx context.Context, source S, t Token) (next Token, err error) { + next, err = vf(t) + if next == nil { + next = t + } + + return + }, + ) + + case func(S, Token) (Token, error): + v = validatorFunc[S]( + func(ctx context.Context, source S, t Token) (next Token, err error) { + next, err = vf(source, t) + if next == nil { + next = t + } + + return + }, + ) + + case func(context.Context, Token) error: + v = validatorFunc[S]( + func(ctx context.Context, source S, t Token) (Token, error) { + return t, vf(ctx, t) + }, + ) + + case func(context.Context, S, Token) error: + v = validatorFunc[S]( + func(ctx context.Context, source S, t Token) (Token, error) { + return t, vf(ctx, source, t) + }, + ) + + case func(context.Context, Token) (Token, error): + v = validatorFunc[S]( + func(ctx context.Context, source S, t Token) (next Token, err error) { + next, err = vf(ctx, t) + if next == nil { + next = t + } + + return + }, + ) + + case func(context.Context, S, Token) (Token, error): + v = validatorFunc[S](vf) } - return nil + return +} + +// AsValidator takes a ValidatorFunc closure and returns a Validator instance that +// executes that closure. This function can also convert custom types which can +// be converted to any of the closure signatures. +func AsValidator[S any, F ValidatorFunc[S]](f F) Validator[S] { + // first, try the simple way: + if v := asValidatorSimple[S](f); v != nil { + return v + } + + // next, support user-defined types that are closures that do not + // require the source type. + fVal := reflect.ValueOf(f) + switch { + case fVal.CanConvert(tokenReturnError): + return asValidatorSimple[S]( + fVal.Convert(tokenReturnError).Interface().(func(Token) error), + ) + + case fVal.CanConvert(tokenReturnTokenAndError): + return asValidatorSimple[S]( + fVal.Convert(tokenReturnTokenAndError).Interface().(func(Token) (Token, error)), + ) + + case fVal.CanConvert(contextTokenReturnError): + return asValidatorSimple[S]( + fVal.Convert(contextTokenReturnError).Interface().(func(context.Context, Token) error), + ) + + case fVal.CanConvert(contextTokenReturnTokenError): + return asValidatorSimple[S]( + fVal.Convert(contextTokenReturnTokenError).Interface().(func(context.Context, Token) (Token, error)), + ) + } + + // finally: user-defined types that are closures involving the source type S. + // we have to look these up here, due to the way generics in golang work. + if ft := reflect.TypeOf((func(S, Token) error)(nil)); fVal.CanConvert(ft) { + return asValidatorSimple[S]( + fVal.Convert(ft).Interface().(func(S, Token) error), + ) + } else if ft := reflect.TypeOf((func(S, Token) (Token, error))(nil)); fVal.CanConvert(ft) { + return asValidatorSimple[S]( + fVal.Convert(ft).Interface().(func(S, Token) (Token, error)), + ) + } else if ft := reflect.TypeOf((func(context.Context, S, Token) error)(nil)); fVal.CanConvert(ft) { + return asValidatorSimple[S]( + fVal.Convert(ft).Interface().(func(context.Context, S, Token) error), + ) + } else { + // we know this can be converted to this final type + ft := reflect.TypeOf((func(context.Context, S, Token) (Token, error))(nil)) + return asValidatorSimple[S]( + fVal.Convert(ft).Interface().(func(context.Context, S, Token) (Token, error)), + ) + } } diff --git a/validator_test.go b/validator_test.go index 4ea2601..af45572 100644 --- a/validator_test.go +++ b/validator_test.go @@ -6,6 +6,7 @@ package bascule import ( "context" "errors" + "fmt" "testing" "github.com/stretchr/testify/suite" @@ -13,70 +14,357 @@ import ( type ValidatorsTestSuite struct { TestSuite + + expectedCtx context.Context + expectedSource int + inputToken Token + outputToken Token + expectedErr error } -func (suite *ValidatorsTestSuite) TestValidate() { - validateErr := errors.New("expected Validate error") - - testCases := []struct { - name string - results []error - expectedErr error - }{ - { - name: "EmptyValidators", - results: nil, - }, - { - name: "OneSuccess", - results: []error{nil}, - }, - { - name: "OneFailure", - results: []error{validateErr}, - expectedErr: validateErr, - }, - { - name: "FirstFailure", - results: []error{validateErr, errors.New("should not be called")}, - expectedErr: validateErr, - }, - { - name: "MiddleFailure", - results: []error{nil, validateErr, errors.New("should not be called")}, - expectedErr: validateErr, - }, - { - name: "AllSuccess", - results: []error{nil, nil, nil}, - }, - } +func (suite *ValidatorsTestSuite) SetupSuite() { + type contextKey struct{} + suite.expectedCtx = context.WithValue( + context.Background(), + contextKey{}, + "value", + ) + + suite.expectedSource = 123 + suite.inputToken = testToken("input token") + suite.outputToken = testToken("output token") + suite.expectedErr = errors.New("expected validator error") +} + +// assertNoTransform verifies that the validator returns the same token as the input token. +func (suite *ValidatorsTestSuite) assertNoTransform(v Validator[int]) { + suite.Require().NotNil(v) + actualToken, actualErr := v.Validate(suite.expectedCtx, suite.expectedSource, suite.inputToken) + suite.Equal(suite.inputToken, actualToken) + suite.ErrorIs(suite.expectedErr, actualErr) +} + +// assertTransform verifies a validator that returns a different token than the input token. +func (suite *ValidatorsTestSuite) assertTransform(v Validator[int]) { + suite.Require().NotNil(v) + actualToken, actualErr := v.Validate(suite.expectedCtx, suite.expectedSource, suite.inputToken) + suite.Equal(suite.outputToken, actualToken) + suite.ErrorIs(suite.expectedErr, actualErr) +} + +// validateToken is a ValidatorFunc of the signature func(Token) error +func (suite *ValidatorsTestSuite) validateToken(actualToken Token) error { + suite.Equal(suite.inputToken, actualToken) + return suite.expectedErr +} + +// validateSourceToken is a ValidatorFunc of the signature func(, Token) error +func (suite *ValidatorsTestSuite) validateSourceToken(actualSource int, actualToken Token) error { + suite.Equal(suite.expectedSource, actualSource) + suite.Equal(suite.inputToken, actualToken) + return suite.expectedErr +} + +// validateContextToken is a ValidatorFunc of the signature func(context.Context, Token) error +func (suite *ValidatorsTestSuite) validateContextToken(actualCtx context.Context, actualToken Token) error { + suite.Equal(suite.expectedCtx, actualCtx) + suite.Equal(suite.inputToken, actualToken) + return suite.expectedErr +} + +// validateContextSourceToken is a ValidatorFunc of the signature func(context.Context, , Token) error +func (suite *ValidatorsTestSuite) validateContextSourceToken(actualCtx context.Context, actualSource int, actualToken Token) error { + suite.Equal(suite.expectedCtx, actualCtx) + suite.Equal(suite.expectedSource, actualSource) + suite.Equal(suite.inputToken, actualToken) + return suite.expectedErr +} + +// transformToken is a ValidatorFunc of the signature func(Token) (Token, error). +// This variant returns suite.outputToken. +func (suite *ValidatorsTestSuite) transformToken(actualToken Token) (Token, error) { + suite.Equal(suite.inputToken, actualToken) + return suite.outputToken, suite.expectedErr +} + +// transformTokenToNil is a ValidatorFunc of the signature func(Token) (Token, error). +// This variant returns a nil Token, indicating that the original token is unchanged. +func (suite *ValidatorsTestSuite) transformTokenToNil(actualToken Token) (Token, error) { + suite.Equal(suite.inputToken, actualToken) + return nil, suite.expectedErr +} - for _, testCase := range testCases { - suite.Run(testCase.name, func() { - var ( - testCtx = suite.testContext() - testToken = suite.testToken() - vs Validators - ) - - for _, err := range testCase.results { - err := err - vs.Add( - ValidatorFunc(func(ctx context.Context, token Token) error { - suite.Same(testCtx, ctx) - suite.Same(testToken, token) - return err - }), - ) - } - - suite.Equal( - testCase.expectedErr, - vs.Validate(testCtx, testToken), - ) +// transformSourceToken is a ValidatorFunc of the signature func(, Token) (Token, error) +// This variant returns suite.outputToken. +func (suite *ValidatorsTestSuite) transformSourceToken(actualSource int, actualToken Token) (Token, error) { + suite.Equal(suite.expectedSource, actualSource) + suite.Equal(suite.inputToken, actualToken) + return suite.outputToken, suite.expectedErr +} + +// transformSourceTokenToNil is a ValidatorFunc of the signature func(, Token) (Token, error) +// This variant returns a nil Token, indicating that the original token is unchanged. +func (suite *ValidatorsTestSuite) transformSourceTokenToNil(actualSource int, actualToken Token) (Token, error) { + suite.Equal(suite.expectedSource, actualSource) + suite.Equal(suite.inputToken, actualToken) + return nil, suite.expectedErr +} + +// transformContextToken is a ValidatorFunc of the signature func(context.context, Token) (Token, error) +// This variant returns suite.outputToken. +func (suite *ValidatorsTestSuite) transformContextToken(actualCtx context.Context, actualToken Token) (Token, error) { + suite.Equal(suite.expectedCtx, actualCtx) + suite.Equal(suite.inputToken, actualToken) + return suite.outputToken, suite.expectedErr +} + +// transformContextTokenToNil is a ValidatorFunc of the signature func(context.context, Token) (Token, error) +// This variant returns a nil Token, indicating that the original token is unchanged. +func (suite *ValidatorsTestSuite) transformContextTokenToNil(actualCtx context.Context, actualToken Token) (Token, error) { + suite.Equal(suite.expectedCtx, actualCtx) + suite.Equal(suite.inputToken, actualToken) + return nil, suite.expectedErr +} + +// transformContextSourceToken is a ValidatorFunc of the signature func(context.Context, , Token) (Token, error) +// This variant returns suite.outputToken. +func (suite *ValidatorsTestSuite) transformContextSourceToken(actualCtx context.Context, actualSource int, actualToken Token) (Token, error) { + suite.Equal(suite.expectedCtx, actualCtx) + suite.Equal(suite.expectedSource, actualSource) + suite.Equal(suite.inputToken, actualToken) + return suite.outputToken, suite.expectedErr +} + +// transformContextSourceToken is a ValidatorFunc of the signature func(context.Context, , Token) (Token, error) +// This variant returns a nil Token, indicating that the original token is unchanged. +func (suite *ValidatorsTestSuite) transformContextSourceTokenToNil(actualCtx context.Context, actualSource int, actualToken Token) (Token, error) { + suite.Equal(suite.expectedCtx, actualCtx) + suite.Equal(suite.expectedSource, actualSource) + suite.Equal(suite.inputToken, actualToken) + return nil, suite.expectedErr +} + +func (suite *ValidatorsTestSuite) testAsValidatorToken() { + suite.Run("ReturnError", func() { + suite.Run("Simple", func() { + v := AsValidator[int](suite.validateToken) + suite.assertNoTransform(v) }) + + suite.Run("CustomType", func() { + type Custom func(Token) error + f := Custom(suite.validateToken) + v := AsValidator[int](f) + suite.assertNoTransform(v) + }) + }) + + suite.Run("ReturnTokenError", func() { + suite.Run("Simple", func() { + v := AsValidator[int](suite.transformToken) + suite.assertTransform(v) + }) + + suite.Run("NilOutputToken", func() { + v := AsValidator[int](suite.transformTokenToNil) + suite.assertNoTransform(v) + }) + + suite.Run("CustomType", func() { + type Custom func(Token) (Token, error) + f := Custom(suite.transformToken) + v := AsValidator[int](f) + suite.assertTransform(v) + }) + }) +} + +func (suite *ValidatorsTestSuite) testAsValidatorSourceToken() { + suite.Run("ReturnError", func() { + suite.Run("Simple", func() { + v := AsValidator[int](suite.validateSourceToken) + suite.assertNoTransform(v) + }) + + suite.Run("CustomType", func() { + type Custom func(int, Token) error + f := Custom(suite.validateSourceToken) + v := AsValidator[int](f) + suite.assertNoTransform(v) + }) + }) + + suite.Run("ReturnTokenError", func() { + suite.Run("Simple", func() { + v := AsValidator[int](suite.transformSourceToken) + suite.assertTransform(v) + }) + + suite.Run("NilOutputToken", func() { + v := AsValidator[int](suite.transformSourceTokenToNil) + suite.assertNoTransform(v) + }) + + suite.Run("CustomType", func() { + type Custom func(int, Token) (Token, error) + f := Custom(suite.transformSourceToken) + v := AsValidator[int](f) + suite.assertTransform(v) + }) + }) +} + +func (suite *ValidatorsTestSuite) testAsValidatorContextToken() { + suite.Run("ReturnError", func() { + suite.Run("Simple", func() { + v := AsValidator[int](suite.validateContextToken) + suite.assertNoTransform(v) + }) + + suite.Run("CustomType", func() { + type Custom func(context.Context, Token) error + f := Custom(suite.validateContextToken) + v := AsValidator[int](f) + suite.assertNoTransform(v) + }) + }) + + suite.Run("ReturnTokenError", func() { + suite.Run("Simple", func() { + v := AsValidator[int](suite.transformContextToken) + suite.assertTransform(v) + }) + + suite.Run("NilOutputToken", func() { + v := AsValidator[int](suite.transformContextTokenToNil) + suite.assertNoTransform(v) + }) + + suite.Run("CustomType", func() { + type Custom func(context.Context, Token) (Token, error) + f := Custom(suite.transformContextToken) + v := AsValidator[int](f) + suite.assertTransform(v) + }) + }) +} + +func (suite *ValidatorsTestSuite) testAsValidatorContextSourceToken() { + suite.Run("ReturnError", func() { + suite.Run("Simple", func() { + v := AsValidator[int](suite.validateContextSourceToken) + suite.assertNoTransform(v) + }) + + suite.Run("CustomType", func() { + type Custom func(context.Context, int, Token) error + f := Custom(suite.validateContextSourceToken) + v := AsValidator[int](f) + suite.assertNoTransform(v) + }) + }) + + suite.Run("ReturnTokenError", func() { + suite.Run("Simple", func() { + v := AsValidator[int](suite.transformContextSourceToken) + suite.assertTransform(v) + }) + + suite.Run("NilOutputToken", func() { + v := AsValidator[int](suite.transformContextSourceTokenToNil) + suite.assertNoTransform(v) + }) + + suite.Run("CustomType", func() { + type Custom func(context.Context, int, Token) (Token, error) + f := Custom(suite.transformContextSourceToken) + v := AsValidator[int](f) + suite.assertTransform(v) + }) + }) +} + +func (suite *ValidatorsTestSuite) TestAsValidator() { + suite.Run("Token", suite.testAsValidatorToken) + suite.Run("SourceToken", suite.testAsValidatorSourceToken) + suite.Run("ContextToken", suite.testAsValidatorContextToken) + suite.Run("ContextSourceToken", suite.testAsValidatorContextSourceToken) +} + +// newValidators constructs an array of validators that can only be called once +// and which successfully validate the suite's input token. +func (suite *ValidatorsTestSuite) newValidators(count int) (vs []Validator[int]) { + vs = make([]Validator[int], 0, count) + for len(vs) < cap(vs) { + v := new(mockValidator[int]) + v.ExpectValidate(suite.expectedCtx, suite.expectedSource, suite.inputToken). + Return(nil, nil).Once() + + vs = append(vs, v) } + + return +} + +func (suite *ValidatorsTestSuite) TestValidate() { + suite.Run("NoValidators", func() { + outputToken, err := Validate[int](suite.expectedCtx, suite.expectedSource, suite.inputToken) + suite.Equal(suite.inputToken, outputToken) + suite.NoError(err) + }) + + suite.Run("NilOutputToken", func() { + for _, count := range []int{1, 2, 5} { + suite.Run(fmt.Sprintf("count=%d", count), func() { + vs := suite.newValidators(count) + actualToken, actualErr := Validate(suite.expectedCtx, suite.expectedSource, suite.inputToken, vs...) + suite.Equal(suite.inputToken, actualToken) + suite.NoError(actualErr) + assertValidators(suite.T(), vs...) + }) + } + }) +} + +func (suite *ValidatorsTestSuite) TestCompositeValidators() { + suite.Run("Empty", func() { + var vs Validators[int] + outputToken, err := vs.Validate(suite.expectedCtx, suite.expectedSource, suite.inputToken) + suite.Equal(suite.inputToken, outputToken) + suite.NoError(err) + }) + + suite.Run("NotEmpty", func() { + suite.Run("len=1", func() { + v := new(mockValidator[int]) + v.ExpectValidate(suite.expectedCtx, suite.expectedSource, suite.inputToken). + Return(suite.outputToken, nil).Once() + + var vs Validators[int] + vs = vs.Append(v) + actualToken, actualErr := vs.Validate(suite.expectedCtx, suite.expectedSource, suite.inputToken) + suite.Equal(suite.outputToken, actualToken) + suite.NoError(actualErr) + assertValidators(suite.T(), v) + }) + + suite.Run("len=2", func() { + v1 := new(mockValidator[int]) + v1.ExpectValidate(suite.expectedCtx, suite.expectedSource, suite.inputToken). + Return(nil, nil).Once() + + v2 := new(mockValidator[int]) + v2.ExpectValidate(suite.expectedCtx, suite.expectedSource, suite.inputToken). + Return(nil, nil).Once() + + var vs Validators[int] + vs = vs.Append(v1, v2) + actualToken, actualErr := vs.Validate(suite.expectedCtx, suite.expectedSource, suite.inputToken) + suite.Equal(suite.inputToken, actualToken) // the token should be unchanged + suite.NoError(actualErr) + assertValidators(suite.T(), v1, v2) + }) + }) } func TestValidators(t *testing.T) {