From 71145a48ba060651ebfcfac1d5df8e6bb5fed241 Mon Sep 17 00:00:00 2001 From: Joshua Hawxwell Date: Wed, 3 Jan 2024 08:28:12 +0000 Subject: [PATCH] Add update for when certificate provider signs --- Makefile | 2 +- docs/openapi/openapi.yaml | 2 +- internal/ddb/client.go | 69 ++++++++ internal/shared/client.go | 10 -- internal/shared/ddb.go | 70 -------- internal/shared/jwt.go | 40 ++--- internal/shared/jwt_test.go | 16 +- internal/shared/lang.go | 12 ++ internal/shared/person.go | 4 + internal/validate/validate.go | 93 +++++++++++ internal/validate/validate_test.go | 112 +++++++++++++ lambda/Dockerfile | 2 +- lambda/create/main.go | 22 +-- lambda/create/validate.go | 183 ++++++--------------- lambda/create/validate_test.go | 88 ---------- lambda/get/main.go | 13 +- lambda/update/main.go | 32 +++- lambda/update/main_test.go | 151 +++++++++++++++++ lambda/update/validate.go | 101 ++++++++++++ lambda/update/validate_test.go | 99 +++++++++++ terraform/environment/region/apigateway.tf | 1 + terraform/environment/region/main.tf | 1 + 22 files changed, 767 insertions(+), 356 deletions(-) create mode 100644 internal/ddb/client.go delete mode 100644 internal/shared/client.go create mode 100644 internal/shared/lang.go create mode 100644 internal/validate/validate.go create mode 100644 internal/validate/validate_test.go create mode 100644 lambda/update/main_test.go create mode 100644 lambda/update/validate.go create mode 100644 lambda/update/validate_test.go diff --git a/Makefile b/Makefile index 1d76d21c..e91005df 100644 --- a/Makefile +++ b/Makefile @@ -17,7 +17,7 @@ down: ## Stop application docker compose down test: ## Unit tests - go test ./lambda/get/... ./lambda/create/... ./lambda/update/... ./internal/shared/... -race -covermode=atomic -coverprofile=coverage.out + go test ./... -race -covermode=atomic -coverprofile=coverage.out test-api: URL ?= http://localhost:9000 test-api: diff --git a/docs/openapi/openapi.yaml b/docs/openapi/openapi.yaml index 3665a7d1..743ac70e 100644 --- a/docs/openapi/openapi.yaml +++ b/docs/openapi/openapi.yaml @@ -119,7 +119,6 @@ paths: httpMethod: "POST" type: "aws_proxy" contentHandling: "CONVERT_TO_TEXT" - /health-check: get: operationId: healthCheck @@ -448,6 +447,7 @@ components: - DONOR_ADDRESS_UPDATE - ATTORNEY_ADDRESS_UPDATE - SCANNING_CORRECTION + - CERTIFICATE_PROVIDER_SIGN changes: type: array items: diff --git a/internal/ddb/client.go b/internal/ddb/client.go new file mode 100644 index 00000000..0f539291 --- /dev/null +++ b/internal/ddb/client.go @@ -0,0 +1,69 @@ +package ddb + +import ( + "context" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" + "github.com/aws/aws-xray-sdk-go/xray" + "github.com/ministryofjustice/opg-data-lpa-store/internal/shared" +) + +type Client struct { + ddb *dynamodb.DynamoDB + tableName string +} + +func (c *Client) Put(ctx context.Context, data any) error { + item, err := dynamodbattribute.MarshalMap(data) + if err != nil { + return err + } + + _, err = c.ddb.PutItemWithContext(ctx, &dynamodb.PutItemInput{ + TableName: aws.String(c.tableName), + Item: item, + }) + + return err +} + +func (c *Client) Get(ctx context.Context, uid string) (shared.Lpa, error) { + lpa := shared.Lpa{} + + marshalledUid, err := dynamodbattribute.Marshal(uid) + if err != nil { + return lpa, err + } + + getItemOutput, err := c.ddb.GetItemWithContext(ctx, &dynamodb.GetItemInput{ + TableName: aws.String(c.tableName), + Key: map[string]*dynamodb.AttributeValue{ + "uid": marshalledUid, + }, + }) + + if err != nil { + return lpa, err + } + + err = dynamodbattribute.UnmarshalMap(getItemOutput.Item, &lpa) + + return lpa, err +} + +func New(endpoint, tableName string) *Client { + sess := session.Must(session.NewSession()) + sess.Config.Endpoint = &endpoint + + c := &Client{ + ddb: dynamodb.New(sess), + tableName: tableName, + } + + xray.AWS(c.ddb.Client) + + return c +} diff --git a/internal/shared/client.go b/internal/shared/client.go deleted file mode 100644 index abff0de6..00000000 --- a/internal/shared/client.go +++ /dev/null @@ -1,10 +0,0 @@ -package shared - -import ( - "context" -) - -type Client interface { - Put(ctx context.Context, data Lpa) error - Get(ctx context.Context, uid string) (Lpa, error) -} diff --git a/internal/shared/ddb.go b/internal/shared/ddb.go index 908beb39..a29b5e40 100644 --- a/internal/shared/ddb.go +++ b/internal/shared/ddb.go @@ -1,71 +1 @@ package shared - -import ( - "context" - "os" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/dynamodb" - "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" - "github.com/aws/aws-xray-sdk-go/xray" -) - -type DynamoDBClient struct { - ddb *dynamodb.DynamoDB - tableName string -} - -func (c DynamoDBClient) Put(ctx context.Context, data Lpa) error { - item, err := dynamodbattribute.MarshalMap(data) - if err != nil { - return err - } - - _, err = c.ddb.PutItemWithContext(ctx, &dynamodb.PutItemInput{ - TableName: aws.String(c.tableName), - Item: item, - }) - - return err -} - -func (c DynamoDBClient) Get(ctx context.Context, uid string) (Lpa, error) { - lpa := Lpa{} - - marshalledUid, err := dynamodbattribute.Marshal(uid) - if err != nil { - return lpa, err - } - - getItemOutput, err := c.ddb.GetItemWithContext(ctx, &dynamodb.GetItemInput{ - TableName: aws.String(c.tableName), - Key: map[string]*dynamodb.AttributeValue{ - "uid": marshalledUid, - }, - }) - - if err != nil { - return lpa, err - } - - err = dynamodbattribute.UnmarshalMap(getItemOutput.Item, &lpa) - - return lpa, err -} - -func NewDynamoDB(tableName string) DynamoDBClient { - sess := session.Must(session.NewSession()) - - endpoint := os.Getenv("AWS_DYNAMODB_ENDPOINT") - sess.Config.Endpoint = &endpoint - - c := DynamoDBClient{ - ddb: dynamodb.New(sess), - tableName: tableName, - } - - xray.AWS(c.ddb.Client) - - return c -} diff --git a/internal/shared/jwt.go b/internal/shared/jwt.go index f201512c..38d28d85 100644 --- a/internal/shared/jwt.go +++ b/internal/shared/jwt.go @@ -88,25 +88,6 @@ func NewJWTVerifier() JWTVerifier { } } -// tokenStr is the JWT token, minus any "Bearer: " prefix -func (v JWTVerifier) VerifyToken(tokenStr string) error { - lsc := lpaStoreClaims{} - - parsedToken, err := jwt.ParseWithClaims(tokenStr, &lsc, func(token *jwt.Token) (interface{}, error) { - return v.secretKey, nil - }) - - if err != nil { - return err - } - - if !parsedToken.Valid { - return fmt.Errorf("Invalid JWT") - } - - return nil -} - var bearerRegexp = regexp.MustCompile("^Bearer[ ]+") // verify JWT from event header @@ -119,9 +100,28 @@ func (v JWTVerifier) VerifyHeader(event events.APIGatewayProxyRequest) bool { } tokenStr := bearerRegexp.ReplaceAllString(jwtHeaders[0], "") - if v.VerifyToken(tokenStr) != nil { + if v.verifyToken(tokenStr) != nil { return false } return true } + +// tokenStr is the JWT token, minus any "Bearer: " prefix +func (v JWTVerifier) verifyToken(tokenStr string) error { + lsc := lpaStoreClaims{} + + parsedToken, err := jwt.ParseWithClaims(tokenStr, &lsc, func(token *jwt.Token) (interface{}, error) { + return v.secretKey, nil + }) + + if err != nil { + return err + } + + if !parsedToken.Valid { + return fmt.Errorf("Invalid JWT") + } + + return nil +} diff --git a/internal/shared/jwt_test.go b/internal/shared/jwt_test.go index b92e3122..8077da3d 100644 --- a/internal/shared/jwt_test.go +++ b/internal/shared/jwt_test.go @@ -26,7 +26,7 @@ func createToken(claims jwt.MapClaims) string { } func TestVerifyEmptyJwt(t *testing.T) { - err := verifier.VerifyToken("") + err := verifier.verifyToken("") assert.NotNil(t, err) } @@ -38,7 +38,7 @@ func TestVerifyExpInPast(t *testing.T) { "sub": "M-3467-89QW-ERTY", }) - err := verifier.VerifyToken(token) + err := verifier.verifyToken(token) assert.NotNil(t, err) if err != nil { @@ -54,7 +54,7 @@ func TestVerifyIatInFuture(t *testing.T) { "sub": "someone@someplace.somewhere.com", }) - err := verifier.VerifyToken(token) + err := verifier.verifyToken(token) assert.NotNil(t, err) if err != nil { @@ -70,7 +70,7 @@ func TestVerifyIssuer(t *testing.T) { "sub": "someone@someplace.somewhere.com", }) - err := verifier.VerifyToken(token) + err := verifier.verifyToken(token) assert.NotNil(t, err) if err != nil { @@ -86,7 +86,7 @@ func TestVerifyBadEmailForSiriusIssuer(t *testing.T) { "sub": "", }) - err := verifier.VerifyToken(token) + err := verifier.verifyToken(token) assert.NotNil(t, err) if err != nil { @@ -102,7 +102,7 @@ func TestVerifyBadUIDForMRLPAIssuer(t *testing.T) { "sub": "", }) - err := verifier.VerifyToken(token) + err := verifier.verifyToken(token) assert.NotNil(t, err) if err != nil { @@ -118,7 +118,7 @@ func TestVerifyGoodJwt(t *testing.T) { "sub": "someone@someplace.somewhere.com", }) - err := verifier.VerifyToken(token) + err := verifier.verifyToken(token) assert.Nil(t, err) } @@ -134,7 +134,7 @@ func TestNewJWTVerifier(t *testing.T) { newVerifier := NewJWTVerifier() os.Unsetenv("JWT_SECRET_KEY") - err := newVerifier.VerifyToken(token) + err := newVerifier.verifyToken(token) assert.Nil(t, err) } diff --git a/internal/shared/lang.go b/internal/shared/lang.go new file mode 100644 index 00000000..2fed3758 --- /dev/null +++ b/internal/shared/lang.go @@ -0,0 +1,12 @@ +package shared + +type Lang string + +var ( + LangCy = Lang("cy") + LangEn = Lang("en") +) + +func (l Lang) IsValid() bool { + return l == LangCy || l == LangEn +} diff --git a/internal/shared/person.go b/internal/shared/person.go index febd607c..251a5941 100644 --- a/internal/shared/person.go +++ b/internal/shared/person.go @@ -9,6 +9,10 @@ type Address struct { Country string `json:"country"` } +func (a Address) IsSet() bool { + return a.Line1 != "" || a.Line2 != "" || a.Line3 != "" || a.Town != "" || a.Postcode != "" || a.Country != "" +} + type Person struct { FirstNames string `json:"firstNames"` LastName string `json:"lastName"` diff --git a/internal/validate/validate.go b/internal/validate/validate.go new file mode 100644 index 00000000..8882f7c6 --- /dev/null +++ b/internal/validate/validate.go @@ -0,0 +1,93 @@ +package validate + +import ( + "fmt" + "regexp" + "time" + + "github.com/ministryofjustice/opg-data-lpa-store/internal/shared" +) + +var countryCodeRe = regexp.MustCompile("^[A-Z]{2}$") + +func All(fieldErrors ...[]shared.FieldError) []shared.FieldError { + var errors []shared.FieldError + + for _, e := range fieldErrors { + if e != nil { + errors = append(errors, e...) + } + } + + return errors +} + +func IfElse(ok bool, eIf []shared.FieldError, eElse []shared.FieldError) []shared.FieldError { + if ok { + return eIf + } + + return eElse +} + +func If(ok bool, e []shared.FieldError) []shared.FieldError { + return IfElse(ok, e, nil) +} + +func Required(source string, value string) []shared.FieldError { + return If(value == "", []shared.FieldError{{Source: source, Detail: "field is required"}}) +} + +func Empty(source string, value string) []shared.FieldError { + return If(value != "", []shared.FieldError{{Source: source, Detail: "field must not be provided"}}) +} + +func Date(source string, date shared.Date) []shared.FieldError { + if date.IsMalformed { + return []shared.FieldError{{Source: source, Detail: "invalid format"}} + } + + if date.IsZero() { + return []shared.FieldError{{Source: source, Detail: "field is required"}} + } + + return nil +} + +func Time(source string, t time.Time) []shared.FieldError { + return If(t.IsZero(), []shared.FieldError{{Source: source, Detail: "field is required"}}) +} + +func Address(prefix string, address shared.Address) []shared.FieldError { + return All( + Required(fmt.Sprintf("%s/line1", prefix), address.Line1), + Required(fmt.Sprintf("%s/town", prefix), address.Town), + Required(fmt.Sprintf("%s/country", prefix), address.Country), + Country(fmt.Sprintf("%s/country", prefix), address.Country), + ) +} + +func Country(source string, country string) []shared.FieldError { + return If(!countryCodeRe.MatchString(country), []shared.FieldError{{Source: source, Detail: "must be a valid ISO-3166-1 country code"}}) +} + +type isValid interface { + ~string + IsValid() bool +} + +func IsValid[V isValid](source string, v V) []shared.FieldError { + if e := Required(source, string(v)); e != nil { + return e + } + + if !v.IsValid() { + return []shared.FieldError{{Source: source, Detail: "invalid value"}} + } + + return nil +} + +func Unset(source string, v interface{ Unset() bool }) []shared.FieldError { + return If(!v.Unset(), []shared.FieldError{{Source: source, Detail: "field must not be provided"}}) +} diff --git a/internal/validate/validate_test.go b/internal/validate/validate_test.go new file mode 100644 index 00000000..a74e0289 --- /dev/null +++ b/internal/validate/validate_test.go @@ -0,0 +1,112 @@ +package validate + +import ( + "testing" + "time" + + "github.com/ministryofjustice/opg-data-lpa-store/internal/shared" + "github.com/stretchr/testify/assert" +) + +var validAddress = shared.Address{ + Line1: "123 Main St", + Town: "Homeland", + Country: "GB", +} + +func newDate(date string, isMalformed bool) shared.Date { + t, _ := time.Parse("2006-01-02", date) + + return shared.Date{ + Time: t, + IsMalformed: isMalformed, + } +} + +func TestAll(t *testing.T) { + errA := shared.FieldError{Source: "a", Detail: "a"} + errB := shared.FieldError{Source: "b", Detail: "b"} + errC := shared.FieldError{Source: "c", Detail: "c"} + + assert.Nil(t, All()) + assert.Nil(t, All([]shared.FieldError{}, []shared.FieldError{})) + assert.Equal(t, []shared.FieldError{errA, errB, errC}, All([]shared.FieldError{errA, errB}, []shared.FieldError{errC})) + assert.Equal(t, []shared.FieldError{errA, errB, errC}, All([]shared.FieldError{errA}, []shared.FieldError{errB, errC})) + assert.Equal(t, []shared.FieldError{errA, errB, errC}, All([]shared.FieldError{errA}, []shared.FieldError{errB}, []shared.FieldError{errC})) +} + +func TestIf(t *testing.T) { + errs := []shared.FieldError{{Source: "a", Detail: "a"}} + + assert.Equal(t, errs, If(true, errs)) + assert.Nil(t, If(false, errs)) +} + +func TestIfElse(t *testing.T) { + errsA := []shared.FieldError{{Source: "a", Detail: "a"}} + errsB := []shared.FieldError{{Source: "b", Detail: "b"}} + + assert.Equal(t, errsA, IfElse(true, errsA, errsB)) + assert.Equal(t, errsB, IfElse(false, errsA, errsB)) +} + +func TestRequired(t *testing.T) { + assert.Nil(t, Required("a", "a")) + assert.Equal(t, []shared.FieldError{{Source: "a", Detail: "field is required"}}, Required("a", "")) +} + +func TestEmpty(t *testing.T) { + assert.Nil(t, Empty("a", "")) + assert.Equal(t, []shared.FieldError{{Source: "a", Detail: "field must not be provided"}}, Empty("a", "a")) +} + +func TestDate(t *testing.T) { + assert.Nil(t, Date("a", shared.Date{Time: time.Now()})) + assert.Equal(t, []shared.FieldError{{Source: "a", Detail: "invalid format"}}, Date("a", shared.Date{IsMalformed: true})) + assert.Equal(t, []shared.FieldError{{Source: "a", Detail: "field is required"}}, Date("a", shared.Date{})) +} + +func TestAddressEmpty(t *testing.T) { + address := shared.Address{} + errors := Address("/test", address) + + assert.Contains(t, errors, shared.FieldError{Source: "/test/line1", Detail: "field is required"}) + assert.Contains(t, errors, shared.FieldError{Source: "/test/town", Detail: "field is required"}) + assert.Contains(t, errors, shared.FieldError{Source: "/test/country", Detail: "field is required"}) +} + +func TestAddressValid(t *testing.T) { + errors := Address("/test", validAddress) + + assert.Empty(t, errors) +} + +func TestAddressInvalidCountry(t *testing.T) { + invalidAddress := shared.Address{ + Line1: "123 Main St", + Town: "Homeland", + Country: "United Kingdom", + } + errors := Address("/test", invalidAddress) + + assert.Contains(t, errors, shared.FieldError{Source: "/test/country", Detail: "must be a valid ISO-3166-1 country code"}) +} + +type testIsValid string + +func (t testIsValid) IsValid() bool { return string(t) == "ok" } + +func TestIsValid(t *testing.T) { + assert.Nil(t, IsValid("a", testIsValid("ok"))) + assert.Equal(t, []shared.FieldError{{Source: "a", Detail: "field is required"}}, IsValid("a", testIsValid(""))) + assert.Equal(t, []shared.FieldError{{Source: "a", Detail: "invalid value"}}, IsValid("a", testIsValid("x"))) +} + +type testUnset bool + +func (t testUnset) Unset() bool { return bool(t) } + +func TestUnset(t *testing.T) { + assert.Nil(t, Unset("a", testUnset(true))) + assert.Equal(t, []shared.FieldError{{Source: "a", Detail: "field must not be provided"}}, Unset("a", testUnset(false))) +} diff --git a/lambda/Dockerfile b/lambda/Dockerfile index e84e3bc3..acbee20d 100644 --- a/lambda/Dockerfile +++ b/lambda/Dockerfile @@ -6,7 +6,7 @@ COPY ./go.sum /app/go.sum RUN go mod download -COPY ./internal/shared /app/internal/shared +COPY ./internal /app/internal ARG DIR COPY ./lambda/$DIR /app/lambda/$DIR diff --git a/lambda/create/main.go b/lambda/create/main.go index 197a4ac6..ceb66965 100644 --- a/lambda/create/main.go +++ b/lambda/create/main.go @@ -8,19 +8,22 @@ import ( "github.com/aws/aws-lambda-go/events" "github.com/aws/aws-lambda-go/lambda" + "github.com/ministryofjustice/opg-data-lpa-store/internal/ddb" "github.com/ministryofjustice/opg-data-lpa-store/internal/shared" "github.com/ministryofjustice/opg-go-common/logging" ) -type Response struct { -} - type Logger interface { Print(...interface{}) } +type Store interface { + Put(ctx context.Context, data any) error + Get(ctx context.Context, uid string) (shared.Lpa, error) +} + type Lambda struct { - store shared.Client + store Store verifier shared.JWTVerifier logger Logger } @@ -84,22 +87,15 @@ func (l *Lambda) HandleEvent(ctx context.Context, event events.APIGatewayProxyRe } // respond - body, err := json.Marshal(Response{}) - - if err != nil { - l.logger.Print(err) - return shared.ProblemInternalServerError.Respond() - } - response.StatusCode = 201 - response.Body = string(body) + response.Body = `{}` return response, nil } func main() { l := &Lambda{ - store: shared.NewDynamoDB(os.Getenv("DDB_TABLE_NAME_DEEDS")), + store: ddb.New(os.Getenv("AWS_DYNAMODB_ENDPOINT"), os.Getenv("DDB_TABLE_NAME_DEEDS")), verifier: shared.NewJWTVerifier(), logger: logging.New(os.Stdout, "opg-data-lpa-store"), } diff --git a/lambda/create/validate.go b/lambda/create/validate.go index 29ab8215..74707a24 100644 --- a/lambda/create/validate.go +++ b/lambda/create/validate.go @@ -2,56 +2,53 @@ package main import ( "fmt" - "regexp" - "time" "github.com/ministryofjustice/opg-data-lpa-store/internal/shared" + "github.com/ministryofjustice/opg-data-lpa-store/internal/validate" ) -var countryCodeRe = regexp.MustCompile("^[A-Z]{2}$") - func Validate(lpa shared.LpaInit) []shared.FieldError { activeAttorneyCount, replacementAttorneyCount := countAttorneys(lpa.Attorneys, lpa.TrustCorporations) - return flatten( - validateIsValid("/lpaType", lpa.LpaType), - required("/donor/firstNames", lpa.Donor.FirstNames), - required("/donor/lastName", lpa.Donor.LastName), - validateDate("/donor/dateOfBirth", lpa.Donor.DateOfBirth), - validateAddress("/donor/address", lpa.Donor.Address), - required("/certificateProvider/firstNames", lpa.CertificateProvider.FirstNames), - required("/certificateProvider/lastName", lpa.CertificateProvider.LastName), - validateAddress("/certificateProvider/address", lpa.CertificateProvider.Address), - validateIsValid("/certificateProvider/channel", lpa.CertificateProvider.Channel), - validateIfElse(lpa.CertificateProvider.Channel == shared.ChannelOnline, - required("/certificateProvider/email", lpa.CertificateProvider.Email), - empty("/certificateProvider/email", lpa.CertificateProvider.Email)), + return validate.All( + validate.IsValid("/lpaType", lpa.LpaType), + validate.Required("/donor/firstNames", lpa.Donor.FirstNames), + validate.Required("/donor/lastName", lpa.Donor.LastName), + validate.Date("/donor/dateOfBirth", lpa.Donor.DateOfBirth), + validate.Address("/donor/address", lpa.Donor.Address), + validate.Required("/certificateProvider/firstNames", lpa.CertificateProvider.FirstNames), + validate.Required("/certificateProvider/lastName", lpa.CertificateProvider.LastName), + validate.Address("/certificateProvider/address", lpa.CertificateProvider.Address), + validate.IsValid("/certificateProvider/channel", lpa.CertificateProvider.Channel), + validate.IfElse(lpa.CertificateProvider.Channel == shared.ChannelOnline, + validate.Required("/certificateProvider/email", lpa.CertificateProvider.Email), + validate.Empty("/certificateProvider/email", lpa.CertificateProvider.Email)), validateAttorneys("/attorneys", lpa.Attorneys), validateTrustCorporations("/trustCorporations", lpa.TrustCorporations), - validateIfElse(activeAttorneyCount > 1, - validateIsValid("/howAttorneysMakeDecisions", lpa.HowAttorneysMakeDecisions), - validateUnset("/howAttorneysMakeDecisions", lpa.HowAttorneysMakeDecisions)), - validateIfElse(lpa.HowAttorneysMakeDecisions == shared.HowMakeDecisionsJointlyForSomeSeverallyForOthers, - required("/howAttorneysMakeDecisionsDetails", lpa.HowAttorneysMakeDecisionsDetails), - empty("/howAttorneysMakeDecisionsDetails", lpa.HowAttorneysMakeDecisionsDetails)), - validateIf(replacementAttorneyCount > 0 && lpa.HowAttorneysMakeDecisions == shared.HowMakeDecisionsJointlyAndSeverally, - validateIsValid("/howReplacementAttorneysStepIn", lpa.HowReplacementAttorneysStepIn)), - validateIfElse(lpa.HowReplacementAttorneysStepIn == shared.HowStepInAnotherWay, - required("/howReplacementAttorneysStepInDetails", lpa.HowReplacementAttorneysStepInDetails), - empty("/howReplacementAttorneysStepInDetails", lpa.HowReplacementAttorneysStepInDetails)), - validateIfElse(replacementAttorneyCount > 1 && (lpa.HowReplacementAttorneysStepIn == shared.HowStepInAllCanNoLongerAct || lpa.HowAttorneysMakeDecisions != shared.HowMakeDecisionsJointlyAndSeverally), - validateIsValid("/howReplacementAttorneysMakeDecisions", lpa.HowReplacementAttorneysMakeDecisions), - validateUnset("/howReplacementAttorneysMakeDecisions", lpa.HowReplacementAttorneysMakeDecisions)), - validateIfElse(lpa.HowReplacementAttorneysMakeDecisions == shared.HowMakeDecisionsJointlyForSomeSeverallyForOthers, - required("/howReplacementAttorneysMakeDecisionsDetails", lpa.HowReplacementAttorneysMakeDecisionsDetails), - empty("/howReplacementAttorneysMakeDecisionsDetails", lpa.HowReplacementAttorneysMakeDecisionsDetails)), - validateIf(lpa.LpaType == shared.LpaTypePersonalWelfare, flatten( - validateIsValid("/lifeSustainingTreatmentOption", lpa.LifeSustainingTreatmentOption), - validateUnset("/whenTheLpaCanBeUsed", lpa.WhenTheLpaCanBeUsed))), - validateIf(lpa.LpaType == shared.LpaTypePropertyAndAffairs, flatten( - validateIsValid("/whenTheLpaCanBeUsed", lpa.WhenTheLpaCanBeUsed), - validateUnset("/lifeSustainingTreatmentOption", lpa.LifeSustainingTreatmentOption))), - validateTime("/signedAt", lpa.SignedAt), + validate.IfElse(activeAttorneyCount > 1, + validate.IsValid("/howAttorneysMakeDecisions", lpa.HowAttorneysMakeDecisions), + validate.Unset("/howAttorneysMakeDecisions", lpa.HowAttorneysMakeDecisions)), + validate.IfElse(lpa.HowAttorneysMakeDecisions == shared.HowMakeDecisionsJointlyForSomeSeverallyForOthers, + validate.Required("/howAttorneysMakeDecisionsDetails", lpa.HowAttorneysMakeDecisionsDetails), + validate.Empty("/howAttorneysMakeDecisionsDetails", lpa.HowAttorneysMakeDecisionsDetails)), + validate.If(replacementAttorneyCount > 0 && lpa.HowAttorneysMakeDecisions == shared.HowMakeDecisionsJointlyAndSeverally, + validate.IsValid("/howReplacementAttorneysStepIn", lpa.HowReplacementAttorneysStepIn)), + validate.IfElse(lpa.HowReplacementAttorneysStepIn == shared.HowStepInAnotherWay, + validate.Required("/howReplacementAttorneysStepInDetails", lpa.HowReplacementAttorneysStepInDetails), + validate.Empty("/howReplacementAttorneysStepInDetails", lpa.HowReplacementAttorneysStepInDetails)), + validate.IfElse(replacementAttorneyCount > 1 && (lpa.HowReplacementAttorneysStepIn == shared.HowStepInAllCanNoLongerAct || lpa.HowAttorneysMakeDecisions != shared.HowMakeDecisionsJointlyAndSeverally), + validate.IsValid("/howReplacementAttorneysMakeDecisions", lpa.HowReplacementAttorneysMakeDecisions), + validate.Unset("/howReplacementAttorneysMakeDecisions", lpa.HowReplacementAttorneysMakeDecisions)), + validate.IfElse(lpa.HowReplacementAttorneysMakeDecisions == shared.HowMakeDecisionsJointlyForSomeSeverallyForOthers, + validate.Required("/howReplacementAttorneysMakeDecisionsDetails", lpa.HowReplacementAttorneysMakeDecisionsDetails), + validate.Empty("/howReplacementAttorneysMakeDecisionsDetails", lpa.HowReplacementAttorneysMakeDecisionsDetails)), + validate.If(lpa.LpaType == shared.LpaTypePersonalWelfare, validate.All( + validate.IsValid("/lifeSustainingTreatmentOption", lpa.LifeSustainingTreatmentOption), + validate.Unset("/whenTheLpaCanBeUsed", lpa.WhenTheLpaCanBeUsed))), + validate.If(lpa.LpaType == shared.LpaTypePropertyAndAffairs, validate.All( + validate.IsValid("/whenTheLpaCanBeUsed", lpa.WhenTheLpaCanBeUsed), + validate.Unset("/lifeSustainingTreatmentOption", lpa.LifeSustainingTreatmentOption))), + validate.Time("/signedAt", lpa.SignedAt), ) } @@ -77,84 +74,6 @@ func countAttorneys(as []shared.Attorney, ts []shared.TrustCorporation) (actives return actives, replacements } -func flatten(fieldErrors ...[]shared.FieldError) []shared.FieldError { - var errors []shared.FieldError - - for _, e := range fieldErrors { - if e != nil { - errors = append(errors, e...) - } - } - - return errors -} - -func validateIfElse(ok bool, eIf []shared.FieldError, eElse []shared.FieldError) []shared.FieldError { - if ok { - return eIf - } - - return eElse -} - -func validateIf(ok bool, e []shared.FieldError) []shared.FieldError { - return validateIfElse(ok, e, nil) -} - -func required(source string, value string) []shared.FieldError { - return validateIf(value == "", []shared.FieldError{{Source: source, Detail: "field is required"}}) -} - -func empty(source string, value string) []shared.FieldError { - return validateIf(value != "", []shared.FieldError{{Source: source, Detail: "field must not be provided"}}) -} - -func validateDate(source string, date shared.Date) []shared.FieldError { - if date.IsMalformed { - return []shared.FieldError{{Source: source, Detail: "invalid format"}} - } - - if date.IsZero() { - return []shared.FieldError{{Source: source, Detail: "field is required"}} - } - - return nil -} - -func validateTime(source string, t time.Time) []shared.FieldError { - return validateIf(t.IsZero(), []shared.FieldError{{Source: source, Detail: "field is required"}}) -} - -func validateAddress(prefix string, address shared.Address) []shared.FieldError { - return flatten( - required(fmt.Sprintf("%s/line1", prefix), address.Line1), - required(fmt.Sprintf("%s/town", prefix), address.Town), - required(fmt.Sprintf("%s/country", prefix), address.Country), - validateIf(!countryCodeRe.MatchString(address.Country), []shared.FieldError{{Source: fmt.Sprintf("%s/country", prefix), Detail: "must be a valid ISO-3166-1 country code"}}), - ) -} - -type isValid interface { - ~string - IsValid() bool -} - -func validateIsValid[V isValid](source string, v V) []shared.FieldError { - if e := required(source, string(v)); e != nil { - return e - } - - if !v.IsValid() { - return []shared.FieldError{{Source: source, Detail: "invalid value"}} - } - - return nil -} - -func validateUnset(source string, v interface{ Unset() bool }) []shared.FieldError { - return validateIf(!v.Unset(), []shared.FieldError{{Source: source, Detail: "field must not be provided"}}) -} - func validateAttorneys(prefix string, attorneys []shared.Attorney) []shared.FieldError { var errors []shared.FieldError @@ -172,13 +91,13 @@ func validateAttorneys(prefix string, attorneys []shared.Attorney) []shared.Fiel } func validateAttorney(prefix string, attorney shared.Attorney) []shared.FieldError { - return flatten( - required(fmt.Sprintf("%s/firstNames", prefix), attorney.FirstNames), - required(fmt.Sprintf("%s/lastName", prefix), attorney.LastName), - required(fmt.Sprintf("%s/status", prefix), string(attorney.Status)), - validateDate(fmt.Sprintf("%s/dateOfBirth", prefix), attorney.DateOfBirth), - validateAddress(fmt.Sprintf("%s/address", prefix), attorney.Address), - validateIsValid(fmt.Sprintf("%s/status", prefix), attorney.Status), + return validate.All( + validate.Required(fmt.Sprintf("%s/firstNames", prefix), attorney.FirstNames), + validate.Required(fmt.Sprintf("%s/lastName", prefix), attorney.LastName), + validate.Required(fmt.Sprintf("%s/status", prefix), string(attorney.Status)), + validate.Date(fmt.Sprintf("%s/dateOfBirth", prefix), attorney.DateOfBirth), + validate.Address(fmt.Sprintf("%s/address", prefix), attorney.Address), + validate.IsValid(fmt.Sprintf("%s/status", prefix), attorney.Status), ) } @@ -195,11 +114,11 @@ func validateTrustCorporations(prefix string, trustCorporations []shared.TrustCo } func validateTrustCorporation(prefix string, trustCorporation shared.TrustCorporation) []shared.FieldError { - return flatten( - required(fmt.Sprintf("%s/name", prefix), trustCorporation.Name), - required(fmt.Sprintf("%s/companyNumber", prefix), trustCorporation.CompanyNumber), - required(fmt.Sprintf("%s/email", prefix), trustCorporation.Email), - validateAddress(fmt.Sprintf("%s/address", prefix), trustCorporation.Address), - validateIsValid(fmt.Sprintf("%s/status", prefix), trustCorporation.Status), + return validate.All( + validate.Required(fmt.Sprintf("%s/name", prefix), trustCorporation.Name), + validate.Required(fmt.Sprintf("%s/companyNumber", prefix), trustCorporation.CompanyNumber), + validate.Required(fmt.Sprintf("%s/email", prefix), trustCorporation.Email), + validate.Address(fmt.Sprintf("%s/address", prefix), trustCorporation.Address), + validate.IsValid(fmt.Sprintf("%s/status", prefix), trustCorporation.Status), ) } diff --git a/lambda/create/validate_test.go b/lambda/create/validate_test.go index 2482fdc6..69999243 100644 --- a/lambda/create/validate_test.go +++ b/lambda/create/validate_test.go @@ -40,94 +40,6 @@ func TestCountAttorneys(t *testing.T) { assert.Equal(t, 3, replacements) } -func TestFlatten(t *testing.T) { - errA := shared.FieldError{Source: "a", Detail: "a"} - errB := shared.FieldError{Source: "b", Detail: "b"} - errC := shared.FieldError{Source: "c", Detail: "c"} - - assert.Nil(t, flatten()) - assert.Nil(t, flatten([]shared.FieldError{}, []shared.FieldError{})) - assert.Equal(t, []shared.FieldError{errA, errB, errC}, flatten([]shared.FieldError{errA, errB}, []shared.FieldError{errC})) - assert.Equal(t, []shared.FieldError{errA, errB, errC}, flatten([]shared.FieldError{errA}, []shared.FieldError{errB, errC})) - assert.Equal(t, []shared.FieldError{errA, errB, errC}, flatten([]shared.FieldError{errA}, []shared.FieldError{errB}, []shared.FieldError{errC})) -} - -func TestValidateIf(t *testing.T) { - errs := []shared.FieldError{{Source: "a", Detail: "a"}} - - assert.Equal(t, errs, validateIf(true, errs)) - assert.Nil(t, validateIf(false, errs)) -} - -func TestValidateIfElse(t *testing.T) { - errsA := []shared.FieldError{{Source: "a", Detail: "a"}} - errsB := []shared.FieldError{{Source: "b", Detail: "b"}} - - assert.Equal(t, errsA, validateIfElse(true, errsA, errsB)) - assert.Equal(t, errsB, validateIfElse(false, errsA, errsB)) -} - -func TestRequired(t *testing.T) { - assert.Nil(t, required("a", "a")) - assert.Equal(t, []shared.FieldError{{Source: "a", Detail: "field is required"}}, required("a", "")) -} - -func TestEmpty(t *testing.T) { - assert.Nil(t, empty("a", "")) - assert.Equal(t, []shared.FieldError{{Source: "a", Detail: "field must not be provided"}}, empty("a", "a")) -} - -func TestValidateDate(t *testing.T) { - assert.Nil(t, validateDate("a", shared.Date{Time: time.Now()})) - assert.Equal(t, []shared.FieldError{{Source: "a", Detail: "invalid format"}}, validateDate("a", shared.Date{IsMalformed: true})) - assert.Equal(t, []shared.FieldError{{Source: "a", Detail: "field is required"}}, validateDate("a", shared.Date{})) -} - -func TestValidateAddressEmpty(t *testing.T) { - address := shared.Address{} - errors := validateAddress("/test", address) - - assert.Contains(t, errors, shared.FieldError{Source: "/test/line1", Detail: "field is required"}) - assert.Contains(t, errors, shared.FieldError{Source: "/test/town", Detail: "field is required"}) - assert.Contains(t, errors, shared.FieldError{Source: "/test/country", Detail: "field is required"}) -} - -func TestValidateAddressValid(t *testing.T) { - errors := validateAddress("/test", validAddress) - - assert.Empty(t, errors) -} - -func TestValidateAddressInvalidCountry(t *testing.T) { - invalidAddress := shared.Address{ - Line1: "123 Main St", - Town: "Homeland", - Country: "United Kingdom", - } - errors := validateAddress("/test", invalidAddress) - - assert.Contains(t, errors, shared.FieldError{Source: "/test/country", Detail: "must be a valid ISO-3166-1 country code"}) -} - -type testIsValid string - -func (t testIsValid) IsValid() bool { return string(t) == "ok" } - -func TestValidateIsValid(t *testing.T) { - assert.Nil(t, validateIsValid("a", testIsValid("ok"))) - assert.Equal(t, []shared.FieldError{{Source: "a", Detail: "field is required"}}, validateIsValid("a", testIsValid(""))) - assert.Equal(t, []shared.FieldError{{Source: "a", Detail: "invalid value"}}, validateIsValid("a", testIsValid("x"))) -} - -type testUnset bool - -func (t testUnset) Unset() bool { return bool(t) } - -func TestValidateUnset(t *testing.T) { - assert.Nil(t, validateUnset("a", testUnset(true))) - assert.Equal(t, []shared.FieldError{{Source: "a", Detail: "field must not be provided"}}, validateUnset("a", testUnset(false))) -} - func TestValidateAttorneyEmpty(t *testing.T) { attorney := shared.Attorney{} errors := validateAttorney("/test", attorney) diff --git a/lambda/get/main.go b/lambda/get/main.go index c0d7a1e9..d5b8c1bd 100644 --- a/lambda/get/main.go +++ b/lambda/get/main.go @@ -7,6 +7,7 @@ import ( "github.com/aws/aws-lambda-go/events" "github.com/aws/aws-lambda-go/lambda" + "github.com/ministryofjustice/opg-data-lpa-store/internal/ddb" "github.com/ministryofjustice/opg-data-lpa-store/internal/shared" "github.com/ministryofjustice/opg-go-common/logging" ) @@ -15,8 +16,12 @@ type Logger interface { Print(...interface{}) } +type Store interface { + Get(ctx context.Context, uid string) (shared.Lpa, error) +} + type Lambda struct { - store shared.Client + store Store verifier shared.JWTVerifier logger Logger } @@ -36,8 +41,8 @@ func (l *Lambda) HandleEvent(ctx context.Context, event events.APIGatewayProxyRe lpa, err := l.store.Get(ctx, event.PathParameters["uid"]) - // If item can't be found in DynamoDB then it returns empty object hence 404 error returned if - // empty object returned + // If item can't be found in DynamoDB then it returns empty object hence 404 error returned if + // empty object returned if lpa.Uid == "" { l.logger.Print("Uid not found") return shared.ProblemNotFoundRequest.Respond() @@ -63,7 +68,7 @@ func (l *Lambda) HandleEvent(ctx context.Context, event events.APIGatewayProxyRe func main() { l := &Lambda{ - store: shared.NewDynamoDB(os.Getenv("DDB_TABLE_NAME_DEEDS")), + store: ddb.New(os.Getenv("AWS_DYNAMODB_ENDPOINT"), os.Getenv("DDB_TABLE_NAME_DEEDS")), verifier: shared.NewJWTVerifier(), logger: logging.New(os.Stdout, "opg-data-lpa-store"), } diff --git a/lambda/update/main.go b/lambda/update/main.go index 0a2def0d..7b3b5fc1 100644 --- a/lambda/update/main.go +++ b/lambda/update/main.go @@ -9,6 +9,7 @@ import ( "github.com/aws/aws-lambda-go/events" "github.com/aws/aws-lambda-go/lambda" "github.com/go-openapi/jsonpointer" + "github.com/ministryofjustice/opg-data-lpa-store/internal/ddb" "github.com/ministryofjustice/opg-data-lpa-store/internal/shared" "github.com/ministryofjustice/opg-go-common/logging" ) @@ -17,10 +18,17 @@ type Logger interface { Print(...interface{}) } +type Store interface { + Put(ctx context.Context, data any) error + Get(ctx context.Context, uid string) (shared.Lpa, error) +} + type Lambda struct { - store shared.Client - verifier shared.JWTVerifier - logger Logger + store Store + verifier interface { + VerifyHeader(events.APIGatewayProxyRequest) bool + } + logger Logger } func (l *Lambda) HandleEvent(ctx context.Context, event events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error) { @@ -48,9 +56,19 @@ func (l *Lambda) HandleEvent(ctx context.Context, event events.APIGatewayProxyRe l.logger.Print(err) return shared.ProblemInternalServerError.Respond() } + if lpa.Uid == "" { + l.logger.Print("Uid not found") + return shared.ProblemNotFoundRequest.Respond() + } - validationErrs, err := applyUpdate(&lpa, update) + if errors := validateUpdate(update); len(errors) > 0 { + problem := shared.ProblemInvalidRequest + problem.Errors = errors + return problem.Respond() + } + + validationErrs, err := applyUpdate(&lpa, update) if err != nil { l.logger.Print(err) return shared.ProblemInternalServerError.Respond() @@ -63,14 +81,12 @@ func (l *Lambda) HandleEvent(ctx context.Context, event events.APIGatewayProxyRe return problem.Respond() } - err = l.store.Put(ctx, lpa) - if err != nil { + if err := l.store.Put(ctx, lpa); err != nil { l.logger.Print(err) return shared.ProblemInternalServerError.Respond() } body, err := json.Marshal(lpa) - if err != nil { l.logger.Print(err) return shared.ProblemInternalServerError.Respond() @@ -114,7 +130,7 @@ func applyUpdate(lpa *shared.Lpa, update shared.Update) ([]shared.FieldError, er func main() { l := &Lambda{ - store: shared.NewDynamoDB(os.Getenv("DDB_TABLE_NAME_DEEDS")), + store: ddb.New(os.Getenv("AWS_DYNAMODB_ENDPOINT"), os.Getenv("DDB_TABLE_NAME_DEEDS")), verifier: shared.NewJWTVerifier(), logger: logging.New(os.Stdout, "opg-data-lpa-store"), } diff --git a/lambda/update/main_test.go b/lambda/update/main_test.go new file mode 100644 index 00000000..b273ed7a --- /dev/null +++ b/lambda/update/main_test.go @@ -0,0 +1,151 @@ +package main + +import ( + "context" + "errors" + "io" + "testing" + + "github.com/aws/aws-lambda-go/events" + "github.com/ministryofjustice/opg-data-lpa-store/internal/shared" + "github.com/ministryofjustice/opg-go-common/logging" + "github.com/stretchr/testify/assert" +) + +var expectedError = errors.New("expected") + +type mockStore struct { + get shared.Lpa + getErr error + put any + putErr error +} + +func (m *mockStore) Get(context.Context, string) (shared.Lpa, error) { return m.get, m.getErr } +func (m *mockStore) Put(ctx context.Context, data any) error { + m.put = data + return m.putErr +} + +type mockVerifier struct{ ok bool } + +func (m *mockVerifier) VerifyHeader(events.APIGatewayProxyRequest) bool { return m.ok } + +func TestHandleEvent(t *testing.T) { + store := &mockStore{get: shared.Lpa{Uid: "1", LpaInit: shared.LpaInit{Donor: shared.Donor{Person: shared.Person{FirstNames: "Johm"}}}}} + l := Lambda{ + store: store, + verifier: &mockVerifier{ok: true}, + logger: logging.New(io.Discard, ""), + } + + resp, err := l.HandleEvent(context.Background(), events.APIGatewayProxyRequest{ + Body: `{"type":"SCANNING_CORRECTION","changes":[{"key":"/donor/firstNames","old":"Johm","new":"John"}]}`, + }) + assert.Nil(t, err) + assert.Equal(t, 201, resp.StatusCode) + assert.Contains(t, resp.Body, `"John"`) + assert.Equal(t, shared.Lpa{ + Uid: "1", + LpaInit: shared.LpaInit{Donor: shared.Donor{Person: shared.Person{FirstNames: "John"}}}, + }, store.put) +} + +func TestHandleEventWhenOldDoesNotMatch(t *testing.T) { + l := Lambda{ + store: &mockStore{get: shared.Lpa{Uid: "1"}}, + verifier: &mockVerifier{ok: true}, + logger: logging.New(io.Discard, ""), + } + + resp, err := l.HandleEvent(context.Background(), events.APIGatewayProxyRequest{ + Body: `{"type":"SCANNING_CORRECTION","changes":[{"key":"/donor/firstNames","old":"Johm","new":"John"}]}`, + }) + assert.Nil(t, err) + assert.Equal(t, 400, resp.StatusCode) + assert.JSONEq(t, `{"code":"INVALID_REQUEST","detail":"Invalid request","errors":[{"source":"/changes/0/old","detail":"does not match existing value"}]}`, resp.Body) +} + +func TestHandleEventWhenIncorrectPointer(t *testing.T) { + l := Lambda{ + store: &mockStore{get: shared.Lpa{Uid: "1"}}, + verifier: &mockVerifier{ok: true}, + logger: logging.New(io.Discard, ""), + } + + resp, err := l.HandleEvent(context.Background(), events.APIGatewayProxyRequest{ + Body: `{"type":"SCANNING_CORRECTION","changes":[{"key":"/donor/whatEven","old":"Johm","new":"John"}]}`, + }) + assert.Nil(t, err) + assert.Equal(t, 500, resp.StatusCode) + assert.JSONEq(t, `{"code":"INTERNAL_SERVER_ERROR","detail":"Internal server error"}`, resp.Body) +} + +func TestHandleEventWhenUpdateInvalid(t *testing.T) { + l := Lambda{ + store: &mockStore{get: shared.Lpa{Uid: "1"}}, + verifier: &mockVerifier{ok: true}, + logger: logging.New(io.Discard, ""), + } + + resp, err := l.HandleEvent(context.Background(), events.APIGatewayProxyRequest{ + Body: `{"type":"CERTIFICATE_PROVIDER_SIGN","changes":[]}`, + }) + assert.Nil(t, err) + assert.Equal(t, 400, resp.StatusCode) + assert.JSONEq(t, `{"code":"INVALID_REQUEST","detail":"Invalid request","errors":[{"source":"/changes","detail":"missing /certificateProvider/signedAt"},{"source":"/changes","detail":"missing /certificateProvider/contactLanguagePreference"}]}`, resp.Body) +} + +func TestHandleEventWhenLpaNotFound(t *testing.T) { + l := Lambda{ + store: &mockStore{}, + verifier: &mockVerifier{ok: true}, + logger: logging.New(io.Discard, ""), + } + + resp, err := l.HandleEvent(context.Background(), events.APIGatewayProxyRequest{ + Body: `{}`, + }) + assert.Nil(t, err) + assert.Equal(t, 404, resp.StatusCode) + assert.JSONEq(t, `{"code":"NOT_FOUND","detail":"Record not found"}`, resp.Body) +} + +func TestHandleEventWhenStoreGetError(t *testing.T) { + l := Lambda{ + store: &mockStore{getErr: expectedError}, + verifier: &mockVerifier{ok: true}, + logger: logging.New(io.Discard, ""), + } + + resp, err := l.HandleEvent(context.Background(), events.APIGatewayProxyRequest{ + Body: `{}`, + }) + assert.Nil(t, err) + assert.Equal(t, 500, resp.StatusCode) + assert.JSONEq(t, `{"code":"INTERNAL_SERVER_ERROR","detail":"Internal server error"}`, resp.Body) +} + +func TestHandleEventWhenRequestBodyNotJSON(t *testing.T) { + l := Lambda{ + verifier: &mockVerifier{ok: true}, + logger: logging.New(io.Discard, ""), + } + + resp, err := l.HandleEvent(context.Background(), events.APIGatewayProxyRequest{}) + assert.Nil(t, err) + assert.Equal(t, 500, resp.StatusCode) + assert.JSONEq(t, `{"code":"INTERNAL_SERVER_ERROR","detail":"Internal server error"}`, resp.Body) +} + +func TestHandleEventWhenHeaderNotVerified(t *testing.T) { + l := Lambda{ + verifier: &mockVerifier{ok: false}, + logger: logging.New(io.Discard, ""), + } + + resp, err := l.HandleEvent(context.Background(), events.APIGatewayProxyRequest{}) + assert.Nil(t, err) + assert.Equal(t, 401, resp.StatusCode) + assert.JSONEq(t, `{"code":"UNAUTHORISED","detail":"Invalid JWT"}`, resp.Body) +} diff --git a/lambda/update/validate.go b/lambda/update/validate.go new file mode 100644 index 00000000..a0a96890 --- /dev/null +++ b/lambda/update/validate.go @@ -0,0 +1,101 @@ +package main + +import ( + "fmt" + "time" + + "github.com/ministryofjustice/opg-data-lpa-store/internal/shared" + "github.com/ministryofjustice/opg-data-lpa-store/internal/validate" +) + +var ( + detailMustBeString = "must be a string" +) + +type CertificateProviderSign struct { + Address shared.Address + SignedAt time.Time + ContactLanguagePreference shared.Lang +} + +func validateUpdate(update shared.Update) []shared.FieldError { + switch update.Type { + case "CERTIFICATE_PROVIDER_SIGN": + var errors []shared.FieldError + var ok bool + x := CertificateProviderSign{} + + for i, change := range update.Changes { + if change.Old != nil { + errors = append(errors, shared.FieldError{Source: fmt.Sprintf("/changes/%d/old", i), Detail: "field must not be provided"}) + } + + newKey := fmt.Sprintf("/changes/%d/new", i) + switch change.Key { + case "/certificateProvider/address/line1": + if x.Address.Line1, ok = change.New.(string); !ok { + errors = append(errors, shared.FieldError{Source: newKey, Detail: detailMustBeString}) + } + case "/certificateProvider/address/line2": + if x.Address.Line2, ok = change.New.(string); !ok { + errors = append(errors, shared.FieldError{Source: newKey, Detail: detailMustBeString}) + } + case "/certificateProvider/address/line3": + if x.Address.Line3, ok = change.New.(string); !ok { + errors = append(errors, shared.FieldError{Source: newKey, Detail: detailMustBeString}) + } + case "/certificateProvider/address/town": + if x.Address.Town, ok = change.New.(string); !ok { + errors = append(errors, shared.FieldError{Source: newKey, Detail: detailMustBeString}) + } + case "/certificateProvider/address/postcode": + if x.Address.Postcode, ok = change.New.(string); !ok { + errors = append(errors, shared.FieldError{Source: newKey, Detail: detailMustBeString}) + } + case "/certificateProvider/address/country": + if x.Address.Country, ok = change.New.(string); !ok { + errors = append(errors, shared.FieldError{Source: newKey, Detail: detailMustBeString}) + } else { + errors = append(errors, validate.Country(newKey, x.Address.Country)...) + } + case "/certificateProvider/signedAt": + if x.SignedAt, ok = change.New.(time.Time); !ok { + errors = append(errors, shared.FieldError{Source: newKey, Detail: "must be a datetime"}) + } + case "/certificateProvider/contactLanguagePreference": + if x.ContactLanguagePreference, ok = change.New.(shared.Lang); !ok { + errors = append(errors, shared.FieldError{Source: newKey, Detail: detailMustBeString}) + } + default: + errors = append(errors, shared.FieldError{Source: fmt.Sprintf("/changes/%d", i), Detail: "change not allowed for type"}) + } + } + + if x.Address.IsSet() { + if x.Address.Line1 == "" { + errors = append(errors, shared.FieldError{Source: "/changes", Detail: "missing /certificateProvider/address/line1"}) + } + + if x.Address.Town == "" { + errors = append(errors, shared.FieldError{Source: "/changes", Detail: "missing /certificateProvider/address/town"}) + } + + if x.Address.Country == "" { + errors = append(errors, shared.FieldError{Source: "/changes", Detail: "missing /certificateProvider/address/country"}) + } + } + + if x.SignedAt.IsZero() { + errors = append(errors, shared.FieldError{Source: "/changes", Detail: "missing /certificateProvider/signedAt"}) + } + + if x.ContactLanguagePreference == shared.Lang("") { + errors = append(errors, shared.FieldError{Source: "/changes", Detail: "missing /certificateProvider/contactLanguagePreference"}) + } + + return errors + + default: + return []shared.FieldError{} + } +} diff --git a/lambda/update/validate_test.go b/lambda/update/validate_test.go new file mode 100644 index 00000000..900c879b --- /dev/null +++ b/lambda/update/validate_test.go @@ -0,0 +1,99 @@ +package main + +import ( + "testing" + "time" + + "github.com/ministryofjustice/opg-data-lpa-store/internal/shared" + "github.com/stretchr/testify/assert" +) + +func TestValidateUpdate(t *testing.T) { + testcases := map[string]struct { + update shared.Update + errors []shared.FieldError + }{ + "CERTIFICATE_PROVIDER_SIGN/valid": { + update: shared.Update{ + Type: "CERTIFICATE_PROVIDER_SIGN", + Changes: []shared.Change{ + { + Key: "/certificateProvider/address/line1", + New: "123 Main St", + }, + { + Key: "/certificateProvider/address/town", + New: "Homeland", + }, + { + Key: "/certificateProvider/address/country", + New: "GB", + }, + { + Key: "/certificateProvider/signedAt", + New: time.Now(), + }, + { + Key: "/certificateProvider/contactLanguagePreference", + New: shared.LangCy, + }, + }, + }, + }, + "CERTIFICATE_PROVIDER_SIGN/missing all": { + update: shared.Update{Type: "CERTIFICATE_PROVIDER_SIGN"}, + errors: []shared.FieldError{ + {Source: "/changes", Detail: "missing /certificateProvider/signedAt"}, + {Source: "/changes", Detail: "missing /certificateProvider/contactLanguagePreference"}, + }, + }, + "CERTIFICATE_PROVIDER_SIGN/bad address": { + update: shared.Update{ + Type: "CERTIFICATE_PROVIDER_SIGN", + Changes: []shared.Change{ + { + Key: "/certificateProvider/address/country", + New: "x", + }, + }, + }, + errors: []shared.FieldError{ + {Source: "/changes", Detail: "missing /certificateProvider/address/line1"}, + {Source: "/changes", Detail: "missing /certificateProvider/address/town"}, + {Source: "/changes/0/new", Detail: "must be a valid ISO-3166-1 country code"}, + {Source: "/changes", Detail: "missing /certificateProvider/signedAt"}, + {Source: "/changes", Detail: "missing /certificateProvider/contactLanguagePreference"}, + }, + }, + "CERTIFICATE_PROVIDER_SIGN/extra fields": { + update: shared.Update{ + Type: "CERTIFICATE_PROVIDER_SIGN", + Changes: []shared.Change{ + { + Key: "/certificateProvider/signedAt", + New: time.Now(), + }, + { + Key: "/certificateProvider/contactLanguagePreference", + Old: shared.LangEn, + New: shared.LangCy, + }, + { + Key: "/donor/firstNames", + New: "John", + }, + }, + }, + errors: []shared.FieldError{ + {Source: "/changes/1/old", Detail: "field must not be provided"}, + {Source: "/changes/2", Detail: "change not allowed for type"}, + }, + }, + } + + for name, tc := range testcases { + t.Run(name, func(t *testing.T) { + assert.ElementsMatch(t, tc.errors, validateUpdate(tc.update)) + }) + } +} diff --git a/terraform/environment/region/apigateway.tf b/terraform/environment/region/apigateway.tf index 2b59bf05..50d4737d 100644 --- a/terraform/environment/region/apigateway.tf +++ b/terraform/environment/region/apigateway.tf @@ -4,6 +4,7 @@ locals { lambda_create_invoke_arn = module.lambda["create"].invoke_arn lambda_get_invoke_arn = module.lambda["get"].invoke_arn lambda_update_invoke_arn = module.lambda["update"].invoke_arn + lambda_create_certificate_provider_invoke_arn = module.lambda["create-certificate-provider"].invoke_arn }) } diff --git a/terraform/environment/region/main.tf b/terraform/environment/region/main.tf index 458b2bc9..c5d58be1 100644 --- a/terraform/environment/region/main.tf +++ b/terraform/environment/region/main.tf @@ -3,6 +3,7 @@ locals { "create", "get", "update", + "create-certificate-provider", ]) }