Skip to content

Commit

Permalink
added test for jwe handling
Browse files Browse the repository at this point in the history
  • Loading branch information
sredny buitrago authored and sredny buitrago committed Sep 27, 2024
1 parent 3f14e01 commit 7a77cec
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 13 deletions.
10 changes: 10 additions & 0 deletions internal/jwe/JWE.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/markbates/goth/providers/openidConnect"
jose "gopkg.in/square/go-jose.v2"
)

Expand Down Expand Up @@ -75,3 +76,12 @@ func (handler *Handler) Decrypt(token string) (string, error) {

return string(decrypted), nil
}

func DecryptIDToken(jweHandler *Handler, JWTSession *openidConnect.Session) error {
decryptedIDToken, err := jweHandler.Decrypt(JWTSession.IDToken)
if err != nil {
return err
}
JWTSession.IDToken = decryptedIDToken
return nil
}
171 changes: 171 additions & 0 deletions internal/jwe/JWE_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
package jwe

import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"github.com/markbates/goth/providers/openidConnect"
"gopkg.in/square/go-jose.v2"

"github.com/stretchr/testify/assert"
"testing"
)

func generateMockPrivateKey() (*tls.Certificate, error) {
// Generate a new RSA private key for testing
privKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, err
}
cert := &tls.Certificate{
PrivateKey: privKey,
// Normally you would also populate the Certificate fields, but it's not needed for this test
}
return cert, nil
}

func createJWE(payload []byte, recipient *rsa.PublicKey) (string, error) {

encrypter, err := jose.NewEncrypter(
jose.A256GCM,
jose.Recipient{
Algorithm: jose.RSA_OAEP_256,
Key: recipient,
},
(&jose.EncrypterOptions{}).WithType("JWT"))
if err != nil {
return "", err
}
jwe, err := encrypter.Encrypt(payload)
if err != nil {
return "", err
}
return jwe.CompactSerialize()
}

// Test case for Handler.Decrypt
func TestHandler_Decrypt(t *testing.T) {
// Generate a mock private key
mockCert, err := generateMockPrivateKey()
assert.NoError(t, err)

// Create a valid JWE token for testing
jweString, err := createJWE([]byte("test token"), mockCert.PrivateKey.(*rsa.PrivateKey).Public().(*rsa.PublicKey))
assert.NoError(t, err)

tests := []struct {
name string
handler *Handler
token string
expected string
expectError bool
errorMessage string
}{
{
name: "Disabled Handler",
handler: &Handler{
Enabled: false,
},
token: jweString,
expected: jweString,
expectError: false,
},
{
name: "Key Not Loaded",
handler: &Handler{
Enabled: true,
Key: nil,
},
token: jweString,
expected: "",
expectError: true,
errorMessage: "JWE Private Key not loaded",
},
{
name: "Successful Decryption",
handler: &Handler{
Enabled: true,
Key: mockCert,
},
token: jweString,
expected: "test token",
expectError: false,
},
{
name: "Invalid Token",
handler: &Handler{
Enabled: true,
Key: mockCert,
},
token: "invalid-token",
expected: "",
expectError: true,
errorMessage: "error parsing JWE",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
decrypted, err := tt.handler.Decrypt(tt.token)

if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMessage)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expected, decrypted)
}
})
}
}

func TestDecryptIDToken(t *testing.T) {
mockCert, err := generateMockPrivateKey()
assert.NoError(t, err)

// Create a valid JWE token for testing
jweString, err := createJWE([]byte("test token"), mockCert.PrivateKey.(*rsa.PrivateKey).Public().(*rsa.PublicKey))
assert.NoError(t, err)

// Setup a valid JWE handler
jweHandler := &Handler{
Enabled: true,
Key: mockCert,
}

tests := []struct {
name string
jwtSession *openidConnect.Session
expectError bool
expectedToken string
}{
{
name: "Successful Decryption",
jwtSession: &openidConnect.Session{
IDToken: jweString,
},
expectError: false,
expectedToken: "test token",
},
{
name: "Invalid Token",
jwtSession: &openidConnect.Session{
IDToken: "invalid-token",
},
expectError: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := DecryptIDToken(jweHandler, tt.jwtSession)

if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expectedToken, tt.jwtSession.IDToken)
}
})
}
}
2 changes: 2 additions & 0 deletions internal/jwe/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"fmt"
)

// file to be removed after testing

// using the keycloak one
var publicKeyPEM = `
-----BEGIN PUBLIC KEY-----
Expand Down
12 changes: 5 additions & 7 deletions providers/social.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,10 @@ var socialLogger = log.WithField("prefix", SocialLogTag)
// Social is the identity handler for all social auth, it is a wrapper around Goth, and makes use of it's pluggable
// providers to provide a raft of social OAuth providers as SSO or Login delegates.
type Social struct {
handler tap.IdentityHandler
config GothConfig
toth toth.TothInstance
profile tap.Profile
jweHandler jwe.Handler
handler tap.IdentityHandler
config GothConfig
toth toth.TothInstance
profile tap.Profile
}

// GothProviderConfig the configurations required for the individual goth providers
Expand Down Expand Up @@ -192,8 +191,7 @@ func (s *Social) checkConstraints(user interface{}) error {

// HandleCallback handles the callback from the OAuth provider
func (s *Social) HandleCallback(w http.ResponseWriter, r *http.Request, onError func(tag string, errorMsg string, rawErr error, code int, w http.ResponseWriter, r *http.Request), profile tap.Profile) {

user, err := tothic.CompleteUserAuth(w, r, &s.toth, profile, &s.jweHandler)
user, err := tothic.CompleteUserAuth(w, r, &s.toth, profile, &s.config.JWE)
if err != nil {
fmt.Fprintln(w, err)
return
Expand Down
11 changes: 5 additions & 6 deletions tothic/tothic.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,21 +226,20 @@ var CompleteUserAuth = func(res http.ResponseWriter, req *http.Request, toth *to
return goth.User{}, err
}

// for testing override the id token
JWTSession.IDToken, _ = jwe.Encrypt(JWTSession.IDToken)
//--end testing

// no decryption is required
if !jweHandler.Enabled {
return provider.FetchUser(sess)
}

// for testing override the id token
// JWTSession.IDToken, _ = jwe.Encrypt(JWTSession.IDToken)
//--end testing

// we must decrypt the ID token
decryptedIDToken, err := jweHandler.Decrypt(JWTSession.IDToken)
err = jwe.DecryptIDToken(jweHandler, JWTSession)
if err != nil {
return goth.User{}, err
}
JWTSession.IDToken = decryptedIDToken
return provider.FetchUser(JWTSession)
}

Expand Down

0 comments on commit 7a77cec

Please sign in to comment.