diff --git a/consent/handler.go b/consent/handler.go index 2eea50bc8d..89e521058b 100644 --- a/consent/handler.go +++ b/consent/handler.go @@ -4,6 +4,7 @@ package consent import ( + "context" "encoding/json" "net/http" "net/url" @@ -476,11 +477,12 @@ func (h *Handler) acceptOAuth2LoginRequest(w http.ResponseWriter, r *http.Reques } handledLoginRequest.RequestedAt = loginRequest.RequestedAt - f, err := flowctx.Decode[flow.Flow](ctx, h.r.FlowCipher(), challenge, flowctx.AsLoginChallenge) + f, err := h.decodeFlowWithClient(ctx, challenge, flowctx.AsLoginChallenge) if err != nil { h.r.Writer().WriteError(w, r, err) return } + request, err := h.r.ConsentManager().HandleLoginRequest(ctx, f, challenge, &handledLoginRequest) if err != nil { h.r.Writer().WriteError(w, r, errorsx.WithStack(err)) @@ -575,7 +577,7 @@ func (h *Handler) rejectOAuth2LoginRequest(w http.ResponseWriter, r *http.Reques return } - f, err := flowctx.Decode[flow.Flow](ctx, h.r.FlowCipher(), challenge, flowctx.AsLoginChallenge) + f, err := h.decodeFlowWithClient(ctx, challenge, flowctx.AsLoginChallenge) if err != nil { h.r.Writer().WriteError(w, r, err) return @@ -761,7 +763,7 @@ func (h *Handler) acceptOAuth2ConsentRequest(w http.ResponseWriter, r *http.Requ p.RequestedAt = cr.RequestedAt p.HandledAt = sqlxx.NullTime(time.Now().UTC()) - f, err := flowctx.Decode[flow.Flow](ctx, h.r.FlowCipher(), challenge, flowctx.AsConsentChallenge) + f, err := h.decodeFlowWithClient(ctx, challenge, flowctx.AsConsentChallenge) if err != nil { h.r.Writer().WriteError(w, r, err) return @@ -868,7 +870,7 @@ func (h *Handler) rejectOAuth2ConsentRequest(w http.ResponseWriter, r *http.Requ return } - f, err := flowctx.Decode[flow.Flow](ctx, h.r.FlowCipher(), challenge, flowctx.AsConsentChallenge) + f, err := h.decodeFlowWithClient(ctx, challenge, flowctx.AsConsentChallenge) if err != nil { h.r.Writer().WriteError(w, r, err) return @@ -1044,3 +1046,17 @@ func (h *Handler) getOAuth2LogoutRequest(w http.ResponseWriter, r *http.Request, h.r.Writer().Write(w, r, request) } + +func (h *Handler) decodeFlowWithClient(ctx context.Context, challenge string, opts ...flowctx.CodecOption) (*flow.Flow, error) { + f, err := flowctx.Decode[flow.Flow](ctx, h.r.FlowCipher(), challenge, opts...) + if err != nil { + return nil, err + } + + f.Client, err = h.r.ClientManager().GetConcreteClient(ctx, f.ClientID) + if err != nil { + return nil, err + } + + return f, nil +} diff --git a/consent/strategy_default.go b/consent/strategy_default.go index 117fba9254..cb1f83ef42 100644 --- a/consent/strategy_default.go +++ b/consent/strategy_default.go @@ -348,6 +348,11 @@ func (s *DefaultStrategy) verifyAuthentication( return nil, errorsx.WithStack(fosite.ErrAccessDenied.WithHint("The login verifier is invalid.")) } + f.Client, err = s.r.ClientManager().GetConcreteClient(ctx, f.ClientID) + if err != nil { + return nil, err + } + session, err := s.r.ConsentManager().VerifyAndInvalidateLoginRequest(ctx, verifier) if errors.Is(err, sqlcon.ErrNoRows) { return nil, errorsx.WithStack(fosite.ErrAccessDenied.WithHint("The login verifier has already been used, has not been granted, or is invalid.")) @@ -652,6 +657,12 @@ func (s *DefaultStrategy) verifyConsent(ctx context.Context, _ http.ResponseWrit if err != nil { return nil, nil, errorsx.WithStack(fosite.ErrAccessDenied.WithHint("The consent verifier has already been used, has not been granted, or is invalid.")) } + + f.Client, err = s.r.ClientManager().GetConcreteClient(ctx, f.ClientID) + if err != nil { + return nil, nil, err + } + if f.Client.GetID() != r.URL.Query().Get("client_id") { return nil, nil, errorsx.WithStack(fosite.ErrInvalidClient.WithHint("The flow client id does not match the authorize request client id.")) } diff --git a/flow/flow.go b/flow/flow.go index 9003fc029f..7db6b9bd74 100644 --- a/flow/flow.go +++ b/flow/flow.go @@ -119,7 +119,7 @@ type Flow struct { // Client is the OAuth 2.0 Client that initiated the request. // // required: true - Client *client.Client `db:"-" json:"c,omitempty"` + Client *client.Client `db:"-" json:"-"` ClientID string `db:"client_id" json:"ci,omitempty"` // RequestURL is the original OAuth 2.0 Authorization URL requested by the OAuth 2.0 client. It is the URL which @@ -511,20 +511,32 @@ type CipherProvider interface { // ToLoginChallenge converts the flow into a login challenge. func (f *Flow) ToLoginChallenge(ctx context.Context, cipherProvider CipherProvider) (string, error) { + if f.Client != nil { + f.ClientID = f.Client.GetID() + } return flowctx.Encode(ctx, cipherProvider.FlowCipher(), f, flowctx.AsLoginChallenge) } // ToLoginVerifier converts the flow into a login verifier. func (f *Flow) ToLoginVerifier(ctx context.Context, cipherProvider CipherProvider) (string, error) { + if f.Client != nil { + f.ClientID = f.Client.GetID() + } return flowctx.Encode(ctx, cipherProvider.FlowCipher(), f, flowctx.AsLoginVerifier) } // ToConsentChallenge converts the flow into a consent challenge. func (f *Flow) ToConsentChallenge(ctx context.Context, cipherProvider CipherProvider) (string, error) { + if f.Client != nil { + f.ClientID = f.Client.GetID() + } return flowctx.Encode(ctx, cipherProvider.FlowCipher(), f, flowctx.AsConsentChallenge) } // ToConsentVerifier converts the flow into a consent verifier. func (f *Flow) ToConsentVerifier(ctx context.Context, cipherProvider CipherProvider) (string, error) { + if f.Client != nil { + f.ClientID = f.Client.GetID() + } return flowctx.Encode(ctx, cipherProvider.FlowCipher(), f, flowctx.AsConsentVerifier) } diff --git a/oauth2/flowctx/encoding.go b/oauth2/flowctx/encoding.go index e3a62be328..b38300e89a 100644 --- a/oauth2/flowctx/encoding.go +++ b/oauth2/flowctx/encoding.go @@ -14,7 +14,6 @@ import ( "github.com/ory/fosite" "github.com/ory/hydra/v2/aead" - "github.com/ory/hydra/v2/driver/config" ) type ( @@ -103,29 +102,6 @@ func Encode(ctx context.Context, cipher aead.Cipher, val any, opts ...CodecOptio return cipher.Encrypt(ctx, b.Bytes(), additionalDataFromOpts(opts...)) } -// SetCookie encrypts the given value and sets it in a cookie. -func SetCookie(ctx context.Context, w http.ResponseWriter, reg interface { - FlowCipher() *aead.XChaCha20Poly1305 - config.Provider -}, cookieName string, value any, opts ...CodecOption) error { - cipher := reg.FlowCipher() - cookie, err := Encode(ctx, cipher, value, opts...) - if err != nil { - return err - } - - http.SetCookie(w, &http.Cookie{ - Name: cookieName, - Value: cookie, - HttpOnly: true, - Domain: reg.Config().CookieDomain(ctx), - Secure: reg.Config().CookieSecure(ctx), - SameSite: reg.Config().CookieSameSiteMode(ctx), - }) - - return nil -} - // FromCookie looks up the value stored in the cookie and decodes it. func FromCookie[T any](ctx context.Context, r *http.Request, cipher aead.Cipher, cookieName string, opts ...CodecOption) (*T, error) { cookie, err := r.Cookie(cookieName) diff --git a/persistence/sql/persister_consent.go b/persistence/sql/persister_consent.go index 0b4582c3f4..a9821b8744 100644 --- a/persistence/sql/persister_consent.go +++ b/persistence/sql/persister_consent.go @@ -198,6 +198,10 @@ func (p *Persister) GetFlowByConsentChallenge(ctx context.Context, challenge str if f.RequestedAt.Add(p.config.ConsentRequestMaxAge(ctx)).Before(time.Now()) { return nil, errorsx.WithStack(fosite.ErrRequestUnauthorized.WithHint("The consent request has expired, please try again.")) } + f.Client, err = p.GetConcreteClient(ctx, f.ClientID) + if err != nil { + return nil, err + } return f, nil } @@ -262,6 +266,10 @@ func (p *Persister) GetLoginRequest(ctx context.Context, loginChallenge string) if f.RequestedAt.Add(p.config.ConsentRequestMaxAge(ctx)).Before(time.Now()) { return nil, errorsx.WithStack(fosite.ErrRequestUnauthorized.WithHint("The login request has expired, please try again.")) } + f.Client, err = p.GetConcreteClient(ctx, f.ClientID) + if err != nil { + return nil, err + } lr := f.GetLoginRequest() // Restore the short challenge ID, which was previously sent to the encoded flow, // to make sure that the challenge ID in the returned flow matches the param. @@ -301,6 +309,10 @@ func (p *Persister) VerifyAndInvalidateConsentRequest(ctx context.Context, verif if f.NID != p.NetworkID(ctx) { return nil, errorsx.WithStack(sqlcon.ErrNoRows) } + f.Client, err = p.GetConcreteClient(ctx, f.ClientID) + if err != nil { + return nil, err + } if err = f.InvalidateConsentRequest(); err != nil { return nil, errorsx.WithStack(fosite.ErrInvalidRequest.WithDebug(err.Error())) @@ -347,6 +359,10 @@ func (p *Persister) VerifyAndInvalidateLoginRequest(ctx context.Context, verifie if f.NID != p.NetworkID(ctx) { return nil, errorsx.WithStack(sqlcon.ErrNoRows) } + f.Client, err = p.GetConcreteClient(ctx, f.ClientID) + if err != nil { + return nil, err + } if err := f.InvalidateLoginRequest(); err != nil { return nil, errorsx.WithStack(fosite.ErrInvalidRequest.WithDebug(err.Error()))