diff --git a/go.mod b/go.mod index 1ee8cb2c..b44e57c1 100644 --- a/go.mod +++ b/go.mod @@ -24,6 +24,7 @@ require ( github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/klauspost/compress v1.17.2 // indirect github.com/kr/pretty v0.3.1 // indirect + github.com/leodido/go-urn v1.4.0 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect diff --git a/go.sum b/go.sum index 3a31410a..d7d02118 100644 --- a/go.sum +++ b/go.sum @@ -71,6 +71,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/ministryofjustice/opg-go-common v0.0.0-20231128145056-24628fba649c h1:598i3upKVEHRLW+eSkGmCaV7+yIaFTV6lMiHOC3tXDY= github.com/ministryofjustice/opg-go-common v0.0.0-20231128145056-24628fba649c/go.mod h1:qktwZb46YkojkLVHU2QNnVK6yVktXkNpBuJ+TyobvuY= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= diff --git a/internal/shared/jwt.go b/internal/shared/jwt.go index ebf033b2..a25ffcd1 100644 --- a/internal/shared/jwt.go +++ b/internal/shared/jwt.go @@ -8,7 +8,8 @@ import ( "time" "github.com/aws/aws-lambda-go/events" - "github.com/golang-jwt/jwt/v5" + jwt "github.com/golang-jwt/jwt/v5" + urn "github.com/leodido/go-urn" ) const ( @@ -61,17 +62,17 @@ func (l LpaStoreClaims) Validate() error { return err } - if iss == sirius { - emailRegex := regexp.MustCompile("^[a-zA-Z0-9.!#$%&'*+/=?^_`{|}~-]+@[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$") - if !emailRegex.MatchString(sub) { - return errors.New("Subject is not a valid email") - } - } - - if iss == mrlpa { - uidRegex := regexp.MustCompile("^.+$") - if !uidRegex.MatchString(sub) { - return errors.New("Subject is not a valid UID") + _, isUrn := urn.Parse([]byte(sub)) + + if !isUrn { + switch iss { + case mrlpa: + return errors.New("Subject is not a valid URN") + case sirius: + emailRegex := regexp.MustCompile("^[a-zA-Z0-9.!#$%&'*+/=?^_`{|}~-]+@[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$") + if !emailRegex.MatchString(sub) { + return errors.New("Subject is not a valid email or URN") + } } } diff --git a/internal/shared/jwt_test.go b/internal/shared/jwt_test.go index 9424f8e9..11f9926a 100644 --- a/internal/shared/jwt_test.go +++ b/internal/shared/jwt_test.go @@ -35,7 +35,7 @@ func TestVerifyExpInPast(t *testing.T) { "exp": time.Now().Add(time.Hour * -24).Unix(), "iat": time.Now().Add(time.Hour * -24).Unix(), "iss": "opg.poas.makeregister", - "sub": "M-3467-89QW-ERTY", + "sub": "urn:opg:poas:makeregister:users:e6707412-c9cd-4547-b428-7039a87e985e", }) _, err := verifier.verifyToken(token) @@ -78,7 +78,7 @@ func TestVerifyIssuer(t *testing.T) { } } -func TestVerifyBadEmailForSiriusIssuer(t *testing.T) { +func TestVerifyBadSubForSiriusIssuer(t *testing.T) { token := createToken(jwt.MapClaims{ "exp": time.Now().Add(time.Hour * 24).Unix(), "iat": time.Now().Add(time.Hour * -24).Unix(), @@ -90,11 +90,11 @@ func TestVerifyBadEmailForSiriusIssuer(t *testing.T) { assert.NotNil(t, err) if err != nil { - assert.Containsf(t, err.Error(), "Subject is not a valid email", "") + assert.Containsf(t, err.Error(), "Subject is not a valid email or URN", "") } } -func TestVerifyBadUIDForMRLPAIssuer(t *testing.T) { +func TestVerifyBadSubForMRLPAIssuer(t *testing.T) { token := createToken(jwt.MapClaims{ "exp": time.Now().Add(time.Hour * 24).Unix(), "iat": time.Now().Add(time.Hour * -24).Unix(), @@ -106,11 +106,11 @@ func TestVerifyBadUIDForMRLPAIssuer(t *testing.T) { assert.NotNil(t, err) if err != nil { - assert.Containsf(t, err.Error(), "Subject is not a valid UID", "") + assert.Containsf(t, err.Error(), "Subject is not a valid URN", "") } } -func TestVerifyGoodJwt(t *testing.T) { +func TestVerifyGoodJwtSiriusSubs(t *testing.T) { token := createToken(jwt.MapClaims{ "exp": time.Now().Add(time.Hour * 24).Unix(), "iat": time.Now().Add(time.Hour * -24).Unix(), @@ -120,6 +120,28 @@ func TestVerifyGoodJwt(t *testing.T) { _, err := verifier.verifyToken(token) assert.Nil(t, err) + + token = createToken(jwt.MapClaims{ + "exp": time.Now().Add(time.Hour * 24).Unix(), + "iat": time.Now().Add(time.Hour * -24).Unix(), + "iss": "opg.poas.sirius", + "sub": "urn:opg:sirius:users:34", + }) + + _, err = verifier.verifyToken(token) + assert.Nil(t, err) +} + +func TestVerifyGoodJwtMRLPASubs(t *testing.T) { + token := createToken(jwt.MapClaims{ + "exp": time.Now().Add(time.Hour * 24).Unix(), + "iat": time.Now().Add(time.Hour * -24).Unix(), + "iss": "opg.poas.makeregister", + "sub": "urn:opg:poas:makeregister:users:e6707412-c9cd-4547-b428-7039a87e985e", + }) + + _, err := verifier.verifyToken(token) + assert.Nil(t, err) } func TestNewJWTVerifier(t *testing.T) {