Skip to content

Commit

Permalink
chore: provide default session credentials to pause container (#5430)
Browse files Browse the repository at this point in the history
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the Apache 2.0 License.
  • Loading branch information
dannyrandall authored Nov 1, 2023
1 parent 9c9cf3b commit baf2a78
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 78 deletions.
73 changes: 47 additions & 26 deletions internal/pkg/cli/run_local.go
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,14 @@ func (o *runLocalOpts) getTask(ctx context.Context) (orchestrator.Task, error) {
Containers: make(map[string]orchestrator.ContainerDefinition, len(td.ContainerDefinitions)),
}

if o.proxy {
pauseSecrets, err := sessionEnvVars(ctx, o.sess)
if err != nil {
return orchestrator.Task{}, fmt.Errorf("get pause container secrets: %w", err)
}
task.PauseSecrets = pauseSecrets
}

for _, ctr := range td.ContainerDefinitions {
name := aws.StringValue(ctr.Name)
def := orchestrator.ContainerDefinition{
Expand Down Expand Up @@ -401,6 +409,24 @@ func (o *runLocalOpts) getTask(ctx context.Context) (orchestrator.Task, error) {
return task, nil
}

func sessionEnvVars(ctx context.Context, sess *session.Session) (map[string]string, error) {
creds, err := sess.Config.Credentials.GetWithContext(ctx)
if err != nil {
return nil, fmt.Errorf("get IAM credentials: %w", err)
}

env := map[string]string{
"AWS_ACCESS_KEY_ID": creds.AccessKeyID,
"AWS_SECRET_ACCESS_KEY": creds.SecretAccessKey,
"AWS_SESSION_TOKEN": creds.SessionToken,
}
if sess.Config.Region != nil {
env["AWS_DEFAULT_REGION"] = aws.StringValue(sess.Config.Region)
env["AWS_REGION"] = aws.StringValue(sess.Config.Region)
}
return env, nil
}

type containerEnv map[string]envVarValue

type envVarValue struct {
Expand Down Expand Up @@ -439,35 +465,12 @@ func (c containerEnv) Secrets() map[string]string {

// getEnvVars uses env overrides passed by flags and environment variables/secrets
// specified in the Task Definition to return a set of environment varibles for each
// continer defined in the TaskDefinition. The returned map is a map of container names,
// container defined in the TaskDefinition. The returned map is a map of container names,
// each of which contains a mapping of key->envVarValue, which defines if the variable is a secret or not.
func (o *runLocalOpts) getEnvVars(ctx context.Context, taskDef *awsecs.TaskDefinition) (map[string]containerEnv, error) {
creds, err := o.sess.Config.Credentials.GetWithContext(ctx)
if err != nil {
return nil, fmt.Errorf("get IAM credentials: %w", err)
}

envVars := make(map[string]containerEnv)
envVars := make(map[string]containerEnv, len(taskDef.ContainerDefinitions))
for _, ctr := range taskDef.ContainerDefinitions {
name := aws.StringValue(ctr.Name)
envVars[name] = map[string]envVarValue{
"AWS_ACCESS_KEY_ID": {
Value: creds.AccessKeyID,
},
"AWS_SECRET_ACCESS_KEY": {
Value: creds.SecretAccessKey,
},
"AWS_SESSION_TOKEN": {
Value: creds.SessionToken,
},
}
if o.sess.Config.Region != nil {
val := envVarValue{
Value: aws.StringValue(o.sess.Config.Region),
}
envVars[name]["AWS_DEFAULT_REGION"] = val
envVars[name]["AWS_REGION"] = val
}
envVars[aws.StringValue(ctr.Name)] = make(map[string]envVarValue)
}

for _, e := range taskDef.EnvironmentVariables() {
Expand All @@ -483,6 +486,24 @@ func (o *runLocalOpts) getEnvVars(ctx context.Context, taskDef *awsecs.TaskDefin
if err := o.fillSecrets(ctx, envVars, taskDef); err != nil {
return nil, fmt.Errorf("get secrets: %w", err)
}

// inject session variables if they haven't been already set
sessionVars, err := sessionEnvVars(ctx, o.sess)
if err != nil {
return nil, err
}

for ctr := range envVars {
for k, v := range sessionVars {
if _, ok := envVars[ctr][k]; !ok {
envVars[ctr][k] = envVarValue{
Value: v,
Secret: true,
}
}
}
}

return envVars, nil
}

Expand Down
143 changes: 93 additions & 50 deletions internal/pkg/cli/run_local_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,14 +317,14 @@ func TestRunLocalOpts_Execute(t *testing.T) {
"foo": {
ImageURI: "image1",
EnvVars: map[string]string{
"FOO_VAR": "foo-value",
"FOO_VAR": "foo-value",
},
Secrets: map[string]string{
"SHARED_SECRET": "secretvalue",
"AWS_ACCESS_KEY_ID": "myID",
"AWS_SECRET_ACCESS_KEY": "mySecret",
"AWS_SESSION_TOKEN": "myToken",
},
Secrets: map[string]string{
"SHARED_SECRET": "secretvalue",
},
Ports: map[string]string{
"80": "8080",
"999": "9999",
Expand All @@ -333,30 +333,38 @@ func TestRunLocalOpts_Execute(t *testing.T) {
"bar": {
ImageURI: "image2",
EnvVars: map[string]string{
"BAR_VAR": "bar-value",
"BAR_VAR": "bar-value",
},
Secrets: map[string]string{
"SHARED_SECRET": "secretvalue",
"AWS_ACCESS_KEY_ID": "myID",
"AWS_SECRET_ACCESS_KEY": "mySecret",
"AWS_SESSION_TOKEN": "myToken",
},
Secrets: map[string]string{
"SHARED_SECRET": "secretvalue",
},
Ports: map[string]string{
"777": "7777",
"10000": "10000",
},
},
},
}
expectedProxyTask := orchestrator.Task{
Containers: expectedTask.Containers,
PauseSecrets: map[string]string{
"AWS_ACCESS_KEY_ID": "myID",
"AWS_SECRET_ACCESS_KEY": "mySecret",
"AWS_SESSION_TOKEN": "myToken",
},
}

testCases := map[string]struct {
inputAppName string
inputEnvName string
inputWkldName string
inputEnvOverrides map[string]string
inputPortOverrides []string
inputProxy bool
buildImagesError error
inProxy bool

setupMocks func(t *testing.T, m *runLocalExecuteMocks)
wantedWkldName string
Expand Down Expand Up @@ -389,7 +397,7 @@ func TestRunLocalOpts_Execute(t *testing.T) {
inputAppName: testAppName,
inputWkldName: testWkldName,
inputEnvName: testEnvName,
inProxy: true,
inputProxy: true,
setupMocks: func(t *testing.T, m *runLocalExecuteMocks) {
m.ecsClient.EXPECT().TaskDefinition(testAppName, testEnvName, testWkldName).Return(taskDef, nil)
m.ssm.EXPECT().GetSecretValue(gomock.Any(), "mysecret").Return("secretvalue", nil)
Expand All @@ -401,7 +409,7 @@ func TestRunLocalOpts_Execute(t *testing.T) {
inputAppName: testAppName,
inputWkldName: testWkldName,
inputEnvName: testEnvName,
inProxy: true,
inputProxy: true,
setupMocks: func(t *testing.T, m *runLocalExecuteMocks) {
m.ecsClient.EXPECT().TaskDefinition(testAppName, testEnvName, testWkldName).Return(taskDef, nil)
m.ssm.EXPECT().GetSecretValue(gomock.Any(), "mysecret").Return("secretvalue", nil)
Expand All @@ -413,7 +421,7 @@ func TestRunLocalOpts_Execute(t *testing.T) {
inputAppName: testAppName,
inputWkldName: testWkldName,
inputEnvName: testEnvName,
inProxy: true,
inputProxy: true,
setupMocks: func(t *testing.T, m *runLocalExecuteMocks) {
m.ecsClient.EXPECT().TaskDefinition(testAppName, testEnvName, testWkldName).Return(taskDef, nil)
m.ssm.EXPECT().GetSecretValue(gomock.Any(), "mysecret").Return("secretvalue", nil)
Expand Down Expand Up @@ -460,7 +468,7 @@ func TestRunLocalOpts_Execute(t *testing.T) {
},
wantedError: errors.New(`build images: some error`),
},
"pulls errors from orchestrator": {
"success, one run task call": {
inputAppName: testAppName,
inputWkldName: testWkldName,
inputEnvName: testEnvName,
Expand All @@ -484,6 +492,40 @@ func TestRunLocalOpts_Execute(t *testing.T) {
}
},
},
"success, one run task call, proxy": {
inputAppName: testAppName,
inputWkldName: testWkldName,
inputEnvName: testEnvName,
inputProxy: true,
setupMocks: func(t *testing.T, m *runLocalExecuteMocks) {
m.ecsClient.EXPECT().TaskDefinition(testAppName, testEnvName, testWkldName).Return(taskDef, nil)
m.ssm.EXPECT().GetSecretValue(gomock.Any(), "mysecret").Return("secretvalue", nil)
m.envChecker.EXPECT().Version().Return("v1.32.0", nil)
m.hostFinder.HostsFn = func(ctx context.Context) ([]host, error) {
return []host{
{
host: "a-different-service",
port: "80",
},
}, nil
}
m.ws.EXPECT().ReadWorkloadManifest(testWkldName).Return([]byte(""), nil)
m.interpolator.EXPECT().Interpolate("").Return("", nil)

errCh := make(chan error, 1)
m.orchestrator.StartFn = func() <-chan error {
errCh <- errors.New("some error")
return errCh
}
m.orchestrator.RunTaskFn = func(task orchestrator.Task) {
require.Equal(t, expectedProxyTask, task)
}
m.orchestrator.StopFn = func() {
require.Len(t, errCh, 0)
close(errCh)
}
},
},
"handles ctrl-c, waits to get all errors": {
inputAppName: testAppName,
inputWkldName: testWkldName,
Expand Down Expand Up @@ -557,7 +599,7 @@ func TestRunLocalOpts_Execute(t *testing.T) {
container: "9999",
},
},
proxy: tc.inProxy,
proxy: tc.inputProxy,
},
newInterpolator: func(app, env string) interpolator {
return m.interpolator
Expand Down Expand Up @@ -623,10 +665,6 @@ func TestRunLocalOpts_getEnvVars(t *testing.T) {
want map[string]containerEnv
wantError string
}{
"error getting creds": {
credsError: errors.New("some error"),
wantError: `get IAM credentials: some error`,
},
"invalid container in env override": {
taskDef: &ecs.TaskDefinition{},
envOverrides: map[string]string{
Expand Down Expand Up @@ -654,16 +692,16 @@ func TestRunLocalOpts_getEnvVars(t *testing.T) {
"foo": {
"OVERRIDE_ALL": newVar("all", true, false),
"OVERRIDE": newVar("foo", true, false),
"AWS_ACCESS_KEY_ID": newVar("myID", false, false),
"AWS_SECRET_ACCESS_KEY": newVar("mySecret", false, false),
"AWS_SESSION_TOKEN": newVar("myToken", false, false),
"AWS_ACCESS_KEY_ID": newVar("myID", false, true),
"AWS_SECRET_ACCESS_KEY": newVar("mySecret", false, true),
"AWS_SESSION_TOKEN": newVar("myToken", false, true),
},
"bar": {
"OVERRIDE_ALL": newVar("all", true, false),
"OVERRIDE": newVar("bar", true, false),
"AWS_ACCESS_KEY_ID": newVar("myID", false, false),
"AWS_SECRET_ACCESS_KEY": newVar("mySecret", false, false),
"AWS_SESSION_TOKEN": newVar("myToken", false, false),
"AWS_ACCESS_KEY_ID": newVar("myID", false, true),
"AWS_SECRET_ACCESS_KEY": newVar("mySecret", false, true),
"AWS_SESSION_TOKEN": newVar("myToken", false, true),
},
},
},
Expand Down Expand Up @@ -716,17 +754,17 @@ func TestRunLocalOpts_getEnvVars(t *testing.T) {
"RANDOM_FOO": newVar("foo", false, false),
"OVERRIDE_ALL": newVar("all", true, false),
"OVERRIDE": newVar("foo", true, false),
"AWS_ACCESS_KEY_ID": newVar("myID", false, false),
"AWS_SECRET_ACCESS_KEY": newVar("mySecret", false, false),
"AWS_SESSION_TOKEN": newVar("myToken", false, false),
"AWS_ACCESS_KEY_ID": newVar("myID", false, true),
"AWS_SECRET_ACCESS_KEY": newVar("mySecret", false, true),
"AWS_SESSION_TOKEN": newVar("myToken", false, true),
},
"bar": {
"RANDOM_BAR": newVar("bar", false, false),
"OVERRIDE_ALL": newVar("all", true, false),
"OVERRIDE": newVar("bar", true, false),
"AWS_ACCESS_KEY_ID": newVar("myID", false, false),
"AWS_SECRET_ACCESS_KEY": newVar("mySecret", false, false),
"AWS_SESSION_TOKEN": newVar("myToken", false, false),
"AWS_ACCESS_KEY_ID": newVar("myID", false, true),
"AWS_SECRET_ACCESS_KEY": newVar("mySecret", false, true),
"AWS_SESSION_TOKEN": newVar("myToken", false, true),
},
},
},
Expand Down Expand Up @@ -819,9 +857,9 @@ func TestRunLocalOpts_getEnvVars(t *testing.T) {
"SSM": newVar("ssm", false, true),
"SECRETS_MANAGER": newVar("secretsmanager", false, true),
"DEFAULT": newVar("default", false, true),
"AWS_ACCESS_KEY_ID": newVar("myID", false, false),
"AWS_SECRET_ACCESS_KEY": newVar("mySecret", false, false),
"AWS_SESSION_TOKEN": newVar("myToken", false, false),
"AWS_ACCESS_KEY_ID": newVar("myID", false, true),
"AWS_SECRET_ACCESS_KEY": newVar("mySecret", false, true),
"AWS_SESSION_TOKEN": newVar("myToken", false, true),
},
},
},
Expand Down Expand Up @@ -865,16 +903,16 @@ func TestRunLocalOpts_getEnvVars(t *testing.T) {
"foo": {
"ONE": newVar("shared-value", false, true),
"TWO": newVar("foo-value", false, true),
"AWS_ACCESS_KEY_ID": newVar("myID", false, false),
"AWS_SECRET_ACCESS_KEY": newVar("mySecret", false, false),
"AWS_SESSION_TOKEN": newVar("myToken", false, false),
"AWS_ACCESS_KEY_ID": newVar("myID", false, true),
"AWS_SECRET_ACCESS_KEY": newVar("mySecret", false, true),
"AWS_SESSION_TOKEN": newVar("myToken", false, true),
},
"bar": {
"THREE": newVar("shared-value", false, true),
"FOUR": newVar("bar-value", false, true),
"AWS_ACCESS_KEY_ID": newVar("myID", false, false),
"AWS_SECRET_ACCESS_KEY": newVar("mySecret", false, false),
"AWS_SESSION_TOKEN": newVar("myToken", false, false),
"AWS_ACCESS_KEY_ID": newVar("myID", false, true),
"AWS_SECRET_ACCESS_KEY": newVar("mySecret", false, true),
"AWS_SESSION_TOKEN": newVar("myToken", false, true),
},
},
},
Expand Down Expand Up @@ -921,20 +959,25 @@ func TestRunLocalOpts_getEnvVars(t *testing.T) {
"foo": {
"ONE": newVar("one-overridden", true, false),
"TWO": newVar("foo-value", false, true),
"AWS_ACCESS_KEY_ID": newVar("myID", false, false),
"AWS_SECRET_ACCESS_KEY": newVar("mySecret", false, false),
"AWS_SESSION_TOKEN": newVar("myToken", false, false),
"AWS_ACCESS_KEY_ID": newVar("myID", false, true),
"AWS_SECRET_ACCESS_KEY": newVar("mySecret", false, true),
"AWS_SESSION_TOKEN": newVar("myToken", false, true),
},
"bar": {
"ONE": newVar("one-overridden", true, false),
"THREE": newVar("shared-value", false, true),
"FOUR": newVar("four-overridden", true, false),
"AWS_ACCESS_KEY_ID": newVar("myID", false, false),
"AWS_SECRET_ACCESS_KEY": newVar("mySecret", false, false),
"AWS_SESSION_TOKEN": newVar("myToken", false, false),
"AWS_ACCESS_KEY_ID": newVar("myID", false, true),
"AWS_SECRET_ACCESS_KEY": newVar("mySecret", false, true),
"AWS_SESSION_TOKEN": newVar("myToken", false, true),
},
},
},
"error getting creds": {
taskDef: &ecs.TaskDefinition{},
credsError: errors.New("some error"),
wantError: `get IAM credentials: some error`,
},
"region env vars set": {
taskDef: &ecs.TaskDefinition{
ContainerDefinitions: []*sdkecs.ContainerDefinition{
Expand All @@ -947,11 +990,11 @@ func TestRunLocalOpts_getEnvVars(t *testing.T) {
region: aws.String("myRegion"),
want: map[string]containerEnv{
"foo": {
"AWS_ACCESS_KEY_ID": newVar("myID", false, false),
"AWS_SECRET_ACCESS_KEY": newVar("mySecret", false, false),
"AWS_SESSION_TOKEN": newVar("myToken", false, false),
"AWS_REGION": newVar("myRegion", false, false),
"AWS_DEFAULT_REGION": newVar("myRegion", false, false),
"AWS_ACCESS_KEY_ID": newVar("myID", false, true),
"AWS_SECRET_ACCESS_KEY": newVar("mySecret", false, true),
"AWS_SESSION_TOKEN": newVar("myToken", false, true),
"AWS_REGION": newVar("myRegion", false, true),
"AWS_DEFAULT_REGION": newVar("myRegion", false, true),
},
},
},
Expand Down
Loading

0 comments on commit baf2a78

Please sign in to comment.