From 0c6372a437711d4a932c215eaaee37751200f715 Mon Sep 17 00:00:00 2001 From: Tom Fenech Date: Fri, 14 Jun 2024 06:31:14 +0200 Subject: [PATCH] refactor: use functional options pattern to reduce passing nil --- selfservice/flow/login/handler.go | 2 +- selfservice/flow/login/hook.go | 17 +++++++++++++++-- selfservice/flow/login/hook_test.go | 2 +- selfservice/strategy/oidc/strategy_login.go | 2 +- 4 files changed, 18 insertions(+), 5 deletions(-) diff --git a/selfservice/flow/login/handler.go b/selfservice/flow/login/handler.go index cf97ff7f9e44..63c38731ef19 100644 --- a/selfservice/flow/login/handler.go +++ b/selfservice/flow/login/handler.go @@ -853,7 +853,7 @@ continueLogin: return } - if err := h.d.LoginHookExecutor().PostLoginHook(w, r, group, f, i, sess, nil, ""); err != nil { + if err := h.d.LoginHookExecutor().PostLoginHook(w, r, group, f, i, sess, ""); err != nil { if errors.Is(err, ErrAddressNotVerified) { h.d.LoginFlowErrorHandler().WriteFlowError(w, r, f, node.DefaultGroup, errors.WithStack(schema.NewAddressNotVerifiedError())) return diff --git a/selfservice/flow/login/hook.go b/selfservice/flow/login/hook.go index 16f40e96f34f..73a484c22e50 100644 --- a/selfservice/flow/login/hook.go +++ b/selfservice/flow/login/hook.go @@ -63,6 +63,7 @@ type ( } HookExecutor struct { d executorDependencies + c *claims.Claims } HookExecutorProvider interface { LoginHookExecutor() *HookExecutor @@ -119,6 +120,14 @@ func (e *HookExecutor) handleLoginError(_ http.ResponseWriter, r *http.Request, return flowError } +type PostLoginHookOpt func(*HookExecutor) + +func WithClaims(c *claims.Claims) PostLoginHookOpt { + return func(h *HookExecutor) { + h.c = c + } +} + func (e *HookExecutor) PostLoginHook( w http.ResponseWriter, r *http.Request, @@ -126,8 +135,8 @@ func (e *HookExecutor) PostLoginHook( f *Flow, i *identity.Identity, s *session.Session, - c *claims.Claims, provider string, + opts ...PostLoginHookOpt, ) (err error) { ctx := r.Context() ctx, span := e.d.Tracer(ctx).Tracer().Start(ctx, "HookExecutor.PostLoginHook") @@ -168,13 +177,17 @@ func (e *HookExecutor) PostLoginHook( classified := s s = s.Declassified() + for _, o := range opts { + o(e) + } + e.d.Logger(). WithRequest(r). WithField("identity_id", i.ID). WithField("flow_method", f.Active). Debug("Running ExecuteLoginPostHook.") for k, executor := range e.d.PostLoginHooks(r.Context(), f.Active) { - if err := executor.ExecuteLoginPostHook(w, r, g, f, s, c); err != nil { + if err := executor.ExecuteLoginPostHook(w, r, g, f, s, e.c); err != nil { if errors.Is(err, ErrHookAbortFlow) { e.d.Logger(). WithRequest(r). diff --git a/selfservice/flow/login/hook_test.go b/selfservice/flow/login/hook_test.go index 351f7bd217fe..7d7f8e158174 100644 --- a/selfservice/flow/login/hook_test.go +++ b/selfservice/flow/login/hook_test.go @@ -72,7 +72,7 @@ func TestLoginExecutor(t *testing.T) { } testhelpers.SelfServiceHookLoginErrorHandler(t, w, r, - reg.LoginHookExecutor().PostLoginHook(w, r, strategy.ToUiNodeGroup(), loginFlow, useIdentity, sess, nil, "")) + reg.LoginHookExecutor().PostLoginHook(w, r, strategy.ToUiNodeGroup(), loginFlow, useIdentity, sess, "")) }) ts := httptest.NewServer(router) diff --git a/selfservice/strategy/oidc/strategy_login.go b/selfservice/strategy/oidc/strategy_login.go index 962f17c5fe8d..88f0c7727537 100644 --- a/selfservice/strategy/oidc/strategy_login.go +++ b/selfservice/strategy/oidc/strategy_login.go @@ -178,7 +178,7 @@ func (s *Strategy) processLogin(w http.ResponseWriter, r *http.Request, loginFlo httprouter.ParamsFromContext(r.Context()).ByName("organization")) for _, c := range oidcCredentials.Providers { if c.Subject == claims.Subject && c.Provider == provider.Config().ID { - if err = s.d.LoginHookExecutor().PostLoginHook(w, r, node.OpenIDConnectGroup, loginFlow, i, sess, claims, provider.Config().ID); err != nil { + if err = s.d.LoginHookExecutor().PostLoginHook(w, r, node.OpenIDConnectGroup, loginFlow, i, sess, provider.Config().ID, login.WithClaims(claims)); err != nil { return nil, s.handleError(w, r, loginFlow, provider.Config().ID, nil, err) } return nil, nil