Skip to content

Commit

Permalink
more auth.Store improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
btoews committed Nov 13, 2023
1 parent 9504f66 commit a670342
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 104 deletions.
56 changes: 33 additions & 23 deletions tp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (tp *TP) HandlePollRequest(w http.ResponseWriter, r *http.Request) {
parts := strings.Split(r.URL.EscapedPath(), "/")
last := parts[len(parts)-1]

sd, err := store.GetByPollSecret(last)
sd, err := store.GetByPollSecret(r.Context(), last)
if err != nil || sd == nil {
tp.getLog(r).WithError(err).Warn("store lookup by poll secret")
http.Error(w, `{"error": "not found"}`, http.StatusNotFound)
Expand All @@ -72,7 +72,7 @@ func (tp *TP) HandlePollRequest(w http.ResponseWriter, r *http.Request) {
return
}

if err := store.DeleteByPollSecret(last); err != nil {
if err := store.DeleteByPollSecret(r.Context(), last); err != nil {
tp.getLog(r).WithError(err).Warn("store delete")
http.Error(w, `{"error": "internal server error"}`, http.StatusInternalServerError)
return
Expand Down Expand Up @@ -106,7 +106,7 @@ func (tp *TP) UserRequestMiddleware(next http.Handler) http.Handler {
return
}

sd, err := store.GetByUserSecret(userSecret)
sd, err := store.GetByUserSecret(r.Context(), userSecret)
if err != nil || sd == nil {
tp.getLog(r).WithError(err).Warn("store lookup by poll secret")
http.Error(w, `{"error": "not found"}`, http.StatusNotFound)
Expand Down Expand Up @@ -165,9 +165,9 @@ func (tp *TP) RespondPoll(w http.ResponseWriter, r *http.Request) string {
return ""
}

_, pollSecret, err := store.Put(&StoreData{Ticket: fd.ticket})
_, pollSecret, err := store.Insert(r.Context(), &StoreData{Ticket: fd.ticket})
if err != nil {
tp.getLog(r).WithError(err).Warn("store put")
tp.getLog(r).WithError(err).Warn("store insert")
http.Error(w, `{"error": "internal server error"}`, http.StatusInternalServerError)
return ""
}
Expand All @@ -179,12 +179,12 @@ func (tp *TP) RespondPoll(w http.ResponseWriter, r *http.Request) string {
return pollSecret
}

func (tp *TP) DischargePoll(pollSecret string, caveats ...macaroon.Caveat) error {
return tp.dischargePoller(pollSecret, "", caveats...)
func (tp *TP) DischargePoll(ctx context.Context, pollSecret string, caveats ...macaroon.Caveat) error {
return tp.dischargePoller(ctx, pollSecret, "", caveats...)
}

func (tp *TP) AbortPoll(pollSecret string, message string) error {
return tp.abortPoller(pollSecret, "", message)
func (tp *TP) AbortPoll(ctx context.Context, pollSecret string, message string) error {
return tp.abortPoller(ctx, pollSecret, "", message)
}

func (tp *TP) RespondUserInteractive(w http.ResponseWriter, r *http.Request) string {
Expand All @@ -196,9 +196,9 @@ func (tp *TP) RespondUserInteractive(w http.ResponseWriter, r *http.Request) str
return ""
}

userSecret, pollSecret, err := store.Put(&StoreData{Ticket: fd.ticket})
userSecret, pollSecret, err := store.Insert(r.Context(), &StoreData{Ticket: fd.ticket})
if err != nil {
tp.getLog(r).WithError(err).Warn("store put")
tp.getLog(r).WithError(err).Warn("store insert")
http.Error(w, `{"error": "internal server error"}`, http.StatusInternalServerError)
return ""
}
Expand All @@ -213,15 +213,15 @@ func (tp *TP) RespondUserInteractive(w http.ResponseWriter, r *http.Request) str
return userSecret
}

func (tp *TP) DischargeUserInteractive(userSecret string, caveats ...macaroon.Caveat) error {
return tp.dischargePoller("", userSecret, caveats...)
func (tp *TP) DischargeUserInteractive(ctx context.Context, userSecret string, caveats ...macaroon.Caveat) error {
return tp.dischargePoller(ctx, "", userSecret, caveats...)
}

func (tp *TP) AbortUserInteractive(userSecret string, message string) error {
return tp.abortPoller("", userSecret, message)
func (tp *TP) AbortUserInteractive(ctx context.Context, userSecret string, message string) error {
return tp.abortPoller(ctx, "", userSecret, message)
}

func (tp *TP) dischargePoller(pollSecret, userSecret string, caveats ...macaroon.Caveat) error {
func (tp *TP) dischargePoller(ctx context.Context, pollSecret, userSecret string, caveats ...macaroon.Caveat) error {
if tp.Store == nil {
return errors.New("no store")
}
Expand All @@ -231,9 +231,9 @@ func (tp *TP) dischargePoller(pollSecret, userSecret string, caveats ...macaroon
err error
)
if pollSecret != "" {
sd, err = tp.Store.GetByPollSecret(pollSecret)
sd, err = tp.Store.GetByPollSecret(ctx, pollSecret)
} else {
sd, err = tp.Store.GetByUserSecret(userSecret)
sd, err = tp.Store.GetByUserSecret(ctx, userSecret)
}
if err != nil {
return err
Expand Down Expand Up @@ -261,14 +261,19 @@ func (tp *TP) dischargePoller(pollSecret, userSecret string, caveats ...macaroon
sd.ResponseBody = jresp
sd.ResponseStatus = http.StatusOK

if _, _, err := tp.Store.Put(sd); err != nil {
if pollSecret != "" {
err = tp.Store.UpdateByPollSecret(ctx, pollSecret, sd)
} else {
err = tp.Store.UpdateByUserSecret(ctx, userSecret, sd)
}
if err != nil {
return err
}

return nil
}

func (tp *TP) abortPoller(pollSecret, userSecret string, message string) error {
func (tp *TP) abortPoller(ctx context.Context, pollSecret, userSecret string, message string) error {
if tp.Store == nil {
return errors.New("no store")
}
Expand All @@ -278,9 +283,9 @@ func (tp *TP) abortPoller(pollSecret, userSecret string, message string) error {
err error
)
if pollSecret != "" {
sd, err = tp.Store.GetByPollSecret(pollSecret)
sd, err = tp.Store.GetByPollSecret(ctx, pollSecret)
} else {
sd, err = tp.Store.GetByUserSecret(userSecret)
sd, err = tp.Store.GetByUserSecret(ctx, userSecret)
}
if err != nil {
return err
Expand All @@ -294,7 +299,12 @@ func (tp *TP) abortPoller(pollSecret, userSecret string, message string) error {
sd.ResponseBody = jresp
sd.ResponseStatus = http.StatusOK

if _, _, err := tp.Store.Put(sd); err != nil {
if pollSecret != "" {
err = tp.Store.UpdateByPollSecret(ctx, pollSecret, sd)
} else {
err = tp.Store.UpdateByUserSecret(ctx, userSecret, sd)
}
if err != nil {
return err
}

Expand Down
140 changes: 93 additions & 47 deletions tp/store.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package tp

import (
"context"
"encoding/hex"
"errors"
"net/http"
"strings"
"sync"

lru "github.com/hashicorp/golang-lru/v2"
"golang.org/x/crypto/blake2b"
Expand All @@ -17,13 +19,16 @@ type StoreData struct {
}

type Store interface {
Put(*StoreData) (userSecret, pollSecret string, err error)
Insert(context.Context, *StoreData) (userSecret, pollSecret string, err error)

DeleteByPollSecret(string) error
DeleteByUserSecret(string) error
GetByPollSecret(context.Context, string) (*StoreData, error)
GetByUserSecret(context.Context, string) (*StoreData, error)

GetByPollSecret(string) (*StoreData, error)
GetByUserSecret(string) (*StoreData, error)
UpdateByPollSecret(context.Context, string, *StoreData) error
UpdateByUserSecret(context.Context, string, *StoreData) error

DeleteByPollSecret(context.Context, string) error
DeleteByUserSecret(context.Context, string) error

UserSecretMunger
}
Expand All @@ -35,20 +40,18 @@ type UserSecretMunger interface {

type MemoryStore struct {
UserSecretMunger
Cache *lru.Cache[string, *StoreData]
secret []byte
Cache *lru.Cache[string, *lockedStoreData]
}

func NewMemoryStore(m UserSecretMunger, size int) (*MemoryStore, error) {
cache, err := lru.New[string, *StoreData](size)
cache, err := lru.New[string, *lockedStoreData](size)
if err != nil {
return nil, err
}

return &MemoryStore{
Cache: cache,
UserSecretMunger: m,
secret: randBytes(32),
}, nil
}

Expand All @@ -58,68 +61,111 @@ var (
errNotFound = errors.New("not found")
)

func (s *MemoryStore) Put(sd *StoreData) (userSecret, pollSecret string, err error) {
userSecret, pollSecret = s.ticketSecrets(sd.Ticket)
s.Cache.Add("u"+digest(userSecret), sd)
s.Cache.Add("p"+digest(pollSecret), sd)
return
}
const secretSize = 16

func (s *MemoryStore) DeleteByPollSecret(pollSecret string) error {
sd, err := s.GetByPollSecret(pollSecret)
if err != nil {
return err
func (s *MemoryStore) Insert(_ context.Context, sd *StoreData) (string, string, error) {
us := randHex(secretSize)
uk := userSecretKey(us)
ps := randHex(secretSize)
pk := pollSecretKey(ps)

lsd := &lockedStoreData{
StoreData: *sd,
userSecretKey: uk,
pollSecretKey: pk,
}
return s.delete(sd)

s.Cache.Add(uk, lsd)
s.Cache.Add(pk, lsd)

return us, ps, nil
}

func (s *MemoryStore) DeleteByUserSecret(userSecret string) error {
sd, err := s.GetByUserSecret(userSecret)
if err != nil {
return err
}
return s.delete(sd)
func (s *MemoryStore) GetByPollSecret(_ context.Context, pollSecret string) (*StoreData, error) {
lsd, _ := s.Cache.Get(pollSecretKey(pollSecret))
return lsd.getStoreData()
}

func (s *MemoryStore) delete(sd *StoreData) error {
userSecret, pollSecret := s.ticketSecrets(sd.Ticket)
s.Cache.Remove("u" + digest(userSecret))
s.Cache.Remove("p" + digest(pollSecret))
return nil
func (s *MemoryStore) GetByUserSecret(_ context.Context, userSecret string) (*StoreData, error) {
lsd, _ := s.Cache.Get(userSecretKey(userSecret))
return lsd.getStoreData()
}

func (s *MemoryStore) UpdateByPollSecret(_ context.Context, pollSecret string, sd *StoreData) error {
lsd, _ := s.Cache.Get(pollSecretKey(pollSecret))
return lsd.updateStoreData(sd)
}

func (s *MemoryStore) GetByPollSecret(pollSecret string) (*StoreData, error) {
if sd, ok := s.Cache.Get("p" + digest(pollSecret)); ok {
return sd, nil
func (s *MemoryStore) UpdateByUserSecret(_ context.Context, userSecret string, sd *StoreData) error {
lsd, _ := s.Cache.Get(userSecretKey(userSecret))
return lsd.updateStoreData(sd)
}

func (s *MemoryStore) DeleteByPollSecret(ctx context.Context, pollSecret string) error {
if lsd, _ := s.Cache.Get(pollSecretKey(pollSecret)); lsd != nil {
s.Cache.Remove(lsd.pollSecretKey)
s.Cache.Remove(lsd.userSecretKey)
return nil
}
return nil, errNotFound

return errNotFound
}

func (s *MemoryStore) GetByUserSecret(userSecret string) (*StoreData, error) {
if sd, ok := s.Cache.Get("u" + digest(userSecret)); ok {
return sd, nil
func (s *MemoryStore) DeleteByUserSecret(ctx context.Context, userSecret string) error {
if lsd, _ := s.Cache.Get(userSecretKey(userSecret)); lsd != nil {
s.Cache.Remove(lsd.pollSecretKey)
s.Cache.Remove(lsd.userSecretKey)
return nil
}
return nil, errNotFound

return errNotFound
}

func (s *MemoryStore) ticketSecrets(t []byte) (string, string) {
h, err := blake2b.New(32, s.secret)
if err != nil {
panic(err)
func userSecretKey(userSecret string) string { return "u" + digest(userSecret) }
func pollSecretKey(userSecret string) string { return "p" + digest(userSecret) }

type lockedStoreData struct {
StoreData
userSecretKey string
pollSecretKey string
sync.RWMutex
}

func (lsd *lockedStoreData) getStoreData() (*StoreData, error) {
if lsd == nil {
return nil, errNotFound
}
if _, err = h.Write(t); err != nil {
panic(err)

lsd.RLock()
defer lsd.RUnlock()

sd := lsd.StoreData

return &sd, nil
}

func (lsd *lockedStoreData) updateStoreData(sd *StoreData) error {
if lsd == nil {
return errNotFound
}
d := h.Sum(nil)

return hex.EncodeToString(d[:16]), hex.EncodeToString(d[16:])
lsd.Lock()
defer lsd.Unlock()

lsd.StoreData = *sd

return nil
}

func digest[T string | []byte](d T) string {
digest := blake2b.Sum256([]byte(d))
return hex.EncodeToString(digest[:])
}

func randHex(n int) string {
return hex.EncodeToString(randBytes(n))
}

type PrefixMunger string

var _ UserSecretMunger = PrefixMunger("")
Expand Down
Loading

0 comments on commit a670342

Please sign in to comment.