diff --git a/internal/jwe/JWE.go b/internal/jwe/JWE.go index 874da518..2cc1bc06 100644 --- a/internal/jwe/JWE.go +++ b/internal/jwe/JWE.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "github.com/markbates/goth/providers/openidConnect" jose "gopkg.in/square/go-jose.v2" ) @@ -75,3 +76,12 @@ func (handler *Handler) Decrypt(token string) (string, error) { return string(decrypted), nil } + +func DecryptIDToken(jweHandler *Handler, JWTSession *openidConnect.Session) error { + decryptedIDToken, err := jweHandler.Decrypt(JWTSession.IDToken) + if err != nil { + return err + } + JWTSession.IDToken = decryptedIDToken + return nil +} diff --git a/internal/jwe/JWE_test.go b/internal/jwe/JWE_test.go new file mode 100644 index 00000000..dac4f6ac --- /dev/null +++ b/internal/jwe/JWE_test.go @@ -0,0 +1,171 @@ +package jwe + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "github.com/markbates/goth/providers/openidConnect" + "gopkg.in/square/go-jose.v2" + + "github.com/stretchr/testify/assert" + "testing" +) + +func generateMockPrivateKey() (*tls.Certificate, error) { + // Generate a new RSA private key for testing + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, err + } + cert := &tls.Certificate{ + PrivateKey: privKey, + // Normally you would also populate the Certificate fields, but it's not needed for this test + } + return cert, nil +} + +func createJWE(payload []byte, recipient *rsa.PublicKey) (string, error) { + + encrypter, err := jose.NewEncrypter( + jose.A256GCM, + jose.Recipient{ + Algorithm: jose.RSA_OAEP_256, + Key: recipient, + }, + (&jose.EncrypterOptions{}).WithType("JWT")) + if err != nil { + return "", err + } + jwe, err := encrypter.Encrypt(payload) + if err != nil { + return "", err + } + return jwe.CompactSerialize() +} + +// Test case for Handler.Decrypt +func TestHandler_Decrypt(t *testing.T) { + // Generate a mock private key + mockCert, err := generateMockPrivateKey() + assert.NoError(t, err) + + // Create a valid JWE token for testing + jweString, err := createJWE([]byte("test token"), mockCert.PrivateKey.(*rsa.PrivateKey).Public().(*rsa.PublicKey)) + assert.NoError(t, err) + + tests := []struct { + name string + handler *Handler + token string + expected string + expectError bool + errorMessage string + }{ + { + name: "Disabled Handler", + handler: &Handler{ + Enabled: false, + }, + token: jweString, + expected: jweString, + expectError: false, + }, + { + name: "Key Not Loaded", + handler: &Handler{ + Enabled: true, + Key: nil, + }, + token: jweString, + expected: "", + expectError: true, + errorMessage: "JWE Private Key not loaded", + }, + { + name: "Successful Decryption", + handler: &Handler{ + Enabled: true, + Key: mockCert, + }, + token: jweString, + expected: "test token", + expectError: false, + }, + { + name: "Invalid Token", + handler: &Handler{ + Enabled: true, + Key: mockCert, + }, + token: "invalid-token", + expected: "", + expectError: true, + errorMessage: "error parsing JWE", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + decrypted, err := tt.handler.Decrypt(tt.token) + + if tt.expectError { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMessage) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, decrypted) + } + }) + } +} + +func TestDecryptIDToken(t *testing.T) { + mockCert, err := generateMockPrivateKey() + assert.NoError(t, err) + + // Create a valid JWE token for testing + jweString, err := createJWE([]byte("test token"), mockCert.PrivateKey.(*rsa.PrivateKey).Public().(*rsa.PublicKey)) + assert.NoError(t, err) + + // Setup a valid JWE handler + jweHandler := &Handler{ + Enabled: true, + Key: mockCert, + } + + tests := []struct { + name string + jwtSession *openidConnect.Session + expectError bool + expectedToken string + }{ + { + name: "Successful Decryption", + jwtSession: &openidConnect.Session{ + IDToken: jweString, + }, + expectError: false, + expectedToken: "test token", + }, + { + name: "Invalid Token", + jwtSession: &openidConnect.Session{ + IDToken: "invalid-token", + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := DecryptIDToken(jweHandler, tt.jwtSession) + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedToken, tt.jwtSession.IDToken) + } + }) + } +} diff --git a/internal/jwe/helper.go b/internal/jwe/helper.go index aee0dc09..a4f8e613 100644 --- a/internal/jwe/helper.go +++ b/internal/jwe/helper.go @@ -7,6 +7,8 @@ import ( "fmt" ) +// file to be removed after testing + // using the keycloak one var publicKeyPEM = ` -----BEGIN PUBLIC KEY----- diff --git a/providers/social.go b/providers/social.go index 4902ac82..98f0f835 100644 --- a/providers/social.go +++ b/providers/social.go @@ -46,11 +46,10 @@ var socialLogger = log.WithField("prefix", SocialLogTag) // Social is the identity handler for all social auth, it is a wrapper around Goth, and makes use of it's pluggable // providers to provide a raft of social OAuth providers as SSO or Login delegates. type Social struct { - handler tap.IdentityHandler - config GothConfig - toth toth.TothInstance - profile tap.Profile - jweHandler jwe.Handler + handler tap.IdentityHandler + config GothConfig + toth toth.TothInstance + profile tap.Profile } // GothProviderConfig the configurations required for the individual goth providers @@ -192,8 +191,7 @@ func (s *Social) checkConstraints(user interface{}) error { // HandleCallback handles the callback from the OAuth provider func (s *Social) HandleCallback(w http.ResponseWriter, r *http.Request, onError func(tag string, errorMsg string, rawErr error, code int, w http.ResponseWriter, r *http.Request), profile tap.Profile) { - - user, err := tothic.CompleteUserAuth(w, r, &s.toth, profile, &s.jweHandler) + user, err := tothic.CompleteUserAuth(w, r, &s.toth, profile, &s.config.JWE) if err != nil { fmt.Fprintln(w, err) return diff --git a/tothic/tothic.go b/tothic/tothic.go index 0dcca70f..2131df63 100644 --- a/tothic/tothic.go +++ b/tothic/tothic.go @@ -226,21 +226,20 @@ var CompleteUserAuth = func(res http.ResponseWriter, req *http.Request, toth *to return goth.User{}, err } - // for testing override the id token - JWTSession.IDToken, _ = jwe.Encrypt(JWTSession.IDToken) - //--end testing - // no decryption is required if !jweHandler.Enabled { return provider.FetchUser(sess) } + // for testing override the id token + // JWTSession.IDToken, _ = jwe.Encrypt(JWTSession.IDToken) + //--end testing + // we must decrypt the ID token - decryptedIDToken, err := jweHandler.Decrypt(JWTSession.IDToken) + err = jwe.DecryptIDToken(jweHandler, JWTSession) if err != nil { return goth.User{}, err } - JWTSession.IDToken = decryptedIDToken return provider.FetchUser(JWTSession) }