diff --git a/selfservice/strategy/oidc/strategy_registration.go b/selfservice/strategy/oidc/strategy_registration.go index 1e303c50df35..ca286714f06e 100644 --- a/selfservice/strategy/oidc/strategy_registration.go +++ b/selfservice/strategy/oidc/strategy_registration.go @@ -38,9 +38,17 @@ var _ registration.Strategy = new(Strategy) type MetadataType string +type OIDCProviderData struct { + Provider string `json:"provider"` + Token string `json:"token"` + Claims Claims `json:"claims"` +} + const ( PublicMetadata MetadataType = "identity.metadata_public" AdminMetadata MetadataType = "identity.metadata_admin" + + InternalContextKeyProviderData = "provider_data" ) func (s *Strategy) RegisterRegistrationRoutes(r *x.RouterPublic) { @@ -157,6 +165,33 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registrat return errors.WithStack(registration.ErrAlreadyLoggedIn) } + if oidcProviderData := gjson.GetBytes(f.InternalContext, flow.PrefixInternalContextKey(s.ID(), InternalContextKeyProviderData)); oidcProviderData.IsObject() { + var providerData OIDCProviderData + if err := json.Unmarshal([]byte(oidcProviderData.Raw), &providerData); err != nil { + return s.handleError(w, r, f, pid, nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Expected OIDC provider data in internal context to be an object but got: %s", err))) + } + if pid != providerData.Provider { + return s.handleError(w, r, f, pid, nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Expected OIDC provider data in internal context to have matching provider but got: %s", providerData.Provider))) + } + + decryptedToken, err := s.d.Cipher(r.Context()).Decrypt(r.Context(), providerData.Token) + if err != nil { + return s.handleError(w, r, f, pid, nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Could not decrypt OAuth2 token from internal context's OIDC provider data: %s", err))) + } + + var token oauth2.Token + if err := json.Unmarshal(decryptedToken, &token); err != nil { + return s.handleError(w, r, f, pid, nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Could not unmarshal OAuth2 token from internal context's OIDC provider data: %s", err))) + } + + _, err = s.processRegistration(w, r, f, &token, &providerData.Claims, provider, &authCodeContainer{ + FlowID: f.ID.String(), + Traits: p.Traits, + TransientPayload: f.TransientPayload, + }) + return err + } + state := generateState(f.ID.String()) if code, hasCode, _ := s.d.SessionTokenExchangePersister().CodeForFlow(r.Context(), f.ID); hasCode { state.setCode(code.InitCode) @@ -243,6 +278,17 @@ func (s *Strategy) processRegistration(w http.ResponseWriter, r *http.Request, r return nil, nil } + providerDataKey := flow.PrefixInternalContextKey(s.ID(), InternalContextKeyProviderData) + if hasOIDCProviderData := gjson.GetBytes(rf.InternalContext, providerDataKey).IsObject(); !hasOIDCProviderData { + if tokenJSON, err := json.Marshal(token); err == nil { + if encryptedToken, err := s.d.Cipher(r.Context()).Encrypt(r.Context(), tokenJSON); err == nil { + if internalContext, err := sjson.SetBytes(rf.InternalContext, providerDataKey, &OIDCProviderData{Provider: provider.Config().ID, Token: encryptedToken, Claims: *claims}); err == nil { + rf.InternalContext = internalContext + } + } + } + } + fetch := fetcher.NewFetcher(fetcher.WithClient(s.d.HTTPClient(r.Context()))) jn, err := fetch.FetchContext(r.Context(), provider.Config().Mapper) if err != nil { @@ -286,6 +332,10 @@ func (s *Strategy) processRegistration(w http.ResponseWriter, r *http.Request, r return nil, s.handleError(w, r, rf, provider.Config().ID, i.Traits, err) } + if internalContext, err := sjson.DeleteBytes(rf.InternalContext, providerDataKey); err == nil { + rf.InternalContext = internalContext + } + return nil, nil }