diff --git a/persistence/sql/migrations/sql/20230823000000000001_verification_add_oauth2_login_challenge_params.down.sql b/persistence/sql/migrations/sql/20230823000000000001_verification_add_oauth2_login_challenge_params.down.sql new file mode 100644 index 000000000000..1a1437ae0526 --- /dev/null +++ b/persistence/sql/migrations/sql/20230823000000000001_verification_add_oauth2_login_challenge_params.down.sql @@ -0,0 +1,2 @@ +ALTER TABLE selfservice_verification_flows DROP COLUMN identity_id; +ALTER TABLE selfservice_verification_flows DROP COLUMN authentication_methods; diff --git a/persistence/sql/migrations/sql/20230823000000000001_verification_add_oauth2_login_challenge_params.mysql.up.sql b/persistence/sql/migrations/sql/20230823000000000001_verification_add_oauth2_login_challenge_params.mysql.up.sql new file mode 100644 index 000000000000..31614eb14175 --- /dev/null +++ b/persistence/sql/migrations/sql/20230823000000000001_verification_add_oauth2_login_challenge_params.mysql.up.sql @@ -0,0 +1,2 @@ +ALTER TABLE selfservice_verification_flows ADD COLUMN identity_id VARCHAR(36); +ALTER TABLE selfservice_verification_flows ADD COLUMN authentication_methods JSON; diff --git a/persistence/sql/migrations/sql/20230823000000000001_verification_add_oauth2_login_challenge_params.sqlite.up.sql b/persistence/sql/migrations/sql/20230823000000000001_verification_add_oauth2_login_challenge_params.sqlite.up.sql new file mode 100644 index 000000000000..ab661e1c6753 --- /dev/null +++ b/persistence/sql/migrations/sql/20230823000000000001_verification_add_oauth2_login_challenge_params.sqlite.up.sql @@ -0,0 +1,2 @@ +ALTER TABLE selfservice_verification_flows ADD COLUMN identity_id VARCHAR(36); +ALTER TABLE selfservice_verification_flows ADD COLUMN authentication_methods TEXT; diff --git a/persistence/sql/migrations/sql/20230823000000000001_verification_add_oauth2_login_challenge_params.up.sql b/persistence/sql/migrations/sql/20230823000000000001_verification_add_oauth2_login_challenge_params.up.sql new file mode 100644 index 000000000000..42b75ecd7dfd --- /dev/null +++ b/persistence/sql/migrations/sql/20230823000000000001_verification_add_oauth2_login_challenge_params.up.sql @@ -0,0 +1,2 @@ +ALTER TABLE selfservice_verification_flows ADD COLUMN identity_id UUID; +ALTER TABLE selfservice_verification_flows ADD COLUMN authentication_methods JSON; diff --git a/selfservice/flow/verification/flow.go b/selfservice/flow/verification/flow.go index a77cf30dc688..71a484f7d9c5 100644 --- a/selfservice/flow/verification/flow.go +++ b/selfservice/flow/verification/flow.go @@ -17,6 +17,7 @@ import ( "github.com/ory/kratos/driver/config" "github.com/ory/kratos/selfservice/flow" + "github.com/ory/kratos/session" "github.com/ory/kratos/ui/container" "github.com/ory/kratos/x" "github.com/ory/x/sqlxx" @@ -81,9 +82,7 @@ type Flow struct { // OAuth2LoginChallenge holds the login challenge originally set during the registration flow. OAuth2LoginChallenge sqlxx.NullString `json:"-" db:"oauth2_login_challenge"` - - // SessionID holds the session id if set from a registraton hook. - SessionID uuid.NullUUID `json:"-" faker:"-" db:"session_id"` + OAuth2LoginChallengeParams // CSRFToken contains the anti-csrf token associated with this request. CSRFToken string `json:"-" db:"csrf_token"` @@ -95,6 +94,18 @@ type Flow struct { NID uuid.UUID `json:"-" faker:"-" db:"nid"` } +type OAuth2LoginChallengeParams struct { + // SessionID holds the session id if set from a registraton hook. + SessionID uuid.NullUUID `json:"-" faker:"-" db:"session_id"` + + // IdentityID holds the identity id if set from a registraton hook. + IdentityID uuid.NullUUID `json:"-" faker:"-" db:"identity_id"` + + // AMR contains a list of authentication methods that were used to verify the + // session if set from a registration hook. + AMR session.AuthenticationMethods `db:"authentication_methods" json:"-"` +} + func (f *Flow) GetType() flow.Type { return f.Type } diff --git a/selfservice/flow/verification/handler.go b/selfservice/flow/verification/handler.go index 2ed3ca3f3679..977cb7bd8616 100644 --- a/selfservice/flow/verification/handler.go +++ b/selfservice/flow/verification/handler.go @@ -4,7 +4,6 @@ package verification import ( - "context" "net/http" "time" @@ -54,6 +53,7 @@ type ( x.CSRFTokenGeneratorProvider x.WriterProvider x.CSRFProvider + x.LoggingProvider FlowPersistenceProvider ErrorHandlerProvider @@ -445,8 +445,7 @@ func (h *Handler) updateVerificationFlow(w http.ResponseWriter, r *http.Request, // Special case: If we ended up here through a OAuth2 login challenge, we need to accept the login request // and redirect back to the OAuth2 provider. if f.OAuth2LoginChallenge.String() != "" { - sess := h.maybeGetSession(ctx, f) - if sess == nil { + if !f.IdentityID.Valid || !f.SessionID.Valid { h.d.VerificationFlowErrorHandler().WriteFlowError(w, r, f, node.DefaultGroup, herodot.ErrBadRequest.WithReasonf("No session was found for this flow. Please retry the authentication.")) return @@ -455,9 +454,9 @@ func (h *Handler) updateVerificationFlow(w http.ResponseWriter, r *http.Request, callbackURL, err := h.d.Hydra().AcceptLoginRequest(ctx, hydra.AcceptLoginRequestParams{ LoginChallenge: string(f.OAuth2LoginChallenge), - IdentityID: sess.IdentityID.String(), - SessionID: sess.ID.String(), - AuthenticationMethods: sess.AMR, + IdentityID: f.IdentityID.UUID.String(), + SessionID: f.SessionID.UUID.String(), + AuthenticationMethods: f.AMR, }) if err != nil { h.d.VerificationFlowErrorHandler().WriteFlowError(w, r, f, node.DefaultGroup, err) @@ -480,16 +479,3 @@ func (h *Handler) updateVerificationFlow(w http.ResponseWriter, r *http.Request, h.d.Writer().Write(w, r, updatedFlow) } - -// maybeGetSession returns the session if it was found in the flow or nil otherwise. -func (h *Handler) maybeGetSession(ctx context.Context, f *Flow) *session.Session { - if !f.SessionID.Valid { - return nil - } - s, err := h.d.SessionPersister().GetSession(ctx, f.SessionID.UUID, session.ExpandNothing) - if err != nil { - return nil - } - - return s -} diff --git a/selfservice/flow/verification/handler_test.go b/selfservice/flow/verification/handler_test.go index cfac42d1eced..092f659e4a36 100644 --- a/selfservice/flow/verification/handler_test.go +++ b/selfservice/flow/verification/handler_test.go @@ -20,9 +20,11 @@ import ( "github.com/ory/kratos/driver/config" "github.com/ory/kratos/hydra" + "github.com/ory/kratos/identity" "github.com/ory/kratos/internal" "github.com/ory/kratos/internal/testhelpers" "github.com/ory/kratos/selfservice/flow/verification" + "github.com/ory/kratos/session" "github.com/ory/kratos/x" ) @@ -214,15 +216,17 @@ func TestPostFlow(t *testing.T) { t.Run("suite=with OIDC login challenge", func(t *testing.T) { t.Run("case=succeeds with a session", func(t *testing.T) { - s := testhelpers.CreateSession(t, reg) - f := &verification.Flow{ ID: uuid.Must(uuid.NewV4()), Type: "browser", ExpiresAt: time.Now().Add(1 * time.Hour), IssuedAt: time.Now(), OAuth2LoginChallenge: hydra.FakeValidLoginChallenge, - SessionID: uuid.NullUUID{UUID: s.ID, Valid: true}, + OAuth2LoginChallengeParams: verification.OAuth2LoginChallengeParams{ + SessionID: uuid.NullUUID{UUID: uuid.Must(uuid.NewV4()), Valid: true}, + IdentityID: uuid.NullUUID{UUID: uuid.Must(uuid.NewV4()), Valid: true}, + AMR: session.AuthenticationMethods{{Method: identity.CredentialsTypePassword}}, + }, } require.NoError(t, reg.VerificationFlowPersister().CreateVerificationFlow(ctx, f)) diff --git a/selfservice/hook/verification.go b/selfservice/hook/verification.go index 899542bd8f60..ba4d7e7d79a7 100644 --- a/selfservice/hook/verification.go +++ b/selfservice/hook/verification.go @@ -47,6 +47,8 @@ func (e *Verifier) ExecutePostRegistrationPostPersistHook(w http.ResponseWriter, return e.do(w, r.WithContext(ctx), s.Identity, f, func(v *verification.Flow) { v.OAuth2LoginChallenge = f.OAuth2LoginChallenge v.SessionID = uuid.NullUUID{UUID: s.ID, Valid: true} + v.IdentityID = uuid.NullUUID{UUID: s.Identity.ID, Valid: true} + v.AMR = s.AMR }) }) } diff --git a/session/session.go b/session/session.go index c409323c2962..f915fd0aad02 100644 --- a/session/session.go +++ b/session/session.go @@ -329,6 +329,9 @@ type AuthenticationMethod struct { // Scan implements the Scanner interface. func (n *AuthenticationMethod) Scan(value interface{}) error { + if value == nil { + return nil + } v := fmt.Sprintf("%s", value) if len(v) == 0 { return nil @@ -347,6 +350,9 @@ func (n AuthenticationMethod) Value() (driver.Value, error) { // Scan implements the Scanner interface. func (n *AuthenticationMethods) Scan(value interface{}) error { + if value == nil { + return nil + } v := fmt.Sprintf("%s", value) if len(v) == 0 { return nil