From 04d023265a476334fde7670105bc66c1bd1cca50 Mon Sep 17 00:00:00 2001 From: johnabass Date: Fri, 16 Aug 2024 08:07:20 -0700 Subject: [PATCH] refactored challenge API --- basculehttp/authorization.go | 7 - basculehttp/challenge.go | 260 ++++++++++++++++++++++++---------- basculehttp/challenge_test.go | 235 ++++++++++++++++++++++++++++++ basculehttp/fastIsSpace.go | 23 +++ basculehttp/middleware.go | 42 ++++-- 5 files changed, 477 insertions(+), 90 deletions(-) create mode 100644 basculehttp/challenge_test.go create mode 100644 basculehttp/fastIsSpace.go diff --git a/basculehttp/authorization.go b/basculehttp/authorization.go index 1ac13ae..f76c789 100644 --- a/basculehttp/authorization.go +++ b/basculehttp/authorization.go @@ -28,13 +28,6 @@ var ( ErrMissingAuthorization = errors.New("missing authorization") ) -// fastIsSpace tests an ASCII byte to see if it's whitespace. -// HTTP headers are restricted to US-ASCII, so we don't need -// the full unicode stack. -func fastIsSpace(b byte) bool { - return b == ' ' || b == '\t' || b == '\n' || b == '\r' || b == '\v' || b == '\f' -} - // ParseAuthorization parses an authorization value typically passed in // the Authorization HTTP header. // diff --git a/basculehttp/challenge.go b/basculehttp/challenge.go index a9171a4..73a1b24 100644 --- a/basculehttp/challenge.go +++ b/basculehttp/challenge.go @@ -4,115 +4,231 @@ package basculehttp import ( + "errors" "net/http" "strings" ) const ( - // WwwAuthenticateHeaderName is the HTTP header used for StatusUnauthorized challenges. - WwwAuthenticateHeaderName = "WWW-Authenticate" + // WWWAuthenticateHeaderName is the HTTP header used for StatusUnauthorized challenges + // when encountered by the Middleware. + // + // This value is used by default when no header is supplied to Challenges.WriteHeader. + WWWAuthenticateHeaderName = "WWW-Authenticate" +) + +var ( + // ErrInvalidChallengeScheme indicates that a scheme was improperly formatted. Usually, + // this methods the scheme was either blank or contained whitespace. + ErrInvalidChallengeScheme = errors.New("Invalid challenge auth scheme") - // DefaultBasicRealm is the realm used for a basic challenge - // when no realm is supplied. - DefaultBasicRealm string = "bascule" + // ErrInvalidChallengeParameter indicates that an attempt was made to an a challenge + // auth parameter that wasn't validly formatted. Usually, this means that the + // name contained whitespace. + ErrInvalidChallengeParameter = errors.New("Invalid challenge auth parameter") - // DefaultBearerRealm is the realm used for a bearer challenge - // when no realm is supplied. - DefaultBearerRealm string = "bascule" + // ErrReservedChallengeParameter indicates that an attempt was made to add a + // challenge auth parameter that was reserved by the RFC. + ErrReservedChallengeParameter = errors.New("Reserved challenge auth parameter") ) -// Challenge represents a WWW-Authenticate challenge. -type Challenge interface { - // FormatAuthenticate formats the authenticate string. - FormatAuthenticate(strings.Builder) +// reservedChallengeParameterNames holds the names of reserved challenge auth parameters +// that cannot be added to a ChallengeParameters. +var reservedChallengeParameterNames = map[string]bool{ + "realm": true, + "token68": true, } -// Challenges represents a sequence of challenges to associated with -// a StatusUnauthorized response. -type Challenges []Challenge +// ChallengeParameters holds the set of parameters. The zero value of this +// type is ready to use. This type handles writing parameters as well as +// provides commonly used parameter names for convenience. +type ChallengeParameters struct { + names, values []string + byName map[string]int // the parameter index +} -// Add appends challenges to this set. -func (chs *Challenges) Add(ch ...Challenge) { - if *chs == nil { - *chs = make(Challenges, 0, len(ch)) +// Len returns the number of name/value pairs contained in these parameters. +func (cp *ChallengeParameters) Len() int { + return len(cp.names) +} + +// Set sets the value of a parameter. If a parameter was already set, it is +// ovewritten. +// +// If the parameter name is invalid, this method raises an error. +func (cp *ChallengeParameters) Set(name, value string) (err error) { + switch { + case len(name) == 0: + err = ErrInvalidChallengeParameter + + case fastContainsSpace(name): + err = ErrInvalidChallengeParameter + + case reservedChallengeParameterNames[name]: + err = ErrReservedChallengeParameter + + default: + if i, exists := cp.byName[name]; exists { + cp.values[i] = value + } else { + if cp.byName == nil { + cp.byName = make(map[string]int) + } + + cp.byName[name] = len(cp.names) + cp.names = append(cp.names, name) + cp.values = append(cp.values, value) + } } - *chs = append(*chs, ch...) + return } -// WriteHeader inserts one WWW-Authenticate header per challenge in this set. -// If this set is empty, the given http.Header is not modified. -func (chs Challenges) WriteHeader(h http.Header) { +// Charset sets a charset auth parameter. Basic auth is the main scheme +// that uses this. +func (cp *ChallengeParameters) Charset(value string) error { + return cp.Set("charset", value) +} + +// Write formats this challenge to the given builder. +func (cp *ChallengeParameters) Write(o *strings.Builder) { + for i := 0; i < len(cp.names); i++ { + if i > 0 { + o.WriteString(", ") + } + + o.WriteString(cp.names[i]) + o.WriteString(`="`) + o.WriteString(cp.values[i]) + o.WriteRune('"') + } +} + +// String returns the RFC 7235 format of these parameters. +func (cp *ChallengeParameters) String() string { var o strings.Builder - for _, ch := range chs { - ch.FormatAuthenticate(o) - h.Add(WwwAuthenticateHeaderName, o.String()) - o.Reset() + cp.Write(&o) + return o.String() +} + +// NewChallengeParameters creates a ChallengeParameters from a sequence of name/value pairs. +// The strings are expected to be in name, value, name, value, ... sequence. If the number +// of strings is odd, then the last parameter will have a blank value. +// +// If any error occurs while setting parameters, execution is halted and that +// error is returned. +func NewChallengeParameters(s ...string) (cp ChallengeParameters, err error) { + for i, j := 0, 1; err == nil && i < len(s); i, j = i+2, j+2 { + if j < len(s) { + err = cp.Set(s[i], s[j]) + } else { + err = cp.Set(s[i], "") + } } + + return } -// BasicChallenge represents a WWW-Authenticate basic auth challenge. -type BasicChallenge struct { - // Scheme is the name of scheme supplied in the challenge. If this - // field is unset, BasicScheme is used. +// Challenge represets an HTTP authentication challenge, as defined by RFC 7235. +type Challenge struct { + // Scheme is the name of scheme supplied in the challenge. This field is required. Scheme Scheme - // Realm is the name of the realm for the challenge. If this field - // is unset, DefaultBasicRealm is used. - // - // Note that this field should always be set. The default isn't very - // useful outside of development. + // Realm is the name of the realm for the challenge. This field is + // optional, but it is HIGHLY recommended to set it to something useful + // to a client. Realm string - // UTF8 indicates whether "charset=UTF-8" is appended to the challenge. - // This is the only charset allowed for a Basic challenge. - UTF8 bool + // Token68 controls whether the token68 flag is written in the challenge. + Token68 bool + + // Parameters are the optional auth parameters. + Parameters ChallengeParameters } -func (bc BasicChallenge) FormatAuthenticate(o strings.Builder) { - if len(bc.Scheme) > 0 { - o.WriteString(string(bc.Scheme)) - } else { - o.WriteString(string(SchemeBasic)) +// Write formats this challenge to the given builder. Any error halts +// formatting and that error is returned. +func (c Challenge) Write(o *strings.Builder) (err error) { + s := string(c.Scheme) + switch { + case len(s) == 0: + err = ErrInvalidChallengeScheme + + case fastContainsSpace(s): + err = ErrInvalidChallengeScheme + + default: + o.WriteString(s) + if len(c.Realm) > 0 { + o.WriteString(` realm="`) + o.WriteString(c.Realm) + o.WriteRune('"') + } + + if c.Token68 { + o.WriteString(" token68") + } + + if c.Parameters.Len() > 0 { + o.WriteRune(' ') + c.Parameters.Write(o) + } } - o.WriteString(` realm="`) - if len(bc.Realm) > 0 { - o.WriteString(bc.Realm) - } else { - o.WriteString(DefaultBasicRealm) + return +} + +// NewBasicChallenge is a convenience for creating a Challenge for basic auth. +// +// Although realm is optional, it is HIGHLY recommended to set it to something +// recognizable for a client. +func NewBasicChallenge(realm string, UTF8 bool) (c Challenge, err error) { + c = Challenge{ + Scheme: SchemeBasic, + Realm: realm, } - o.WriteRune('"') - if bc.UTF8 { - o.WriteString(`, charset="UTF-8"`) + if UTF8 { + err = c.Parameters.Charset("UTF-8") } + + return } -type BearerChallenge struct { - // Scheme is the name of scheme supplied in the challenge. If this - // field is unset, BearerScheme is used. - Scheme Scheme +// Challenges represents a sequence of challenges to associated with +// a StatusUnauthorized response. +type Challenges []Challenge - // Realm is the name of the realm for the challenge. If this field - // is unset, DefaultBearerRealm is used. - // - // Note that this field should always be set. The default isn't very - // useful outside of development. - Realm string +// Append appends challenges to this set. The semantics of this +// method are the same as the built-in append. +func (chs Challenges) Append(ch ...Challenge) Challenges { + return append(chs, ch...) } -func (bc BearerChallenge) FormatAuthenticate(o strings.Builder) { - if len(bc.Scheme) > 0 { - o.WriteString(string(bc.Scheme)) - } else { - o.WriteString(string(SchemeBearer)) +// WriteHeader inserts one Http authenticate header per challenge in this set. +// If this set is empty, the given http.Header is not modified. +// +// The name is used as the header name for each header this method writes. +// Typically, this will be WWW-Authenticate or Proxy-Authenticate. If name +// is blank, WWWAuthenticateHeaderName is used. +// +// If any challenge returns an error during formatting, execution is +// halted and that error is returned. +func (chs Challenges) WriteHeader(name string, h http.Header) error { + if len(name) == 0 { + name = WWWAuthenticateHeaderName } - o.WriteString(` realm="`) - if len(bc.Realm) > 0 { - o.WriteString(bc.Realm) - } else { - o.WriteString(DefaultBasicRealm) + var o strings.Builder + for _, ch := range chs { + err := ch.Write(&o) + if err != nil { + return err + } + + h.Add(name, o.String()) + o.Reset() } + + return nil } diff --git a/basculehttp/challenge_test.go b/basculehttp/challenge_test.go new file mode 100644 index 0000000..b49e494 --- /dev/null +++ b/basculehttp/challenge_test.go @@ -0,0 +1,235 @@ +// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package basculehttp + +import ( + "net/http" + "strconv" + "strings" + "testing" + + "github.com/stretchr/testify/suite" +) + +type ChallengeTestSuite struct { + suite.Suite +} + +// newValidParameters creates a test set of parameters and asserts +// that they're value. +func (suite *ChallengeTestSuite) newValidParameters(s ...string) ChallengeParameters { + cp, err := NewChallengeParameters(s...) + suite.Require().NoError(err) + return cp +} + +// newValidBasic uses NewBasicChallenge to create a Challenge and asserts +// that no error occurred. +func (suite *ChallengeTestSuite) newValidBasic(realm string, UTF8 bool) Challenge { + c, err := NewBasicChallenge(realm, UTF8) + suite.Require().NoError(err) + return c +} + +func (suite *ChallengeTestSuite) TestChallengeParameters() { + suite.Run("Invalid", func() { + badParameterNames := []string{ + "", + " ", + "this is not ok", + "neither\tis\bthis", + "token68", // reserved + "realm", // reserved + } + + for i, bad := range badParameterNames { + suite.Run(strconv.Itoa(i), func() { + var cp ChallengeParameters + suite.Error(cp.Set(bad, "value")) + }) + } + }) + + suite.Run("Duplicate", func() { + var cp ChallengeParameters + suite.NoError(cp.Set("name", "value1")) + suite.NoError(cp.Set("another", "somevalue")) + suite.NoError(cp.Set("name", "value2")) + suite.Equal( + `name="value2", another="somevalue"`, + cp.String(), + ) + }) +} + +func (suite *ChallengeTestSuite) testChallengeValid() { + testCases := []struct { + challenge Challenge + expectedFormat string + }{ + { + challenge: Challenge{ + Scheme: SchemeBasic, + }, + expectedFormat: `Basic`, + }, + { + challenge: Challenge{ + Scheme: SchemeBasic, + Realm: "test", + }, + expectedFormat: `Basic realm="test"`, + }, + { + challenge: suite.newValidBasic("", false), + expectedFormat: `Basic`, + }, + { + challenge: suite.newValidBasic("test", false), + expectedFormat: `Basic realm="test"`, + }, + { + challenge: suite.newValidBasic("test@example.com", true), + expectedFormat: `Basic realm="test@example.com" charset="UTF-8"`, + }, + { + challenge: Challenge{ + Scheme: Scheme("Custom"), + Realm: "test@example.com", + Parameters: suite.newValidParameters( + "nonce", "this is a nonce", + "qop", "a, b, c", + "custom", "1234", + ), + }, + expectedFormat: `Custom realm="test@example.com" nonce="this is a nonce", qop="a, b, c", custom="1234"`, + }, + { + challenge: Challenge{ + Scheme: Scheme("Bearer"), + Realm: "my realm", + Token68: true, + Parameters: suite.newValidParameters( + "nonce", "this is a nonce", + "blank", + ), + }, + expectedFormat: `Bearer realm="my realm" token68 nonce="this is a nonce", blank=""`, + }, + } + + for i, testCase := range testCases { + suite.Run(strconv.Itoa(i), func() { + var o strings.Builder + suite.NoError(testCase.challenge.Write(&o)) + suite.Equal(testCase.expectedFormat, o.String()) + }) + } +} + +func (suite *ChallengeTestSuite) testChallengeInvalid() { + badChallenges := []Challenge{ + Challenge{}, // blank scheme + Challenge{ + Scheme: Scheme("this is not a valid scheme"), + }, + } + + for i, bad := range badChallenges { + suite.Run(strconv.Itoa(i), func() { + var o strings.Builder + suite.Error(bad.Write(&o)) + }) + } +} + +func (suite *ChallengeTestSuite) TestChallenge() { + suite.Run("Valid", suite.testChallengeValid) + suite.Run("Invalid", suite.testChallengeInvalid) +} + +func (suite *ChallengeTestSuite) testChallengesValid() { + testCases := []struct { + challenges Challenges + expected []string + }{ + { + challenges: Challenges{}, // empty is always valid and should do nothing + expected: nil, + }, + { + challenges: Challenges{}. + Append( + suite.newValidBasic("test@server.com", true), + ), + expected: []string{ + `Basic realm="test@server.com" charset="UTF-8"`, + }, + }, + { + challenges: Challenges{}. + Append(Challenge{ + Scheme: Scheme("Bearer"), + Realm: "my realm", + Parameters: suite.newValidParameters("foo", "bar"), + }). + Append(Challenge{ + Scheme: Scheme("Custom"), + Realm: "another realm@somewhere.net", + Token68: true, + Parameters: suite.newValidParameters("nonce", "this is a nonce", "age", "123"), + }), + expected: []string{ + `Bearer realm="my realm" foo="bar"`, + `Custom realm="another realm@somewhere.net" token68 nonce="this is a nonce", age="123"`, + }, + }, + } + + for i, testCase := range testCases { + suite.Run(strconv.Itoa(i), func() { + suite.Run("DefaultHeader", func() { + header := make(http.Header) + suite.NoError(testCase.challenges.WriteHeader("", header)) + suite.ElementsMatch(testCase.expected, header.Values(WWWAuthenticateHeaderName)) + }) + + suite.Run("CustomHeader", func() { + header := make(http.Header) + suite.NoError(testCase.challenges.WriteHeader("Custom", header)) + suite.ElementsMatch(testCase.expected, header.Values("Custom")) + }) + }) + } +} + +func (suite *ChallengeTestSuite) testChallengesInvalid() { + badChallenges := []Challenges{ + Challenges{}.Append(Challenge{ + Scheme: Scheme("bad scheme"), + }), + Challenges{}.Append(Challenge{ + Scheme: Scheme("Good"), + }). + Append(Challenge{ + Scheme: Scheme("bad scheme"), + }), + } + + for i, bad := range badChallenges { + suite.Run(strconv.Itoa(i), func() { + header := make(http.Header) + suite.Error(bad.WriteHeader("", header)) + }) + } +} + +func (suite *ChallengeTestSuite) TestChallenges() { + suite.Run("Valid", suite.testChallengesValid) + suite.Run("Invalid", suite.testChallengesInvalid) +} + +func TestChallenge(t *testing.T) { + suite.Run(t, new(ChallengeTestSuite)) +} diff --git a/basculehttp/fastIsSpace.go b/basculehttp/fastIsSpace.go new file mode 100644 index 0000000..691e572 --- /dev/null +++ b/basculehttp/fastIsSpace.go @@ -0,0 +1,23 @@ +// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package basculehttp + +// fastIsSpace tests an ASCII byte to see if it's whitespace. +// HTTP headers are restricted to US-ASCII, so we don't need +// the full unicode stack. +func fastIsSpace(b byte) bool { + return b == ' ' || b == '\t' || b == '\n' || b == '\r' || b == '\v' || b == '\f' +} + +// fastContainsSpace uses fastIsSpace on each character in a string +// until it finds a space. +func fastContainsSpace(v string) bool { + for i := 0; i < len(v); i++ { + if fastIsSpace(v[i]) { + return true + } + } + + return false +} diff --git a/basculehttp/middleware.go b/basculehttp/middleware.go index 6f6348a..4529cba 100644 --- a/basculehttp/middleware.go +++ b/basculehttp/middleware.go @@ -72,7 +72,7 @@ func UseAuthorizer(authorizer *bascule.Authorizer[*http.Request], err error) Mid // in a separate WWW-Authenticate header, in the order specified by this option. func WithChallenges(ch ...Challenge) MiddlewareOption { return middlewareOptionFunc(func(m *Middleware) error { - m.challenges.Add(ch...) + m.challenges = m.challenges.Append(ch...) return nil }) } @@ -152,31 +152,51 @@ func (m *Middleware) ThenFunc(protected http.HandlerFunc) http.Handler { return m.Then(protected) } -// writeError handles writing error information to an HTTP response. This will include any WWW-Authenticate -// challenges that are configured if a 401 status is detected. +// writeRawError is a fallback to write an error that came from this package. +// The response is always a text/plain representation of the error. +func (m *Middleware) writeRawError(response http.ResponseWriter, err error) { + response.WriteHeader(http.StatusInternalServerError) + response.Header().Set("Content-Type", "text/plain") + + errBody := []byte(err.Error()) + response.Header().Set("Content-Length", strconv.Itoa(len(errBody))) + response.Write(errBody) +} + +// writeWorkflowError handles writing an error that came from the bascule workflow to an HTTP request. +// This will include writing any HTTP challenges if a 401 status is detected. // // The defaultCode is used as the response status code if the given error does not supply a StatusCode method. // // If the error supports JSON or text marshaling, that is used for the response body. Otherwise, a text/plain // response with the Error() method's text is used. -func (m *Middleware) writeError(response http.ResponseWriter, request *http.Request, defaultCode int, err error) { +func (m *Middleware) writeWorkflowError(response http.ResponseWriter, request *http.Request, defaultCode int, err error) { statusCode := m.errorStatusCoder(request, err) if statusCode < 100 { statusCode = defaultCode } + var ( + contentType string + content []byte + writeErr error + ) + if statusCode == http.StatusUnauthorized { - m.challenges.WriteHeader(response.Header()) + writeErr = m.challenges.WriteHeader("", response.Header()) } - contentType, content, marshalErr := m.errorMarshaler(request, err) + if writeErr == nil { + contentType, content, writeErr = m.errorMarshaler(request, err) + } - // TODO: what if marshalErr != nil ? - if marshalErr == nil { + if writeErr != nil { + m.writeRawError(response, writeErr) + } else { response.Header().Set("Content-Type", contentType) response.Header().Set("Content-Length", strconv.Itoa(len(content))) response.WriteHeader(statusCode) - response.Write(content) // TODO: handle errors here somehow + response.Write(content) } } @@ -193,7 +213,7 @@ func (fd *frontDoor) ServeHTTP(response http.ResponseWriter, request *http.Reque token, err := fd.authenticator.Authenticate(ctx, request) if err != nil { // by default, failing to parse a token is a malformed request - fd.writeError(response, request, http.StatusBadRequest, err) + fd.writeWorkflowError(response, request, http.StatusBadRequest, err) return } @@ -203,7 +223,7 @@ func (fd *frontDoor) ServeHTTP(response http.ResponseWriter, request *http.Reque if fd.authorizer != nil { err = fd.authorizer.Authorize(ctx, request, token) if err != nil { - fd.writeError(response, request, http.StatusForbidden, err) + fd.writeWorkflowError(response, request, http.StatusForbidden, err) return } }