Skip to content

Commit

Permalink
feat: add verification hook to login flow (#3829)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonas-jonas authored Mar 21, 2024
1 parent 8ebdfd2 commit 43e4ead
Show file tree
Hide file tree
Showing 8 changed files with 291 additions and 132 deletions.
2 changes: 2 additions & 0 deletions driver/registry_default_hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ func (m *RegistryDefault) getHooks(credentialsType string, configs []config.Self
i = append(i, m.HookShowVerificationUI())
case hook.KeyTwoStepRegistration:
i = append(i, m.HookTwoStepRegistration())
case hook.KeyVerifier:
i = append(i, m.HookVerifier())
default:
var found bool
for name, m := range m.injectedSelfserviceHooks {
Expand Down
22 changes: 22 additions & 0 deletions embedx/config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,16 @@
"additionalProperties": false,
"required": ["hook"]
},
"selfServiceVerificationHook": {
"type": "object",
"properties": {
"hook": {
"const": "verification"
}
},
"additionalProperties": false,
"required": ["hook"]
},
"selfServiceShowVerificationUIHook": {
"type": "object",
"properties": {
Expand Down Expand Up @@ -735,6 +745,12 @@
},
{
"$ref": "#/definitions/selfServiceWebHook"
},
{
"$ref": "#/definitions/selfServiceVerificationHook"
},
{
"$ref": "#/definitions/selfServiceShowVerificationUIHook"
}
]
},
Expand Down Expand Up @@ -893,6 +909,12 @@
{
"$ref": "#/definitions/selfServiceRequireVerifiedAddressHook"
},
{
"$ref": "#/definitions/selfServiceVerificationHook"
},
{
"$ref": "#/definitions/selfServiceShowVerificationUIHook"
},
{
"$ref": "#/definitions/b2bSSOHook"
}
Expand Down
16 changes: 16 additions & 0 deletions selfservice/flow/login/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,12 @@ type Flow struct {
//
// required: false
TransientPayload json.RawMessage `json:"transient_payload,omitempty" faker:"-" db:"-"`

// Contains a list of actions, that could follow this flow
//
// It can, for example, contain a reference to the verification flow, created as part of the user's
// registration.
ContinueWithItems []flow.ContinueWith `json:"-" db:"-" faker:"-" `
}

var _ flow.Flow = new(Flow)
Expand Down Expand Up @@ -301,3 +307,13 @@ func (f *Flow) SetState(state flow.State) {
func (t *Flow) GetTransientPayload() json.RawMessage {
return t.TransientPayload
}

var _ flow.FlowWithContinueWith = new(Flow)

func (f *Flow) AddContinueWith(c flow.ContinueWith) {
f.ContinueWithItems = append(f.ContinueWithItems, c)
}

func (f *Flow) ContinueWith() []flow.ContinueWith {
return f.ContinueWithItems
}
67 changes: 37 additions & 30 deletions selfservice/flow/login/hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func (e *HookExecutor) PostLoginHook(
w http.ResponseWriter,
r *http.Request,
g node.UiNodeGroup,
a *Flow,
f *Flow,
i *identity.Identity,
s *session.Session,
provider string,
Expand All @@ -132,7 +132,7 @@ func (e *HookExecutor) PostLoginHook(
r = r.WithContext(ctx)
defer otelx.End(span, &err)

if err := e.maybeLinkCredentials(r.Context(), s, i, a); err != nil {
if err := e.maybeLinkCredentials(r.Context(), s, i, f); err != nil {
return err
}

Expand All @@ -144,11 +144,11 @@ func (e *HookExecutor) PostLoginHook(
// Verify the redirect URL before we do any other processing.
returnTo, err := x.SecureRedirectTo(r,
c.SelfServiceBrowserDefaultReturnTo(r.Context()),
x.SecureRedirectReturnTo(a.ReturnTo),
x.SecureRedirectUseSourceURL(a.RequestURL),
x.SecureRedirectReturnTo(f.ReturnTo),
x.SecureRedirectUseSourceURL(f.RequestURL),
x.SecureRedirectAllowURLs(c.SelfServiceBrowserAllowedReturnToDomains(r.Context())),
x.SecureRedirectAllowSelfServiceURLs(c.SelfPublicURL(r.Context())),
x.SecureRedirectOverrideDefaultReturnTo(c.SelfServiceFlowLoginReturnTo(r.Context(), a.Active.String())),
x.SecureRedirectOverrideDefaultReturnTo(c.SelfServiceFlowLoginReturnTo(r.Context(), f.Active.String())),
)
if err != nil {
return err
Expand All @@ -165,38 +165,38 @@ func (e *HookExecutor) PostLoginHook(
e.d.Logger().
WithRequest(r).
WithField("identity_id", i.ID).
WithField("flow_method", a.Active).
WithField("flow_method", f.Active).
Debug("Running ExecuteLoginPostHook.")
for k, executor := range e.d.PostLoginHooks(r.Context(), a.Active) {
if err := executor.ExecuteLoginPostHook(w, r, g, a, s); err != nil {
for k, executor := range e.d.PostLoginHooks(r.Context(), f.Active) {
if err := executor.ExecuteLoginPostHook(w, r, g, f, s); err != nil {
if errors.Is(err, ErrHookAbortFlow) {
e.d.Logger().
WithRequest(r).
WithField("executor", fmt.Sprintf("%T", executor)).
WithField("executor_position", k).
WithField("executors", PostHookExecutorNames(e.d.PostLoginHooks(r.Context(), a.Active))).
WithField("executors", PostHookExecutorNames(e.d.PostLoginHooks(r.Context(), f.Active))).
WithField("identity_id", i.ID).
WithField("flow_method", a.Active).
WithField("flow_method", f.Active).
Debug("A ExecuteLoginPostHook hook aborted early.")

span.SetAttributes(attribute.String("redirect_reason", "aborted by hook"), attribute.String("executor", fmt.Sprintf("%T", executor)))

return nil
}
return e.handleLoginError(w, r, g, a, i, err)
return e.handleLoginError(w, r, g, f, i, err)
}

e.d.Logger().
WithRequest(r).
WithField("executor", fmt.Sprintf("%T", executor)).
WithField("executor_position", k).
WithField("executors", PostHookExecutorNames(e.d.PostLoginHooks(r.Context(), a.Active))).
WithField("executors", PostHookExecutorNames(e.d.PostLoginHooks(r.Context(), f.Active))).
WithField("identity_id", i.ID).
WithField("flow_method", a.Active).
WithField("flow_method", f.Active).
Debug("ExecuteLoginPostHook completed successfully.")
}

if a.Type == flow.TypeAPI {
if f.Type == flow.TypeAPI {
span.SetAttributes(attribute.String("flow_type", string(flow.TypeAPI)))
if err := e.d.SessionPersister().UpsertSession(r.Context(), s); err != nil {
return errors.WithStack(err)
Expand All @@ -210,23 +210,27 @@ func (e *HookExecutor) PostLoginHook(
span.AddEvent(events.NewLoginSucceeded(r.Context(), &events.LoginSucceededOpts{
SessionID: s.ID,
IdentityID: i.ID,
FlowType: string(a.Type),
RequestedAAL: string(a.RequestedAAL),
IsRefresh: a.Refresh,
Method: a.Active.String(),
FlowType: string(f.Type),
RequestedAAL: string(f.RequestedAAL),
IsRefresh: f.Refresh,
Method: f.Active.String(),
SSOProvider: provider,
}))
if a.IDToken != "" {
if f.IDToken != "" {
// We don't want to redirect with the code, if the flow was submitted with an ID token.
// This is the case for Sign in with native Apple SDK or Google SDK.
} else if handled, err := e.d.SessionManager().MaybeRedirectAPICodeFlow(w, r, a, s.ID, g); err != nil {
} else if handled, err := e.d.SessionManager().MaybeRedirectAPICodeFlow(w, r, f, s.ID, g); err != nil {
return errors.WithStack(err)
} else if handled {
return nil
}

response := &APIFlowResponse{Session: s, Token: s.Token}
if required, _ := e.requiresAAL2(r, classified, a); required {
response := &APIFlowResponse{
Session: s,
Token: s.Token,
ContinueWith: f.ContinueWith(),
}
if required, _ := e.requiresAAL2(r, classified, f); required {
// If AAL is not satisfied, we omit the identity to preserve the user's privacy in case of a phishing attack.
response.Session.Identity = nil
}
Expand All @@ -247,7 +251,7 @@ func (e *HookExecutor) PostLoginHook(

trace.SpanFromContext(r.Context()).AddEvent(events.NewLoginSucceeded(r.Context(), &events.LoginSucceededOpts{
SessionID: s.ID,
IdentityID: i.ID, FlowType: string(a.Type), RequestedAAL: string(a.RequestedAAL), IsRefresh: a.Refresh, Method: a.Active.String(),
IdentityID: i.ID, FlowType: string(f.Type), RequestedAAL: string(f.RequestedAAL), IsRefresh: f.Refresh, Method: f.Active.String(),
SSOProvider: provider,
}))

Expand All @@ -258,7 +262,7 @@ func (e *HookExecutor) PostLoginHook(
s.Token = ""

// If we detect that whoami would require a higher AAL, we redirect!
if _, err := e.requiresAAL2(r, s, a); err != nil {
if _, err := e.requiresAAL2(r, s, f); err != nil {
if aalErr := new(session.ErrAALNotSatisfied); errors.As(err, &aalErr) {
span.SetAttributes(attribute.String("return_to", aalErr.RedirectTo), attribute.String("redirect_reason", "requires aal2"))
e.d.Writer().WriteError(w, r, flow.NewBrowserLocationChangeRequiredError(aalErr.RedirectTo))
Expand All @@ -269,10 +273,10 @@ func (e *HookExecutor) PostLoginHook(

// If Kratos is used as a Hydra login provider, we need to redirect back to Hydra by returning a 422 status
// with the post login challenge URL as the body.
if a.OAuth2LoginChallenge != "" {
if f.OAuth2LoginChallenge != "" {
postChallengeURL, err := e.d.Hydra().AcceptLoginRequest(r.Context(),
hydra.AcceptLoginRequestParams{
LoginChallenge: string(a.OAuth2LoginChallenge),
LoginChallenge: string(f.OAuth2LoginChallenge),
IdentityID: i.ID.String(),
SessionID: s.ID.String(),
AuthenticationMethods: s.AMR,
Expand All @@ -285,13 +289,16 @@ func (e *HookExecutor) PostLoginHook(
return nil
}

response := &APIFlowResponse{Session: s}
response := &APIFlowResponse{
Session: s,
ContinueWith: f.ContinueWith(),
}
e.d.Writer().Write(w, r, response)
return nil
}

// If we detect that whoami would require a higher AAL, we redirect!
if _, err := e.requiresAAL2(r, s, a); err != nil {
if _, err := e.requiresAAL2(r, s, f); err != nil {
if aalErr := new(session.ErrAALNotSatisfied); errors.As(err, &aalErr) {
http.Redirect(w, r, aalErr.RedirectTo, http.StatusSeeOther)
return nil
Expand All @@ -300,10 +307,10 @@ func (e *HookExecutor) PostLoginHook(
}

finalReturnTo := returnTo.String()
if a.OAuth2LoginChallenge != "" {
if f.OAuth2LoginChallenge != "" {
rt, err := e.d.Hydra().AcceptLoginRequest(r.Context(),
hydra.AcceptLoginRequestParams{
LoginChallenge: string(a.OAuth2LoginChallenge),
LoginChallenge: string(f.OAuth2LoginChallenge),
IdentityID: i.ID.String(),
SessionID: s.ID.String(),
AuthenticationMethods: s.AMR,
Expand Down
13 changes: 12 additions & 1 deletion selfservice/flow/login/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@

package login

import "github.com/ory/kratos/session"
import (
"github.com/ory/kratos/selfservice/flow"
"github.com/ory/kratos/session"
)

// The Response for Login Flows via API
//
Expand All @@ -26,4 +29,12 @@ type APIFlowResponse struct {
//
// required: true
Session *session.Session `json:"session"`

// Contains a list of actions, that could follow this flow
//
// It can, for example, this will contain a reference to the verification flow, created as part of the user's
// registration or the token of the session.
//
// required: false
ContinueWith []flow.ContinueWith `json:"continue_with"`
}
1 change: 1 addition & 0 deletions selfservice/hook/hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ const (
KeyAddressVerifier = "require_verified_address"
KeyVerificationUI = "show_verification_ui"
KeyTwoStepRegistration = "two_step_registration"
KeyVerifier = "verification"
)
27 changes: 24 additions & 3 deletions selfservice/hook/verification.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,20 @@ import (
"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/identity"
"github.com/ory/kratos/selfservice/flow"
"github.com/ory/kratos/selfservice/flow/login"
"github.com/ory/kratos/selfservice/flow/registration"
"github.com/ory/kratos/selfservice/flow/settings"
"github.com/ory/kratos/selfservice/flow/verification"
"github.com/ory/kratos/session"
"github.com/ory/kratos/ui/node"
"github.com/ory/kratos/x"
"github.com/ory/x/otelx"
)

var (
_ registration.PostHookPostPersistExecutor = new(Verifier)
_ settings.PostHookPostPersistExecutor = new(Verifier)
_ login.PostHookExecutor = new(Verifier)
)

type (
Expand All @@ -34,6 +37,7 @@ type (
verification.FlowPersistenceProvider
identity.PrivilegedPoolProvider
x.WriterProvider
x.TracingProvider
}
Verifier struct {
r verifierDependencies
Expand Down Expand Up @@ -61,6 +65,18 @@ func (e *Verifier) ExecuteSettingsPostPersistHook(w http.ResponseWriter, r *http
})
}

func (e *Verifier) ExecuteLoginPostHook(w http.ResponseWriter, r *http.Request, g node.UiNodeGroup, f *login.Flow, s *session.Session) (err error) {
ctx, span := e.r.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.hook.Verifier.ExecuteLoginPostHook")
r = r.WithContext(ctx)
defer otelx.End(span, &err)
if f.RequestedAAL != identity.AuthenticatorAssuranceLevel1 {
span.AddEvent("Skipping verification hook because AAL is not 1")
return nil
}

return e.do(w, r.WithContext(ctx), s.Identity, f, nil)
}

func (e *Verifier) do(
w http.ResponseWriter,
r *http.Request,
Expand All @@ -78,19 +94,24 @@ func (e *Verifier) do(
}

isBrowserFlow := f.GetType() == flow.TypeBrowser
isRegistrationFlow := f.GetFlowName() == flow.RegistrationFlow
isRegistrationOrLoginFlow := f.GetFlowName() == flow.RegistrationFlow

for k := range i.VerifiableAddresses {
address := &i.VerifiableAddresses[k]
if address.Status != identity.VerifiableAddressStatusPending {
if isRegistrationOrLoginFlow && address.Verified {
continue
} else if !isRegistrationOrLoginFlow && address.Status != identity.VerifiableAddressStatusPending {
// In case of the settings flow, we only want to create a new verification flow if there is no pending
// verification flow for the address. Otherwise, we would create a new verification flow for each setting,
// even if the address did not change.
continue
}

var csrf string

// TODO: this is pretty ugly, we should probably have a better way to handle CSRF tokens here.
if isBrowserFlow {
if isRegistrationFlow {
if isRegistrationOrLoginFlow {
// If this hook is executed from a registration flow, we need to regenerate the CSRF token.
csrf = e.r.CSRFHandler().RegenerateToken(w, r)
} else {
Expand Down
Loading

0 comments on commit 43e4ead

Please sign in to comment.