diff --git a/config/config.go b/config/config.go index b25d708..6a1fd69 100644 --- a/config/config.go +++ b/config/config.go @@ -25,6 +25,13 @@ import ( "github.com/aws/aws-sdk-go-v2/credentials" ) +type AWSCredentialType string + +const ( + AWSCredentialTypeAccessKey AWSCredentialType = "access_key" + AWSCredentialTypeRole AWSCredentialType = "role" +) + // NewConfig returns a new Config func NewConfig(cfgFile string) (*Config, error) { var config Config @@ -59,7 +66,7 @@ func (c *Config) Validate() error { return nil } -type Credentials struct { +type AccessKeyCredentials struct { // AWS Access key ID AccessKeyID string `toml:"access_key_id"` @@ -70,7 +77,7 @@ type Credentials struct { SessionToken string `toml:"session_token"` } -func (c Credentials) Validate() error { +func (c AccessKeyCredentials) Validate() error { if c.AccessKeyID == "" { return fmt.Errorf("missing access_key_id") } @@ -85,19 +92,42 @@ func (c Credentials) Validate() error { return nil } +type Credentials struct { + CredentialType AWSCredentialType `toml:"credential_type"` + AccessKey AccessKeyCredentials `toml:"access_key"` +} + +func (c Credentials) Validate() error { + switch c.CredentialType { + case AWSCredentialTypeAccessKey: + return c.AccessKey.Validate() + case AWSCredentialTypeRole: + } + return nil +} + func (c Config) GetAWSConfig(ctx context.Context) (aws.Config, error) { if err := c.Credentials.Validate(); err != nil { return aws.Config{}, fmt.Errorf("failed to validate credentials: %w", err) } - cfg, err := config.LoadDefaultConfig(ctx, - config.WithCredentialsProvider( - credentials.NewStaticCredentialsProvider( - c.Credentials.AccessKeyID, - c.Credentials.SecretAccessKey, - c.Credentials.SessionToken)), - config.WithRegion(c.Region), - ) + var cfg aws.Config + var err error + switch c.Credentials.CredentialType { + case AWSCredentialTypeAccessKey: + cfg, err = config.LoadDefaultConfig(ctx, + config.WithCredentialsProvider( + credentials.NewStaticCredentialsProvider( + c.Credentials.AccessKey.AccessKeyID, + c.Credentials.AccessKey.SecretAccessKey, + c.Credentials.AccessKey.SessionToken)), + config.WithRegion(c.Region), + ) + case AWSCredentialTypeRole: + cfg, err = config.LoadDefaultConfig(ctx) + default: + return aws.Config{}, fmt.Errorf("unknown credential type: %s", c.Credentials.CredentialType) + } if err != nil { return aws.Config{}, fmt.Errorf("failed to get aws config: %w", err) } diff --git a/config/config_test.go b/config/config_test.go index ae16b71..2763af5 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -32,9 +32,12 @@ func TestConfigValidate(t *testing.T) { name: "valid config", c: &Config{ Credentials: Credentials{ - AccessKeyID: "access_key_id", - SecretAccessKey: "secret_access_key", - SessionToken: "session_token", + CredentialType: AWSCredentialTypeAccessKey, + AccessKey: AccessKeyCredentials{ + AccessKeyID: "AccessKeyID", + SecretAccessKey: "SecretAccessKey", + SessionToken: "SessionToken", + }, }, SubnetID: "subnet_id", Region: "region", @@ -45,9 +48,12 @@ func TestConfigValidate(t *testing.T) { name: "missing subnet_id", c: &Config{ Credentials: Credentials{ - AccessKeyID: "access_key_id", - SecretAccessKey: "secret_access_key", - SessionToken: "session_token", + CredentialType: AWSCredentialTypeAccessKey, + AccessKey: AccessKeyCredentials{ + AccessKeyID: "AccessKeyID", + SecretAccessKey: "SecretAccessKey", + SessionToken: "SessionToken", + }, }, Region: "region", }, @@ -57,9 +63,12 @@ func TestConfigValidate(t *testing.T) { name: "missing region", c: &Config{ Credentials: Credentials{ - AccessKeyID: "access_key_id", - SecretAccessKey: "secret_access_key", - SessionToken: "session_token", + CredentialType: AWSCredentialTypeAccessKey, + AccessKey: AccessKeyCredentials{ + AccessKeyID: "AccessKeyID", + SecretAccessKey: "SecretAccessKey", + SessionToken: "SessionToken", + }, }, SubnetID: "subnet_id", }, @@ -96,36 +105,48 @@ func TestCredentialsValidate(t *testing.T) { { name: "valid credentials", c: Credentials{ - AccessKeyID: "access_key_id", - SecretAccessKey: "secret_access_key", - SessionToken: "session_token", + CredentialType: AWSCredentialTypeAccessKey, + AccessKey: AccessKeyCredentials{ + AccessKeyID: "AccessKeyID", + SecretAccessKey: "SecretAccessKey", + SessionToken: "SessionToken", + }, }, errString: "", }, { name: "missing access_key_id", c: Credentials{ - AccessKeyID: "", - SecretAccessKey: "secret_access_key", - SessionToken: "session_token", + CredentialType: AWSCredentialTypeAccessKey, + AccessKey: AccessKeyCredentials{ + AccessKeyID: "AccessKeyID", + SecretAccessKey: "SecretAccessKey", + SessionToken: "SessionToken", + }, }, errString: "missing access_key_id", }, { name: "missing secret_access_key", c: Credentials{ - AccessKeyID: "access_key_id", - SecretAccessKey: "", - SessionToken: "session_token", + CredentialType: AWSCredentialTypeAccessKey, + AccessKey: AccessKeyCredentials{ + AccessKeyID: "AccessKeyID", + SecretAccessKey: "SecretAccessKey", + SessionToken: "SessionToken", + }, }, errString: "missing secret_access_key", }, { name: "missing session_token", c: Credentials{ - AccessKeyID: "access_key_id", - SecretAccessKey: "secret_access_key", - SessionToken: "", + CredentialType: AWSCredentialTypeAccessKey, + AccessKey: AccessKeyCredentials{ + AccessKeyID: "AccessKeyID", + SecretAccessKey: "SecretAccessKey", + SessionToken: "SessionToken", + }, }, errString: "missing session_token", }, @@ -171,9 +192,12 @@ func TestNewConfig(t *testing.T) { require.NoError(t, err, "NewConfig() should not have returned an error") require.Equal(t, &Config{ Credentials: Credentials{ - AccessKeyID: "access_key_id", - SecretAccessKey: "secret", - SessionToken: "token", + CredentialType: AWSCredentialTypeAccessKey, + AccessKey: AccessKeyCredentials{ + AccessKeyID: "AccessKeyID", + SecretAccessKey: "SecretAccessKey", + SessionToken: "SessionToken", + }, }, SubnetID: "subnet_id", Region: "region", diff --git a/internal/client/aws_test.go b/internal/client/aws_test.go index 7c654a1..12324cf 100644 --- a/internal/client/aws_test.go +++ b/internal/client/aws_test.go @@ -35,9 +35,12 @@ func TestStartInstance(t *testing.T) { Region: "us-west-2", SubnetID: "subnet-1234567890abcdef0", Credentials: config.Credentials{ - AccessKeyID: "AccessKeyID", - SecretAccessKey: "SecretAccessKey", - SessionToken: "SessionToken", + CredentialType: config.AWSCredentialTypeAccessKey, + AccessKey: config.AccessKeyCredentials{ + AccessKeyID: "AccessKeyID", + SecretAccessKey: "SecretAccessKey", + SessionToken: "SessionToken", + }, }, } mockClient := new(MockComputeClient) @@ -62,9 +65,12 @@ func TestStopInstance(t *testing.T) { Region: "us-west-2", SubnetID: "subnet-1234567890abcdef0", Credentials: config.Credentials{ - AccessKeyID: "AccessKeyID", - SecretAccessKey: "SecretAccessKey", - SessionToken: "SessionToken", + CredentialType: config.AWSCredentialTypeAccessKey, + AccessKey: config.AccessKeyCredentials{ + AccessKeyID: "AccessKeyID", + SecretAccessKey: "SecretAccessKey", + SessionToken: "SessionToken", + }, }, } mockClient := new(MockComputeClient) @@ -89,9 +95,12 @@ func TestFindInstances(t *testing.T) { Region: "us-west-2", SubnetID: "subnet-1234567890abcdef0", Credentials: config.Credentials{ - AccessKeyID: "AccessKeyID", - SecretAccessKey: "SecretAccessKey", - SessionToken: "SessionToken", + CredentialType: config.AWSCredentialTypeAccessKey, + AccessKey: config.AccessKeyCredentials{ + AccessKeyID: "AccessKeyID", + SecretAccessKey: "SecretAccessKey", + SessionToken: "SessionToken", + }, }, } mockClient := new(MockComputeClient) @@ -145,9 +154,12 @@ func TestFindOneInstanceWithName(t *testing.T) { Region: "us-west-2", SubnetID: "subnet-1234567890abcdef0", Credentials: config.Credentials{ - AccessKeyID: "AccessKeyID", - SecretAccessKey: "SecretAccessKey", - SessionToken: "SessionToken", + CredentialType: config.AWSCredentialTypeAccessKey, + AccessKey: config.AccessKeyCredentials{ + AccessKeyID: "AccessKeyID", + SecretAccessKey: "SecretAccessKey", + SessionToken: "SessionToken", + }, }, } mockClient := new(MockComputeClient) @@ -195,9 +207,12 @@ func TestFindOneInstanceWithID(t *testing.T) { Region: "us-west-2", SubnetID: "subnet-1234567890abcdef0", Credentials: config.Credentials{ - AccessKeyID: "AccessKeyID", - SecretAccessKey: "SecretAccessKey", - SessionToken: "SessionToken", + CredentialType: config.AWSCredentialTypeAccessKey, + AccessKey: config.AccessKeyCredentials{ + AccessKeyID: "AccessKeyID", + SecretAccessKey: "SecretAccessKey", + SessionToken: "SessionToken", + }, }, } mockClient := new(MockComputeClient) @@ -234,9 +249,12 @@ func TestGetInstance(t *testing.T) { Region: "us-west-2", SubnetID: "subnet-1234567890abcdef0", Credentials: config.Credentials{ - AccessKeyID: "AccessKeyID", - SecretAccessKey: "SecretAccessKey", - SessionToken: "SessionToken", + CredentialType: config.AWSCredentialTypeAccessKey, + AccessKey: config.AccessKeyCredentials{ + AccessKeyID: "AccessKeyID", + SecretAccessKey: "SecretAccessKey", + SessionToken: "SessionToken", + }, }, } mockClient := new(MockComputeClient) @@ -270,9 +288,12 @@ func TestTerminateInstance(t *testing.T) { Region: "us-west-2", SubnetID: "subnet-1234567890abcdef0", Credentials: config.Credentials{ - AccessKeyID: "AccessKeyID", - SecretAccessKey: "SecretAccessKey", - SessionToken: "SessionToken", + CredentialType: config.AWSCredentialTypeAccessKey, + AccessKey: config.AccessKeyCredentials{ + AccessKeyID: "AccessKeyID", + SecretAccessKey: "SecretAccessKey", + SessionToken: "SessionToken", + }, }, } mockClient := new(MockComputeClient) @@ -315,9 +336,12 @@ func TestCreateRunningInstance(t *testing.T) { Region: "us-west-2", SubnetID: "subnet-1234567890abcdef0", Credentials: config.Credentials{ - AccessKeyID: "AccessKeyID", - SecretAccessKey: "SecretAccessKey", - SessionToken: "SessionToken", + CredentialType: config.AWSCredentialTypeAccessKey, + AccessKey: config.AccessKeyCredentials{ + AccessKeyID: "AccessKeyID", + SecretAccessKey: "SecretAccessKey", + SessionToken: "SessionToken", + }, }, } mockClient := new(MockComputeClient) diff --git a/internal/spec/spec_test.go b/internal/spec/spec_test.go index 6780e03..9aa0344 100644 --- a/internal/spec/spec_test.go +++ b/internal/spec/spec_test.go @@ -120,9 +120,12 @@ func TestGetRunnerSpecFromBootstrapParams(t *testing.T) { config := &config.Config{ Credentials: config.Credentials{ - AccessKeyID: "access_key_id", - SecretAccessKey: "secret_access_key", - SessionToken: "session_token", + CredentialType: config.AWSCredentialTypeAccessKey, + AccessKey: config.AccessKeyCredentials{ + AccessKeyID: "AccessKeyID", + SecretAccessKey: "SecretAccessKey", + SessionToken: "SessionToken", + }, }, SubnetID: "subnet_id", Region: "region", diff --git a/provider/provider_test.go b/provider/provider_test.go index 015989b..8a84277 100644 --- a/provider/provider_test.go +++ b/provider/provider_test.go @@ -74,9 +74,12 @@ func TestCreateInstance(t *testing.T) { Region: "us-east-1", SubnetID: "subnet-123456", Credentials: config.Credentials{ - AccessKeyID: "accessKey", - SecretAccessKey: "secretKey", - SessionToken: "token", + CredentialType: config.AWSCredentialTypeAccessKey, + AccessKey: config.AccessKeyCredentials{ + AccessKeyID: "AccessKeyID", + SecretAccessKey: "SecretAccessKey", + SessionToken: "SessionToken", + }, }, } mockComputeClient := new(client.MockComputeClient) @@ -106,9 +109,12 @@ func TestDeleteInstanceWithID(t *testing.T) { Region: "us-east-1", SubnetID: "subnet-123456", Credentials: config.Credentials{ - AccessKeyID: "accessKey", - SecretAccessKey: "secretKey", - SessionToken: "token", + CredentialType: config.AWSCredentialTypeAccessKey, + AccessKey: config.AccessKeyCredentials{ + AccessKeyID: "AccessKeyID", + SecretAccessKey: "SecretAccessKey", + SessionToken: "SessionToken", + }, }, } mockComputeClient := new(client.MockComputeClient) @@ -134,9 +140,12 @@ func TestDeleteInstanceWithName(t *testing.T) { Region: "us-east-1", SubnetID: "subnet-123456", Credentials: config.Credentials{ - AccessKeyID: "accessKey", - SecretAccessKey: "secretKey", - SessionToken: "token", + CredentialType: config.AWSCredentialTypeAccessKey, + AccessKey: config.AccessKeyCredentials{ + AccessKeyID: "AccessKeyID", + SecretAccessKey: "SecretAccessKey", + SessionToken: "SessionToken", + }, }, } mockComputeClient := new(client.MockComputeClient) @@ -195,9 +204,12 @@ func TestGetInstanceWithID(t *testing.T) { Region: "us-east-1", SubnetID: "subnet-123456", Credentials: config.Credentials{ - AccessKeyID: "accessKey", - SecretAccessKey: "secretKey", - SessionToken: "token", + CredentialType: config.AWSCredentialTypeAccessKey, + AccessKey: config.AccessKeyCredentials{ + AccessKeyID: "AccessKeyID", + SecretAccessKey: "SecretAccessKey", + SessionToken: "SessionToken", + }, }, } mockComputeClient := new(client.MockComputeClient) @@ -264,9 +276,12 @@ func TestGetInstanceWithName(t *testing.T) { Region: "us-east-1", SubnetID: "subnet-123456", Credentials: config.Credentials{ - AccessKeyID: "accessKey", - SecretAccessKey: "secretKey", - SessionToken: "token", + CredentialType: config.AWSCredentialTypeAccessKey, + AccessKey: config.AccessKeyCredentials{ + AccessKeyID: "AccessKeyID", + SecretAccessKey: "SecretAccessKey", + SessionToken: "SessionToken", + }, }, } mockComputeClient := new(client.MockComputeClient) @@ -348,9 +363,12 @@ func TestListInstances(t *testing.T) { Region: "us-east-1", SubnetID: "subnet-123456", Credentials: config.Credentials{ - AccessKeyID: "accessKey", - SecretAccessKey: "secretKey", - SessionToken: "token", + CredentialType: config.AWSCredentialTypeAccessKey, + AccessKey: config.AccessKeyCredentials{ + AccessKeyID: "AccessKeyID", + SecretAccessKey: "SecretAccessKey", + SessionToken: "SessionToken", + }, }, } mockComputeClient := new(client.MockComputeClient) @@ -434,9 +452,12 @@ func TestStop(t *testing.T) { Region: "us-east-1", SubnetID: "subnet-123456", Credentials: config.Credentials{ - AccessKeyID: "accessKey", - SecretAccessKey: "secretKey", - SessionToken: "token", + CredentialType: config.AWSCredentialTypeAccessKey, + AccessKey: config.AccessKeyCredentials{ + AccessKeyID: "AccessKeyID", + SecretAccessKey: "SecretAccessKey", + SessionToken: "SessionToken", + }, }, } mockComputeClient := new(client.MockComputeClient) @@ -461,9 +482,12 @@ func TestStartStoppedInstance(t *testing.T) { Region: "us-east-1", SubnetID: "subnet-123456", Credentials: config.Credentials{ - AccessKeyID: "accessKey", - SecretAccessKey: "secretKey", - SessionToken: "token", + CredentialType: config.AWSCredentialTypeAccessKey, + AccessKey: config.AccessKeyCredentials{ + AccessKeyID: "AccessKeyID", + SecretAccessKey: "SecretAccessKey", + SessionToken: "SessionToken", + }, }, } mockComputeClient := new(client.MockComputeClient) @@ -510,9 +534,12 @@ func TestStartStoppingInstance(t *testing.T) { Region: "us-east-1", SubnetID: "subnet-123456", Credentials: config.Credentials{ - AccessKeyID: "accessKey", - SecretAccessKey: "secretKey", - SessionToken: "token", + CredentialType: config.AWSCredentialTypeAccessKey, + AccessKey: config.AccessKeyCredentials{ + AccessKeyID: "AccessKeyID", + SecretAccessKey: "SecretAccessKey", + SessionToken: "SessionToken", + }, }, } mockComputeClient := new(client.MockComputeClient)