Skip to content

Commit

Permalink
feat: improve session extend performance (ory#3948)
Browse files Browse the repository at this point in the history
This patch improves the performance for extending session lifespans. Lifespan extension is tricky as it is often part of the middleware of Ory Kratos consumers. As such, it is prone to transaction contention when we read and write to the same session row at the same time (and potentially multiple times).

To address this, we:

1. Introduce a locking mechanism on the row to reduce transaction contention;
2. Add a new feature flag that toggles returning 204 no content instead of 200 + session.

Be aware that all reads on the session table will have to wait for the transaction to commit before they return a value. This may cause long(er) response times on `/session/whoami` for sessions that are being extended at the same time.

BREAKING CHANGES: Going forward, the `/admin/session/.../extend` endpoint will return 204 no content for new Ory Network projects. We will deprecate returning 200 + session body in the future.
  • Loading branch information
aeneasr authored Jun 6, 2024
1 parent 278d8e0 commit 4e3fad4
Show file tree
Hide file tree
Showing 7 changed files with 211 additions and 29 deletions.
5 changes: 5 additions & 0 deletions driver/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ const (
ViperKeySessionTokenizerTemplates = "session.whoami.tokenizer.templates"
ViperKeySessionWhoAmIAAL = "session.whoami.required_aal"
ViperKeySessionWhoAmICaching = "feature_flags.cacheable_sessions"
ViperKeyFeatureFlagFasterSessionExtend = "feature_flags.faster_session_extend"
ViperKeySessionWhoAmICachingMaxAge = "feature_flags.cacheable_sessions_max_age"
ViperKeyUseContinueWithTransitions = "feature_flags.use_continue_with_transitions"
ViperKeySessionRefreshMinTimeLeft = "session.earliest_possible_extend"
Expand Down Expand Up @@ -1369,6 +1370,10 @@ func (p *Config) SessionWhoAmICaching(ctx context.Context) bool {
return p.GetProvider(ctx).Bool(ViperKeySessionWhoAmICaching)
}

func (p *Config) FeatureFlagFasterSessionExtend(ctx context.Context) bool {
return p.GetProvider(ctx).Bool(ViperKeyFeatureFlagFasterSessionExtend)
}

func (p *Config) SessionWhoAmICachingMaxAge(ctx context.Context) time.Duration {
return p.GetProvider(ctx).DurationF(ViperKeySessionWhoAmICachingMaxAge, 0)
}
Expand Down
6 changes: 6 additions & 0 deletions embedx/config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -2852,6 +2852,12 @@
"title": "Enable new flow transitions using `continue_with` items",
"description": "If enabled allows new flow transitions using `continue_with` items.",
"default": false
},
"faster_session_extend": {
"type": "boolean",
"title": "Enable faster session extension",
"description": "If enabled allows faster session extension by skipping the session lookup. Disabling this feature will be deprecated in the future.",
"default": false
}
},
"additionalProperties": false
Expand Down
59 changes: 59 additions & 0 deletions persistence/sql/persister_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ import (
"fmt"
"time"

"github.com/ory/herodot"
"github.com/ory/x/dbal"

"github.com/gobuffalo/pop/v6"
"github.com/gofrs/uuid"
"github.com/pkg/errors"
Expand Down Expand Up @@ -176,6 +179,61 @@ func (p *Persister) ListSessionsByIdentity(
return s, t, nil
}

// ExtendSession updates the expiry of a session.
func (p *Persister) ExtendSession(ctx context.Context, sessionID uuid.UUID) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ExtendSession")
defer otelx.End(span, &err)

nid := p.NetworkID(ctx)
s := new(session.Session)
var didRefresh bool
if err := errors.WithStack(p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) (err error) {
lockBehavior := ""
if tx.Dialect.Name() == dbal.DriverCockroachDB {
// SKIP LOCKED returns no rows if the row is locked by another transaction.
lockBehavior = "FOR UPDATE SKIP LOCKED"
}

if err := tx.
Where(
// We make use of the fact that CRDB supports FOR UPDATE as part of the WHERE clause.
fmt.Sprintf("id = ? AND nid = ? %s", lockBehavior),
sessionID, nid,
).First(s); err != nil {

// This is a special case for CockroachDB. If the row is locked, we do not see the session. Therefor we return
// a 404 not found error indicating to the user that the session might already be updated by someone else.
if errors.Is(err, sqlcon.ErrNoRows) && tx.Dialect.Name() == dbal.DriverCockroachDB {
return errors.WithStack(herodot.ErrNotFound.WithReason("The session you are trying to extend is already being extended by another request or does not exist."))
}

return sqlcon.HandleError(err)
}

if !s.CanBeRefreshed(ctx, p.r.Config()) {
// This prevents excessive writes to the database.
return nil
}

didRefresh = true
s = s.Refresh(ctx, p.r.Config())

if _, err := tx.Where("id = ? AND nid = ?", sessionID, nid).UpdateQuery(s, "expires_at"); err != nil {
return sqlcon.HandleError(err)
}

return nil
})); err != nil {
return err
}

if didRefresh {
trace.SpanFromContext(ctx).AddEvent(events.NewSessionLifespanExtended(ctx, s.ID, s.IdentityID, s.ExpiresAt))
}

return nil
}

// UpsertSession creates a session if not found else updates.
// This operation also inserts Session device records when a session is being created.
// The update operation skips updating Session device records since only one record would need to be updated in this case.
Expand All @@ -196,6 +254,7 @@ func (p *Persister) UpsertSession(ctx context.Context, s *session.Session) (err
trace.SpanFromContext(ctx).AddEvent(events.NewSessionIssued(ctx, string(s.AuthenticatorAssuranceLevel), s.ID, s.IdentityID))
}
}()

return errors.WithStack(p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) (err error) {
updated = false
exists := false
Expand Down
27 changes: 18 additions & 9 deletions session/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,10 @@ type extendSession struct {
// Calling this endpoint extends the given session ID. If `session.earliest_possible_extend` is set it
// will only extend the session after the specified time has passed.
//
// This endpoint returns per default a 204 No Content response on success. Older Ory Network projects may
// return a 200 OK response with the session in the body. Returning the session as part of the response
// will be deprecated in the future and should not be relied upon.
//
// Retrieve the session ID from the `/sessions/whoami` endpoint / `toSession` SDK method.
//
// Schemes: http, https
Expand All @@ -882,30 +886,35 @@ type extendSession struct {
//
// Responses:
// 200: session
// 204: emptyResponse
// 400: errorGeneric
// 404: errorGeneric
// default: errorGeneric
func (h *Handler) adminSessionExtend(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
iID, err := uuid.FromString(ps.ByName("id"))
id, err := uuid.FromString(ps.ByName("id"))
if err != nil {
h.r.Writer().WriteError(w, r, errors.WithStack(herodot.ErrBadRequest.WithError(err.Error()).WithDebug("could not parse UUID")))
return
}

s, err := h.r.SessionPersister().GetSession(r.Context(), iID, ExpandDefault)
if err != nil {
c := h.r.Config()
if err := h.r.SessionPersister().ExtendSession(r.Context(), id); err != nil {
h.r.Writer().WriteError(w, r, err)
return
}

c := h.r.Config()
if s.CanBeRefreshed(r.Context(), c) {
if err := h.r.SessionPersister().UpsertSession(r.Context(), s.Refresh(r.Context(), c)); err != nil {
h.r.Writer().WriteError(w, r, err)
return
}
// Default behavior going forward.
if c.FeatureFlagFasterSessionExtend(r.Context()) {
w.WriteHeader(http.StatusNoContent)
return
}

// WARNING - this will be deprecated at some point!
s, err := h.r.SessionPersister().GetSession(r.Context(), id, ExpandDefault)
if err != nil {
h.r.Writer().WriteError(w, r, err)
return
}
h.r.Writer().Write(w, r, s)
}

Expand Down
3 changes: 3 additions & 0 deletions session/persistence.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ type Persister interface {
// UpsertSession inserts or updates a session into / in the store.
UpsertSession(ctx context.Context, s *Session) error

// ExtendSession updates the expiry of a session.
ExtendSession(ctx context.Context, sessionID uuid.UUID) error

// DeleteSession removes a session from the store.
DeleteSession(ctx context.Context, id uuid.UUID) error

Expand Down
82 changes: 82 additions & 0 deletions session/test/persistence.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ import (
"testing"
"time"

"github.com/pkg/errors"
"golang.org/x/sync/errgroup"

"github.com/ory/x/dbal"

"github.com/gobuffalo/pop/v6"

"github.com/ory/x/pagination/keysetpagination"
Expand Down Expand Up @@ -604,5 +609,82 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
_, err = p.GetSessionByToken(ctx, t2, session.ExpandNothing, identity.ExpandDefault)
require.ErrorIs(t, err, sqlcon.ErrNoRows)
})

t.Run("extend session lifespan but min time is not yet reached", func(t *testing.T) {
conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, time.Hour*2)
t.Cleanup(func() {
conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, nil)
})

var expected session.Session
require.NoError(t, faker.FakeData(&expected))
expected.ExpiresAt = time.Now().Add(time.Hour * 10).Round(time.Second).UTC()
require.NoError(t, p.CreateIdentity(ctx, expected.Identity))
require.NoError(t, p.UpsertSession(ctx, &expected))

require.NoError(t, p.ExtendSession(ctx, expected.ID))
actual, err := p.GetSession(ctx, expected.ID, session.ExpandNothing)
require.NoError(t, err)
assert.Equal(t, expected.ExpiresAt, actual.ExpiresAt)
})

t.Run("extend session lifespan", func(t *testing.T) {
conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, time.Hour)
t.Cleanup(func() {
conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, nil)
})

conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, time.Hour*2)
var expected session.Session
require.NoError(t, faker.FakeData(&expected))
expected.ExpiresAt = time.Now().Add(time.Hour).UTC()
require.NoError(t, p.CreateIdentity(ctx, expected.Identity))
require.NoError(t, p.UpsertSession(ctx, &expected))

expectedExpiry := expected.Refresh(ctx, conf).ExpiresAt.Round(time.Minute)
require.NoError(t, p.ExtendSession(ctx, expected.ID))
actual, err := p.GetSession(ctx, expected.ID, session.ExpandNothing)
require.NoError(t, err)
assert.Equal(t, expectedExpiry, actual.ExpiresAt.Round(time.Minute))
})

t.Run("extend session lifespan on CockroachDB", func(t *testing.T) {
if p.GetConnection(ctx).Dialect.Name() != dbal.DriverCockroachDB {
t.Skip("Skipping test because driver is not CockroachDB")
}

conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, time.Hour)
t.Cleanup(func() {
conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, nil)
})

conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, time.Hour*2)
var expected session.Session
require.NoError(t, faker.FakeData(&expected))
expected.ExpiresAt = time.Now().Add(time.Hour).UTC()
require.NoError(t, p.CreateIdentity(ctx, expected.Identity))
require.NoError(t, p.UpsertSession(ctx, &expected))

expectedExpiry := expected.Refresh(ctx, conf).ExpiresAt.Round(time.Minute)

var foundExpectedCockroachError bool
g := errgroup.Group{}
for i := 0; i < 10; i++ {
g.Go(func() error {
err := p.ExtendSession(ctx, expected.ID)
if errors.Is(err, sqlcon.ErrNoRows) {
foundExpectedCockroachError = true
return nil
}
return err
})
}
require.NoError(t, g.Wait())

actual, err := p.GetSession(ctx, expected.ID, session.ExpandNothing)
require.NoError(t, err)
assert.Equal(t, expectedExpiry, actual.ExpiresAt.Round(time.Minute))
assert.True(t, foundExpectedCockroachError, "We expect to find a not found error caused by ... FOR UPDATE SKIP LOCKED")
})
}
}
58 changes: 38 additions & 20 deletions x/events/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,33 @@ import (
)

const (
SessionIssued semconv.Event = "SessionIssued"
SessionChanged semconv.Event = "SessionChanged"
SessionRevoked semconv.Event = "SessionRevoked"
SessionChecked semconv.Event = "SessionChecked"
SessionTokenizedAsJWT semconv.Event = "SessionTokenizedAsJWT"
RegistrationFailed semconv.Event = "RegistrationFailed"
RegistrationSucceeded semconv.Event = "RegistrationSucceeded"
LoginFailed semconv.Event = "LoginFailed"
LoginSucceeded semconv.Event = "LoginSucceeded"
SettingsFailed semconv.Event = "SettingsFailed"
SettingsSucceeded semconv.Event = "SettingsSucceeded"
RecoveryFailed semconv.Event = "RecoveryFailed"
RecoverySucceeded semconv.Event = "RecoverySucceeded"
VerificationFailed semconv.Event = "VerificationFailed"
VerificationSucceeded semconv.Event = "VerificationSucceeded"
IdentityCreated semconv.Event = "IdentityCreated"
IdentityUpdated semconv.Event = "IdentityUpdated"
WebhookDelivered semconv.Event = "WebhookDelivered"
WebhookSucceeded semconv.Event = "WebhookSucceeded"
WebhookFailed semconv.Event = "WebhookFailed"
SessionIssued semconv.Event = "SessionIssued"
SessionChanged semconv.Event = "SessionChanged"
SessionLifespanExtended semconv.Event = "SessionLifespanExtended"
SessionRevoked semconv.Event = "SessionRevoked"
SessionChecked semconv.Event = "SessionChecked"
SessionTokenizedAsJWT semconv.Event = "SessionTokenizedAsJWT"
RegistrationFailed semconv.Event = "RegistrationFailed"
RegistrationSucceeded semconv.Event = "RegistrationSucceeded"
LoginFailed semconv.Event = "LoginFailed"
LoginSucceeded semconv.Event = "LoginSucceeded"
SettingsFailed semconv.Event = "SettingsFailed"
SettingsSucceeded semconv.Event = "SettingsSucceeded"
RecoveryFailed semconv.Event = "RecoveryFailed"
RecoverySucceeded semconv.Event = "RecoverySucceeded"
VerificationFailed semconv.Event = "VerificationFailed"
VerificationSucceeded semconv.Event = "VerificationSucceeded"
IdentityCreated semconv.Event = "IdentityCreated"
IdentityUpdated semconv.Event = "IdentityUpdated"
WebhookDelivered semconv.Event = "WebhookDelivered"
WebhookSucceeded semconv.Event = "WebhookSucceeded"
WebhookFailed semconv.Event = "WebhookFailed"
)

const (
attributeKeySessionID semconv.AttributeKey = "SessionID"
attributeKeySessionAAL semconv.AttributeKey = "SessionAAL"
attributeKeySessionExpiresAt semconv.AttributeKey = "SessionExpiresAt"
attributeKeySelfServiceFlowType semconv.AttributeKey = "SelfServiceFlowType"
attributeKeySelfServiceMethodUsed semconv.AttributeKey = "SelfServiceMethodUsed"
attributeKeySelfServiceSSOProviderUsed semconv.AttributeKey = "SelfServiceSSOProviderUsed"
Expand Down Expand Up @@ -71,6 +73,10 @@ func attLoginRequestedAAL(val string) otelattr.KeyValue {
return otelattr.String(attributeKeyLoginRequestedAAL.String(), val)
}

func attSessionExpiresAt(expiresAt time.Time) otelattr.KeyValue {
return otelattr.String(attributeKeySessionExpiresAt.String(), expiresAt.String())
}

func attLoginRequestedPrivilegedSession(val bool) otelattr.KeyValue {
return otelattr.Bool(attributeKeyLoginRequestedPrivilegedSession.String(), val)
}
Expand Down Expand Up @@ -135,6 +141,18 @@ func NewSessionChanged(ctx context.Context, aal string, sessionID, identityID uu
)
}

func NewSessionLifespanExtended(ctx context.Context, sessionID, identityID uuid.UUID, newExpiry time.Time) (string, trace.EventOption) {
return SessionLifespanExtended.String(),
trace.WithAttributes(
append(
semconv.AttributesFromContext(ctx),
semconv.AttrIdentityID(identityID),
attrSessionID(sessionID),
attSessionExpiresAt(newExpiry),
)...,
)
}

type LoginSucceededOpts struct {
SessionID, IdentityID uuid.UUID
FlowType, RequestedAAL, Method, SSOProvider string
Expand Down

0 comments on commit 4e3fad4

Please sign in to comment.