From aa9b02cbf4904e28171078477078830aed00f003 Mon Sep 17 00:00:00 2001 From: Kryvchun Date: Thu, 18 May 2023 06:54:02 +0300 Subject: [PATCH] fix: renew token of vault k8s auth method --- dependency/vault_token.go | 51 ++++++++++++++++++++++++++++++---- dependency/vault_token_test.go | 20 +++++++++++++ watch/vault_token.go | 33 ++++++++++++++++++++-- watch/vault_token_test.go | 31 +++++++++++++++++++++ 4 files changed, 128 insertions(+), 7 deletions(-) diff --git a/dependency/vault_token.go b/dependency/vault_token.go index 5d184e05d..6efd80705 100644 --- a/dependency/vault_token.go +++ b/dependency/vault_token.go @@ -8,6 +8,9 @@ import ( "github.com/pkg/errors" ) +// VaultTokenRefreshCurrent tells to refresh the current client token. +const VaultTokenRefreshCurrent = "" + // Ensure implements var _ Dependency = (*VaultTokenQuery)(nil) @@ -16,6 +19,8 @@ type VaultTokenQuery struct { stopCh chan struct{} secret *Secret vaultSecret *api.Secret + + initialToken string } // NewVaultTokenQuery creates a new dependency. @@ -28,9 +33,10 @@ func NewVaultTokenQuery(token string) (*VaultTokenQuery, error) { }, } return &VaultTokenQuery{ - stopCh: make(chan struct{}, 1), - vaultSecret: vaultSecret, - secret: transformSecret(vaultSecret), + stopCh: make(chan struct{}, 1), + vaultSecret: vaultSecret, + secret: transformSecret(vaultSecret), + initialToken: token, }, nil } @@ -43,8 +49,16 @@ func (d *VaultTokenQuery) Fetch(clients *ClientSet, opts *QueryOptions, default: } - if vaultSecretRenewable(d.secret) { - err := renewSecret(clients, d) + var currentRenewer renewer = d + + if d.initialToken == VaultTokenRefreshCurrent { + currentRenewer = newVaultSecretsOverrideRenewer(d, clients.Vault().Token()) + } + + secret, _ := currentRenewer.secrets() + + if vaultSecretRenewable(secret) { + err := renewSecret(clients, currentRenewer) if err != nil { return nil, nil, errors.Wrap(err, d.String()) } @@ -80,3 +94,30 @@ func (d *VaultTokenQuery) String() string { func (d *VaultTokenQuery) Type() Type { return TypeVault } + +func newVaultSecretsOverrideRenewer(parent renewer, token string) *vaultSecretsOverrideRenewer { + vaultSecret := &api.Secret{ + Auth: &api.SecretAuth{ + ClientToken: token, + Renewable: true, + LeaseDuration: 1, + }, + } + + return &vaultSecretsOverrideRenewer{ + renewer: parent, + vaultSecret: vaultSecret, + secret: transformSecret(vaultSecret), + } +} + +type vaultSecretsOverrideRenewer struct { + renewer + + secret *Secret + vaultSecret *api.Secret +} + +func (d *vaultSecretsOverrideRenewer) secrets() (*Secret, *api.Secret) { + return d.secret, d.vaultSecret +} diff --git a/dependency/vault_token_test.go b/dependency/vault_token_test.go index 73a4e0980..45ce1ae25 100644 --- a/dependency/vault_token_test.go +++ b/dependency/vault_token_test.go @@ -9,6 +9,7 @@ import ( "github.com/hashicorp/vault/api" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewVaultTokenQuery(t *testing.T) { @@ -76,3 +77,22 @@ func TestVaultTokenQuery_String(t *testing.T) { }) } } + +func TestTewVaultSecretsOverrideRenewer(t *testing.T) { + const token = "expected_token" + + parent, err := NewVaultTokenQuery(VaultTokenRefreshCurrent) + require.NoError(t, err) + + vaultTokenSecretsOverride := newVaultSecretsOverrideRenewer(parent, token) + + secret, vaultSecret := vaultTokenSecretsOverride.secrets() + + if assert.NotNil(t, secret) && assert.NotNil(t, secret.Auth) { + assert.Equal(t, token, secret.Auth.ClientToken) + } + + if assert.NotNil(t, vaultSecret) && assert.NotNil(t, vaultSecret.Auth) { + assert.Equal(t, token, vaultSecret.Auth.ClientToken) + } +} diff --git a/watch/vault_token.go b/watch/vault_token.go index 7e3cc2fbd..9ec7cd6ae 100644 --- a/watch/vault_token.go +++ b/watch/vault_token.go @@ -15,6 +15,36 @@ import ( "github.com/hashicorp/vault/api" ) +func newVaultK8SAuthMethodRefreshTokenWatcher( + clients *dep.ClientSet, c *config.VaultConfig, doneCh chan struct{}, +) (*Watcher, error) { + isK8SAuthMethod := config.StringVal(c.K8SServiceAccountToken) != "" || config.StringVal(c.K8SServiceAccountTokenPath) != "" + + if !isK8SAuthMethod || !config.BoolVal(c.RenewToken) { + return nil, nil + } + + watcher := NewWatcher(&NewWatcherInput{ + Clients: clients, + RetryFuncVault: RetryFunc(c.Retry.RetryFunc()), + }) + + vaultQuery, err := dep.NewVaultTokenQuery(dep.VaultTokenRefreshCurrent) + if err != nil { + watcher.Stop() + + return nil, fmt.Errorf("vaultwatcher: %w", err) + } + + if _, err := watcher.Add(vaultQuery); err != nil { + watcher.Stop() + + return nil, fmt.Errorf("vaultwatcher: %w", err) + } + + return watcher, nil +} + // VaultTokenWatcher monitors the vault token for updates func VaultTokenWatcher( clients *dep.ClientSet, c *config.VaultConfig, doneCh chan struct{}, @@ -24,7 +54,7 @@ func VaultTokenWatcher( // tokens are not being used. raw_token := strings.TrimSpace(config.StringVal(c.Token)) if raw_token == "" { - return nil, nil + return newVaultK8SAuthMethodRefreshTokenWatcher(clients, c, doneCh) } unwrap := config.BoolVal(c.UnwrapToken) @@ -76,7 +106,6 @@ func VaultTokenWatcher( return watcher, nil } - func watchTokenFile( w *Watcher, tokenFile, raw_token string, unwrap bool, doneCh chan struct{}, ) (func(), error) { diff --git a/watch/vault_token_test.go b/watch/vault_token_test.go index 9e9398935..2c01dff7d 100644 --- a/watch/vault_token_test.go +++ b/watch/vault_token_test.go @@ -14,6 +14,8 @@ import ( "github.com/hashicorp/consul-template/config" dep "github.com/hashicorp/consul-template/dependency" "github.com/hashicorp/vault/api" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // approle auto-auth setup in watch_test.go, TestMain() @@ -141,6 +143,35 @@ func TestVaultTokenWatcher(t *testing.T) { // give it a chance to throw an error } }) + + t.Run("renew_kubernetes", func(t *testing.T) { + // Check that there is an attempt to refresh token. + testClients.Vault().SetToken(vaultToken) + + _, err := testClients.Vault().Auth().Token().Create( + &api.TokenCreateRequest{ + ID: "c_token", + TTL: "1m", + Renewable: config.Bool(true), + }) + require.NoError(t, err) + + conf := config.DefaultVaultConfig() + conf.Token = config.String("") + conf.RenewToken = config.Bool(true) + conf.K8SServiceAccountToken = config.String("any_k8s_token") + + watcher, err := VaultTokenWatcher(testClients, conf, nil) + require.NoError(t, err) + + defer watcher.Stop() + + select { + case err := <-watcher.ErrCh(): + assert.ErrorIs(t, err, dep.ErrLeaseExpired) + case <-time.After(time.Millisecond * 100): + } + }) } func TestVaultTokenRefreshToken(t *testing.T) {