From c0cde3c75f2b87b6333b120095985be36456da6d Mon Sep 17 00:00:00 2001 From: James Date: Tue, 14 Jan 2025 10:17:38 +0000 Subject: [PATCH] Feat/ssm 24 (#29) * feat/SSM-24: added auth middleware and refactored components. * feat/SSM-24: fixed invalid creds string for auth middleware_test.go. * feat/SSM-24: fixed service_cases_test.go. * feat/SSM-24: refactored solution to provide more pluggable auth handlers. * feat/SSM-24: minor clean-up, removed duplication. * feat/SSM-24: changed routes. * feat/SSM-24: minor fixes based on latest review of PR. * feat/SSM-24: removed redundant HashPassword * feat/SSM-24: added test authenticator on creds and added segregation to authenticator logic. --- Makefile | 13 +- service-app/config/config.go | 7 +- service-app/go.mod | 22 ++- service-app/go.sum | 20 ++- service-app/internal/api/handler.go | 81 +++++++--- service-app/internal/api/service.go | 4 + .../internal/api/service_cases_test.go | 89 ++++------ .../internal/api/service_documents_test.go | 44 ++--- service-app/internal/auth/authenticator.go | 82 ++++++++++ .../internal/auth/authenticator_test.go | 52 ++++++ service-app/internal/auth/cookie.go | 37 +++++ service-app/internal/auth/middleware.go | 56 +++++++ service-app/internal/auth/middleware_mock.go | 41 +++++ service-app/internal/auth/middleware_test.go | 75 +++++++++ service-app/internal/auth/token.go | 153 ++++++++++++++++++ service-app/internal/auth/user.go | 18 +++ service-app/internal/aws/client.go | 45 ++++++ service-app/internal/aws/client_mock.go | 5 + service-app/internal/httpclient/middleware.go | 128 +++------------ .../internal/httpclient/middleware_test.go | 75 --------- .../{httpclient => mocks}/request_mock.go | 2 +- 21 files changed, 732 insertions(+), 317 deletions(-) create mode 100644 service-app/internal/auth/authenticator.go create mode 100644 service-app/internal/auth/authenticator_test.go create mode 100644 service-app/internal/auth/cookie.go create mode 100644 service-app/internal/auth/middleware.go create mode 100644 service-app/internal/auth/middleware_mock.go create mode 100644 service-app/internal/auth/middleware_test.go create mode 100644 service-app/internal/auth/token.go create mode 100644 service-app/internal/auth/user.go delete mode 100644 service-app/internal/httpclient/middleware_test.go rename service-app/internal/{httpclient => mocks}/request_mock.go (97%) diff --git a/Makefile b/Makefile index 02924a9..79ac707 100644 --- a/Makefile +++ b/Makefile @@ -1,10 +1,6 @@ -.PHONY: all test clean build start +.PHONY: all test build start clean -all: build test start clean - -build: - @echo "Building the application image using Docker Compose..." - @docker-compose build || { echo "Failed to build the application image"; exit 1; } +all: test start clean test: @echo "Running tests in the service-app-test container..." @@ -12,7 +8,12 @@ test: @docker-compose run --rm service-app-test || { echo "Tests failed"; exit 1; } @docker-compose down --remove-orphans --volumes service-app-test +build: + @echo "Building the application..." + @docker-compose build || { echo "Failed to build the application image"; exit 1; } + start: + @${MAKE} build @echo "Running the application using Docker Compose..." @docker-compose up -d || { echo "Failed to start Docker Compose"; exit 1; } diff --git a/service-app/config/config.go b/service-app/config/config.go index 5b5c4bd..5ab1499 100644 --- a/service-app/config/config.go +++ b/service-app/config/config.go @@ -33,9 +33,10 @@ type ( } Auth struct { - ApiUsername string `envconfig:"API_USERNAME" default:"opg_document_and_d@publicguardian.gsi.gov.uk"` - JWTSecretARN string `envconfig:"JWT_SECRET_ARN" default:"local/jwt-key"` - JWTExpiration int `envconfig:"JWT_EXPIRATION" default:"3600"` + ApiUsername string `envconfig:"API_USERNAME" default:"opg_document_and_d@publicguardian.gsi.gov.uk"` + JWTSecretARN string `envconfig:"JWT_SECRET_ARN" default:"local/jwt-key"` + CredentialsARN string `envconfig:"CREDENTIALS_ARN" default:"/local/local-credentials"` + JWTExpiration int `envconfig:"JWT_EXPIRATION" default:"3600"` } HTTP struct { diff --git a/service-app/go.mod b/service-app/go.mod index 7f365d6..c70ac61 100644 --- a/service-app/go.mod +++ b/service-app/go.mod @@ -4,7 +4,7 @@ go 1.23.2 require ( github.com/aws/aws-sdk-go v1.55.5 - github.com/aws/aws-sdk-go-v2 v1.32.6 + github.com/aws/aws-sdk-go-v2 v1.32.7 github.com/kelseyhightower/envconfig v1.4.0 github.com/stretchr/testify v1.10.0 ) @@ -13,8 +13,8 @@ require ( github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7 // indirect github.com/aws/aws-sdk-go-v2/credentials v1.17.47 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.21 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.25 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.25 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.26 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.26 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.25 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 // indirect @@ -33,8 +33,6 @@ require ( github.com/grpc-ecosystem/grpc-gateway/v2 v2.24.0 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/pkg/errors v0.9.1 // indirect - github.com/stretchr/objx v0.5.2 // indirect - go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/contrib/detectors/aws/ecs v1.33.0 // indirect go.opentelemetry.io/contrib/propagators/aws v1.33.0 // indirect go.opentelemetry.io/otel v1.33.0 // indirect @@ -45,7 +43,7 @@ require ( go.opentelemetry.io/otel/trace v1.33.0 // indirect go.opentelemetry.io/proto/otlp v1.4.0 // indirect golang.org/x/net v0.33.0 // indirect - golang.org/x/sys v0.28.0 // indirect + golang.org/x/sys v0.29.0 // indirect golang.org/x/text v0.21.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20241209162323-e6fa225c2576 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20241209162323-e6fa225c2576 // indirect @@ -53,14 +51,24 @@ require ( google.golang.org/protobuf v1.35.2 // indirect ) +require ( + github.com/stretchr/objx v0.5.2 // indirect + go.opentelemetry.io/auto/sdk v1.1.0 // indirect +) + require ( github.com/aws/aws-sdk-go-v2/config v1.28.6 github.com/aws/aws-sdk-go-v2/service/s3 v1.71.0 github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.34.7 - github.com/davecgh/go-spew v1.1.1 // indirect + github.com/aws/aws-sdk-go-v2/service/ssm v1.56.2 github.com/golang-jwt/jwt/v5 v5.2.1 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect github.com/lestrrat-go/libxml2 v0.0.0-20240905100032-c934e3fcb9d3 github.com/ministryofjustice/opg-go-common v1.45.0 github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/crypto v0.32.0 gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/service-app/go.sum b/service-app/go.sum index 9b97343..4067a05 100644 --- a/service-app/go.sum +++ b/service-app/go.sum @@ -1,7 +1,7 @@ github.com/aws/aws-sdk-go v1.55.5 h1:KKUZBfBoyqy5d3swXyiC7Q76ic40rYcbqH7qjh59kzU= github.com/aws/aws-sdk-go v1.55.5/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= -github.com/aws/aws-sdk-go-v2 v1.32.6 h1:7BokKRgRPuGmKkFMhEg/jSul+tB9VvXhcViILtfG8b4= -github.com/aws/aws-sdk-go-v2 v1.32.6/go.mod h1:P5WJBrYqqbWVaOxgH0X/FYYD47/nooaPOZPlQdmiN2U= +github.com/aws/aws-sdk-go-v2 v1.32.7 h1:ky5o35oENWi0JYWUZkB7WYvVPP+bcRF5/Iq7JWSb5Rw= +github.com/aws/aws-sdk-go-v2 v1.32.7/go.mod h1:P5WJBrYqqbWVaOxgH0X/FYYD47/nooaPOZPlQdmiN2U= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7 h1:lL7IfaFzngfx0ZwUGOZdsFFnQ5uLvR0hWqqhyE7Q9M8= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7/go.mod h1:QraP0UcVlQJsmHfioCrveWOC1nbiWUl3ej08h4mXWoc= github.com/aws/aws-sdk-go-v2/config v1.28.6 h1:D89IKtGrs/I3QXOLNTH93NJYtDhm8SYa9Q5CsPShmyo= @@ -10,10 +10,10 @@ github.com/aws/aws-sdk-go-v2/credentials v1.17.47 h1:48bA+3/fCdi2yAwVt+3COvmatZ6 github.com/aws/aws-sdk-go-v2/credentials v1.17.47/go.mod h1:+KdckOejLW3Ks3b0E3b5rHsr2f9yuORBum0WPnE5o5w= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.21 h1:AmoU1pziydclFT/xRV+xXE/Vb8fttJCLRPv8oAkprc0= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.21/go.mod h1:AjUdLYe4Tgs6kpH4Bv7uMZo7pottoyHMn4eTcIcneaY= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.25 h1:s/fF4+yDQDoElYhfIVvSNyeCydfbuTKzhxSXDXCPasU= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.25/go.mod h1:IgPfDv5jqFIzQSNbUEMoitNooSMXjRSDkhXv8jiROvU= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.25 h1:ZntTCl5EsYnhN/IygQEUugpdwbhdkom9uHcbCftiGgA= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.25/go.mod h1:DBdPrgeocww+CSl1C8cEV8PN1mHMBhuCDLpXezyvWkE= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.26 h1:I/5wmGMffY4happ8NOCuIUEWGUvvFp5NSeQcXl9RHcI= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.26/go.mod h1:FR8f4turZtNy6baO0KJ5FJUmXH/cSkI9fOngs0yl6mA= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.26 h1:zXFLuEuMMUOvEARXFUVJdfqZ4bvvSgdGRq/ATcrQxzM= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.26/go.mod h1:3o2Wpy0bogG1kyOPrgkXA8pgIfEEv0+m19O9D5+W8y8= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 h1:VaRN3TlFdd6KxX1x3ILT5ynH6HvKgqdiXoTxAF4HQcQ= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.25 h1:r67ps7oHCYnflpgDy2LZU0MAQtQbYIOqNNnqGO6xQkE= @@ -30,6 +30,8 @@ github.com/aws/aws-sdk-go-v2/service/s3 v1.71.0 h1:nyuzXooUNJexRT0Oy0UQY6AhOzxPx github.com/aws/aws-sdk-go-v2/service/s3 v1.71.0/go.mod h1:sT/iQz8JK3u/5gZkT+Hmr7GzVZehUMkRZpOaAwYXeGY= github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.34.7 h1:Nyfbgei75bohfmZNxgN27i528dGYVzqWJGlAO6lzXy8= github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.34.7/go.mod h1:FG4p/DciRxPgjA+BEOlwRHN0iA8hX2h9g5buSy3cTDA= +github.com/aws/aws-sdk-go-v2/service/ssm v1.56.2 h1:MOxvXH2kRP5exvqJxAZ0/H9Ar51VmADJh95SgZE8u60= +github.com/aws/aws-sdk-go-v2/service/ssm v1.56.2/go.mod h1:RKWoqC9FlgMCkrfVOtgfqfwdaUIaq8H93UAt4xNaR0A= github.com/aws/aws-sdk-go-v2/service/sso v1.24.7 h1:rLnYAfXQ3YAccocshIH5mzNNwZBkBo+bP6EhIxak6Hw= github.com/aws/aws-sdk-go-v2/service/sso v1.24.7/go.mod h1:ZHtuQJ6t9A/+YDuxOLnbryAmITtr8UysSny3qcyvJTc= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.6 h1:JnhTZR3PiYDNKlXy50/pNeix9aGMo6lLpXwJ1mw8MD4= @@ -112,10 +114,12 @@ go.opentelemetry.io/proto/otlp v1.4.0 h1:TA9WRvW6zMwP+Ssb6fLoUIuirti1gGbP28GcKG1 go.opentelemetry.io/proto/otlp v1.4.0/go.mod h1:PPBWZIP98o2ElSqI35IHfu7hIhSwvc5N38Jw8pXuGFY= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= +golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= -golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= -golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= +golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= google.golang.org/genproto/googleapis/api v0.0.0-20241209162323-e6fa225c2576 h1:CkkIfIt50+lT6NHAVoRYEyAvQGFM7xEwXUUywFvEb3Q= diff --git a/service-app/internal/api/handler.go b/service-app/internal/api/handler.go index 23d41e6..47297d8 100644 --- a/service-app/internal/api/handler.go +++ b/service-app/internal/api/handler.go @@ -12,6 +12,7 @@ import ( "github.com/ministryofjustice/opg-go-common/telemetry" "github.com/ministryofjustice/opg-scanning/config" + "github.com/ministryofjustice/opg-scanning/internal/auth" "github.com/ministryofjustice/opg-scanning/internal/aws" "github.com/ministryofjustice/opg-scanning/internal/httpclient" "github.com/ministryofjustice/opg-scanning/internal/ingestion" @@ -20,31 +21,75 @@ import ( ) type IndexController struct { - config *config.Config - logger *logger.Logger - validator *ingestion.Validator - Queue *ingestion.JobQueue - AwsClient *aws.AwsClient + config *config.Config + logger *logger.Logger + validator *ingestion.Validator + httpMiddleware *httpclient.Middleware + authMiddleware *auth.Middleware + Queue *ingestion.JobQueue + AwsClient *aws.AwsClient } func NewIndexController(awsClient *aws.AwsClient, appConfig *config.Config) *IndexController { + logger := logger.NewLogger(appConfig) + + // Create dependencies + httpClient := httpclient.NewHttpClient(*appConfig, *logger) + tokenGenerator := auth.NewJWTTokenGenerator(awsClient, appConfig, logger) + cookieHelper := auth.MembraneCookieHelper{ + CookieName: "membrane", + Secure: appConfig.App.Environment != "local", + } + authenticator := auth.NewBasicAuthAuthenticator(awsClient, cookieHelper, tokenGenerator) + + // Create authentication middleware + authMiddleware := auth.NewMiddleware(authenticator, tokenGenerator, cookieHelper, logger) + // Create HTTP middleware + httpMiddleware, _ := httpclient.NewMiddleware(httpClient, tokenGenerator) + return &IndexController{ - config: appConfig, - logger: logger.NewLogger(appConfig), - validator: ingestion.NewValidator(), - Queue: ingestion.NewJobQueue(appConfig), - AwsClient: awsClient, + config: appConfig, + logger: logger, + validator: ingestion.NewValidator(), + httpMiddleware: httpMiddleware, + authMiddleware: authMiddleware, + Queue: ingestion.NewJobQueue(appConfig), + AwsClient: awsClient, } } func (c *IndexController) HandleRequests() { - http.Handle("/ingest", telemetry.Middleware(c.logger.SlogLogger)(http.HandlerFunc(c.IngestHandler))) + // Create the route to handle user authentication and issue JWT token + http.Handle("/auth/sessions", http.HandlerFunc(c.AuthHandler)) + + // Protect the route with JWT validation (using the authMiddleware) + http.Handle("/api/ddc", telemetry.Middleware(c.logger.SlogLogger)( + c.authMiddleware.CheckAuthMiddleware(http.HandlerFunc(c.IngestHandler)), + )) c.logger.Info("Starting server on :"+c.config.HTTP.Port, nil) http.ListenAndServe(":"+c.config.HTTP.Port, nil) } +func (c *IndexController) AuthHandler(w http.ResponseWriter, r *http.Request) { + // Authenticate user credentials and issue JWT token + _, err := c.authMiddleware.Authenticator.Authenticate(w, r) + if err != nil { + c.respondWithError(w, http.StatusUnauthorized, "Authentication failed", err) + return + } + + w.Write([]byte("Authentication successful")) +} + func (c *IndexController) IngestHandler(w http.ResponseWriter, r *http.Request) { + // Extract claims from context + // _, ok := r.Context().Value("claims").(jwt.MapClaims) + // if !ok { + // c.respondWithError(w, http.StatusUnauthorized, "Unauthorized: Unable to extract claims", nil) + // return + // } + if r.Method != http.MethodPost { c.respondWithError(w, http.StatusMethodNotAllowed, "Invalid HTTP method", nil) return @@ -54,7 +99,7 @@ func (c *IndexController) IngestHandler(w http.ResponseWriter, r *http.Request) bodyStr, err := c.readRequestBody(r) if err != nil { - c.respondWithError(w, http.StatusBadRequest, "Failed to read request body", err) + c.respondWithError(w, http.StatusBadRequest, "Invalid request body", err) return } @@ -76,17 +121,8 @@ func (c *IndexController) IngestHandler(w http.ResponseWriter, r *http.Request) return } - // Sirius API integration - // Create a case stub in Sirius if we have a case to create - httpClient := httpclient.NewHttpClient(*c.config, *c.logger) - middleware, err := httpclient.NewMiddleware(httpClient, c.AwsClient) - if err != nil { - c.respondWithError(w, http.StatusInternalServerError, "Failed to create middleware", err) - return - } - // Create a new client and prepare to attach documents - client := NewClient(middleware) + client := NewClient(c.httpMiddleware) service := NewService(client, parsedBaseXml) scannedCaseResponse, err := service.CreateCaseStub(r.Context()) if err != nil { @@ -103,6 +139,7 @@ func (c *IndexController) IngestHandler(w http.ResponseWriter, r *http.Request) } // Queue each document for further processing c.logger.Info("Queueing documents for processing", nil) + for i := range parsedBaseXml.Body.Documents { doc := &parsedBaseXml.Body.Documents[i] c.Queue.AddToQueue(doc, "xml", func(processedDoc interface{}, originalDoc *types.BaseDocument) { diff --git a/service-app/internal/api/service.go b/service-app/internal/api/service.go index cd31130..357d194 100644 --- a/service-app/internal/api/service.go +++ b/service-app/internal/api/service.go @@ -91,6 +91,10 @@ func (s *Service) CreateCaseStub(ctx context.Context) (*types.ScannedCaseRespons }, nil } + if s.Client.Middleware == nil { + return nil, fmt.Errorf("middleware is nil") + } + url := fmt.Sprintf("%s/%s", s.Client.Middleware.Config.App.SiriusBaseURL, s.Client.Middleware.Config.App.SiriusCaseStubURL) resp, err := s.Client.ClientRequest(ctx, scannedCaseRequest, url) diff --git a/service-app/internal/api/service_cases_test.go b/service-app/internal/api/service_cases_test.go index a12f8a8..fe13ac1 100644 --- a/service-app/internal/api/service_cases_test.go +++ b/service-app/internal/api/service_cases_test.go @@ -1,16 +1,15 @@ package api import ( + "bytes" "context" "encoding/json" "encoding/xml" "fmt" - "net/http" - "net/http/httptest" "testing" "github.com/ministryofjustice/opg-scanning/config" - "github.com/ministryofjustice/opg-scanning/internal/aws" + "github.com/ministryofjustice/opg-scanning/internal/auth" "github.com/ministryofjustice/opg-scanning/internal/httpclient" "github.com/ministryofjustice/opg-scanning/internal/logger" "github.com/ministryofjustice/opg-scanning/internal/types" @@ -91,70 +90,46 @@ func parseXMLPayload(t *testing.T, payload string) types.BaseSet { return set } -func setupMockServer(t *testing.T, expectedReq *types.ScannedCaseRequest) *httptest.Server { - return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - defer r.Body.Close() - - if r.Method != http.MethodPost { - t.Errorf("expected POST method, got %s", r.Method) - } - - var receivedRequest types.ScannedCaseRequest - if err := json.NewDecoder(r.Body).Decode(&receivedRequest); err != nil { - t.Fatalf("failed to decode request body: %v", err) - } - - if expectedReq != nil { - if receivedRequest.CaseType != expectedReq.CaseType { - t.Errorf("expected case type %s, but got %s", expectedReq.CaseType, receivedRequest.CaseType) - } - if expectedReq.CourtReference != "" && receivedRequest.CourtReference != expectedReq.CourtReference { - t.Errorf("expected court reference %s, but got %s", expectedReq.CourtReference, receivedRequest.CourtReference) - } - } - - w.WriteHeader(http.StatusOK) - w.Write([]byte(`{"UID": "dummy-uid-1234"}`)) - })) -} - func runStubCaseTest(t *testing.T, tt requestCaseStub) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - endpoint := "api/public/v1/scanned-cases" set := parseXMLPayload(t, tt.xmlPayload) - mockServer := setupMockServer(t, tt.expectedReq) - defer mockServer.Close() - - mockConfig := config.Config{ - App: config.App{ - SiriusBaseURL: mockServer.URL, - SiriusCaseStubURL: endpoint, - }, - Auth: config.Auth{ - ApiUsername: "test", - JWTSecretARN: "local/jwt-key", - JWTExpiration: 3600, - }, - } + mockConfig := *config.NewConfig() logger := *logger.NewLogger(&mockConfig) // Mock dependencies - mockAwsClient := new(aws.MockAwsClient) - mockAwsClient.On("GetSecretValue", mock.Anything, mock.AnythingOfType("string")). - Return("mock-signing-secret", nil) - - httpClient := httpclient.NewHttpClient(mockConfig, logger) - middleware, err := httpclient.NewMiddleware(httpClient, mockAwsClient) - if err != nil { - t.Fatalf("failed to create middleware: %v", err) - } - client := NewClient(middleware) + mockHttpClient, _, _, tokenGenerator := auth.PrepareMocks(&mockConfig, &logger) + httpMiddleware, _ := httpclient.NewMiddleware(mockHttpClient, tokenGenerator) + + mockHttpClient.On("HTTPRequest", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Maybe(). + Run(func(args mock.Arguments) { + payload := args[3].([]byte) + + var receivedRequest types.ScannedCaseRequest + if err := json.NewDecoder(bytes.NewReader(payload)).Decode(&receivedRequest); err != nil { + t.Fatalf("failed to decode request body: %v", err) + } + + // Perform assertions or checks on the received request data + if tt.expectedReq != nil { + if receivedRequest.CaseType != tt.expectedReq.CaseType { + t.Errorf("expected case type %s, but got %s", tt.expectedReq.CaseType, receivedRequest.CaseType) + } + if tt.expectedReq.CourtReference != "" && receivedRequest.CourtReference != tt.expectedReq.CourtReference { + t.Errorf("expected court reference %s, but got %s", tt.expectedReq.CourtReference, receivedRequest.CourtReference) + } + } + }). + Return([]byte(`{"UID": "dummy-uid-1234"}`), nil) + + client := NewClient(httpMiddleware) service := NewService(client, &set) ctx := context.Background() - _, err = service.CreateCaseStub(ctx) + + _, err := service.CreateCaseStub(ctx) if tt.expectedErr { if err == nil { @@ -166,6 +141,8 @@ func runStubCaseTest(t *testing.T, tt requestCaseStub) { } } + // Assert mock expectations + mockHttpClient.AssertExpectations(t) }) } diff --git a/service-app/internal/api/service_documents_test.go b/service-app/internal/api/service_documents_test.go index 60a188c..fb9c386 100644 --- a/service-app/internal/api/service_documents_test.go +++ b/service-app/internal/api/service_documents_test.go @@ -5,11 +5,10 @@ import ( "encoding/base64" "encoding/json" "os" - "regexp" "testing" "github.com/ministryofjustice/opg-scanning/config" - "github.com/ministryofjustice/opg-scanning/internal/aws" + "github.com/ministryofjustice/opg-scanning/internal/auth" "github.com/ministryofjustice/opg-scanning/internal/httpclient" "github.com/ministryofjustice/opg-scanning/internal/logger" "github.com/ministryofjustice/opg-scanning/internal/types" @@ -20,23 +19,11 @@ import ( func TestAttachDocument_Correspondence(t *testing.T) { // Mock dependencies - mockAwsClient := new(aws.MockAwsClient) - mockAwsClient.On("GetSecretValue", mock.Anything, mock.AnythingOfType("string")). - Return("mock-signing-secret", nil) + mockConfig := *config.NewConfig() + logger := *logger.NewLogger(&mockConfig) - mockClient := new(httpclient.MockHttpClient) - - mockConfig := config.NewConfig() - mockClient.On("GetConfig").Return(mockConfig) - - mockLogger := logger.NewLogger(mockConfig) - mockClient.On("GetLogger").Return(mockLogger) - - // Create middleware instance - middleware, err := httpclient.NewMiddleware(mockClient, mockAwsClient) - if err != nil { - t.Fatalf("failed to create middleware: %v", err) - } + mockHttpClient, _, _, tokenGenerator := auth.PrepareMocks(&mockConfig, &logger) + httpMiddleware, _ := httpclient.NewMiddleware(mockHttpClient, tokenGenerator) // Load PDF from the test file data, err := os.ReadFile("../../pdf/dummy.pdf") @@ -53,7 +40,7 @@ func TestAttachDocument_Correspondence(t *testing.T) { // Prepare service instance service := &Service{ - Client: &Client{Middleware: middleware}, + Client: &Client{Middleware: httpMiddleware}, originalDoc: &types.BaseDocument{ EmbeddedXML: xmlData, EmbeddedPDF: base64.StdEncoding.EncodeToString(data), @@ -82,12 +69,12 @@ func TestAttachDocument_Correspondence(t *testing.T) { } mockResponseBytes, _ := json.Marshal(mockResponse) - // Mock the HTTPRequest method - mockClient.On("HTTPRequest", mock.Anything, mock.MatchedBy(func(url string) bool { - // Remove domain from url using regex pattern - urlWithoutDomain := regexp.MustCompile(`^https?://[^/]+`).ReplaceAllString(url, "") - return urlWithoutDomain == "/api/public/v1/scanned-documents" - }), "POST", mock.Anything, mock.Anything).Return(mockResponseBytes, nil) + mockHttpClient.On("HTTPRequest", + mock.Anything, + mock.Anything, + "POST", + mock.Anything, + mock.Anything).Return(mockResponseBytes, nil) ctx := context.Background() response, err := service.AttachDocuments(ctx, caseResponse) @@ -96,10 +83,5 @@ func TestAttachDocument_Correspondence(t *testing.T) { } assert.NotNil(t, response, "Expected non-nil response") assert.Equal(t, mockResponse, response, "Expected response to match mock response") - - // Assert HTTPRequest was called with the expected parameters - mockClient.AssertCalled(t, "HTTPRequest", mock.Anything, mock.MatchedBy(func(url string) bool { - urlWithoutDomain := regexp.MustCompile(`^https?://[^/]+`).ReplaceAllString(url, "") - return urlWithoutDomain == "/api/public/v1/scanned-documents" - }), "POST", mock.Anything, mock.Anything) + mockHttpClient.AssertExpectations(t) } diff --git a/service-app/internal/auth/authenticator.go b/service-app/internal/auth/authenticator.go new file mode 100644 index 0000000..d5231f7 --- /dev/null +++ b/service-app/internal/auth/authenticator.go @@ -0,0 +1,82 @@ +package auth + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + + "github.com/ministryofjustice/opg-scanning/internal/aws" + "golang.org/x/crypto/bcrypt" +) + +type Authenticator interface { + Authenticate(w http.ResponseWriter, r *http.Request) (context.Context, error) + ValidateCredentials(ctx context.Context, creds User) (context.Context, error) +} + +type BasicAuthAuthenticator struct { + awsClient aws.AwsClientInterface + CookieHelper CookieHelper + TokenGenerator TokenGenerator +} + +func NewBasicAuthAuthenticator(awsClient aws.AwsClientInterface, cookieHelper CookieHelper, tokenGenerator TokenGenerator) *BasicAuthAuthenticator { + return &BasicAuthAuthenticator{ + awsClient: awsClient, + CookieHelper: cookieHelper, + TokenGenerator: tokenGenerator, + } +} + +func (a *BasicAuthAuthenticator) Authenticate(w http.ResponseWriter, r *http.Request) (context.Context, error) { + var creds User + + decoder := json.NewDecoder(r.Body) + if err := decoder.Decode(&creds); err != nil { + return nil, fmt.Errorf("invalid JSON payload: %w", err) + } + + // Validate credentials first + ctx, err := a.ValidateCredentials(r.Context(), creds) + if err != nil { + return nil, err + } + + if err := a.TokenGenerator.EnsureToken(ctx); err != nil { + return nil, fmt.Errorf("failed to ensure token: %w", err) + } + + token := a.TokenGenerator.GetToken() + expiry := a.TokenGenerator.GetExpiry() + + if err := a.CookieHelper.SetTokenInCookie(w, token, expiry); err != nil { + return nil, fmt.Errorf("failed to set cookie: %w", err) + } + + return ctx, nil +} + +func (a *BasicAuthAuthenticator) ValidateCredentials(ctx context.Context, creds User) (context.Context, error) { + if creds.Email == "" || creds.Password == "" { + return nil, fmt.Errorf("missing email or password") + } + + // Fetch credentials from AWS + storedCredentials, err := a.awsClient.FetchCredentials(ctx) + if err != nil { + return nil, fmt.Errorf("failed to fetch credentials: %w", err) + } + + storedHash, ok := storedCredentials[creds.Email] + if !ok { + return nil, fmt.Errorf("unknown email: %s", creds.Email) + } + + if err := bcrypt.CompareHashAndPassword([]byte(storedHash), []byte(creds.Password)); err != nil { + return nil, fmt.Errorf("invalid credentials") + } + + ctx = context.WithValue(ctx, userContextKey, creds.Email) + return ctx, nil +} diff --git a/service-app/internal/auth/authenticator_test.go b/service-app/internal/auth/authenticator_test.go new file mode 100644 index 0000000..3a38d68 --- /dev/null +++ b/service-app/internal/auth/authenticator_test.go @@ -0,0 +1,52 @@ +package auth + +import ( + "context" + "testing" + + "github.com/ministryofjustice/opg-scanning/config" + "github.com/ministryofjustice/opg-scanning/internal/logger" + "github.com/stretchr/testify/assert" +) + +func TestAuthenticatorCredentials(t *testing.T) { + cfg := config.NewConfig() + logger := logger.NewLogger(cfg) + + _, authMiddleware, _, _ := PrepareMocks(cfg, logger) + + tests := []struct { + name string + creds User + isValid bool + }{ + { + "Valid creds", + User{ + Email: cfg.Auth.ApiUsername, + Password: "password", + }, + true, + }, + { + "Invalid creds", + User{ + Email: cfg.Auth.ApiUsername, + Password: "", + }, + false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + authenticator := authMiddleware.Authenticator + _, err := authenticator.ValidateCredentials(context.Background(), tt.creds) + if tt.isValid { + assert.NoError(t, err) + } else { + assert.Error(t, err) + } + }) + } +} diff --git a/service-app/internal/auth/cookie.go b/service-app/internal/auth/cookie.go new file mode 100644 index 0000000..00cc147 --- /dev/null +++ b/service-app/internal/auth/cookie.go @@ -0,0 +1,37 @@ +package auth + +import ( + "net/http" + "time" +) + +type CookieHelper interface { + GetTokenFromCookie(r *http.Request) (string, error) + SetTokenInCookie(w http.ResponseWriter, token string, expiry time.Time) error +} + +type MembraneCookieHelper struct { + CookieName string + Secure bool +} + +func (h MembraneCookieHelper) GetTokenFromCookie(r *http.Request) (string, error) { + cookie, err := r.Cookie(h.CookieName) + if err != nil { + return "", err + } + return cookie.Value, nil +} + +func (h MembraneCookieHelper) SetTokenInCookie(w http.ResponseWriter, token string, expiry time.Time) error { + http.SetCookie(w, &http.Cookie{ + Name: h.CookieName, + Value: token, + Expires: expiry, + Path: "/", + HttpOnly: true, + Secure: h.Secure, + SameSite: http.SameSiteStrictMode, + }) + return nil +} diff --git a/service-app/internal/auth/middleware.go b/service-app/internal/auth/middleware.go new file mode 100644 index 0000000..a19cd47 --- /dev/null +++ b/service-app/internal/auth/middleware.go @@ -0,0 +1,56 @@ +package auth + +import ( + "context" + "net/http" + + "github.com/ministryofjustice/opg-scanning/internal/logger" +) + +type Middleware struct { + Authenticator Authenticator + TokenGenerator TokenGenerator + CookieHelper CookieHelper + logger *logger.Logger +} + +func NewMiddleware(authenticator Authenticator, tokenGenerator TokenGenerator, cookieHelper CookieHelper, logger *logger.Logger) *Middleware { + return &Middleware{ + Authenticator: authenticator, + TokenGenerator: tokenGenerator, + CookieHelper: cookieHelper, + logger: logger, + } +} + +func (m *Middleware) AuthenticateMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx, err := m.Authenticator.Authenticate(w, r) + if err != nil { + m.respondWithError(w, http.StatusUnauthorized, "Unauthorized", err) + return + } + + // Pass the new context with user info to the next handler + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +func (m *Middleware) CheckAuthMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + token, err := m.CookieHelper.GetTokenFromCookie(r) + if err != nil { + m.respondWithError(w, http.StatusUnauthorized, "Unauthorized: Missing token", err) + return + } + + ctx := context.WithValue(r.Context(), userContextKey, token) + + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +func (m *Middleware) respondWithError(w http.ResponseWriter, statusCode int, message string, err error) { + m.logger.Error("%s: %v", nil, message, err) + http.Error(w, message, statusCode) +} diff --git a/service-app/internal/auth/middleware_mock.go b/service-app/internal/auth/middleware_mock.go new file mode 100644 index 0000000..3eec41b --- /dev/null +++ b/service-app/internal/auth/middleware_mock.go @@ -0,0 +1,41 @@ +package auth + +import ( + "github.com/ministryofjustice/opg-scanning/config" + "github.com/ministryofjustice/opg-scanning/internal/aws" + "github.com/ministryofjustice/opg-scanning/internal/logger" + "github.com/ministryofjustice/opg-scanning/internal/mocks" + "github.com/stretchr/testify/mock" + "golang.org/x/crypto/bcrypt" +) + +func PrepareMocks(mockConfig *config.Config, logger *logger.Logger) (*mocks.MockHttpClient, *Middleware, *aws.MockAwsClient, *JWTTokenGenerator) { + // Initialize the mock AWS client + mockAwsClient := new(aws.MockAwsClient) + mockAwsClient.On("GetSecretValue", mock.Anything, "local/jwt-key").Maybe().Return("mysupersecrettestkeythatis128bits", nil) + mockAwsClient.On("FetchCredentials", mock.Anything).Maybe().Return(map[string]string{ + mockConfig.Auth.ApiUsername: hashPassword("password"), + }, nil) + // Create the HTTP client and middleware + mockHttpClient := new(mocks.MockHttpClient) + mockHttpClient.On("GetConfig").Return(mockConfig) + mockHttpClient.On("GetLogger").Return(logger) + + tokenGenerator := NewJWTTokenGenerator(mockAwsClient, mockConfig, logger) + cookieHelper := MembraneCookieHelper{ + CookieName: "membrane", + Secure: false, + } + authenticator := NewBasicAuthAuthenticator(mockAwsClient, cookieHelper, tokenGenerator) + authMiddleware := NewMiddleware(authenticator, tokenGenerator, cookieHelper, logger) + + return mockHttpClient, authMiddleware, mockAwsClient, tokenGenerator +} + +func hashPassword(password string) string { + hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + panic("failed to hash password in mock: " + err.Error()) + } + return string(hashedBytes) +} diff --git a/service-app/internal/auth/middleware_test.go b/service-app/internal/auth/middleware_test.go new file mode 100644 index 0000000..0e2afb9 --- /dev/null +++ b/service-app/internal/auth/middleware_test.go @@ -0,0 +1,75 @@ +package auth + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/ministryofjustice/opg-scanning/config" + "github.com/ministryofjustice/opg-scanning/internal/logger" + "github.com/stretchr/testify/assert" +) + +func TestEnsureTokenConcurrency(t *testing.T) { + cfg := config.NewConfig() + // Set a reasonable JWTExpiration for testiing e.g. 60 seconds + cfg.Auth.JWTExpiration = 60 + + logger := logger.NewLogger(cfg) + + _, _, mockAwsClient, _ := PrepareMocks(cfg, logger) + tokenGenerator := NewJWTTokenGenerator(mockAwsClient, cfg, logger) + + var wg sync.WaitGroup + numGoroutines := 10 + tokensChan := make(chan string, numGoroutines) + errorsChan := make(chan error, numGoroutines) + + // Launch multiple goroutines to call EnsureToken concurrently + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + err := tokenGenerator.EnsureToken(context.Background()) + if err != nil { + errorsChan <- err + return + } + token := tokenGenerator.GetToken() + tokensChan <- token + }() + } + wg.Wait() + close(errorsChan) + close(tokensChan) + + // Check for any errors from goroutines + for err := range errorsChan { + t.Errorf("EnsureToken failed: %v", err) + } + + // Assert that GetSecretValue was called only once + mockAwsClient.AssertNumberOfCalls(t, "GetSecretValue", 1) + + // Collect all tokens and verify they are the same + var firstToken string + for token := range tokensChan { + if firstToken == "" { + firstToken = token + assert.NotEmpty(t, firstToken, "First token should not be empty") + } else { + assert.Equal(t, firstToken, token, "All tokens should be identical") + } + } + + // verify that the token expiry is in the future + expiry := tokenGenerator.GetExpiry() + assert.True(t, expiry.After(time.Now()), "Token expiry should be in the future") + + // verify that tokenExpiry is around now + JWTExpiration + expectedExpiry := time.Now().Add(time.Duration(cfg.Auth.JWTExpiration) * time.Second) + assert.WithinDuration(t, expectedExpiry, expiry, 2*time.Second, "Token expiry should be set correctly") + + mockAwsClient.AssertExpectations(t) +} diff --git a/service-app/internal/auth/token.go b/service-app/internal/auth/token.go new file mode 100644 index 0000000..9e7653c --- /dev/null +++ b/service-app/internal/auth/token.go @@ -0,0 +1,153 @@ +package auth + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/ministryofjustice/opg-scanning/config" + "github.com/ministryofjustice/opg-scanning/internal/aws" + "github.com/ministryofjustice/opg-scanning/internal/logger" +) + +// Refresh token after 10 minutes +const secretTTL = 10 * time.Minute + +type TokenGenerator interface { + EnsureToken(ctx context.Context) error + GetToken() string + GetExpiry() time.Time +} + +type JWTTokenGenerator struct { + awsClient aws.AwsClientInterface + config *config.Config + logger *logger.Logger + signingSecret string + ApiUser string + mu sync.RWMutex + token string + tokenExpiry time.Time + lastSecretFetch time.Time +} + +type Claims struct { + SessionData string `json:"session-data"` + Iat int64 `json:"iat"` + Exp int64 `json:"exp"` +} + +func NewJWTTokenGenerator(awsClient aws.AwsClientInterface, config *config.Config, logger *logger.Logger) *JWTTokenGenerator { + return &JWTTokenGenerator{ + awsClient: awsClient, + config: config, + logger: logger, + } +} + +func (tg *JWTTokenGenerator) NewClaims() (Claims, error) { + if tg.ApiUser == "" { + tg.ApiUser = tg.config.Auth.ApiUsername + } + + return Claims{ + SessionData: tg.ApiUser, + Iat: time.Now().Unix(), + Exp: time.Now().Add(time.Duration(tg.config.Auth.JWTExpiration) * time.Second).Unix(), + }, nil +} + +func (tg *JWTTokenGenerator) fetchSigningSecret(ctx context.Context) error { + shouldFetch := time.Since(tg.lastSecretFetch) >= secretTTL || tg.signingSecret == "" + if !shouldFetch { + return nil + } + + secret, err := tg.awsClient.GetSecretValue(ctx, tg.config.Auth.JWTSecretARN) + if err != nil { + return fmt.Errorf("failed to fetch signing secret: %w", err) + } + tg.signingSecret = secret + tg.lastSecretFetch = time.Now() + return nil +} + +// Creates a new JWT token. +func (tg *JWTTokenGenerator) generateToken(ctx context.Context) (string, error) { + if err := tg.fetchSigningSecret(ctx); err != nil { + return "", err + } + + claims, err := tg.NewClaims() + if err != nil { + return "", fmt.Errorf("failed to create claims: %w", err) + } + + jwtClaims := jwt.MapClaims{ + "session-data": claims.SessionData, + "iat": claims.Iat, + "exp": claims.Exp, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwtClaims) + signedToken, err := token.SignedString([]byte(tg.signingSecret)) + if err != nil { + return "", fmt.Errorf("failed to sign token: %w", err) + } + + expiry := time.Unix(claims.Exp, 0) + + tg.token = signedToken + tg.tokenExpiry = expiry + + tg.logger.Info("Generated new JWT token.", nil) + + return signedToken, nil +} + +// Ensures that a valid token exists, generating a new one if necessary. +func (tg *JWTTokenGenerator) EnsureToken(ctx context.Context) error { + // First, acquire a read lock to check if the token is valid + tg.mu.RLock() + tokenValid := tg.token != "" && time.Now().Before(tg.tokenExpiry) + tg.mu.RUnlock() + + if tokenValid { + tg.logger.Info("Using cached JWT token.", nil) + return nil + } + + tg.mu.Lock() + defer tg.mu.Unlock() + + // Recheck the token validity under the write lock + if tg.token != "" && time.Now().Before(tg.tokenExpiry) { + tg.logger.Info("Another goroutine refreshed the token.", nil) + return nil + } + + // Generate a new token + ctx, cancel := context.WithTimeout(ctx, time.Duration(tg.config.HTTP.Timeout)*time.Second) + defer cancel() + + _, err := tg.generateToken(ctx) + if err != nil { + return fmt.Errorf("failed to generate token: %w", err) + } + + return nil +} + +func (tg *JWTTokenGenerator) GetToken() string { + tg.mu.RLock() + defer tg.mu.RUnlock() + return tg.token +} + +func (tg *JWTTokenGenerator) GetExpiry() time.Time { + tg.mu.RLock() + defer tg.mu.RUnlock() + return tg.tokenExpiry +} diff --git a/service-app/internal/auth/user.go b/service-app/internal/auth/user.go new file mode 100644 index 0000000..c6e19c8 --- /dev/null +++ b/service-app/internal/auth/user.go @@ -0,0 +1,18 @@ +package auth + +import "context" + +type User struct { + Email string `json:"email"` + Password string `json:"password"` +} + +type contextKey string + +const userContextKey = contextKey("auth-user") + +// Retrieves the users identity from the context +func UserFromContext(ctx context.Context) (string, bool) { + user, ok := ctx.Value(userContextKey).(string) + return user, ok +} diff --git a/service-app/internal/aws/client.go b/service-app/internal/aws/client.go index afa6413..e2a0b8a 100644 --- a/service-app/internal/aws/client.go +++ b/service-app/internal/aws/client.go @@ -2,24 +2,29 @@ package aws import ( "context" + "encoding/json" "fmt" "io" + "strings" "time" awsSdk "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/aws/aws-sdk-go-v2/service/secretsmanager" + "github.com/aws/aws-sdk-go-v2/service/ssm" "github.com/ministryofjustice/opg-scanning/config" ) type AwsClientInterface interface { GetSecretValue(ctx context.Context, secretName string) (string, error) + FetchCredentials(ctx context.Context) (map[string]string, error) } type AwsClient struct { config *config.Config SecretsManager *secretsmanager.Client + SSM *ssm.Client S3 *s3.Client } @@ -35,6 +40,10 @@ func NewAwsClient(ctx context.Context, cfg awsSdk.Config, appConfig *config.Conf o.BaseEndpoint = &customEndpoint }) + SsmClient := ssm.NewFromConfig(cfg, func(o *ssm.Options) { + o.BaseEndpoint = &customEndpoint + }) + s3Client := s3.NewFromConfig(cfg, func(o *s3.Options) { o.BaseEndpoint = &customEndpoint o.UsePathStyle = appConfig.App.Environment == "local" @@ -43,6 +52,7 @@ func NewAwsClient(ctx context.Context, cfg awsSdk.Config, appConfig *config.Conf return &AwsClient{ config: appConfig, SecretsManager: smClient, + SSM: SsmClient, S3: s3Client, }, nil } @@ -59,6 +69,21 @@ func (a *AwsClient) GetSecretValue(ctx context.Context, secretName string) (stri return *output.SecretString, nil } +// Fetch secret value from SSM Parameter Store +func (a *AwsClient) GetSsmValue(ctx context.Context, secretName string) (string, error) { + input := &ssm.GetParameterInput{ + Name: &secretName, + WithDecryption: awsSdk.Bool(true), + } + + output, err := a.SSM.GetParameter(ctx, input) + if err != nil { + return "", err + } + + return *output.Parameter.Value, nil +} + func (a *AwsClient) PersistFormData(ctx context.Context, body io.Reader, docType string) (string, error) { bucketName := a.config.Aws.JobsQueueBucket if bucketName == "" { @@ -88,3 +113,23 @@ func (a *AwsClient) PersistFormData(ctx context.Context, body io.Reader, docType return fileName, nil } + +func (a *AwsClient) FetchCredentials(ctx context.Context) (map[string]string, error) { + secretValue, err := a.GetSsmValue(ctx, a.config.Auth.CredentialsARN) + if err != nil { + return nil, fmt.Errorf("failed to retrieve secret from AWS: %w", err) + } + + secretValue = strings.TrimPrefix(secretValue, "kms:alias/aws/ssm:") + + var credentials map[string]string + if err := json.Unmarshal([]byte(secretValue), &credentials); err != nil { + return nil, fmt.Errorf("failed to unmarshal credentials: %w", err) + } + + if len(credentials) == 0 { + return nil, fmt.Errorf("no credentials found in secret") + } + + return credentials, nil +} diff --git a/service-app/internal/aws/client_mock.go b/service-app/internal/aws/client_mock.go index 604e816..9437201 100644 --- a/service-app/internal/aws/client_mock.go +++ b/service-app/internal/aws/client_mock.go @@ -14,3 +14,8 @@ func (m *MockAwsClient) GetSecretValue(ctx context.Context, secretName string) ( args := m.Called(ctx, secretName) return args.String(0), args.Error(1) } + +func (m *MockAwsClient) FetchCredentials(ctx context.Context) (map[string]string, error) { + args := m.Called(ctx) + return args.Get(0).(map[string]string), args.Error(1) +} diff --git a/service-app/internal/httpclient/middleware.go b/service-app/internal/httpclient/middleware.go index aff3ce5..46e20ed 100644 --- a/service-app/internal/httpclient/middleware.go +++ b/service-app/internal/httpclient/middleware.go @@ -1,34 +1,24 @@ +// httpclient/middleware.go package httpclient import ( "context" "fmt" - "sync" - "time" - "github.com/golang-jwt/jwt/v5" "github.com/ministryofjustice/opg-scanning/config" - "github.com/ministryofjustice/opg-scanning/internal/aws" + "github.com/ministryofjustice/opg-scanning/internal/auth" "github.com/ministryofjustice/opg-scanning/internal/logger" ) +// Middleware handles HTTP requests with authorization. type Middleware struct { - Client HttpClientInterface - Config *config.Config - Logger *logger.Logger - awsClient aws.AwsClientInterface - token string - tokenExpiry time.Time - mu sync.RWMutex + Client HttpClientInterface + Config *config.Config + Logger *logger.Logger + TokenGenerator auth.TokenGenerator } -type Claims struct { - SessionData string - Iat int64 - Exp int64 -} - -func NewMiddleware(client HttpClientInterface, awsClient aws.AwsClientInterface) (*Middleware, error) { +func NewMiddleware(client HttpClientInterface, tokenGenerator auth.TokenGenerator) (*Middleware, error) { config := client.GetConfig() logger := client.GetLogger() @@ -41,108 +31,30 @@ func NewMiddleware(client HttpClientInterface, awsClient aws.AwsClientInterface) } return &Middleware{ - Client: client, - awsClient: awsClient, - Config: config, - Logger: logger, - }, nil -} - -func NewClaims(cfg config.Config) (Claims, error) { - if cfg.Auth.ApiUsername == "" { - return Claims{}, fmt.Errorf("middleware configuration is missing or invalid") - } - - return Claims{ - SessionData: cfg.Auth.ApiUsername, - Iat: time.Now().Unix(), - Exp: time.Now().Add(time.Duration(cfg.Auth.JWTExpiration) * time.Second).Unix(), + Client: client, + Config: config, + Logger: logger, + TokenGenerator: tokenGenerator, }, nil } -func (m *Middleware) fetchSigningSecret(ctx context.Context) (string, error) { - secret, err := m.awsClient.GetSecretValue(ctx, m.Config.Auth.JWTSecretARN) - if err != nil { - return "", fmt.Errorf("failed to fetch signing secret: %w", err) - } - return secret, nil -} - -func (m *Middleware) generateToken(ctx context.Context) (string, error) { - signingSecret, err := m.fetchSigningSecret(ctx) - if err != nil { - return "", fmt.Errorf("failed to fetch signing secret: %w", err) - } - - claims, err := NewClaims(*m.Config) - if err != nil { - return "", fmt.Errorf("failed to create claims: %w", err) - } - - jwtClaims := jwt.MapClaims{ - "session-data": claims.SessionData, - "iat": claims.Iat, - "exp": claims.Exp, - } - - token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwtClaims) - signedToken, err := token.SignedString([]byte(signingSecret)) - if err != nil { - return "", fmt.Errorf("failed to sign token: %w", err) - } - - expiry := time.Unix(claims.Exp, 0) - - m.mu.Lock() - m.token = signedToken - m.tokenExpiry = expiry - m.mu.Unlock() - - return signedToken, nil -} - -func (m *Middleware) ensureToken(ctx context.Context) error { - m.mu.RLock() - tokenValid := m.token != "" && time.Now().Before(m.tokenExpiry) - m.mu.RUnlock() - - if tokenValid { - m.Logger.Info("Using cached JWT token.", nil) - return nil - } - - // Recheck token validity after acquiring the write lock - if m.token != "" && time.Now().Before(m.tokenExpiry) { - m.Logger.Info("Another goroutine refreshed the token.", nil) - return nil - } - - ctx, cancel := context.WithTimeout(ctx, time.Duration(m.Config.HTTP.Timeout)*time.Second) - defer cancel() - - _, err := m.generateToken(ctx) - if err != nil { - return fmt.Errorf("failed to generate token: %w", err) - } - return nil -} - // Acts as an HTTP wrapper for existing client with Authorization header set. func (m *Middleware) HTTPRequest(ctx context.Context, url, method string, payload []byte, headers map[string]string) ([]byte, error) { - // Ensure a valid token is available - if err := m.ensureToken(ctx); err != nil { + // Ensure token is valid using the TokenGenerator from auth package + if err := m.TokenGenerator.EnsureToken(ctx); err != nil { return nil, fmt.Errorf("failed to ensure token: %w", err) } - // Safely initialize headers if nil + // Retrieve the token + token := m.TokenGenerator.GetToken() + + // Add the Authorization header if headers == nil { headers = make(map[string]string) } + headers["Authorization"] = "Bearer " + token - // Add Authorization header - headers["Authorization"] = "Bearer " + m.token - - // Perform the HTTP request + // Perform the target HTTP request response, err := m.Client.HTTPRequest(ctx, url, method, payload, headers) if err != nil { return nil, fmt.Errorf("failed to perform HTTP request: %w", err) diff --git a/service-app/internal/httpclient/middleware_test.go b/service-app/internal/httpclient/middleware_test.go deleted file mode 100644 index 62ef81b..0000000 --- a/service-app/internal/httpclient/middleware_test.go +++ /dev/null @@ -1,75 +0,0 @@ -package httpclient - -import ( - "context" - "sync" - "testing" - "time" - - awsConfig "github.com/aws/aws-sdk-go-v2/config" - "github.com/ministryofjustice/opg-scanning/config" - "github.com/ministryofjustice/opg-scanning/internal/aws" - "github.com/ministryofjustice/opg-scanning/internal/logger" -) - -func TestEnsureTokenConcurrency(t *testing.T) { - ctx := context.Background() - cfg := config.NewConfig() - logger := *logger.NewLogger(cfg) - mockConfig := config.Config{ - HTTP: cfg.HTTP, - App: cfg.App, - Aws: cfg.Aws, - Auth: config.Auth{ - ApiUsername: "test", - JWTSecretARN: "local/jwt-key", - JWTExpiration: 3600, - }, - } - - // TODO: Check if we can integration AWS client during git actions workflow - // For now skip the test - if cfg.App.Environment != "local" { - t.Skip("Skipping test as it requires localstack and RUN_LOCAL_TESTS is not set to true") - } - - // Log mockConfig - // t.Logf("mockConfig: %+v", mockConfig) - - // Load AWS configuration - awsCfg, err := awsConfig.LoadDefaultConfig(ctx, - awsConfig.WithRegion(cfg.Aws.Region), - ) - if err != nil { - t.Errorf("Failed to load AWS config %v", err) - return - } - // Initialize AwsClient - awsClient, err := aws.NewAwsClient(ctx, awsCfg, &mockConfig) - if err != nil { - t.Errorf("failed to initialize AWS clients: %v", err) - } - httpClient := NewHttpClient(mockConfig, logger) - - middleware := &Middleware{ - Client: httpClient, - Config: &mockConfig, - Logger: &logger, - awsClient: awsClient, - tokenExpiry: time.Now().Add(time.Hour), - mu: sync.RWMutex{}, - } - - var wg sync.WaitGroup - for i := 0; i < 10; i++ { - wg.Add(1) - go func() { - defer wg.Done() - err := middleware.ensureToken(context.Background()) - if err != nil { - t.Errorf("ensureToken failed: %v", err) - } - }() - } - wg.Wait() -} diff --git a/service-app/internal/httpclient/request_mock.go b/service-app/internal/mocks/request_mock.go similarity index 97% rename from service-app/internal/httpclient/request_mock.go rename to service-app/internal/mocks/request_mock.go index bd7f853..10bb54b 100644 --- a/service-app/internal/httpclient/request_mock.go +++ b/service-app/internal/mocks/request_mock.go @@ -1,4 +1,4 @@ -package httpclient +package mocks import ( "context"