Skip to content

Commit

Permalink
little bit of test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
jr22 committed Sep 24, 2024
1 parent 1d96286 commit 02ea251
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 5 deletions.
2 changes: 1 addition & 1 deletion web/auth0.go
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ func verifyAndSaveToken(ctx context.Context, state *AuthProvider, session *Sessi
}

session.Data["id_token"] = rawIDToken
session.Data["access_token"] = token.AccessToken
session.Data[accessTokenSessionDataField] = token.AccessToken
session.Data["profile"] = profile

return session, nil
Expand Down
11 changes: 8 additions & 3 deletions web/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ import (
mongoutils "go.viam.com/utils/mongo"
)

var (
accessTokenSessionDataField = "access_token"
accessTokenFieldPath = fmt.Sprintf("data.%s", accessTokenSessionDataField)
)

var webSessionsIndex = []mongo.IndexModel{
{
Keys: bson.D{
Expand All @@ -27,7 +32,7 @@ var webSessionsIndex = []mongo.IndexModel{
},
{
Keys: bson.D{
{Key: "data.access_token", Value: 1},
{Key: accessTokenFieldPath, Value: 1},
},
},
}
Expand Down Expand Up @@ -237,7 +242,7 @@ func (mss *mongoDBSessionStore) HasSessionWithToken(ctx context.Context, token s
ctx, span := trace.StartSpan(ctx, "MongoDBSessionStore::Get")
defer span.End()

count, err := mss.collection.CountDocuments(ctx, bson.M{"data.access_token": token}, options.Count().SetLimit(1))
count, err := mss.collection.CountDocuments(ctx, bson.M{accessTokenFieldPath: token}, options.Count().SetLimit(1))
if err != nil {
if errors.Is(err, mongo.ErrNoDocuments) {
return false, errNoSession
Expand Down Expand Up @@ -293,7 +298,7 @@ func (mss *memorySessionStore) Delete(ctx context.Context, id string) error {
func (mss *memorySessionStore) HasSessionWithToken(ctx context.Context, token string) (bool, error) {
if mss.data != nil {
for _, session := range mss.data {
savedToken, ok := session.Data["access_token"]
savedToken, ok := session.Data[accessTokenSessionDataField]
if !ok {
continue
}
Expand Down
28 changes: 27 additions & 1 deletion web/sessions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
)

func TestSession1(t *testing.T) {
ctx := context.Background()
sm := NewSessionManager(&memorySessionStore{}, golog.NewTestLogger(t))

r, err := http.NewRequest(http.MethodGet, "http://localhost/", nil)
Expand All @@ -36,7 +37,16 @@ func TestSession1(t *testing.T) {
w := &DummyWriter{}

s.Data["a"] = 1
s.Data["access_token"] = "the_access_token"
s.Save(context.TODO(), r, w)

if hasSession := sm.HasSessionWithAccessToken(ctx, "the_wrong_access_token"); hasSession {
t.Fatal("should not have session with token")
}

if hasSession := sm.HasSessionWithAccessToken(ctx, "the_access_token"); !hasSession {
t.Fatal("should have session with token")
}
}

// ----
Expand Down Expand Up @@ -64,7 +74,7 @@ func TestMongoStore(t *testing.T) {

s1 := &Session{}
s1.id = "foo"
s1.Data = bson.M{"a": 1, "b": 2}
s1.Data = bson.M{"a": 1, "b": 2, "access_token": "testToken"}
err = store.Save(ctx, s1)
if err != nil {
t.Fatal(err)
Expand All @@ -84,6 +94,22 @@ func TestMongoStore(t *testing.T) {
if _, err := store.Get(ctx, "something"); !errors.Is(err, errNoSession) {
t.Fatal(err)
}

hasSession, err := store.HasSessionWithToken(ctx, "no_token")
if err != nil {
t.Fatal(err)
}
if hasSession {
t.Fatal("should not have session")
}

hasSession, err = store.HasSessionWithToken(ctx, "testToken")
if err != nil {
t.Fatal(err)
}
if !hasSession {
t.Fatal("should have session")
}
}

// ----
Expand Down

0 comments on commit 02ea251

Please sign in to comment.