diff --git a/web/auth0.go b/web/auth0.go index 2ab38764..631e795c 100644 --- a/web/auth0.go +++ b/web/auth0.go @@ -425,7 +425,7 @@ func verifyAndSaveToken(ctx context.Context, state *AuthProvider, session *Sessi } session.Data["id_token"] = rawIDToken - session.Data["access_token"] = token.AccessToken + session.Data[accessTokenSessionDataField] = token.AccessToken session.Data["profile"] = profile return session, nil diff --git a/web/sessions.go b/web/sessions.go index 66b30d05..3994a14c 100644 --- a/web/sessions.go +++ b/web/sessions.go @@ -18,6 +18,11 @@ import ( mongoutils "go.viam.com/utils/mongo" ) +var ( + accessTokenSessionDataField = "access_token" + accessTokenFieldPath = fmt.Sprintf("data.%s", accessTokenSessionDataField) +) + var webSessionsIndex = []mongo.IndexModel{ { Keys: bson.D{ @@ -27,7 +32,7 @@ var webSessionsIndex = []mongo.IndexModel{ }, { Keys: bson.D{ - {Key: "data.access_token", Value: 1}, + {Key: accessTokenFieldPath, Value: 1}, }, }, } @@ -237,7 +242,7 @@ func (mss *mongoDBSessionStore) HasSessionWithToken(ctx context.Context, token s ctx, span := trace.StartSpan(ctx, "MongoDBSessionStore::Get") defer span.End() - count, err := mss.collection.CountDocuments(ctx, bson.M{"data.access_token": token}, options.Count().SetLimit(1)) + count, err := mss.collection.CountDocuments(ctx, bson.M{accessTokenFieldPath: token}, options.Count().SetLimit(1)) if err != nil { if errors.Is(err, mongo.ErrNoDocuments) { return false, errNoSession @@ -293,7 +298,7 @@ func (mss *memorySessionStore) Delete(ctx context.Context, id string) error { func (mss *memorySessionStore) HasSessionWithToken(ctx context.Context, token string) (bool, error) { if mss.data != nil { for _, session := range mss.data { - savedToken, ok := session.Data["access_token"] + savedToken, ok := session.Data[accessTokenSessionDataField] if !ok { continue } diff --git a/web/sessions_test.go b/web/sessions_test.go index 9a6db2bf..e2f4e58d 100644 --- a/web/sessions_test.go +++ b/web/sessions_test.go @@ -13,6 +13,7 @@ import ( ) func TestSession1(t *testing.T) { + ctx := context.Background() sm := NewSessionManager(&memorySessionStore{}, golog.NewTestLogger(t)) r, err := http.NewRequest(http.MethodGet, "http://localhost/", nil) @@ -36,7 +37,16 @@ func TestSession1(t *testing.T) { w := &DummyWriter{} s.Data["a"] = 1 + s.Data["access_token"] = "the_access_token" s.Save(context.TODO(), r, w) + + if hasSession := sm.HasSessionWithAccessToken(ctx, "the_wrong_access_token"); hasSession { + t.Fatal("should not have session with token") + } + + if hasSession := sm.HasSessionWithAccessToken(ctx, "the_access_token"); !hasSession { + t.Fatal("should have session with token") + } } // ---- @@ -64,7 +74,7 @@ func TestMongoStore(t *testing.T) { s1 := &Session{} s1.id = "foo" - s1.Data = bson.M{"a": 1, "b": 2} + s1.Data = bson.M{"a": 1, "b": 2, "access_token": "testToken"} err = store.Save(ctx, s1) if err != nil { t.Fatal(err) @@ -84,6 +94,22 @@ func TestMongoStore(t *testing.T) { if _, err := store.Get(ctx, "something"); !errors.Is(err, errNoSession) { t.Fatal(err) } + + hasSession, err := store.HasSessionWithToken(ctx, "no_token") + if err != nil { + t.Fatal(err) + } + if hasSession { + t.Fatal("should not have session") + } + + hasSession, err = store.HasSessionWithToken(ctx, "testToken") + if err != nil { + t.Fatal(err) + } + if !hasSession { + t.Fatal("should have session") + } } // ----