Skip to content

Commit

Permalink
feat: redirect to OIDC providers only once in registration flows
Browse files Browse the repository at this point in the history
  • Loading branch information
Saancreed committed Aug 3, 2023
1 parent cd9e6a0 commit 4b97280
Showing 1 changed file with 50 additions and 0 deletions.
50 changes: 50 additions & 0 deletions selfservice/strategy/oidc/strategy_registration.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,17 @@ var _ registration.Strategy = new(Strategy)

type MetadataType string

type OIDCProviderData struct {
Provider string `json:"provider"`
Token string `json:"token"`
Claims Claims `json:"claims"`
}

const (
PublicMetadata MetadataType = "identity.metadata_public"
AdminMetadata MetadataType = "identity.metadata_admin"

InternalContextKeyProviderData = "provider_data"
)

func (s *Strategy) RegisterRegistrationRoutes(r *x.RouterPublic) {
Expand Down Expand Up @@ -157,6 +165,33 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registrat
return errors.WithStack(registration.ErrAlreadyLoggedIn)
}

if oidcProviderData := gjson.GetBytes(f.InternalContext, flow.PrefixInternalContextKey(s.ID(), InternalContextKeyProviderData)); oidcProviderData.IsObject() {
var providerData OIDCProviderData
if err := json.Unmarshal([]byte(oidcProviderData.Raw), &providerData); err != nil {
return s.handleError(w, r, f, pid, nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Expected OIDC provider data in internal context to be an object but got: %s", err)))
}
if pid != providerData.Provider {
return s.handleError(w, r, f, pid, nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Expected OIDC provider data in internal context to have matching provider but got: %s", providerData.Provider)))
}

decryptedToken, err := s.d.Cipher(r.Context()).Decrypt(r.Context(), providerData.Token)
if err != nil {
return s.handleError(w, r, f, pid, nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Could not decrypt OAuth2 token from internal context's OIDC provider data: %s", err)))
}

var token oauth2.Token
if err := json.Unmarshal(decryptedToken, &token); err != nil {
return s.handleError(w, r, f, pid, nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Could not unmarshal OAuth2 token from internal context's OIDC provider data: %s", err)))
}

_, err = s.processRegistration(w, r, f, &token, &providerData.Claims, provider, &authCodeContainer{
FlowID: f.ID.String(),
Traits: p.Traits,
TransientPayload: f.TransientPayload,
})
return err
}

state := generateState(f.ID.String())
if code, hasCode, _ := s.d.SessionTokenExchangePersister().CodeForFlow(r.Context(), f.ID); hasCode {
state.setCode(code.InitCode)
Expand Down Expand Up @@ -243,6 +278,17 @@ func (s *Strategy) processRegistration(w http.ResponseWriter, r *http.Request, r
return nil, nil
}

providerDataKey := flow.PrefixInternalContextKey(s.ID(), InternalContextKeyProviderData)
if hasOIDCProviderData := gjson.GetBytes(rf.InternalContext, providerDataKey).IsObject(); !hasOIDCProviderData {
if tokenJSON, err := json.Marshal(token); err == nil {
if encryptedToken, err := s.d.Cipher(r.Context()).Encrypt(r.Context(), tokenJSON); err == nil {
if internalContext, err := sjson.SetBytes(rf.InternalContext, providerDataKey, &OIDCProviderData{Provider: provider.Config().ID, Token: encryptedToken, Claims: *claims}); err == nil {
rf.InternalContext = internalContext
}
}
}
}

fetch := fetcher.NewFetcher(fetcher.WithClient(s.d.HTTPClient(r.Context())))
jn, err := fetch.FetchContext(r.Context(), provider.Config().Mapper)
if err != nil {
Expand Down Expand Up @@ -286,6 +332,10 @@ func (s *Strategy) processRegistration(w http.ResponseWriter, r *http.Request, r
return nil, s.handleError(w, r, rf, provider.Config().ID, i.Traits, err)
}

if internalContext, err := sjson.DeleteBytes(rf.InternalContext, providerDataKey); err == nil {
rf.InternalContext = internalContext
}

return nil, nil
}

Expand Down

0 comments on commit 4b97280

Please sign in to comment.