diff --git a/pkg/secrets/secrets.go b/pkg/secrets/secrets.go index d2c19c24..5d1e8652 100644 --- a/pkg/secrets/secrets.go +++ b/pkg/secrets/secrets.go @@ -51,6 +51,7 @@ type SecretDefinition struct { paths []string secrets map[string]string plural bool + secretEnv bool } // SecretFetcher inspects the environment for variables that @@ -84,6 +85,7 @@ func SecretFetcher(client *api.Client, config cfg.Config) { envkey: envKey, secretApex: apex, secrets: make(map[string]string), + secretEnv: config.SecretEnv, } switch { @@ -285,22 +287,21 @@ func (sd *SecretDefinition) addSecrets(secretResult *SecretResult) error { return fmt.Errorf("vault listed a secret %s %s, but failed trying to read it; likely the rate-limiting retry attempts were exceeded", keyName, keyPath) } + envKey := os.Getenv(secretValueKeyPrefix + sd.secretID) + if envKey != "" { + log.Info().Str("key", secretValueKeyPrefix+sd.secretID).Msg("Found an explicit vault value key, will read this value key instead of using the default") + } + + if !sd.plural && sd.secretEnv && envKey != "" { + return sd.copyValue(secretData, envKey) + } + if !sd.plural && sd.outputDestination != "" { singleValueKey := defaultKeyName - if envKey := os.Getenv(secretValueKeyPrefix + sd.secretID); envKey != "" { - log.Info().Str("key", secretValueKeyPrefix+sd.secretID).Str("value", singleValueKey).Msg("Found an explicit vault value key, will read this value key instead of using the default") + if envKey != "" { singleValueKey = envKey } - v, ok := secretData[singleValueKey] - if ok { - secretValue, err := valueConverter(v) - if err == nil { - sd.Lock() - sd.secrets[singleValueKey] = secretValue - sd.Unlock() - } - return err - } + return sd.copyValue(secretData, singleValueKey) } for k, v := range secretData { @@ -321,6 +322,24 @@ func (sd *SecretDefinition) addSecrets(secretResult *SecretResult) error { return nil } +// copyValues copies a value from the secretData object returned by vault and writes it into the secrets map of the +// SecretDefintion +func (sd *SecretDefinition) copyValue(secretData map[string]interface{}, key string) error { + v, ok := secretData[key] + if ok { + secretValue, err := valueConverter(v) + if err == nil { + sd.Lock() + sd.secrets[key] = secretValue + sd.Unlock() + return nil + } else { + return err + } + } + return nil +} + // Walk walks a SecretDefintions SecretApex. This is used for iteration // of the provided apex path func (sd *SecretDefinition) Walk(client *api.Client) error { diff --git a/pkg/secrets/secrets_test.go b/pkg/secrets/secrets_test.go index b32cc5a8..43ff2d12 100644 --- a/pkg/secrets/secrets_test.go +++ b/pkg/secrets/secrets_test.go @@ -288,7 +288,7 @@ func TestSecretAWalk(t *testing.T) { os.Setenv("VAULT_SECRETS_COMMON", "secret/path/common") os.Setenv("DAYTONA_SECRET_DESTINATION_COMMON", destinationPrefixFile.Name()) defer os.Unsetenv("VAULT_SECRETS_COMMON") - defer os.Unsetenv("VAULT_SECRETS_GENERIC") + defer os.Unsetenv("DAYTONA_SECRET_DESTINATION_COMMON") config.SecretPayloadPath = file.Name() SecretFetcher(client, config) @@ -440,7 +440,7 @@ func TestUnmatchedPluralDesintation(t *testing.T) { os.Setenv("DAYTONA_SECRET_DESTINATION_jacka", f2.Name()) defer os.Unsetenv("VAULT_SECRETS_APEX") - defer os.Setenv("DAYTONA_SECRET_DESTINATION_tha", f1.Name()) + defer os.Unsetenv("DAYTONA_SECRET_DESTINATION_tha") defer os.Unsetenv("DAYTONA_SECRET_DESTINATION_jacka") SecretFetcher(client, config) @@ -804,3 +804,91 @@ func TestSecretSingularDestinationKeyOverride(t *testing.T) { assert.Equal(t, "nonstandard", string(data)) } + +func TestSecretSingularEnvKeyOverride(t *testing.T) { + var config cfg.Config + + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, ` + { + "auth": null, + "data": { + "username": "alice", + "password": "p@ssw0rd" + }, + "lease_duration": 3600, + "lease_id": "", + "renewable": false + } + `) + })) + defer ts.Close() + client, err := testhelpers.GetTestClient(ts.URL) + if err != nil { + t.Fatal(err) + } + + os.Setenv("VAULT_SECRET_APPLICATIONA", "secret/applicationA") + os.Setenv("VAULT_VALUE_KEY_APPLICATIONA", "password") + + defer os.Unsetenv("VAULT_SECRET_APPLICATIONA") + defer os.Unsetenv("VAULT_VALUE_KEY_APPLICATIONA") + defer os.Unsetenv("password") + + config.Workers = 3 + config.SecretEnv = true + SecretFetcher(client, config) + + assert.Equal(t, "p@ssw0rd", os.Getenv("password")) +} + +func TestCopyValue(t *testing.T) { + tests := []struct { + name string + secretData map[string]interface{} + key string + expectedSecrets map[string]string + expectedError error + }{ + { + name: "copy value", + secretData: map[string]interface{}{ + "foo": "bar", + }, + key: "foo", + expectedSecrets: map[string]string{ + "foo": "bar", + }, + expectedError: nil, + }, + { + name: "value not found", + secretData: map[string]interface{}{ + "foo": "bar", + }, + key: "baz", + expectedSecrets: map[string]string{}, + expectedError: nil, + }, + { + name: "value conversion error", + secretData: map[string]interface{}{ + "foo": 42, + }, + key: "foo", + expectedSecrets: map[string]string{}, + expectedError: fmt.Errorf("unsupported value type retrieved from vault: int"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sd := &SecretDefinition{secrets: map[string]string{}} + + err := sd.copyValue(tt.secretData, tt.key) + + assert.Equal(t, tt.expectedSecrets, sd.secrets) + assert.Equal(t, tt.expectedError, err) + }) + } +}