diff --git a/driver/config/config.go b/driver/config/config.go index 0d755a11ba63..05d7ddef52a7 100644 --- a/driver/config/config.go +++ b/driver/config/config.go @@ -116,6 +116,7 @@ const ( ViperKeySessionTokenizerTemplates = "session.whoami.tokenizer.templates" ViperKeySessionWhoAmIAAL = "session.whoami.required_aal" ViperKeySessionWhoAmICaching = "feature_flags.cacheable_sessions" + ViperKeyFeatureFlagFasterSessionExtend = "feature_flags.faster_session_extend" ViperKeySessionWhoAmICachingMaxAge = "feature_flags.cacheable_sessions_max_age" ViperKeyUseContinueWithTransitions = "feature_flags.use_continue_with_transitions" ViperKeySessionRefreshMinTimeLeft = "session.earliest_possible_extend" @@ -1369,6 +1370,10 @@ func (p *Config) SessionWhoAmICaching(ctx context.Context) bool { return p.GetProvider(ctx).Bool(ViperKeySessionWhoAmICaching) } +func (p *Config) FeatureFlagFasterSessionExtend(ctx context.Context) bool { + return p.GetProvider(ctx).Bool(ViperKeyFeatureFlagFasterSessionExtend) +} + func (p *Config) SessionWhoAmICachingMaxAge(ctx context.Context) time.Duration { return p.GetProvider(ctx).DurationF(ViperKeySessionWhoAmICachingMaxAge, 0) } diff --git a/embedx/config.schema.json b/embedx/config.schema.json index 79bc83c88b1d..5349164b6f9b 100644 --- a/embedx/config.schema.json +++ b/embedx/config.schema.json @@ -2852,6 +2852,12 @@ "title": "Enable new flow transitions using `continue_with` items", "description": "If enabled allows new flow transitions using `continue_with` items.", "default": false + }, + "faster_session_extend": { + "type": "boolean", + "title": "Enable faster session extension", + "description": "If enabled allows faster session extension by skipping the session lookup. Disabling this feature will be deprecated in the future.", + "default": false } }, "additionalProperties": false diff --git a/persistence/sql/persister_session.go b/persistence/sql/persister_session.go index c0c1c3d865ba..412ed2d8a825 100644 --- a/persistence/sql/persister_session.go +++ b/persistence/sql/persister_session.go @@ -8,6 +8,9 @@ import ( "fmt" "time" + "github.com/ory/herodot" + "github.com/ory/x/dbal" + "github.com/gobuffalo/pop/v6" "github.com/gofrs/uuid" "github.com/pkg/errors" @@ -176,6 +179,61 @@ func (p *Persister) ListSessionsByIdentity( return s, t, nil } +// ExtendSession updates the expiry of a session. +func (p *Persister) ExtendSession(ctx context.Context, sessionID uuid.UUID) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ExtendSession") + defer otelx.End(span, &err) + + nid := p.NetworkID(ctx) + s := new(session.Session) + var didRefresh bool + if err := errors.WithStack(p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) (err error) { + lockBehavior := "" + if tx.Dialect.Name() == dbal.DriverCockroachDB { + // SKIP LOCKED returns no rows if the row is locked by another transaction. + lockBehavior = "FOR UPDATE SKIP LOCKED" + } + + if err := tx. + Where( + // We make use of the fact that CRDB supports FOR UPDATE as part of the WHERE clause. + fmt.Sprintf("id = ? AND nid = ? %s", lockBehavior), + sessionID, nid, + ).First(s); err != nil { + + // This is a special case for CockroachDB. If the row is locked, we do not see the session. Therefor we return + // a 404 not found error indicating to the user that the session might already be updated by someone else. + if errors.Is(err, sqlcon.ErrNoRows) && tx.Dialect.Name() == dbal.DriverCockroachDB { + return errors.WithStack(herodot.ErrNotFound.WithReason("The session you are trying to extend is already being extended by another request or does not exist.")) + } + + return sqlcon.HandleError(err) + } + + if !s.CanBeRefreshed(ctx, p.r.Config()) { + // This prevents excessive writes to the database. + return nil + } + + didRefresh = true + s = s.Refresh(ctx, p.r.Config()) + + if _, err := tx.Where("id = ? AND nid = ?", sessionID, nid).UpdateQuery(s, "expires_at"); err != nil { + return sqlcon.HandleError(err) + } + + return nil + })); err != nil { + return err + } + + if didRefresh { + trace.SpanFromContext(ctx).AddEvent(events.NewSessionLifespanExtended(ctx, s.ID, s.IdentityID, s.ExpiresAt)) + } + + return nil +} + // UpsertSession creates a session if not found else updates. // This operation also inserts Session device records when a session is being created. // The update operation skips updating Session device records since only one record would need to be updated in this case. @@ -196,6 +254,7 @@ func (p *Persister) UpsertSession(ctx context.Context, s *session.Session) (err trace.SpanFromContext(ctx).AddEvent(events.NewSessionIssued(ctx, string(s.AuthenticatorAssuranceLevel), s.ID, s.IdentityID)) } }() + return errors.WithStack(p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) (err error) { updated = false exists := false diff --git a/session/handler.go b/session/handler.go index 8953fb37c6c3..9dab8860c773 100644 --- a/session/handler.go +++ b/session/handler.go @@ -873,6 +873,10 @@ type extendSession struct { // Calling this endpoint extends the given session ID. If `session.earliest_possible_extend` is set it // will only extend the session after the specified time has passed. // +// This endpoint returns per default a 204 No Content response on success. Older Ory Network projects may +// return a 200 OK response with the session in the body. Returning the session as part of the response +// will be deprecated in the future and should not be relied upon. +// // Retrieve the session ID from the `/sessions/whoami` endpoint / `toSession` SDK method. // // Schemes: http, https @@ -882,30 +886,35 @@ type extendSession struct { // // Responses: // 200: session +// 204: emptyResponse // 400: errorGeneric // 404: errorGeneric // default: errorGeneric func (h *Handler) adminSessionExtend(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - iID, err := uuid.FromString(ps.ByName("id")) + id, err := uuid.FromString(ps.ByName("id")) if err != nil { h.r.Writer().WriteError(w, r, errors.WithStack(herodot.ErrBadRequest.WithError(err.Error()).WithDebug("could not parse UUID"))) return } - s, err := h.r.SessionPersister().GetSession(r.Context(), iID, ExpandDefault) - if err != nil { + c := h.r.Config() + if err := h.r.SessionPersister().ExtendSession(r.Context(), id); err != nil { h.r.Writer().WriteError(w, r, err) return } - c := h.r.Config() - if s.CanBeRefreshed(r.Context(), c) { - if err := h.r.SessionPersister().UpsertSession(r.Context(), s.Refresh(r.Context(), c)); err != nil { - h.r.Writer().WriteError(w, r, err) - return - } + // Default behavior going forward. + if c.FeatureFlagFasterSessionExtend(r.Context()) { + w.WriteHeader(http.StatusNoContent) + return } + // WARNING - this will be deprecated at some point! + s, err := h.r.SessionPersister().GetSession(r.Context(), id, ExpandDefault) + if err != nil { + h.r.Writer().WriteError(w, r, err) + return + } h.r.Writer().Write(w, r, s) } diff --git a/session/persistence.go b/session/persistence.go index 34bf4d2654d4..ab5ec0a47735 100644 --- a/session/persistence.go +++ b/session/persistence.go @@ -35,6 +35,9 @@ type Persister interface { // UpsertSession inserts or updates a session into / in the store. UpsertSession(ctx context.Context, s *Session) error + // ExtendSession updates the expiry of a session. + ExtendSession(ctx context.Context, sessionID uuid.UUID) error + // DeleteSession removes a session from the store. DeleteSession(ctx context.Context, id uuid.UUID) error diff --git a/session/test/persistence.go b/session/test/persistence.go index 727cc744bdce..8e8cbfeb18b2 100644 --- a/session/test/persistence.go +++ b/session/test/persistence.go @@ -8,6 +8,11 @@ import ( "testing" "time" + "github.com/pkg/errors" + "golang.org/x/sync/errgroup" + + "github.com/ory/x/dbal" + "github.com/gobuffalo/pop/v6" "github.com/ory/x/pagination/keysetpagination" @@ -604,5 +609,82 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface { _, err = p.GetSessionByToken(ctx, t2, session.ExpandNothing, identity.ExpandDefault) require.ErrorIs(t, err, sqlcon.ErrNoRows) }) + + t.Run("extend session lifespan but min time is not yet reached", func(t *testing.T) { + conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, time.Hour*2) + t.Cleanup(func() { + conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, nil) + }) + + var expected session.Session + require.NoError(t, faker.FakeData(&expected)) + expected.ExpiresAt = time.Now().Add(time.Hour * 10).Round(time.Second).UTC() + require.NoError(t, p.CreateIdentity(ctx, expected.Identity)) + require.NoError(t, p.UpsertSession(ctx, &expected)) + + require.NoError(t, p.ExtendSession(ctx, expected.ID)) + actual, err := p.GetSession(ctx, expected.ID, session.ExpandNothing) + require.NoError(t, err) + assert.Equal(t, expected.ExpiresAt, actual.ExpiresAt) + }) + + t.Run("extend session lifespan", func(t *testing.T) { + conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, time.Hour) + t.Cleanup(func() { + conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, nil) + }) + + conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, time.Hour*2) + var expected session.Session + require.NoError(t, faker.FakeData(&expected)) + expected.ExpiresAt = time.Now().Add(time.Hour).UTC() + require.NoError(t, p.CreateIdentity(ctx, expected.Identity)) + require.NoError(t, p.UpsertSession(ctx, &expected)) + + expectedExpiry := expected.Refresh(ctx, conf).ExpiresAt.Round(time.Minute) + require.NoError(t, p.ExtendSession(ctx, expected.ID)) + actual, err := p.GetSession(ctx, expected.ID, session.ExpandNothing) + require.NoError(t, err) + assert.Equal(t, expectedExpiry, actual.ExpiresAt.Round(time.Minute)) + }) + + t.Run("extend session lifespan on CockroachDB", func(t *testing.T) { + if p.GetConnection(ctx).Dialect.Name() != dbal.DriverCockroachDB { + t.Skip("Skipping test because driver is not CockroachDB") + } + + conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, time.Hour) + t.Cleanup(func() { + conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, nil) + }) + + conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, time.Hour*2) + var expected session.Session + require.NoError(t, faker.FakeData(&expected)) + expected.ExpiresAt = time.Now().Add(time.Hour).UTC() + require.NoError(t, p.CreateIdentity(ctx, expected.Identity)) + require.NoError(t, p.UpsertSession(ctx, &expected)) + + expectedExpiry := expected.Refresh(ctx, conf).ExpiresAt.Round(time.Minute) + + var foundExpectedCockroachError bool + g := errgroup.Group{} + for i := 0; i < 10; i++ { + g.Go(func() error { + err := p.ExtendSession(ctx, expected.ID) + if errors.Is(err, sqlcon.ErrNoRows) { + foundExpectedCockroachError = true + return nil + } + return err + }) + } + require.NoError(t, g.Wait()) + + actual, err := p.GetSession(ctx, expected.ID, session.ExpandNothing) + require.NoError(t, err) + assert.Equal(t, expectedExpiry, actual.ExpiresAt.Round(time.Minute)) + assert.True(t, foundExpectedCockroachError, "We expect to find a not found error caused by ... FOR UPDATE SKIP LOCKED") + }) } } diff --git a/x/events/events.go b/x/events/events.go index 52e921112d12..13ec16955539 100644 --- a/x/events/events.go +++ b/x/events/events.go @@ -16,31 +16,33 @@ import ( ) const ( - SessionIssued semconv.Event = "SessionIssued" - SessionChanged semconv.Event = "SessionChanged" - SessionRevoked semconv.Event = "SessionRevoked" - SessionChecked semconv.Event = "SessionChecked" - SessionTokenizedAsJWT semconv.Event = "SessionTokenizedAsJWT" - RegistrationFailed semconv.Event = "RegistrationFailed" - RegistrationSucceeded semconv.Event = "RegistrationSucceeded" - LoginFailed semconv.Event = "LoginFailed" - LoginSucceeded semconv.Event = "LoginSucceeded" - SettingsFailed semconv.Event = "SettingsFailed" - SettingsSucceeded semconv.Event = "SettingsSucceeded" - RecoveryFailed semconv.Event = "RecoveryFailed" - RecoverySucceeded semconv.Event = "RecoverySucceeded" - VerificationFailed semconv.Event = "VerificationFailed" - VerificationSucceeded semconv.Event = "VerificationSucceeded" - IdentityCreated semconv.Event = "IdentityCreated" - IdentityUpdated semconv.Event = "IdentityUpdated" - WebhookDelivered semconv.Event = "WebhookDelivered" - WebhookSucceeded semconv.Event = "WebhookSucceeded" - WebhookFailed semconv.Event = "WebhookFailed" + SessionIssued semconv.Event = "SessionIssued" + SessionChanged semconv.Event = "SessionChanged" + SessionLifespanExtended semconv.Event = "SessionLifespanExtended" + SessionRevoked semconv.Event = "SessionRevoked" + SessionChecked semconv.Event = "SessionChecked" + SessionTokenizedAsJWT semconv.Event = "SessionTokenizedAsJWT" + RegistrationFailed semconv.Event = "RegistrationFailed" + RegistrationSucceeded semconv.Event = "RegistrationSucceeded" + LoginFailed semconv.Event = "LoginFailed" + LoginSucceeded semconv.Event = "LoginSucceeded" + SettingsFailed semconv.Event = "SettingsFailed" + SettingsSucceeded semconv.Event = "SettingsSucceeded" + RecoveryFailed semconv.Event = "RecoveryFailed" + RecoverySucceeded semconv.Event = "RecoverySucceeded" + VerificationFailed semconv.Event = "VerificationFailed" + VerificationSucceeded semconv.Event = "VerificationSucceeded" + IdentityCreated semconv.Event = "IdentityCreated" + IdentityUpdated semconv.Event = "IdentityUpdated" + WebhookDelivered semconv.Event = "WebhookDelivered" + WebhookSucceeded semconv.Event = "WebhookSucceeded" + WebhookFailed semconv.Event = "WebhookFailed" ) const ( attributeKeySessionID semconv.AttributeKey = "SessionID" attributeKeySessionAAL semconv.AttributeKey = "SessionAAL" + attributeKeySessionExpiresAt semconv.AttributeKey = "SessionExpiresAt" attributeKeySelfServiceFlowType semconv.AttributeKey = "SelfServiceFlowType" attributeKeySelfServiceMethodUsed semconv.AttributeKey = "SelfServiceMethodUsed" attributeKeySelfServiceSSOProviderUsed semconv.AttributeKey = "SelfServiceSSOProviderUsed" @@ -71,6 +73,10 @@ func attLoginRequestedAAL(val string) otelattr.KeyValue { return otelattr.String(attributeKeyLoginRequestedAAL.String(), val) } +func attSessionExpiresAt(expiresAt time.Time) otelattr.KeyValue { + return otelattr.String(attributeKeySessionExpiresAt.String(), expiresAt.String()) +} + func attLoginRequestedPrivilegedSession(val bool) otelattr.KeyValue { return otelattr.Bool(attributeKeyLoginRequestedPrivilegedSession.String(), val) } @@ -135,6 +141,18 @@ func NewSessionChanged(ctx context.Context, aal string, sessionID, identityID uu ) } +func NewSessionLifespanExtended(ctx context.Context, sessionID, identityID uuid.UUID, newExpiry time.Time) (string, trace.EventOption) { + return SessionLifespanExtended.String(), + trace.WithAttributes( + append( + semconv.AttributesFromContext(ctx), + semconv.AttrIdentityID(identityID), + attrSessionID(sessionID), + attSessionExpiresAt(newExpiry), + )..., + ) +} + type LoginSucceededOpts struct { SessionID, IdentityID uuid.UUID FlowType, RequestedAAL, Method, SSOProvider string