Skip to content

Commit

Permalink
Allow IAM role credentials
Browse files Browse the repository at this point in the history
This change allows users to not specify any credentials, thus facilitating
IAM roles if running on AWS instances.

Signed-off-by: Gabriel Adrian Samfira <[email protected]>
  • Loading branch information
gabriel-samfira committed Jun 10, 2024
1 parent 58c45ca commit 1cb4ad6
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 88 deletions.
50 changes: 40 additions & 10 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ import (
"github.com/aws/aws-sdk-go-v2/credentials"
)

type AWSCredentialType string

const (
AWSCredentialTypeAccessKey AWSCredentialType = "access_key"
AWSCredentialTypeRole AWSCredentialType = "role"
)

// NewConfig returns a new Config
func NewConfig(cfgFile string) (*Config, error) {
var config Config
Expand Down Expand Up @@ -59,7 +66,7 @@ func (c *Config) Validate() error {
return nil
}

type Credentials struct {
type AccessKeyCredentials struct {
// AWS Access key ID
AccessKeyID string `toml:"access_key_id"`

Expand All @@ -70,7 +77,7 @@ type Credentials struct {
SessionToken string `toml:"session_token"`
}

func (c Credentials) Validate() error {
func (c AccessKeyCredentials) Validate() error {
if c.AccessKeyID == "" {
return fmt.Errorf("missing access_key_id")
}
Expand All @@ -85,19 +92,42 @@ func (c Credentials) Validate() error {
return nil
}

type Credentials struct {
CredentialType AWSCredentialType `toml:"credential_type"`
AccessKey AccessKeyCredentials `toml:"access_key"`
}

func (c Credentials) Validate() error {
switch c.CredentialType {
case AWSCredentialTypeAccessKey:
return c.AccessKey.Validate()
case AWSCredentialTypeRole:
}
return nil
}

func (c Config) GetAWSConfig(ctx context.Context) (aws.Config, error) {
if err := c.Credentials.Validate(); err != nil {
return aws.Config{}, fmt.Errorf("failed to validate credentials: %w", err)
}

cfg, err := config.LoadDefaultConfig(ctx,
config.WithCredentialsProvider(
credentials.NewStaticCredentialsProvider(
c.Credentials.AccessKeyID,
c.Credentials.SecretAccessKey,
c.Credentials.SessionToken)),
config.WithRegion(c.Region),
)
var cfg aws.Config
var err error
switch c.Credentials.CredentialType {
case AWSCredentialTypeAccessKey:
cfg, err = config.LoadDefaultConfig(ctx,
config.WithCredentialsProvider(
credentials.NewStaticCredentialsProvider(
c.Credentials.AccessKey.AccessKeyID,
c.Credentials.AccessKey.SecretAccessKey,
c.Credentials.AccessKey.SessionToken)),
config.WithRegion(c.Region),
)
case AWSCredentialTypeRole:
cfg, err = config.LoadDefaultConfig(ctx)
default:
return aws.Config{}, fmt.Errorf("unknown credential type: %s", c.Credentials.CredentialType)
}
if err != nil {
return aws.Config{}, fmt.Errorf("failed to get aws config: %w", err)
}
Expand Down
72 changes: 48 additions & 24 deletions config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@ func TestConfigValidate(t *testing.T) {
name: "valid config",
c: &Config{
Credentials: Credentials{
AccessKeyID: "access_key_id",
SecretAccessKey: "secret_access_key",
SessionToken: "session_token",
CredentialType: AWSCredentialTypeAccessKey,
AccessKey: AccessKeyCredentials{
AccessKeyID: "AccessKeyID",
SecretAccessKey: "SecretAccessKey",
SessionToken: "SessionToken",
},
},
SubnetID: "subnet_id",
Region: "region",
Expand All @@ -45,9 +48,12 @@ func TestConfigValidate(t *testing.T) {
name: "missing subnet_id",
c: &Config{
Credentials: Credentials{
AccessKeyID: "access_key_id",
SecretAccessKey: "secret_access_key",
SessionToken: "session_token",
CredentialType: AWSCredentialTypeAccessKey,
AccessKey: AccessKeyCredentials{
AccessKeyID: "AccessKeyID",
SecretAccessKey: "SecretAccessKey",
SessionToken: "SessionToken",
},
},
Region: "region",
},
Expand All @@ -57,9 +63,12 @@ func TestConfigValidate(t *testing.T) {
name: "missing region",
c: &Config{
Credentials: Credentials{
AccessKeyID: "access_key_id",
SecretAccessKey: "secret_access_key",
SessionToken: "session_token",
CredentialType: AWSCredentialTypeAccessKey,
AccessKey: AccessKeyCredentials{
AccessKeyID: "AccessKeyID",
SecretAccessKey: "SecretAccessKey",
SessionToken: "SessionToken",
},
},
SubnetID: "subnet_id",
},
Expand Down Expand Up @@ -96,36 +105,48 @@ func TestCredentialsValidate(t *testing.T) {
{
name: "valid credentials",
c: Credentials{
AccessKeyID: "access_key_id",
SecretAccessKey: "secret_access_key",
SessionToken: "session_token",
CredentialType: AWSCredentialTypeAccessKey,
AccessKey: AccessKeyCredentials{
AccessKeyID: "AccessKeyID",
SecretAccessKey: "SecretAccessKey",
SessionToken: "SessionToken",
},
},
errString: "",
},
{
name: "missing access_key_id",
c: Credentials{
AccessKeyID: "",
SecretAccessKey: "secret_access_key",
SessionToken: "session_token",
CredentialType: AWSCredentialTypeAccessKey,
AccessKey: AccessKeyCredentials{
AccessKeyID: "AccessKeyID",
SecretAccessKey: "SecretAccessKey",
SessionToken: "SessionToken",
},
},
errString: "missing access_key_id",
},
{
name: "missing secret_access_key",
c: Credentials{
AccessKeyID: "access_key_id",
SecretAccessKey: "",
SessionToken: "session_token",
CredentialType: AWSCredentialTypeAccessKey,
AccessKey: AccessKeyCredentials{
AccessKeyID: "AccessKeyID",
SecretAccessKey: "SecretAccessKey",
SessionToken: "SessionToken",
},
},
errString: "missing secret_access_key",
},
{
name: "missing session_token",
c: Credentials{
AccessKeyID: "access_key_id",
SecretAccessKey: "secret_access_key",
SessionToken: "",
CredentialType: AWSCredentialTypeAccessKey,
AccessKey: AccessKeyCredentials{
AccessKeyID: "AccessKeyID",
SecretAccessKey: "SecretAccessKey",
SessionToken: "SessionToken",
},
},
errString: "missing session_token",
},
Expand Down Expand Up @@ -171,9 +192,12 @@ func TestNewConfig(t *testing.T) {
require.NoError(t, err, "NewConfig() should not have returned an error")
require.Equal(t, &Config{
Credentials: Credentials{
AccessKeyID: "access_key_id",
SecretAccessKey: "secret",
SessionToken: "token",
CredentialType: AWSCredentialTypeAccessKey,
AccessKey: AccessKeyCredentials{
AccessKeyID: "AccessKeyID",
SecretAccessKey: "SecretAccessKey",
SessionToken: "SessionToken",
},
},
SubnetID: "subnet_id",
Region: "region",
Expand Down
72 changes: 48 additions & 24 deletions internal/client/aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,12 @@ func TestStartInstance(t *testing.T) {
Region: "us-west-2",
SubnetID: "subnet-1234567890abcdef0",
Credentials: config.Credentials{
AccessKeyID: "AccessKeyID",
SecretAccessKey: "SecretAccessKey",
SessionToken: "SessionToken",
CredentialType: config.AWSCredentialTypeAccessKey,
AccessKey: config.AccessKeyCredentials{
AccessKeyID: "AccessKeyID",
SecretAccessKey: "SecretAccessKey",
SessionToken: "SessionToken",
},
},
}
mockClient := new(MockComputeClient)
Expand All @@ -62,9 +65,12 @@ func TestStopInstance(t *testing.T) {
Region: "us-west-2",
SubnetID: "subnet-1234567890abcdef0",
Credentials: config.Credentials{
AccessKeyID: "AccessKeyID",
SecretAccessKey: "SecretAccessKey",
SessionToken: "SessionToken",
CredentialType: config.AWSCredentialTypeAccessKey,
AccessKey: config.AccessKeyCredentials{
AccessKeyID: "AccessKeyID",
SecretAccessKey: "SecretAccessKey",
SessionToken: "SessionToken",
},
},
}
mockClient := new(MockComputeClient)
Expand All @@ -89,9 +95,12 @@ func TestFindInstances(t *testing.T) {
Region: "us-west-2",
SubnetID: "subnet-1234567890abcdef0",
Credentials: config.Credentials{
AccessKeyID: "AccessKeyID",
SecretAccessKey: "SecretAccessKey",
SessionToken: "SessionToken",
CredentialType: config.AWSCredentialTypeAccessKey,
AccessKey: config.AccessKeyCredentials{
AccessKeyID: "AccessKeyID",
SecretAccessKey: "SecretAccessKey",
SessionToken: "SessionToken",
},
},
}
mockClient := new(MockComputeClient)
Expand Down Expand Up @@ -145,9 +154,12 @@ func TestFindOneInstanceWithName(t *testing.T) {
Region: "us-west-2",
SubnetID: "subnet-1234567890abcdef0",
Credentials: config.Credentials{
AccessKeyID: "AccessKeyID",
SecretAccessKey: "SecretAccessKey",
SessionToken: "SessionToken",
CredentialType: config.AWSCredentialTypeAccessKey,
AccessKey: config.AccessKeyCredentials{
AccessKeyID: "AccessKeyID",
SecretAccessKey: "SecretAccessKey",
SessionToken: "SessionToken",
},
},
}
mockClient := new(MockComputeClient)
Expand Down Expand Up @@ -195,9 +207,12 @@ func TestFindOneInstanceWithID(t *testing.T) {
Region: "us-west-2",
SubnetID: "subnet-1234567890abcdef0",
Credentials: config.Credentials{
AccessKeyID: "AccessKeyID",
SecretAccessKey: "SecretAccessKey",
SessionToken: "SessionToken",
CredentialType: config.AWSCredentialTypeAccessKey,
AccessKey: config.AccessKeyCredentials{
AccessKeyID: "AccessKeyID",
SecretAccessKey: "SecretAccessKey",
SessionToken: "SessionToken",
},
},
}
mockClient := new(MockComputeClient)
Expand Down Expand Up @@ -234,9 +249,12 @@ func TestGetInstance(t *testing.T) {
Region: "us-west-2",
SubnetID: "subnet-1234567890abcdef0",
Credentials: config.Credentials{
AccessKeyID: "AccessKeyID",
SecretAccessKey: "SecretAccessKey",
SessionToken: "SessionToken",
CredentialType: config.AWSCredentialTypeAccessKey,
AccessKey: config.AccessKeyCredentials{
AccessKeyID: "AccessKeyID",
SecretAccessKey: "SecretAccessKey",
SessionToken: "SessionToken",
},
},
}
mockClient := new(MockComputeClient)
Expand Down Expand Up @@ -270,9 +288,12 @@ func TestTerminateInstance(t *testing.T) {
Region: "us-west-2",
SubnetID: "subnet-1234567890abcdef0",
Credentials: config.Credentials{
AccessKeyID: "AccessKeyID",
SecretAccessKey: "SecretAccessKey",
SessionToken: "SessionToken",
CredentialType: config.AWSCredentialTypeAccessKey,
AccessKey: config.AccessKeyCredentials{
AccessKeyID: "AccessKeyID",
SecretAccessKey: "SecretAccessKey",
SessionToken: "SessionToken",
},
},
}
mockClient := new(MockComputeClient)
Expand Down Expand Up @@ -315,9 +336,12 @@ func TestCreateRunningInstance(t *testing.T) {
Region: "us-west-2",
SubnetID: "subnet-1234567890abcdef0",
Credentials: config.Credentials{
AccessKeyID: "AccessKeyID",
SecretAccessKey: "SecretAccessKey",
SessionToken: "SessionToken",
CredentialType: config.AWSCredentialTypeAccessKey,
AccessKey: config.AccessKeyCredentials{
AccessKeyID: "AccessKeyID",
SecretAccessKey: "SecretAccessKey",
SessionToken: "SessionToken",
},
},
}
mockClient := new(MockComputeClient)
Expand Down
9 changes: 6 additions & 3 deletions internal/spec/spec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,12 @@ func TestGetRunnerSpecFromBootstrapParams(t *testing.T) {

config := &config.Config{
Credentials: config.Credentials{
AccessKeyID: "access_key_id",
SecretAccessKey: "secret_access_key",
SessionToken: "session_token",
CredentialType: config.AWSCredentialTypeAccessKey,
AccessKey: config.AccessKeyCredentials{
AccessKeyID: "AccessKeyID",
SecretAccessKey: "SecretAccessKey",
SessionToken: "SessionToken",
},
},
SubnetID: "subnet_id",
Region: "region",
Expand Down
Loading

0 comments on commit 1cb4ad6

Please sign in to comment.