diff --git a/auth/encoding_utils.go b/auth/encoding_utils.go new file mode 100644 index 0000000000..e0f296e2f1 --- /dev/null +++ b/auth/encoding_utils.go @@ -0,0 +1,22 @@ +package auth + +import ( + "context" + "encoding/base64" + + "github.com/flyteorg/flytestdlib/logger" +) + +// EncodeBase64 returns the base64 encoded version of the data +func EncodeBase64(raw []byte) string { + return base64.RawStdEncoding.EncodeToString(raw) +} + +// DecodeFromBase64 returns the original encoded bytes and logs warning in case of error +func DecodeFromBase64(encodedData string) ([]byte, error) { + decodedData, err := base64.StdEncoding.DecodeString(encodedData) + if err != nil { + logger.Warnf(context.TODO(), "Unable to decode %v due to %v", encodedData, err) + } + return decodedData, err +} diff --git a/auth/encoding_utils_test.go b/auth/encoding_utils_test.go new file mode 100644 index 0000000000..2918d76ccb --- /dev/null +++ b/auth/encoding_utils_test.go @@ -0,0 +1,31 @@ +package auth + +import ( + "encoding/base64" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestEncodeAscii(t *testing.T) { + assert.Equal(t, "bmls", EncodeBase64([]byte("nil"))) + assert.Equal(t, "w4RwZmVs", EncodeBase64([]byte("Äpfel"))) +} + +func TestDecodeFromAscii(t *testing.T) { + type data struct { + decoded string + encoded string + expectedErr error + } + tt := []data{ + {decoded: "nil", encoded: "bmls", expectedErr: nil}, + {decoded: "Äpfel", encoded: "w4RwZmVs", expectedErr: nil}, + {decoded: "", encoded: "Äpfel", expectedErr: base64.CorruptInputError(0)}, + } + for _, testdata := range tt { + actualDecoded, actualErr := DecodeFromBase64(testdata.encoded) + assert.Equal(t, []byte(testdata.decoded), actualDecoded) + assert.Equal(t, testdata.expectedErr, actualErr) + } +} diff --git a/auth/handlers.go b/auth/handlers.go index 26c9469c01..048d49b82f 100644 --- a/auth/handlers.go +++ b/auth/handlers.go @@ -323,7 +323,7 @@ func GetHTTPRequestCookieToMetadataHandler(authCtx interfaces.AuthenticationCont } if len(raw) > 0 { - meta.Set(UserInfoMDKey, string(raw)) + meta.Set(UserInfoMDKey, EncodeBase64(raw)) } return meta diff --git a/auth/handlers_test.go b/auth/handlers_test.go index d86dee3a4e..d8c0f86d8f 100644 --- a/auth/handlers_test.go +++ b/auth/handlers_test.go @@ -230,6 +230,7 @@ func TestGetHTTPRequestCookieToMetadataHandler(t *testing.T) { req.AddCookie(&idCookie) assert.Equal(t, "IDToken a.b.c", handler(ctx, req)["authorization"][0]) + assert.Equal(t, "bnVsbA", handler(ctx, req).Get(UserInfoMDKey)[0]) } func TestGetHTTPMetadataTaggingHandler(t *testing.T) { diff --git a/auth/token.go b/auth/token.go index 974463f9ab..7b97f637da 100644 --- a/auth/token.go +++ b/auth/token.go @@ -111,10 +111,10 @@ func GRPCGetIdentityFromIDToken(ctx context.Context, clientID string, provider * } meta := metautils.ExtractIncoming(ctx) - userInfoStr := meta.Get(UserInfoMDKey) + userInfoDecoded, _ := DecodeFromBase64(meta.Get(UserInfoMDKey)) userInfo := &service.UserInfoResponse{} - if len(userInfoStr) > 0 { - err = json.Unmarshal([]byte(userInfoStr), userInfo) + if len(userInfoDecoded) > 0 { + err = json.Unmarshal(userInfoDecoded, userInfo) if err != nil { logger.Infof(ctx, "Could not unmarshal user info from metadata %v", err) }