Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SEC-39: Check if JWT is associated with session #357

Merged
merged 10 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions rpc/server_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ type JWTClaims struct {
jwt.RegisteredClaims
AuthCredentialsType CredentialsType `json:"rpc_creds_type,omitempty"`
AuthMetadata map[string]string `json:"rpc_auth_md,omitempty"`
ApplicationID string `json:"applicationId,omitempty"`
}

// Entity returns the entity from the claims' Subject.
Expand Down Expand Up @@ -238,7 +239,8 @@ func (wrapped ctxWrappedServerStream) Context() context.Context {
return wrapped.ctx
}

func tokenFromContext(ctx context.Context) (string, error) {
// TokenFromContext returns the bearer token from the authorization header and errors if it does not exist.
func TokenFromContext(ctx context.Context) (string, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return "", status.Error(codes.Unauthenticated, "authentication required")
Expand Down Expand Up @@ -289,7 +291,7 @@ func (ss *simpleServer) ensureAuthed(ctx context.Context) (context.Context, erro
return ss.ensureAuthedHandler(ctx)
}

tokenString, err := tokenFromContext(ctx)
tokenString, err := TokenFromContext(ctx)
if err != nil {
// check TLS state
if ss.tlsAuthHandler == nil {
Expand Down
16 changes: 8 additions & 8 deletions web/auth0.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,20 +196,20 @@ func installAuthProviderRoutes(
logger utils.ZapCompatibleLogger,
) {
mux.Handle(pat.New("/login"), &loginHandler{
authProvider,
logger,
redirectStateCookieName,
redirectStateCookieMaxAge,
state: authProvider,
logger: logger,
redirectStateCookieName: redirectStateCookieName,
redirectStateCookieMaxAge: redirectStateCookieMaxAge,
})
mux.Handle(pat.New(redirectURL), &callbackHandler{
authProvider,
logger,
redirectStateCookieName,
})
mux.Handle(pat.New("/logout"), &logoutHandler{
authProvider,
logger,
providerLogoutURL,
state: authProvider,
logger: logger,
providerLogoutURL: providerLogoutURL,
})
mux.Handle(pat.New("/token"), &tokenHandler{
authProvider,
Expand Down Expand Up @@ -425,7 +425,7 @@ func verifyAndSaveToken(ctx context.Context, state *AuthProvider, session *Sessi
}

session.Data["id_token"] = rawIDToken
session.Data["access_token"] = token.AccessToken
session.Data[accessTokenSessionDataField] = token.AccessToken
session.Data["profile"] = profile

return session, nil
Expand Down
49 changes: 49 additions & 0 deletions web/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,23 @@ import (
mongoutils "go.viam.com/utils/mongo"
)

var (
accessTokenSessionDataField = "access_token"
accessTokenFieldPath = fmt.Sprintf("data.%s", accessTokenSessionDataField)
)

var webSessionsIndex = []mongo.IndexModel{
{
Keys: bson.D{
{Key: "lastUpdate", Value: 1},
},
Options: options.Index().SetExpireAfterSeconds(30 * 24 * 3600),
},
{
Keys: bson.D{
{Key: accessTokenFieldPath, Value: 1},
},
},
}

// SessionManager handles working with sessions from http.
Expand All @@ -51,6 +61,7 @@ type Store interface {
Get(ctx context.Context, id string) (*Session, error)
Save(ctx context.Context, s *Session) error
SetSessionManager(*SessionManager)
HasSessionWithToken(ctx context.Context, token string) (bool, error)
}

// ----
Expand Down Expand Up @@ -130,6 +141,17 @@ func (sm *SessionManager) DeleteSession(ctx context.Context, r *http.Request, w
}
}

// HasSessionWithAccessToken returns true if there is an active session associated with that access token.
func (sm *SessionManager) HasSessionWithAccessToken(ctx context.Context, token string) bool {
hasSessionWithToken, err := sm.store.HasSessionWithToken(ctx, token)
if err != nil {
sm.logger.Errorw("error finding session with access token", "err", err)
return false
}

return hasSessionWithToken
}

func (sm *SessionManager) newID() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
Expand Down Expand Up @@ -216,6 +238,18 @@ func (mss *mongoDBSessionStore) Get(ctx context.Context, id string) (*Session, e
return s, nil
}

func (mss *mongoDBSessionStore) HasSessionWithToken(ctx context.Context, token string) (bool, error) {
ctx, span := trace.StartSpan(ctx, "MongoDBSessionStore::HasSessionWithToken")
defer span.End()

count, err := mss.collection.CountDocuments(ctx, bson.M{accessTokenFieldPath: token}, options.Count().SetLimit(1))
if err != nil {
return false, fmt.Errorf("failed to check if token session exists: %w", err)
}

return count > 0, nil
}

func (mss *mongoDBSessionStore) Save(ctx context.Context, s *Session) error {
ctx, span := trace.StartSpan(ctx, "MongoDBSessionStore::Save")
defer span.End()
Expand Down Expand Up @@ -258,6 +292,21 @@ func (mss *memorySessionStore) Delete(ctx context.Context, id string) error {
return nil
}

func (mss *memorySessionStore) HasSessionWithToken(ctx context.Context, token string) (bool, error) {
if mss.data != nil {
for _, session := range mss.data {
savedToken, ok := session.Data[accessTokenSessionDataField]
if !ok {
continue
}
if token == savedToken {
return true, nil
}
}
}
return false, errNoSession
}

func (mss *memorySessionStore) Get(ctx context.Context, id string) (*Session, error) {
if mss.data == nil {
return nil, errNoSession
Expand Down
28 changes: 27 additions & 1 deletion web/sessions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
)

func TestSession1(t *testing.T) {
ctx := context.Background()
sm := NewSessionManager(&memorySessionStore{}, golog.NewTestLogger(t))

r, err := http.NewRequest(http.MethodGet, "http://localhost/", nil)
Expand All @@ -36,7 +37,16 @@ func TestSession1(t *testing.T) {
w := &DummyWriter{}

s.Data["a"] = 1
s.Data["access_token"] = "the_access_token"
s.Save(context.TODO(), r, w)

if hasSession := sm.HasSessionWithAccessToken(ctx, "the_wrong_access_token"); hasSession {
t.Fatal("should not have session with token")
}

if hasSession := sm.HasSessionWithAccessToken(ctx, "the_access_token"); !hasSession {
t.Fatal("should have session with token")
}
}

// ----
Expand Down Expand Up @@ -64,7 +74,7 @@ func TestMongoStore(t *testing.T) {

s1 := &Session{}
s1.id = "foo"
s1.Data = bson.M{"a": 1, "b": 2}
s1.Data = bson.M{"a": 1, "b": 2, "access_token": "testToken"}
err = store.Save(ctx, s1)
if err != nil {
t.Fatal(err)
Expand All @@ -84,6 +94,22 @@ func TestMongoStore(t *testing.T) {
if _, err := store.Get(ctx, "something"); !errors.Is(err, errNoSession) {
t.Fatal(err)
}

hasSession, err := store.HasSessionWithToken(ctx, "no_token")
if err != nil {
t.Fatal(err)
}
if hasSession {
t.Fatal("should not have session")
}

hasSession, err = store.HasSessionWithToken(ctx, "testToken")
if err != nil {
t.Fatal(err)
}
if !hasSession {
t.Fatal("should have session")
}
}

// ----
Expand Down
Loading