diff --git a/tp/server.go b/tp/server.go index 030dc76..5cbe153 100644 --- a/tp/server.go +++ b/tp/server.go @@ -55,7 +55,7 @@ func (tp *TP) HandlePollRequest(w http.ResponseWriter, r *http.Request) { parts := strings.Split(r.URL.EscapedPath(), "/") last := parts[len(parts)-1] - sd, err := store.GetByPollSecret(last) + sd, err := store.GetByPollSecret(r.Context(), last) if err != nil || sd == nil { tp.getLog(r).WithError(err).Warn("store lookup by poll secret") http.Error(w, `{"error": "not found"}`, http.StatusNotFound) @@ -72,7 +72,7 @@ func (tp *TP) HandlePollRequest(w http.ResponseWriter, r *http.Request) { return } - if err := store.DeleteByPollSecret(last); err != nil { + if err := store.DeleteByPollSecret(r.Context(), last); err != nil { tp.getLog(r).WithError(err).Warn("store delete") http.Error(w, `{"error": "internal server error"}`, http.StatusInternalServerError) return @@ -106,7 +106,7 @@ func (tp *TP) UserRequestMiddleware(next http.Handler) http.Handler { return } - sd, err := store.GetByUserSecret(userSecret) + sd, err := store.GetByUserSecret(r.Context(), userSecret) if err != nil || sd == nil { tp.getLog(r).WithError(err).Warn("store lookup by poll secret") http.Error(w, `{"error": "not found"}`, http.StatusNotFound) @@ -165,9 +165,9 @@ func (tp *TP) RespondPoll(w http.ResponseWriter, r *http.Request) string { return "" } - _, pollSecret, err := store.Put(&StoreData{Ticket: fd.ticket}) + _, pollSecret, err := store.Insert(r.Context(), &StoreData{Ticket: fd.ticket}) if err != nil { - tp.getLog(r).WithError(err).Warn("store put") + tp.getLog(r).WithError(err).Warn("store insert") http.Error(w, `{"error": "internal server error"}`, http.StatusInternalServerError) return "" } @@ -179,12 +179,12 @@ func (tp *TP) RespondPoll(w http.ResponseWriter, r *http.Request) string { return pollSecret } -func (tp *TP) DischargePoll(pollSecret string, caveats ...macaroon.Caveat) error { - return tp.dischargePoller(pollSecret, "", caveats...) +func (tp *TP) DischargePoll(ctx context.Context, pollSecret string, caveats ...macaroon.Caveat) error { + return tp.dischargePoller(ctx, pollSecret, "", caveats...) } -func (tp *TP) AbortPoll(pollSecret string, message string) error { - return tp.abortPoller(pollSecret, "", message) +func (tp *TP) AbortPoll(ctx context.Context, pollSecret string, message string) error { + return tp.abortPoller(ctx, pollSecret, "", message) } func (tp *TP) RespondUserInteractive(w http.ResponseWriter, r *http.Request) string { @@ -196,9 +196,9 @@ func (tp *TP) RespondUserInteractive(w http.ResponseWriter, r *http.Request) str return "" } - userSecret, pollSecret, err := store.Put(&StoreData{Ticket: fd.ticket}) + userSecret, pollSecret, err := store.Insert(r.Context(), &StoreData{Ticket: fd.ticket}) if err != nil { - tp.getLog(r).WithError(err).Warn("store put") + tp.getLog(r).WithError(err).Warn("store insert") http.Error(w, `{"error": "internal server error"}`, http.StatusInternalServerError) return "" } @@ -213,15 +213,15 @@ func (tp *TP) RespondUserInteractive(w http.ResponseWriter, r *http.Request) str return userSecret } -func (tp *TP) DischargeUserInteractive(userSecret string, caveats ...macaroon.Caveat) error { - return tp.dischargePoller("", userSecret, caveats...) +func (tp *TP) DischargeUserInteractive(ctx context.Context, userSecret string, caveats ...macaroon.Caveat) error { + return tp.dischargePoller(ctx, "", userSecret, caveats...) } -func (tp *TP) AbortUserInteractive(userSecret string, message string) error { - return tp.abortPoller("", userSecret, message) +func (tp *TP) AbortUserInteractive(ctx context.Context, userSecret string, message string) error { + return tp.abortPoller(ctx, "", userSecret, message) } -func (tp *TP) dischargePoller(pollSecret, userSecret string, caveats ...macaroon.Caveat) error { +func (tp *TP) dischargePoller(ctx context.Context, pollSecret, userSecret string, caveats ...macaroon.Caveat) error { if tp.Store == nil { return errors.New("no store") } @@ -231,9 +231,9 @@ func (tp *TP) dischargePoller(pollSecret, userSecret string, caveats ...macaroon err error ) if pollSecret != "" { - sd, err = tp.Store.GetByPollSecret(pollSecret) + sd, err = tp.Store.GetByPollSecret(ctx, pollSecret) } else { - sd, err = tp.Store.GetByUserSecret(userSecret) + sd, err = tp.Store.GetByUserSecret(ctx, userSecret) } if err != nil { return err @@ -261,14 +261,19 @@ func (tp *TP) dischargePoller(pollSecret, userSecret string, caveats ...macaroon sd.ResponseBody = jresp sd.ResponseStatus = http.StatusOK - if _, _, err := tp.Store.Put(sd); err != nil { + if pollSecret != "" { + err = tp.Store.UpdateByPollSecret(ctx, pollSecret, sd) + } else { + err = tp.Store.UpdateByUserSecret(ctx, userSecret, sd) + } + if err != nil { return err } return nil } -func (tp *TP) abortPoller(pollSecret, userSecret string, message string) error { +func (tp *TP) abortPoller(ctx context.Context, pollSecret, userSecret string, message string) error { if tp.Store == nil { return errors.New("no store") } @@ -278,9 +283,9 @@ func (tp *TP) abortPoller(pollSecret, userSecret string, message string) error { err error ) if pollSecret != "" { - sd, err = tp.Store.GetByPollSecret(pollSecret) + sd, err = tp.Store.GetByPollSecret(ctx, pollSecret) } else { - sd, err = tp.Store.GetByUserSecret(userSecret) + sd, err = tp.Store.GetByUserSecret(ctx, userSecret) } if err != nil { return err @@ -294,7 +299,12 @@ func (tp *TP) abortPoller(pollSecret, userSecret string, message string) error { sd.ResponseBody = jresp sd.ResponseStatus = http.StatusOK - if _, _, err := tp.Store.Put(sd); err != nil { + if pollSecret != "" { + err = tp.Store.UpdateByPollSecret(ctx, pollSecret, sd) + } else { + err = tp.Store.UpdateByUserSecret(ctx, userSecret, sd) + } + if err != nil { return err } diff --git a/tp/store.go b/tp/store.go index 267281a..cdb040c 100644 --- a/tp/store.go +++ b/tp/store.go @@ -1,10 +1,12 @@ package tp import ( + "context" "encoding/hex" "errors" "net/http" "strings" + "sync" lru "github.com/hashicorp/golang-lru/v2" "golang.org/x/crypto/blake2b" @@ -17,13 +19,16 @@ type StoreData struct { } type Store interface { - Put(*StoreData) (userSecret, pollSecret string, err error) + Insert(context.Context, *StoreData) (userSecret, pollSecret string, err error) - DeleteByPollSecret(string) error - DeleteByUserSecret(string) error + GetByPollSecret(context.Context, string) (*StoreData, error) + GetByUserSecret(context.Context, string) (*StoreData, error) - GetByPollSecret(string) (*StoreData, error) - GetByUserSecret(string) (*StoreData, error) + UpdateByPollSecret(context.Context, string, *StoreData) error + UpdateByUserSecret(context.Context, string, *StoreData) error + + DeleteByPollSecret(context.Context, string) error + DeleteByUserSecret(context.Context, string) error UserSecretMunger } @@ -35,12 +40,11 @@ type UserSecretMunger interface { type MemoryStore struct { UserSecretMunger - Cache *lru.Cache[string, *StoreData] - secret []byte + Cache *lru.Cache[string, *lockedStoreData] } func NewMemoryStore(m UserSecretMunger, size int) (*MemoryStore, error) { - cache, err := lru.New[string, *StoreData](size) + cache, err := lru.New[string, *lockedStoreData](size) if err != nil { return nil, err } @@ -48,7 +52,6 @@ func NewMemoryStore(m UserSecretMunger, size int) (*MemoryStore, error) { return &MemoryStore{ Cache: cache, UserSecretMunger: m, - secret: randBytes(32), }, nil } @@ -58,61 +61,100 @@ var ( errNotFound = errors.New("not found") ) -func (s *MemoryStore) Put(sd *StoreData) (userSecret, pollSecret string, err error) { - userSecret, pollSecret = s.ticketSecrets(sd.Ticket) - s.Cache.Add("u"+digest(userSecret), sd) - s.Cache.Add("p"+digest(pollSecret), sd) - return -} +const secretSize = 16 -func (s *MemoryStore) DeleteByPollSecret(pollSecret string) error { - sd, err := s.GetByPollSecret(pollSecret) - if err != nil { - return err +func (s *MemoryStore) Insert(_ context.Context, sd *StoreData) (string, string, error) { + us := randHex(secretSize) + uk := userSecretKey(us) + ps := randHex(secretSize) + pk := pollSecretKey(ps) + + lsd := &lockedStoreData{ + StoreData: *sd, + userSecretKey: uk, + pollSecretKey: pk, } - return s.delete(sd) + + s.Cache.Add(uk, lsd) + s.Cache.Add(pk, lsd) + + return us, ps, nil } -func (s *MemoryStore) DeleteByUserSecret(userSecret string) error { - sd, err := s.GetByUserSecret(userSecret) - if err != nil { - return err - } - return s.delete(sd) +func (s *MemoryStore) GetByPollSecret(_ context.Context, pollSecret string) (*StoreData, error) { + lsd, _ := s.Cache.Get(pollSecretKey(pollSecret)) + return lsd.getStoreData() } -func (s *MemoryStore) delete(sd *StoreData) error { - userSecret, pollSecret := s.ticketSecrets(sd.Ticket) - s.Cache.Remove("u" + digest(userSecret)) - s.Cache.Remove("p" + digest(pollSecret)) - return nil +func (s *MemoryStore) GetByUserSecret(_ context.Context, userSecret string) (*StoreData, error) { + lsd, _ := s.Cache.Get(userSecretKey(userSecret)) + return lsd.getStoreData() +} + +func (s *MemoryStore) UpdateByPollSecret(_ context.Context, pollSecret string, sd *StoreData) error { + lsd, _ := s.Cache.Get(pollSecretKey(pollSecret)) + return lsd.updateStoreData(sd) } -func (s *MemoryStore) GetByPollSecret(pollSecret string) (*StoreData, error) { - if sd, ok := s.Cache.Get("p" + digest(pollSecret)); ok { - return sd, nil +func (s *MemoryStore) UpdateByUserSecret(_ context.Context, userSecret string, sd *StoreData) error { + lsd, _ := s.Cache.Get(userSecretKey(userSecret)) + return lsd.updateStoreData(sd) +} + +func (s *MemoryStore) DeleteByPollSecret(ctx context.Context, pollSecret string) error { + if lsd, _ := s.Cache.Get(pollSecretKey(pollSecret)); lsd != nil { + s.Cache.Remove(lsd.pollSecretKey) + s.Cache.Remove(lsd.userSecretKey) + return nil } - return nil, errNotFound + + return errNotFound } -func (s *MemoryStore) GetByUserSecret(userSecret string) (*StoreData, error) { - if sd, ok := s.Cache.Get("u" + digest(userSecret)); ok { - return sd, nil +func (s *MemoryStore) DeleteByUserSecret(ctx context.Context, userSecret string) error { + if lsd, _ := s.Cache.Get(userSecretKey(userSecret)); lsd != nil { + s.Cache.Remove(lsd.pollSecretKey) + s.Cache.Remove(lsd.userSecretKey) + return nil } - return nil, errNotFound + + return errNotFound } -func (s *MemoryStore) ticketSecrets(t []byte) (string, string) { - h, err := blake2b.New(32, s.secret) - if err != nil { - panic(err) +func userSecretKey(userSecret string) string { return "u" + digest(userSecret) } +func pollSecretKey(userSecret string) string { return "p" + digest(userSecret) } + +type lockedStoreData struct { + StoreData + userSecretKey string + pollSecretKey string + sync.RWMutex +} + +func (lsd *lockedStoreData) getStoreData() (*StoreData, error) { + if lsd == nil { + return nil, errNotFound } - if _, err = h.Write(t); err != nil { - panic(err) + + lsd.RLock() + defer lsd.RUnlock() + + sd := lsd.StoreData + + return &sd, nil +} + +func (lsd *lockedStoreData) updateStoreData(sd *StoreData) error { + if lsd == nil { + return errNotFound } - d := h.Sum(nil) - return hex.EncodeToString(d[:16]), hex.EncodeToString(d[16:]) + lsd.Lock() + defer lsd.Unlock() + + lsd.StoreData = *sd + + return nil } func digest[T string | []byte](d T) string { @@ -120,6 +162,10 @@ func digest[T string | []byte](d T) string { return hex.EncodeToString(digest[:]) } +func randHex(n int) string { + return hex.EncodeToString(randBytes(n)) +} + type PrefixMunger string var _ UserSecretMunger = PrefixMunger("") diff --git a/tp/store_test.go b/tp/store_test.go index c339e2e..d249a4b 100644 --- a/tp/store_test.go +++ b/tp/store_test.go @@ -1,90 +1,133 @@ package tp import ( + "context" "testing" "github.com/alecthomas/assert/v2" ) func TestMemoryStoreSecrets(t *testing.T) { + ctx := context.Background() + ms, err := NewMemoryStore(PrefixMunger("/user/"), 100) assert.NoError(t, err) - assert.Equal(t, 32, len(ms.secret)) - - x, y := ms.ticketSecrets([]byte("hi")) - assert.Equal(t, 32, len(x)) - assert.Equal(t, 32, len(y)) - a := &StoreData{Ticket: []byte("a")} - aUS, aPS, err := ms.Put(a) + aUS, aPS, err := ms.Insert(ctx, a) assert.NoError(t, err) b := &StoreData{Ticket: []byte("b")} - bUS, bPS, err := ms.Put(b) + bUS, bPS, err := ms.Insert(ctx, b) assert.NoError(t, err) - sd, err := ms.GetByUserSecret(aUS) + sd, err := ms.GetByUserSecret(ctx, aUS) assert.NoError(t, err) assert.Equal(t, []byte("a"), sd.Ticket) - _, err = ms.GetByPollSecret(aUS) + _, err = ms.GetByPollSecret(ctx, aUS) assert.Equal(t, errNotFound, err) - sd, err = ms.GetByPollSecret(aPS) + sd, err = ms.GetByPollSecret(ctx, aPS) assert.NoError(t, err) assert.Equal(t, []byte("a"), sd.Ticket) - _, err = ms.GetByUserSecret(aPS) + _, err = ms.GetByUserSecret(ctx, aPS) assert.Equal(t, errNotFound, err) - sd, err = ms.GetByUserSecret(bUS) + sd, err = ms.GetByUserSecret(ctx, bUS) assert.NoError(t, err) assert.Equal(t, []byte("b"), sd.Ticket) - _, err = ms.GetByPollSecret(bUS) + _, err = ms.GetByPollSecret(ctx, bUS) assert.Equal(t, errNotFound, err) - sd, err = ms.GetByPollSecret(bPS) + sd, err = ms.GetByPollSecret(ctx, bPS) assert.NoError(t, err) assert.Equal(t, []byte("b"), sd.Ticket) - _, err = ms.GetByUserSecret(bPS) + _, err = ms.GetByUserSecret(ctx, bPS) assert.Equal(t, errNotFound, err) - assert.NoError(t, ms.DeleteByPollSecret(aPS)) + err = ms.DeleteByUserSecret(ctx, aPS) + assert.Equal(t, errNotFound, err) + err = ms.DeleteByPollSecret(ctx, aUS) + assert.Equal(t, errNotFound, err) + assert.NoError(t, ms.DeleteByPollSecret(ctx, aPS)) - _, err = ms.GetByPollSecret(aPS) + _, err = ms.GetByPollSecret(ctx, aPS) assert.Equal(t, errNotFound, err) - _, err = ms.GetByUserSecret(aUS) + _, err = ms.GetByUserSecret(ctx, aUS) assert.Equal(t, errNotFound, err) - sd, err = ms.GetByUserSecret(bUS) + sd, err = ms.GetByUserSecret(ctx, bUS) assert.NoError(t, err) assert.Equal(t, []byte("b"), sd.Ticket) - _, err = ms.GetByPollSecret(bUS) + _, err = ms.GetByPollSecret(ctx, bUS) assert.Equal(t, errNotFound, err) - sd, err = ms.GetByPollSecret(bPS) + sd, err = ms.GetByPollSecret(ctx, bPS) assert.NoError(t, err) assert.Equal(t, []byte("b"), sd.Ticket) - _, err = ms.GetByUserSecret(bPS) + _, err = ms.GetByUserSecret(ctx, bPS) assert.Equal(t, errNotFound, err) - bb := *b - bb.ResponseBody = []byte{1, 2, 3} - bbUS, bbPS, err := ms.Put(&bb) + b.ResponseBody = []byte{1, 2, 3} + err = ms.UpdateByPollSecret(ctx, bPS, b) assert.NoError(t, err) - assert.Equal(t, bUS, bbUS) - assert.Equal(t, bPS, bbPS) - sd, err = ms.GetByUserSecret(bUS) + sd, err = ms.GetByUserSecret(ctx, bUS) assert.NoError(t, err) assert.Equal(t, []byte("b"), sd.Ticket) assert.Equal(t, []byte{1, 2, 3}, sd.ResponseBody) - _, err = ms.GetByPollSecret(bUS) + _, err = ms.GetByPollSecret(ctx, bUS) assert.Equal(t, errNotFound, err) - sd, err = ms.GetByPollSecret(bPS) + sd, err = ms.GetByPollSecret(ctx, bPS) assert.NoError(t, err) assert.Equal(t, []byte("b"), sd.Ticket) assert.Equal(t, []byte{1, 2, 3}, sd.ResponseBody) - _, err = ms.GetByUserSecret(bPS) + _, err = ms.GetByUserSecret(ctx, bPS) + assert.Equal(t, errNotFound, err) + + b.ResponseBody = []byte{4, 5, 6} + err = ms.UpdateByUserSecret(ctx, bUS, b) + assert.NoError(t, err) + + sd, err = ms.GetByUserSecret(ctx, bUS) + assert.NoError(t, err) + assert.Equal(t, []byte("b"), sd.Ticket) + assert.Equal(t, []byte{4, 5, 6}, sd.ResponseBody) + _, err = ms.GetByPollSecret(ctx, bUS) + assert.Equal(t, errNotFound, err) + + sd, err = ms.GetByPollSecret(ctx, bPS) + assert.NoError(t, err) + assert.Equal(t, []byte("b"), sd.Ticket) + assert.Equal(t, []byte{4, 5, 6}, sd.ResponseBody) + _, err = ms.GetByUserSecret(ctx, bPS) + assert.Equal(t, errNotFound, err) + + b.ResponseBody = []byte{9, 9, 9} + err = ms.UpdateByPollSecret(ctx, bUS, b) + assert.Equal(t, errNotFound, err) + err = ms.UpdateByUserSecret(ctx, bPS, b) + assert.Equal(t, errNotFound, err) + + sd, err = ms.GetByUserSecret(ctx, bUS) + assert.NoError(t, err) + assert.Equal(t, []byte("b"), sd.Ticket) + assert.Equal(t, []byte{4, 5, 6}, sd.ResponseBody) + _, err = ms.GetByPollSecret(ctx, bUS) + assert.Equal(t, errNotFound, err) + + sd, err = ms.GetByPollSecret(ctx, bPS) + assert.NoError(t, err) + assert.Equal(t, []byte("b"), sd.Ticket) + assert.Equal(t, []byte{4, 5, 6}, sd.ResponseBody) + _, err = ms.GetByUserSecret(ctx, bPS) + assert.Equal(t, errNotFound, err) + + assert.NoError(t, ms.DeleteByUserSecret(ctx, bUS)) + + _, err = ms.GetByPollSecret(ctx, bPS) + assert.Equal(t, errNotFound, err) + _, err = ms.GetByUserSecret(ctx, bUS) assert.Equal(t, errNotFound, err) } diff --git a/tp/tp_test.go b/tp/tp_test.go index ec2942c..0a40c49 100644 --- a/tp/tp_test.go +++ b/tp/tp_test.go @@ -93,7 +93,7 @@ func TestTP(t *testing.T) { case <-pollSecretSet: select { case <-time.After(5 * time.Millisecond): - assert.NoError(t, tp.DischargePoll(pollSecret, myCaveat("dis-cav"))) + assert.NoError(t, tp.DischargePoll(context.Background(), pollSecret, myCaveat("dis-cav"))) case <-ctx.Done(): panic("oh no") } @@ -130,7 +130,7 @@ func TestTP(t *testing.T) { }, UserURLCallback: func(_ context.Context, url string) error { time.Sleep(10 * time.Millisecond) - assert.NoError(t, tp.DischargeUserInteractive(userSecret, myCaveat("dis-cav"))) + assert.NoError(t, tp.DischargeUserInteractive(context.Background(), userSecret, myCaveat("dis-cav"))) return nil }, }