From c826a49d703af335986cb1343c5ed5cf6207b858 Mon Sep 17 00:00:00 2001 From: johnabass Date: Mon, 26 Aug 2024 14:50:59 -0700 Subject: [PATCH 1/2] tokens can now have subtokens with conversions similar to the errors package --- mocks_test.go | 28 ++++++ token.go | 143 +++++++++++++++++++++++++++++- token_test.go | 238 ++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 408 insertions(+), 1 deletion(-) diff --git a/mocks_test.go b/mocks_test.go index c9bd335..0f8be0c 100644 --- a/mocks_test.go +++ b/mocks_test.go @@ -35,6 +35,34 @@ func (m *mockTokenWithCapabilities) ExpectCapabilities(caps ...string) *mock.Cal return m.On("Capabilities").Return(caps) } +type mockTokenUnwrapOne struct { + mockToken +} + +func (m *mockTokenUnwrapOne) Unwrap() Token { + args := m.Called() + t, _ := args.Get(0).(Token) + return t +} + +func (m *mockTokenUnwrapOne) ExpectUnwrap(t Token) *mock.Call { + return m.On("Unwrap").Return(t) +} + +type mockTokenUnwrapMany struct { + mockToken +} + +func (m *mockTokenUnwrapMany) Unwrap() []Token { + args := m.Called() + t, _ := args.Get(0).([]Token) + return t +} + +func (m *mockTokenUnwrapMany) ExpectUnwrap(t ...Token) *mock.Call { + return m.On("Unwrap").Return(t) +} + type mockValidator[S any] struct { mock.Mock } diff --git a/token.go b/token.go index 1ebf068..63f3613 100644 --- a/token.go +++ b/token.go @@ -30,13 +30,154 @@ var ( ) // Token is a runtime representation of credentials. This interface will be further -// customized by infrastructure. +// customized by infrastructure. A Token may have subtokens and may provide access +// to an arbitrary tree of subtokens by supplying either an 'Unwrap() Token' or +// an 'Unwrap() []Token' method. Subtokens are not required to have the same principal. type Token interface { // Principal is the security subject of this token, e.g. the user name or other // user identifier. Principal() string } +// MultiToken is an aggregate Token that is the root of a subtree of Tokens. +type MultiToken []Token + +// Principal returns the principal for the first token in this set, or +// the empty string if this set is empty. +func (mt MultiToken) Principal() string { + if len(mt) > 0 { + return mt[0].Principal() + } + + return "" +} + +// Unwrap provides access to this token's children. +func (mt MultiToken) Unwrap() []Token { + return []Token(mt) +} + +// JoinTokens joins multiple tokens into one. Any nil tokens are discarded. +// The principal of the returned token will always be the principal of the +// first non-nil token supplied to this function. +// +// If there is only (1) non-nil token, that token is returned as is. Otherwise, +// no attempt is made to flatten the set of tokens. If there are multiple non-nil +// tokens, the returned token will have an 'Unwrap() []Token' method to access +// the joined tokens individually. +// +// If no non-nil tokens are passed to this function, it returns nil. +func JoinTokens(tokens ...Token) Token { + if len(tokens) == 0 { + return nil + } + + mt := make(MultiToken, 0, len(tokens)) + for _, t := range tokens { + if t != nil { + mt = append(mt, t) + } + } + + switch len(mt) { + case 0: + return nil + + case 1: + return mt[0] + + default: + return mt + } +} + +// UnwrapToken does the opposite of JoinTokens. +// +// If the supplied token provides an 'Unwrap() Token' method, and that +// method returns a non-nil Token, the returned slice contains only that Token. +// +// If the supplied token provides an 'Unwrap() []Token' method, the +// result of that method is returned. +// +// Otherwise, this function returns nil. +func UnwrapToken(t Token) []Token { + switch u := t.(type) { + case interface{ Unwrap() Token }: + uu := u.Unwrap() + if uu != nil { + return []Token{uu} + } + + case interface{ Unwrap() []Token }: + return u.Unwrap() + } + + return nil +} + +var tokenType = reflect.TypeOf((*Token)(nil)).Elem() + +// tokenTargetValue produces a reflect value to set and the required type that +// a token must be convertible to. This function panics in all the same cases +// as errors.As. +func tokenTarget[T any](target *T) (targetValue reflect.Value, targetType reflect.Type) { + if target == nil { + panic("bascule: token target must be a non-nil pointer") + } + + targetValue = reflect.ValueOf(target) + targetType = targetValue.Type().Elem() + if targetType.Kind() != reflect.Interface && !targetType.Implements(tokenType) { + panic("bascule: *target must be an interface or implement Token") + } + + return +} + +// tokenAs is a recursive function that checks the Token tree to see if +// it can do a coversion to the targetType. targetValue will hold the +// result of the conversion. +func tokenAs(t Token, targetValue reflect.Value, targetType reflect.Type) bool { + if reflect.TypeOf(t).AssignableTo(targetType) { + targetValue.Elem().Set(reflect.ValueOf(t)) + return true + } + + switch u := t.(type) { + case interface{ Unwrap() Token }: + t = u.Unwrap() + if t != nil { + return tokenAs(t, targetValue, targetType) + } + + case interface{ Unwrap() []Token }: + for _, t := range u.Unwrap() { + if t != nil && tokenAs(t, targetValue, targetType) { + return true + } + } + } + + return false +} + +// TokenAs attempts to coerce the given Token into an arbitrary target. This function +// is similar to errors.As. If target is nil, this function panics. If target is neither +// an interface or a concrete implementation of the Token interface, this function +// also panics. +// +// The Token's tree is examined depth-first beginning with the given token and +// preceding down. If a token is found that is convertible to T, then target is set +// to that token and this function returns true. Otherwise, this function returns false. +func TokenAs[T any](t Token, target *T) bool { + if t == nil { + return false + } + + targetValue, targetType := tokenTarget(target) + return tokenAs(t, targetValue, targetType) +} + // TokenParser produces tokens from a source. The original source S of the credentials // are made available to the parser. type TokenParser[S any] interface { diff --git a/token_test.go b/token_test.go index 8062bd1..0fc3fa4 100644 --- a/token_test.go +++ b/token_test.go @@ -7,11 +7,249 @@ import ( "context" "errors" "fmt" + "strconv" "testing" "github.com/stretchr/testify/suite" ) +type TokenSuite struct { + suite.Suite +} + +func (suite *TokenSuite) TestMultiToken() { + suite.Run("Empty", func() { + var mt MultiToken + suite.Empty(mt.Principal()) + suite.Empty(mt.Unwrap()) + }) + + suite.Run("One", func() { + t := StubToken("test") + mt := MultiToken{t} + suite.Equal(t.Principal(), mt.Principal()) + + unwrapped := mt.Unwrap() + suite.Require().Len(unwrapped, 1) + suite.Equal(t, unwrapped[0]) + }) + + suite.Run("Several", func() { + var ( + m1 = StubToken("test") + m2 = StubToken("another") + m3 = StubToken("and another") + ) + + mt := MultiToken{m1, m2, m3} + suite.Equal(m1.Principal(), mt.Principal()) + + unwrapped := mt.Unwrap() + suite.Require().Len(unwrapped, 3) + suite.Equal(m1, unwrapped[0]) + suite.Equal(m2, unwrapped[1]) + suite.Equal(m3, unwrapped[2]) + }) +} + +func (suite *TokenSuite) TestJoinTokens() { + suite.Run("Nil", func() { + suite.Nil(JoinTokens()) + suite.Nil(JoinTokens(nil)) + suite.Nil(JoinTokens(nil, nil)) + suite.Nil(JoinTokens(nil, nil, nil)) + }) + + suite.Run("NonNil", func() { + testCases := []struct { + tokens []Token + expectedUnwrap []Token + }{ + { + tokens: []Token{StubToken("test")}, + expectedUnwrap: nil, + }, + { + tokens: []Token{nil, StubToken("test")}, + expectedUnwrap: nil, + }, + { + tokens: []Token{StubToken("test"), nil}, + expectedUnwrap: nil, + }, + { + tokens: []Token{nil, StubToken("test"), nil}, + expectedUnwrap: nil, + }, + { + tokens: []Token{StubToken("test"), StubToken("another"), StubToken("yet another")}, + expectedUnwrap: []Token{StubToken("test"), StubToken("another"), StubToken("yet another")}, + }, + { + tokens: []Token{StubToken("test"), nil, StubToken("another"), StubToken("yet another")}, + expectedUnwrap: []Token{StubToken("test"), StubToken("another"), StubToken("yet another")}, + }, + } + + for i, testCase := range testCases { + suite.Run(strconv.Itoa(i), func() { + joined := JoinTokens(testCase.tokens...) + suite.Equal("test", joined.Principal()) + suite.Equal( + testCase.expectedUnwrap, + UnwrapToken(joined), + ) + }) + } + }) +} + +func (suite *TokenSuite) TestUnwrapToken() { + suite.Run("Nil", func() { + suite.Nil(UnwrapToken(nil)) + }) + + suite.Run("Simple", func() { + suite.Nil( + UnwrapToken(StubToken("solo")), + ) + }) + + suite.Run("Scalar", func() { + t := StubToken("test") + m := new(mockTokenUnwrapOne) + m.ExpectUnwrap(t).Once() + + suite.Equal([]Token{t}, UnwrapToken(m)) + m.AssertExpectations(suite.T()) + }) + + suite.Run("Multi", func() { + t1 := StubToken("test") + t2 := StubToken("another") + m := new(mockTokenUnwrapMany) + m.ExpectUnwrap(t1, t2).Once() + + suite.Equal([]Token{t1, t2}, UnwrapToken(m)) + m.AssertExpectations(suite.T()) + }) +} + +func (suite *TokenSuite) testTokenAsNilToken() { + var target int // won't matter + suite.False( + TokenAs(nil, &target), + ) +} + +func (suite *TokenSuite) testTokenAsNilTarget() { + m := new(mockToken) + suite.Panics(func() { + TokenAs[int](m, nil) + }) + + m.AssertExpectations(suite.T()) +} + +func (suite *TokenSuite) testTokenAsInvalidTargetType() { + var invalid int // not an interface and does not implement Token + m := new(mockToken) + suite.Panics(func() { + TokenAs[int](m, &invalid) + }) + + m.AssertExpectations(suite.T()) +} + +// wrapToken wraps the given token within another, and returns the wrapper. +func (suite *TokenSuite) wrapToken(t Token) Token { + wrapper := new(mockTokenUnwrapOne) + wrapper.ExpectUnwrap(t).Maybe() + return wrapper +} + +func (suite *TokenSuite) testTokenAsConcreteType() { + suite.Run("Trivial", func() { + var target StubToken + t := StubToken("test") + suite.True(TokenAs(t, &target)) + suite.Equal(t, target) + }) + + suite.Run("NoConversion", func() { + var target StubToken + m := new(mockToken) + suite.False(TokenAs(m, &target)) + }) + + suite.Run("Chain", func() { + nested := StubToken("test") + wrapper := suite.wrapToken( + suite.wrapToken(nested), + ) + + var target StubToken + suite.True(TokenAs(wrapper, &target)) + suite.Equal(nested, target) + }) + + suite.Run("Tree", func() { + nested := StubToken("test") + wrapper := JoinTokens(new(mockToken), nested, new(mockToken)) + + var target StubToken + suite.True(TokenAs(wrapper, &target)) + suite.Equal(nested, target) + }) +} + +func (suite *TokenSuite) testTokenAsInterface() { + suite.Run("Trivial", func() { + var target CapabilitiesAccessor + m := new(mockTokenWithCapabilities) + suite.True(TokenAs(m, &target)) + suite.Same(m, target) + }) + + suite.Run("NoConversion", func() { + var target CapabilitiesAccessor + m := new(mockToken) + suite.False(TokenAs(m, &target)) + }) + + suite.Run("Chain", func() { + nested := new(mockTokenWithCapabilities) + wrapper := suite.wrapToken( + suite.wrapToken(nested), + ) + + var target CapabilitiesAccessor + suite.True(TokenAs(wrapper, &target)) + suite.Same(nested, target) + }) + + suite.Run("Tree", func() { + nested := new(mockTokenWithCapabilities) + wrapper := JoinTokens(new(mockToken), nested, new(mockToken)) + + var target CapabilitiesAccessor + suite.True(TokenAs(wrapper, &target)) + suite.Same(nested, target) + }) +} + +func (suite *TokenSuite) TestTokenAs() { + suite.Run("NilToken", suite.testTokenAsNilToken) + suite.Run("NilTarget", suite.testTokenAsNilTarget) + suite.Run("InvalidTargetType", suite.testTokenAsInvalidTargetType) + suite.Run("ConcreteType", suite.testTokenAsConcreteType) + suite.Run("Interface", suite.testTokenAsInterface) +} + +func TestToken(t *testing.T) { + suite.Run(t, new(TokenSuite)) +} + type TokenParserSuite struct { TestSuite From 69196cfe8ff7488026666f070b231b8b15ca046e Mon Sep 17 00:00:00 2001 From: johnabass Date: Mon, 26 Aug 2024 15:06:38 -0700 Subject: [PATCH 2/2] working example of augmenting a token --- authenticator_examples_test.go | 57 ++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 authenticator_examples_test.go diff --git a/authenticator_examples_test.go b/authenticator_examples_test.go new file mode 100644 index 0000000..def9eeb --- /dev/null +++ b/authenticator_examples_test.go @@ -0,0 +1,57 @@ +// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package bascule + +import ( + "context" + "fmt" +) + +type Extra struct { + Name string + Age int +} + +func (e Extra) Principal() string { return e.Name } + +// ExampleJoinTokens_augment shows how to augment a Token as part +// of authentication workflow. +func ExampleJoinTokens_augment() { + original := StubToken("original") + authenticator, _ := NewAuthenticator[string]( + WithTokenParsers( + StubTokenParser[string]{ + Token: original, + }, + ), + WithValidators( + AsValidator[string]( + func(t Token) (Token, error) { + // augment this token with extra information + return JoinTokens(t, Extra{Name: "extra", Age: 33}), nil + }, + ), + ), + ) + + authenticated, _ := authenticator.Authenticate( + context.Background(), + "source", + ) + + fmt.Println("authenticated principal:", authenticated.Principal()) + + var extra Extra + if !TokenAs(authenticated, &extra) { + panic("token cannot be converted") + } + + fmt.Println("extra.Name:", extra.Name) + fmt.Println("extra.Age:", extra.Age) + + // Output: + // authenticated principal: original + // extra.Name: extra + // extra.Age: 33 +}