From d32570e20b6ffb08285e5d68f41c5e8fb853bae3 Mon Sep 17 00:00:00 2001 From: Alex Saunders Date: Thu, 1 Feb 2024 15:29:21 +0000 Subject: [PATCH] prefer passing org ID --- internal/app/organisation_store.go | 12 ++-- internal/page/donor/register.go | 12 +++- internal/page/fixtures/supporter.go | 17 ++++-- internal/page/paths.go | 4 -- internal/page/supporter/dashboard.go | 2 +- .../page/supporter/enter_organisation_name.go | 16 +++++- internal/page/supporter/login_callback.go | 24 +++++--- internal/page/supporter/mock_Handler_test.go | 26 +++++---- .../supporter/mock_OrganisationStore_test.go | 28 +++++++--- internal/page/supporter/register.go | 56 +------------------ internal/sesh/sesh.go | 7 ++- 11 files changed, 99 insertions(+), 105 deletions(-) diff --git a/internal/app/organisation_store.go b/internal/app/organisation_store.go index 3e6c8c9619..7291f7bdbf 100644 --- a/internal/app/organisation_store.go +++ b/internal/app/organisation_store.go @@ -17,14 +17,14 @@ type organisationStore struct { now func() time.Time } -func (s *organisationStore) Create(ctx context.Context, name string) error { +func (s *organisationStore) Create(ctx context.Context, name string) (*actor.Organisation, error) { data, err := page.SessionDataFromContext(ctx) if err != nil { - return err + return nil, err } if data.SessionID == "" { - return errors.New("organisationStore.Create requires SessionID") + return nil, errors.New("organisationStore.Create requires SessionID") } organisationID := s.uuidString() @@ -38,7 +38,7 @@ func (s *organisationStore) Create(ctx context.Context, name string) error { } if err := s.dynamoClient.Create(ctx, organisation); err != nil { - return fmt.Errorf("error creating organisation: %w", err) + return nil, fmt.Errorf("error creating organisation: %w", err) } member := &actor.Member{ @@ -48,10 +48,10 @@ func (s *organisationStore) Create(ctx context.Context, name string) error { } if err := s.dynamoClient.Create(ctx, member); err != nil { - return fmt.Errorf("error creating organisation member: %w", err) + return nil, fmt.Errorf("error creating organisation member: %w", err) } - return nil + return organisation, nil } func (s *organisationStore) Get(ctx context.Context) (*actor.Organisation, error) { diff --git a/internal/page/donor/register.go b/internal/page/donor/register.go index 34385f463d..2aab12b392 100644 --- a/internal/page/donor/register.go +++ b/internal/page/donor/register.go @@ -436,15 +436,21 @@ func makeLpaHandle(mux *http.ServeMux, store sesh.Store, defaultOptions page.Han appData.ActorType = actor.TypeDonor appData.AppPublicURL = appPublicURL - donorSession, err := sesh.Login(store, r) + loginSession, err := sesh.Login(store, r) if err != nil { http.Redirect(w, r, page.Paths.Start.Format(), http.StatusFound) return } - appData.SessionID = donorSession.SessionID() - sessionData, err := page.SessionDataFromContext(ctx) + + appData.SessionID = loginSession.SessionID() + + if loginSession.OrganisationID != "" { + appData.IsSupporter = true + sessionData.OrganisationID = loginSession.OrganisationID + } + if err == nil { sessionData.SessionID = appData.SessionID ctx = page.ContextWithSessionData(ctx, sessionData) diff --git a/internal/page/fixtures/supporter.go b/internal/page/fixtures/supporter.go index f0d0e5da47..7d4c3c6057 100644 --- a/internal/page/fixtures/supporter.go +++ b/internal/page/fixtures/supporter.go @@ -5,13 +5,14 @@ import ( "encoding/base64" "net/http" + "github.com/ministryofjustice/opg-modernising-lpa/internal/actor" "github.com/ministryofjustice/opg-modernising-lpa/internal/page" "github.com/ministryofjustice/opg-modernising-lpa/internal/random" "github.com/ministryofjustice/opg-modernising-lpa/internal/sesh" ) type OrganisationStore interface { - Create(context.Context, string) error + Create(context.Context, string) (*actor.Organisation, error) } func Supporter(sessionStore sesh.Store, organisationStore OrganisationStore) page.Handler { @@ -25,14 +26,20 @@ func Supporter(sessionStore sesh.Store, organisationStore OrganisationStore) pag ctx = page.ContextWithSessionData(r.Context(), &page.SessionData{SessionID: supporterSessionID}) ) - if err := sesh.SetLoginSession(sessionStore, r, w, &sesh.LoginSession{Sub: supporterSub, Email: testEmail}); err != nil { - return err - } + loginSession := &sesh.LoginSession{Sub: supporterSub, Email: testEmail} if organisation == "1" { - if err := organisationStore.Create(ctx, random.String(12)); err != nil { + org, err := organisationStore.Create(ctx, random.String(12)) + + if err != nil { return err } + + loginSession.OrganisationID = org.ID + } + + if err := sesh.SetLoginSession(sessionStore, r, w, loginSession); err != nil { + return err } if redirect != page.Paths.Supporter.EnterOrganisationName.Format() { diff --git a/internal/page/paths.go b/internal/page/paths.go index cc89e7c1b9..1be8fb6e7d 100644 --- a/internal/page/paths.go +++ b/internal/page/paths.go @@ -43,10 +43,6 @@ func (p LpaPath) Redirect(w http.ResponseWriter, r *http.Request, appData AppDat rurl = fromURL } - if appData.IsSupporter { - rurl = "/supporter/" + rurl - } - if CanGoTo(donor, rurl) { http.Redirect(w, r, appData.Lang.URL(rurl), http.StatusFound) } else { diff --git a/internal/page/supporter/dashboard.go b/internal/page/supporter/dashboard.go index 156999ea45..4b9e914867 100644 --- a/internal/page/supporter/dashboard.go +++ b/internal/page/supporter/dashboard.go @@ -22,7 +22,7 @@ func Dashboard(tmpl template.Template, organisationStore OrganisationStore) Hand return err } - return page.Paths.Supporter.DonorDetails.RedirectToLpa(w, r.WithContext(r.Context()), appData, donorProvided) + return page.Paths.YourDetails.Redirect(w, r.WithContext(r.Context()), appData, donorProvided) } return tmpl(w, DashboardData{App: appData}) diff --git a/internal/page/supporter/enter_organisation_name.go b/internal/page/supporter/enter_organisation_name.go index 0931ee56b6..1544ed0b8e 100644 --- a/internal/page/supporter/enter_organisation_name.go +++ b/internal/page/supporter/enter_organisation_name.go @@ -5,6 +5,7 @@ import ( "github.com/ministryofjustice/opg-go-common/template" "github.com/ministryofjustice/opg-modernising-lpa/internal/page" + "github.com/ministryofjustice/opg-modernising-lpa/internal/sesh" "github.com/ministryofjustice/opg-modernising-lpa/internal/validation" ) @@ -14,7 +15,7 @@ type enterOrganisationNameData struct { Form *enterOrganisationNameForm } -func EnterOrganisationName(tmpl template.Template, organisationStore OrganisationStore) page.Handler { +func EnterOrganisationName(tmpl template.Template, organisationStore OrganisationStore, sessionStore sesh.Store) page.Handler { return func(appData page.AppData, w http.ResponseWriter, r *http.Request) error { data := &enterOrganisationNameData{ App: appData, @@ -26,7 +27,18 @@ func EnterOrganisationName(tmpl template.Template, organisationStore Organisatio data.Errors = data.Form.Validate() if !data.Errors.Any() { - if err := organisationStore.Create(r.Context(), data.Form.Name); err != nil { + organisation, err := organisationStore.Create(r.Context(), data.Form.Name) + if err != nil { + return err + } + + loginSession, err := sesh.Login(sessionStore, r) + if err != nil { + return page.Paths.Supporter.Start.Redirect(w, r, appData) + } + + loginSession.OrganisationID = organisation.ID + if err := sesh.SetLoginSession(sessionStore, r, w, loginSession); err != nil { return err } diff --git a/internal/page/supporter/login_callback.go b/internal/page/supporter/login_callback.go index 0ec726fda7..2480a97c57 100644 --- a/internal/page/supporter/login_callback.go +++ b/internal/page/supporter/login_callback.go @@ -35,22 +35,30 @@ func LoginCallback(oneLoginClient LoginCallbackOneLoginClient, sessionStore sesh session := &sesh.LoginSession{ IDToken: idToken, - Sub: userInfo.Sub, + Sub: "supporter-" + userInfo.Sub, Email: userInfo.Email, } - if err := sesh.SetLoginSession(sessionStore, r, w, session); err != nil { - return err - } - ctx := page.ContextWithSessionData(r.Context(), &page.SessionData{SessionID: session.SessionID()}) - _, err = organisationStore.Get(ctx) + organisation, err := organisationStore.Get(ctx) if err == nil { + session.OrganisationID = organisation.ID + if err := sesh.SetLoginSession(sessionStore, r, w, session); err != nil { + return err + } + return page.Paths.Supporter.Dashboard.Redirect(w, r, appData) } - if !errors.Is(err, dynamo.NotFoundError{}) { - return err + + if errors.Is(err, dynamo.NotFoundError{}) { + if err := sesh.SetLoginSession(sessionStore, r, w, &sesh.LoginSession{ + IDToken: idToken, + Sub: "supporter-" + userInfo.Sub, + Email: userInfo.Email, + }); err != nil { + return err + } } return page.Paths.Supporter.EnterOrganisationName.Redirect(w, r, appData) diff --git a/internal/page/supporter/mock_Handler_test.go b/internal/page/supporter/mock_Handler_test.go index cfbc60079c..7c142bd45c 100644 --- a/internal/page/supporter/mock_Handler_test.go +++ b/internal/page/supporter/mock_Handler_test.go @@ -5,8 +5,11 @@ package supporter import ( http "net/http" - page "github.com/ministryofjustice/opg-modernising-lpa/internal/page" + actor "github.com/ministryofjustice/opg-modernising-lpa/internal/actor" + mock "github.com/stretchr/testify/mock" + + page "github.com/ministryofjustice/opg-modernising-lpa/internal/page" ) // mockHandler is an autogenerated mock type for the Handler type @@ -22,17 +25,17 @@ func (_m *mockHandler) EXPECT() *mockHandler_Expecter { return &mockHandler_Expecter{mock: &_m.Mock} } -// Execute provides a mock function with given fields: data, w, r -func (_m *mockHandler) Execute(data page.AppData, w http.ResponseWriter, r *http.Request) error { - ret := _m.Called(data, w, r) +// Execute provides a mock function with given fields: data, w, r, organisation +func (_m *mockHandler) Execute(data page.AppData, w http.ResponseWriter, r *http.Request, organisation *actor.Organisation) error { + ret := _m.Called(data, w, r, organisation) if len(ret) == 0 { panic("no return value specified for Execute") } var r0 error - if rf, ok := ret.Get(0).(func(page.AppData, http.ResponseWriter, *http.Request) error); ok { - r0 = rf(data, w, r) + if rf, ok := ret.Get(0).(func(page.AppData, http.ResponseWriter, *http.Request, *actor.Organisation) error); ok { + r0 = rf(data, w, r, organisation) } else { r0 = ret.Error(0) } @@ -49,13 +52,14 @@ type mockHandler_Execute_Call struct { // - data page.AppData // - w http.ResponseWriter // - r *http.Request -func (_e *mockHandler_Expecter) Execute(data interface{}, w interface{}, r interface{}) *mockHandler_Execute_Call { - return &mockHandler_Execute_Call{Call: _e.mock.On("Execute", data, w, r)} +// - organisation *actor.Organisation +func (_e *mockHandler_Expecter) Execute(data interface{}, w interface{}, r interface{}, organisation interface{}) *mockHandler_Execute_Call { + return &mockHandler_Execute_Call{Call: _e.mock.On("Execute", data, w, r, organisation)} } -func (_c *mockHandler_Execute_Call) Run(run func(data page.AppData, w http.ResponseWriter, r *http.Request)) *mockHandler_Execute_Call { +func (_c *mockHandler_Execute_Call) Run(run func(data page.AppData, w http.ResponseWriter, r *http.Request, organisation *actor.Organisation)) *mockHandler_Execute_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(page.AppData), args[1].(http.ResponseWriter), args[2].(*http.Request)) + run(args[0].(page.AppData), args[1].(http.ResponseWriter), args[2].(*http.Request), args[3].(*actor.Organisation)) }) return _c } @@ -65,7 +69,7 @@ func (_c *mockHandler_Execute_Call) Return(_a0 error) *mockHandler_Execute_Call return _c } -func (_c *mockHandler_Execute_Call) RunAndReturn(run func(page.AppData, http.ResponseWriter, *http.Request) error) *mockHandler_Execute_Call { +func (_c *mockHandler_Execute_Call) RunAndReturn(run func(page.AppData, http.ResponseWriter, *http.Request, *actor.Organisation) error) *mockHandler_Execute_Call { _c.Call.Return(run) return _c } diff --git a/internal/page/supporter/mock_OrganisationStore_test.go b/internal/page/supporter/mock_OrganisationStore_test.go index 055bb0c7f8..2490f0a4c7 100644 --- a/internal/page/supporter/mock_OrganisationStore_test.go +++ b/internal/page/supporter/mock_OrganisationStore_test.go @@ -24,21 +24,33 @@ func (_m *mockOrganisationStore) EXPECT() *mockOrganisationStore_Expecter { } // Create provides a mock function with given fields: ctx, name -func (_m *mockOrganisationStore) Create(ctx context.Context, name string) error { +func (_m *mockOrganisationStore) Create(ctx context.Context, name string) (*actor.Organisation, error) { ret := _m.Called(ctx, name) if len(ret) == 0 { panic("no return value specified for Create") } - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + var r0 *actor.Organisation + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*actor.Organisation, error)); ok { + return rf(ctx, name) + } + if rf, ok := ret.Get(0).(func(context.Context, string) *actor.Organisation); ok { r0 = rf(ctx, name) } else { - r0 = ret.Error(0) + if ret.Get(0) != nil { + r0 = ret.Get(0).(*actor.Organisation) + } } - return r0 + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, name) + } else { + r1 = ret.Error(1) + } + + return r0, r1 } // mockOrganisationStore_Create_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Create' @@ -60,12 +72,12 @@ func (_c *mockOrganisationStore_Create_Call) Run(run func(ctx context.Context, n return _c } -func (_c *mockOrganisationStore_Create_Call) Return(_a0 error) *mockOrganisationStore_Create_Call { - _c.Call.Return(_a0) +func (_c *mockOrganisationStore_Create_Call) Return(_a0 *actor.Organisation, _a1 error) *mockOrganisationStore_Create_Call { + _c.Call.Return(_a0, _a1) return _c } -func (_c *mockOrganisationStore_Create_Call) RunAndReturn(run func(context.Context, string) error) *mockOrganisationStore_Create_Call { +func (_c *mockOrganisationStore_Create_Call) RunAndReturn(run func(context.Context, string) (*actor.Organisation, error)) *mockOrganisationStore_Create_Call { _c.Call.Return(run) return _c } diff --git a/internal/page/supporter/register.go b/internal/page/supporter/register.go index f2fa4f9de1..d43ad460e7 100644 --- a/internal/page/supporter/register.go +++ b/internal/page/supporter/register.go @@ -11,13 +11,12 @@ import ( "github.com/ministryofjustice/opg-modernising-lpa/internal/notify" "github.com/ministryofjustice/opg-modernising-lpa/internal/onelogin" "github.com/ministryofjustice/opg-modernising-lpa/internal/page" - "github.com/ministryofjustice/opg-modernising-lpa/internal/page/donor" "github.com/ministryofjustice/opg-modernising-lpa/internal/random" "github.com/ministryofjustice/opg-modernising-lpa/internal/sesh" ) type OrganisationStore interface { - Create(ctx context.Context, name string) error + Create(ctx context.Context, name string) (*actor.Organisation, error) CreateMemberInvite(ctx context.Context, organisation *actor.Organisation, email, code string) error Get(ctx context.Context) (*actor.Organisation, error) CreateLPA(ctx context.Context, organisationID string) (*actor.DonorProvidedDetails, error) @@ -72,17 +71,13 @@ func Register( handleRoot(supporterPaths.LoginCallback, page.None, LoginCallback(oneLoginClient, sessionStore, organisationStore)) handleRoot(supporterPaths.EnterOrganisationName, page.RequireSession, - EnterOrganisationName(supporterTmpls.Get("enter_organisation_name.gohtml"), organisationStore)) + EnterOrganisationName(supporterTmpls.Get("enter_organisation_name.gohtml"), organisationStore, sessionStore)) supporterMux := http.NewServeMux() rootMux.Handle("/supporter/", http.StripPrefix("/supporter", supporterMux)) - supporterLpaMux := http.NewServeMux() - rootMux.Handle("/supporter/lpa/", page.RouteToPrefix("/supporter/lpa/", supporterLpaMux, notFoundHandler)) - handleSupporter := makeHandle(supporterMux, sessionStore, errorHandler) handleWithSupporter := makeSupporterHandle(supporterMux, sessionStore, errorHandler, organisationStore) - handleWithSupporterAndDonor := makeSupporterDonorHandle(supporterLpaMux, sessionStore, errorHandler, organisationStore, donorStore) handleSupporter(page.Paths.Root, page.None, notFoundHandler) @@ -94,9 +89,6 @@ func Register( InviteMember(supporterTmpls.Get("invite_member.gohtml"), organisationStore, notifyClient, random.String)) handleWithSupporter(supporterPaths.InviteMemberConfirmation, Guidance(supporterTmpls.Get("invite_member_confirmation.gohtml"))) - - handleWithSupporterAndDonor(supporterPaths.DonorDetails, - donor.YourDetails(donorTmpls.Get("your_details.gohtml"), donorStore, sessionStore)) } func makeHandle(mux *http.ServeMux, store sesh.Store, errorHandler page.ErrorHandler) func(page.Path, page.HandleOpt, page.Handler) { @@ -168,47 +160,3 @@ func makeSupporterHandle(mux *http.ServeMux, store sesh.Store, errorHandler page }) } } - -func makeSupporterDonorHandle(mux *http.ServeMux, store sesh.Store, errorHandler page.ErrorHandler, organisationStore OrganisationStore, donorStore DonorStore) func(page.SupporterPath, donor.Handler) { - return func(path page.SupporterPath, h donor.Handler) { - mux.HandleFunc(path.String(), func(w http.ResponseWriter, r *http.Request) { - loginSession, err := sesh.Login(store, r) - if err != nil { - http.Redirect(w, r, page.Paths.Supporter.Start.Format(), http.StatusFound) - return - } - - ctx := r.Context() - - sessionData, err := page.SessionDataFromContext(ctx) - if err != nil { - errorHandler(w, r, err) - } - - sessionData.SessionID = loginSession.SessionID() - - appData := page.AppDataFromContext(ctx) - appData.IsSupporter = true - appData.SessionID = loginSession.SessionID() - appData.LpaID = sessionData.LpaID - - member, err := organisationStore.GetMember(page.ContextWithSessionData(ctx, sessionData)) - if err != nil { - errorHandler(w, r, err) - } - - appData.OrganisationID = member.OrganisationID() - sessionData.OrganisationID = member.OrganisationID() - - ctx = page.ContextWithAppData(page.ContextWithSessionData(ctx, sessionData), appData) - donorProvided, err := donorStore.Get(ctx) - if err != nil { - errorHandler(w, r, err) - } - - if err := h(appData, w, r.WithContext(ctx), donorProvided); err != nil { - errorHandler(w, r, err) - } - }) - } -} diff --git a/internal/sesh/sesh.go b/internal/sesh/sesh.go index 6d551a12f6..7c823f039c 100644 --- a/internal/sesh/sesh.go +++ b/internal/sesh/sesh.go @@ -107,9 +107,10 @@ func SetOneLogin(store sessions.Store, r *http.Request, w http.ResponseWriter, o } type LoginSession struct { - IDToken string - Sub string - Email string + IDToken string + Sub string + Email string + OrganisationID string } func (s LoginSession) SessionID() string {