diff --git a/selfservice/flow/settings/handler.go b/selfservice/flow/settings/handler.go index 6232f47b7a9b..bbf3a9010ee6 100644 --- a/selfservice/flow/settings/handler.go +++ b/selfservice/flow/settings/handler.go @@ -214,7 +214,7 @@ type createNativeSettingsFlow struct { // default: errorGeneric func (h *Handler) createNativeSettingsFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { ctx := r.Context() - s, err := h.d.SessionManager().FetchFromRequest(ctx, r) + s, err := h.d.SessionManager().FetchFromRequestContext(ctx, r) if err != nil { h.d.Writer().WriteError(w, r, err) return @@ -298,7 +298,7 @@ type createBrowserSettingsFlow struct { // default: errorGeneric func (h *Handler) createBrowserSettingsFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { ctx := r.Context() - s, err := h.d.SessionManager().FetchFromRequest(ctx, r) + s, err := h.d.SessionManager().FetchFromRequestContext(ctx, r) if err != nil { h.d.SelfServiceErrorManager().Forward(ctx, w, r, err) return @@ -404,7 +404,7 @@ func (h *Handler) getSettingsFlow(w http.ResponseWriter, r *http.Request, _ http return } - sess, err := h.d.SessionManager().FetchFromRequest(ctx, r) + sess, err := h.d.SessionManager().FetchFromRequestContext(ctx, r) if err != nil { h.d.Writer().WriteError(w, r, err) return @@ -574,7 +574,7 @@ func (h *Handler) updateSettingsFlow(w http.ResponseWriter, r *http.Request, ps return } - ss, err := h.d.SessionManager().FetchFromRequest(ctx, r) + ss, err := h.d.SessionManager().FetchFromRequestContext(ctx, r) if err != nil { h.d.SettingsFlowErrorHandler().WriteFlowError(w, r, node.DefaultGroup, f, nil, err) return diff --git a/session/handler.go b/session/handler.go index 8b8cd84e3bd8..ad10e389af99 100644 --- a/session/handler.go +++ b/session/handler.go @@ -4,6 +4,7 @@ package session import ( + "context" "fmt" "net/http" "strconv" @@ -837,9 +838,17 @@ func (h *Handler) listMySessions(w http.ResponseWriter, r *http.Request, _ httpr h.r.Writer().Write(w, r, sess) } +type sessionInContext int + +const ( + sessionInContextKey sessionInContext = iota +) + func (h *Handler) IsAuthenticated(wrap httprouter.Handle, onUnauthenticated httprouter.Handle) httprouter.Handle { return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - if _, err := h.r.SessionManager().FetchFromRequest(r.Context(), r); err != nil { + ctx := r.Context() + sess, err := h.r.SessionManager().FetchFromRequest(ctx, r) + if err != nil { if onUnauthenticated != nil { onUnauthenticated(w, r, ps) return @@ -849,7 +858,7 @@ func (h *Handler) IsAuthenticated(wrap httprouter.Handle, onUnauthenticated http return } - wrap(w, r, ps) + wrap(w, r.WithContext(context.WithValue(ctx, sessionInContextKey, sess)), ps) } } diff --git a/session/manager.go b/session/manager.go index e44d91f9111f..e6409c21dfd9 100644 --- a/session/manager.go +++ b/session/manager.go @@ -133,6 +133,9 @@ type Manager interface { // FetchFromRequest creates an HTTP session using cookies. FetchFromRequest(context.Context, *http.Request) (*Session, error) + // FetchFromRequestContext returns the session from the context or if that is unset, falls back to FetchFromRequest. + FetchFromRequestContext(context.Context, *http.Request) (*Session, error) + // PurgeFromRequest removes an HTTP session. PurgeFromRequest(context.Context, http.ResponseWriter, *http.Request) error diff --git a/session/manager_http.go b/session/manager_http.go index b56a16be88ee..2b3ffc6aa1e4 100644 --- a/session/manager_http.go +++ b/session/manager_http.go @@ -227,6 +227,17 @@ func (s *ManagerHTTP) extractToken(r *http.Request) string { return token } +func (s *ManagerHTTP) FetchFromRequestContext(ctx context.Context, r *http.Request) (_ *Session, err error) { + ctx, span := s.r.Tracer(ctx).Tracer().Start(ctx, "sessions.ManagerHTTP.FetchFromRequestContext") + otelx.End(span, &err) + + if sess, ok := ctx.Value(sessionInContextKey).(*Session); ok { + return sess, nil + } + + return s.FetchFromRequest(ctx, r) +} + func (s *ManagerHTTP) FetchFromRequest(ctx context.Context, r *http.Request) (_ *Session, err error) { ctx, span := s.r.Tracer(ctx).Tracer().Start(ctx, "sessions.ManagerHTTP.FetchFromRequest") defer func() { diff --git a/session/manager_http_test.go b/session/manager_http_test.go index a2ca87893075..0bbfdd625461 100644 --- a/session/manager_http_test.go +++ b/session/manager_http_test.go @@ -244,6 +244,16 @@ func TestManagerHTTP(t *testing.T) { reg.Writer().Write(w, r, sess) }) + rp.GET("/session/get-middleware", reg.SessionHandler().IsAuthenticated(func(w http.ResponseWriter, r *http.Request, p httprouter.Params) { + sess, err := reg.SessionManager().FetchFromRequestContext(r.Context(), r) + if err != nil { + t.Logf("Got error on lookup: %s %T", err, errors.Unwrap(err)) + reg.Writer().WriteError(w, r, err) + return + } + reg.Writer().Write(w, r, sess) + }, session.RedirectOnUnauthenticated("https://failed.com"))) + pts := httptest.NewServer(x.NewTestCSRFHandler(rp, reg)) t.Cleanup(pts.Close) conf.MustSet(ctx, config.ViperKeyPublicBaseURL, pts.URL) @@ -263,6 +273,10 @@ func TestManagerHTTP(t *testing.T) { res, err := c.Get(pts.URL + "/session/get") require.NoError(t, err) assert.EqualValues(t, http.StatusOK, res.StatusCode) + + res, err = c.Get(pts.URL + "/session/get-middleware") + require.NoError(t, err) + assert.EqualValues(t, http.StatusOK, res.StatusCode) }) t.Run("case=key rotation", func(t *testing.T) {