Skip to content

Commit

Permalink
feat: fewer DB loads when linking credentials, add tracing
Browse files Browse the repository at this point in the history
  • Loading branch information
alnr committed Jan 2, 2025
1 parent 7578f00 commit 2c5bb21
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 13 deletions.
18 changes: 11 additions & 7 deletions selfservice/flow/login/hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"net/url"
"time"

"github.com/gofrs/uuid"
"github.com/pkg/errors"
"go.opentelemetry.io/otel/attribute"

Expand Down Expand Up @@ -393,7 +392,10 @@ func (e *HookExecutor) PreLoginHook(w http.ResponseWriter, r *http.Request, a *F
}

// maybeLinkCredentials links the identity with the credentials of the inner context of the login flow.
func (e *HookExecutor) maybeLinkCredentials(ctx context.Context, sess *session.Session, ident *identity.Identity, loginFlow *Flow) error {
func (e *HookExecutor) maybeLinkCredentials(ctx context.Context, sess *session.Session, ident *identity.Identity, loginFlow *Flow) (err error) {
ctx, span := e.d.Tracer(ctx).Tracer().Start(ctx, "HookExecutor.PostLoginHook.maybeLinkCredentials")
defer otelx.End(span, &err)

if e.checkAAL(ctx, sess, loginFlow) != nil {
// we don't yet want to link credentials because the required AAL is not satisfied
return nil
Expand All @@ -406,7 +408,7 @@ func (e *HookExecutor) maybeLinkCredentials(ctx context.Context, sess *session.S
return nil
}

if err = e.checkDuplicateCredentialsIdentifierMatch(ctx, ident.ID, lc.DuplicateIdentifier); err != nil {
if err = e.checkDuplicateCredentialsIdentifierMatch(ctx, ident, lc.DuplicateIdentifier); err != nil {
return err
}
strategy, err := e.d.AllLoginStrategies().Strategy(lc.CredentialsType)
Expand All @@ -431,11 +433,13 @@ func (e *HookExecutor) maybeLinkCredentials(ctx context.Context, sess *session.S
return nil
}

func (e *HookExecutor) checkDuplicateCredentialsIdentifierMatch(ctx context.Context, identityID uuid.UUID, match string) error {
i, err := e.d.PrivilegedIdentityPool().GetIdentityConfidential(ctx, identityID)
if err != nil {
return err
func (e *HookExecutor) checkDuplicateCredentialsIdentifierMatch(ctx context.Context, i *identity.Identity, match string) error {
if len(i.Credentials) == 0 {
if err := e.d.PrivilegedIdentityPool().HydrateIdentityAssociations(ctx, i, identity.ExpandCredentials); err != nil {
return err
}
}

for _, credentials := range i.Credentials {
for _, identifier := range credentials.Identifiers {
if identifier == match {
Expand Down
15 changes: 12 additions & 3 deletions selfservice/strategy/oidc/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -778,10 +778,19 @@ func (s *Strategy) processIDToken(r *http.Request, provider Provider, idToken, i
return claims, nil
}

func (s *Strategy) linkCredentials(ctx context.Context, i *identity.Identity, tokens *identity.CredentialsOIDCEncryptedTokens, provider, subject, organization string) error {
if err := s.d.PrivilegedIdentityPool().HydrateIdentityAssociations(ctx, i, identity.ExpandCredentials); err != nil {
return err
func (s *Strategy) linkCredentials(ctx context.Context, i *identity.Identity, tokens *identity.CredentialsOIDCEncryptedTokens, provider, subject, organization string) (err error) {
ctx, span := s.d.Tracer(ctx).Tracer().Start(ctx, "strategy.oidc.linkCredentials", trace.WithAttributes(
attribute.String("provider", provider),
// attribute.String("subject", subject), // PII
attribute.String("organization", organization)))
defer otelx.End(span, &err)

if len(i.Credentials) == 0 {
if err := s.d.PrivilegedIdentityPool().HydrateIdentityAssociations(ctx, i, identity.ExpandCredentials); err != nil {
return err
}
}

var conf identity.CredentialsOIDC
creds, err := i.ParseCredentials(s.ID(), &conf)
if errors.Is(err, herodot.ErrNotFound) {
Expand Down
8 changes: 5 additions & 3 deletions selfservice/strategy/oidc/strategy_settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,10 @@ func (s *Strategy) handleSettingsError(ctx context.Context, w http.ResponseWrite
return err
}

func (s *Strategy) Link(ctx context.Context, i *identity.Identity, credentialsConfig sqlxx.JSONRawMessage) error {
func (s *Strategy) Link(ctx context.Context, i *identity.Identity, credentialsConfig sqlxx.JSONRawMessage) (err error) {
ctx, span := s.d.Tracer(ctx).Tracer().Start(ctx, "selfservice.strategy.oidc.Strategy.Link")
defer otelx.End(span, &err)

var credentialsOIDCConfig identity.CredentialsOIDC
if err := json.Unmarshal(credentialsConfig, &credentialsOIDCConfig); err != nil {
return err
Expand All @@ -540,8 +543,7 @@ func (s *Strategy) Link(ctx context.Context, i *identity.Identity, credentialsCo
return err
}

options := []identity.ManagerOption{identity.ManagerAllowWriteProtectedTraits}
if err := s.d.IdentityManager().Update(ctx, i, options...); err != nil {
if err := s.d.IdentityManager().Update(ctx, i, identity.ManagerAllowWriteProtectedTraits); err != nil {
return err
}

Expand Down

0 comments on commit 2c5bb21

Please sign in to comment.