diff --git a/flyteadmin/auth/authzserver/provider.go b/flyteadmin/auth/authzserver/provider.go index be33ac28f4..b2948331fb 100644 --- a/flyteadmin/auth/authzserver/provider.go +++ b/flyteadmin/auth/authzserver/provider.go @@ -147,6 +147,9 @@ func NewProvider(ctx context.Context, cfg config.AuthorizationServer, sm core.Se if err != nil { return Provider{}, fmt.Errorf("failed to read secretTokenHash file. Error: %w", err) } + if tokenHashBase64 == "" { + return Provider{}, fmt.Errorf("failed to read secretTokenHash. Error: empty value") + } secret, err := base64.RawStdEncoding.DecodeString(tokenHashBase64) if err != nil { @@ -158,8 +161,14 @@ func NewProvider(ctx context.Context, cfg config.AuthorizationServer, sm core.Se if err != nil { return Provider{}, fmt.Errorf("failed to read token signing RSA Key. Error: %w", err) } + if privateKeyPEM == "" { + return Provider{}, fmt.Errorf("failed to read token signing RSA Key. Error: empty value") + } block, _ := pem.Decode([]byte(privateKeyPEM)) + if block == nil { + return Provider{}, fmt.Errorf("failed to decode token signing RSA Key. Error: no PEM data found") + } privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) if err != nil { return Provider{}, fmt.Errorf("failed to parse PKCS1PrivateKey. Error: %w", err) @@ -197,7 +206,13 @@ func NewProvider(ctx context.Context, cfg config.AuthorizationServer, sm core.Se // Try to load old key to validate tokens using it to support key rotation. privateKeyPEM, err = sm.Get(ctx, cfg.OldTokenSigningRSAKeySecretName) if err == nil { + if privateKeyPEM == "" { + return Provider{}, fmt.Errorf("failed to read PKCS1PrivateKey. Error: empty value") + } block, _ = pem.Decode([]byte(privateKeyPEM)) + if block == nil { + return Provider{}, fmt.Errorf("failed to decode PKCS1PrivateKey. Error: no PEM data found") + } oldPrivateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) if err != nil { return Provider{}, fmt.Errorf("failed to parse PKCS1PrivateKey. Error: %w", err) diff --git a/flyteadmin/auth/authzserver/provider_test.go b/flyteadmin/auth/authzserver/provider_test.go index 4659e603a4..45f0778b51 100644 --- a/flyteadmin/auth/authzserver/provider_test.go +++ b/flyteadmin/auth/authzserver/provider_test.go @@ -34,7 +34,7 @@ func newMockProvider(t testing.TB) (Provider, auth.SecretsSet) { var buf bytes.Buffer assert.NoError(t, pem.Encode(&buf, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: privBytes})) sm.OnGet(ctx, config.SecretNameTokenSigningRSAKey).Return(buf.String(), nil) - sm.OnGet(ctx, config.SecretNameOldTokenSigningRSAKey).Return("", fmt.Errorf("not found")) + sm.OnGet(ctx, config.SecretNameOldTokenSigningRSAKey).Return(buf.String(), nil) p, err := NewProvider(ctx, config.DefaultConfig.AppAuth.SelfAuthServer, sm) assert.NoError(t, err) @@ -45,9 +45,126 @@ func TestNewProvider(t *testing.T) { newMockProvider(t) } +func newInvalidMockProvider(ctx context.Context, t *testing.T, secrets auth.SecretsSet, sm *mocks.SecretManager, invalidFunc func() *mocks.SecretManager_Get, errorContains string) { + + sm.OnGet(ctx, config.SecretNameClaimSymmetricKey).Return(base64.RawStdEncoding.EncodeToString(secrets.TokenHashKey), nil) + sm.OnGet(ctx, config.SecretNameCookieBlockKey).Return(base64.RawStdEncoding.EncodeToString(secrets.CookieBlockKey), nil) + sm.OnGet(ctx, config.SecretNameCookieHashKey).Return(base64.RawStdEncoding.EncodeToString(secrets.CookieHashKey), nil) + + privBytes := x509.MarshalPKCS1PrivateKey(secrets.TokenSigningRSAPrivateKey) + var buf bytes.Buffer + assert.NoError(t, pem.Encode(&buf, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: privBytes})) + sm.OnGet(ctx, config.SecretNameTokenSigningRSAKey).Return(buf.String(), nil) + sm.OnGet(ctx, config.SecretNameOldTokenSigningRSAKey).Return(buf.String(), nil) + + invalidFunc() + p, err := NewProvider(ctx, config.DefaultConfig.AppAuth.SelfAuthServer, sm) + assert.Error(t, err) + assert.ErrorContains(t, err, errorContains) + assert.Equal(t, Provider{}, p) +} + +func TestNewInvalidProviderSecretTokenHashBad(t *testing.T) { + secrets, err := auth.NewSecrets() + assert.NoError(t, err) + + ctx := context.Background() + sm := &mocks.SecretManager{} + + invalidFunc := func() *mocks.SecretManager_Get { + sm.OnGet(ctx, config.SecretNameClaimSymmetricKey).Unset() + return sm.OnGet(ctx, config.SecretNameClaimSymmetricKey).Return("", fmt.Errorf("test error")) + } + newInvalidMockProvider(ctx, t, secrets, sm, invalidFunc, "failed to read secretTokenHash file. Error: test error") +} + +func TestNewInvalidProviderSecretTokenHashEmpty(t *testing.T) { + secrets, err := auth.NewSecrets() + assert.NoError(t, err) + + ctx := context.Background() + sm := &mocks.SecretManager{} + + invalidFunc := func() *mocks.SecretManager_Get { + sm.OnGet(ctx, config.SecretNameClaimSymmetricKey).Unset() + return sm.OnGet(ctx, config.SecretNameClaimSymmetricKey).Return("", nil) + } + newInvalidMockProvider(ctx, t, secrets, sm, invalidFunc, "failed to read secretTokenHash. Error: empty value") +} + +func TestNewInvalidProviderTokenSigningRSAKeyBad(t *testing.T) { + secrets, err := auth.NewSecrets() + assert.NoError(t, err) + + ctx := context.Background() + sm := &mocks.SecretManager{} + + invalidFunc := func() *mocks.SecretManager_Get { + sm.OnGet(ctx, config.SecretNameTokenSigningRSAKey).Unset() + return sm.OnGet(ctx, config.SecretNameTokenSigningRSAKey).Return("", fmt.Errorf("test error")) + } + newInvalidMockProvider(ctx, t, secrets, sm, invalidFunc, "failed to read token signing RSA Key. Error: test error") +} + +func TestNewInvalidProviderTokenSigningRSAKeyEmpty(t *testing.T) { + secrets, err := auth.NewSecrets() + assert.NoError(t, err) + + ctx := context.Background() + sm := &mocks.SecretManager{} + + invalidFunc := func() *mocks.SecretManager_Get { + sm.OnGet(ctx, config.SecretNameTokenSigningRSAKey).Unset() + return sm.OnGet(ctx, config.SecretNameTokenSigningRSAKey).Return("", nil) + } + newInvalidMockProvider(ctx, t, secrets, sm, invalidFunc, "failed to read token signing RSA Key. Error: empty value") +} + +func TestNewInvalidProviderTokenSigningRSAKeyNoPEMData(t *testing.T) { + secrets, err := auth.NewSecrets() + assert.NoError(t, err) + + ctx := context.Background() + sm := &mocks.SecretManager{} + + invalidFunc := func() *mocks.SecretManager_Get { + sm.OnGet(ctx, config.SecretNameTokenSigningRSAKey).Unset() + return sm.OnGet(ctx, config.SecretNameTokenSigningRSAKey).Return("this is no PEM data", nil) + } + newInvalidMockProvider(ctx, t, secrets, sm, invalidFunc, "failed to decode token signing RSA Key. Error: no PEM data found") +} + +func TestNewInvalidProviderOldTokenSigningRSAKeyEmpty(t *testing.T) { + secrets, err := auth.NewSecrets() + assert.NoError(t, err) + + ctx := context.Background() + sm := &mocks.SecretManager{} + + invalidFunc := func() *mocks.SecretManager_Get { + sm.OnGet(ctx, config.SecretNameOldTokenSigningRSAKey).Unset() + return sm.OnGet(ctx, config.SecretNameOldTokenSigningRSAKey).Return("", nil) + } + newInvalidMockProvider(ctx, t, secrets, sm, invalidFunc, "failed to read PKCS1PrivateKey. Error: empty value") +} + +func TestNewInvalidProviderOldTokenSigningRSAKeyNoPEMData(t *testing.T) { + secrets, err := auth.NewSecrets() + assert.NoError(t, err) + + ctx := context.Background() + sm := &mocks.SecretManager{} + + invalidFunc := func() *mocks.SecretManager_Get { + sm.OnGet(ctx, config.SecretNameOldTokenSigningRSAKey).Unset() + return sm.OnGet(ctx, config.SecretNameOldTokenSigningRSAKey).Return("this is no PEM data", nil) + } + newInvalidMockProvider(ctx, t, secrets, sm, invalidFunc, "failed to decode PKCS1PrivateKey. Error: no PEM data found") +} + func TestProvider_KeySet(t *testing.T) { p, _ := newMockProvider(t) - assert.Equal(t, 1, p.KeySet().Len()) + assert.Equal(t, 2, p.KeySet().Len()) } func TestProvider_NewJWTSessionToken(t *testing.T) { @@ -64,7 +181,7 @@ func TestProvider_NewJWTSessionToken(t *testing.T) { func TestProvider_PublicKeys(t *testing.T) { p, _ := newMockProvider(t) - assert.Len(t, p.PublicKeys(), 1) + assert.Len(t, p.PublicKeys(), 2) } type CustomClaimsExample struct { @@ -175,7 +292,7 @@ func TestProvider_ValidateAccessToken(t *testing.T) { var buf bytes.Buffer assert.NoError(t, pem.Encode(&buf, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: privBytes})) sm.OnGet(ctx, config.SecretNameTokenSigningRSAKey).Return(buf.String(), nil) - sm.OnGet(ctx, config.SecretNameOldTokenSigningRSAKey).Return("", fmt.Errorf("not found")) + sm.OnGet(ctx, config.SecretNameOldTokenSigningRSAKey).Return(buf.String(), nil) p, err := NewProvider(ctx, config.DefaultConfig.AppAuth.SelfAuthServer, sm) assert.NoError(t, err)