Skip to content

Commit

Permalink
feat: auth middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
washanhanzi committed Nov 27, 2023
1 parent 05babba commit e689800
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 45 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@
used with caution, still in development

```go
//default bearer token extractor and parser
//extract token from "Authorization": Bearer <token>, and parse token into jwt.MapClaim
authMiddleware, err := middleware.NewAuthMiddleware(middleware.WithDefaultBearerExtractorAndParser([]byte("secret")))
if err != nil {
panic(err)
}
http.ListenAndServe(
"localhost:8080",
// Use h2c so we can serve HTTP/2 without TLS.
h2c.NewHandler(authMiddleware.Wrap(mux), &http2.Server{}),
)
```

## TODO
Expand Down
43 changes: 21 additions & 22 deletions authInterceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,13 @@ However, the added generic has no benefit when extract the value from context, u
*/
type authInterceptor struct {
ServiceHandlerType
parser Parser
clientHandler ClientTokenGetter
serviceHandler *AuthHandler
}

type opt func(*authInterceptor)
type authInterceptorOpt func(*authInterceptor)

func NewAuthInterceptor(opts ...opt) (*authInterceptor, error) {
func NewAuthInterceptor(opts ...authInterceptorOpt) (*authInterceptor, error) {
i := authInterceptor{
ServiceHandlerType: UnaryHandler,
serviceHandler: &AuthHandler{
Expand Down Expand Up @@ -70,24 +69,24 @@ func (i *authInterceptor) preventNilServiceHandler() {
}
}

func WithDefaultBearerExtractor() opt {
func WithInterceptorDefaultBearerExtractor() authInterceptorOpt {
return func(i *authInterceptor) {
i.preventNilServiceHandler()
i.serviceHandler.Extractor = DefaultBearerTokenExtractor().ToExtractor()
i.serviceHandler.Extractor = DefaultBasicExtractor().ToExtractor()
}
}

func WithDefaultBearerExtractorAndParser(signningKey any) opt {
func WithInterceptorDefaultBearerExtractorAndParser(signningKey any) authInterceptorOpt {
return func(i *authInterceptor) {
i.preventNilServiceHandler()
i.parser = DefaultJWTMapClaimsParser(signningKey)
i.serviceHandler.Extractor = DefaultBearerTokenExtractor().ToExtractor()
i.serviceHandler.Parser = DefaultJWTMapClaimsParser(signningKey)
i.serviceHandler.Extractor = DefaultBasicExtractor().ToExtractor()
}
}

func WithDefaultJWTMapClaimsParser(signningKey any) opt {
func WithInterceptorDefaultJWTMapClaimsParser(signningKey any) authInterceptorOpt {
return func(i *authInterceptor) {
i.parser = DefaultJWTMapClaimsParser(signningKey)
i.serviceHandler.Parser = DefaultJWTMapClaimsParser(signningKey)
}
}

Expand All @@ -97,14 +96,14 @@ func WithDefaultJWTMapClaimsParser(signningKey any) opt {
// func(ctx context.Context) jwt.Claims{
// return &jwt.MapClaims{}
// }
func WithCustomJWTClaimsParser(signningKey any, claimsFunc func(context.Context) jwt.Claims) opt {
func WithInterceptorCustomJWTClaimsParser(signningKey any, claimsFunc func(context.Context) jwt.Claims) authInterceptorOpt {
return func(i *authInterceptor) {
p, _ := NewJWTParser(WithSigningKey(signningKey), WithNewClaimsFunc(claimsFunc))
i.parser = p.ToParser()
i.serviceHandler.Parser = p.ToParser()
}
}

func WithIgnoreError() opt {
func WithInterceptorIgnoreError() authInterceptorOpt {
return func(i *authInterceptor) {
i.preventNilServiceHandler()
i.serviceHandler.ErrorHandler = func(context.Context, *Request, error) error {
Expand All @@ -114,55 +113,55 @@ func WithIgnoreError() opt {
}

// WithClientTokenGetter sets client token getter when the interceptor in client side
func WithClientTokenGetter(getter ClientTokenGetter) opt {
func WithInterceptorClientTokenGetter(getter ClientTokenGetter) authInterceptorOpt {
return func(i *authInterceptor) {
i.clientHandler = getter
}
}

// WithUnarySkipper skip the interceptor for unary handler
func WithSkipper(s Skipper) opt {
func WithInterceptorSkipper(s Skipper) authInterceptorOpt {
return func(i *authInterceptor) {
i.preventNilServiceHandler()
i.serviceHandler.Skipper = s
}
}

func WithBeforeFunc(fn BeforeOrSuccessFunc) opt {
func WithInterceptorBeforeFunc(fn BeforeOrSuccessFunc) authInterceptorOpt {
return func(i *authInterceptor) {
i.preventNilServiceHandler()
i.serviceHandler.BeforeFunc = fn
}
}

func WithSuccessFunc(fn BeforeOrSuccessFunc) opt {
func WithInterceptorSuccessFunc(fn BeforeOrSuccessFunc) authInterceptorOpt {
return func(i *authInterceptor) {
i.preventNilServiceHandler()
i.serviceHandler.SuccessFunc = fn
}
}

func WithErrorHandler(fn ErrorHandle) opt {
func WithInterceptorErrorHandler(fn ErrorHandle) authInterceptorOpt {
return func(i *authInterceptor) {
i.preventNilServiceHandler()
i.serviceHandler.ErrorHandler = fn
}
}

func WithExtractor(fn Extractor) opt {
func WithInterceptorExtractor(fn Extractor) authInterceptorOpt {
return func(i *authInterceptor) {
i.preventNilServiceHandler()
i.serviceHandler.Extractor = fn
}
}

func WithParser(p Parser) opt {
func WithInterceptorParser(p Parser) authInterceptorOpt {
return func(i *authInterceptor) {
i.parser = p
i.serviceHandler.Parser = p
}
}

func WithServiceHandlerType(s ServiceHandlerType) opt {
func WithServiceHandlerType(s ServiceHandlerType) authInterceptorOpt {
return func(i *authInterceptor) {
i.ServiceHandlerType = s
}
Expand Down
119 changes: 114 additions & 5 deletions authMiddleware.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,131 @@
package middleware

import (
"context"
"net/http"
"strings"

"connectrpc.com/connect"
"github.com/cockroachdb/errors"
"github.com/golang-jwt/jwt/v5"
)

type authMiddleware struct {
handler *AuthHandler
errW *connect.ErrorWriter
}

func NewAuthMiddleware(handler *AuthHandler) *authMiddleware {
return &authMiddleware{
handler: handler,
//TODO opts
errW: connect.NewErrorWriter(),
func NewAuthMiddleware(opts ...authMiddlewareOpt) (*authMiddleware, error) {
m := authMiddleware{}
for _, o := range opts {
o(&m)
}
if m.handler == nil {
return nil, errors.New("no handler set")
}
if m.errW == nil {
m.errW = connect.NewErrorWriter()
}
return &m, nil
}

type authMiddlewareOpt func(*authMiddleware)

func WithErrorWriterOpts(opts ...connect.HandlerOption) authMiddlewareOpt {
return func(m *authMiddleware) {
m.errW = connect.NewErrorWriter(opts...)
}
}

func (m *authMiddleware) preventNilHandler() {
if m.handler == nil {
m.handler = &AuthHandler{
Skipper: DefaultSkipper,
}
}
}

func WithDefaultBearerExtractor() authMiddlewareOpt {
return func(m *authMiddleware) {
m.preventNilHandler()
m.handler.Extractor = DefaultBasicExtractor().ToExtractor()
}
}

func WithDefaultBearerExtractorAndParser(signningKey any) authMiddlewareOpt {
return func(m *authMiddleware) {
m.preventNilHandler()
m.handler.Extractor = DefaultBearerTokenExtractor().ToExtractor()
m.handler.Parser = DefaultJWTMapClaimsParser(signningKey)
}
}

func WithDefaultJWTMapClaimsParser(signningKey any) authMiddlewareOpt {
return func(m *authMiddleware) {
m.handler.Parser = DefaultJWTMapClaimsParser(signningKey)
}
}

// WithCustomJWTClaimsParser sets Parser with signning key and a claimsFunc, the claimsFunc must return a reference
// for example:
//
// func(ctx context.Context) jwt.Claims{
// return &jwt.MapClaims{}
// }
func WithCustomJWTClaimsParser(signningKey any, claimsFunc func(context.Context) jwt.Claims) authMiddlewareOpt {
return func(m *authMiddleware) {
p, _ := NewJWTParser(WithSigningKey(signningKey), WithNewClaimsFunc(claimsFunc))
m.handler.Parser = p.ToParser()
}
}

func WithIgnoreError() authMiddlewareOpt {
return func(m *authMiddleware) {
m.preventNilHandler()
m.handler.ErrorHandler = func(context.Context, *Request, error) error {
return nil
}
}
}

// WithUnarySkipper skip the interceptor for unary handler
func WithSkipper(s Skipper) authMiddlewareOpt {
return func(m *authMiddleware) {
m.preventNilHandler()
m.handler.Skipper = s
}
}

func WithBeforeFunc(fn BeforeOrSuccessFunc) authMiddlewareOpt {
return func(m *authMiddleware) {
m.preventNilHandler()
m.handler.BeforeFunc = fn
}
}

func WithSuccessFunc(fn BeforeOrSuccessFunc) authMiddlewareOpt {
return func(m *authMiddleware) {
m.handler.SuccessFunc = fn
}
}

func WithErrorHandler(fn ErrorHandle) authMiddlewareOpt {
return func(m *authMiddleware) {
m.preventNilHandler()
m.handler.ErrorHandler = fn
}
}

func WithExtractor(fn Extractor) authMiddlewareOpt {
return func(m *authMiddleware) {
m.preventNilHandler()
m.handler.Extractor = fn
}
}

func WithParser(p Parser) authMiddlewareOpt {
return func(m *authMiddleware) {
m.handler.Parser = p
}
}

Expand Down
20 changes: 10 additions & 10 deletions auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ var unaryAuthTests = []struct {
{
Case: "skip",
Interceptor: func(t *testing.T) connect.Interceptor {
interceptor, err := NewAuthInterceptor(WithDefaultBearerExtractorAndParser([]byte("secret")))
interceptor, err := NewAuthInterceptor(WithInterceptorDefaultBearerExtractorAndParser([]byte("secret")))
assert.Nil(t, err)
return interceptor
},
Expand All @@ -53,7 +53,7 @@ var unaryAuthTests = []struct {
{
Case: "ignore error",
Interceptor: func(t *testing.T) connect.Interceptor {
interceptor, err := NewAuthInterceptor(WithDefaultBearerExtractorAndParser([]byte("secret")), WithIgnoreError())
interceptor, err := NewAuthInterceptor(WithInterceptorDefaultBearerExtractorAndParser([]byte("secret")), WithInterceptorIgnoreError())
assert.Nil(t, err)
return interceptor
},
Expand All @@ -74,7 +74,7 @@ var unaryAuthTests = []struct {
{
Case: "invalid bearer token",
Interceptor: func(t *testing.T) connect.Interceptor {
interceptor, err := NewAuthInterceptor(WithDefaultBearerExtractorAndParser([]byte("secret")))
interceptor, err := NewAuthInterceptor(WithInterceptorDefaultBearerExtractorAndParser([]byte("secret")))
assert.Nil(t, err)
return interceptor
},
Expand All @@ -96,7 +96,7 @@ var unaryAuthTests = []struct {
{
Case: "invalid auth header",
Interceptor: func(t *testing.T) connect.Interceptor {
interceptor, err := NewAuthInterceptor(WithDefaultBearerExtractorAndParser([]byte("secret")))
interceptor, err := NewAuthInterceptor(WithInterceptorDefaultBearerExtractorAndParser([]byte("secret")))
assert.Nil(t, err)
return interceptor
},
Expand All @@ -118,7 +118,7 @@ var unaryAuthTests = []struct {
{
Case: "invalid signing key",
Interceptor: func(t *testing.T) connect.Interceptor {
interceptor, err := NewAuthInterceptor(WithDefaultBearerExtractorAndParser([]byte("secet")))
interceptor, err := NewAuthInterceptor(WithInterceptorDefaultBearerExtractorAndParser([]byte("secet")))
assert.Nil(t, err)
return interceptor
},
Expand All @@ -140,7 +140,7 @@ var unaryAuthTests = []struct {
{
Case: "default",
Interceptor: func(t *testing.T) connect.Interceptor {
interceptor, err := NewAuthInterceptor(WithDefaultBearerExtractorAndParser([]byte("secret")))
interceptor, err := NewAuthInterceptor(WithInterceptorDefaultBearerExtractorAndParser([]byte("secret")))
assert.Nil(t, err)
return interceptor
},
Expand All @@ -164,8 +164,8 @@ var unaryAuthTests = []struct {
Case: "custom claim",
Interceptor: func(t *testing.T) connect.Interceptor {
interceptor, err := NewAuthInterceptor(
WithDefaultBearerExtractor(),
WithCustomJWTClaimsParser([]byte("secret"), func(ctx context.Context) jwt.Claims {
WithInterceptorDefaultBearerExtractor(),
WithInterceptorCustomJWTClaimsParser([]byte("secret"), func(ctx context.Context) jwt.Claims {
return &jwtCustomClaims{}
}),
)
Expand Down Expand Up @@ -226,8 +226,8 @@ var unaryAuthTests = []struct {
return payload, nil
}
interceptor, err := NewAuthInterceptor(
WithExtractor(extractor.ToExtractor()),
WithParser(parser),
WithInterceptorExtractor(extractor.ToExtractor()),
WithInterceptorParser(parser),
)
assert.Nil(t, err)
return interceptor
Expand Down
2 changes: 1 addition & 1 deletion e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func newServer(t *testing.T, interceptor connect.Interceptor, validator func(con
return memhttptest.New(t, mux)
}
func TestE2E(t *testing.T) {
authInterceptor, err := NewAuthInterceptor(WithDefaultBearerExtractorAndParser([]byte("secret")))
authInterceptor, err := NewAuthInterceptor(WithInterceptorDefaultBearerExtractorAndParser([]byte("secret")))
assert.Nil(t, err)
s := newServer(t, authInterceptor, func(ctx context.Context) {
claims, ok := FromContext[jwt.MapClaims](ctx)
Expand Down
16 changes: 10 additions & 6 deletions example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,21 @@ func (s *PingServer) CumSum(
}

func main() {
auth, err := middleware.NewAuthInterceptor(middleware.WithDefaultBearerExtractorAndParser([]byte("secret")))
// auth, err := middleware.NewAuthInterceptor(middleware.WithInterceptorDefaultBearerExtractorAndParser([]byte("secret")))
// if err != nil {
// panic(err)
// }
// interceptors := connect.WithInterceptors(auth)
greeter := &PingServer{pingv1connect.UnimplementedPingServiceHandler{}}
mux := http.NewServeMux()
mux.Handle(pingv1connect.NewPingServiceHandler(greeter))
authMiddleware, err := middleware.NewAuthMiddleware(middleware.WithDefaultBearerExtractorAndParser([]byte("secret")))
if err != nil {
panic(err)
}
interceptors := connect.WithInterceptors(auth)
greeter := &PingServer{pingv1connect.UnimplementedPingServiceHandler{}}
mux := http.NewServeMux()
mux.Handle(pingv1connect.NewPingServiceHandler(greeter, interceptors))
http.ListenAndServe(
"localhost:8080",
// Use h2c so we can serve HTTP/2 without TLS.
h2c.NewHandler(mux, &http2.Server{}),
h2c.NewHandler(authMiddleware.Wrap(mux), &http2.Server{}),
)
}
Loading

0 comments on commit e689800

Please sign in to comment.