From 7ace6e8553231ca283304280dd0ce7046989642d Mon Sep 17 00:00:00 2001 From: Joshua Hawxwell Date: Wed, 3 Jan 2024 08:28:12 +0000 Subject: [PATCH] Add lambda to create certificate provider --- Makefile | 2 +- docs/openapi/openapi.yaml | 44 ++++- internal/ddb/client.go | 69 +++++++ internal/shared/client.go | 10 - internal/shared/ddb.go | 70 ------- internal/shared/jwt.go | 40 ++-- internal/shared/jwt_test.go | 16 +- internal/shared/lang.go | 12 ++ internal/validate/validate.go | 89 +++++++++ internal/validate/validate_test.go | 112 +++++++++++ lambda/Dockerfile | 2 +- lambda/create-certificate-provider/main.go | 100 ++++++++++ .../create-certificate-provider/main_test.go | 148 ++++++++++++++ .../create-certificate-provider/validate.go | 23 +++ .../validate_test.go | 46 +++++ lambda/create/main.go | 22 +-- lambda/create/validate.go | 183 +++++------------- lambda/create/validate_test.go | 88 --------- lambda/get/main.go | 13 +- lambda/update/main.go | 10 +- terraform/environment/region/apigateway.tf | 1 + terraform/environment/region/main.tf | 1 + 22 files changed, 751 insertions(+), 350 deletions(-) create mode 100644 internal/ddb/client.go delete mode 100644 internal/shared/client.go create mode 100644 internal/shared/lang.go create mode 100644 internal/validate/validate.go create mode 100644 internal/validate/validate_test.go create mode 100644 lambda/create-certificate-provider/main.go create mode 100644 lambda/create-certificate-provider/main_test.go create mode 100644 lambda/create-certificate-provider/validate.go create mode 100644 lambda/create-certificate-provider/validate_test.go diff --git a/Makefile b/Makefile index 1d76d21c..e91005df 100644 --- a/Makefile +++ b/Makefile @@ -17,7 +17,7 @@ down: ## Stop application docker compose down test: ## Unit tests - go test ./lambda/get/... ./lambda/create/... ./lambda/update/... ./internal/shared/... -race -covermode=atomic -coverprofile=coverage.out + go test ./... -race -covermode=atomic -coverprofile=coverage.out test-api: URL ?= http://localhost:9000 test-api: diff --git a/docs/openapi/openapi.yaml b/docs/openapi/openapi.yaml index 3665a7d1..263ceca2 100644 --- a/docs/openapi/openapi.yaml +++ b/docs/openapi/openapi.yaml @@ -119,7 +119,33 @@ paths: httpMethod: "POST" type: "aws_proxy" contentHandling: "CONVERT_TO_TEXT" - + /lpas/{uid}/certifcate-provider: + parameters: + - name: uid + in: path + required: true + description: The UID of the case + schema: + type: string + pattern: "M-([A-Z0-9]{4})-([A-Z0-9]{4})-([A-Z0-9]{4})" + example: M-789Q-P4DF-4UX3 + put: + operationId: putCertificateProvider + summary: Store a certificate provider + requestBody: + content: + application/json: + schema: + $ref: "#/components/schemas/CertificateProviderRequest" + responses: + "201": + description: Certificate provider created + "400": + description: Invalid request + content: + application/json: + schema: + $ref: "#/components/schemas/BadRequestError" /health-check: get: operationId: healthCheck @@ -202,6 +228,22 @@ components: properties: code: enum: ["NOT_FOUND"] + CertificateProviderRequest: + type: object + required: + - contactLanguagePreference + - signedAt + properties: + address: + $ref: "#/components/schemas/Address" + contactLanguagePreference: + enum: + - cy + - en + signedAt: + type: string + format: date-time + additionalProperties: false Lpa: allOf: - $ref: "#/components/schemas/InitialLpa" diff --git a/internal/ddb/client.go b/internal/ddb/client.go new file mode 100644 index 00000000..0f539291 --- /dev/null +++ b/internal/ddb/client.go @@ -0,0 +1,69 @@ +package ddb + +import ( + "context" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" + "github.com/aws/aws-xray-sdk-go/xray" + "github.com/ministryofjustice/opg-data-lpa-store/internal/shared" +) + +type Client struct { + ddb *dynamodb.DynamoDB + tableName string +} + +func (c *Client) Put(ctx context.Context, data any) error { + item, err := dynamodbattribute.MarshalMap(data) + if err != nil { + return err + } + + _, err = c.ddb.PutItemWithContext(ctx, &dynamodb.PutItemInput{ + TableName: aws.String(c.tableName), + Item: item, + }) + + return err +} + +func (c *Client) Get(ctx context.Context, uid string) (shared.Lpa, error) { + lpa := shared.Lpa{} + + marshalledUid, err := dynamodbattribute.Marshal(uid) + if err != nil { + return lpa, err + } + + getItemOutput, err := c.ddb.GetItemWithContext(ctx, &dynamodb.GetItemInput{ + TableName: aws.String(c.tableName), + Key: map[string]*dynamodb.AttributeValue{ + "uid": marshalledUid, + }, + }) + + if err != nil { + return lpa, err + } + + err = dynamodbattribute.UnmarshalMap(getItemOutput.Item, &lpa) + + return lpa, err +} + +func New(endpoint, tableName string) *Client { + sess := session.Must(session.NewSession()) + sess.Config.Endpoint = &endpoint + + c := &Client{ + ddb: dynamodb.New(sess), + tableName: tableName, + } + + xray.AWS(c.ddb.Client) + + return c +} diff --git a/internal/shared/client.go b/internal/shared/client.go deleted file mode 100644 index abff0de6..00000000 --- a/internal/shared/client.go +++ /dev/null @@ -1,10 +0,0 @@ -package shared - -import ( - "context" -) - -type Client interface { - Put(ctx context.Context, data Lpa) error - Get(ctx context.Context, uid string) (Lpa, error) -} diff --git a/internal/shared/ddb.go b/internal/shared/ddb.go index 908beb39..a29b5e40 100644 --- a/internal/shared/ddb.go +++ b/internal/shared/ddb.go @@ -1,71 +1 @@ package shared - -import ( - "context" - "os" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/dynamodb" - "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" - "github.com/aws/aws-xray-sdk-go/xray" -) - -type DynamoDBClient struct { - ddb *dynamodb.DynamoDB - tableName string -} - -func (c DynamoDBClient) Put(ctx context.Context, data Lpa) error { - item, err := dynamodbattribute.MarshalMap(data) - if err != nil { - return err - } - - _, err = c.ddb.PutItemWithContext(ctx, &dynamodb.PutItemInput{ - TableName: aws.String(c.tableName), - Item: item, - }) - - return err -} - -func (c DynamoDBClient) Get(ctx context.Context, uid string) (Lpa, error) { - lpa := Lpa{} - - marshalledUid, err := dynamodbattribute.Marshal(uid) - if err != nil { - return lpa, err - } - - getItemOutput, err := c.ddb.GetItemWithContext(ctx, &dynamodb.GetItemInput{ - TableName: aws.String(c.tableName), - Key: map[string]*dynamodb.AttributeValue{ - "uid": marshalledUid, - }, - }) - - if err != nil { - return lpa, err - } - - err = dynamodbattribute.UnmarshalMap(getItemOutput.Item, &lpa) - - return lpa, err -} - -func NewDynamoDB(tableName string) DynamoDBClient { - sess := session.Must(session.NewSession()) - - endpoint := os.Getenv("AWS_DYNAMODB_ENDPOINT") - sess.Config.Endpoint = &endpoint - - c := DynamoDBClient{ - ddb: dynamodb.New(sess), - tableName: tableName, - } - - xray.AWS(c.ddb.Client) - - return c -} diff --git a/internal/shared/jwt.go b/internal/shared/jwt.go index f201512c..38d28d85 100644 --- a/internal/shared/jwt.go +++ b/internal/shared/jwt.go @@ -88,25 +88,6 @@ func NewJWTVerifier() JWTVerifier { } } -// tokenStr is the JWT token, minus any "Bearer: " prefix -func (v JWTVerifier) VerifyToken(tokenStr string) error { - lsc := lpaStoreClaims{} - - parsedToken, err := jwt.ParseWithClaims(tokenStr, &lsc, func(token *jwt.Token) (interface{}, error) { - return v.secretKey, nil - }) - - if err != nil { - return err - } - - if !parsedToken.Valid { - return fmt.Errorf("Invalid JWT") - } - - return nil -} - var bearerRegexp = regexp.MustCompile("^Bearer[ ]+") // verify JWT from event header @@ -119,9 +100,28 @@ func (v JWTVerifier) VerifyHeader(event events.APIGatewayProxyRequest) bool { } tokenStr := bearerRegexp.ReplaceAllString(jwtHeaders[0], "") - if v.VerifyToken(tokenStr) != nil { + if v.verifyToken(tokenStr) != nil { return false } return true } + +// tokenStr is the JWT token, minus any "Bearer: " prefix +func (v JWTVerifier) verifyToken(tokenStr string) error { + lsc := lpaStoreClaims{} + + parsedToken, err := jwt.ParseWithClaims(tokenStr, &lsc, func(token *jwt.Token) (interface{}, error) { + return v.secretKey, nil + }) + + if err != nil { + return err + } + + if !parsedToken.Valid { + return fmt.Errorf("Invalid JWT") + } + + return nil +} diff --git a/internal/shared/jwt_test.go b/internal/shared/jwt_test.go index b92e3122..8077da3d 100644 --- a/internal/shared/jwt_test.go +++ b/internal/shared/jwt_test.go @@ -26,7 +26,7 @@ func createToken(claims jwt.MapClaims) string { } func TestVerifyEmptyJwt(t *testing.T) { - err := verifier.VerifyToken("") + err := verifier.verifyToken("") assert.NotNil(t, err) } @@ -38,7 +38,7 @@ func TestVerifyExpInPast(t *testing.T) { "sub": "M-3467-89QW-ERTY", }) - err := verifier.VerifyToken(token) + err := verifier.verifyToken(token) assert.NotNil(t, err) if err != nil { @@ -54,7 +54,7 @@ func TestVerifyIatInFuture(t *testing.T) { "sub": "someone@someplace.somewhere.com", }) - err := verifier.VerifyToken(token) + err := verifier.verifyToken(token) assert.NotNil(t, err) if err != nil { @@ -70,7 +70,7 @@ func TestVerifyIssuer(t *testing.T) { "sub": "someone@someplace.somewhere.com", }) - err := verifier.VerifyToken(token) + err := verifier.verifyToken(token) assert.NotNil(t, err) if err != nil { @@ -86,7 +86,7 @@ func TestVerifyBadEmailForSiriusIssuer(t *testing.T) { "sub": "", }) - err := verifier.VerifyToken(token) + err := verifier.verifyToken(token) assert.NotNil(t, err) if err != nil { @@ -102,7 +102,7 @@ func TestVerifyBadUIDForMRLPAIssuer(t *testing.T) { "sub": "", }) - err := verifier.VerifyToken(token) + err := verifier.verifyToken(token) assert.NotNil(t, err) if err != nil { @@ -118,7 +118,7 @@ func TestVerifyGoodJwt(t *testing.T) { "sub": "someone@someplace.somewhere.com", }) - err := verifier.VerifyToken(token) + err := verifier.verifyToken(token) assert.Nil(t, err) } @@ -134,7 +134,7 @@ func TestNewJWTVerifier(t *testing.T) { newVerifier := NewJWTVerifier() os.Unsetenv("JWT_SECRET_KEY") - err := newVerifier.VerifyToken(token) + err := newVerifier.verifyToken(token) assert.Nil(t, err) } diff --git a/internal/shared/lang.go b/internal/shared/lang.go new file mode 100644 index 00000000..2fed3758 --- /dev/null +++ b/internal/shared/lang.go @@ -0,0 +1,12 @@ +package shared + +type Lang string + +var ( + LangCy = Lang("cy") + LangEn = Lang("en") +) + +func (l Lang) IsValid() bool { + return l == LangCy || l == LangEn +} diff --git a/internal/validate/validate.go b/internal/validate/validate.go new file mode 100644 index 00000000..e91bd0a2 --- /dev/null +++ b/internal/validate/validate.go @@ -0,0 +1,89 @@ +package validate + +import ( + "fmt" + "regexp" + "time" + + "github.com/ministryofjustice/opg-data-lpa-store/internal/shared" +) + +var countryCodeRe = regexp.MustCompile("^[A-Z]{2}$") + +func All(fieldErrors ...[]shared.FieldError) []shared.FieldError { + var errors []shared.FieldError + + for _, e := range fieldErrors { + if e != nil { + errors = append(errors, e...) + } + } + + return errors +} + +func IfElse(ok bool, eIf []shared.FieldError, eElse []shared.FieldError) []shared.FieldError { + if ok { + return eIf + } + + return eElse +} + +func If(ok bool, e []shared.FieldError) []shared.FieldError { + return IfElse(ok, e, nil) +} + +func Required(source string, value string) []shared.FieldError { + return If(value == "", []shared.FieldError{{Source: source, Detail: "field is required"}}) +} + +func Empty(source string, value string) []shared.FieldError { + return If(value != "", []shared.FieldError{{Source: source, Detail: "field must not be provided"}}) +} + +func Date(source string, date shared.Date) []shared.FieldError { + if date.IsMalformed { + return []shared.FieldError{{Source: source, Detail: "invalid format"}} + } + + if date.IsZero() { + return []shared.FieldError{{Source: source, Detail: "field is required"}} + } + + return nil +} + +func Time(source string, t time.Time) []shared.FieldError { + return If(t.IsZero(), []shared.FieldError{{Source: source, Detail: "field is required"}}) +} + +func Address(prefix string, address shared.Address) []shared.FieldError { + return All( + Required(fmt.Sprintf("%s/line1", prefix), address.Line1), + Required(fmt.Sprintf("%s/town", prefix), address.Town), + Required(fmt.Sprintf("%s/country", prefix), address.Country), + If(!countryCodeRe.MatchString(address.Country), []shared.FieldError{{Source: fmt.Sprintf("%s/country", prefix), Detail: "must be a valid ISO-3166-1 country code"}}), + ) +} + +type isValid interface { + ~string + IsValid() bool +} + +func IsValid[V isValid](source string, v V) []shared.FieldError { + if e := Required(source, string(v)); e != nil { + return e + } + + if !v.IsValid() { + return []shared.FieldError{{Source: source, Detail: "invalid value"}} + } + + return nil +} + +func Unset(source string, v interface{ Unset() bool }) []shared.FieldError { + return If(!v.Unset(), []shared.FieldError{{Source: source, Detail: "field must not be provided"}}) +} diff --git a/internal/validate/validate_test.go b/internal/validate/validate_test.go new file mode 100644 index 00000000..a74e0289 --- /dev/null +++ b/internal/validate/validate_test.go @@ -0,0 +1,112 @@ +package validate + +import ( + "testing" + "time" + + "github.com/ministryofjustice/opg-data-lpa-store/internal/shared" + "github.com/stretchr/testify/assert" +) + +var validAddress = shared.Address{ + Line1: "123 Main St", + Town: "Homeland", + Country: "GB", +} + +func newDate(date string, isMalformed bool) shared.Date { + t, _ := time.Parse("2006-01-02", date) + + return shared.Date{ + Time: t, + IsMalformed: isMalformed, + } +} + +func TestAll(t *testing.T) { + errA := shared.FieldError{Source: "a", Detail: "a"} + errB := shared.FieldError{Source: "b", Detail: "b"} + errC := shared.FieldError{Source: "c", Detail: "c"} + + assert.Nil(t, All()) + assert.Nil(t, All([]shared.FieldError{}, []shared.FieldError{})) + assert.Equal(t, []shared.FieldError{errA, errB, errC}, All([]shared.FieldError{errA, errB}, []shared.FieldError{errC})) + assert.Equal(t, []shared.FieldError{errA, errB, errC}, All([]shared.FieldError{errA}, []shared.FieldError{errB, errC})) + assert.Equal(t, []shared.FieldError{errA, errB, errC}, All([]shared.FieldError{errA}, []shared.FieldError{errB}, []shared.FieldError{errC})) +} + +func TestIf(t *testing.T) { + errs := []shared.FieldError{{Source: "a", Detail: "a"}} + + assert.Equal(t, errs, If(true, errs)) + assert.Nil(t, If(false, errs)) +} + +func TestIfElse(t *testing.T) { + errsA := []shared.FieldError{{Source: "a", Detail: "a"}} + errsB := []shared.FieldError{{Source: "b", Detail: "b"}} + + assert.Equal(t, errsA, IfElse(true, errsA, errsB)) + assert.Equal(t, errsB, IfElse(false, errsA, errsB)) +} + +func TestRequired(t *testing.T) { + assert.Nil(t, Required("a", "a")) + assert.Equal(t, []shared.FieldError{{Source: "a", Detail: "field is required"}}, Required("a", "")) +} + +func TestEmpty(t *testing.T) { + assert.Nil(t, Empty("a", "")) + assert.Equal(t, []shared.FieldError{{Source: "a", Detail: "field must not be provided"}}, Empty("a", "a")) +} + +func TestDate(t *testing.T) { + assert.Nil(t, Date("a", shared.Date{Time: time.Now()})) + assert.Equal(t, []shared.FieldError{{Source: "a", Detail: "invalid format"}}, Date("a", shared.Date{IsMalformed: true})) + assert.Equal(t, []shared.FieldError{{Source: "a", Detail: "field is required"}}, Date("a", shared.Date{})) +} + +func TestAddressEmpty(t *testing.T) { + address := shared.Address{} + errors := Address("/test", address) + + assert.Contains(t, errors, shared.FieldError{Source: "/test/line1", Detail: "field is required"}) + assert.Contains(t, errors, shared.FieldError{Source: "/test/town", Detail: "field is required"}) + assert.Contains(t, errors, shared.FieldError{Source: "/test/country", Detail: "field is required"}) +} + +func TestAddressValid(t *testing.T) { + errors := Address("/test", validAddress) + + assert.Empty(t, errors) +} + +func TestAddressInvalidCountry(t *testing.T) { + invalidAddress := shared.Address{ + Line1: "123 Main St", + Town: "Homeland", + Country: "United Kingdom", + } + errors := Address("/test", invalidAddress) + + assert.Contains(t, errors, shared.FieldError{Source: "/test/country", Detail: "must be a valid ISO-3166-1 country code"}) +} + +type testIsValid string + +func (t testIsValid) IsValid() bool { return string(t) == "ok" } + +func TestIsValid(t *testing.T) { + assert.Nil(t, IsValid("a", testIsValid("ok"))) + assert.Equal(t, []shared.FieldError{{Source: "a", Detail: "field is required"}}, IsValid("a", testIsValid(""))) + assert.Equal(t, []shared.FieldError{{Source: "a", Detail: "invalid value"}}, IsValid("a", testIsValid("x"))) +} + +type testUnset bool + +func (t testUnset) Unset() bool { return bool(t) } + +func TestUnset(t *testing.T) { + assert.Nil(t, Unset("a", testUnset(true))) + assert.Equal(t, []shared.FieldError{{Source: "a", Detail: "field must not be provided"}}, Unset("a", testUnset(false))) +} diff --git a/lambda/Dockerfile b/lambda/Dockerfile index e84e3bc3..acbee20d 100644 --- a/lambda/Dockerfile +++ b/lambda/Dockerfile @@ -6,7 +6,7 @@ COPY ./go.sum /app/go.sum RUN go mod download -COPY ./internal/shared /app/internal/shared +COPY ./internal /app/internal ARG DIR COPY ./lambda/$DIR /app/lambda/$DIR diff --git a/lambda/create-certificate-provider/main.go b/lambda/create-certificate-provider/main.go new file mode 100644 index 00000000..2862b000 --- /dev/null +++ b/lambda/create-certificate-provider/main.go @@ -0,0 +1,100 @@ +package main + +import ( + "context" + "encoding/json" + "os" + "time" + + "github.com/aws/aws-lambda-go/events" + "github.com/aws/aws-lambda-go/lambda" + "github.com/ministryofjustice/opg-data-lpa-store/internal/ddb" + "github.com/ministryofjustice/opg-data-lpa-store/internal/shared" + "github.com/ministryofjustice/opg-go-common/logging" +) + +type Logger interface { + Print(...interface{}) +} + +type Store interface { + Get(ctx context.Context, uid string) (shared.Lpa, error) + Put(ctx context.Context, data any) error +} + +type Verifier interface { + VerifyHeader(events.APIGatewayProxyRequest) bool +} + +type Lambda struct { + now func() time.Time + store Store + verifier Verifier + logger Logger +} + +func (l *Lambda) HandleEvent(ctx context.Context, event events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error) { + if !l.verifier.VerifyHeader(event) { + l.logger.Print("Unable to verify JWT from header") + return shared.ProblemUnauthorisedRequest.Respond() + } + + l.logger.Print("Successfully parsed JWT from event header") + + uid := event.PathParameters["uid"] + + response := events.APIGatewayProxyResponse{ + StatusCode: 500, + Body: "{\"code\":\"INTERNAL_SERVER_ERROR\",\"detail\":\"Internal server error\"}", + } + + // check for existing Lpa + existingLpa, err := l.store.Get(ctx, uid) + if err != nil { + l.logger.Print(err) + return shared.ProblemInternalServerError.Respond() + } + if existingLpa.Uid == "" { + return shared.ProblemNotFoundRequest.Respond() + } + + var input CertificateProvider + if err := json.Unmarshal([]byte(event.Body), &input); err != nil { + l.logger.Print(err) + return shared.ProblemInternalServerError.Respond() + } + + // validation + errors := Validate(input) + if len(errors) > 0 { + problem := shared.ProblemInvalidRequest + problem.Errors = errors + + return problem.Respond() + } + + input.UpdatedAt = l.now() + + // save + if err := l.store.Put(ctx, input); err != nil { + l.logger.Print(err) + return shared.ProblemInternalServerError.Respond() + } + + // respond + response.StatusCode = 201 + response.Body = `{}` + + return response, nil +} + +func main() { + l := &Lambda{ + now: time.Now, + store: ddb.New(os.Getenv("AWS_DYNAMODB_ENDPOINT"), os.Getenv("DDB_TABLE_NAME_DEEDS")), + verifier: shared.NewJWTVerifier(), + logger: logging.New(os.Stdout, "opg-data-lpa-store"), + } + + lambda.Start(l.HandleEvent) +} diff --git a/lambda/create-certificate-provider/main_test.go b/lambda/create-certificate-provider/main_test.go new file mode 100644 index 00000000..ff460174 --- /dev/null +++ b/lambda/create-certificate-provider/main_test.go @@ -0,0 +1,148 @@ +package main + +import ( + "bytes" + "context" + "errors" + "io" + "testing" + "time" + + "github.com/aws/aws-lambda-go/events" + "github.com/ministryofjustice/opg-data-lpa-store/internal/shared" + "github.com/ministryofjustice/opg-go-common/logging" + "github.com/stretchr/testify/assert" +) + +var expectedError = errors.New("expected") + +type mockStore struct { + get shared.Lpa + getErr error + put any + putErr error +} + +func (m *mockStore) Get(context.Context, string) (shared.Lpa, error) { return m.get, m.getErr } +func (m *mockStore) Put(ctx context.Context, data any) error { + m.put = data + return m.putErr +} + +type mockVerifier struct{ ok bool } + +func (m *mockVerifier) VerifyHeader(events.APIGatewayProxyRequest) bool { return m.ok } + +func TestHandleEvent(t *testing.T) { + now := time.Now() + store := &mockStore{get: shared.Lpa{Uid: "1"}} + l := Lambda{ + now: func() time.Time { return now }, + store: store, + verifier: &mockVerifier{ok: true}, + logger: logging.New(io.Discard, ""), + } + + resp, err := l.HandleEvent(context.Background(), events.APIGatewayProxyRequest{ + Body: `{"address":{"line1":"x","town":"y","country":"ZZ"},"signedAt":"2022-01-02T12:13:14.000000006Z","contactLanguagePreference":"en"}`, + }) + assert.Nil(t, err) + assert.Equal(t, 201, resp.StatusCode) + assert.JSONEq(t, `{}`, resp.Body) + assert.Equal(t, CertificateProvider{ + UpdatedAt: now, + Address: shared.Address{Line1: "x", Town: "y", Country: "ZZ"}, + SignedAt: time.Date(2022, time.January, 2, 12, 13, 14, 6, time.UTC), + ContactLanguagePreference: shared.LangEn, + }, store.put) +} + +func TestHandleEventWhenPutErrors(t *testing.T) { + var buf bytes.Buffer + l := Lambda{ + now: time.Now, + store: &mockStore{get: shared.Lpa{Uid: "1"}, putErr: expectedError}, + verifier: &mockVerifier{ok: true}, + logger: logging.New(&buf, ""), + } + + resp, err := l.HandleEvent(context.Background(), events.APIGatewayProxyRequest{ + Body: `{"address":{"line1":"x","town":"y","country":"ZZ"},"signedAt":"2022-01-02T12:13:14.000006Z","contactLanguagePreference":"en"}`, + }) + assert.Nil(t, err) + assert.Equal(t, 500, resp.StatusCode) + assert.JSONEq(t, `{"code":"INTERNAL_SERVER_ERROR","detail":"Internal server error"}`, resp.Body) + assert.Contains(t, buf.String(), "expected") +} + +func TestHandleEventWhenInvalid(t *testing.T) { + l := Lambda{ + store: &mockStore{get: shared.Lpa{Uid: "1"}}, + verifier: &mockVerifier{ok: true}, + logger: logging.New(io.Discard, ""), + } + + resp, err := l.HandleEvent(context.Background(), events.APIGatewayProxyRequest{ + Body: `{"address":{"line1":"x","town":"y","country":"ZZ"}}`, + }) + assert.Nil(t, err) + assert.Equal(t, 400, resp.StatusCode) + assert.JSONEq(t, `{"code":"INVALID_REQUEST","detail":"Invalid request","errors":[{"source":"/signedAt","detail":"field is required"},{"source":"/contactLanguagePreference","detail":"field is required"}]}`, resp.Body) +} + +func TestHandleEventWhenRequestJsonBad(t *testing.T) { + var buf bytes.Buffer + l := Lambda{ + store: &mockStore{get: shared.Lpa{Uid: "1"}}, + verifier: &mockVerifier{ok: true}, + logger: logging.New(&buf, ""), + } + + resp, err := l.HandleEvent(context.Background(), events.APIGatewayProxyRequest{}) + assert.Nil(t, err) + assert.Equal(t, 500, resp.StatusCode) + assert.JSONEq(t, `{"code":"INTERNAL_SERVER_ERROR","detail":"Internal server error"}`, resp.Body) + assert.Contains(t, buf.String(), `"unexpected end of JSON input"`) +} + +func TestHandleEventWhenUidNotFound(t *testing.T) { + l := Lambda{ + store: &mockStore{}, + verifier: &mockVerifier{ok: true}, + logger: logging.New(io.Discard, ""), + } + + resp, err := l.HandleEvent(context.Background(), events.APIGatewayProxyRequest{}) + assert.Nil(t, err) + assert.Equal(t, 404, resp.StatusCode) + assert.JSONEq(t, `{"code":"NOT_FOUND","detail":"Record not found"}`, resp.Body) +} + +func TestHandleEventWhenStoreGetErrors(t *testing.T) { + var buf bytes.Buffer + l := Lambda{ + store: &mockStore{getErr: expectedError}, + verifier: &mockVerifier{ok: true}, + logger: logging.New(&buf, ""), + } + + resp, err := l.HandleEvent(context.Background(), events.APIGatewayProxyRequest{}) + assert.Nil(t, err) + assert.Equal(t, 500, resp.StatusCode) + assert.JSONEq(t, `{"code":"INTERNAL_SERVER_ERROR","detail":"Internal server error"}`, resp.Body) + assert.Contains(t, buf.String(), `"expected"`) +} + +func TestHandleEventWhenHeaderNotVerified(t *testing.T) { + var buf bytes.Buffer + l := Lambda{ + verifier: &mockVerifier{ok: false}, + logger: logging.New(&buf, ""), + } + + resp, err := l.HandleEvent(context.Background(), events.APIGatewayProxyRequest{}) + assert.Nil(t, err) + assert.Equal(t, 401, resp.StatusCode) + assert.JSONEq(t, `{"code":"UNAUTHORISED","detail":"Invalid JWT"}`, resp.Body) + assert.Contains(t, buf.String(), `"Unable to verify JWT from header"`) +} diff --git a/lambda/create-certificate-provider/validate.go b/lambda/create-certificate-provider/validate.go new file mode 100644 index 00000000..4766b2f8 --- /dev/null +++ b/lambda/create-certificate-provider/validate.go @@ -0,0 +1,23 @@ +package main + +import ( + "time" + + "github.com/ministryofjustice/opg-data-lpa-store/internal/shared" + "github.com/ministryofjustice/opg-data-lpa-store/internal/validate" +) + +type CertificateProvider struct { + UpdatedAt time.Time `json:"updatedAt"` + Address shared.Address `json:"address"` + SignedAt time.Time `json:"signedAt"` + ContactLanguagePreference shared.Lang `json:"contactLanguagePreference"` +} + +func Validate(certificateProvider CertificateProvider) []shared.FieldError { + return validate.All( + validate.Address("/address", certificateProvider.Address), + validate.Time("/signedAt", certificateProvider.SignedAt), + validate.IsValid("/contactLanguagePreference", certificateProvider.ContactLanguagePreference), + ) +} diff --git a/lambda/create-certificate-provider/validate_test.go b/lambda/create-certificate-provider/validate_test.go new file mode 100644 index 00000000..f3c868f0 --- /dev/null +++ b/lambda/create-certificate-provider/validate_test.go @@ -0,0 +1,46 @@ +package main + +import ( + "testing" + "time" + + "github.com/ministryofjustice/opg-data-lpa-store/internal/shared" + "github.com/stretchr/testify/assert" +) + +func TestValidate(t *testing.T) { + validAddress := shared.Address{ + Line1: "123 Main St", + Town: "Homeland", + Country: "GB", + } + + testcases := map[string]struct { + certificateProvider CertificateProvider + errors []shared.FieldError + }{ + "valid": { + certificateProvider: CertificateProvider{ + Address: validAddress, + SignedAt: time.Now(), + ContactLanguagePreference: shared.LangCy, + }, + }, + "missing all": { + errors: []shared.FieldError{ + {Source: "/address/line1", Detail: "field is required"}, + {Source: "/address/town", Detail: "field is required"}, + {Source: "/address/country", Detail: "field is required"}, + {Source: "/address/country", Detail: "must be a valid ISO-3166-1 country code"}, + {Source: "/signedAt", Detail: "field is required"}, + {Source: "/contactLanguagePreference", Detail: "field is required"}, + }, + }, + } + + for name, tc := range testcases { + t.Run(name, func(t *testing.T) { + assert.ElementsMatch(t, tc.errors, Validate(tc.certificateProvider)) + }) + } +} diff --git a/lambda/create/main.go b/lambda/create/main.go index 197a4ac6..ceb66965 100644 --- a/lambda/create/main.go +++ b/lambda/create/main.go @@ -8,19 +8,22 @@ import ( "github.com/aws/aws-lambda-go/events" "github.com/aws/aws-lambda-go/lambda" + "github.com/ministryofjustice/opg-data-lpa-store/internal/ddb" "github.com/ministryofjustice/opg-data-lpa-store/internal/shared" "github.com/ministryofjustice/opg-go-common/logging" ) -type Response struct { -} - type Logger interface { Print(...interface{}) } +type Store interface { + Put(ctx context.Context, data any) error + Get(ctx context.Context, uid string) (shared.Lpa, error) +} + type Lambda struct { - store shared.Client + store Store verifier shared.JWTVerifier logger Logger } @@ -84,22 +87,15 @@ func (l *Lambda) HandleEvent(ctx context.Context, event events.APIGatewayProxyRe } // respond - body, err := json.Marshal(Response{}) - - if err != nil { - l.logger.Print(err) - return shared.ProblemInternalServerError.Respond() - } - response.StatusCode = 201 - response.Body = string(body) + response.Body = `{}` return response, nil } func main() { l := &Lambda{ - store: shared.NewDynamoDB(os.Getenv("DDB_TABLE_NAME_DEEDS")), + store: ddb.New(os.Getenv("AWS_DYNAMODB_ENDPOINT"), os.Getenv("DDB_TABLE_NAME_DEEDS")), verifier: shared.NewJWTVerifier(), logger: logging.New(os.Stdout, "opg-data-lpa-store"), } diff --git a/lambda/create/validate.go b/lambda/create/validate.go index 29ab8215..74707a24 100644 --- a/lambda/create/validate.go +++ b/lambda/create/validate.go @@ -2,56 +2,53 @@ package main import ( "fmt" - "regexp" - "time" "github.com/ministryofjustice/opg-data-lpa-store/internal/shared" + "github.com/ministryofjustice/opg-data-lpa-store/internal/validate" ) -var countryCodeRe = regexp.MustCompile("^[A-Z]{2}$") - func Validate(lpa shared.LpaInit) []shared.FieldError { activeAttorneyCount, replacementAttorneyCount := countAttorneys(lpa.Attorneys, lpa.TrustCorporations) - return flatten( - validateIsValid("/lpaType", lpa.LpaType), - required("/donor/firstNames", lpa.Donor.FirstNames), - required("/donor/lastName", lpa.Donor.LastName), - validateDate("/donor/dateOfBirth", lpa.Donor.DateOfBirth), - validateAddress("/donor/address", lpa.Donor.Address), - required("/certificateProvider/firstNames", lpa.CertificateProvider.FirstNames), - required("/certificateProvider/lastName", lpa.CertificateProvider.LastName), - validateAddress("/certificateProvider/address", lpa.CertificateProvider.Address), - validateIsValid("/certificateProvider/channel", lpa.CertificateProvider.Channel), - validateIfElse(lpa.CertificateProvider.Channel == shared.ChannelOnline, - required("/certificateProvider/email", lpa.CertificateProvider.Email), - empty("/certificateProvider/email", lpa.CertificateProvider.Email)), + return validate.All( + validate.IsValid("/lpaType", lpa.LpaType), + validate.Required("/donor/firstNames", lpa.Donor.FirstNames), + validate.Required("/donor/lastName", lpa.Donor.LastName), + validate.Date("/donor/dateOfBirth", lpa.Donor.DateOfBirth), + validate.Address("/donor/address", lpa.Donor.Address), + validate.Required("/certificateProvider/firstNames", lpa.CertificateProvider.FirstNames), + validate.Required("/certificateProvider/lastName", lpa.CertificateProvider.LastName), + validate.Address("/certificateProvider/address", lpa.CertificateProvider.Address), + validate.IsValid("/certificateProvider/channel", lpa.CertificateProvider.Channel), + validate.IfElse(lpa.CertificateProvider.Channel == shared.ChannelOnline, + validate.Required("/certificateProvider/email", lpa.CertificateProvider.Email), + validate.Empty("/certificateProvider/email", lpa.CertificateProvider.Email)), validateAttorneys("/attorneys", lpa.Attorneys), validateTrustCorporations("/trustCorporations", lpa.TrustCorporations), - validateIfElse(activeAttorneyCount > 1, - validateIsValid("/howAttorneysMakeDecisions", lpa.HowAttorneysMakeDecisions), - validateUnset("/howAttorneysMakeDecisions", lpa.HowAttorneysMakeDecisions)), - validateIfElse(lpa.HowAttorneysMakeDecisions == shared.HowMakeDecisionsJointlyForSomeSeverallyForOthers, - required("/howAttorneysMakeDecisionsDetails", lpa.HowAttorneysMakeDecisionsDetails), - empty("/howAttorneysMakeDecisionsDetails", lpa.HowAttorneysMakeDecisionsDetails)), - validateIf(replacementAttorneyCount > 0 && lpa.HowAttorneysMakeDecisions == shared.HowMakeDecisionsJointlyAndSeverally, - validateIsValid("/howReplacementAttorneysStepIn", lpa.HowReplacementAttorneysStepIn)), - validateIfElse(lpa.HowReplacementAttorneysStepIn == shared.HowStepInAnotherWay, - required("/howReplacementAttorneysStepInDetails", lpa.HowReplacementAttorneysStepInDetails), - empty("/howReplacementAttorneysStepInDetails", lpa.HowReplacementAttorneysStepInDetails)), - validateIfElse(replacementAttorneyCount > 1 && (lpa.HowReplacementAttorneysStepIn == shared.HowStepInAllCanNoLongerAct || lpa.HowAttorneysMakeDecisions != shared.HowMakeDecisionsJointlyAndSeverally), - validateIsValid("/howReplacementAttorneysMakeDecisions", lpa.HowReplacementAttorneysMakeDecisions), - validateUnset("/howReplacementAttorneysMakeDecisions", lpa.HowReplacementAttorneysMakeDecisions)), - validateIfElse(lpa.HowReplacementAttorneysMakeDecisions == shared.HowMakeDecisionsJointlyForSomeSeverallyForOthers, - required("/howReplacementAttorneysMakeDecisionsDetails", lpa.HowReplacementAttorneysMakeDecisionsDetails), - empty("/howReplacementAttorneysMakeDecisionsDetails", lpa.HowReplacementAttorneysMakeDecisionsDetails)), - validateIf(lpa.LpaType == shared.LpaTypePersonalWelfare, flatten( - validateIsValid("/lifeSustainingTreatmentOption", lpa.LifeSustainingTreatmentOption), - validateUnset("/whenTheLpaCanBeUsed", lpa.WhenTheLpaCanBeUsed))), - validateIf(lpa.LpaType == shared.LpaTypePropertyAndAffairs, flatten( - validateIsValid("/whenTheLpaCanBeUsed", lpa.WhenTheLpaCanBeUsed), - validateUnset("/lifeSustainingTreatmentOption", lpa.LifeSustainingTreatmentOption))), - validateTime("/signedAt", lpa.SignedAt), + validate.IfElse(activeAttorneyCount > 1, + validate.IsValid("/howAttorneysMakeDecisions", lpa.HowAttorneysMakeDecisions), + validate.Unset("/howAttorneysMakeDecisions", lpa.HowAttorneysMakeDecisions)), + validate.IfElse(lpa.HowAttorneysMakeDecisions == shared.HowMakeDecisionsJointlyForSomeSeverallyForOthers, + validate.Required("/howAttorneysMakeDecisionsDetails", lpa.HowAttorneysMakeDecisionsDetails), + validate.Empty("/howAttorneysMakeDecisionsDetails", lpa.HowAttorneysMakeDecisionsDetails)), + validate.If(replacementAttorneyCount > 0 && lpa.HowAttorneysMakeDecisions == shared.HowMakeDecisionsJointlyAndSeverally, + validate.IsValid("/howReplacementAttorneysStepIn", lpa.HowReplacementAttorneysStepIn)), + validate.IfElse(lpa.HowReplacementAttorneysStepIn == shared.HowStepInAnotherWay, + validate.Required("/howReplacementAttorneysStepInDetails", lpa.HowReplacementAttorneysStepInDetails), + validate.Empty("/howReplacementAttorneysStepInDetails", lpa.HowReplacementAttorneysStepInDetails)), + validate.IfElse(replacementAttorneyCount > 1 && (lpa.HowReplacementAttorneysStepIn == shared.HowStepInAllCanNoLongerAct || lpa.HowAttorneysMakeDecisions != shared.HowMakeDecisionsJointlyAndSeverally), + validate.IsValid("/howReplacementAttorneysMakeDecisions", lpa.HowReplacementAttorneysMakeDecisions), + validate.Unset("/howReplacementAttorneysMakeDecisions", lpa.HowReplacementAttorneysMakeDecisions)), + validate.IfElse(lpa.HowReplacementAttorneysMakeDecisions == shared.HowMakeDecisionsJointlyForSomeSeverallyForOthers, + validate.Required("/howReplacementAttorneysMakeDecisionsDetails", lpa.HowReplacementAttorneysMakeDecisionsDetails), + validate.Empty("/howReplacementAttorneysMakeDecisionsDetails", lpa.HowReplacementAttorneysMakeDecisionsDetails)), + validate.If(lpa.LpaType == shared.LpaTypePersonalWelfare, validate.All( + validate.IsValid("/lifeSustainingTreatmentOption", lpa.LifeSustainingTreatmentOption), + validate.Unset("/whenTheLpaCanBeUsed", lpa.WhenTheLpaCanBeUsed))), + validate.If(lpa.LpaType == shared.LpaTypePropertyAndAffairs, validate.All( + validate.IsValid("/whenTheLpaCanBeUsed", lpa.WhenTheLpaCanBeUsed), + validate.Unset("/lifeSustainingTreatmentOption", lpa.LifeSustainingTreatmentOption))), + validate.Time("/signedAt", lpa.SignedAt), ) } @@ -77,84 +74,6 @@ func countAttorneys(as []shared.Attorney, ts []shared.TrustCorporation) (actives return actives, replacements } -func flatten(fieldErrors ...[]shared.FieldError) []shared.FieldError { - var errors []shared.FieldError - - for _, e := range fieldErrors { - if e != nil { - errors = append(errors, e...) - } - } - - return errors -} - -func validateIfElse(ok bool, eIf []shared.FieldError, eElse []shared.FieldError) []shared.FieldError { - if ok { - return eIf - } - - return eElse -} - -func validateIf(ok bool, e []shared.FieldError) []shared.FieldError { - return validateIfElse(ok, e, nil) -} - -func required(source string, value string) []shared.FieldError { - return validateIf(value == "", []shared.FieldError{{Source: source, Detail: "field is required"}}) -} - -func empty(source string, value string) []shared.FieldError { - return validateIf(value != "", []shared.FieldError{{Source: source, Detail: "field must not be provided"}}) -} - -func validateDate(source string, date shared.Date) []shared.FieldError { - if date.IsMalformed { - return []shared.FieldError{{Source: source, Detail: "invalid format"}} - } - - if date.IsZero() { - return []shared.FieldError{{Source: source, Detail: "field is required"}} - } - - return nil -} - -func validateTime(source string, t time.Time) []shared.FieldError { - return validateIf(t.IsZero(), []shared.FieldError{{Source: source, Detail: "field is required"}}) -} - -func validateAddress(prefix string, address shared.Address) []shared.FieldError { - return flatten( - required(fmt.Sprintf("%s/line1", prefix), address.Line1), - required(fmt.Sprintf("%s/town", prefix), address.Town), - required(fmt.Sprintf("%s/country", prefix), address.Country), - validateIf(!countryCodeRe.MatchString(address.Country), []shared.FieldError{{Source: fmt.Sprintf("%s/country", prefix), Detail: "must be a valid ISO-3166-1 country code"}}), - ) -} - -type isValid interface { - ~string - IsValid() bool -} - -func validateIsValid[V isValid](source string, v V) []shared.FieldError { - if e := required(source, string(v)); e != nil { - return e - } - - if !v.IsValid() { - return []shared.FieldError{{Source: source, Detail: "invalid value"}} - } - - return nil -} - -func validateUnset(source string, v interface{ Unset() bool }) []shared.FieldError { - return validateIf(!v.Unset(), []shared.FieldError{{Source: source, Detail: "field must not be provided"}}) -} - func validateAttorneys(prefix string, attorneys []shared.Attorney) []shared.FieldError { var errors []shared.FieldError @@ -172,13 +91,13 @@ func validateAttorneys(prefix string, attorneys []shared.Attorney) []shared.Fiel } func validateAttorney(prefix string, attorney shared.Attorney) []shared.FieldError { - return flatten( - required(fmt.Sprintf("%s/firstNames", prefix), attorney.FirstNames), - required(fmt.Sprintf("%s/lastName", prefix), attorney.LastName), - required(fmt.Sprintf("%s/status", prefix), string(attorney.Status)), - validateDate(fmt.Sprintf("%s/dateOfBirth", prefix), attorney.DateOfBirth), - validateAddress(fmt.Sprintf("%s/address", prefix), attorney.Address), - validateIsValid(fmt.Sprintf("%s/status", prefix), attorney.Status), + return validate.All( + validate.Required(fmt.Sprintf("%s/firstNames", prefix), attorney.FirstNames), + validate.Required(fmt.Sprintf("%s/lastName", prefix), attorney.LastName), + validate.Required(fmt.Sprintf("%s/status", prefix), string(attorney.Status)), + validate.Date(fmt.Sprintf("%s/dateOfBirth", prefix), attorney.DateOfBirth), + validate.Address(fmt.Sprintf("%s/address", prefix), attorney.Address), + validate.IsValid(fmt.Sprintf("%s/status", prefix), attorney.Status), ) } @@ -195,11 +114,11 @@ func validateTrustCorporations(prefix string, trustCorporations []shared.TrustCo } func validateTrustCorporation(prefix string, trustCorporation shared.TrustCorporation) []shared.FieldError { - return flatten( - required(fmt.Sprintf("%s/name", prefix), trustCorporation.Name), - required(fmt.Sprintf("%s/companyNumber", prefix), trustCorporation.CompanyNumber), - required(fmt.Sprintf("%s/email", prefix), trustCorporation.Email), - validateAddress(fmt.Sprintf("%s/address", prefix), trustCorporation.Address), - validateIsValid(fmt.Sprintf("%s/status", prefix), trustCorporation.Status), + return validate.All( + validate.Required(fmt.Sprintf("%s/name", prefix), trustCorporation.Name), + validate.Required(fmt.Sprintf("%s/companyNumber", prefix), trustCorporation.CompanyNumber), + validate.Required(fmt.Sprintf("%s/email", prefix), trustCorporation.Email), + validate.Address(fmt.Sprintf("%s/address", prefix), trustCorporation.Address), + validate.IsValid(fmt.Sprintf("%s/status", prefix), trustCorporation.Status), ) } diff --git a/lambda/create/validate_test.go b/lambda/create/validate_test.go index 2482fdc6..69999243 100644 --- a/lambda/create/validate_test.go +++ b/lambda/create/validate_test.go @@ -40,94 +40,6 @@ func TestCountAttorneys(t *testing.T) { assert.Equal(t, 3, replacements) } -func TestFlatten(t *testing.T) { - errA := shared.FieldError{Source: "a", Detail: "a"} - errB := shared.FieldError{Source: "b", Detail: "b"} - errC := shared.FieldError{Source: "c", Detail: "c"} - - assert.Nil(t, flatten()) - assert.Nil(t, flatten([]shared.FieldError{}, []shared.FieldError{})) - assert.Equal(t, []shared.FieldError{errA, errB, errC}, flatten([]shared.FieldError{errA, errB}, []shared.FieldError{errC})) - assert.Equal(t, []shared.FieldError{errA, errB, errC}, flatten([]shared.FieldError{errA}, []shared.FieldError{errB, errC})) - assert.Equal(t, []shared.FieldError{errA, errB, errC}, flatten([]shared.FieldError{errA}, []shared.FieldError{errB}, []shared.FieldError{errC})) -} - -func TestValidateIf(t *testing.T) { - errs := []shared.FieldError{{Source: "a", Detail: "a"}} - - assert.Equal(t, errs, validateIf(true, errs)) - assert.Nil(t, validateIf(false, errs)) -} - -func TestValidateIfElse(t *testing.T) { - errsA := []shared.FieldError{{Source: "a", Detail: "a"}} - errsB := []shared.FieldError{{Source: "b", Detail: "b"}} - - assert.Equal(t, errsA, validateIfElse(true, errsA, errsB)) - assert.Equal(t, errsB, validateIfElse(false, errsA, errsB)) -} - -func TestRequired(t *testing.T) { - assert.Nil(t, required("a", "a")) - assert.Equal(t, []shared.FieldError{{Source: "a", Detail: "field is required"}}, required("a", "")) -} - -func TestEmpty(t *testing.T) { - assert.Nil(t, empty("a", "")) - assert.Equal(t, []shared.FieldError{{Source: "a", Detail: "field must not be provided"}}, empty("a", "a")) -} - -func TestValidateDate(t *testing.T) { - assert.Nil(t, validateDate("a", shared.Date{Time: time.Now()})) - assert.Equal(t, []shared.FieldError{{Source: "a", Detail: "invalid format"}}, validateDate("a", shared.Date{IsMalformed: true})) - assert.Equal(t, []shared.FieldError{{Source: "a", Detail: "field is required"}}, validateDate("a", shared.Date{})) -} - -func TestValidateAddressEmpty(t *testing.T) { - address := shared.Address{} - errors := validateAddress("/test", address) - - assert.Contains(t, errors, shared.FieldError{Source: "/test/line1", Detail: "field is required"}) - assert.Contains(t, errors, shared.FieldError{Source: "/test/town", Detail: "field is required"}) - assert.Contains(t, errors, shared.FieldError{Source: "/test/country", Detail: "field is required"}) -} - -func TestValidateAddressValid(t *testing.T) { - errors := validateAddress("/test", validAddress) - - assert.Empty(t, errors) -} - -func TestValidateAddressInvalidCountry(t *testing.T) { - invalidAddress := shared.Address{ - Line1: "123 Main St", - Town: "Homeland", - Country: "United Kingdom", - } - errors := validateAddress("/test", invalidAddress) - - assert.Contains(t, errors, shared.FieldError{Source: "/test/country", Detail: "must be a valid ISO-3166-1 country code"}) -} - -type testIsValid string - -func (t testIsValid) IsValid() bool { return string(t) == "ok" } - -func TestValidateIsValid(t *testing.T) { - assert.Nil(t, validateIsValid("a", testIsValid("ok"))) - assert.Equal(t, []shared.FieldError{{Source: "a", Detail: "field is required"}}, validateIsValid("a", testIsValid(""))) - assert.Equal(t, []shared.FieldError{{Source: "a", Detail: "invalid value"}}, validateIsValid("a", testIsValid("x"))) -} - -type testUnset bool - -func (t testUnset) Unset() bool { return bool(t) } - -func TestValidateUnset(t *testing.T) { - assert.Nil(t, validateUnset("a", testUnset(true))) - assert.Equal(t, []shared.FieldError{{Source: "a", Detail: "field must not be provided"}}, validateUnset("a", testUnset(false))) -} - func TestValidateAttorneyEmpty(t *testing.T) { attorney := shared.Attorney{} errors := validateAttorney("/test", attorney) diff --git a/lambda/get/main.go b/lambda/get/main.go index c0d7a1e9..d5b8c1bd 100644 --- a/lambda/get/main.go +++ b/lambda/get/main.go @@ -7,6 +7,7 @@ import ( "github.com/aws/aws-lambda-go/events" "github.com/aws/aws-lambda-go/lambda" + "github.com/ministryofjustice/opg-data-lpa-store/internal/ddb" "github.com/ministryofjustice/opg-data-lpa-store/internal/shared" "github.com/ministryofjustice/opg-go-common/logging" ) @@ -15,8 +16,12 @@ type Logger interface { Print(...interface{}) } +type Store interface { + Get(ctx context.Context, uid string) (shared.Lpa, error) +} + type Lambda struct { - store shared.Client + store Store verifier shared.JWTVerifier logger Logger } @@ -36,8 +41,8 @@ func (l *Lambda) HandleEvent(ctx context.Context, event events.APIGatewayProxyRe lpa, err := l.store.Get(ctx, event.PathParameters["uid"]) - // If item can't be found in DynamoDB then it returns empty object hence 404 error returned if - // empty object returned + // If item can't be found in DynamoDB then it returns empty object hence 404 error returned if + // empty object returned if lpa.Uid == "" { l.logger.Print("Uid not found") return shared.ProblemNotFoundRequest.Respond() @@ -63,7 +68,7 @@ func (l *Lambda) HandleEvent(ctx context.Context, event events.APIGatewayProxyRe func main() { l := &Lambda{ - store: shared.NewDynamoDB(os.Getenv("DDB_TABLE_NAME_DEEDS")), + store: ddb.New(os.Getenv("AWS_DYNAMODB_ENDPOINT"), os.Getenv("DDB_TABLE_NAME_DEEDS")), verifier: shared.NewJWTVerifier(), logger: logging.New(os.Stdout, "opg-data-lpa-store"), } diff --git a/lambda/update/main.go b/lambda/update/main.go index 0a2def0d..2172e827 100644 --- a/lambda/update/main.go +++ b/lambda/update/main.go @@ -9,6 +9,7 @@ import ( "github.com/aws/aws-lambda-go/events" "github.com/aws/aws-lambda-go/lambda" "github.com/go-openapi/jsonpointer" + "github.com/ministryofjustice/opg-data-lpa-store/internal/ddb" "github.com/ministryofjustice/opg-data-lpa-store/internal/shared" "github.com/ministryofjustice/opg-go-common/logging" ) @@ -17,8 +18,13 @@ type Logger interface { Print(...interface{}) } +type Store interface { + Put(ctx context.Context, data any) error + Get(ctx context.Context, uid string) (shared.Lpa, error) +} + type Lambda struct { - store shared.Client + store Store verifier shared.JWTVerifier logger Logger } @@ -114,7 +120,7 @@ func applyUpdate(lpa *shared.Lpa, update shared.Update) ([]shared.FieldError, er func main() { l := &Lambda{ - store: shared.NewDynamoDB(os.Getenv("DDB_TABLE_NAME_DEEDS")), + store: ddb.New(os.Getenv("AWS_DYNAMODB_ENDPOINT"), os.Getenv("DDB_TABLE_NAME_DEEDS")), verifier: shared.NewJWTVerifier(), logger: logging.New(os.Stdout, "opg-data-lpa-store"), } diff --git a/terraform/environment/region/apigateway.tf b/terraform/environment/region/apigateway.tf index 2b59bf05..50d4737d 100644 --- a/terraform/environment/region/apigateway.tf +++ b/terraform/environment/region/apigateway.tf @@ -4,6 +4,7 @@ locals { lambda_create_invoke_arn = module.lambda["create"].invoke_arn lambda_get_invoke_arn = module.lambda["get"].invoke_arn lambda_update_invoke_arn = module.lambda["update"].invoke_arn + lambda_create_certificate_provider_invoke_arn = module.lambda["create-certificate-provider"].invoke_arn }) } diff --git a/terraform/environment/region/main.tf b/terraform/environment/region/main.tf index 458b2bc9..c5d58be1 100644 --- a/terraform/environment/region/main.tf +++ b/terraform/environment/region/main.tf @@ -3,6 +3,7 @@ locals { "create", "get", "update", + "create-certificate-provider", ]) }