From 271ba61890b728532a388487e08a3743b19204a4 Mon Sep 17 00:00:00 2001 From: johnabass Date: Wed, 31 Jul 2024 17:40:44 -0700 Subject: [PATCH 1/6] refactored token parsing to remove credentials and simplify the API --- authorizer.go | 4 +- basculehttp/credentials.go | 7 -- context.go | 25 ----- context_test.go | 84 --------------- credentials.go | 32 ------ credentials_test.go | 43 -------- error.go | 109 ------------------- error_test.go | 68 ------------ mocks_test.go | 22 ++++ testSuite_test.go | 11 -- token.go | 119 ++++++++++++++++----- token_test.go | 208 ++++++++++++++++++++++++++++--------- validator.go | 24 ++--- validator_test.go | 8 +- 14 files changed, 289 insertions(+), 475 deletions(-) delete mode 100644 credentials.go delete mode 100644 credentials_test.go delete mode 100644 error.go delete mode 100644 error_test.go diff --git a/authorizer.go b/authorizer.go index 79e03a2..5917b06 100644 --- a/authorizer.go +++ b/authorizer.go @@ -34,8 +34,8 @@ type Authorizers[R any] []Authorizer[R] // Append tacks on one or more authorizers to this collection. The possibly // new Authorizers instance is returned. The semantics of this method are // the same as the built-in append. -func (as Authorizers[R]) Append(a ...Authorizer[R]) Authorizers[R] { - return append(as, a...) +func (as Authorizers[R]) Append(more ...Authorizer[R]) Authorizers[R] { + return append(as, more...) } // Authorize requires all authorizers in this sequence to allow access. This diff --git a/basculehttp/credentials.go b/basculehttp/credentials.go index 000d6ca..025119e 100644 --- a/basculehttp/credentials.go +++ b/basculehttp/credentials.go @@ -52,13 +52,6 @@ func (err *MissingHeaderError) StatusCode() int { return http.StatusUnauthorized } -// fastIsSpace tests an ASCII byte to see if it's whitespace. -// HTTP headers are restricted to US-ASCII, so we don't need -// the full unicode stack. -func fastIsSpace(b byte) bool { - return b == ' ' || b == '\t' || b == '\n' || b == '\r' || b == '\v' || b == '\f' -} - // DefaultCredentialsParser is the default algorithm used to produce HTTP credentials // from a source request. type DefaultCredentialsParser struct { diff --git a/context.go b/context.go index 3207fd9..bf9804d 100644 --- a/context.go +++ b/context.go @@ -13,31 +13,6 @@ type Contexter interface { Context() context.Context } -type credentialsContextKey struct{} - -// GetCredentials examines the context and returns the credentials used to -// build the Token. If no credentials are in the context, this function -// returns false. -func GetCredentials(ctx context.Context) (c Credentials, found bool) { - c, found = ctx.Value(credentialsContextKey{}).(Credentials) - return -} - -// GetCredentialsFrom uses the context held by src to obtain credentials. -// As with GetCredentials, if no credentials are found this function returns false. -func GetCredentialsFrom(src Contexter) (Credentials, bool) { - return GetCredentials(src.Context()) -} - -// WithCredentials constructs a new context with the supplied credentials. -func WithCredentials(ctx context.Context, c Credentials) context.Context { - return context.WithValue( - ctx, - credentialsContextKey{}, - c, - ) -} - type tokenContextKey struct{} // GetToken retrieves a Token from a context. If not token is in the context, diff --git a/context_test.go b/context_test.go index fa0d79a..fb41ec4 100644 --- a/context_test.go +++ b/context_test.go @@ -14,90 +14,6 @@ type ContextTestSuite struct { TestSuite } -func (suite *ContextTestSuite) testGetCredentialsSuccess() { - ctx := context.WithValue( - context.Background(), - credentialsContextKey{}, - suite.testCredentials(), - ) - - creds, ok := GetCredentials(ctx) - suite.Require().True(ok) - suite.Equal( - suite.testCredentials(), - creds, - ) -} - -func (suite *ContextTestSuite) testGetCredentialsMissing() { - creds, ok := GetCredentials(context.Background()) - suite.Equal(Credentials{}, creds) - suite.False(ok) -} - -func (suite *ContextTestSuite) testGetCredentialsWrongType() { - ctx := context.WithValue(context.Background(), credentialsContextKey{}, 123) - creds, ok := GetCredentials(ctx) - suite.Equal(Credentials{}, creds) - suite.False(ok) -} - -func (suite *ContextTestSuite) TestGetCredentials() { - suite.Run("Success", suite.testGetCredentialsSuccess) - suite.Run("Missing", suite.testGetCredentialsMissing) - suite.Run("WrongType", suite.testGetCredentialsWrongType) -} - -func (suite *ContextTestSuite) testGetCredentialsFromSuccess() { - c := suite.contexter( - context.WithValue( - context.Background(), - credentialsContextKey{}, - suite.testCredentials(), - ), - ) - - creds, ok := GetCredentialsFrom(c) - suite.Require().True(ok) - suite.Equal( - suite.testCredentials(), - creds, - ) -} - -func (suite *ContextTestSuite) testGetCredentialsFromMissing() { - creds, ok := GetCredentialsFrom( - suite.contexter(context.Background()), - ) - - suite.Equal(Credentials{}, creds) - suite.False(ok) -} - -func (suite *ContextTestSuite) testGetCredentialsFromWrongType() { - c := suite.contexter( - context.WithValue(context.Background(), credentialsContextKey{}, 123), - ) - - creds, ok := GetCredentialsFrom(c) - suite.Equal(Credentials{}, creds) - suite.False(ok) -} - -func (suite *ContextTestSuite) TestGetCredentialsFrom() { - suite.Run("Success", suite.testGetCredentialsFromSuccess) - suite.Run("Missing", suite.testGetCredentialsFromMissing) - suite.Run("WrongType", suite.testGetCredentialsFromWrongType) -} - -func (suite *ContextTestSuite) TestWithCredentials() { - ctx := WithCredentials(context.Background(), suite.testCredentials()) - - creds, ok := ctx.Value(credentialsContextKey{}).(Credentials) - suite.Require().True(ok) - suite.Equal(suite.testCredentials(), creds) -} - func (suite *ContextTestSuite) testGetTokenSuccess() { ctx := context.WithValue( context.Background(), diff --git a/credentials.go b/credentials.go deleted file mode 100644 index eeda081..0000000 --- a/credentials.go +++ /dev/null @@ -1,32 +0,0 @@ -// SPDX-FileCopyrightText: 2020 Comcast Cable Communications Management, LLC -// SPDX-License-Identifier: Apache-2.0 - -package bascule - -import "context" - -// Scheme represents how a security token should be parsed. For HTTP, examples -// of a scheme are "Bearer" and "Basic". -type Scheme string - -// Credentials holds the raw, unparsed token information. -type Credentials struct { - // Scheme is the parsing scheme used for the credential value. - Scheme Scheme - - // Value is the raw, unparsed credential information. - Value string -} - -// CredentialsParser produces Credentials from a data source. -type CredentialsParser[S any] interface { - // Parse extracts Credentials from a Source data object. - Parse(ctx context.Context, source S) (Credentials, error) -} - -// CredentialsParserFunc is a function type that implements CredentialsParser. -type CredentialsParserFunc[S any] func(context.Context, S) (Credentials, error) - -func (cpf CredentialsParserFunc[S]) Parse(ctx context.Context, source S) (Credentials, error) { - return cpf(ctx, source) -} diff --git a/credentials_test.go b/credentials_test.go deleted file mode 100644 index 20956ca..0000000 --- a/credentials_test.go +++ /dev/null @@ -1,43 +0,0 @@ -// SPDX-FileCopyrightText: 2020 Comcast Cable Communications Management, LLC -// SPDX-License-Identifier: Apache-2.0 - -package bascule - -import ( - "context" - "errors" - "testing" - - "github.com/stretchr/testify/suite" -) - -type CredentialsTestSuite struct { - suite.Suite -} - -func (suite *CredentialsTestSuite) TestCredentialsParserFunc() { - const expectedRaw = "expected raw credentials" - expectedErr := errors.New("expected error") - var c CredentialsParser[string] = CredentialsParserFunc[string](func(_ context.Context, raw string) (Credentials, error) { - suite.Equal(expectedRaw, raw) - return Credentials{ - Scheme: Scheme("test"), - Value: "value", - }, expectedErr - }) - - creds, err := c.Parse(context.Background(), expectedRaw) - suite.Equal( - Credentials{ - Scheme: Scheme("test"), - Value: "value", - }, - creds, - ) - - suite.Same(expectedErr, err) -} - -func TestCredentials(t *testing.T) { - suite.Run(t, new(CredentialsTestSuite)) -} diff --git a/error.go b/error.go deleted file mode 100644 index 7d29d6f..0000000 --- a/error.go +++ /dev/null @@ -1,109 +0,0 @@ -// SPDX-FileCopyrightText: 2021 Comcast Cable Communications Management, LLC -// SPDX-License-Identifier: Apache-2.0 - -package bascule - -import ( - "errors" - "strings" -) - -// ErrorType is an enumeration type for various types of security errors. -// This type can be used to determine more detail about the context of an error. -type ErrorType int - -const ( - // ErrorTypeUnknown indicates an error that didn't specify an ErrorType, - // possibly because the error didn't implement the Error interface in this package. - ErrorTypeUnknown ErrorType = iota - - // ErrorTypeMissingCredentials indicates that no credentials could be found. - // For example, this is the type used when no credentials are present in an HTTP request. - ErrorTypeMissingCredentials - - // ErrorTypeBadCredentials indcates that credentials exist, but they were badly formatted. - // In other words, bascule could not parse the credentials. - ErrorTypeBadCredentials - - // ErrorTypeInvalidCredentials indicates that credentials exist and are properly formatted, - // but they failed validation. Typically, this is due to failed authentication. It can also - // mean that a token's fields are invalid, such as the exp field of a JWT. - ErrorTypeInvalidCredentials - - // ErrorTypeForbidden indicates that a token did not have sufficient privileges to - // perform an operation. - ErrorTypeForbidden -) - -// Error is an optional interface that errors may implement to expose security -// metadata about the error. -type Error interface { - // Type is the ErrorType describing this error. - Type() ErrorType -} - -type typedError struct { - error - et ErrorType -} - -func (te *typedError) Unwrap() error { return te.error } - -func (te *typedError) Type() ErrorType { return te.et } - -// NewTypedError wraps a given error and associates an ErrorType with it. -// The returned error will implement the Error interface in this package, -// and will have an Unwrap method that returns err. -func NewTypedError(err error, et ErrorType) error { - return &typedError{ - error: err, - et: et, - } -} - -// GetErrorType examines err to determine its associated metadata type. If err -// does not implement Error, this function returns ErrorTypeUnknown. -func GetErrorType(err error) ErrorType { - var e Error - if errors.As(err, &e) { - return e.Type() - } - - return ErrorTypeUnknown -} - -// UnsupportedSchemeError indicates that a credentials scheme was not supported -// by a TokenParser. -type UnsupportedSchemeError struct { - // Scheme is the unsupported credential scheme. - Scheme Scheme -} - -// Type tags errors of this type as ErrorTypeBadCredentials. -func (err *UnsupportedSchemeError) Type() ErrorType { return ErrorTypeBadCredentials } - -func (err *UnsupportedSchemeError) Error() string { - var o strings.Builder - o.WriteString(`Unsupported scheme: "`) - o.WriteString(string(err.Scheme)) - o.WriteRune('"') - return o.String() -} - -// BadCredentialsError is a general-purpose error indicating that credentials -// could not be parsed. -type BadCredentialsError struct { - // Raw is the raw value of the credentials that could not be parsed. - Raw string -} - -// Type tags errors of this type as ErrorTypeBadCredentials. -func (err *BadCredentialsError) Type() ErrorType { return ErrorTypeBadCredentials } - -func (err *BadCredentialsError) Error() string { - var o strings.Builder - o.WriteString(`Bad credentials: "`) - o.WriteString(err.Raw) - o.WriteRune('"') - return o.String() -} diff --git a/error_test.go b/error_test.go deleted file mode 100644 index e715ab6..0000000 --- a/error_test.go +++ /dev/null @@ -1,68 +0,0 @@ -// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC -// SPDX-License-Identifier: Apache-2.0 - -package bascule - -import ( - "errors" - "testing" - - "github.com/stretchr/testify/suite" -) - -type ErrorSuite struct { - suite.Suite -} - -func (suite *ErrorSuite) TestUnsupportedSchemeError() { - err := UnsupportedSchemeError{ - Scheme: Scheme("scheme"), - } - - suite.Contains(err.Error(), "scheme") - suite.Equal(ErrorTypeBadCredentials, err.Type()) -} - -func (suite *ErrorSuite) TestBadCredentialsError() { - err := BadCredentialsError{ - Raw: "these are an unparseable, raw credentials", - } - - suite.Contains(err.Error(), "these are an unparseable, raw credentials") - suite.Equal(ErrorTypeBadCredentials, err.Type()) -} - -func (suite *ErrorSuite) TestNewTypedError() { - original := errors.New("original error") - typed := NewTypedError(original, ErrorTypeBadCredentials) - - suite.ErrorIs(typed, original) - suite.Require().Implements((*Error)(nil), typed) - - var e Error - suite.Require().ErrorAs(typed, &e) - suite.Equal( - ErrorTypeBadCredentials, - e.Type(), - ) -} - -func (suite *ErrorSuite) TestGetErrorType() { - suite.Run("Unknown", func() { - suite.Equal( - ErrorTypeUnknown, - GetErrorType(errors.New("this is an error that is unknown to bascule")), - ) - }) - - suite.Run("ImplementsError", func() { - suite.Equal( - ErrorTypeBadCredentials, - GetErrorType(new(BadCredentialsError)), - ) - }) -} - -func TestError(t *testing.T) { - suite.Run(t, new(ErrorSuite)) -} diff --git a/mocks_test.go b/mocks_test.go index 00d0e3b..8a4f613 100644 --- a/mocks_test.go +++ b/mocks_test.go @@ -34,3 +34,25 @@ func assertValidators[S any](t mock.TestingT, vs ...Validator[S]) (passed bool) return } + +type mockTokenParser[S any] struct { + mock.Mock +} + +func (m *mockTokenParser[S]) Parse(ctx context.Context, source S) (Token, error) { + args := m.Called(ctx, source) + t, _ := args.Get(0).(Token) + return t, args.Error(1) +} + +func (m *mockTokenParser[S]) ExpectParse(ctx context.Context, source S) *mock.Call { + return m.On("Parse", ctx, source) +} + +func assertTokenParsers[S any](t mock.TestingT, tps ...TokenParser[S]) (passed bool) { + for _, p := range tps { + passed = p.(*mockTokenParser[S]).AssertExpectations(t) && passed + } + + return +} diff --git a/testSuite_test.go b/testSuite_test.go index 987cf63..43c6f1d 100644 --- a/testSuite_test.go +++ b/testSuite_test.go @@ -10,10 +10,6 @@ import ( "github.com/stretchr/testify/suite" ) -const ( - testScheme Scheme = "Test" -) - // TestSuite holds generally useful functionality for testing bascule. type TestSuite struct { suite.Suite @@ -27,13 +23,6 @@ func (suite *TestSuite) testContext() context.Context { ) } -func (suite *TestSuite) testCredentials() Credentials { - return Credentials{ - Scheme: testScheme, - Value: "test", - } -} - func (suite *TestSuite) testToken() Token { return testToken("test") } diff --git a/token.go b/token.go index 884eaf0..8290718 100644 --- a/token.go +++ b/token.go @@ -5,6 +5,24 @@ package bascule import ( "context" + "errors" + "reflect" +) + +var ( + // ErrorNoTokenParsers is returned by TokenParsers.Parse to indicate an empty array. + // This distinguishes the absence of a token from a source from the absence of a token + // because of configuration, possibly intentionally. + ErrorNoTokenParsers = errors.New("no token parsers") + + // ErrorMissingCredentials is returned by TokenParser.Parse to indicate that a source + // object did not have any credentials recognized by that parser. + ErrorMissingCredentials = errors.New("missing credentials") + + // ErrorInvalidCredentials is returned by TokenParser.Parse to indicate that a source + // did contain recognizable credentials, but those credentials could not be parsed, + // possibly due to bad formatting. + ErrorInvalidCredentials = errors.New("invalid credentials") ) // Token is a runtime representation of credentials. This interface will be further @@ -15,42 +33,95 @@ type Token interface { Principal() string } -// TokenParser produces tokens from credentials. The original source S of the credentials +// TokenParser produces tokens from a source. The original source S of the credentials // are made available to the parser. type TokenParser[S any] interface { - // Parse extracts a Token from a set of credentials. - Parse(ctx context.Context, source S, c Credentials) (Token, error) + // Parse extracts a Token from a source object, e.g. an HTTP request. + // + // If a particular source instance doesn't have the credentials expected by this + // parser, this method must return an error with MissingCredentials in the returned + // error's chain. + // + // If a source has credentials that failed to parse, this method must return an error + // with InvalidCredentials in its error chain. + // + // If this method returns a nil Token, it must return a non-nil error. Returning an + // error with a non-nil Token is allowed but not required. + Parse(ctx context.Context, source S) (Token, error) } -// TokenParserFunc is a closure type that implements TokenParser. -type TokenParserFunc[S any] func(context.Context, S, Credentials) (Token, error) +// TokenParserFunc describes the closure signatures that are allowed as TokenParser instances. +type TokenParserFunc[S any] interface { + ~func(source S) (Token, error) | + ~func(ctx context.Context, source S) (Token, error) +} + +// tokenParserFunc is the internal closure type that can be used to adapt +// a TokenParserFunc onto a TokenParser instance. +type tokenParserFunc[S any] func(context.Context, S) (Token, error) -func (tpf TokenParserFunc[S]) Parse(ctx context.Context, source S, c Credentials) (Token, error) { - return tpf(ctx, source, c) +func (tpf tokenParserFunc[S]) Parse(ctx context.Context, source S) (Token, error) { + return tpf(ctx, source) } -// TokenParsers is a registry of parsers based on credential schemes. -// The zero value of this type is valid and ready to use. -type TokenParsers[S any] map[Scheme]TokenParser[S] +// AsTokenParser accepts a closure and turns it into a TokenParser instance. +// Custom types that are convertible to a TokenParserFunc are also supported. +func AsTokenParser[S any, F TokenParserFunc[S]](f F) TokenParser[S] { + // first, try the simple cases + switch ft := any(f).(type) { + case func(S) (Token, error): + return tokenParserFunc[S](func(_ context.Context, source S) (Token, error) { + return ft(source) // curry away the context + }) + + case func(context.Context, S) (Token, error): + return tokenParserFunc[S](ft) + } -// Register adds or replaces the parser associated with the given scheme. -func (tp *TokenParsers[S]) Register(scheme Scheme, p TokenParser[S]) { - if *tp == nil { - *tp = make(TokenParsers[S]) + // now handle user-defined types. we have to look these up here, instead + // of "caching" them, because of the way generics in golang work. + fVal := reflect.ValueOf(f) + if ft := reflect.TypeOf((func(S) (Token, error))(nil)); fVal.CanConvert(ft) { + sourceOnly := fVal.Convert(ft).Interface().(func(S) (Token, error)) + return tokenParserFunc[S](func(_ context.Context, source S) (Token, error) { + return sourceOnly(source) // curry away the context + }) + } else { + ft := reflect.TypeOf((func(context.Context, S) (Token, error))(nil)) + return tokenParserFunc[S]( + fVal.Convert(ft).Interface().(func(context.Context, S) (Token, error)), + ) } +} + +// TokenParsers is an aggregate, ordered list of TokenParser implementations for +// a given type of source. +type TokenParsers[S any] []TokenParser[S] - (*tp)[scheme] = p +// Append adds one or more parsers to this aggregate TokenParsers. The semantics +// of this method are the same as the built-in append. +func (tps TokenParsers[S]) Append(more ...TokenParser[S]) TokenParsers[S] { + return append(tps, more...) } -// Parse chooses a TokenParser based on the Scheme and invokes that -// parser. If the credential scheme is unsupported, an error is returned. -func (tp TokenParsers[S]) Parse(ctx context.Context, source S, c Credentials) (t Token, err error) { - if p, ok := tp[c.Scheme]; ok { - t, err = p.Parse(ctx, source, c) - } else { - err = &UnsupportedSchemeError{ - Scheme: c.Scheme, - } +// Parse executes each TokenParser in turn. +// +// If this TokenParsers is empty, this method returns ErrorNoTokenParsers. +// +// If a parser returns MissingCredentials, it is skipped. If all parsers return +// MissingCredentials, the last error is returned. +// +// If a parser returns any other error, parsing is halted early and that error is returned. +// +// Otherwise, the token returned from the first successful parse is returned by +// this aggregate method. +func (tps TokenParsers[S]) Parse(ctx context.Context, source S) (t Token, err error) { + if len(tps) == 0 { + err = ErrorNoTokenParsers + } + + for i := 0; i < len(tps) && t == nil && (err == nil || errors.Is(err, ErrorMissingCredentials)); i++ { + t, err = tps[i].Parse(ctx, source) } return diff --git a/token_test.go b/token_test.go index 052475e..951bfd1 100644 --- a/token_test.go +++ b/token_test.go @@ -6,79 +6,185 @@ package bascule import ( "context" "errors" + "fmt" "testing" "github.com/stretchr/testify/suite" ) -type TokenParsersSuite struct { +type TokenParserSuite struct { TestSuite + + expectedCtx context.Context + expectedSource int + expectedToken Token + expectedErr error +} + +func (suite *TokenParserSuite) SetupSuite() { + suite.expectedCtx = suite.testContext() + suite.expectedSource = 123 + suite.expectedToken = testToken("expected token") + suite.expectedErr = errors.New("expected token parser error") +} + +func (suite *TokenParserSuite) assertParserResult(actualToken Token, actualErr error) { + suite.Equal(suite.expectedToken, actualToken) + suite.Equal(suite.expectedErr, actualErr) +} + +func (suite *TokenParserSuite) validateSource(actualSource int) (Token, error) { + suite.Equal(suite.expectedSource, actualSource) + return suite.expectedToken, suite.expectedErr +} + +func (suite *TokenParserSuite) validateContextSource(actualCtx context.Context, actualSource int) (Token, error) { + suite.Equal(suite.expectedCtx, actualCtx) + suite.Equal(suite.expectedSource, actualSource) + return suite.expectedToken, suite.expectedErr +} + +func (suite *TokenParserSuite) testAsTokenParserSource() { + suite.Run("Simple", func() { + suite.assertParserResult( + AsTokenParser[int](suite.validateSource). + Parse(suite.expectedCtx, suite.expectedSource), + ) + }) + + suite.Run("CustomType", func() { + type Custom func(int) (Token, error) + var cf Custom = Custom(suite.validateSource) + + suite.assertParserResult( + AsTokenParser[int](cf). + Parse(suite.expectedCtx, suite.expectedSource), + ) + }) } -func (suite *TokenParsersSuite) assertUnsupportedScheme(scheme Scheme, err error) { - var use *UnsupportedSchemeError - if suite.ErrorAs(err, &use) { - suite.Equal(scheme, use.Scheme) +func (suite *TokenParserSuite) testAsTokenParserContextSource() { + suite.Run("Simple", func() { + suite.assertParserResult( + AsTokenParser[int](suite.validateContextSource). + Parse(suite.expectedCtx, suite.expectedSource), + ) + }) + + suite.Run("CustomType", func() { + type Custom func(context.Context, int) (Token, error) + var cf Custom = Custom(suite.validateContextSource) + + suite.assertParserResult( + AsTokenParser[int](cf). + Parse(suite.expectedCtx, suite.expectedSource), + ) + }) +} + +func (suite *TokenParserSuite) TestAsTokenParser() { + suite.Run("Source", suite.testAsTokenParserSource) + suite.Run("ContextSource", suite.testAsTokenParserContextSource) +} + +// appendMissing appends a count of mocked TokenParser objects that return +// (nil, ErrorMissingCredentials) and expect this suite's expected input. +func (suite *TokenParserSuite) appendMissing(tps TokenParsers[int], count int) TokenParsers[int] { + for repeat := 0; repeat < count; repeat++ { + m := new(mockTokenParser[int]) + m.ExpectParse(suite.expectedCtx, suite.expectedSource). + Return(nil, ErrorMissingCredentials).Once() + tps = tps.Append(m) } + + return tps } -func (suite *TokenParsersSuite) testParseEmpty() { - var tp TokenParsers[string] +// appendSuccess appends a single mocked TokenParser that returns success using this +// suite's expected inputs and outputs. +func (suite *TokenParserSuite) appendSuccess(tps TokenParsers[int]) TokenParsers[int] { + m := new(mockTokenParser[int]) + m.ExpectParse(suite.expectedCtx, suite.expectedSource). + Return(suite.expectedToken, nil).Once() - // legal, but will always fail - token, err := tp.Parse(context.Background(), "doesnotmatter", suite.testCredentials()) - suite.Nil(token) - suite.assertUnsupportedScheme(testScheme, err) + return tps.Append(m) } -func (suite *TokenParsersSuite) testParseUnsupported() { - var tp TokenParsers[string] - tp.Register( - Scheme("Supported"), - TokenParserFunc[string]( - func(context.Context, string, Credentials) (Token, error) { - suite.Fail("TokenParser should not have been called") - return nil, nil - }, - ), - ) +// appendFail appends a single mocked TokenParser that returns a nil token and a failing +// error, using this suite's expected inputs and outputs. +func (suite *TokenParserSuite) appendFail(tps TokenParsers[int]) TokenParsers[int] { + m := new(mockTokenParser[int]) + m.ExpectParse(suite.expectedCtx, suite.expectedSource). + Return(nil, suite.expectedErr).Once() - token, err := tp.Parse(context.Background(), "doesnotmatter", suite.testCredentials()) - suite.Nil(token) - suite.assertUnsupportedScheme(testScheme, err) + return tps.Append(m) } -func (suite *TokenParsersSuite) testParseSupported() { - var ( - expectedErr = errors.New("expected Parse error") +// appendNoCall appends a count of mocked TokenParser objects that expect no calls to +// be made. Useful to verify that a TokenParsers instance stops parsing upon +// a successful parse or a non-missing error. +func (suite *TokenParserSuite) appendNoCall(tps TokenParsers[int], count int) TokenParsers[int] { + for repeat := 0; repeat < count; repeat++ { + m := new(mockTokenParser[int]) + tps = tps.Append(m) + } - testCtx = suite.testContext() - testCredentials = suite.testCredentials() - ) + return tps +} - var tp TokenParsers[string] - tp.Register( - testCredentials.Scheme, - TokenParserFunc[string]( - func(ctx context.Context, _ string, c Credentials) (Token, error) { - suite.Equal(testCtx, ctx) - suite.Equal(testCredentials, c) - return suite.testToken(), expectedErr - }, - ), - ) +// assertTokenParsersSuccess calls Parse and asserts that this suite's expected input occurred and +// that the ultimate token was this suite's expected token with a nil error. +func (suite *TokenParserSuite) assertTokenParsersSuccess(tps TokenParsers[int]) { + actualToken, actualErr := tps.Parse(suite.expectedCtx, suite.expectedSource) + suite.Equal(suite.expectedToken, actualToken) + suite.NoError(actualErr) + assertTokenParsers(suite.T(), tps...) +} - token, err := tp.Parse(testCtx, "doesnotmatter", testCredentials) - suite.Equal(suite.testToken(), token) - suite.Same(expectedErr, err) +// assertTokenParsersFail calls Parse and asserts that this suite's expected input occurred and +// that a failure occurred with this suite's expected error. +func (suite *TokenParserSuite) assertTokenParsersFail(tps TokenParsers[int]) { + actualToken, actualErr := tps.Parse(suite.expectedCtx, suite.expectedSource) + suite.Nil(actualToken) + suite.ErrorIs(actualErr, suite.expectedErr) + assertTokenParsers(suite.T(), tps...) +} + +func (suite *TokenParserSuite) testTokenParsersEmpty() { + var tps TokenParsers[int] + t, err := tps.Parse(suite.expectedCtx, suite.expectedSource) + suite.Nil(t) + suite.ErrorIs(err, ErrorNoTokenParsers) +} + +func (suite *TokenParserSuite) testTokenParsersSuccess() { + for _, count := range []int{1, 2, 5, 8} { + suite.Run(fmt.Sprintf("count=%d", count), func() { + tps := suite.appendMissing(nil, count/2) // half the parsers report missing, i.e. unsupported + tps = suite.appendSuccess(tps) + tps = suite.appendNoCall(tps, count-len(tps)) + suite.assertTokenParsersSuccess(tps) + }) + } +} + +func (suite *TokenParserSuite) testTokenParsersFail() { + for _, count := range []int{1, 2, 5, 8} { + suite.Run(fmt.Sprintf("count=%d", count), func() { + tps := suite.appendMissing(nil, count/2) // half the parsers report missing, i.e. unsupported + tps = suite.appendFail(tps) + tps = suite.appendNoCall(tps, count-len(tps)) + suite.assertTokenParsersFail(tps) + }) + } } -func (suite *TokenParsersSuite) TestParse() { - suite.Run("Empty", suite.testParseEmpty) - suite.Run("Unsupported", suite.testParseUnsupported) - suite.Run("Supported", suite.testParseSupported) +func (suite *TokenParserSuite) TestTokenParsers() { + suite.Run("Empty", suite.testTokenParsersEmpty) + suite.Run("Success", suite.testTokenParsersSuccess) + suite.Run("Fail", suite.testTokenParsersFail) } -func TestTokenParsers(t *testing.T) { - suite.Run(t, new(TokenParsersSuite)) +func TestTokenParser(t *testing.T) { + suite.Run(t, new(TokenParserSuite)) } diff --git a/validator.go b/validator.go index c0bea98..b72f3af 100644 --- a/validator.go +++ b/validator.go @@ -94,10 +94,10 @@ func (vf validatorFunc[S]) Validate(ctx context.Context, source S, t Token) (nex } var ( - tokenReturnError = reflect.TypeOf((func(Token) error)(nil)) - tokenReturnTokenAndError = reflect.TypeOf((func(Token) (Token, error))(nil)) - contextTokenReturnError = reflect.TypeOf((func(context.Context, Token) error)(nil)) - contextTokenReturnTokenError = reflect.TypeOf((func(context.Context, Token) (Token, error))(nil)) + validatorTokenReturnError = reflect.TypeOf((func(Token) error)(nil)) + validatorTokenReturnTokenAndError = reflect.TypeOf((func(Token) (Token, error))(nil)) + validatorContextTokenReturnError = reflect.TypeOf((func(context.Context, Token) error)(nil)) + validatorContextTokenReturnTokenError = reflect.TypeOf((func(context.Context, Token) (Token, error))(nil)) ) // asValidatorSimple tries simple conversions on f. This function will not catch @@ -188,24 +188,24 @@ func AsValidator[S any, F ValidatorFunc[S]](f F) Validator[S] { // require the source type. fVal := reflect.ValueOf(f) switch { - case fVal.CanConvert(tokenReturnError): + case fVal.CanConvert(validatorTokenReturnError): return asValidatorSimple[S]( - fVal.Convert(tokenReturnError).Interface().(func(Token) error), + fVal.Convert(validatorTokenReturnError).Interface().(func(Token) error), ) - case fVal.CanConvert(tokenReturnTokenAndError): + case fVal.CanConvert(validatorTokenReturnTokenAndError): return asValidatorSimple[S]( - fVal.Convert(tokenReturnTokenAndError).Interface().(func(Token) (Token, error)), + fVal.Convert(validatorTokenReturnTokenAndError).Interface().(func(Token) (Token, error)), ) - case fVal.CanConvert(contextTokenReturnError): + case fVal.CanConvert(validatorContextTokenReturnError): return asValidatorSimple[S]( - fVal.Convert(contextTokenReturnError).Interface().(func(context.Context, Token) error), + fVal.Convert(validatorContextTokenReturnError).Interface().(func(context.Context, Token) error), ) - case fVal.CanConvert(contextTokenReturnTokenError): + case fVal.CanConvert(validatorContextTokenReturnTokenError): return asValidatorSimple[S]( - fVal.Convert(contextTokenReturnTokenError).Interface().(func(context.Context, Token) (Token, error)), + fVal.Convert(validatorContextTokenReturnTokenError).Interface().(func(context.Context, Token) (Token, error)), ) } diff --git a/validator_test.go b/validator_test.go index af45572..1c8b523 100644 --- a/validator_test.go +++ b/validator_test.go @@ -23,13 +23,7 @@ type ValidatorsTestSuite struct { } func (suite *ValidatorsTestSuite) SetupSuite() { - type contextKey struct{} - suite.expectedCtx = context.WithValue( - context.Background(), - contextKey{}, - "value", - ) - + suite.expectedCtx = suite.testContext() suite.expectedSource = 123 suite.inputToken = testToken("input token") suite.outputToken = testToken("output token") From 2d75fee152506d97572f7c0e2baefa0f769d7c8d Mon Sep 17 00:00:00 2001 From: johnabass Date: Thu, 1 Aug 2024 09:17:50 -0700 Subject: [PATCH 2/6] chore: renamed to standard golang conventions --- token.go | 18 +++++++++--------- token_test.go | 4 ++-- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/token.go b/token.go index 8290718..f0b388a 100644 --- a/token.go +++ b/token.go @@ -10,19 +10,19 @@ import ( ) var ( - // ErrorNoTokenParsers is returned by TokenParsers.Parse to indicate an empty array. + // ErrNoTokenParsers is returned by TokenParsers.Parse to indicate an empty array. // This distinguishes the absence of a token from a source from the absence of a token // because of configuration, possibly intentionally. - ErrorNoTokenParsers = errors.New("no token parsers") + ErrNoTokenParsers = errors.New("no token parsers") - // ErrorMissingCredentials is returned by TokenParser.Parse to indicate that a source + // ErrMissingCredentials is returned by TokenParser.Parse to indicate that a source // object did not have any credentials recognized by that parser. - ErrorMissingCredentials = errors.New("missing credentials") + ErrMissingCredentials = errors.New("missing credentials") - // ErrorInvalidCredentials is returned by TokenParser.Parse to indicate that a source + // ErrInvalidCredentials is returned by TokenParser.Parse to indicate that a source // did contain recognizable credentials, but those credentials could not be parsed, // possibly due to bad formatting. - ErrorInvalidCredentials = errors.New("invalid credentials") + ErrInvalidCredentials = errors.New("invalid credentials") ) // Token is a runtime representation of credentials. This interface will be further @@ -106,7 +106,7 @@ func (tps TokenParsers[S]) Append(more ...TokenParser[S]) TokenParsers[S] { // Parse executes each TokenParser in turn. // -// If this TokenParsers is empty, this method returns ErrorNoTokenParsers. +// If this TokenParsers is empty, this method returns ErrNoTokenParsers. // // If a parser returns MissingCredentials, it is skipped. If all parsers return // MissingCredentials, the last error is returned. @@ -117,10 +117,10 @@ func (tps TokenParsers[S]) Append(more ...TokenParser[S]) TokenParsers[S] { // this aggregate method. func (tps TokenParsers[S]) Parse(ctx context.Context, source S) (t Token, err error) { if len(tps) == 0 { - err = ErrorNoTokenParsers + err = ErrNoTokenParsers } - for i := 0; i < len(tps) && t == nil && (err == nil || errors.Is(err, ErrorMissingCredentials)); i++ { + for i := 0; i < len(tps) && t == nil && (err == nil || errors.Is(err, ErrMissingCredentials)); i++ { t, err = tps[i].Parse(ctx, source) } diff --git a/token_test.go b/token_test.go index 951bfd1..4606708 100644 --- a/token_test.go +++ b/token_test.go @@ -93,7 +93,7 @@ func (suite *TokenParserSuite) appendMissing(tps TokenParsers[int], count int) T for repeat := 0; repeat < count; repeat++ { m := new(mockTokenParser[int]) m.ExpectParse(suite.expectedCtx, suite.expectedSource). - Return(nil, ErrorMissingCredentials).Once() + Return(nil, ErrMissingCredentials).Once() tps = tps.Append(m) } @@ -154,7 +154,7 @@ func (suite *TokenParserSuite) testTokenParsersEmpty() { var tps TokenParsers[int] t, err := tps.Parse(suite.expectedCtx, suite.expectedSource) suite.Nil(t) - suite.ErrorIs(err, ErrorNoTokenParsers) + suite.ErrorIs(err, ErrNoTokenParsers) } func (suite *TokenParserSuite) testTokenParsersSuccess() { From cc8747beef0b542205c656412ba9226507d553ec Mon Sep 17 00:00:00 2001 From: johnabass Date: Thu, 1 Aug 2024 11:26:51 -0700 Subject: [PATCH 3/6] chore: renamed context API functions to something more standard --- basculehttp/basic.go | 60 +++++++----------- basculehttp/challenge.go | 16 ++--- basculehttp/credentials.go | 108 -------------------------------- basculehttp/credentials_test.go | 108 -------------------------------- basculehttp/middleware.go | 34 ++-------- context.go | 10 +-- context_test.go | 40 ++++++------ 7 files changed, 55 insertions(+), 321 deletions(-) delete mode 100644 basculehttp/credentials.go delete mode 100644 basculehttp/credentials_test.go diff --git a/basculehttp/basic.go b/basculehttp/basic.go index 43e69b2..315a056 100644 --- a/basculehttp/basic.go +++ b/basculehttp/basic.go @@ -5,55 +5,39 @@ package basculehttp import ( "context" - "encoding/base64" "net/http" - "strings" "github.com/xmidt-org/bascule/v1" ) -// InvalidBasicAuthError indicates that the Basic credentials were improperly -// encoded, either due to base64 issues or formatting. -type InvalidBasicAuthError struct { - // Cause represents the lower level error that occurred, e.g. a base64 - // encoding error. - Cause error +// BasicToken is a bascule.Token that results from Basic authorization. +type BasicToken struct { + UserName string + Password string } -func (err *InvalidBasicAuthError) Unwrap() error { return err.Cause } - -func (err *InvalidBasicAuthError) Error() string { - var o strings.Builder - o.WriteString("Basic auth string not encoded properly") - - if err.Cause != nil { - o.WriteString(": ") - o.WriteString(err.Cause.Error()) - } - - return o.String() +// Principal returns the user name from Basic auth. +func (bt BasicToken) Principal() string { + return bt.UserName } -type basicTokenParser struct{} - -func (btp basicTokenParser) Parse(_ context.Context, _ *http.Request, c bascule.Credentials) (t bascule.Token, err error) { - var decoded []byte - decoded, err = base64.StdEncoding.DecodeString(c.Value) - if err != nil { - err = &InvalidBasicAuthError{ - Cause: err, - } - - return - } - - username, _, found := strings.Cut(string(decoded), ":") - if found { - t = &Token{ - principal: username, +// BasicTokenParser is a bascule.TokenParser expects Basic auth to +// be present. +type BasicTokenParser struct{} + +// Parse extracts the Basic auth credentials from the source request. +// The net/http package is used to do this parsing. +// +// If no Basic auth credentials could be found, this method returns +// bascule.MissingCredentials. +func (btp BasicTokenParser) Parse(_ context.Context, source *http.Request) (t bascule.Token, err error) { + if userName, password, ok := source.BasicAuth(); ok { + t = BasicToken{ + UserName: userName, + Password: password, } } else { - err = &InvalidBasicAuthError{} + err = bascule.MissingCredentials } return diff --git a/basculehttp/challenge.go b/basculehttp/challenge.go index b81e033..a9c0811 100644 --- a/basculehttp/challenge.go +++ b/basculehttp/challenge.go @@ -6,17 +6,9 @@ package basculehttp import ( "net/http" "strings" - - "github.com/xmidt-org/bascule/v1" ) const ( - // BasicScheme is the name of the basic HTTP authentication scheme. - BasicScheme bascule.Scheme = "Basic" - - // BearerScheme is the name of the bearer HTTP authentication scheme. - BearerScheme bascule.Scheme = "Bearer" - // WwwAuthenticateHeaderName is the HTTP header used for StatusUnauthorized challenges. WwwAuthenticateHeaderName = "WWW-Authenticate" @@ -63,7 +55,7 @@ func (chs Challenges) WriteHeader(h http.Header) { type BasicChallenge struct { // Scheme is the name of scheme supplied in the challenge. If this // field is unset, BasicScheme is used. - Scheme bascule.Scheme + Scheme Scheme // Realm is the name of the realm for the challenge. If this field // is unset, DefaultBasicRealm is used. @@ -81,7 +73,7 @@ func (bc BasicChallenge) FormatAuthenticate(o strings.Builder) { if len(bc.Scheme) > 0 { o.WriteString(string(bc.Scheme)) } else { - o.WriteString(string(BasicScheme)) + o.WriteString(string(SchemeBasic)) } o.WriteString(` realm="`) @@ -100,7 +92,7 @@ func (bc BasicChallenge) FormatAuthenticate(o strings.Builder) { type BearerChallenge struct { // Scheme is the name of scheme supplied in the challenge. If this // field is unset, BearerScheme is used. - Scheme bascule.Scheme + Scheme Scheme // Realm is the name of the realm for the challenge. If this field // is unset, DefaultBearerRealm is used. @@ -114,7 +106,7 @@ func (bc BearerChallenge) FormatAuthenticate(o strings.Builder) { if len(bc.Scheme) > 0 { o.WriteString(string(bc.Scheme)) } else { - o.WriteString(string(BasicScheme)) + o.WriteString(string(SchemeBearer)) } o.WriteString(` realm="`) diff --git a/basculehttp/credentials.go b/basculehttp/credentials.go deleted file mode 100644 index 025119e..0000000 --- a/basculehttp/credentials.go +++ /dev/null @@ -1,108 +0,0 @@ -// SPDX-FileCopyrightText: 2020 Comcast Cable Communications Management, LLC -// SPDX-License-Identifier: Apache-2.0 - -package basculehttp - -import ( - "context" - "net/http" - "strings" - - "github.com/xmidt-org/bascule/v1" -) - -const ( - // DefaultAuthorizationHeader is the name of the header used by default to obtain - // the raw credentials. - DefaultAuthorizationHeader = "Authorization" -) - -// DuplicateHeaderError indicates that an HTTP header had more than one value -// when only one value was expected. -type DuplicateHeaderError struct { - // Header is the name of the duplicate header. - Header string -} - -func (err *DuplicateHeaderError) Error() string { - var o strings.Builder - o.WriteString(`Duplicate header: "`) - o.WriteString(err.Header) - o.WriteString(`"`) - return o.String() -} - -// MissingHeaderError indicates that an expected HTTP header is missing. -type MissingHeaderError struct { - // Header is the name of the missing header. - Header string -} - -func (err *MissingHeaderError) Error() string { - var o strings.Builder - o.WriteString(`Missing header: "`) - o.WriteString(err.Header) - o.WriteString(`"`) - return o.String() -} - -// StatusCode returns http.StatusUnauthorized, as the request carries -// no authorization in it. -func (err *MissingHeaderError) StatusCode() int { - return http.StatusUnauthorized -} - -// DefaultCredentialsParser is the default algorithm used to produce HTTP credentials -// from a source request. -type DefaultCredentialsParser struct { - // Header is the name of the authorization header. If unset, - // DefaultAuthorizationHeader is used. - Header string - - // ErrorOnDuplicate controls whether an error is returned if more - // than one Header is found in the request. By default, this is false. - ErrorOnDuplicate bool -} - -func (dcp DefaultCredentialsParser) Parse(_ context.Context, source *http.Request) (c bascule.Credentials, err error) { - header := dcp.Header - if len(header) == 0 { - header = DefaultAuthorizationHeader - } - - var raw string - values := source.Header.Values(header) - switch { - case len(values) == 0: - err = &MissingHeaderError{ - Header: header, - } - - case len(values) == 1 || !dcp.ErrorOnDuplicate: - raw = values[0] - - default: - err = &DuplicateHeaderError{ - Header: header, - } - } - - if err == nil { - // format is - // the code is strict: it requires no leading or trailing space - // and exactly one (1) space as a separator. - scheme, credValue, found := strings.Cut(raw, " ") - if found && len(scheme) > 0 && !fastIsSpace(credValue[0]) && !fastIsSpace(credValue[len(credValue)-1]) { - c = bascule.Credentials{ - Scheme: bascule.Scheme(scheme), - Value: credValue, - } - } else { - err = &bascule.BadCredentialsError{ - Raw: raw, - } - } - } - - return -} diff --git a/basculehttp/credentials_test.go b/basculehttp/credentials_test.go deleted file mode 100644 index 89941e0..0000000 --- a/basculehttp/credentials_test.go +++ /dev/null @@ -1,108 +0,0 @@ -// SPDX-FileCopyrightText: 2020 Comcast Cable Communications Management, LLC -// SPDX-License-Identifier: Apache-2.0 - -package basculehttp - -import ( - "context" - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/suite" - "github.com/xmidt-org/bascule/v1" -) - -type CredentialsTestSuite struct { - suite.Suite -} - -func (suite *CredentialsTestSuite) newDefaultSource(value string) *http.Request { - r := httptest.NewRequest("GET", "/", nil) - r.Header.Set(DefaultAuthorizationHeader, value) - return r -} - -func (suite *CredentialsTestSuite) testDefaultCredentialsParserSuccess() { - const ( - expectedScheme bascule.Scheme = "Test" - expectedValue string = "credentialValue" - ) - - testCases := []string{ - "Test credentialValue", - } - - for _, testCase := range testCases { - suite.Run(testCase, func() { - dcp := DefaultCredentialsParser{} - suite.Require().NotNil(dcp) - - creds, err := dcp.Parse(context.Background(), suite.newDefaultSource(testCase)) - suite.Require().NoError(err) - suite.Equal( - bascule.Credentials{ - Scheme: expectedScheme, - Value: expectedValue, - }, - creds, - ) - }) - } -} - -func (suite *CredentialsTestSuite) testDefaultCredentialsParserFailure() { - testCases := []string{ - "", - " ", - "thisisnotvalid", - "Test\tcredentialValue", - " Test credentialValue", - "Test credentialValue ", - "Test credentialValue", - } - - for _, testCase := range testCases { - suite.Run(testCase, func() { - dcp := DefaultCredentialsParser{} - suite.Require().NotNil(dcp) - - creds, err := dcp.Parse(context.Background(), suite.newDefaultSource(testCase)) - suite.Require().Error(err) - suite.Equal(bascule.Credentials{}, creds) - - var ice *bascule.BadCredentialsError - if suite.ErrorAs(err, &ice) { - suite.Equal(testCase, ice.Raw) - } - }) - } -} - -func (suite *CredentialsTestSuite) testDefaultCredentialsParserMissingHeader() { - dcp := DefaultCredentialsParser{} - suite.Require().NotNil(dcp) - - r := httptest.NewRequest("GET", "/", nil) - creds, err := dcp.Parse(context.Background(), r) - suite.Require().Error(err) - suite.Equal(bascule.Credentials{}, creds) - - type statusCoder interface { - StatusCode() int - } - - var sc statusCoder - suite.Require().ErrorAs(err, &sc) - suite.Equal(http.StatusUnauthorized, sc.StatusCode()) -} - -func (suite *CredentialsTestSuite) TestDefaultCredentialsParser() { - suite.Run("Success", suite.testDefaultCredentialsParserSuccess) - suite.Run("Failure", suite.testDefaultCredentialsParserFailure) - suite.Run("MissingHeader", suite.testDefaultCredentialsParserMissingHeader) -} - -func TestCredentials(t *testing.T) { - suite.Run(t, new(CredentialsTestSuite)) -} diff --git a/basculehttp/middleware.go b/basculehttp/middleware.go index 9868f6e..eaf2e38 100644 --- a/basculehttp/middleware.go +++ b/basculehttp/middleware.go @@ -23,31 +23,6 @@ func (mof middlewareOptionFunc) apply(m *Middleware) error { return mof(m) } -// WithCredentialsParser configures a credentials parser for this Middleware. If not supplied -// or if the supplied CredentialsParser is nil, DefaultCredentialsParser() is used. -func WithCredentialsParser(cp bascule.CredentialsParser[*http.Request]) MiddlewareOption { - return middlewareOptionFunc(func(m *Middleware) error { - if cp != nil { - m.credentialsParser = cp - } else { - m.credentialsParser = DefaultCredentialsParser{} - } - - return nil - }) -} - -// WithTokenParser registers a token parser for the given scheme. If the scheme has -// already been registered, the given parser will replace that registration. -// -// The parser cannot be nil. -func WithTokenParser(scheme bascule.Scheme, tp bascule.TokenParser[*http.Request]) MiddlewareOption { - return middlewareOptionFunc(func(m *Middleware) error { - m.tokenParsers.Register(scheme, tp) - return nil - }) -} - // WithAuthentication adds validators used for authentication to this Middleware. Each // invocation of this option is cumulative. Authentication validators are run in the order // supplied by this option. @@ -112,11 +87,10 @@ func WithErrorMarshaler(em ErrorMarshaler) MiddlewareOption { // Middleware is an immutable configuration that can decorate multiple handlers. type Middleware struct { - credentialsParser bascule.CredentialsParser[*http.Request] - tokenParsers bascule.TokenParsers[*http.Request] - authentication bascule.Validators[*http.Request] - authorization bascule.Authorizers[*http.Request] - challenges Challenges + tokenParsers bascule.TokenParsers[*http.Request] + authentication bascule.Validators[*http.Request] + authorization bascule.Authorizers[*http.Request] + challenges Challenges errorStatusCoder ErrorStatusCoder errorMarshaler ErrorMarshaler diff --git a/context.go b/context.go index bf9804d..c90104b 100644 --- a/context.go +++ b/context.go @@ -15,17 +15,17 @@ type Contexter interface { type tokenContextKey struct{} -// GetToken retrieves a Token from a context. If not token is in the context, +// Get retrieves a Token from a context. If not token is in the context, // this function returns false. -func GetToken(ctx context.Context) (t Token, found bool) { +func Get(ctx context.Context) (t Token, found bool) { t, found = ctx.Value(tokenContextKey{}).(Token) return } -// GetTokenFrom uses the context held by src to obtain a Token. As with GetToken, +// GetFrom uses the context held by src to obtain a Token. As with GetToken, // if no token is found this function returns false. -func GetTokenFrom(src Contexter) (Token, bool) { - return GetToken(src.Context()) +func GetFrom(src Contexter) (Token, bool) { + return Get(src.Context()) } // WithToken constructs a new context with the supplied token. diff --git a/context_test.go b/context_test.go index fb41ec4..4680e38 100644 --- a/context_test.go +++ b/context_test.go @@ -14,14 +14,14 @@ type ContextTestSuite struct { TestSuite } -func (suite *ContextTestSuite) testGetTokenSuccess() { +func (suite *ContextTestSuite) testGetSuccess() { ctx := context.WithValue( context.Background(), tokenContextKey{}, suite.testToken(), ) - token, ok := GetToken(ctx) + token, ok := Get(ctx) suite.Require().True(ok) suite.Equal( suite.testToken(), @@ -29,26 +29,26 @@ func (suite *ContextTestSuite) testGetTokenSuccess() { ) } -func (suite *ContextTestSuite) testGetTokenMissing() { - token, ok := GetToken(context.Background()) +func (suite *ContextTestSuite) testGetMissing() { + token, ok := Get(context.Background()) suite.Nil(token) suite.False(ok) } -func (suite *ContextTestSuite) testGetTokenWrongType() { +func (suite *ContextTestSuite) testGetWrongType() { ctx := context.WithValue(context.Background(), tokenContextKey{}, 123) - token, ok := GetToken(ctx) + token, ok := Get(ctx) suite.Nil(token) suite.False(ok) } -func (suite *ContextTestSuite) TestGetToken() { - suite.Run("Success", suite.testGetTokenSuccess) - suite.Run("Missing", suite.testGetTokenMissing) - suite.Run("WrongType", suite.testGetTokenWrongType) +func (suite *ContextTestSuite) TestGet() { + suite.Run("Success", suite.testGetSuccess) + suite.Run("Missing", suite.testGetMissing) + suite.Run("WrongType", suite.testGetWrongType) } -func (suite *ContextTestSuite) testGetTokenFromSuccess() { +func (suite *ContextTestSuite) testGetFromSuccess() { c := suite.contexter( context.WithValue( context.Background(), @@ -57,7 +57,7 @@ func (suite *ContextTestSuite) testGetTokenFromSuccess() { ), ) - token, ok := GetTokenFrom(c) + token, ok := GetFrom(c) suite.Require().True(ok) suite.Equal( suite.testToken(), @@ -65,8 +65,8 @@ func (suite *ContextTestSuite) testGetTokenFromSuccess() { ) } -func (suite *ContextTestSuite) testGetTokenFromMissing() { - token, ok := GetTokenFrom( +func (suite *ContextTestSuite) testGetFromMissing() { + token, ok := GetFrom( suite.contexter(context.Background()), ) @@ -74,20 +74,20 @@ func (suite *ContextTestSuite) testGetTokenFromMissing() { suite.False(ok) } -func (suite *ContextTestSuite) testGetTokenFromWrongType() { +func (suite *ContextTestSuite) testGetFromWrongType() { c := suite.contexter( context.WithValue(context.Background(), tokenContextKey{}, 123), ) - token, ok := GetTokenFrom(c) + token, ok := GetFrom(c) suite.Nil(token) suite.False(ok) } -func (suite *ContextTestSuite) TestGetTokenFrom() { - suite.Run("Success", suite.testGetTokenFromSuccess) - suite.Run("Missing", suite.testGetTokenFromMissing) - suite.Run("WrongType", suite.testGetTokenFromWrongType) +func (suite *ContextTestSuite) TestGetFrom() { + suite.Run("Success", suite.testGetFromSuccess) + suite.Run("Missing", suite.testGetFromMissing) + suite.Run("WrongType", suite.testGetFromWrongType) } func (suite *ContextTestSuite) TestWithToken() { From fa873822ed214e675a534449f988e7df2907fc49 Mon Sep 17 00:00:00 2001 From: johnabass Date: Mon, 5 Aug 2024 15:44:10 -0700 Subject: [PATCH 4/6] simplified token parsing and error handling --- basculehttp/authorization.go | 105 ++++++++++++++++++++++++ basculehttp/basic.go | 71 ++++++++++------ basculehttp/error.go | 52 ++++++------ basculehttp/middleware.go | 50 ++++++----- basculehttp/middleware_examples_test.go | 2 +- basculehttp/scheme.go | 15 ++++ basculehttp/token.go | 28 ------- 7 files changed, 221 insertions(+), 102 deletions(-) create mode 100644 basculehttp/authorization.go create mode 100644 basculehttp/scheme.go delete mode 100644 basculehttp/token.go diff --git a/basculehttp/authorization.go b/basculehttp/authorization.go new file mode 100644 index 0000000..46a38a9 --- /dev/null +++ b/basculehttp/authorization.go @@ -0,0 +1,105 @@ +// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package basculehttp + +import ( + "context" + "errors" + "net/http" + "strings" + + "github.com/xmidt-org/bascule/v1" +) + +const ( + // DefaultAuthorizationHeader is the default HTTP header used for authorization + // tokens in an HTTP request. + DefaultAuthorizationHeader = "Authorization" +) + +var ( + // ErrInvalidAuthorization indicates an authorization header value did not + // correspond to the standard. + ErrInvalidAuthorization = errors.New("invalidation authorization") + + // ErrMissingAuthorization indicates that no authorization header was + // present in the source HTTP request. + ErrMissingAuthorization = errors.New("missing authorization") +) + +// fastIsSpace tests an ASCII byte to see if it's whitespace. +// HTTP headers are restricted to US-ASCII, so we don't need +// the full unicode stack. +func fastIsSpace(b byte) bool { + return b == ' ' || b == '\t' || b == '\n' || b == '\r' || b == '\v' || b == '\f' +} + +// ParseAuthorization parses an authorization value typically passed in +// the Authorization HTTP header. +// +// The required format is . This function +// is strict: it requires no leading or trailing space and exactly (1) space as +// a separator. If the raw value does not adhere to this format, ErrInvalidAuthorization +// is returned. +func ParseAuthorization(raw string) (s Scheme, v string, err error) { + var scheme string + var found bool + scheme, v, found = strings.Cut(raw, " ") + if found && len(scheme) > 0 && !fastIsSpace(v[0]) && !fastIsSpace(v[len(v)-1]) { + s = Scheme(scheme) + } else { + err = ErrInvalidAuthorization + } + + return +} + +type AuthorizationParserOption interface { + apply(*AuthorizationParser) error +} + +type authorizationParserOptionFunc func(*AuthorizationParser) error + +func (apof authorizationParserOptionFunc) apply(ap *AuthorizationParser) error { return apof(ap) } + +func WithAuthorizationHeader(header string) AuthorizationParserOption { + return authorizationParserOptionFunc(func(ap *AuthorizationParser) error { + ap.header = header + return nil + }) +} + +func WithScheme(scheme Scheme, parser bascule.TokenParser[string]) AuthorizationParserOption { + return authorizationParserOptionFunc(func(ap *AuthorizationParser) error { + ap.parsers[scheme] = parser + return nil + }) +} + +type AuthorizationParser struct { + header string + parsers map[Scheme]bascule.TokenParser[string] +} + +func NewAuthorizationParser(opts ...AuthorizationParserOption) (*AuthorizationParser, error) { + ap := &AuthorizationParser{ + parsers: make(map[Scheme]bascule.TokenParser[string]), + } + + for _, o := range opts { + if err := o.apply(ap); err != nil { + return nil, err + } + } + + if len(ap.header) == 0 { + ap.header = DefaultAuthorizationHeader + } + + return ap, nil +} + +func (ap *AuthorizationParser) Parse(_ context.Context, source *http.Request) (bascule.Token, error) { + return nil, nil // TODO +} diff --git a/basculehttp/basic.go b/basculehttp/basic.go index 315a056..bff71c0 100644 --- a/basculehttp/basic.go +++ b/basculehttp/basic.go @@ -5,40 +5,63 @@ package basculehttp import ( "context" - "net/http" + "encoding/base64" + "strings" "github.com/xmidt-org/bascule/v1" ) -// BasicToken is a bascule.Token that results from Basic authorization. -type BasicToken struct { - UserName string - Password string +// BasicToken is the interface that Basic Auth tokens implement. +type BasicToken interface { + UserName() string + Password() string } -// Principal returns the user name from Basic auth. -func (bt BasicToken) Principal() string { - return bt.UserName +// basicToken is the internal basic token struct that results from +// parsing a Basic Authorization header value. +type basicToken struct { + userName string + password string } -// BasicTokenParser is a bascule.TokenParser expects Basic auth to -// be present. -type BasicTokenParser struct{} +func (bt basicToken) Principal() string { + return bt.userName +} + +func (bt basicToken) UserName() string { + return bt.userName +} + +func (bt basicToken) Password() string { + return bt.password +} -// Parse extracts the Basic auth credentials from the source request. -// The net/http package is used to do this parsing. +// BasicTokenParser is a string-based bascule.TokenParser that produces +// BasicToken instances from strings. // -// If no Basic auth credentials could be found, this method returns -// bascule.MissingCredentials. -func (btp BasicTokenParser) Parse(_ context.Context, source *http.Request) (t bascule.Token, err error) { - if userName, password, ok := source.BasicAuth(); ok { - t = BasicToken{ - UserName: userName, - Password: password, - } - } else { - err = bascule.MissingCredentials +// An instance of this parser may be passed to WithScheme in order to +// configure an AuthorizationParser. +type BasicTokenParser struct{} + +// Parse assumes that value is of the format required by https://datatracker.ietf.org/doc/html/rfc7617. +// The returned Token will return the basic auth username from its Principal() method. +// The returned Token will also implement BasicToken. +func (BasicTokenParser) Parse(_ context.Context, value string) (bascule.Token, error) { + // this mimics what the stdlib does at net/http.Request.BasicAuth() + raw, err := base64.StdEncoding.DecodeString(value) + if err != nil { + return nil, bascule.ErrInvalidCredentials + } + + var ( + bt basicToken + ok bool + ) + + bt.userName, bt.password, ok = strings.Cut(string(raw), ":") + if ok { + return bt, nil } - return + return nil, bascule.ErrInvalidCredentials } diff --git a/basculehttp/error.go b/basculehttp/error.go index b96d6f5..01a1846 100644 --- a/basculehttp/error.go +++ b/basculehttp/error.go @@ -14,44 +14,42 @@ import ( // ErrorStatusCoder is a strategy for determining the HTTP response code for an error. // -// The defaultCode is used when this strategy cannot determine the code from the error. -// This default can be a sentinel for decorators, e.g. zero (0), or can be an actual -// status code. -type ErrorStatusCoder func(request *http.Request, defaultCode int, err error) int +// If this closure returns a value less than 100, which is the smallest valid HTTP +// response code, the caller should supply a useful default. +type ErrorStatusCoder func(request *http.Request, err error) int // DefaultErrorStatusCoder is the strategy used when no ErrorStatusCoder is supplied. -// This function first tries to see if the error implements bascule.Error, in which case -// the error's type will dictate the response code. Next, if the wrapper error provides -// a StatusCode() method, that code is used. Failing all of that, the defaultCode is -// returned. // -// This function can also be decorated. Passing a sentinel value for defaultCode allows -// a decorator to take further action. -func DefaultErrorStatusCoder(_ *http.Request, defaultCode int, err error) int { - switch bascule.GetErrorType(err) { - case bascule.ErrorTypeMissingCredentials: - return http.StatusUnauthorized - - case bascule.ErrorTypeBadCredentials: - return http.StatusBadRequest - - case bascule.ErrorTypeInvalidCredentials: - return http.StatusForbidden - - case bascule.ErrorTypeForbidden: - return http.StatusForbidden - } - +// If err has bascule.ErrMissingCredentials in its chain, this function returns +// http.StatusUnauthorized. +// +// If err has bascule.ErrInvalidCredentials in its chain, this function returns +// http.StatusBadRequest. +// +// Failing the previous two checks, if the error provides a StatusCode() method, +// the return value from that method is used. +// +// Otherwise, this method returns 0 to indicate that it doesn't know how to +// produce a status code from the error. +func DefaultErrorStatusCoder(_ *http.Request, err error) int { type statusCoder interface { StatusCode() int } var sc statusCoder - if errors.As(err, &sc) { + + switch { + case errors.Is(err, bascule.ErrMissingCredentials): + return http.StatusUnauthorized + + case errors.Is(err, bascule.ErrInvalidCredentials): + return http.StatusBadRequest + + case errors.As(err, &sc): return sc.StatusCode() } - return defaultCode + return 0 } // ErrorMarshaler is a strategy for marshaling an error's contents, particularly to diff --git a/basculehttp/middleware.go b/basculehttp/middleware.go index eaf2e38..1b2de87 100644 --- a/basculehttp/middleware.go +++ b/basculehttp/middleware.go @@ -23,6 +23,16 @@ func (mof middlewareOptionFunc) apply(m *Middleware) error { return mof(m) } +// WithTokenParsers appends token parsers to the chain used by the middleware. +// Each invocation of this option is cumulative. Token parsers are run in the +// order supplied via this option. +func WithTokenParsers(tps ...bascule.TokenParser[*http.Request]) MiddlewareOption { + return middlewareOptionFunc(func(m *Middleware) error { + m.tokenParsers = m.tokenParsers.Append(tps...) + return nil + }) +} + // WithAuthentication adds validators used for authentication to this Middleware. Each // invocation of this option is cumulative. Authentication validators are run in the order // supplied by this option. @@ -100,10 +110,8 @@ type Middleware struct { // No options will result in a Middleware with default behavior. func NewMiddleware(opts ...MiddlewareOption) (m *Middleware, err error) { m = &Middleware{ - credentialsParser: DefaultCredentialsParser{}, - tokenParsers: DefaultTokenParsers(), - errorStatusCoder: DefaultErrorStatusCoder, - errorMarshaler: DefaultErrorMarshaler, + errorStatusCoder: DefaultErrorStatusCoder, + errorMarshaler: DefaultErrorMarshaler, } for _, o := range opts { @@ -121,7 +129,7 @@ func (m *Middleware) Then(protected http.Handler) http.Handler { } return &frontDoor{ - middleware: m, + Middleware: m, protected: protected, } } @@ -143,7 +151,11 @@ func (m *Middleware) ThenFunc(protected http.HandlerFunc) http.Handler { // If the error supports JSON or text marshaling, that is used for the response body. Otherwise, a text/plain // response with the Error() method's text is used. func (m *Middleware) writeError(response http.ResponseWriter, request *http.Request, defaultCode int, err error) { - statusCode := m.errorStatusCoder(request, defaultCode, err) + statusCode := m.errorStatusCoder(request, err) + if statusCode < 100 { + statusCode = defaultCode + } + if statusCode == http.StatusUnauthorized { m.challenges.WriteHeader(response.Header()) } @@ -159,13 +171,8 @@ func (m *Middleware) writeError(response http.ResponseWriter, request *http.Requ } } -func (m *Middleware) getCredentialsAndToken(ctx context.Context, request *http.Request) (c bascule.Credentials, t bascule.Token, err error) { - c, err = m.credentialsParser.Parse(request.Context(), request) - if err == nil { - t, err = m.tokenParsers.Parse(ctx, request, c) - } - - return +func (m *Middleware) parseToken(ctx context.Context, request *http.Request) (bascule.Token, error) { + return m.tokenParsers.Parse(ctx, request) } func (m *Middleware) authenticate(ctx context.Context, request *http.Request, token bascule.Token) (bascule.Token, error) { @@ -179,35 +186,34 @@ func (m *Middleware) authorize(ctx context.Context, token bascule.Token, request // frontDoor is the internal handler implementation that protects a handler // using the bascule workflow. type frontDoor struct { - middleware *Middleware - protected http.Handler + *Middleware + protected http.Handler } // ServeHTTP implements the bascule workflow, using the configured middleware. func (fd *frontDoor) ServeHTTP(response http.ResponseWriter, request *http.Request) { ctx := request.Context() - creds, token, err := fd.middleware.getCredentialsAndToken(ctx, request) + token, err := fd.parseToken(ctx, request) if err != nil { // by default, failing to parse a token is a malformed request - fd.middleware.writeError(response, request, http.StatusBadRequest, err) + fd.writeError(response, request, http.StatusBadRequest, err) return } - ctx = bascule.WithCredentials(ctx, creds) - token, err = fd.middleware.authenticate(ctx, request, token) + token, err = fd.authenticate(ctx, request, token) if err != nil { // at this point in the workflow, the request has valid credentials. we use // StatusForbidden as the default because any failure to authenticate isn't a // case where the caller needs to supply credentials. Rather, the supplied // credentials aren't adequate enough. - fd.middleware.writeError(response, request, http.StatusForbidden, err) + fd.writeError(response, request, http.StatusForbidden, err) return } ctx = bascule.WithToken(ctx, token) - err = fd.middleware.authorize(ctx, token, request) + err = fd.authorize(ctx, token, request) if err != nil { - fd.middleware.writeError(response, request, http.StatusForbidden, err) + fd.writeError(response, request, http.StatusForbidden, err) return } diff --git a/basculehttp/middleware_examples_test.go b/basculehttp/middleware_examples_test.go index 04da19c..471bd6a 100644 --- a/basculehttp/middleware_examples_test.go +++ b/basculehttp/middleware_examples_test.go @@ -22,7 +22,7 @@ func ExampleMiddleware_simple() { // decorate a handler that needs authorization h := m.ThenFunc( func(response http.ResponseWriter, request *http.Request) { - t, ok := bascule.GetTokenFrom(request) + t, ok := bascule.GetFrom(request) if !ok { panic("no token found") } diff --git a/basculehttp/scheme.go b/basculehttp/scheme.go new file mode 100644 index 0000000..09e3186 --- /dev/null +++ b/basculehttp/scheme.go @@ -0,0 +1,15 @@ +// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package basculehttp + +// Scheme is the authorization header scheme, e.g. Basic, Bearer, etc. +type Scheme string + +const ( + // SchemeBasic is the Basic HTTP authorization scheme. + SchemeBasic Scheme = "Basic" + + // SchemeBearer is the Bearer HTTP authorization scheme. + SchemeBearer Scheme = "Bearer" +) diff --git a/basculehttp/token.go b/basculehttp/token.go deleted file mode 100644 index 52f2dc0..0000000 --- a/basculehttp/token.go +++ /dev/null @@ -1,28 +0,0 @@ -// SPDX-FileCopyrightText: 2020 Comcast Cable Communications Management, LLC -// SPDX-License-Identifier: Apache-2.0 - -package basculehttp - -import ( - "net/http" - - "github.com/xmidt-org/bascule/v1" -) - -// Token is bascule's default HTTP token. -type Token struct { - principal string -} - -func (t *Token) Principal() string { - return t.principal -} - -// DefaultTokenParsers returns the default suite of parsers supported by -// bascule. This method returns a distinct instance each time it is called, -// thus allowing calling code to tailor it independently of other calls. -func DefaultTokenParsers() bascule.TokenParsers[*http.Request] { - return bascule.TokenParsers[*http.Request]{ - BasicScheme: basicTokenParser{}, - } -} From 93f541ab01dabc1816dea01495a715a1cf137bf8 Mon Sep 17 00:00:00 2001 From: johnabass Date: Mon, 5 Aug 2024 16:04:36 -0700 Subject: [PATCH 5/6] fully implemented the authorization parsing; fixed middleware example --- basculehttp/authorization.go | 29 ++++++++++++++++++--- basculehttp/middleware_examples_test.go | 19 +++++++++++--- basculehttp/scheme.go | 34 +++++++++++++++++++++++++ 3 files changed, 75 insertions(+), 7 deletions(-) diff --git a/basculehttp/authorization.go b/basculehttp/authorization.go index 46a38a9..1ac13ae 100644 --- a/basculehttp/authorization.go +++ b/basculehttp/authorization.go @@ -72,7 +72,8 @@ func WithAuthorizationHeader(header string) AuthorizationParserOption { func WithScheme(scheme Scheme, parser bascule.TokenParser[string]) AuthorizationParserOption { return authorizationParserOptionFunc(func(ap *AuthorizationParser) error { - ap.parsers[scheme] = parser + // we want case-insensitive matches, so lowercase everything + ap.parsers[scheme.lower()] = parser return nil }) } @@ -100,6 +101,28 @@ func NewAuthorizationParser(opts ...AuthorizationParserOption) (*AuthorizationPa return ap, nil } -func (ap *AuthorizationParser) Parse(_ context.Context, source *http.Request) (bascule.Token, error) { - return nil, nil // TODO +// Parse extracts the appropriate header, Authorization by default, and parses the +// scheme and value. Schemes are case-insensitive, e.g. BASIC and Basic are the same scheme. +// +// If a token parser is registered for the given scheme, that token parser is invoked. +// Otherwise, UnsupportedSchemeError is returned, indicating the scheme in question. +func (ap *AuthorizationParser) Parse(ctx context.Context, source *http.Request) (bascule.Token, error) { + authValue := source.Header.Get(ap.header) + if len(authValue) == 0 { + return nil, bascule.ErrMissingCredentials + } + + scheme, value, err := ParseAuthorization(authValue) + if err != nil { + return nil, err + } + + p, registered := ap.parsers[scheme.lower()] + if !registered { + return nil, &UnsupportedSchemeError{ + Scheme: scheme, + } + } + + return p.Parse(ctx, value) } diff --git a/basculehttp/middleware_examples_test.go b/basculehttp/middleware_examples_test.go index 471bd6a..8d1417d 100644 --- a/basculehttp/middleware_examples_test.go +++ b/basculehttp/middleware_examples_test.go @@ -11,10 +11,21 @@ import ( "github.com/xmidt-org/bascule/v1" ) -// ExampleMiddleware_simple illustrates how to use a basculehttp Middleware with -// just the defaults. -func ExampleMiddleware_simple() { - m, err := NewMiddleware() // all defaults +// ExampleMiddleware_basicauth illustrates how to use a basculehttp Middleware with +// just basic auth. +func ExampleMiddleware_basicauth() { + tp, err := NewAuthorizationParser( + WithScheme(SchemeBasic, BasicTokenParser{}), + ) + + if err != nil { + panic(err) + } + + m, err := NewMiddleware( + WithTokenParsers(tp), + ) + if err != nil { panic(err) } diff --git a/basculehttp/scheme.go b/basculehttp/scheme.go index 09e3186..726d330 100644 --- a/basculehttp/scheme.go +++ b/basculehttp/scheme.go @@ -3,6 +3,11 @@ package basculehttp +import ( + "net/http" + "strings" +) + // Scheme is the authorization header scheme, e.g. Basic, Bearer, etc. type Scheme string @@ -13,3 +18,32 @@ const ( // SchemeBearer is the Bearer HTTP authorization scheme. SchemeBearer Scheme = "Bearer" ) + +// lower returns a lowercased version of this Scheme. Useful +// for ensuring case-insensitive matches. +func (s Scheme) lower() Scheme { + return Scheme( + strings.ToLower(string(s)), + ) +} + +// UnsupportedSchemeError is used to indicate that a particular HTTP Authorization +// scheme is not supported by the server. +type UnsupportedSchemeError struct { + Scheme Scheme +} + +// StatusCode marks this error as using the http.StatusUnauthorized code. +// This is appropriate for almost all cases, as this error occurs because +// the server does not accept or understand the scheme that the +// HTTP client supplied. +func (use *UnsupportedSchemeError) StatusCode() int { + return http.StatusUnauthorized +} + +func (use *UnsupportedSchemeError) Error() string { + var o strings.Builder + o.WriteString("Unsupported authorization scheme: ") + o.WriteString(string(use.Scheme)) + return o.String() +} From 4bcfc4eaf1026ef21ce5c71b098530eecd86c2d7 Mon Sep 17 00:00:00 2001 From: johnabass Date: Mon, 5 Aug 2024 16:09:27 -0700 Subject: [PATCH 6/6] use the new parsing interface --- basculejwt/token.go | 24 +++++++----------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/basculejwt/token.go b/basculejwt/token.go index b53bec9..e0dcb2d 100644 --- a/basculejwt/token.go +++ b/basculejwt/token.go @@ -37,16 +37,6 @@ type Claims interface { Subject() string } -// Token is the interface implemented by JWT-based tokens supplied by this package. -// User-defined claims can be accessed through the bascule.Attributes interface. -// -// Note that the Princpal method returns the same value as the Subject claim. -type Token interface { - bascule.Token - bascule.Attributes - Claims -} - // token is the internal implementation of the JWT Token interface. It fronts // a lestrrat-go Token. type token struct { @@ -59,15 +49,13 @@ func (t *token) Principal() string { // tokenParser is the canonical parser for bascule that deals with JWTs. // This parser does not use the source. -type tokenParser[S any] struct { +type tokenParser struct { options []jwt.ParseOption } // NewTokenParser constructs a parser using the supplied set of parse options. -// The returned parser will produce tokens that implement the Token interface -// in this package. -func NewTokenParser[S any](options ...jwt.ParseOption) (bascule.TokenParser[S], error) { - return &tokenParser[S]{ +func NewTokenParser(options ...jwt.ParseOption) (bascule.TokenParser[string], error) { + return &tokenParser{ options: append( make([]jwt.ParseOption, 0, len(options)), options..., @@ -75,8 +63,10 @@ func NewTokenParser[S any](options ...jwt.ParseOption) (bascule.TokenParser[S], }, nil } -func (tp *tokenParser[S]) Parse(_ context.Context, _ S, c bascule.Credentials) (bascule.Token, error) { - jwtToken, err := jwt.ParseString(c.Value, tp.options...) +// Parse parses the value as a JWT, using the parsing options passed to NewTokenParser. +// The returned Token will implement the bascule.Attributes and Claims interfaces. +func (tp *tokenParser) Parse(ctx context.Context, value string) (bascule.Token, error) { + jwtToken, err := jwt.ParseString(value, tp.options...) if err != nil { return nil, err }