Skip to content

Commit

Permalink
chore: minor improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
zepatrik committed Oct 17, 2024
1 parent 848e63a commit f7d43a1
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 65 deletions.
109 changes: 53 additions & 56 deletions oauth2/fosite_store_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ import (
"testing"
"time"

"github.com/ory/hydra/v2/driver/config"

"github.com/ory/x/assertx"

"github.com/ory/hydra/v2/flow"
Expand All @@ -27,6 +25,7 @@ import (

"github.com/ory/hydra/v2/oauth2/trust"

"github.com/ory/hydra/v2/driver/config"
"github.com/ory/hydra/v2/x"

"github.com/ory/fosite/storage"
Expand Down Expand Up @@ -232,28 +231,29 @@ func TestHelperRunner(t *testing.T, store InternalRegistry, k string) {

func testHelperRequestIDMultiples(m InternalRegistry, _ string) func(t *testing.T) {
return func(t *testing.T) {
requestId := uuid.New()
mockRequestForeignKey(t, requestId, m)
ctx := context.Background()
requestID := uuid.New()
mockRequestForeignKey(t, requestID, m)
cl := &client.Client{ID: "foobar"}

fositeRequest := &fosite.Request{
ID: requestId,
ID: requestID,
Client: cl,
RequestedAt: time.Now().UTC().Round(time.Second),
Session: NewSession("bar"),
}

for i := 0; i < 4; i++ {
signature := uuid.New()
err := m.OAuth2Storage().CreateRefreshTokenSession(context.TODO(), signature, fositeRequest)
err := m.OAuth2Storage().CreateRefreshTokenSession(ctx, signature, fositeRequest)
assert.NoError(t, err)
err = m.OAuth2Storage().CreateAccessTokenSession(context.TODO(), signature, fositeRequest)
err = m.OAuth2Storage().CreateAccessTokenSession(ctx, signature, fositeRequest)
assert.NoError(t, err)
err = m.OAuth2Storage().CreateOpenIDConnectSession(context.TODO(), signature, fositeRequest)
err = m.OAuth2Storage().CreateOpenIDConnectSession(ctx, signature, fositeRequest)
assert.NoError(t, err)
err = m.OAuth2Storage().CreatePKCERequestSession(context.TODO(), signature, fositeRequest)
err = m.OAuth2Storage().CreatePKCERequestSession(ctx, signature, fositeRequest)
assert.NoError(t, err)
err = m.OAuth2Storage().CreateAuthorizeCodeSession(context.TODO(), signature, fositeRequest)
err = m.OAuth2Storage().CreateAuthorizeCodeSession(ctx, signature, fositeRequest)
assert.NoError(t, err)
}
}
Expand Down Expand Up @@ -478,7 +478,7 @@ func testHelperNilAccessToken(x InternalRegistry) func(t *testing.T) {
m := x.OAuth2Storage()
c := &client.Client{ID: "nil-request-client-id-123"}
require.NoError(t, x.ClientManager().CreateClient(context.Background(), c))
err := m.CreateAccessTokenSession(context.TODO(), "nil-request-id", &fosite.Request{
err := m.CreateAccessTokenSession(context.Background(), "nil-request-id", &fosite.Request{
ID: "",
RequestedAt: time.Now().UTC().Round(time.Second),
Client: c,
Expand Down Expand Up @@ -558,31 +558,29 @@ func testHelperRevokeAccessToken(x InternalRegistry) func(t *testing.T) {

func testHelperRevokeRefreshTokenMaybeGracePeriod(x InternalRegistry) func(t *testing.T) {
return func(t *testing.T) {
ctx := context.Background()

t.Run("Revokes refresh token when grace period not configured", func(t *testing.T) {
// SETUP
m := x.OAuth2Storage()
ctx := context.Background()

refreshTokenSession := fmt.Sprintf("refresh_token_%d", time.Now().Unix())
err := m.CreateRefreshTokenSession(ctx, refreshTokenSession, &defaultRequest)
require.NoError(t, err, "precondition failed: could not create refresh token session")

// ACT
err = m.RevokeRefreshTokenMaybeGracePeriod(ctx, defaultRequest.GetID(), refreshTokenSession)
assert.NoError(t, err)
require.NoError(t, err)

tmpSession := new(fosite.Session)
_, err = m.GetRefreshTokenSession(ctx, refreshTokenSession, *tmpSession)

// ASSERT
// a revoked refresh token returns an error when getting the token again
assert.Error(t, err)
assert.True(t, errors.Is(err, fosite.ErrInactiveToken))
assert.ErrorIs(t, err, fosite.ErrInactiveToken)
})

t.Run("refresh token enters grace period when configured,", func(t *testing.T) {
ctx := context.Background()

// SETUP
x.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "1m")

Expand All @@ -595,16 +593,14 @@ func testHelperRevokeRefreshTokenMaybeGracePeriod(x InternalRegistry) func(t *te

refreshTokenSession := fmt.Sprintf("refresh_token_%d_with_grace_period", time.Now().Unix())
err := m.CreateRefreshTokenSession(ctx, refreshTokenSession, &defaultRequest)
assert.NoError(t, err, "precondition failed: could not create refresh token session")
require.NoError(t, err, "precondition failed: could not create refresh token session")

// ACT
assert.NoError(t, m.RevokeRefreshTokenMaybeGracePeriod(ctx, defaultRequest.GetID(), refreshTokenSession))
assert.NoError(t, m.RevokeRefreshTokenMaybeGracePeriod(ctx, defaultRequest.GetID(), refreshTokenSession))
assert.NoError(t, m.RevokeRefreshTokenMaybeGracePeriod(ctx, defaultRequest.GetID(), refreshTokenSession))
require.NoError(t, m.RevokeRefreshTokenMaybeGracePeriod(ctx, defaultRequest.GetID(), refreshTokenSession))
require.NoError(t, m.RevokeRefreshTokenMaybeGracePeriod(ctx, defaultRequest.GetID(), refreshTokenSession))
require.NoError(t, m.RevokeRefreshTokenMaybeGracePeriod(ctx, defaultRequest.GetID(), refreshTokenSession))

tmpSession := new(fosite.Session)
req, err := m.GetRefreshTokenSession(ctx, refreshTokenSession, *tmpSession)
assert.NoError(t, err)
req, err := m.GetRefreshTokenSession(ctx, refreshTokenSession, nil)

// ASSERT
// when grace period is configured the refresh token can be obtained within
Expand Down Expand Up @@ -944,6 +940,7 @@ func testFositeStoreClientAssertionJWTValid(m InternalRegistry) func(*testing.T)

func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) {
return func(t *testing.T) {
ctx := context.Background()
grantManager := x.GrantManager()
keyManager := x.KeyManager()
grantStorage := x.OAuth2Storage().(rfc7523.RFC7523KeyStorage)
Expand All @@ -966,28 +963,28 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) {
ExpiresAt: time.Now().UTC().Round(time.Second).AddDate(1, 0, 0),
}

storedKeySet, err := grantStorage.GetPublicKeys(context.TODO(), issuer, subject)
storedKeySet, err := grantStorage.GetPublicKeys(ctx, issuer, subject)
require.NoError(t, err)
require.Len(t, storedKeySet.Keys, 0)

err = grantManager.CreateGrant(context.TODO(), grant, publicKey)
err = grantManager.CreateGrant(ctx, grant, publicKey)
require.NoError(t, err)

storedKeySet, err = grantStorage.GetPublicKeys(context.TODO(), issuer, subject)
storedKeySet, err = grantStorage.GetPublicKeys(ctx, issuer, subject)
require.NoError(t, err)
assert.Len(t, storedKeySet.Keys, 1)

storedKey, err := grantStorage.GetPublicKey(context.TODO(), issuer, subject, publicKey.KeyID)
storedKey, err := grantStorage.GetPublicKey(ctx, issuer, subject, publicKey.KeyID)
require.NoError(t, err)
assert.Equal(t, publicKey.KeyID, storedKey.KeyID)
assert.Equal(t, publicKey.Use, storedKey.Use)
assert.Equal(t, publicKey.Key, storedKey.Key)

storedScopes, err := grantStorage.GetPublicKeyScopes(context.TODO(), issuer, subject, publicKey.KeyID)
storedScopes, err := grantStorage.GetPublicKeyScopes(ctx, issuer, subject, publicKey.KeyID)
require.NoError(t, err)
assert.Equal(t, grant.Scope, storedScopes)

storedKeySet, err = keyManager.GetKey(context.TODO(), issuer, publicKey.KeyID)
storedKeySet, err = keyManager.GetKey(ctx, issuer, publicKey.KeyID)
require.NoError(t, err)
assert.Equal(t, publicKey.KeyID, storedKeySet.Keys[0].KeyID)
assert.Equal(t, publicKey.Use, storedKeySet.Keys[0].Use)
Expand Down Expand Up @@ -1017,7 +1014,7 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) {

keySet2ToReturn, err := jwk.GenerateJWK(context.Background(), jose.ES256, "maria-key-2", "sig")
require.NoError(t, err)
require.NoError(t, grantManager.CreateGrant(context.TODO(), trust.Grant{
require.NoError(t, grantManager.CreateGrant(ctx, trust.Grant{
ID: uuid.New(),
Issuer: issuer,
Subject: subject,
Expand Down Expand Up @@ -1075,22 +1072,22 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) {
ExpiresAt: time.Now().UTC().Round(time.Second).AddDate(1, 0, 0),
}

err = grantManager.CreateGrant(context.TODO(), grant, publicKey)
err = grantManager.CreateGrant(ctx, grant, publicKey)
require.NoError(t, err)

_, err = grantStorage.GetPublicKey(context.TODO(), issuer, subject, grant.PublicKey.KeyID)
_, err = grantStorage.GetPublicKey(ctx, issuer, subject, grant.PublicKey.KeyID)
require.NoError(t, err)

_, err = keyManager.GetKey(context.TODO(), issuer, publicKey.KeyID)
_, err = keyManager.GetKey(ctx, issuer, publicKey.KeyID)
require.NoError(t, err)

err = grantManager.DeleteGrant(context.TODO(), grant.ID)
err = grantManager.DeleteGrant(ctx, grant.ID)
require.NoError(t, err)

_, err = grantStorage.GetPublicKey(context.TODO(), issuer, subject, publicKey.KeyID)
_, err = grantStorage.GetPublicKey(ctx, issuer, subject, publicKey.KeyID)
assert.Error(t, err)

_, err = keyManager.GetKey(context.TODO(), issuer, publicKey.KeyID)
_, err = keyManager.GetKey(ctx, issuer, publicKey.KeyID)
assert.Error(t, err)
})

Expand All @@ -1112,22 +1109,22 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) {
ExpiresAt: time.Now().UTC().Round(time.Second).AddDate(1, 0, 0),
}

err = grantManager.CreateGrant(context.TODO(), grant, publicKey)
err = grantManager.CreateGrant(ctx, grant, publicKey)
require.NoError(t, err)

_, err = grantStorage.GetPublicKey(context.TODO(), issuer, subject, publicKey.KeyID)
_, err = grantStorage.GetPublicKey(ctx, issuer, subject, publicKey.KeyID)
require.NoError(t, err)

_, err = keyManager.GetKey(context.TODO(), issuer, publicKey.KeyID)
_, err = keyManager.GetKey(ctx, issuer, publicKey.KeyID)
require.NoError(t, err)

err = keyManager.DeleteKey(context.TODO(), issuer, publicKey.KeyID)
err = keyManager.DeleteKey(ctx, issuer, publicKey.KeyID)
require.NoError(t, err)

_, err = keyManager.GetKey(context.TODO(), issuer, publicKey.KeyID)
_, err = keyManager.GetKey(ctx, issuer, publicKey.KeyID)
assert.Error(t, err)

_, err = grantManager.GetConcreteGrant(context.TODO(), grant.ID)
_, err = grantManager.GetConcreteGrant(ctx, grant.ID)
assert.Error(t, err)
})

Expand All @@ -1149,25 +1146,25 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) {
ExpiresAt: time.Now().UTC().Round(time.Second).AddDate(1, 0, 0),
}

err = grantManager.CreateGrant(context.TODO(), grant, publicKey)
err = grantManager.CreateGrant(ctx, grant, publicKey)
require.NoError(t, err)

// All three get methods should only return the public key when using the valid subject
_, err = grantStorage.GetPublicKey(context.TODO(), issuer, "any-subject-1", publicKey.KeyID)
_, err = grantStorage.GetPublicKey(ctx, issuer, "any-subject-1", publicKey.KeyID)
require.Error(t, err)
_, err = grantStorage.GetPublicKey(context.TODO(), issuer, subject, publicKey.KeyID)
_, err = grantStorage.GetPublicKey(ctx, issuer, subject, publicKey.KeyID)
require.NoError(t, err)

_, err = grantStorage.GetPublicKeyScopes(context.TODO(), issuer, "any-subject-2", publicKey.KeyID)
_, err = grantStorage.GetPublicKeyScopes(ctx, issuer, "any-subject-2", publicKey.KeyID)
require.Error(t, err)
_, err = grantStorage.GetPublicKeyScopes(context.TODO(), issuer, subject, publicKey.KeyID)
_, err = grantStorage.GetPublicKeyScopes(ctx, issuer, subject, publicKey.KeyID)
require.NoError(t, err)

jwks, err := grantStorage.GetPublicKeys(context.TODO(), issuer, "any-subject-3")
jwks, err := grantStorage.GetPublicKeys(ctx, issuer, "any-subject-3")
require.NoError(t, err)
require.NotNil(t, jwks)
require.Empty(t, jwks.Keys)
jwks, err = grantStorage.GetPublicKeys(context.TODO(), issuer, subject)
jwks, err = grantStorage.GetPublicKeys(ctx, issuer, subject)
require.NoError(t, err)
require.NotNil(t, jwks)
require.NotEmpty(t, jwks.Keys)
Expand All @@ -1190,17 +1187,17 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) {
ExpiresAt: time.Now().UTC().Round(time.Second).AddDate(1, 0, 0),
}

err = grantManager.CreateGrant(context.TODO(), grant, publicKey)
err = grantManager.CreateGrant(ctx, grant, publicKey)
require.NoError(t, err)

// All three get methods should always return the public key
_, err = grantStorage.GetPublicKey(context.TODO(), issuer, "any-subject-1", publicKey.KeyID)
_, err = grantStorage.GetPublicKey(ctx, issuer, "any-subject-1", publicKey.KeyID)
require.NoError(t, err)

_, err = grantStorage.GetPublicKeyScopes(context.TODO(), issuer, "any-subject-2", publicKey.KeyID)
_, err = grantStorage.GetPublicKeyScopes(ctx, issuer, "any-subject-2", publicKey.KeyID)
require.NoError(t, err)

jwks, err := grantStorage.GetPublicKeys(context.TODO(), issuer, "any-subject-3")
jwks, err := grantStorage.GetPublicKeys(ctx, issuer, "any-subject-3")
require.NoError(t, err)
require.NotNil(t, jwks)
require.NotEmpty(t, jwks.Keys)
Expand All @@ -1223,10 +1220,10 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) {
ExpiresAt: time.Now().UTC().Round(time.Second).AddDate(-1, 0, 0),
}

err = grantManager.CreateGrant(context.TODO(), grant, publicKey)
err = grantManager.CreateGrant(ctx, grant, publicKey)
require.NoError(t, err)

keys, err := grantStorage.GetPublicKeys(context.TODO(), issuer, "any-subject-3")
keys, err := grantStorage.GetPublicKeys(ctx, issuer, "any-subject-3")
require.NoError(t, err)
assert.Len(t, keys.Keys, 0)
})
Expand Down
13 changes: 4 additions & 9 deletions persistence/sql/persister_oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -475,14 +475,14 @@ func (p *Persister) GetRefreshTokenSession(ctx context.Context, signature string

if gracePeriod := p.r.Config().RefreshTokenRotationGracePeriod(ctx); gracePeriod > 0 && r.FirstUsedAt.Valid {
if r.FirstUsedAt.Time.Add(gracePeriod).Before(time.Now()) {
return fositeRequest, errorsx.WithStack(fosite.ErrInactiveToken)
return fositeRequest, errors.WithStack(fosite.ErrInactiveToken)
}

r.Active = true // We set active to true because we are in the grace period.
return r.toRequest(ctx, session, p) // And re-generate the request
}

return fositeRequest, errorsx.WithStack(fosite.ErrInactiveToken)
return fositeRequest, errors.WithStack(fosite.ErrInactiveToken)
}

func (p *Persister) DeleteRefreshTokenSession(ctx context.Context, signature string) (err error) {
Expand Down Expand Up @@ -536,20 +536,15 @@ func (p *Persister) RevokeRefreshToken(ctx context.Context, id string) (err erro
return p.deactivateSessionByRequestID(ctx, id, sqlTableRefresh)
}

func (p *Persister) RevokeRefreshTokenMaybeGracePeriod(ctx context.Context, id string, signature string) (err error) {
func (p *Persister) RevokeRefreshTokenMaybeGracePeriod(ctx context.Context, id string, _ string) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeRefreshTokenMaybeGracePeriod")
defer otelx.End(span, &err)

gracePeriod := p.config.RefreshTokenRotationGracePeriod(ctx)
if gracePeriod <= 0 {
return p.deactivateSessionByRequestID(ctx, id, sqlTableRefresh)
}

/* #nosec G201 table is static */
return sqlcon.HandleError(
p.Connection(ctx).
RawQuery(
fmt.Sprintf("UPDATE %s SET active=false, first_used_at = CURRENT_TIMESTAMP WHERE request_id=? AND nid = ? AND active=true", OAuth2RequestSQL{Table: sqlTableRefresh}.TableName()),
fmt.Sprintf("UPDATE %s SET active=false, first_used_at = CURRENT_TIMESTAMP WHERE request_id=? AND nid = ? AND active", OAuth2RequestSQL{Table: sqlTableRefresh}.TableName()),
id,
p.NetworkID(ctx),
).
Expand Down

0 comments on commit f7d43a1

Please sign in to comment.