diff --git a/selfservice/flow/login/handler.go b/selfservice/flow/login/handler.go index e252a23e7011..1104b461a8ba 100644 --- a/selfservice/flow/login/handler.go +++ b/selfservice/flow/login/handler.go @@ -882,7 +882,7 @@ continueLogin: return } - if err := h.d.LoginHookExecutor().PostLoginHook(w, r, group, f, i, sess, ""); err != nil { + if err := h.d.LoginHookExecutor().PostLoginHook(w, r, group, f, i, sess, nil, ""); err != nil { if errors.Is(err, ErrAddressNotVerified) { h.d.LoginFlowErrorHandler().WriteFlowError(w, r, f, node.DefaultGroup, errors.WithStack(schema.NewAddressNotVerifiedError())) return diff --git a/selfservice/flow/login/hook.go b/selfservice/flow/login/hook.go index bfa581d0a85c..f8618d32c3d5 100644 --- a/selfservice/flow/login/hook.go +++ b/selfservice/flow/login/hook.go @@ -20,6 +20,7 @@ import ( "github.com/ory/kratos/schema" "github.com/ory/kratos/selfservice/flow" "github.com/ory/kratos/selfservice/sessiontokenexchange" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/session" "github.com/ory/kratos/ui/container" "github.com/ory/kratos/ui/node" @@ -34,7 +35,7 @@ type ( } PostHookExecutor interface { - ExecuteLoginPostHook(w http.ResponseWriter, r *http.Request, g node.UiNodeGroup, a *Flow, s *session.Session) error + ExecuteLoginPostHook(w http.ResponseWriter, r *http.Request, g node.UiNodeGroup, a *Flow, s *session.Session, c *claims.Claims) error } HooksProvider interface { @@ -126,6 +127,7 @@ func (e *HookExecutor) PostLoginHook( f *Flow, i *identity.Identity, s *session.Session, + c *claims.Claims, provider string, ) (err error) { ctx := r.Context() @@ -141,15 +143,15 @@ func (e *HookExecutor) PostLoginHook( return err } - c := e.d.Config() + cfg := e.d.Config() // Verify the redirect URL before we do any other processing. returnTo, err := x.SecureRedirectTo(r, - c.SelfServiceBrowserDefaultReturnTo(r.Context()), + cfg.SelfServiceBrowserDefaultReturnTo(r.Context()), x.SecureRedirectReturnTo(f.ReturnTo), x.SecureRedirectUseSourceURL(f.RequestURL), - x.SecureRedirectAllowURLs(c.SelfServiceBrowserAllowedReturnToDomains(r.Context())), - x.SecureRedirectAllowSelfServiceURLs(c.SelfPublicURL(r.Context())), - x.SecureRedirectOverrideDefaultReturnTo(c.SelfServiceFlowLoginReturnTo(r.Context(), f.Active.String())), + x.SecureRedirectAllowURLs(cfg.SelfServiceBrowserAllowedReturnToDomains(r.Context())), + x.SecureRedirectAllowSelfServiceURLs(cfg.SelfPublicURL(r.Context())), + x.SecureRedirectOverrideDefaultReturnTo(cfg.SelfServiceFlowLoginReturnTo(r.Context(), f.Active.String())), ) if err != nil { return err @@ -173,7 +175,7 @@ func (e *HookExecutor) PostLoginHook( WithField("flow_method", f.Active). Debug("Running ExecuteLoginPostHook.") for k, executor := range e.d.PostLoginHooks(r.Context(), f.Active) { - if err := executor.ExecuteLoginPostHook(w, r, g, f, s); err != nil { + if err := executor.ExecuteLoginPostHook(w, r, g, f, s, c); err != nil { if errors.Is(err, ErrHookAbortFlow) { e.d.Logger(). WithRequest(r). diff --git a/selfservice/flow/login/hook_test.go b/selfservice/flow/login/hook_test.go index 7d7f8e158174..351f7bd217fe 100644 --- a/selfservice/flow/login/hook_test.go +++ b/selfservice/flow/login/hook_test.go @@ -72,7 +72,7 @@ func TestLoginExecutor(t *testing.T) { } testhelpers.SelfServiceHookLoginErrorHandler(t, w, r, - reg.LoginHookExecutor().PostLoginHook(w, r, strategy.ToUiNodeGroup(), loginFlow, useIdentity, sess, "")) + reg.LoginHookExecutor().PostLoginHook(w, r, strategy.ToUiNodeGroup(), loginFlow, useIdentity, sess, nil, "")) }) ts := httptest.NewServer(router) diff --git a/selfservice/hook/address_verifier.go b/selfservice/hook/address_verifier.go index b28ca6d9b7e4..7b713380a6dc 100644 --- a/selfservice/hook/address_verifier.go +++ b/selfservice/hook/address_verifier.go @@ -14,6 +14,7 @@ import ( "github.com/ory/kratos/identity" "github.com/ory/kratos/selfservice/flow/login" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/session" ) @@ -25,7 +26,7 @@ func NewAddressVerifier() *AddressVerifier { return &AddressVerifier{} } -func (e *AddressVerifier) ExecuteLoginPostHook(_ http.ResponseWriter, _ *http.Request, _ node.UiNodeGroup, f *login.Flow, s *session.Session) error { +func (e *AddressVerifier) ExecuteLoginPostHook(_ http.ResponseWriter, _ *http.Request, _ node.UiNodeGroup, f *login.Flow, s *session.Session, _ *claims.Claims) error { // if the login happens using the password method, there must be at least one verified address if f.Active != identity.CredentialsTypePassword { return nil diff --git a/selfservice/hook/address_verifier_test.go b/selfservice/hook/address_verifier_test.go index fa80c3632644..7f6532507f64 100644 --- a/selfservice/hook/address_verifier_test.go +++ b/selfservice/hook/address_verifier_test.go @@ -82,7 +82,7 @@ func TestAddressVerifier(t *testing.T) { Identity: &identity.Identity{ID: x.NewUUID(), VerifiableAddresses: uc.verifiableAddresses}, } - err := verifier.ExecuteLoginPostHook(nil, nil, node.DefaultGroup, tc.flow, sessions) + err := verifier.ExecuteLoginPostHook(nil, nil, node.DefaultGroup, tc.flow, sessions, nil) if tc.neverError || uc.expectedError == nil { assert.NoError(t, err) diff --git a/selfservice/hook/error.go b/selfservice/hook/error.go index bb396578e5f8..a82d9e3511fe 100644 --- a/selfservice/hook/error.go +++ b/selfservice/hook/error.go @@ -12,6 +12,7 @@ import ( "github.com/ory/kratos/selfservice/flow/recovery" "github.com/ory/kratos/selfservice/flow/verification" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/ui/node" "github.com/ory/kratos/identity" @@ -64,7 +65,7 @@ func (e Error) ExecuteSettingsPostPersistHook(w http.ResponseWriter, r *http.Req return e.err("ExecuteSettingsPostPersistHook", settings.ErrHookAbortFlow) } -func (e Error) ExecuteLoginPostHook(w http.ResponseWriter, r *http.Request, g node.UiNodeGroup, a *login.Flow, s *session.Session) error { +func (e Error) ExecuteLoginPostHook(w http.ResponseWriter, r *http.Request, g node.UiNodeGroup, a *login.Flow, s *session.Session, c *claims.Claims) error { return e.err("ExecuteLoginPostHook", login.ErrHookAbortFlow) } diff --git a/selfservice/hook/session_destroyer.go b/selfservice/hook/session_destroyer.go index 16f7aa11b435..0e19b2d85d9a 100644 --- a/selfservice/hook/session_destroyer.go +++ b/selfservice/hook/session_destroyer.go @@ -11,14 +11,17 @@ import ( "github.com/ory/kratos/selfservice/flow/login" "github.com/ory/kratos/selfservice/flow/recovery" "github.com/ory/kratos/selfservice/flow/settings" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/session" "github.com/ory/kratos/ui/node" "github.com/ory/x/otelx" ) -var _ login.PostHookExecutor = new(SessionDestroyer) -var _ recovery.PostHookExecutor = new(SessionDestroyer) -var _ settings.PostHookPostPersistExecutor = new(SessionDestroyer) +var ( + _ login.PostHookExecutor = new(SessionDestroyer) + _ recovery.PostHookExecutor = new(SessionDestroyer) + _ settings.PostHookPostPersistExecutor = new(SessionDestroyer) +) type ( sessionDestroyerDependencies interface { @@ -34,7 +37,7 @@ func NewSessionDestroyer(r sessionDestroyerDependencies) *SessionDestroyer { return &SessionDestroyer{r: r} } -func (e *SessionDestroyer) ExecuteLoginPostHook(_ http.ResponseWriter, r *http.Request, _ node.UiNodeGroup, _ *login.Flow, s *session.Session) error { +func (e *SessionDestroyer) ExecuteLoginPostHook(_ http.ResponseWriter, r *http.Request, _ node.UiNodeGroup, _ *login.Flow, s *session.Session, _ *claims.Claims) error { return otelx.WithSpan(r.Context(), "selfservice.hook.SessionDestroyer.ExecuteLoginPostHook", func(ctx context.Context) error { if _, err := e.r.SessionPersister().RevokeSessionsIdentityExcept(ctx, s.Identity.ID, s.ID); err != nil { return err diff --git a/selfservice/hook/session_destroyer_test.go b/selfservice/hook/session_destroyer_test.go index e2d0cc21c2e2..833dd2968a61 100644 --- a/selfservice/hook/session_destroyer_test.go +++ b/selfservice/hook/session_destroyer_test.go @@ -52,6 +52,7 @@ func TestSessionDestroyer(t *testing.T) { node.DefaultGroup, nil, &session.Session{Identity: i}, + nil, ) }, }, diff --git a/selfservice/hook/show_verification_ui.go b/selfservice/hook/show_verification_ui.go index 65a5935ec7a6..c1893485fe7b 100644 --- a/selfservice/hook/show_verification_ui.go +++ b/selfservice/hook/show_verification_ui.go @@ -11,6 +11,7 @@ import ( "github.com/ory/kratos/selfservice/flow" "github.com/ory/kratos/selfservice/flow/login" "github.com/ory/kratos/selfservice/flow/registration" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/session" "github.com/ory/kratos/ui/node" "github.com/ory/kratos/x" @@ -52,7 +53,7 @@ func (e *ShowVerificationUIHook) ExecutePostRegistrationPostPersistHook(_ http.R // ExecuteLoginPostHook adds redirect headers and status code if the request is a browser request. // If the request is not a browser request, this hook does nothing. -func (e *ShowVerificationUIHook) ExecuteLoginPostHook(_ http.ResponseWriter, r *http.Request, _ node.UiNodeGroup, f *login.Flow, _ *session.Session) error { +func (e *ShowVerificationUIHook) ExecuteLoginPostHook(_ http.ResponseWriter, r *http.Request, _ node.UiNodeGroup, f *login.Flow, _ *session.Session, _ *claims.Claims) error { return otelx.WithSpan(r.Context(), "selfservice.hook.ShowVerificationUIHook.ExecutePostRegistrationPostPersistHook", func(ctx context.Context) error { return e.execute(r.WithContext(ctx), f) }) diff --git a/selfservice/hook/show_verification_ui_test.go b/selfservice/hook/show_verification_ui_test.go index 22171f0c345b..824eefe0bbf3 100644 --- a/selfservice/hook/show_verification_ui_test.go +++ b/selfservice/hook/show_verification_ui_test.go @@ -84,7 +84,7 @@ func TestExecutePostRegistrationPostPersistHook(t *testing.T) { browserRequest := httptest.NewRequest("GET", "/", nil) f := &login.Flow{} rec := httptest.NewRecorder() - require.NoError(t, h.ExecuteLoginPostHook(rec, browserRequest, "", f, nil)) + require.NoError(t, h.ExecuteLoginPostHook(rec, browserRequest, "", f, nil, nil)) require.Equal(t, 200, rec.Code) }) @@ -95,7 +95,7 @@ func TestExecutePostRegistrationPostPersistHook(t *testing.T) { browserRequest.Header.Add("Accept", "application/json") f := &login.Flow{} rec := httptest.NewRecorder() - require.NoError(t, h.ExecuteLoginPostHook(rec, browserRequest, "", f, nil)) + require.NoError(t, h.ExecuteLoginPostHook(rec, browserRequest, "", f, nil, nil)) require.Equal(t, 200, rec.Code) }) @@ -112,7 +112,7 @@ func TestExecutePostRegistrationPostPersistHook(t *testing.T) { flow.NewContinueWithVerificationUI(vf, "some@ory.sh", ""), } rec := httptest.NewRecorder() - require.NoError(t, h.ExecuteLoginPostHook(rec, browserRequest, "", rf, nil)) + require.NoError(t, h.ExecuteLoginPostHook(rec, browserRequest, "", rf, nil, nil)) assert.Equal(t, 200, rec.Code) assert.Equal(t, "/verification?flow="+vf.ID.String(), rf.ReturnToVerification) }) @@ -127,7 +127,7 @@ func TestExecutePostRegistrationPostPersistHook(t *testing.T) { flow.NewContinueWithSetToken("token"), } rec := httptest.NewRecorder() - require.NoError(t, h.ExecuteLoginPostHook(rec, browserRequest, "", rf, nil)) + require.NoError(t, h.ExecuteLoginPostHook(rec, browserRequest, "", rf, nil, nil)) assert.Equal(t, 200, rec.Code) }) }) diff --git a/selfservice/hook/stub/test_body.jsonnet b/selfservice/hook/stub/test_body.jsonnet index 117ecc587707..f406ad51076e 100644 --- a/selfservice/hook/stub/test_body.jsonnet +++ b/selfservice/hook/stub/test_body.jsonnet @@ -1,10 +1,14 @@ function(ctx) std.prune({ flow_id: ctx.flow.id, - identity_id: if std.objectHas(ctx, "identity") then ctx.identity.id, - session_id: if std.objectHas(ctx, "session") then ctx.session.id, + identity_id: if std.objectHas(ctx, 'identity') then ctx.identity.id, + session_id: if std.objectHas(ctx, 'session') then ctx.session.id, headers: ctx.request_headers, url: ctx.request_url, method: ctx.request_method, cookies: ctx.request_cookies, - transient_payload: if std.objectHas(ctx.flow, "transient_payload") then ctx.flow.transient_payload, + transient_payload: if std.objectHas(ctx.flow, 'transient_payload') then ctx.flow.transient_payload, + nickname: if std.objectHas(ctx, 'claims') then ctx.claims.nickname, + groups: if std.objectHas(ctx, 'claims') && + std.objectHas(ctx.claims, 'raw_claims') && + std.objectHas(ctx.claims.raw_claims, 'groups') then ctx.claims.raw_claims.groups, }) diff --git a/selfservice/hook/verification.go b/selfservice/hook/verification.go index 6fdd039c146f..bd722f5edfb0 100644 --- a/selfservice/hook/verification.go +++ b/selfservice/hook/verification.go @@ -16,6 +16,7 @@ import ( "github.com/ory/kratos/selfservice/flow/registration" "github.com/ory/kratos/selfservice/flow/settings" "github.com/ory/kratos/selfservice/flow/verification" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/session" "github.com/ory/kratos/ui/node" "github.com/ory/kratos/x" @@ -65,7 +66,7 @@ func (e *Verifier) ExecuteSettingsPostPersistHook(w http.ResponseWriter, r *http }) } -func (e *Verifier) ExecuteLoginPostHook(w http.ResponseWriter, r *http.Request, g node.UiNodeGroup, f *login.Flow, s *session.Session) (err error) { +func (e *Verifier) ExecuteLoginPostHook(w http.ResponseWriter, r *http.Request, g node.UiNodeGroup, f *login.Flow, s *session.Session, c *claims.Claims) (err error) { ctx, span := e.r.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.hook.Verifier.ExecuteLoginPostHook") r = r.WithContext(ctx) defer otelx.End(span, &err) diff --git a/selfservice/hook/verification_test.go b/selfservice/hook/verification_test.go index 5e815fa4d4a4..9facac1653f5 100644 --- a/selfservice/hook/verification_test.go +++ b/selfservice/hook/verification_test.go @@ -45,7 +45,7 @@ func TestVerifier(t *testing.T) { name: "login", execHook: func(h *hook.Verifier, i *identity.Identity, f flow.Flow) error { return h.ExecuteLoginPostHook( - httptest.NewRecorder(), u, node.CodeGroup, f.(*login.Flow), &session.Session{ID: x.NewUUID(), Identity: i}) + httptest.NewRecorder(), u, node.CodeGroup, f.(*login.Flow), &session.Session{ID: x.NewUUID(), Identity: i}, nil) }, originalFlow: func() flow.FlowWithContinueWith { return &login.Flow{RequestURL: "http://foo.com/login", RequestedAAL: "aal1"} @@ -126,7 +126,7 @@ func TestVerifier(t *testing.T) { h := hook.NewVerifier(reg) i := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID) f := &login.Flow{RequestedAAL: "aal2"} - require.NoError(t, h.ExecuteLoginPostHook(httptest.NewRecorder(), u, node.CodeGroup, f, &session.Session{ID: x.NewUUID(), Identity: i})) + require.NoError(t, h.ExecuteLoginPostHook(httptest.NewRecorder(), u, node.CodeGroup, f, &session.Session{ID: x.NewUUID(), Identity: i}, nil)) messages, err := reg.CourierPersister().NextMessages(context.Background(), 12) require.EqualError(t, err, "queue is empty") diff --git a/selfservice/hook/web_hook.go b/selfservice/hook/web_hook.go index 6773e9ec4b89..199e14100f40 100644 --- a/selfservice/hook/web_hook.go +++ b/selfservice/hook/web_hook.go @@ -34,6 +34,7 @@ import ( "github.com/ory/kratos/selfservice/flow/registration" "github.com/ory/kratos/selfservice/flow/settings" "github.com/ory/kratos/selfservice/flow/verification" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/session" "github.com/ory/kratos/text" "github.com/ory/kratos/ui/node" @@ -84,6 +85,7 @@ type ( RequestCookies map[string]string `json:"request_cookies"` Identity *identity.Identity `json:"identity,omitempty"` Session *session.Session `json:"session,omitempty"` + Claims *claims.Claims `json:"claims,omitempty"` } WebHook struct { @@ -134,7 +136,7 @@ func (e *WebHook) ExecuteLoginPreHook(_ http.ResponseWriter, req *http.Request, }) } -func (e *WebHook) ExecuteLoginPostHook(_ http.ResponseWriter, req *http.Request, _ node.UiNodeGroup, flow *login.Flow, session *session.Session) error { +func (e *WebHook) ExecuteLoginPostHook(_ http.ResponseWriter, req *http.Request, _ node.UiNodeGroup, flow *login.Flow, session *session.Session, claims *claims.Claims) error { return otelx.WithSpan(req.Context(), "selfservice.hook.WebHook.ExecuteLoginPostHook", func(ctx context.Context) error { return e.execute(ctx, &templateContext{ Flow: flow, @@ -144,6 +146,7 @@ func (e *WebHook) ExecuteLoginPostHook(_ http.ResponseWriter, req *http.Request, RequestCookies: cookies(req), Identity: session.Identity, Session: session, + Claims: claims, }) }) } diff --git a/selfservice/hook/web_hook_integration_test.go b/selfservice/hook/web_hook_integration_test.go index cae9659a285c..03bad3a36a23 100644 --- a/selfservice/hook/web_hook_integration_test.go +++ b/selfservice/hook/web_hook_integration_test.go @@ -39,6 +39,7 @@ import ( "github.com/ory/kratos/selfservice/flow/settings" "github.com/ory/kratos/selfservice/flow/verification" "github.com/ory/kratos/selfservice/hook" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/session" "github.com/ory/kratos/text" "github.com/ory/kratos/ui/node" @@ -56,6 +57,13 @@ var transientPayload = json.RawMessage(`{ } }`) +var oidcClaims = claims.Claims{ + Nickname: "nicky", + RawClaims: map[string]interface{}{ + "groups": []string{"first", "second"}, + }, +} + func TestWebHooks(t *testing.T) { _, reg := internal.NewFastRegistryWithMocks(t) logger := logrusx.New("kratos", "test") @@ -159,7 +167,7 @@ func TestWebHooks(t *testing.T) { return body } - bodyWithFlowAndIdentityAndSessionAndTransientPayload := func(req *http.Request, f flow.Flow, s *session.Session, tp json.RawMessage) string { + bodyWithFlowAndIdentityAndSessionAndClaimsAndTransientPayload := func(req *http.Request, f flow.Flow, s *session.Session, c *claims.Claims, tp json.RawMessage) string { body := fmt.Sprintf(`{ "flow_id": "%s", "identity_id": "%s", @@ -171,8 +179,10 @@ func TestWebHooks(t *testing.T) { "Some-Cookie-2": "Some-other-Cookie-Value", "Some-Cookie-3": "Third-Cookie-Value" }, - "transient_payload": %s - }`, f.GetID(), s.Identity.ID, s.ID, req.Method, "http://www.ory.sh/some_end_point", string(tp)) + "transient_payload": %s, + "nickname": "%s", + "groups": ["%s", "%s"] + }`, f.GetID(), s.Identity.ID, s.ID, req.Method, "http://www.ory.sh/some_end_point", string(tp), c.Nickname, c.RawClaims["groups"].([]string)[0], c.RawClaims["groups"].([]string)[1]) if len(req.Header) != 0 { if ua := req.Header.Get("User-Agent"); ua != "" { body, _ = sjson.Set(body, "headers.User-Agent", []string{ua}) @@ -202,10 +212,10 @@ func TestWebHooks(t *testing.T) { uc: "Post Login Hook", createFlow: func() flow.Flow { return &login.Flow{ID: x.NewUUID(), TransientPayload: transientPayload} }, callWebHook: func(wh *hook.WebHook, req *http.Request, f flow.Flow, s *session.Session) error { - return wh.ExecuteLoginPostHook(nil, req, node.PasswordGroup, f.(*login.Flow), s) + return wh.ExecuteLoginPostHook(nil, req, node.PasswordGroup, f.(*login.Flow), s, &oidcClaims) }, expectedBody: func(req *http.Request, f flow.Flow, s *session.Session) string { - return bodyWithFlowAndIdentityAndSessionAndTransientPayload(req, f, s, transientPayload) + return bodyWithFlowAndIdentityAndSessionAndClaimsAndTransientPayload(req, f, s, &oidcClaims, transientPayload) }, }, { @@ -446,7 +456,7 @@ func TestWebHooks(t *testing.T) { uc: "Post Login Hook - no block", createFlow: func() flow.Flow { return &login.Flow{ID: x.NewUUID()} }, callWebHook: func(wh *hook.WebHook, req *http.Request, f flow.Flow, s *session.Session) error { - return wh.ExecuteLoginPostHook(nil, req, node.PasswordGroup, f.(*login.Flow), s) + return wh.ExecuteLoginPostHook(nil, req, node.PasswordGroup, f.(*login.Flow), s, nil) }, webHookResponse: func() (int, []byte) { return http.StatusOK, []byte{} @@ -457,7 +467,7 @@ func TestWebHooks(t *testing.T) { uc: "Post Login Hook - block", createFlow: func() flow.Flow { return &login.Flow{ID: x.NewUUID()} }, callWebHook: func(wh *hook.WebHook, req *http.Request, f flow.Flow, s *session.Session) error { - return wh.ExecuteLoginPostHook(nil, req, node.PasswordGroup, f.(*login.Flow), s) + return wh.ExecuteLoginPostHook(nil, req, node.PasswordGroup, f.(*login.Flow), s, nil) }, webHookResponse: func() (int, []byte) { return http.StatusBadRequest, webHookResponse @@ -1022,7 +1032,7 @@ func TestDisallowPrivateIPRanges(t *testing.T) { "method": "GET", "body": "file://stub/test_body.jsonnet" }`)) - err := wh.ExecuteLoginPostHook(nil, req, node.DefaultGroup, f, s) + err := wh.ExecuteLoginPostHook(nil, req, node.DefaultGroup, f, s, nil) require.Error(t, err) require.Contains(t, err.Error(), "is not a permitted destination") }) @@ -1034,7 +1044,7 @@ func TestDisallowPrivateIPRanges(t *testing.T) { "method": "GET", "body": "file://stub/test_body.jsonnet" }`)) - err := wh.ExecuteLoginPostHook(nil, req, node.DefaultGroup, f, s) + err := wh.ExecuteLoginPostHook(nil, req, node.DefaultGroup, f, s, nil) require.Error(t, err, "the target does not exist and we still receive an error") require.NotContains(t, err.Error(), "is not a permitted destination", "but the error is not related to the IP range.") }) @@ -1055,7 +1065,7 @@ func TestDisallowPrivateIPRanges(t *testing.T) { "method": "GET", "body": "http://192.168.178.0/test_body.jsonnet" }`)) - err := wh.ExecuteLoginPostHook(nil, req, node.DefaultGroup, f, s) + err := wh.ExecuteLoginPostHook(nil, req, node.DefaultGroup, f, s, nil) require.Error(t, err) require.Contains(t, err.Error(), "is not a permitted destination") }) @@ -1112,7 +1122,7 @@ func TestAsyncWebhook(t *testing.T) { "ignore": true } }`, webhookReceiver.URL))) - err := wh.ExecuteLoginPostHook(nil, req, node.DefaultGroup, f, s) + err := wh.ExecuteLoginPostHook(nil, req, node.DefaultGroup, f, s, nil) require.NoError(t, err) // execution returns immediately for async webhook select { case <-time.After(1 * time.Second): diff --git a/selfservice/strategy/oidc/claims/claims.go b/selfservice/strategy/oidc/claims/claims.go new file mode 100644 index 000000000000..73f492652359 --- /dev/null +++ b/selfservice/strategy/oidc/claims/claims.go @@ -0,0 +1,49 @@ +package claims + +import ( + "github.com/pkg/errors" + + "github.com/ory/herodot" + "github.com/ory/kratos/x" +) + +// ConvertibleBoolean is used as Apple casually sends the email_verified field as a string. +type Claims struct { + Issuer string `json:"iss,omitempty"` + Subject string `json:"sub,omitempty"` + Name string `json:"name,omitempty"` + GivenName string `json:"given_name,omitempty"` + FamilyName string `json:"family_name,omitempty"` + LastName string `json:"last_name,omitempty"` + MiddleName string `json:"middle_name,omitempty"` + Nickname string `json:"nickname,omitempty"` + PreferredUsername string `json:"preferred_username,omitempty"` + Profile string `json:"profile,omitempty"` + Picture string `json:"picture,omitempty"` + Website string `json:"website,omitempty"` + Email string `json:"email,omitempty"` + EmailVerified x.ConvertibleBoolean `json:"email_verified,omitempty"` + Gender string `json:"gender,omitempty"` + Birthdate string `json:"birthdate,omitempty"` + Zoneinfo string `json:"zoneinfo,omitempty"` + Locale Locale `json:"locale,omitempty"` + PhoneNumber string `json:"phone_number,omitempty"` + PhoneNumberVerified bool `json:"phone_number_verified,omitempty"` + UpdatedAt int64 `json:"updated_at,omitempty"` + HD string `json:"hd,omitempty"` + Team string `json:"team,omitempty"` + Nonce string `json:"nonce,omitempty"` + NonceSupported bool `json:"nonce_supported,omitempty"` + RawClaims map[string]interface{} `json:"raw_claims,omitempty"` +} + +// Validate checks if the claims are valid. +func (c *Claims) Validate() error { + if c.Subject == "" { + return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("provider did not return a subject")) + } + if c.Issuer == "" { + return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("issuer not set in claims")) + } + return nil +} diff --git a/selfservice/strategy/oidc/claims/claims_test.go b/selfservice/strategy/oidc/claims/claims_test.go new file mode 100644 index 000000000000..47ada4a4695c --- /dev/null +++ b/selfservice/strategy/oidc/claims/claims_test.go @@ -0,0 +1,18 @@ +package claims_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/ory/kratos/selfservice/strategy/oidc/claims" +) + +func TestClaimsValidate(t *testing.T) { + require.Error(t, new(claims.Claims).Validate()) + require.Error(t, (&claims.Claims{Issuer: "not-empty"}).Validate()) + require.Error(t, (&claims.Claims{Issuer: "not-empty"}).Validate()) + require.Error(t, (&claims.Claims{Subject: "not-empty"}).Validate()) + require.Error(t, (&claims.Claims{Subject: "not-empty"}).Validate()) + require.NoError(t, (&claims.Claims{Issuer: "not-empty", Subject: "not-empty"}).Validate()) +} diff --git a/selfservice/strategy/oidc/claims/locale.go b/selfservice/strategy/oidc/claims/locale.go new file mode 100644 index 000000000000..07fb8576becc --- /dev/null +++ b/selfservice/strategy/oidc/claims/locale.go @@ -0,0 +1,29 @@ +package claims + +import ( + "encoding/json" + "strings" +) + +type Locale string + +func (l *Locale) UnmarshalJSON(data []byte) error { + var linkedInLocale struct { + Language string `json:"language"` + Country string `json:"country"` + } + if err := json.Unmarshal(data, &linkedInLocale); err == nil { + switch { + case linkedInLocale.Language == "": + *l = Locale(linkedInLocale.Country) + case linkedInLocale.Country == "": + *l = Locale(linkedInLocale.Language) + default: + *l = Locale(strings.Join([]string{linkedInLocale.Language, linkedInLocale.Country}, "-")) + } + + return nil + } + + return json.Unmarshal(data, (*string)(l)) +} diff --git a/selfservice/strategy/oidc/provider.go b/selfservice/strategy/oidc/provider.go index 30ea305a22ed..c241d12f257c 100644 --- a/selfservice/strategy/oidc/provider.go +++ b/selfservice/strategy/oidc/provider.go @@ -5,19 +5,14 @@ package oidc import ( "context" - "encoding/json" "net/http" "net/url" - "strings" "github.com/dghubble/oauth1" - "github.com/pkg/errors" - "github.com/ory/herodot" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "golang.org/x/oauth2" - - "github.com/ory/kratos/x" ) type Provider interface { @@ -28,14 +23,14 @@ type OAuth2Provider interface { Provider AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption OAuth2(ctx context.Context) (*oauth2.Config, error) - Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) + Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) } type OAuth1Provider interface { Provider OAuth1(ctx context.Context) *oauth1.Config AuthURL(ctx context.Context, state string) (string, error) - Claims(ctx context.Context, token *oauth1.Token) (*Claims, error) + Claims(ctx context.Context, token *oauth1.Token) (*claims.Claims, error) ExchangeToken(ctx context.Context, req *http.Request) (*oauth1.Token, error) } @@ -44,75 +39,11 @@ type OAuth2TokenExchanger interface { } type IDTokenVerifier interface { - Verify(ctx context.Context, rawIDToken string) (*Claims, error) + Verify(ctx context.Context, rawIDToken string) (*claims.Claims, error) } type NonceValidationSkipper interface { - CanSkipNonce(*Claims) bool -} - -// ConvertibleBoolean is used as Apple casually sends the email_verified field as a string. -type Claims struct { - Issuer string `json:"iss,omitempty"` - Subject string `json:"sub,omitempty"` - Name string `json:"name,omitempty"` - GivenName string `json:"given_name,omitempty"` - FamilyName string `json:"family_name,omitempty"` - LastName string `json:"last_name,omitempty"` - MiddleName string `json:"middle_name,omitempty"` - Nickname string `json:"nickname,omitempty"` - PreferredUsername string `json:"preferred_username,omitempty"` - Profile string `json:"profile,omitempty"` - Picture string `json:"picture,omitempty"` - Website string `json:"website,omitempty"` - Email string `json:"email,omitempty"` - EmailVerified x.ConvertibleBoolean `json:"email_verified,omitempty"` - Gender string `json:"gender,omitempty"` - Birthdate string `json:"birthdate,omitempty"` - Zoneinfo string `json:"zoneinfo,omitempty"` - Locale Locale `json:"locale,omitempty"` - PhoneNumber string `json:"phone_number,omitempty"` - PhoneNumberVerified bool `json:"phone_number_verified,omitempty"` - UpdatedAt int64 `json:"updated_at,omitempty"` - HD string `json:"hd,omitempty"` - Team string `json:"team,omitempty"` - Nonce string `json:"nonce,omitempty"` - NonceSupported bool `json:"nonce_supported,omitempty"` - RawClaims map[string]interface{} `json:"raw_claims,omitempty"` -} - -type Locale string - -func (l *Locale) UnmarshalJSON(data []byte) error { - var linkedInLocale struct { - Language string `json:"language"` - Country string `json:"country"` - } - if err := json.Unmarshal(data, &linkedInLocale); err == nil { - switch { - case linkedInLocale.Language == "": - *l = Locale(linkedInLocale.Country) - case linkedInLocale.Country == "": - *l = Locale(linkedInLocale.Language) - default: - *l = Locale(strings.Join([]string{linkedInLocale.Language, linkedInLocale.Country}, "-")) - } - - return nil - } - - return json.Unmarshal(data, (*string)(l)) -} - -// Validate checks if the claims are valid. -func (c *Claims) Validate() error { - if c.Subject == "" { - return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("provider did not return a subject")) - } - if c.Issuer == "" { - return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("issuer not set in claims")) - } - return nil + CanSkipNonce(*claims.Claims) bool } // UpstreamParameters returns a list of oauth2.AuthCodeOption based on the upstream parameters. diff --git a/selfservice/strategy/oidc/provider_apple.go b/selfservice/strategy/oidc/provider_apple.go index 706a7150c5e4..0a45411e7b27 100644 --- a/selfservice/strategy/oidc/provider_apple.go +++ b/selfservice/strategy/oidc/provider_apple.go @@ -15,6 +15,8 @@ import ( "github.com/coreos/go-oidc/v3/oidc" "github.com/golang-jwt/jwt/v4" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" + "github.com/pkg/errors" "golang.org/x/oauth2" @@ -112,7 +114,7 @@ func (a *ProviderApple) AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption { return options } -func (a *ProviderApple) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (a *ProviderApple) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { claims, err := a.ProviderGenericOIDC.Claims(ctx, exchange, query) if err != nil { return claims, err @@ -126,7 +128,7 @@ func (a *ProviderApple) Claims(ctx context.Context, exchange *oauth2.Token, quer // The info is sent as an extra query parameter to the redirect URL. // See https://developer.apple.com/documentation/sign_in_with_apple/sign_in_with_apple_js/configuring_your_webpage_for_sign_in_with_apple#3331292 // Note that there's no way to make sure the info hasn't been tampered with. -func (a *ProviderApple) DecodeQuery(query url.Values, claims *Claims) { +func (a *ProviderApple) DecodeQuery(query url.Values, claims *claims.Claims) { var user struct { Name *struct { FirstName *string `json:"firstName"` @@ -156,7 +158,7 @@ var _ IDTokenVerifier = new(ProviderApple) const issuerUrlApple = "https://appleid.apple.com" -func (a *ProviderApple) Verify(ctx context.Context, rawIDToken string) (*Claims, error) { +func (a *ProviderApple) Verify(ctx context.Context, rawIDToken string) (*claims.Claims, error) { keySet := oidc.NewRemoteKeySet(ctx, a.JWKSUrl) ctx = oidc.ClientContext(ctx, a.reg.HTTPClient(ctx).HTTPClient) @@ -165,6 +167,6 @@ func (a *ProviderApple) Verify(ctx context.Context, rawIDToken string) (*Claims, var _ NonceValidationSkipper = new(ProviderApple) -func (a *ProviderApple) CanSkipNonce(c *Claims) bool { +func (a *ProviderApple) CanSkipNonce(c *claims.Claims) bool { return c.NonceSupported } diff --git a/selfservice/strategy/oidc/provider_apple_test.go b/selfservice/strategy/oidc/provider_apple_test.go index 422ae643708a..39c6c95462f4 100644 --- a/selfservice/strategy/oidc/provider_apple_test.go +++ b/selfservice/strategy/oidc/provider_apple_test.go @@ -20,6 +20,7 @@ import ( "github.com/ory/kratos/internal" "github.com/ory/kratos/selfservice/strategy/oidc" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" ) func TestDecodeQuery(t *testing.T) { @@ -28,15 +29,15 @@ func TestDecodeQuery(t *testing.T) { } for k, tc := range []struct { - claims *oidc.Claims + claims *claims.Claims familyName string givenName string lastName string }{ - {claims: &oidc.Claims{}, familyName: "last", givenName: "first", lastName: "last"}, - {claims: &oidc.Claims{FamilyName: "fam"}, familyName: "fam", givenName: "first", lastName: "last"}, - {claims: &oidc.Claims{FamilyName: "fam", GivenName: "giv"}, familyName: "fam", givenName: "giv", lastName: "last"}, - {claims: &oidc.Claims{FamilyName: "fam", GivenName: "giv", LastName: "las"}, familyName: "fam", givenName: "giv", lastName: "las"}, + {claims: &claims.Claims{}, familyName: "last", givenName: "first", lastName: "last"}, + {claims: &claims.Claims{FamilyName: "fam"}, familyName: "fam", givenName: "first", lastName: "last"}, + {claims: &claims.Claims{FamilyName: "fam", GivenName: "giv"}, familyName: "fam", givenName: "giv", lastName: "last"}, + {claims: &claims.Claims{FamilyName: "fam", GivenName: "giv", LastName: "las"}, familyName: "fam", givenName: "giv", lastName: "las"}, } { t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { a := oidc.NewProviderApple(&oidc.Configuration{}, nil).(*oidc.ProviderApple) diff --git a/selfservice/strategy/oidc/provider_auth0.go b/selfservice/strategy/oidc/provider_auth0.go index a4c9ee46e1ab..479879980371 100644 --- a/selfservice/strategy/oidc/provider_auth0.go +++ b/selfservice/strategy/oidc/provider_auth0.go @@ -11,6 +11,7 @@ import ( "path" "time" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/httpx" "github.com/ory/x/stringsx" @@ -71,7 +72,7 @@ func (g *ProviderAuth0) OAuth2(ctx context.Context) (*oauth2.Config, error) { return g.oauth2(ctx) } -func (g *ProviderAuth0) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (g *ProviderAuth0) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { o, err := g.OAuth2(ctx) if err != nil { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) @@ -113,7 +114,7 @@ func (g *ProviderAuth0) Claims(ctx context.Context, exchange *oauth2.Token, quer } // Once we get here, we know that if there is an updated_at field in the json, it is the correct type. - var claims Claims + var claims claims.Claims if err := json.Unmarshal(b, &claims); err != nil { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) } diff --git a/selfservice/strategy/oidc/provider_dingtalk.go b/selfservice/strategy/oidc/provider_dingtalk.go index 12abffe85942..436dbe931aee 100644 --- a/selfservice/strategy/oidc/provider_dingtalk.go +++ b/selfservice/strategy/oidc/provider_dingtalk.go @@ -13,6 +13,7 @@ import ( "github.com/pkg/errors" "golang.org/x/oauth2" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/httpx" "github.com/hashicorp/go-retryablehttp" @@ -40,7 +41,7 @@ func (g *ProviderDingTalk) Config() *Configuration { } func (g *ProviderDingTalk) oauth2(ctx context.Context) *oauth2.Config { - var endpoint = oauth2.Endpoint{ + endpoint := oauth2.Endpoint{ AuthURL: "https://login.dingtalk.com/oauth2/auth", TokenURL: "https://api.dingtalk.com/v1.0/oauth2/userAccessToken", } @@ -122,7 +123,7 @@ func (g *ProviderDingTalk) ExchangeOAuth2Token(ctx context.Context, code string, return token, nil } -func (g *ProviderDingTalk) Claims(ctx context.Context, exchange *oauth2.Token, _ url.Values) (*Claims, error) { +func (g *ProviderDingTalk) Claims(ctx context.Context, exchange *oauth2.Token, _ url.Values) (*claims.Claims, error) { userInfoURL := "https://api.dingtalk.com/v1.0/contact/users/me" accessToken := exchange.AccessToken @@ -160,7 +161,7 @@ func (g *ProviderDingTalk) Claims(ctx context.Context, exchange *oauth2.Token, _ return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("userResp.ErrCode = %s, userResp.ErrMsg = %s", user.ErrCode, user.ErrMsg)) } - return &Claims{ + return &claims.Claims{ Issuer: userInfoURL, Subject: user.OpenId, Nickname: user.Nick, diff --git a/selfservice/strategy/oidc/provider_discord.go b/selfservice/strategy/oidc/provider_discord.go index 99bea24d5770..2c542d0aa8f8 100644 --- a/selfservice/strategy/oidc/provider_discord.go +++ b/selfservice/strategy/oidc/provider_discord.go @@ -8,6 +8,7 @@ import ( "fmt" "net/url" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/x" "github.com/bwmarrin/discordgo" @@ -66,7 +67,7 @@ func (d *ProviderDiscord) AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption { } } -func (d *ProviderDiscord) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (d *ProviderDiscord) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { grantedScopes := stringsx.Splitx(fmt.Sprintf("%s", exchange.Extra("scope")), " ") for _, check := range d.Config().Scope { if !stringslice.Has(grantedScopes, check) { @@ -84,7 +85,7 @@ func (d *ProviderDiscord) Claims(ctx context.Context, exchange *oauth2.Token, qu return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) } - claims := &Claims{ + claims := &claims.Claims{ Issuer: discordgo.EndpointOauth2, Subject: user.ID, Name: fmt.Sprintf("%s#%s", user.Username, user.Discriminator), @@ -93,7 +94,7 @@ func (d *ProviderDiscord) Claims(ctx context.Context, exchange *oauth2.Token, qu Picture: user.AvatarURL(""), Email: user.Email, EmailVerified: x.ConvertibleBoolean(user.Verified), - Locale: Locale(user.Locale), + Locale: claims.Locale(user.Locale), } return claims, nil diff --git a/selfservice/strategy/oidc/provider_facebook.go b/selfservice/strategy/oidc/provider_facebook.go index abf9806cce05..be67f20098b1 100644 --- a/selfservice/strategy/oidc/provider_facebook.go +++ b/selfservice/strategy/oidc/provider_facebook.go @@ -16,6 +16,7 @@ import ( "github.com/ory/x/httpx" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/x" "github.com/pkg/errors" @@ -62,7 +63,7 @@ func (g *ProviderFacebook) OAuth2(ctx context.Context) (*oauth2.Config, error) { return g.oauth2ConfigFromEndpoint(ctx, endpoint), nil } -func (g *ProviderFacebook) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (g *ProviderFacebook) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { o, err := g.OAuth2(ctx) if err != nil { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) @@ -114,7 +115,7 @@ func (g *ProviderFacebook) Claims(ctx context.Context, exchange *oauth2.Token, q user.EmailVerified = true } - return &Claims{ + return &claims.Claims{ Issuer: u.String(), Subject: user.Id, Name: user.Name, diff --git a/selfservice/strategy/oidc/provider_generic_oidc.go b/selfservice/strategy/oidc/provider_generic_oidc.go index 146505165807..4e28ebcc4318 100644 --- a/selfservice/strategy/oidc/provider_generic_oidc.go +++ b/selfservice/strategy/oidc/provider_generic_oidc.go @@ -13,6 +13,7 @@ import ( gooidc "github.com/coreos/go-oidc/v3/oidc" "github.com/ory/herodot" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/stringslice" ) @@ -96,13 +97,13 @@ func (g *ProviderGenericOIDC) AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption return options } -func (g *ProviderGenericOIDC) verifyAndDecodeClaimsWithProvider(ctx context.Context, provider *gooidc.Provider, raw string) (*Claims, error) { +func (g *ProviderGenericOIDC) verifyAndDecodeClaimsWithProvider(ctx context.Context, provider *gooidc.Provider, raw string) (*claims.Claims, error) { token, err := provider.VerifierContext(g.withHTTPClientContext(ctx), &gooidc.Config{ClientID: g.config.ClientID}).Verify(ctx, raw) if err != nil { return nil, errors.WithStack(herodot.ErrBadRequest.WithReasonf("%s", err)) } - var claims Claims + var claims claims.Claims if err := token.Claims(&claims); err != nil { return nil, errors.WithStack(herodot.ErrBadRequest.WithReasonf("%s", err)) } @@ -116,7 +117,7 @@ func (g *ProviderGenericOIDC) verifyAndDecodeClaimsWithProvider(ctx context.Cont return &claims, nil } -func (g *ProviderGenericOIDC) Claims(ctx context.Context, exchange *oauth2.Token, _ url.Values) (*Claims, error) { +func (g *ProviderGenericOIDC) Claims(ctx context.Context, exchange *oauth2.Token, _ url.Values) (*claims.Claims, error) { switch g.config.ClaimsSource { case ClaimsSourceIDToken, "": return g.claimsFromIDToken(ctx, exchange) @@ -128,7 +129,7 @@ func (g *ProviderGenericOIDC) Claims(ctx context.Context, exchange *oauth2.Token WithReasonf("Unknown claims source: %q", g.config.ClaimsSource)) } -func (g *ProviderGenericOIDC) claimsFromUserInfo(ctx context.Context, exchange *oauth2.Token) (*Claims, error) { +func (g *ProviderGenericOIDC) claimsFromUserInfo(ctx context.Context, exchange *oauth2.Token) (*claims.Claims, error) { p, err := g.provider(ctx) if err != nil { return nil, err @@ -139,7 +140,7 @@ func (g *ProviderGenericOIDC) claimsFromUserInfo(ctx context.Context, exchange * return nil, err } - var claims Claims + var claims claims.Claims if err = userInfo.Claims(&claims); err != nil { return nil, err } @@ -178,7 +179,7 @@ func (g *ProviderGenericOIDC) claimsFromUserInfo(ctx context.Context, exchange * return &claims, nil } -func (g *ProviderGenericOIDC) claimsFromIDToken(ctx context.Context, exchange *oauth2.Token) (*Claims, error) { +func (g *ProviderGenericOIDC) claimsFromIDToken(ctx context.Context, exchange *oauth2.Token) (*claims.Claims, error) { p, raw, err := g.idTokenAndProvider(ctx, exchange) if err != nil { return nil, err diff --git a/selfservice/strategy/oidc/provider_github.go b/selfservice/strategy/oidc/provider_github.go index fe1d2bc371d1..08b00d19afa5 100644 --- a/selfservice/strategy/oidc/provider_github.go +++ b/selfservice/strategy/oidc/provider_github.go @@ -8,6 +8,7 @@ import ( "fmt" "net/url" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/x" "github.com/pkg/errors" @@ -60,7 +61,7 @@ func (g *ProviderGitHub) AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption { return []oauth2.AuthCodeOption{} } -func (g *ProviderGitHub) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (g *ProviderGitHub) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { grantedScopes := stringsx.Splitx(fmt.Sprintf("%s", exchange.Extra("scope")), ",") for _, check := range g.Config().Scope { if !stringslice.Has(grantedScopes, check) { @@ -76,7 +77,7 @@ func (g *ProviderGitHub) Claims(ctx context.Context, exchange *oauth2.Token, que return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) } - claims := &Claims{ + claims := &claims.Claims{ Subject: fmt.Sprintf("%d", user.GetID()), Issuer: github.Endpoint.TokenURL, Name: user.GetName(), diff --git a/selfservice/strategy/oidc/provider_github_app.go b/selfservice/strategy/oidc/provider_github_app.go index 83cfd9bdb882..dfa55d225849 100644 --- a/selfservice/strategy/oidc/provider_github_app.go +++ b/selfservice/strategy/oidc/provider_github_app.go @@ -8,6 +8,7 @@ import ( "fmt" "net/url" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/x" "github.com/ory/x/httpx" @@ -57,7 +58,7 @@ func (g *ProviderGitHubApp) AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption { return []oauth2.AuthCodeOption{} } -func (g *ProviderGitHubApp) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (g *ProviderGitHubApp) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { ctx, client := httpx.SetOAuth2(ctx, g.reg.HTTPClient(ctx), g.oauth2(ctx), exchange) gh := ghapi.NewClient(client.HTTPClient) @@ -66,7 +67,7 @@ func (g *ProviderGitHubApp) Claims(ctx context.Context, exchange *oauth2.Token, return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) } - claims := &Claims{ + claims := &claims.Claims{ Subject: fmt.Sprintf("%d", user.GetID()), Issuer: github.Endpoint.TokenURL, Name: user.GetName(), diff --git a/selfservice/strategy/oidc/provider_gitlab.go b/selfservice/strategy/oidc/provider_gitlab.go index 9ef55b4beef7..417e6acf60c2 100644 --- a/selfservice/strategy/oidc/provider_gitlab.go +++ b/selfservice/strategy/oidc/provider_gitlab.go @@ -9,6 +9,7 @@ import ( "net/url" "path" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/stringsx" "github.com/hashicorp/go-retryablehttp" @@ -69,7 +70,7 @@ func (g *ProviderGitLab) OAuth2(ctx context.Context) (*oauth2.Config, error) { return g.oauth2(ctx) } -func (g *ProviderGitLab) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (g *ProviderGitLab) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { o, err := g.OAuth2(ctx) if err != nil { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) @@ -98,7 +99,7 @@ func (g *ProviderGitLab) Claims(ctx context.Context, exchange *oauth2.Token, que return nil, err } - var claims Claims + var claims claims.Claims if err := json.NewDecoder(resp.Body).Decode(&claims); err != nil { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) } diff --git a/selfservice/strategy/oidc/provider_google.go b/selfservice/strategy/oidc/provider_google.go index e27832692faa..867380dcd720 100644 --- a/selfservice/strategy/oidc/provider_google.go +++ b/selfservice/strategy/oidc/provider_google.go @@ -9,6 +9,7 @@ import ( gooidc "github.com/coreos/go-oidc/v3/oidc" "golang.org/x/oauth2" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/stringslice" ) @@ -73,7 +74,7 @@ var _ IDTokenVerifier = new(ProviderGoogle) const issuerUrlGoogle = "https://accounts.google.com" -func (p *ProviderGoogle) Verify(ctx context.Context, rawIDToken string) (*Claims, error) { +func (p *ProviderGoogle) Verify(ctx context.Context, rawIDToken string) (*claims.Claims, error) { keySet := gooidc.NewRemoteKeySet(ctx, p.JWKSUrl) ctx = gooidc.ClientContext(ctx, p.reg.HTTPClient(ctx).HTTPClient) return verifyToken(ctx, keySet, p.config, rawIDToken, issuerUrlGoogle) @@ -81,7 +82,7 @@ func (p *ProviderGoogle) Verify(ctx context.Context, rawIDToken string) (*Claims var _ NonceValidationSkipper = new(ProviderGoogle) -func (a *ProviderGoogle) CanSkipNonce(c *Claims) bool { +func (a *ProviderGoogle) CanSkipNonce(c *claims.Claims) bool { // Not all SDKs support nonce validation, so we skip it if no nonce is present in the claims of the ID Token. return c.Nonce == "" } diff --git a/selfservice/strategy/oidc/provider_lark.go b/selfservice/strategy/oidc/provider_lark.go index 52902dc20e8c..66006d08754c 100644 --- a/selfservice/strategy/oidc/provider_lark.go +++ b/selfservice/strategy/oidc/provider_lark.go @@ -13,6 +13,7 @@ import ( "golang.org/x/oauth2" "github.com/ory/herodot" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/httpx" ) @@ -46,7 +47,6 @@ func (g *ProviderLark) Config() *Configuration { } func (g *ProviderLark) OAuth2(ctx context.Context) (*oauth2.Config, error) { - return &oauth2.Config{ ClientID: g.config.ClientID, ClientSecret: g.config.ClientSecret, @@ -55,10 +55,9 @@ func (g *ProviderLark) OAuth2(ctx context.Context) (*oauth2.Config, error) { Scopes: g.config.Scope, RedirectURL: g.config.Redir(g.reg.Config().OIDCRedirectURIBase(ctx)), }, nil - } -func (g *ProviderLark) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (g *ProviderLark) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { // larkClaim is defined in the https://open.feishu.cn/document/common-capabilities/sso/api/get-user-info type larkClaim struct { Sub string `json:"sub"` @@ -101,7 +100,7 @@ func (g *ProviderLark) Claims(ctx context.Context, exchange *oauth2.Token, query return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) } - return &Claims{ + return &claims.Claims{ Issuer: larkUserEndpoint, Subject: user.OpenID, Name: user.Name, diff --git a/selfservice/strategy/oidc/provider_linkedin.go b/selfservice/strategy/oidc/provider_linkedin.go index 03a3db3e490d..363a33fe8afd 100644 --- a/selfservice/strategy/oidc/provider_linkedin.go +++ b/selfservice/strategy/oidc/provider_linkedin.go @@ -9,6 +9,7 @@ import ( "net/http" "net/url" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/otelx" "github.com/hashicorp/go-retryablehttp" @@ -165,7 +166,7 @@ func (l *ProviderLinkedIn) ProfilePicture(profile *LinkedInProfile) string { return identifiers[0].Identifier } -func (l *ProviderLinkedIn) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (_ *Claims, err error) { +func (l *ProviderLinkedIn) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (_ *claims.Claims, err error) { ctx, span := l.reg.Tracer(ctx).Tracer().Start(ctx, "selfservice.strategy.oidc.ProviderLinkedIn.Claims") defer otelx.End(span, &err) @@ -185,7 +186,7 @@ func (l *ProviderLinkedIn) Claims(ctx context.Context, exchange *oauth2.Token, q return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) } - claims := &Claims{ + claims := &claims.Claims{ Subject: profile.ID, Issuer: "https://login.linkedin.com/", Email: email.Elements[0].Handle.EmailAddress, diff --git a/selfservice/strategy/oidc/provider_linkedin_test.go b/selfservice/strategy/oidc/provider_linkedin_test.go index d5b9df86d25a..cff4f0edb9f7 100644 --- a/selfservice/strategy/oidc/provider_linkedin_test.go +++ b/selfservice/strategy/oidc/provider_linkedin_test.go @@ -19,6 +19,7 @@ import ( "github.com/ory/kratos/internal" "github.com/ory/kratos/selfservice/strategy/oidc" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" ) func TestProviderLinkedin_Claims(t *testing.T) { @@ -122,7 +123,7 @@ func TestProviderLinkedin_Claims(t *testing.T) { ) require.NoError(t, err) - assert.Equal(t, &oidc.Claims{ + assert.Equal(t, &claims.Claims{ Issuer: "https://login.linkedin.com/", Subject: "5foOWOiYXD", GivenName: "John", @@ -198,7 +199,7 @@ func TestProviderLinkedin_No_Picture(t *testing.T) { ) require.NoError(t, err) - assert.Equal(t, &oidc.Claims{ + assert.Equal(t, &claims.Claims{ Issuer: "https://login.linkedin.com/", Subject: "5foOWOiYXD", GivenName: "John", diff --git a/selfservice/strategy/oidc/provider_microsoft.go b/selfservice/strategy/oidc/provider_microsoft.go index d69206ec4d87..c05c31407c27 100644 --- a/selfservice/strategy/oidc/provider_microsoft.go +++ b/selfservice/strategy/oidc/provider_microsoft.go @@ -11,6 +11,7 @@ import ( "github.com/hashicorp/go-retryablehttp" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/httpx" "github.com/gofrs/uuid" @@ -53,7 +54,7 @@ func (m *ProviderMicrosoft) OAuth2(ctx context.Context) (*oauth2.Config, error) return m.oauth2ConfigFromEndpoint(ctx, endpoint), nil } -func (m *ProviderMicrosoft) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (m *ProviderMicrosoft) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { raw, ok := exchange.Extra("id_token").(string) if !ok || len(raw) == 0 { return nil, errors.WithStack(ErrIDTokenMissing) @@ -84,7 +85,7 @@ func (m *ProviderMicrosoft) Claims(ctx context.Context, exchange *oauth2.Token, return m.updateSubject(ctx, claims, exchange) } -func (m *ProviderMicrosoft) updateSubject(ctx context.Context, claims *Claims, exchange *oauth2.Token) (*Claims, error) { +func (m *ProviderMicrosoft) updateSubject(ctx context.Context, claims *claims.Claims, exchange *oauth2.Token) (*claims.Claims, error) { if m.config.SubjectSource == "me" { o, err := m.OAuth2(ctx) if err != nil { diff --git a/selfservice/strategy/oidc/provider_netid.go b/selfservice/strategy/oidc/provider_netid.go index dfe83c958433..c9b823ce4cd1 100644 --- a/selfservice/strategy/oidc/provider_netid.go +++ b/selfservice/strategy/oidc/provider_netid.go @@ -11,6 +11,7 @@ import ( gooidc "github.com/coreos/go-oidc/v3/oidc" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/stringslice" "github.com/hashicorp/go-retryablehttp" @@ -71,7 +72,7 @@ func (n *ProviderNetID) oAuth2(ctx context.Context) (*oauth2.Config, error) { }, nil } -func (n *ProviderNetID) Claims(ctx context.Context, exchange *oauth2.Token, _ url.Values) (*Claims, error) { +func (n *ProviderNetID) Claims(ctx context.Context, exchange *oauth2.Token, _ url.Values) (*claims.Claims, error) { o, err := n.OAuth2(ctx) if err != nil { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) @@ -103,17 +104,17 @@ func (n *ProviderNetID) Claims(ctx context.Context, exchange *oauth2.Token, _ ur return nil, errors.WithStack(ErrIDTokenMissing) } - claims, err := n.verifyAndDecodeClaimsWithProvider(ctx, p, raw) + dec, err := n.verifyAndDecodeClaimsWithProvider(ctx, p, raw) if err != nil { return nil, err } - var userinfo Claims + var userinfo claims.Claims if err := json.NewDecoder(resp.Body).Decode(&userinfo); err != nil { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) } - userinfo.Issuer = claims.Issuer - userinfo.Subject = claims.Subject + userinfo.Issuer = dec.Issuer + userinfo.Subject = dec.Subject return &userinfo, nil } diff --git a/selfservice/strategy/oidc/provider_patreon.go b/selfservice/strategy/oidc/provider_patreon.go index 745dc8fcc199..cf4a09df2773 100644 --- a/selfservice/strategy/oidc/provider_patreon.go +++ b/selfservice/strategy/oidc/provider_patreon.go @@ -10,6 +10,7 @@ import ( "github.com/hashicorp/go-retryablehttp" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/httpx" "github.com/pkg/errors" @@ -79,7 +80,7 @@ func (d *ProviderPatreon) AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption { } } -func (d *ProviderPatreon) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (d *ProviderPatreon) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { identityUrl := "https://www.patreon.com/api/oauth2/v2/identity?fields%5Buser%5D=first_name,last_name,url,full_name,email,image_url" o := d.oauth2(ctx) @@ -107,7 +108,7 @@ func (d *ProviderPatreon) Claims(ctx context.Context, exchange *oauth2.Token, qu return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", jsonErr)) } - claims := &Claims{ + claims := &claims.Claims{ Issuer: "https://www.patreon.com/", Subject: data.Data.Id, Name: data.Data.Attributes.FullName, diff --git a/selfservice/strategy/oidc/provider_salesforce.go b/selfservice/strategy/oidc/provider_salesforce.go index 1d028a1a8de7..1644afa57cef 100644 --- a/selfservice/strategy/oidc/provider_salesforce.go +++ b/selfservice/strategy/oidc/provider_salesforce.go @@ -11,6 +11,7 @@ import ( "path" "time" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/httpx" "github.com/ory/x/stringsx" @@ -71,7 +72,7 @@ func (g *ProviderSalesforce) OAuth2(ctx context.Context) (*oauth2.Config, error) return g.oauth2(ctx) } -func (g *ProviderSalesforce) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (g *ProviderSalesforce) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { o, err := g.OAuth2(ctx) if err != nil { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) @@ -113,7 +114,7 @@ func (g *ProviderSalesforce) Claims(ctx context.Context, exchange *oauth2.Token, } // Once we get here, we know that if there is an updated_at field in the json, it is the correct type. - var claims Claims + var claims claims.Claims if err := json.Unmarshal(b, &claims); err != nil { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) } diff --git a/selfservice/strategy/oidc/provider_slack.go b/selfservice/strategy/oidc/provider_slack.go index 7c7e26c99da4..d1c4a9eb4519 100644 --- a/selfservice/strategy/oidc/provider_slack.go +++ b/selfservice/strategy/oidc/provider_slack.go @@ -9,6 +9,7 @@ import ( "net/url" "github.com/ory/herodot" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/pkg/errors" "golang.org/x/oauth2" @@ -61,7 +62,7 @@ func (d *ProviderSlack) AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption { return []oauth2.AuthCodeOption{} } -func (d *ProviderSlack) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (d *ProviderSlack) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { grantedScopes := stringsx.Splitx(fmt.Sprintf("%s", exchange.Extra("scope")), ",") for _, check := range d.Config().Scope { if !stringslice.Has(grantedScopes, check) { @@ -75,7 +76,7 @@ func (d *ProviderSlack) Claims(ctx context.Context, exchange *oauth2.Token, quer return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) } - claims := &Claims{ + claims := &claims.Claims{ Issuer: "https://slack.com/oauth/", Subject: identity.User.ID, Name: identity.User.Name, diff --git a/selfservice/strategy/oidc/provider_spotify.go b/selfservice/strategy/oidc/provider_spotify.go index 366105c94d0e..b1dc3791f423 100644 --- a/selfservice/strategy/oidc/provider_spotify.go +++ b/selfservice/strategy/oidc/provider_spotify.go @@ -13,6 +13,7 @@ import ( "github.com/pkg/errors" "golang.org/x/oauth2" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/httpx" "github.com/ory/x/stringslice" "github.com/ory/x/stringsx" @@ -60,7 +61,7 @@ func (g *ProviderSpotify) AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption { return []oauth2.AuthCodeOption{} } -func (g *ProviderSpotify) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (g *ProviderSpotify) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { grantedScopes := stringsx.Splitx(fmt.Sprintf("%s", exchange.Extra("scope")), " ") for _, check := range g.Config().Scope { if !stringslice.Has(grantedScopes, check) { @@ -85,7 +86,7 @@ func (g *ProviderSpotify) Claims(ctx context.Context, exchange *oauth2.Token, qu userPicture = user.Images[0].URL } - claims := &Claims{ + claims := &claims.Claims{ Subject: user.ID, Issuer: spotify.Endpoint.TokenURL, Name: user.DisplayName, diff --git a/selfservice/strategy/oidc/provider_test.go b/selfservice/strategy/oidc/provider_test.go index a5733d2e95f8..1041ea2c5c56 100644 --- a/selfservice/strategy/oidc/provider_test.go +++ b/selfservice/strategy/oidc/provider_test.go @@ -11,16 +11,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" -) -func TestClaimsValidate(t *testing.T) { - require.Error(t, new(Claims).Validate()) - require.Error(t, (&Claims{Issuer: "not-empty"}).Validate()) - require.Error(t, (&Claims{Issuer: "not-empty"}).Validate()) - require.Error(t, (&Claims{Subject: "not-empty"}).Validate()) - require.Error(t, (&Claims{Subject: "not-empty"}).Validate()) - require.NoError(t, (&Claims{Issuer: "not-empty", Subject: "not-empty"}).Validate()) -} + "github.com/ory/kratos/selfservice/strategy/oidc/claims" +) type TestProvider struct { *ProviderGenericOIDC @@ -43,11 +36,11 @@ func RegisterTestProvider(id string) func() { var _ IDTokenVerifier = new(TestProvider) -func (t *TestProvider) Verify(_ context.Context, token string) (*Claims, error) { +func (t *TestProvider) Verify(_ context.Context, token string) (*claims.Claims, error) { if token == "error" { return nil, fmt.Errorf("stub error") } - c := Claims{} + c := claims.Claims{} if err := json.Unmarshal([]byte(token), &c); err != nil { return nil, err } @@ -95,7 +88,7 @@ func TestLocale(t *testing.T) { expected: "", }} { t.Run(tc.name, func(t *testing.T) { - var c Claims + var c claims.Claims err := json.Unmarshal([]byte(tc.json), &c) if tc.assertErr != nil { tc.assertErr(t, err) diff --git a/selfservice/strategy/oidc/provider_userinfo_test.go b/selfservice/strategy/oidc/provider_userinfo_test.go index dde2507af319..1a720970af06 100644 --- a/selfservice/strategy/oidc/provider_userinfo_test.go +++ b/selfservice/strategy/oidc/provider_userinfo_test.go @@ -20,6 +20,7 @@ import ( "github.com/ory/kratos/driver/config" "github.com/ory/kratos/internal" "github.com/ory/kratos/selfservice/strategy/oidc" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/httpx" "github.com/ory/x/otelx" @@ -45,7 +46,7 @@ func TestProviderClaimsRespectsErrorCodes(t *testing.T) { ctx := context.Background() token := &oauth2.Token{AccessToken: "foo", Expiry: time.Now().Add(time.Hour)} - expectedClaims := &oidc.Claims{ + expectedClaims := &claims.Claims{ Issuer: "ignore-me", Subject: "123456789012345", Name: "John Doe", @@ -75,7 +76,7 @@ func TestProviderClaimsRespectsErrorCodes(t *testing.T) { config *oidc.Configuration provider oidc.Provider userInfoHandler func(req *http.Request) (*http.Response, error) - expectedClaims *oidc.Claims + expectedClaims *claims.Claims useToken *oauth2.Token hook func(t *testing.T) }{ @@ -135,7 +136,7 @@ func TestProviderClaimsRespectsErrorCodes(t *testing.T) { }, ) }, - expectedClaims: &oidc.Claims{Issuer: "https://broker.netid.de/", Subject: "1234567890", Name: "John Doe", GivenName: "John", FamilyName: "Doe", LastName: "", MiddleName: "", Nickname: "John Doe", PreferredUsername: "John Doe", Profile: "", Picture: "", Website: "", Email: "john.doe@example.com", EmailVerified: true, Gender: "", Birthdate: "01/01/1990", Zoneinfo: "", Locale: "", PhoneNumber: "", PhoneNumberVerified: false, UpdatedAt: 0, HD: "", Team: ""}, + expectedClaims: &claims.Claims{Issuer: "https://broker.netid.de/", Subject: "1234567890", Name: "John Doe", GivenName: "John", FamilyName: "Doe", LastName: "", MiddleName: "", Nickname: "John Doe", PreferredUsername: "John Doe", Profile: "", Picture: "", Website: "", Email: "john.doe@example.com", EmailVerified: true, Gender: "", Birthdate: "01/01/1990", Zoneinfo: "", Locale: "", PhoneNumber: "", PhoneNumberVerified: false, UpdatedAt: 0, HD: "", Team: ""}, }, { name: "vk", @@ -158,7 +159,7 @@ func TestProviderClaimsRespectsErrorCodes(t *testing.T) { return resp, err }, - expectedClaims: &oidc.Claims{ + expectedClaims: &claims.Claims{ Issuer: "https://api.vk.com/method/users.get", Subject: "123456789012345", Email: "john.doe@example.com", @@ -186,7 +187,7 @@ func TestProviderClaimsRespectsErrorCodes(t *testing.T) { return resp, err }, - expectedClaims: &oidc.Claims{ + expectedClaims: &claims.Claims{ Issuer: "https://login.yandex.ru/info", Subject: "123456789012345", Email: "john.doe@example.com", @@ -231,7 +232,7 @@ func TestProviderClaimsRespectsErrorCodes(t *testing.T) { }) return resp, err }, - expectedClaims: &oidc.Claims{ + expectedClaims: &claims.Claims{ Issuer: "https://graph.facebook.com/me?fields=id,name,first_name,last_name,middle_name,email,picture,birthday,gender&appsecret_proof=0c0d98f7e3d9d45e72e8877bc1b104327efb9c07b18f2ffeced76d81307f1fff", Subject: "123456789012345", Name: "John Doe", @@ -302,7 +303,7 @@ func TestProviderClaimsRespectsErrorCodes(t *testing.T) { }, ) }, - expectedClaims: &oidc.Claims{ + expectedClaims: &claims.Claims{ Issuer: "https://login.microsoftonline.com/a9b86385-f32c-4803-afc8-4b2312fbdf24/v2.0", Subject: "new-id", Name: "John Doe", Email: "john.doe@example.com", RawClaims: map[string]interface{}{"aud": []interface{}{"foo"}, "exp": 4.071728504e+09, "iat": 1.516239022e+09, "iss": "https://login.microsoftonline.com/a9b86385-f32c-4803-afc8-4b2312fbdf24/v2.0", "email": "john.doe@example.com", "name": "John Doe", "sub": "1234567890", "tid": "a9b86385-f32c-4803-afc8-4b2312fbdf24"}, }, @@ -327,7 +328,7 @@ func TestProviderClaimsRespectsErrorCodes(t *testing.T) { ID: "dingtalk", Provider: "dingtalk", }, reg), - expectedClaims: &oidc.Claims{ + expectedClaims: &claims.Claims{ Issuer: "https://api.dingtalk.com/v1.0/contact/users/me", Subject: "123456789012345", Email: "john.doe@example.com", diff --git a/selfservice/strategy/oidc/provider_vk.go b/selfservice/strategy/oidc/provider_vk.go index 2a3513b6e050..97d77cb8af65 100644 --- a/selfservice/strategy/oidc/provider_vk.go +++ b/selfservice/strategy/oidc/provider_vk.go @@ -11,6 +11,7 @@ import ( "github.com/hashicorp/go-retryablehttp" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/httpx" "github.com/pkg/errors" @@ -59,7 +60,7 @@ func (g *ProviderVK) OAuth2(ctx context.Context) (*oauth2.Config, error) { return g.oauth2(ctx), nil } -func (g *ProviderVK) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (g *ProviderVK) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { o, err := g.OAuth2(ctx) if err != nil { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) @@ -118,7 +119,7 @@ func (g *ProviderVK) Claims(ctx context.Context, exchange *oauth2.Token, query u gender = "male" } - return &Claims{ + return &claims.Claims{ Issuer: "https://api.vk.com/method/users.get", Subject: strconv.Itoa(user.Id), GivenName: user.FirstName, diff --git a/selfservice/strategy/oidc/provider_x.go b/selfservice/strategy/oidc/provider_x.go index f58dbd48182f..3c4933d403b5 100644 --- a/selfservice/strategy/oidc/provider_x.go +++ b/selfservice/strategy/oidc/provider_x.go @@ -9,6 +9,7 @@ import ( "fmt" "net/http" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/otelx" "github.com/dghubble/oauth1" @@ -18,11 +19,15 @@ import ( "github.com/ory/herodot" ) -var _ Provider = (*ProviderX)(nil) -var _ OAuth1Provider = (*ProviderX)(nil) +var ( + _ Provider = (*ProviderX)(nil) + _ OAuth1Provider = (*ProviderX)(nil) +) -const xUserInfoBase = "https://api.twitter.com/1.1/account/verify_credentials.json" -const xUserInfoWithEmail = xUserInfoBase + "?include_email=true" +const ( + xUserInfoBase = "https://api.twitter.com/1.1/account/verify_credentials.json" + xUserInfoWithEmail = xUserInfoBase + "?include_email=true" +) type ProviderX struct { config *Configuration @@ -35,7 +40,8 @@ func (p *ProviderX) Config() *Configuration { func NewProviderX( config *Configuration, - reg Dependencies) Provider { + reg Dependencies, +) Provider { return &ProviderX{ config: config, reg: reg, @@ -107,7 +113,7 @@ func (p *ProviderX) userInfoEndpoint() string { return xUserInfoBase } -func (p *ProviderX) Claims(ctx context.Context, token *oauth1.Token) (*Claims, error) { +func (p *ProviderX) Claims(ctx context.Context, token *oauth1.Token) (*claims.Claims, error) { ctx = context.WithValue(ctx, oauth1.HTTPClient, p.reg.HTTPClient(ctx).HTTPClient) c := p.OAuth1(ctx) @@ -134,7 +140,7 @@ func (p *ProviderX) Claims(ctx context.Context, token *oauth1.Token) (*Claims, e website = *user.URL } - return &Claims{ + return &claims.Claims{ Issuer: endpoint, Subject: user.IDStr, Name: user.Name, diff --git a/selfservice/strategy/oidc/provider_yandex.go b/selfservice/strategy/oidc/provider_yandex.go index 07b30caee52b..6f8582cb5ee3 100644 --- a/selfservice/strategy/oidc/provider_yandex.go +++ b/selfservice/strategy/oidc/provider_yandex.go @@ -12,6 +12,7 @@ import ( "github.com/pkg/errors" "golang.org/x/oauth2" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/httpx" "github.com/ory/herodot" @@ -57,7 +58,7 @@ func (g *ProviderYandex) OAuth2(ctx context.Context) (*oauth2.Config, error) { return g.oauth2(ctx), nil } -func (g *ProviderYandex) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (g *ProviderYandex) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { o, err := g.OAuth2(ctx) if err != nil { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) @@ -100,7 +101,7 @@ func (g *ProviderYandex) Claims(ctx context.Context, exchange *oauth2.Token, que user.Picture = "" } - return &Claims{ + return &claims.Claims{ Issuer: "https://login.yandex.ru/info", Subject: user.Id, GivenName: user.FirstName, diff --git a/selfservice/strategy/oidc/strategy.go b/selfservice/strategy/oidc/strategy.go index e6733f0f0019..1a0e1c968e66 100644 --- a/selfservice/strategy/oidc/strategy.go +++ b/selfservice/strategy/oidc/strategy.go @@ -27,6 +27,7 @@ import ( "github.com/ory/kratos/cipher" "github.com/ory/kratos/selfservice/sessiontokenexchange" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/jsonnetsecure" "github.com/ory/x/otelx" @@ -422,7 +423,7 @@ func (s *Strategy) HandleCallback(w http.ResponseWriter, r *http.Request, ps htt return } - var claims *Claims + var claims *claims.Claims var et *identity.CredentialsOIDCEncryptedTokens switch p := provider.(type) { case OAuth2Provider: @@ -703,7 +704,7 @@ func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context) session.Au } } -func (s *Strategy) processIDToken(w http.ResponseWriter, r *http.Request, provider Provider, idToken, idTokenNonce string) (*Claims, error) { +func (s *Strategy) processIDToken(w http.ResponseWriter, r *http.Request, provider Provider, idToken, idTokenNonce string) (*claims.Claims, error) { verifier, ok := provider.(IDTokenVerifier) if !ok { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("The provider %s does not support id_token verification", provider.Config().Provider)) diff --git a/selfservice/strategy/oidc/strategy_helper_test.go b/selfservice/strategy/oidc/strategy_helper_test.go index 7b39729cc6af..0285f484c286 100644 --- a/selfservice/strategy/oidc/strategy_helper_test.go +++ b/selfservice/strategy/oidc/strategy_helper_test.go @@ -368,7 +368,7 @@ var publicJWKS []byte //go:embed stub/jwks_public2.json var publicJWKS2 []byte -type claims struct { +type jwtClaims struct { *jwt.RegisteredClaims Email string `json:"email"` } @@ -376,7 +376,7 @@ type claims struct { func createIdToken(t *testing.T, cl jwt.RegisteredClaims) string { key := &jwk.KeySpec{} require.NoError(t, json.Unmarshal(rawKey, key)) - token := jwt.NewWithClaims(jwt.SigningMethodRS256, &claims{ + token := jwt.NewWithClaims(jwt.SigningMethodRS256, &jwtClaims{ RegisteredClaims: &cl, Email: "acme@ory.sh", }) diff --git a/selfservice/strategy/oidc/strategy_login.go b/selfservice/strategy/oidc/strategy_login.go index 23e96eeb1504..135dce24bba6 100644 --- a/selfservice/strategy/oidc/strategy_login.go +++ b/selfservice/strategy/oidc/strategy_login.go @@ -25,6 +25,7 @@ import ( "github.com/ory/x/sqlcon" "github.com/ory/kratos/selfservice/flow/registration" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/text" @@ -107,7 +108,7 @@ type UpdateLoginFlowWithOidcMethod struct { TransientPayload json.RawMessage `json:"transient_payload,omitempty" form:"transient_payload"` } -func (s *Strategy) processLogin(ctx context.Context, w http.ResponseWriter, r *http.Request, loginFlow *login.Flow, token *identity.CredentialsOIDCEncryptedTokens, claims *Claims, provider Provider, container *AuthCodeContainer) (_ *registration.Flow, err error) { +func (s *Strategy) processLogin(ctx context.Context, w http.ResponseWriter, r *http.Request, loginFlow *login.Flow, token *identity.CredentialsOIDCEncryptedTokens, claims *claims.Claims, provider Provider, container *AuthCodeContainer) (_ *registration.Flow, err error) { ctx, span := s.d.Tracer(ctx).Tracer().Start(ctx, "selfservice.strategy.oidc.strategy.processLogin") defer otelx.End(span, &err) @@ -181,7 +182,7 @@ func (s *Strategy) processLogin(ctx context.Context, w http.ResponseWriter, r *h httprouter.ParamsFromContext(ctx).ByName("organization")) for _, c := range oidcCredentials.Providers { if c.Subject == claims.Subject && c.Provider == provider.Config().ID { - if err = s.d.LoginHookExecutor().PostLoginHook(w, r, node.OpenIDConnectGroup, loginFlow, i, sess, provider.Config().ID); err != nil { + if err = s.d.LoginHookExecutor().PostLoginHook(w, r, node.OpenIDConnectGroup, loginFlow, i, sess, claims, provider.Config().ID); err != nil { return nil, s.handleError(ctx, w, r, loginFlow, provider.Config().ID, nil, err) } return nil, nil diff --git a/selfservice/strategy/oidc/strategy_registration.go b/selfservice/strategy/oidc/strategy_registration.go index f756e8caf156..c54e320e7dce 100644 --- a/selfservice/strategy/oidc/strategy_registration.go +++ b/selfservice/strategy/oidc/strategy_registration.go @@ -24,6 +24,7 @@ import ( "github.com/ory/kratos/selfservice/flow" "github.com/ory/kratos/selfservice/flow/login" "github.com/ory/kratos/selfservice/flow/registration" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/text" "github.com/ory/kratos/x" "github.com/ory/x/decoderx" @@ -279,7 +280,7 @@ func (s *Strategy) registrationToLogin(w http.ResponseWriter, r *http.Request, r return lf, nil } -func (s *Strategy) processRegistration(ctx context.Context, w http.ResponseWriter, r *http.Request, rf *registration.Flow, token *identity.CredentialsOIDCEncryptedTokens, claims *Claims, provider Provider, container *AuthCodeContainer, idToken string) (_ *login.Flow, err error) { +func (s *Strategy) processRegistration(ctx context.Context, w http.ResponseWriter, r *http.Request, rf *registration.Flow, token *identity.CredentialsOIDCEncryptedTokens, claims *claims.Claims, provider Provider, container *AuthCodeContainer, idToken string) (_ *login.Flow, err error) { ctx, span := s.d.Tracer(ctx).Tracer().Start(ctx, "selfservice.strategy.oidc.strategy.processRegistration") defer otelx.End(span, &err) @@ -350,7 +351,7 @@ func (s *Strategy) processRegistration(ctx context.Context, w http.ResponseWrite return nil, nil } -func (s *Strategy) createIdentity(ctx context.Context, w http.ResponseWriter, r *http.Request, a *registration.Flow, claims *Claims, provider Provider, container *AuthCodeContainer, jsonnetSnippet []byte) (*identity.Identity, []VerifiedAddress, error) { +func (s *Strategy) createIdentity(ctx context.Context, w http.ResponseWriter, r *http.Request, a *registration.Flow, claims *claims.Claims, provider Provider, container *AuthCodeContainer, jsonnetSnippet []byte) (*identity.Identity, []VerifiedAddress, error) { var jsonClaims bytes.Buffer if err := json.NewEncoder(&jsonClaims).Encode(claims); err != nil { return nil, nil, s.handleError(ctx, w, r, a, provider.Config().ID, nil, err) @@ -399,7 +400,7 @@ func (s *Strategy) createIdentity(ctx context.Context, w http.ResponseWriter, r return i, va, nil } -func (s *Strategy) setTraits(ctx context.Context, w http.ResponseWriter, r *http.Request, a *registration.Flow, claims *Claims, provider Provider, container *AuthCodeContainer, evaluated string, i *identity.Identity) error { +func (s *Strategy) setTraits(ctx context.Context, w http.ResponseWriter, r *http.Request, a *registration.Flow, claims *claims.Claims, provider Provider, container *AuthCodeContainer, evaluated string, i *identity.Identity) error { jsonTraits := gjson.Get(evaluated, "identity.traits") if !jsonTraits.IsObject() { return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("OpenID Connect Jsonnet mapper did not return an object for key identity.traits. Please check your Jsonnet code!")) diff --git a/selfservice/strategy/oidc/strategy_settings.go b/selfservice/strategy/oidc/strategy_settings.go index 82df2298692c..a98c5a39d324 100644 --- a/selfservice/strategy/oidc/strategy_settings.go +++ b/selfservice/strategy/oidc/strategy_settings.go @@ -20,6 +20,7 @@ import ( "github.com/ory/kratos/continuity" "github.com/ory/kratos/selfservice/strategy" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/decoderx" "github.com/ory/kratos/session" @@ -409,7 +410,7 @@ func (s *Strategy) initLinkProvider(w http.ResponseWriter, r *http.Request, ctxU return errors.WithStack(flow.ErrCompletedByStrategy) } -func (s *Strategy) linkProvider(w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, token *identity.CredentialsOIDCEncryptedTokens, claims *Claims, provider Provider) error { +func (s *Strategy) linkProvider(w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, token *identity.CredentialsOIDCEncryptedTokens, claims *claims.Claims, provider Provider) error { ctx := r.Context() p := &updateSettingsFlowWithOidcMethod{ Link: provider.Config().ID, FlowID: ctxUpdate.Flow.ID.String(), diff --git a/selfservice/strategy/oidc/token_verifier.go b/selfservice/strategy/oidc/token_verifier.go index ce9cb8b3d3ee..864d9faa153e 100644 --- a/selfservice/strategy/oidc/token_verifier.go +++ b/selfservice/strategy/oidc/token_verifier.go @@ -9,9 +9,11 @@ import ( "strings" "github.com/coreos/go-oidc/v3/oidc" + + "github.com/ory/kratos/selfservice/strategy/oidc/claims" ) -func verifyToken(ctx context.Context, keySet oidc.KeySet, config *Configuration, rawIDToken, issuerURL string) (*Claims, error) { +func verifyToken(ctx context.Context, keySet oidc.KeySet, config *Configuration, rawIDToken, issuerURL string) (*claims.Claims, error) { tokenAudiences := append([]string{config.ClientID}, config.AdditionalIDTokenAudiences...) var token *oidc.IDToken err := fmt.Errorf("no audience matched the token's audience") @@ -34,7 +36,7 @@ func verifyToken(ctx context.Context, keySet oidc.KeySet, config *Configuration, // None of the allowed audiences matched the audience in the token return nil, fmt.Errorf("token audience didn't match allowed audiences: %+v %w", tokenAudiences, err) } - claims := &Claims{} + claims := &claims.Claims{} if err := token.Claims(claims); err != nil { return nil, err }