Skip to content

Commit

Permalink
Merge pull request #1464 from ministryofjustice/MLPAB-2407-onelogin-did
Browse files Browse the repository at this point in the history
MLPAB-2407 Use published onelogin public key for verifying identity jwt
  • Loading branch information
hawx authored Sep 12, 2024
2 parents a623c6f + 07386d8 commit e611634
Show file tree
Hide file tree
Showing 11 changed files with 702 additions and 118 deletions.
27 changes: 2 additions & 25 deletions cmd/mlpa/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package main
import (
"cmp"
"context"
"crypto/ecdsa"
"encoding/base64"
"encoding/json"
"fmt"
html "html/template"
Expand All @@ -21,7 +19,6 @@ import (
"github.com/aws/aws-sdk-go-v2/aws"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/golang-jwt/jwt/v5"
"github.com/gorilla/handlers"
"github.com/ministryofjustice/opg-go-common/template"
"github.com/ministryofjustice/opg-modernising-lpa/internal/actor"
Expand Down Expand Up @@ -94,6 +91,7 @@ func run(ctx context.Context, logger *slog.Logger) error {
awsBaseURL = os.Getenv("AWS_BASE_URL")
clientID = cmp.Or(os.Getenv("CLIENT_ID"), "client-id-value")
issuer = cmp.Or(os.Getenv("ISSUER"), "http://mock-onelogin:8080")
identityURL = cmp.Or(os.Getenv("IDENTITY_URL"), "http://mock-onelogin:8080")
dynamoTableLpas = cmp.Or(os.Getenv("DYNAMODB_TABLE_LPAS"), "lpas")
notifyBaseURL = cmp.Or(os.Getenv("GOVUK_NOTIFY_BASE_URL"), "http://mock-notify:8080")
notifyIsProduction = os.Getenv("GOVUK_NOTIFY_IS_PRODUCTION") == "1"
Expand All @@ -114,7 +112,6 @@ func run(ctx context.Context, logger *slog.Logger) error {
oneloginURL = cmp.Or(os.Getenv("ONELOGIN_URL"), "https://home.integration.account.gov.uk")
evidenceBucketName = cmp.Or(os.Getenv("UPLOADS_S3_BUCKET_NAME"), "evidence")
eventBusName = cmp.Or(os.Getenv("EVENT_BUS_NAME"), "default")
mockIdentityPublicKey = os.Getenv("MOCK_IDENTITY_PUBLIC_KEY")
searchEndpoint = os.Getenv("SEARCH_ENDPOINT")
searchIndexName = cmp.Or(os.Getenv("SEARCH_INDEX_NAME"), "lpas")
searchIndexingEnabled = os.Getenv("SEARCH_INDEXING_DISABLED") != "1"
Expand Down Expand Up @@ -246,27 +243,7 @@ func run(ctx context.Context, logger *slog.Logger) error {

redirectURL := authRedirectBaseURL + page.PathAuthRedirect.Format()

identityPublicKeyFunc := func(ctx context.Context) (*ecdsa.PublicKey, error) {
bytes, err := secretsClient.SecretBytes(ctx, secrets.GovUkOneLoginIdentityPublicKey)
if err != nil {
return nil, err
}

return jwt.ParseECPublicKeyFromPEM(bytes)
}

if mockIdentityPublicKey != "" {
identityPublicKeyFunc = func(ctx context.Context) (*ecdsa.PublicKey, error) {
bytes, err := base64.StdEncoding.DecodeString(mockIdentityPublicKey)
if err != nil {
return nil, err
}

return jwt.ParseECPublicKeyFromPEM(bytes)
}
}

oneloginClient := onelogin.New(ctx, logger, httpClient, secretsClient, issuer, clientID, redirectURL, identityPublicKeyFunc)
oneloginClient := onelogin.New(ctx, logger, httpClient, secretsClient, issuer, identityURL, clientID, redirectURL)

payApiKey, err := secretsClient.Secret(ctx, secrets.GovUkPay)
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions docker/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ services:
- CLIENT_ID=client-id-value
- REDIRECT_URL=http://localhost:5050/auth/redirect
- TEMPLATE_SUB=1
- TEMPLATE_SUB_DEFAULT=random
- TEMPLATE_RETURN_CODES=1

mock-notify:
Expand Down
1 change: 0 additions & 1 deletion docker/localstack/localstack-init.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ openssl rsa -pubout -in /tmp/private_key.pem -out /tmp/public_key.pem

echo 'setting secrets'
awslocal secretsmanager create-secret --region eu-west-1 --name "private-jwt-key-base64" --secret-string "$(base64 /tmp/private_key.pem)"
awslocal secretsmanager create-secret --region eu-west-1 --name "gov-uk-onelogin-identity-public-key" --secret-string "LS0tLS1CRUdJTiBQVUJMSUMgS0VZLS0tLS0KTUZrd0V3WUhLb1pJemowQ0FRWUlLb1pJemowREFRY0RRZ0FFSlEyVmtpZWtzNW9rSTIxY1Jma0FhOXVxN0t4TQo2bTJqWllCeHBybFVXQlpDRWZ4cTI3cFV0Qzd5aXplVlRiZUVqUnlJaStYalhPQjFBbDhPbHFtaXJnPT0KLS0tLS1FTkQgUFVCTElDIEtFWS0tLS0tCg=="
awslocal secretsmanager create-secret --region eu-west-1 --name "cookie-session-keys" --secret-string "[\"$(head -c32 /dev/random | base64)\"]"
awslocal secretsmanager create-secret --region eu-west-1 --name "gov-uk-pay-api-key" --secret-string "totally-fake-key"
awslocal secretsmanager create-secret --region eu-west-1 --name "os-postcode-lookup-api-key" --secret-string "another-fake-key"
Expand Down
34 changes: 17 additions & 17 deletions internal/onelogin/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,29 +28,29 @@ type SecretsClient interface {
type IdentityPublicKeyFunc func(context.Context) (*ecdsa.PublicKey, error)

type Client struct {
ctx context.Context
logger Logger
httpClient Doer
openidConfiguration *configurationClient
secretsClient SecretsClient
randomString func(int) string
identityPublicKeyFunc IdentityPublicKeyFunc
ctx context.Context
logger Logger
httpClient Doer
openidConfiguration *configurationClient
secretsClient SecretsClient
randomString func(int) string
didClient *didClient

clientID string
redirectURL string
}

func New(ctx context.Context, logger Logger, httpClient *http.Client, secretsClient SecretsClient, issuer, clientID, redirectURL string, identityPublicKeyFunc IdentityPublicKeyFunc) *Client {
func New(ctx context.Context, logger Logger, httpClient *http.Client, secretsClient SecretsClient, issuer, identityURL, clientID, redirectURL string) *Client {
return &Client{
ctx: ctx,
logger: logger,
httpClient: httpClient,
secretsClient: secretsClient,
randomString: random.String,
identityPublicKeyFunc: identityPublicKeyFunc,
clientID: clientID,
redirectURL: redirectURL,
openidConfiguration: getConfiguration(ctx, logger, httpClient, issuer),
ctx: ctx,
logger: logger,
httpClient: httpClient,
secretsClient: secretsClient,
randomString: random.String,
clientID: clientID,
redirectURL: redirectURL,
openidConfiguration: getConfiguration(ctx, logger, httpClient, issuer),
didClient: getDID(ctx, logger, httpClient, identityURL),
}
}

Expand Down
182 changes: 182 additions & 0 deletions internal/onelogin/did.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
package onelogin

import (
"context"
"crypto"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"strconv"
"strings"
"time"

"github.com/MicahParks/jwkset"
)

const didDocumentEndpoint = "/.well-known/did.json"

type didDocument struct {
Context []string `json:"@context"`
ID string `json:"id"`
AssertionMethods []didAssertionMethod `json:"assertionMethod"`
}

type didAssertionMethod struct {
Type string `json:"type"`
ID string `json:"id"`
Controller string `json:"controller"`
PublicKeyJWK jwkset.JWKMarshal `json:"publicKeyJwk"`
}

type didClient struct {
ctx context.Context
identityURL string
http Doer
logger Logger
now func() time.Time
refreshRateLimit time.Duration
refreshRequest chan (struct{})

controllerID string
assertionMethods map[string]crypto.PublicKey
}

func getDID(ctx context.Context, logger Logger, httpClient Doer, identityURL string) *didClient {
client := &didClient{
ctx: ctx,
identityURL: identityURL,
http: httpClient,
logger: logger,
now: time.Now,
refreshRateLimit: refreshRateLimit,
// only allow a single request to be waiting
refreshRequest: make(chan struct{}, 1),
}

go client.backgroundRefresh()

return client
}

// ForKID retrieves the public key for the given kid.
func (c *didClient) ForKID(kid string) (crypto.PublicKey, error) {
if c.controllerID == "" {
c.requestRefresh()
return nil, ErrConfigurationMissing
}

controllerID, _, found := strings.Cut(kid, "#")
if !found {
return nil, fmt.Errorf("malformed kid missing '#'")
}

if c.controllerID != controllerID {
return nil, fmt.Errorf("controller id does not match: %s != %s", c.controllerID, controllerID)
}

publicKey, ok := c.assertionMethods[kid]
if !ok {
return nil, fmt.Errorf("missing jwk for kid %s", kid)
}

return publicKey, nil
}

// refresh updates the did documents.
func (c *didClient) refresh() (time.Duration, error) {
const errRefresh = time.Minute

ctx, cancel := context.WithTimeout(c.ctx, refreshTimeout)
defer cancel()

req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.identityURL+didDocumentEndpoint, nil)
if err != nil {
return errRefresh, err
}

resp, err := c.http.Do(req)
if err != nil {
return errRefresh, err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return errRefresh, fmt.Errorf("unexpected response status code %d for %s", resp.StatusCode, c.identityURL+didDocumentEndpoint)
}

var document didDocument
if err := json.NewDecoder(resp.Body).Decode(&document); err != nil {
return errRefresh, err
}

assertionMethods := map[string]crypto.PublicKey{}

for _, method := range document.AssertionMethods {
jwk, err := jwkset.NewJWKFromMarshal(method.PublicKeyJWK, jwkset.JWKMarshalOptions{}, jwkset.JWKValidateOptions{})
if err != nil {
return errRefresh, fmt.Errorf("could not unmarshal public key jwk for %s: %w", method.ID, err)
}

assertionMethods[method.ID] = jwk.Key().(crypto.PublicKey)
}

c.controllerID = document.ID
c.assertionMethods = assertionMethods

return parseCacheControl(resp.Header.Get("Cache-Control")), nil
}

// requestRefresh will request that the DID document is refreshed, if no other request is waiting
func (c *didClient) requestRefresh() {
select {
case c.refreshRequest <- struct{}{}:
default:
}
}

func (c *didClient) backgroundRefresh() {
var (
lastRefresh time.Time
refreshIn time.Duration
err error
)

for {
select {
case <-time.After(refreshIn):
c.requestRefresh()

case <-c.refreshRequest:
if lastRefresh.Add(c.refreshRateLimit).After(c.now()) {
continue
}

refreshIn, err = c.refresh()
if err != nil {
c.logger.WarnContext(c.ctx, "problem refreshing did document", slog.Any("err", err.Error()))
}
lastRefresh = c.now()

case <-c.ctx.Done():
return
}
}
}

func parseCacheControl(s string) time.Duration {
for _, directive := range strings.Split(s, ",") {
key, val, _ := strings.Cut(strings.TrimSpace(directive), "=")
switch key {
case "max-age":
i, err := strconv.Atoi(val)
if err != nil {
continue
}

return time.Duration(i) * time.Second
}
}

return refreshInterval
}
Loading

0 comments on commit e611634

Please sign in to comment.