Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Payload validation refactoring for processes, roles, routes, and service bindings #2633

Merged
merged 5 commits into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions api/handlers/app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ var _ = Describe("App", func() {
BeforeEach(func() {
payload = &payloads.AppCreate{
Name: appName,
Relationships: payloads.AppRelationships{
Space: payloads.Relationship{
Relationships: &payloads.AppRelationships{
Space: &payloads.Relationship{
Data: &payloads.RelationshipData{
GUID: spaceGUID,
},
Expand Down
16 changes: 6 additions & 10 deletions api/handlers/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,20 @@ type Process struct {
serverURL url.URL
processRepo CFProcessRepository
processStats ProcessStats
decoderValidator *DecoderValidator
requestValidator RequestValidator
}

func NewProcess(
serverURL url.URL,
processRepo CFProcessRepository,
processStatsFetcher ProcessStats,
decoderValidator *DecoderValidator,
requestValidator RequestValidator,
) *Process {
return &Process{
serverURL: serverURL,
processRepo: processRepo,
processStats: processStatsFetcher,
decoderValidator: decoderValidator,
requestValidator: requestValidator,
}
}

Expand Down Expand Up @@ -110,7 +110,7 @@ func (h *Process) scale(r *http.Request) (*routing.Response, error) {
processGUID := routing.URLParam(r, "guid")

var payload payloads.ProcessScale
if err := h.decoderValidator.DecodeAndValidateJSONPayload(r, &payload); err != nil {
if err := h.requestValidator.DecodeAndValidateJSONPayload(r, &payload); err != nil {
return nil, apierrors.LogAndReturn(logger, err, "failed to decode payload")
}

Expand Down Expand Up @@ -149,12 +149,8 @@ func (h *Process) list(r *http.Request) (*routing.Response, error) { //nolint:du
authInfo, _ := authorization.InfoFromContext(r.Context())
logger := logr.FromContextOrDiscard(r.Context()).WithName("handlers.process.list")

if err := r.ParseForm(); err != nil {
return nil, apierrors.LogAndReturn(logger, err, "Unable to parse request query parameters")
}

processListFilter := new(payloads.ProcessList)
err := payloads.Decode(processListFilter, r.Form)
err := h.requestValidator.DecodeAndValidateURLValues(r, processListFilter)
if err != nil {
return nil, apierrors.LogAndReturn(logger, err, "Unable to decode request query parameters")
}
Expand All @@ -174,7 +170,7 @@ func (h *Process) update(r *http.Request) (*routing.Response, error) {
processGUID := routing.URLParam(r, "guid")

var payload payloads.ProcessPatch
if err := h.decoderValidator.DecodeAndValidateJSONPayload(r, &payload); err != nil {
if err := h.requestValidator.DecodeAndValidateJSONPayload(r, &payload); err != nil {
return nil, apierrors.LogAndReturn(logger, err, "failed to decode json payload")
}

Expand Down
115 changes: 61 additions & 54 deletions api/handlers/process_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@ package handlers_test

import (
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"

"code.cloudfoundry.org/korifi/api/actions"
apierrors "code.cloudfoundry.org/korifi/api/errors"
. "code.cloudfoundry.org/korifi/api/handlers"
"code.cloudfoundry.org/korifi/api/handlers/fake"
"code.cloudfoundry.org/korifi/api/payloads"
"code.cloudfoundry.org/korifi/api/repositories"
. "code.cloudfoundry.org/korifi/tests/matchers"
"code.cloudfoundry.org/korifi/tools"
Expand All @@ -20,21 +21,21 @@ import (

var _ = Describe("Process", func() {
var (
processRepo *fake.CFProcessRepository
processStats *fake.ProcessStats
processRepo *fake.CFProcessRepository
processStats *fake.ProcessStats
requestValidator *fake.RequestValidator
)

BeforeEach(func() {
processRepo = new(fake.CFProcessRepository)
processStats = new(fake.ProcessStats)
decoderValidator, err := NewDefaultDecoderValidator()
Expect(err).NotTo(HaveOccurred())
requestValidator = new(fake.RequestValidator)

apiHandler := NewProcess(
*serverURL,
processRepo,
processStats,
decoderValidator,
requestValidator,
)
routerBuilder.LoadRoutes(apiHandler)
})
Expand Down Expand Up @@ -151,6 +152,12 @@ var _ = Describe("Process", func() {
"memory_in_mb": 512,
"disk_in_mb": 256
}`

requestValidator.DecodeAndValidateJSONPayloadStub = decodeAndValidateJSONPayloadStub(&payloads.ProcessScale{
Instances: tools.PtrTo(3),
MemoryMB: tools.PtrTo[int64](512),
DiskMB: tools.PtrTo[int64](256),
})
})

JustBeforeEach(func() {
Expand All @@ -160,6 +167,12 @@ var _ = Describe("Process", func() {
})

It("scales the process", func() {
Expect(requestValidator.DecodeAndValidateJSONPayloadCallCount()).To(Equal(1))
req, _ := requestValidator.DecodeAndValidateJSONPayloadArgsForCall(0)
reqBytes, err := io.ReadAll(req.Body)
Expect(err).NotTo(HaveOccurred())
Expect(string(reqBytes)).To(Equal(requestBody))

Expect(processRepo.GetProcessCallCount()).To(Equal(1))
_, actualAuthInfo, actualProcessGUID := processRepo.GetProcessArgsForCall(0)
Expect(actualAuthInfo).To(Equal(authInfo))
Expand Down Expand Up @@ -190,11 +203,11 @@ var _ = Describe("Process", func() {

When("the request JSON is invalid", func() {
BeforeEach(func() {
requestBody = `}`
requestValidator.DecodeAndValidateJSONPayloadReturns(errors.New("boom"))
})

It("has the expected error response body", func() {
expectBadRequestError()
It("returns an error", func() {
expectUnknownError()
})
})

Expand Down Expand Up @@ -227,24 +240,6 @@ var _ = Describe("Process", func() {
expectUnknownError()
})
})

When("validating scale parameters", func() {
DescribeTable("returns a validation decision",
func(requestBody string, status int) {
tableTestRecorder := httptest.NewRecorder()
req, err := http.NewRequestWithContext(ctx, "POST", "/v3/processes/process-guid/actions/scale", strings.NewReader(requestBody))
Expect(err).NotTo(HaveOccurred())
routerBuilder.Build().ServeHTTP(tableTestRecorder, req)
Expect(tableTestRecorder.Code).To(Equal(status))
},
Entry("instances is negative", `{"instances":-1}`, http.StatusUnprocessableEntity),
Entry("memory is not a positive integer", `{"memory_in_mb":0}`, http.StatusUnprocessableEntity),
Entry("disk is not a positive integer", `{"disk_in_mb":0}`, http.StatusUnprocessableEntity),
Entry("instances is zero", `{"instances":0}`, http.StatusOK),
Entry("memory is a positive integer", `{"memory_in_mb":1024}`, http.StatusOK),
Entry("disk is a positive integer", `{"disk_in_mb":1024}`, http.StatusOK),
)
})
})

Describe("the GET /v3/processes/<guid>/stats endpoint", func() {
Expand Down Expand Up @@ -308,10 +303,8 @@ var _ = Describe("Process", func() {
})

Describe("the GET /v3/processes endpoint", func() {
var queryString string

BeforeEach(func() {
queryString = ""
requestValidator.DecodeAndValidateURLValuesStub = decodeAndValidateURLValuesStub(&payloads.ProcessList{})
processRepo.ListProcessesReturns([]repositories.ProcessRecord{
{
GUID: "process-guid",
Expand All @@ -320,12 +313,16 @@ var _ = Describe("Process", func() {
})

JustBeforeEach(func() {
req, err := http.NewRequestWithContext(ctx, "GET", "/v3/processes"+queryString, nil)
req, err := http.NewRequestWithContext(ctx, "GET", "/v3/processes", nil)
Expect(err).NotTo(HaveOccurred())
routerBuilder.Build().ServeHTTP(rr, req)
})

It("returns the processes", func() {
Expect(requestValidator.DecodeAndValidateURLValuesCallCount()).To(Equal(1))
req, _ := requestValidator.DecodeAndValidateURLValuesArgsForCall(0)
Expect(req.URL.String()).To(HaveSuffix("/v3/processes"))

Expect(rr).To(HaveHTTPStatus(http.StatusOK))
Expect(rr).To(HaveHTTPHeaderWithValue("Content-Type", "application/json"))
Expect(rr).To(HaveHTTPBody(SatisfyAll(
Expand All @@ -335,9 +332,11 @@ var _ = Describe("Process", func() {
)))
})

When("Query Parameters are provided", func() {
When("app_guids query parameter is provided", func() {
BeforeEach(func() {
queryString = "?app_guids=my-app-guid"
requestValidator.DecodeAndValidateURLValuesStub = decodeAndValidateURLValuesStub(&payloads.ProcessList{
AppGUIDs: "my-app-guid",
})
})

It("invokes process repository with correct args", func() {
Expand All @@ -349,13 +348,13 @@ var _ = Describe("Process", func() {
})
})

When("invalid query parameters are provided", func() {
When("the request body is invalid", func() {
BeforeEach(func() {
queryString = "?foo=my-app-guid"
requestValidator.DecodeAndValidateURLValuesReturns(errors.New("boo"))
})

It("returns an Unknown key error", func() {
expectUnknownKeyError("The query parameter is invalid: Valid parameters are: .*")
It("returns an error", func() {
expectUnknownError()
})
})

Expand Down Expand Up @@ -397,6 +396,22 @@ var _ = Describe("Process", func() {
processRepo.PatchProcessReturns(repositories.ProcessRecord{
GUID: "process-guid",
}, nil)

requestValidator.DecodeAndValidateJSONPayloadStub = decodeAndValidateJSONPayloadStub(&payloads.ProcessPatch{
Metadata: &payloads.MetadataPatch{
Labels: map[string]*string{
"foo": tools.PtrTo("value1"),
},
},
HealthCheck: &payloads.HealthCheck{
Type: tools.PtrTo("port"),
Data: &payloads.Data{
Timeout: tools.PtrTo[int64](5),
Endpoint: tools.PtrTo("http://myapp.com/health"),
InvocationTimeout: tools.PtrTo[int64](2),
},
},
})
})

JustBeforeEach(func() {
Expand All @@ -406,6 +421,12 @@ var _ = Describe("Process", func() {
})

It("updates the process", func() {
Expect(requestValidator.DecodeAndValidateJSONPayloadCallCount()).To(Equal(1))
req, _ := requestValidator.DecodeAndValidateJSONPayloadArgsForCall(0)
reqBytes, err := io.ReadAll(req.Body)
Expect(err).NotTo(HaveOccurred())
Expect(string(reqBytes)).To(Equal(requestBody))

Expect(processRepo.PatchProcessCallCount()).To(Equal(1))
_, actualAuthInfo, actualMsg := processRepo.PatchProcessArgsForCall(0)
Expect(actualAuthInfo).To(Equal(authInfo))
Expand All @@ -426,25 +447,11 @@ var _ = Describe("Process", func() {

When("the request body is invalid json", func() {
BeforeEach(func() {
requestBody = `{`
requestValidator.DecodeAndValidateJSONPayloadReturns(errors.New("boom"))
})

It("return an request malformed error", func() {
expectBadRequestError()
})
})

When("the request body is invalid with an unknown field", func() {
BeforeEach(func() {
requestBody = `{
"health_check": {
"endpoint": "my-endpoint"
}
}`
})

It("return an request malformed error", func() {
expectUnprocessableEntityError("invalid request body: json: unknown field \"endpoint\"")
It("return an error", func() {
expectUnknownError()
})
})

Expand Down
36 changes: 8 additions & 28 deletions api/handlers/role.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,6 @@ const (
RolePath = RolesPath + "/{guid}"
)

type RoleName string

const (
RoleAdmin RoleName = "admin"
RoleAdminReadOnly RoleName = "admin_read_only"
RoleGlobalAuditor RoleName = "global_auditor"
RoleOrganizationAuditor RoleName = "organization_auditor"
RoleOrganizationBillingManager RoleName = "organization_billing_manager"
RoleOrganizationManager RoleName = "organization_manager"
RoleOrganizationUser RoleName = "organization_user"
RoleSpaceAuditor RoleName = "space_auditor"
RoleSpaceDeveloper RoleName = "space_developer"
RoleSpaceManager RoleName = "space_manager"
RoleSpaceSupporter RoleName = "space_supporter"
)

//counterfeiter:generate -o fake -fake-name CFRoleRepository . CFRoleRepository

type CFRoleRepository interface {
Expand All @@ -49,14 +33,14 @@ type CFRoleRepository interface {
type Role struct {
apiBaseURL url.URL
roleRepo CFRoleRepository
decoderValidator RequestValidator
requestValidator RequestValidator
}

func NewRole(apiBaseURL url.URL, roleRepo CFRoleRepository, decoderValidator RequestValidator) *Role {
func NewRole(apiBaseURL url.URL, roleRepo CFRoleRepository, requestValidator RequestValidator) *Role {
return &Role{
apiBaseURL: apiBaseURL,
roleRepo: roleRepo,
decoderValidator: decoderValidator,
requestValidator: requestValidator,
}
}

Expand All @@ -65,7 +49,7 @@ func (h *Role) create(r *http.Request) (*routing.Response, error) {
logger := logr.FromContextOrDiscard(r.Context()).WithName("handlers.role.create")

var payload payloads.RoleCreate
if err := h.decoderValidator.DecodeAndValidateJSONPayload(r, &payload); err != nil {
if err := h.requestValidator.DecodeAndValidateJSONPayload(r, &payload); err != nil {
return nil, apierrors.LogAndReturn(logger, err, "failed to decode payload")
}

Expand All @@ -84,12 +68,8 @@ func (h *Role) list(r *http.Request) (*routing.Response, error) {
authInfo, _ := authorization.InfoFromContext(r.Context())
logger := logr.FromContextOrDiscard(r.Context()).WithName("handlers.role.list")

if err := r.ParseForm(); err != nil {
return nil, apierrors.LogAndReturn(logger, err, "Unable to parse request query parameters")
}

roleListFilter := new(payloads.RoleListFilter)
err := payloads.Decode(roleListFilter, r.Form)
roleListFilter := new(payloads.RoleList)
err := h.requestValidator.DecodeAndValidateURLValues(r, roleListFilter)
if err != nil {
return nil, apierrors.LogAndReturn(logger, err, "Unable to decode request query parameters")
}
Expand All @@ -101,14 +81,14 @@ func (h *Role) list(r *http.Request) (*routing.Response, error) {

filteredRoles := filterRoles(roleListFilter, roles)

if err := h.sortList(filteredRoles, r.FormValue("order_by")); err != nil {
if err := h.sortList(filteredRoles, roleListFilter.OrderBy); err != nil {
return nil, apierrors.LogAndReturn(logger, err, "unable to parse order by request")
}

return routing.NewResponse(http.StatusOK).WithBody(presenter.ForList(presenter.ForRole, filteredRoles, h.apiBaseURL, *r.URL)), nil
}

func filterRoles(roleListFilter *payloads.RoleListFilter, roles []repositories.RoleRecord) []repositories.RoleRecord {
func filterRoles(roleListFilter *payloads.RoleList, roles []repositories.RoleRecord) []repositories.RoleRecord {
var filteredRoles []repositories.RoleRecord
for _, role := range roles {
if match(roleListFilter.GUIDs, role.GUID) &&
Expand Down
Loading
Loading