diff --git a/data/service/api/v1/alerts.go b/data/service/api/v1/alerts.go index c4de56199b..e8036fae54 100644 --- a/data/service/api/v1/alerts.go +++ b/data/service/api/v1/alerts.go @@ -2,12 +2,11 @@ package v1 import ( "context" - "encoding/json" "errors" "fmt" - "log" "net/http" + "github.com/ant0ine/go-json-rest/rest" "go.mongodb.org/mongo-driver/mongo" "github.com/tidepool-org/platform/alerts" @@ -31,7 +30,7 @@ func DeleteAlert(dCtx service.Context) { details := request.DetailsFromContext(ctx) repo := dCtx.AlertsRepository() - if err := checkAuthentication(details, r.PathParam("userID")); err != nil { + if err := checkAuthentication(details); err != nil { dCtx.RespondWithError(platform.ErrorUnauthorized()) return } @@ -39,15 +38,21 @@ func DeleteAlert(dCtx service.Context) { cfg := &alerts.Config{} if err := request.DecodeRequestBody(r.Request, cfg); err != nil { dCtx.RespondWithError(platform.ErrorJSONMalformed()) + return + } + + if err := checkUserIDConsistency(details, r, cfg); err != nil { + dCtx.RespondWithError(platform.ErrorJSONMalformed()) + return } + userID := userIDWithServiceFallback(details, r.PathParam("userID")) pc := dCtx.PermissionClient() - if err := checkUserAuthorization(ctx, pc, details.UserID(), cfg.FollowedID); err != nil { + if err := checkUserAuthorization(ctx, pc, userID, cfg.FollowedID); err != nil { dCtx.RespondWithError(platform.ErrorUnauthorized()) return } - cfg.UserID = details.UserID() if err := repo.Delete(ctx, cfg); err != nil { dCtx.RespondWithError(platform.ErrorInternalServerFailure()) return @@ -60,19 +65,25 @@ func GetAlert(dCtx service.Context) { details := request.DetailsFromContext(ctx) repo := dCtx.AlertsRepository() - if err := checkAuthentication(details, r.PathParam("userID")); err != nil { + if err := checkAuthentication(details); err != nil { dCtx.RespondWithError(platform.ErrorUnauthorized()) return } + userID := userIDWithServiceFallback(details, r.PathParam("userID")) followedID := r.PathParam("followedID") pc := dCtx.PermissionClient() - if err := checkUserAuthorization(ctx, pc, details.UserID(), followedID); err != nil { + if err := checkUserAuthorization(ctx, pc, userID, followedID); err != nil { dCtx.RespondWithError(platform.ErrorUnauthorized()) return } - cfg := &alerts.Config{UserID: details.UserID(), FollowedID: followedID} + cfg := &alerts.Config{UserID: userID, FollowedID: followedID} + if err := checkUserIDConsistency(details, r, cfg); err != nil { + dCtx.RespondWithError(platform.ErrorJSONMalformed()) + return + } + alert, err := repo.Get(ctx, cfg) if err != nil { if errors.Is(err, mongo.ErrNoDocuments) { @@ -94,36 +105,72 @@ func UpsertAlert(dCtx service.Context) { details := request.DetailsFromContext(ctx) repo := dCtx.AlertsRepository() - if err := checkAuthentication(details, r.PathParam("userID")); err != nil { + if err := checkAuthentication(details); err != nil { dCtx.RespondWithError(platform.ErrorUnauthorized()) return } cfg := &alerts.Config{} - if err := json.NewDecoder(r.Body).Decode(cfg); err != nil { + if err := request.DecodeRequestBody(r.Request, cfg); err != nil { + dCtx.RespondWithError(platform.ErrorJSONMalformed()) + return + } + + if err := checkUserIDConsistency(details, r, cfg); err != nil { dCtx.RespondWithError(platform.ErrorJSONMalformed()) + return } + userID := userIDWithServiceFallback(details, r.PathParam("userID")) pc := dCtx.PermissionClient() - if err := checkUserAuthorization(ctx, pc, details.UserID(), cfg.FollowedID); err != nil { + if err := checkUserAuthorization(ctx, pc, userID, cfg.FollowedID); err != nil { dCtx.RespondWithError(platform.ErrorUnauthorized()) return } - cfg.UserID = details.UserID() if err := repo.Upsert(ctx, cfg); err != nil { dCtx.RespondWithError(platform.ErrorInternalServerFailure()) return } } -var ErrUnauthorized = fmt.Errorf("unauthorized") +var ( + ErrBadRequest = fmt.Errorf("bad request") + ErrUnauthorized = fmt.Errorf("unauthorized") +) + +// checkUserIDConsistency verifies the various userIDs in a request. +// +// There are three possible sources of userIDs: +// 1. the request path +// 2. the alerts.Config specified in the request body +// 3. the authenticating token (if a user token) +// +// For safety reasons, if any of these three values don't agree, return an +// error (bad request). +func checkUserIDConsistency(details request.Details, r *rest.Request, cfg *alerts.Config) error { + if details.IsService() { + // Services won't have a userID in their token, so that check is + // skipped. + if r.PathParam("userID") == cfg.UserID { + return nil + } + return ErrBadRequest + } + + if r.PathParam("userID") == details.UserID() && details.UserID() == cfg.UserID { + return nil + } + + return ErrBadRequest +} -func checkAuthentication(details request.Details, userID string) error { +// checkAuthentication ensures that the request has an authentication token. +func checkAuthentication(details request.Details) error { + if details.Token() == "" { + return ErrUnauthorized + } if details.IsUser() { - if details.UserID() != userID { - log.Printf("warning: URL userID doesn't match token userID, token wins ") - } return nil } if details.IsService() { @@ -146,3 +193,18 @@ func checkUserAuthorization(ctx context.Context, pc permission.Client, userID, f } return fmt.Errorf("user isn't authorized for alerting: %q", userID) } + +// userIDWithServiceFallback returns the user's ID. +// +// If the request is from a user, the userID found in the token will be +// returned. This could be an empty string if the request details are +// malformed. +// +// If the request is from a service, then the service fallback value is used, +// as no userID is passed with the details in the event of a service request. +func userIDWithServiceFallback(details request.Details, serviceFallback string) string { + if details.IsUser() { + return details.UserID() + } + return serviceFallback +} diff --git a/data/service/api/v1/alerts_test.go b/data/service/api/v1/alerts_test.go index 3b65f4e3e0..346d7d6c6e 100644 --- a/data/service/api/v1/alerts_test.go +++ b/data/service/api/v1/alerts_test.go @@ -16,7 +16,7 @@ import ( "github.com/tidepool-org/platform/request" ) -func permsNoAlerting() map[string]map[string]permission.Permissions { +func permsNoFollow() map[string]map[string]permission.Permissions { return map[string]map[string]permission.Permissions{ mocks.TestUserID1: { mocks.TestUserID2: { @@ -28,7 +28,7 @@ func permsNoAlerting() map[string]map[string]permission.Permissions { var _ = Describe("Alerts endpoints", func() { - testAuthentication := func(f func(dataservice.Context)) { + testAuthenticationRequired := func(f func(dataservice.Context)) { t := GinkgoT() body := bytes.NewBuffer(mocks.MustMarshalJSON(t, alerts.Config{ UserID: mocks.TestUserID1, @@ -45,7 +45,7 @@ var _ = Describe("Alerts endpoints", func() { Expect(rec.Code).To(Equal(http.StatusForbidden)) } - testPermissions := func(f func(dataservice.Context)) { + testUserHasFollowPermission := func(f func(dataservice.Context)) { t := GinkgoT() body := bytes.NewBuffer(mocks.MustMarshalJSON(t, alerts.Config{ UserID: mocks.TestUserID1, @@ -53,7 +53,7 @@ var _ = Describe("Alerts endpoints", func() { })) dCtx := mocks.NewContext(t, "", "", body) dCtx.MockAlertsRepository = newMockRepo() - dCtx.MockPermissionClient = mocks.NewPermission(permsNoAlerting(), nil, nil) + dCtx.MockPermissionClient = mocks.NewPermission(permsNoFollow(), nil, nil) f(dCtx) @@ -61,7 +61,7 @@ var _ = Describe("Alerts endpoints", func() { Expect(rec.Code).To(Equal(http.StatusForbidden)) } - testUserID := func(f func(dataservice.Context)) { + testAlertsConfigUserIDMustMatchToken := func(f func(dataservice.Context)) { t := GinkgoT() body := bytes.NewBuffer(mocks.MustMarshalJSON(t, alerts.Config{ UserID: "00000000-dead-4123-beef-000000000000", @@ -69,16 +69,48 @@ var _ = Describe("Alerts endpoints", func() { })) dCtx := mocks.NewContext(t, "", "", body) repo := newMockRepo() - repo.ExpectsOwnerID(mocks.TestUserID2) dCtx.MockAlertsRepository = repo - badDetails := mocks.NewDetails(request.MethodSessionToken, mocks.TestUserID1, "") - dCtx.WithDetails(badDetails) f(dCtx) - Expect(repo.UserID).To(Equal(mocks.TestUserID1)) rec := dCtx.Recorder() - Expect(rec.Code).To(Equal(http.StatusOK)) + Expect(rec.Code).To(Equal(http.StatusBadRequest)) + } + + testTokenUserIDMustMatchPathParam := func(f func(dataservice.Context), details *mocks.Details) { + t := GinkgoT() + dCtx := mocks.NewContext(t, "", "", nil) + if details != nil { + dCtx.WithDetails(details) + } + dCtx.RESTRequest.PathParams["userID"] = "bad" + repo := newMockRepo() + dCtx.MockAlertsRepository = repo + + f(dCtx) + + rec := dCtx.Recorder() + Expect(rec.Code).To(Equal(http.StatusBadRequest)) + } + + testAlertsConfigUserIDMustMatchPathParam := func(f func(dataservice.Context), details *mocks.Details) { + t := GinkgoT() + body := bytes.NewBuffer(mocks.MustMarshalJSON(t, alerts.Config{ + UserID: mocks.TestUserID1, + FollowedID: mocks.TestUserID2, + })) + dCtx := mocks.NewContext(t, "", "", body) + if details != nil { + dCtx.WithDetails(details) + } + dCtx.RESTRequest.PathParams["userID"] = "bad" + repo := newMockRepo() + dCtx.MockAlertsRepository = repo + + f(dCtx) + + rec := dCtx.Recorder() + Expect(rec.Code).To(Equal(http.StatusBadRequest)) } testInvalidJSON := func(f func(dataservice.Context)) { @@ -86,10 +118,7 @@ var _ = Describe("Alerts endpoints", func() { body := bytes.NewBuffer([]byte(`"improper JSON data"`)) dCtx := mocks.NewContext(t, "", "", body) repo := newMockRepo() - repo.ExpectsOwnerID(mocks.TestUserID2) dCtx.MockAlertsRepository = repo - badDetails := mocks.NewDetails(request.MethodSessionToken, mocks.TestUserID1, "") - dCtx.WithDetails(badDetails) f(dCtx) @@ -99,11 +128,21 @@ var _ = Describe("Alerts endpoints", func() { Describe("Delete", func() { It("rejects unauthenticated users", func() { - testAuthentication(DeleteAlert) + testAuthenticationRequired(DeleteAlert) + }) + + Context("when called by a service", func() { + It("requires that the alert.Config's userID matches the userID path param", func() { + testAlertsConfigUserIDMustMatchPathParam(UpsertAlert, mocks.ServiceDetails()) + }) + }) + + It("requires that the alert.Config's userID matches the user's token", func() { + testAlertsConfigUserIDMustMatchToken(DeleteAlert) }) - It("uses the authenticated user's userID", func() { - testUserID(DeleteAlert) + It("requires that the alert.Config's userID matches the userID path param", func() { + testAlertsConfigUserIDMustMatchPathParam(UpsertAlert, nil) }) It("errors on invalid JSON", func() { @@ -111,17 +150,45 @@ var _ = Describe("Alerts endpoints", func() { }) It("rejects users without alerting permissions", func() { - testPermissions(DeleteAlert) + testUserHasFollowPermission(DeleteAlert) + }) + }) + + Describe("Upsert", func() { + It("rejects unauthenticated users", func() { + testAuthenticationRequired(UpsertAlert) + }) + + Context("when called by a service", func() { + It("requires that the alert.Config's userID matches the userID path param", func() { + testAlertsConfigUserIDMustMatchPathParam(UpsertAlert, mocks.ServiceDetails()) + }) + }) + + It("requires that the alert.Config's userID matches the user's token", func() { + testAlertsConfigUserIDMustMatchToken(UpsertAlert) + }) + + It("requires that the alert.Config's userID matches the userID path param", func() { + testAlertsConfigUserIDMustMatchPathParam(UpsertAlert, nil) + }) + + It("errors on invalid JSON", func() { + testInvalidJSON(UpsertAlert) + }) + + It("rejects users without alerting permissions", func() { + testUserHasFollowPermission(UpsertAlert) }) }) Describe("Get", func() { It("rejects unauthenticated users", func() { - testAuthentication(GetAlert) + testAuthenticationRequired(GetAlert) }) - It("uses the authenticated user's userID", func() { - testUserID(GetAlert) + It("requires that the user's token matches the userID path param", func() { + testTokenUserIDMustMatchPathParam(GetAlert, nil) }) It("errors when Config doesn't exist", func() { @@ -142,31 +209,13 @@ var _ = Describe("Alerts endpoints", func() { }) It("rejects users without alerting permissions", func() { - testPermissions(func(dCtx dataservice.Context) { + testUserHasFollowPermission(func(dCtx dataservice.Context) { dCtx.Request().PathParams["followedID"] = mocks.TestUserID2 GetAlert(dCtx) }) }) }) - - Describe("Upsert", func() { - It("rejects unauthenticated users", func() { - testAuthentication(UpsertAlert) - }) - - It("uses the authenticated user's userID", func() { - testUserID(UpsertAlert) - }) - - It("errors on invalid JSON", func() { - testInvalidJSON(UpsertAlert) - }) - - It("rejects users without alerting permissions", func() { - testPermissions(UpsertAlert) - }) - }) }) type mockRepo struct { @@ -182,10 +231,6 @@ func (r *mockRepo) ReturnsError(err error) { r.Error = err } -func (r *mockRepo) ExpectsOwnerID(ownerID string) { - r.UserID = ownerID -} - func (r *mockRepo) Upsert(ctx context.Context, conf *alerts.Config) error { if r.Error != nil { return r.Error diff --git a/data/service/api/v1/mocks/context.go b/data/service/api/v1/mocks/context.go index 723171aee0..aa0bacf229 100644 --- a/data/service/api/v1/mocks/context.go +++ b/data/service/api/v1/mocks/context.go @@ -30,7 +30,7 @@ type Context struct { } func NewContext(t likeT, method, url string, body io.Reader) *Context { - details := defDetails() + details := DefaultDetails() ctx := request.NewContextWithDetails(stdcontext.Background(), details) r, err := http.NewRequestWithContext(ctx, method, url, body) if err != nil { @@ -42,7 +42,7 @@ func NewContext(t likeT, method, url string, body io.Reader) *Context { rr := &rest.Request{ Request: r, - PathParams: map[string]string{}, + PathParams: map[string]string{"userID": TestUserID1}, Env: map[string]interface{}{}, } responder, err := servicecontext.NewResponder(w, rr) @@ -70,10 +70,16 @@ func (c *Context) WithDetails(details *Details) { c.RESTRequest.Request = r.WithContext(ctx) } -func defDetails() *Details { +// DefaultDetails provides details for TestUser #1. +func DefaultDetails() *Details { return NewDetails(request.MethodSessionToken, TestUserID1, TestToken1) } +// ServiceDetails provides details for a service call. +func ServiceDetails() *Details { + return NewDetails(request.MethodServiceSecret, "", TestToken2) +} + func (c *Context) Response() rest.ResponseWriter { return c.ResponseWriter }