Skip to content

Commit

Permalink
Consolidate ClientPassedWookieCtxKey and ServerCreatedWookieCtxKey (
Browse files Browse the repository at this point in the history
  • Loading branch information
kkajla12 authored Sep 29, 2023
1 parent e042ba3 commit f052204
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 39 deletions.
11 changes: 5 additions & 6 deletions pkg/authz/wookie/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
package authz

import (
"context"
"net/http"

"github.com/rs/zerolog/hlog"
Expand All @@ -31,7 +30,7 @@ func GenerateWookieMiddleware(wookieSvc *WookieService) service.Middleware {

func wookieMiddleware(next http.Handler, wookieSvc *WookieService) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
clientPassedWookieFromCtx, _ := wookie.GetClientPassedWookieFromRequestContext(r.Context())
clientPassedWookieFromCtx, _ := wookie.GetWookieFromRequestContext(r.Context())
if clientPassedWookieFromCtx != nil {
next.ServeHTTP(w, r)
return
Expand All @@ -48,8 +47,8 @@ func wookieMiddleware(next http.Handler, wookieSvc *WookieService) http.Handler
return
}

ctxWithToken := context.WithValue(r.Context(), wookie.ClientPassedWookieCtxKey{}, *token)
next.ServeHTTP(w, r.WithContext(ctxWithToken))
ctxWithWookie := wookie.WithWookie(r.Context(), token)
next.ServeHTTP(w, r.WithContext(ctxWithWookie))
default:
token, err := wookie.FromString(headerVal)
if err != nil {
Expand All @@ -58,8 +57,8 @@ func wookieMiddleware(next http.Handler, wookieSvc *WookieService) http.Handler
return
}

ctxWithToken := context.WithValue(r.Context(), wookie.ClientPassedWookieCtxKey{}, token)
next.ServeHTTP(w, r.WithContext(ctxWithToken))
ctxWithWookie := wookie.WithWookie(r.Context(), token)
next.ServeHTTP(w, r.WithContext(ctxWithWookie))
}
})
}
10 changes: 6 additions & 4 deletions pkg/authz/wookie/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,12 @@ func (svc WookieService) GetLatestWookie(ctx context.Context) (*wookie.Token, er
}

func (svc WookieService) WithNewWookie(ctx context.Context, txWookieFunc func(txCtx context.Context, createdWookieId int64) error) (*wookie.Token, error) {
serverCreatedWookie, hasServerCreatedWookie := ctx.Value(wookie.ServerCreatedWookieCtxKey{}).(*wookie.Token)
serverCreatedWookie, err := wookie.GetWookieFromRequestContext(ctx)
if err != nil {
return nil, err
}
// An update is already in progress so continue with that ctx
if hasServerCreatedWookie {
if serverCreatedWookie != nil {
e := txWookieFunc(ctx, serverCreatedWookie.ID)
if e != nil {
return nil, e
Expand All @@ -93,14 +96,13 @@ func (svc WookieService) WithNewWookie(ctx context.Context, txWookieFunc func(tx

// Otherwise, create a new tx and a new wookie for writes in txWookieFunc to use.
var newWookie *wookie.Token
var err error
err = svc.Env().DB().WithinTransaction(ctx, func(txCtx context.Context) error {
newWookie, err = svc.CreateNewWookie(txCtx)
if err != nil {
return err
}

wkCtx := context.WithValue(txCtx, wookie.ServerCreatedWookieCtxKey{}, newWookie)
wkCtx := wookie.WithWookie(txCtx, newWookie)
err = txWookieFunc(wkCtx, newWookie.ID)
if err != nil {
return err
Expand Down
16 changes: 8 additions & 8 deletions pkg/wookie/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,33 +37,33 @@ func (t Token) String() string {
}

// De-serialize token from string (from header).
func FromString(wookieString string) (Token, error) {
func FromString(wookieString string) (*Token, error) {
if wookieString == "" {
return Token{}, errors.New("empty wookie string")
return nil, errors.New("empty wookie string")
}
decodedStr, err := base64.StdEncoding.DecodeString(wookieString)
if err != nil {
return Token{}, errors.New("invalid wookie string")
return nil, errors.New("invalid wookie string")
}
parts := strings.Split(string(decodedStr), ";")
if len(parts) != 3 {
return Token{}, errors.New("invalid wookie string")
return nil, errors.New("invalid wookie string")
}
id, err := strconv.ParseInt(parts[0], 0, 64)
if err != nil {
return Token{}, errors.New("invalid id in wookie string")
return nil, errors.New("invalid id in wookie string")
}
version, err := strconv.ParseInt(parts[1], 0, 64)
if err != nil {
return Token{}, errors.New("invalid version in wookie string")
return nil, errors.New("invalid version in wookie string")
}
microTs, err := strconv.ParseInt(parts[2], 0, 64)
if err != nil {
return Token{}, errors.New("invalid timestamp in wookie string")
return nil, errors.New("invalid timestamp in wookie string")
}
timestamp := time.UnixMicro(microTs)

return Token{
return &Token{
ID: id,
Version: version,
Timestamp: timestamp,
Expand Down
31 changes: 10 additions & 21 deletions pkg/wookie/wookie.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ const HeaderName = "Warrant-Token"
const Latest = "latest"

type WarrantTokenCtxKey struct{}
type ClientPassedWookieCtxKey struct{}
type ServerCreatedWookieCtxKey struct{}
type WookieCtxKey struct{}

func WarrantTokenMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -50,41 +49,31 @@ func ContainsLatest(ctx context.Context) bool {
return false
}

func GetServerCreatedWookieFromRequestContext(ctx context.Context) (*Token, error) {
wookieCtxVal := ctx.Value(ServerCreatedWookieCtxKey{})
func GetWookieFromRequestContext(ctx context.Context) (*Token, error) {
wookieCtxVal := ctx.Value(WookieCtxKey{})
if wookieCtxVal == nil {
//nolint:nilnil
return nil, nil
}

wookieToken, ok := wookieCtxVal.(Token)
if !ok {
return nil, errors.New("error fetching server created wookie from request context")
return nil, errors.New("error fetching wookie from request context")
}

return &wookieToken, nil
}

func GetClientPassedWookieFromRequestContext(ctx context.Context) (*Token, error) {
wookieCtxVal := ctx.Value(ClientPassedWookieCtxKey{})
if wookieCtxVal == nil {
//nolint:nilnil
return nil, nil
}

wookieToken, ok := wookieCtxVal.(Token)
if !ok {
return nil, errors.New("error fetching client passed wookie from request context")
}

return &wookieToken, nil
}

// Return a context with wookie set to 'latest'.
// Return a context with Warrant-Token set to 'latest'.
func WithLatest(parent context.Context) context.Context {
return context.WithValue(parent, WarrantTokenCtxKey{}, Latest)
}

// Return context with wookie set to specified Token.
func WithWookie(parent context.Context, wookie *Token) context.Context {
return context.WithValue(parent, WookieCtxKey{}, *wookie)
}

func AddAsResponseHeader(w http.ResponseWriter, token *Token) {
if token != nil {
w.Header().Set(HeaderName, token.String())
Expand Down

0 comments on commit f052204

Please sign in to comment.