From 4e3fad4b4739b5cf00d658155350cb599f2cd06a Mon Sep 17 00:00:00 2001 From: hackerman <3372410+aeneasr@users.noreply.github.com> Date: Thu, 6 Jun 2024 17:37:39 +0200 Subject: [PATCH] feat: improve session extend performance (#3948) This patch improves the performance for extending session lifespans. Lifespan extension is tricky as it is often part of the middleware of Ory Kratos consumers. As such, it is prone to transaction contention when we read and write to the same session row at the same time (and potentially multiple times). To address this, we: 1. Introduce a locking mechanism on the row to reduce transaction contention; 2. Add a new feature flag that toggles returning 204 no content instead of 200 + session. Be aware that all reads on the session table will have to wait for the transaction to commit before they return a value. This may cause long(er) response times on `/session/whoami` for sessions that are being extended at the same time. BREAKING CHANGES: Going forward, the `/admin/session/.../extend` endpoint will return 204 no content for new Ory Network projects. We will deprecate returning 200 + session body in the future. --- driver/config/config.go | 5 ++ embedx/config.schema.json | 6 ++ persistence/sql/persister_session.go | 59 ++++++++++++++++++++ session/handler.go | 27 ++++++--- session/persistence.go | 3 + session/test/persistence.go | 82 ++++++++++++++++++++++++++++ x/events/events.go | 58 +++++++++++++------- 7 files changed, 211 insertions(+), 29 deletions(-) diff --git a/driver/config/config.go b/driver/config/config.go index 0d755a11ba63..05d7ddef52a7 100644 --- a/driver/config/config.go +++ b/driver/config/config.go @@ -116,6 +116,7 @@ const ( ViperKeySessionTokenizerTemplates = "session.whoami.tokenizer.templates" ViperKeySessionWhoAmIAAL = "session.whoami.required_aal" ViperKeySessionWhoAmICaching = "feature_flags.cacheable_sessions" + ViperKeyFeatureFlagFasterSessionExtend = "feature_flags.faster_session_extend" ViperKeySessionWhoAmICachingMaxAge = "feature_flags.cacheable_sessions_max_age" ViperKeyUseContinueWithTransitions = "feature_flags.use_continue_with_transitions" ViperKeySessionRefreshMinTimeLeft = "session.earliest_possible_extend" @@ -1369,6 +1370,10 @@ func (p *Config) SessionWhoAmICaching(ctx context.Context) bool { return p.GetProvider(ctx).Bool(ViperKeySessionWhoAmICaching) } +func (p *Config) FeatureFlagFasterSessionExtend(ctx context.Context) bool { + return p.GetProvider(ctx).Bool(ViperKeyFeatureFlagFasterSessionExtend) +} + func (p *Config) SessionWhoAmICachingMaxAge(ctx context.Context) time.Duration { return p.GetProvider(ctx).DurationF(ViperKeySessionWhoAmICachingMaxAge, 0) } diff --git a/embedx/config.schema.json b/embedx/config.schema.json index 79bc83c88b1d..5349164b6f9b 100644 --- a/embedx/config.schema.json +++ b/embedx/config.schema.json @@ -2852,6 +2852,12 @@ "title": "Enable new flow transitions using `continue_with` items", "description": "If enabled allows new flow transitions using `continue_with` items.", "default": false + }, + "faster_session_extend": { + "type": "boolean", + "title": "Enable faster session extension", + "description": "If enabled allows faster session extension by skipping the session lookup. Disabling this feature will be deprecated in the future.", + "default": false } }, "additionalProperties": false diff --git a/persistence/sql/persister_session.go b/persistence/sql/persister_session.go index c0c1c3d865ba..412ed2d8a825 100644 --- a/persistence/sql/persister_session.go +++ b/persistence/sql/persister_session.go @@ -8,6 +8,9 @@ import ( "fmt" "time" + "github.com/ory/herodot" + "github.com/ory/x/dbal" + "github.com/gobuffalo/pop/v6" "github.com/gofrs/uuid" "github.com/pkg/errors" @@ -176,6 +179,61 @@ func (p *Persister) ListSessionsByIdentity( return s, t, nil } +// ExtendSession updates the expiry of a session. +func (p *Persister) ExtendSession(ctx context.Context, sessionID uuid.UUID) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ExtendSession") + defer otelx.End(span, &err) + + nid := p.NetworkID(ctx) + s := new(session.Session) + var didRefresh bool + if err := errors.WithStack(p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) (err error) { + lockBehavior := "" + if tx.Dialect.Name() == dbal.DriverCockroachDB { + // SKIP LOCKED returns no rows if the row is locked by another transaction. + lockBehavior = "FOR UPDATE SKIP LOCKED" + } + + if err := tx. + Where( + // We make use of the fact that CRDB supports FOR UPDATE as part of the WHERE clause. + fmt.Sprintf("id = ? AND nid = ? %s", lockBehavior), + sessionID, nid, + ).First(s); err != nil { + + // This is a special case for CockroachDB. If the row is locked, we do not see the session. Therefor we return + // a 404 not found error indicating to the user that the session might already be updated by someone else. + if errors.Is(err, sqlcon.ErrNoRows) && tx.Dialect.Name() == dbal.DriverCockroachDB { + return errors.WithStack(herodot.ErrNotFound.WithReason("The session you are trying to extend is already being extended by another request or does not exist.")) + } + + return sqlcon.HandleError(err) + } + + if !s.CanBeRefreshed(ctx, p.r.Config()) { + // This prevents excessive writes to the database. + return nil + } + + didRefresh = true + s = s.Refresh(ctx, p.r.Config()) + + if _, err := tx.Where("id = ? AND nid = ?", sessionID, nid).UpdateQuery(s, "expires_at"); err != nil { + return sqlcon.HandleError(err) + } + + return nil + })); err != nil { + return err + } + + if didRefresh { + trace.SpanFromContext(ctx).AddEvent(events.NewSessionLifespanExtended(ctx, s.ID, s.IdentityID, s.ExpiresAt)) + } + + return nil +} + // UpsertSession creates a session if not found else updates. // This operation also inserts Session device records when a session is being created. // The update operation skips updating Session device records since only one record would need to be updated in this case. @@ -196,6 +254,7 @@ func (p *Persister) UpsertSession(ctx context.Context, s *session.Session) (err trace.SpanFromContext(ctx).AddEvent(events.NewSessionIssued(ctx, string(s.AuthenticatorAssuranceLevel), s.ID, s.IdentityID)) } }() + return errors.WithStack(p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) (err error) { updated = false exists := false diff --git a/session/handler.go b/session/handler.go index 8953fb37c6c3..9dab8860c773 100644 --- a/session/handler.go +++ b/session/handler.go @@ -873,6 +873,10 @@ type extendSession struct { // Calling this endpoint extends the given session ID. If `session.earliest_possible_extend` is set it // will only extend the session after the specified time has passed. // +// This endpoint returns per default a 204 No Content response on success. Older Ory Network projects may +// return a 200 OK response with the session in the body. Returning the session as part of the response +// will be deprecated in the future and should not be relied upon. +// // Retrieve the session ID from the `/sessions/whoami` endpoint / `toSession` SDK method. // // Schemes: http, https @@ -882,30 +886,35 @@ type extendSession struct { // // Responses: // 200: session +// 204: emptyResponse // 400: errorGeneric // 404: errorGeneric // default: errorGeneric func (h *Handler) adminSessionExtend(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - iID, err := uuid.FromString(ps.ByName("id")) + id, err := uuid.FromString(ps.ByName("id")) if err != nil { h.r.Writer().WriteError(w, r, errors.WithStack(herodot.ErrBadRequest.WithError(err.Error()).WithDebug("could not parse UUID"))) return } - s, err := h.r.SessionPersister().GetSession(r.Context(), iID, ExpandDefault) - if err != nil { + c := h.r.Config() + if err := h.r.SessionPersister().ExtendSession(r.Context(), id); err != nil { h.r.Writer().WriteError(w, r, err) return } - c := h.r.Config() - if s.CanBeRefreshed(r.Context(), c) { - if err := h.r.SessionPersister().UpsertSession(r.Context(), s.Refresh(r.Context(), c)); err != nil { - h.r.Writer().WriteError(w, r, err) - return - } + // Default behavior going forward. + if c.FeatureFlagFasterSessionExtend(r.Context()) { + w.WriteHeader(http.StatusNoContent) + return } + // WARNING - this will be deprecated at some point! + s, err := h.r.SessionPersister().GetSession(r.Context(), id, ExpandDefault) + if err != nil { + h.r.Writer().WriteError(w, r, err) + return + } h.r.Writer().Write(w, r, s) } diff --git a/session/persistence.go b/session/persistence.go index 34bf4d2654d4..ab5ec0a47735 100644 --- a/session/persistence.go +++ b/session/persistence.go @@ -35,6 +35,9 @@ type Persister interface { // UpsertSession inserts or updates a session into / in the store. UpsertSession(ctx context.Context, s *Session) error + // ExtendSession updates the expiry of a session. + ExtendSession(ctx context.Context, sessionID uuid.UUID) error + // DeleteSession removes a session from the store. DeleteSession(ctx context.Context, id uuid.UUID) error diff --git a/session/test/persistence.go b/session/test/persistence.go index 727cc744bdce..8e8cbfeb18b2 100644 --- a/session/test/persistence.go +++ b/session/test/persistence.go @@ -8,6 +8,11 @@ import ( "testing" "time" + "github.com/pkg/errors" + "golang.org/x/sync/errgroup" + + "github.com/ory/x/dbal" + "github.com/gobuffalo/pop/v6" "github.com/ory/x/pagination/keysetpagination" @@ -604,5 +609,82 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface { _, err = p.GetSessionByToken(ctx, t2, session.ExpandNothing, identity.ExpandDefault) require.ErrorIs(t, err, sqlcon.ErrNoRows) }) + + t.Run("extend session lifespan but min time is not yet reached", func(t *testing.T) { + conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, time.Hour*2) + t.Cleanup(func() { + conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, nil) + }) + + var expected session.Session + require.NoError(t, faker.FakeData(&expected)) + expected.ExpiresAt = time.Now().Add(time.Hour * 10).Round(time.Second).UTC() + require.NoError(t, p.CreateIdentity(ctx, expected.Identity)) + require.NoError(t, p.UpsertSession(ctx, &expected)) + + require.NoError(t, p.ExtendSession(ctx, expected.ID)) + actual, err := p.GetSession(ctx, expected.ID, session.ExpandNothing) + require.NoError(t, err) + assert.Equal(t, expected.ExpiresAt, actual.ExpiresAt) + }) + + t.Run("extend session lifespan", func(t *testing.T) { + conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, time.Hour) + t.Cleanup(func() { + conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, nil) + }) + + conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, time.Hour*2) + var expected session.Session + require.NoError(t, faker.FakeData(&expected)) + expected.ExpiresAt = time.Now().Add(time.Hour).UTC() + require.NoError(t, p.CreateIdentity(ctx, expected.Identity)) + require.NoError(t, p.UpsertSession(ctx, &expected)) + + expectedExpiry := expected.Refresh(ctx, conf).ExpiresAt.Round(time.Minute) + require.NoError(t, p.ExtendSession(ctx, expected.ID)) + actual, err := p.GetSession(ctx, expected.ID, session.ExpandNothing) + require.NoError(t, err) + assert.Equal(t, expectedExpiry, actual.ExpiresAt.Round(time.Minute)) + }) + + t.Run("extend session lifespan on CockroachDB", func(t *testing.T) { + if p.GetConnection(ctx).Dialect.Name() != dbal.DriverCockroachDB { + t.Skip("Skipping test because driver is not CockroachDB") + } + + conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, time.Hour) + t.Cleanup(func() { + conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, nil) + }) + + conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, time.Hour*2) + var expected session.Session + require.NoError(t, faker.FakeData(&expected)) + expected.ExpiresAt = time.Now().Add(time.Hour).UTC() + require.NoError(t, p.CreateIdentity(ctx, expected.Identity)) + require.NoError(t, p.UpsertSession(ctx, &expected)) + + expectedExpiry := expected.Refresh(ctx, conf).ExpiresAt.Round(time.Minute) + + var foundExpectedCockroachError bool + g := errgroup.Group{} + for i := 0; i < 10; i++ { + g.Go(func() error { + err := p.ExtendSession(ctx, expected.ID) + if errors.Is(err, sqlcon.ErrNoRows) { + foundExpectedCockroachError = true + return nil + } + return err + }) + } + require.NoError(t, g.Wait()) + + actual, err := p.GetSession(ctx, expected.ID, session.ExpandNothing) + require.NoError(t, err) + assert.Equal(t, expectedExpiry, actual.ExpiresAt.Round(time.Minute)) + assert.True(t, foundExpectedCockroachError, "We expect to find a not found error caused by ... FOR UPDATE SKIP LOCKED") + }) } } diff --git a/x/events/events.go b/x/events/events.go index 52e921112d12..13ec16955539 100644 --- a/x/events/events.go +++ b/x/events/events.go @@ -16,31 +16,33 @@ import ( ) const ( - SessionIssued semconv.Event = "SessionIssued" - SessionChanged semconv.Event = "SessionChanged" - SessionRevoked semconv.Event = "SessionRevoked" - SessionChecked semconv.Event = "SessionChecked" - SessionTokenizedAsJWT semconv.Event = "SessionTokenizedAsJWT" - RegistrationFailed semconv.Event = "RegistrationFailed" - RegistrationSucceeded semconv.Event = "RegistrationSucceeded" - LoginFailed semconv.Event = "LoginFailed" - LoginSucceeded semconv.Event = "LoginSucceeded" - SettingsFailed semconv.Event = "SettingsFailed" - SettingsSucceeded semconv.Event = "SettingsSucceeded" - RecoveryFailed semconv.Event = "RecoveryFailed" - RecoverySucceeded semconv.Event = "RecoverySucceeded" - VerificationFailed semconv.Event = "VerificationFailed" - VerificationSucceeded semconv.Event = "VerificationSucceeded" - IdentityCreated semconv.Event = "IdentityCreated" - IdentityUpdated semconv.Event = "IdentityUpdated" - WebhookDelivered semconv.Event = "WebhookDelivered" - WebhookSucceeded semconv.Event = "WebhookSucceeded" - WebhookFailed semconv.Event = "WebhookFailed" + SessionIssued semconv.Event = "SessionIssued" + SessionChanged semconv.Event = "SessionChanged" + SessionLifespanExtended semconv.Event = "SessionLifespanExtended" + SessionRevoked semconv.Event = "SessionRevoked" + SessionChecked semconv.Event = "SessionChecked" + SessionTokenizedAsJWT semconv.Event = "SessionTokenizedAsJWT" + RegistrationFailed semconv.Event = "RegistrationFailed" + RegistrationSucceeded semconv.Event = "RegistrationSucceeded" + LoginFailed semconv.Event = "LoginFailed" + LoginSucceeded semconv.Event = "LoginSucceeded" + SettingsFailed semconv.Event = "SettingsFailed" + SettingsSucceeded semconv.Event = "SettingsSucceeded" + RecoveryFailed semconv.Event = "RecoveryFailed" + RecoverySucceeded semconv.Event = "RecoverySucceeded" + VerificationFailed semconv.Event = "VerificationFailed" + VerificationSucceeded semconv.Event = "VerificationSucceeded" + IdentityCreated semconv.Event = "IdentityCreated" + IdentityUpdated semconv.Event = "IdentityUpdated" + WebhookDelivered semconv.Event = "WebhookDelivered" + WebhookSucceeded semconv.Event = "WebhookSucceeded" + WebhookFailed semconv.Event = "WebhookFailed" ) const ( attributeKeySessionID semconv.AttributeKey = "SessionID" attributeKeySessionAAL semconv.AttributeKey = "SessionAAL" + attributeKeySessionExpiresAt semconv.AttributeKey = "SessionExpiresAt" attributeKeySelfServiceFlowType semconv.AttributeKey = "SelfServiceFlowType" attributeKeySelfServiceMethodUsed semconv.AttributeKey = "SelfServiceMethodUsed" attributeKeySelfServiceSSOProviderUsed semconv.AttributeKey = "SelfServiceSSOProviderUsed" @@ -71,6 +73,10 @@ func attLoginRequestedAAL(val string) otelattr.KeyValue { return otelattr.String(attributeKeyLoginRequestedAAL.String(), val) } +func attSessionExpiresAt(expiresAt time.Time) otelattr.KeyValue { + return otelattr.String(attributeKeySessionExpiresAt.String(), expiresAt.String()) +} + func attLoginRequestedPrivilegedSession(val bool) otelattr.KeyValue { return otelattr.Bool(attributeKeyLoginRequestedPrivilegedSession.String(), val) } @@ -135,6 +141,18 @@ func NewSessionChanged(ctx context.Context, aal string, sessionID, identityID uu ) } +func NewSessionLifespanExtended(ctx context.Context, sessionID, identityID uuid.UUID, newExpiry time.Time) (string, trace.EventOption) { + return SessionLifespanExtended.String(), + trace.WithAttributes( + append( + semconv.AttributesFromContext(ctx), + semconv.AttrIdentityID(identityID), + attrSessionID(sessionID), + attSessionExpiresAt(newExpiry), + )..., + ) +} + type LoginSucceededOpts struct { SessionID, IdentityID uuid.UUID FlowType, RequestedAAL, Method, SSOProvider string