diff --git a/cmd/main.go b/cmd/main.go index 1428a69dc1..a694b3dc3a 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -168,7 +168,7 @@ func main() { region = md.GetRegion() } - cloud, err := cloud.NewCloud(region, options.AwsSdkDebugLog, options.UserAgentExtra, options.Batching) + cloud, err := cloud.NewCloud(region, options.AwsSdkDebugLog, options.UserAgentExtra, options.Batching, options.RoleARN) if err != nil { klog.ErrorS(err, "failed to create cloud service") klog.FlushAndExit(klog.ExitFlushTimeout, 1) diff --git a/docs/options.md b/docs/options.md index 05951f2ab5..e9b9ae37e5 100644 --- a/docs/options.md +++ b/docs/options.md @@ -19,3 +19,4 @@ There are a couple of driver options that can be passed as arguments when starti | modify-volume-request-handler-timeout | 10s | 2s | Timeout for the window in which volume modification calls must be received in order for them to coalesce into a single volume modification call to AWS. If changing this, be aware that the ebs-csi-controller's csi-resizer and volumemodifier containers both have timeouts on the calls they make, if this value exceeds those timeouts it will cause them to always fail and fall into a retry loop, so adjust those values accordingly. | warn-on-invalid-tag | true | false | To warn on invalid tags, instead of returning an error| |reserved-volume-attachments | 2 | -1 | Number of volume attachments reserved for system use. Not used when --volume-attach-limit is specified. When -1, the amount of reserved attachments is loaded from instance metadata that captured state at node boot and may include not only system disks but also CSI volumes.| +| role-arn | arn:aws:iam::012345678910:role/ExampleRole | | The Role used to interact with AWS EC2 APIs | diff --git a/pkg/cloud/cloud.go b/pkg/cloud/cloud.go index 4e946008c8..e1e49d67df 100644 --- a/pkg/cloud/cloud.go +++ b/pkg/cloud/cloud.go @@ -30,8 +30,10 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/aws/smithy-go" "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/batcher" dm "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud/devicemanager" @@ -94,6 +96,8 @@ const ( volumeDetachedState = "detached" volumeAttachedState = "attached" cacheForgetDelay = 1 * time.Hour + + assumeRoleSessionDuration = 1 * time.Hour ) // AWS provisioning limits. @@ -333,12 +337,12 @@ var _ Cloud = &cloud{} // NewCloud returns a new instance of AWS cloud // It panics if session is invalid -func NewCloud(region string, awsSdkDebugLog bool, userAgentExtra string, batching bool) (Cloud, error) { - c := newEC2Cloud(region, awsSdkDebugLog, userAgentExtra, batching) +func NewCloud(region string, awsSdkDebugLog bool, userAgentExtra string, batching bool, roleARN string) (Cloud, error) { + c := newEC2Cloud(region, awsSdkDebugLog, userAgentExtra, batching, roleARN) return c, nil } -func newEC2Cloud(region string, awsSdkDebugLog bool, userAgentExtra string, batchingEnabled bool) Cloud { +func newEC2Cloud(region string, awsSdkDebugLog bool, userAgentExtra string, batchingEnabled bool, roleARN string) Cloud { cfg, err := config.LoadDefaultConfig(context.Background(), config.WithRegion(region)) if err != nil { panic(err) @@ -355,7 +359,19 @@ func newEC2Cloud(region string, awsSdkDebugLog bool, userAgentExtra string, batc os.Setenv("AWS_EXECUTION_ENV", "aws-ebs-csi-driver-"+driverVersion) } - svc := ec2.NewFromConfig(cfg, func(o *ec2.Options) { + ec2Config := cfg + if roleARN != "" { + creds := stscreds.NewAssumeRoleProvider(sts.NewFromConfig(cfg), roleARN, func(aro *stscreds.AssumeRoleOptions) { + aro.Duration = assumeRoleSessionDuration + }) + ec2Config = aws.Config{ + Region: cfg.Region, + DefaultsMode: aws.DefaultsModeStandard, + Credentials: aws.NewCredentialsCache(creds), + } + } + + svc := ec2.NewFromConfig(ec2Config, func(o *ec2.Options) { o.APIOptions = append(o.APIOptions, RecordRequestsMiddleware(), ) diff --git a/pkg/cloud/cloud_test.go b/pkg/cloud/cloud_test.go index f7acef7b61..1c3396b577 100644 --- a/pkg/cloud/cloud_test.go +++ b/pkg/cloud/cloud_test.go @@ -83,7 +83,6 @@ func extractVolumeIdentifiers(volumes []types.Volume) (volumeIDs []string, volum return volumeIDs, volumeNames } func TestNewCloud(t *testing.T) { - testCases := []struct { name string region string @@ -110,7 +109,7 @@ func TestNewCloud(t *testing.T) { }, } for _, tc := range testCases { - ec2Cloud, err := NewCloud(tc.region, tc.awsSdkDebugLog, tc.userAgentExtra, tc.batchingEnabled) + ec2Cloud, err := NewCloud(tc.region, tc.awsSdkDebugLog, tc.userAgentExtra, tc.batchingEnabled, "") if err != nil { t.Fatalf("error %v", err) } diff --git a/pkg/driver/options.go b/pkg/driver/options.go index 5dc13c4b43..478b901f24 100644 --- a/pkg/driver/options.go +++ b/pkg/driver/options.go @@ -32,6 +32,9 @@ type Options struct { // If empty, the in-cluster config will be loaded. Kubeconfig string + //RoleArn is the role driver will use to interact with the AWS EC2 APIs. + RoleARN string + // #### Server options #### //Endpoint is the endpoint for the CSI driver server @@ -91,6 +94,7 @@ type Options struct { func (o *Options) AddFlags(f *flag.FlagSet) { f.StringVar(&o.Kubeconfig, "kubeconfig", "", "Absolute path to a kubeconfig file. The default is the emtpy string, which causes the in-cluster config to be used") + f.StringVar(&o.RoleARN, "role-arn", "", "Arn of the role to be used while interacting with EC2 APIs. The default is the empty string, which causes the role provided by the Pod identity or OIDC to be used.") // Server options f.StringVar(&o.Endpoint, "endpoint", DefaultCSIEndpoint, "Endpoint for the CSI driver server") diff --git a/pkg/driver/options_test.go b/pkg/driver/options_test.go index 21b4fe4a75..c8f3510ea2 100644 --- a/pkg/driver/options_test.go +++ b/pkg/driver/options_test.go @@ -70,6 +70,9 @@ func TestAddFlags(t *testing.T) { if err := f.Set("reserved-volume-attachments", "5"); err != nil { t.Errorf("error setting reserved-volume-attachments: %v", err) } + if err := f.Set("role-arn", "arn:aws:iam::012345678910:role/ExampleRole"); err != nil { + t.Errorf("error setting role-arn: %v", err) + } if o.Endpoint != "custom-endpoint" { t.Errorf("unexpected Endpoint: got %s, want custom-endpoint", o.Endpoint) @@ -107,6 +110,9 @@ func TestAddFlags(t *testing.T) { if o.ReservedVolumeAttachments != 5 { t.Errorf("unexpected ReservedVolumeAttachments: got %d, want 5", o.ReservedVolumeAttachments) } + if o.RoleARN != "arn:aws:iam::012345678910:role/ExampleRole" { + t.Errorf("unexpected role-arn: got %d, want arn:aws:iam::012345678910:role/ExampleRole", o.RoleARN) + } } func TestValidateAttachmentLimits(t *testing.T) { diff --git a/tests/e2e/dynamic_provisioning.go b/tests/e2e/dynamic_provisioning.go index c39de5b51c..0b4b3e41e7 100644 --- a/tests/e2e/dynamic_provisioning.go +++ b/tests/e2e/dynamic_provisioning.go @@ -623,7 +623,7 @@ var _ = Describe("[ebs-csi-e2e] [single-az] Dynamic Provisioning", func() { availabilityZones := strings.Split(os.Getenv(awsAvailabilityZonesEnv), ",") availabilityZone := availabilityZones[rand.Intn(len(availabilityZones))] region := availabilityZone[0 : len(availabilityZone)-1] - cloud, err := awscloud.NewCloud(region, false, "", true) + cloud, err := awscloud.NewCloud(region, false, "", true, "") if err != nil { Fail(fmt.Sprintf("could not get NewCloud: %v", err)) } diff --git a/tests/e2e/pre_provsioning.go b/tests/e2e/pre_provsioning.go index cd77e6e6ca..a7b70bdf3e 100644 --- a/tests/e2e/pre_provsioning.go +++ b/tests/e2e/pre_provsioning.go @@ -88,7 +88,7 @@ var _ = Describe("[ebs-csi-e2e] [single-az] Pre-Provisioned", func() { Tags: map[string]string{awscloud.VolumeNameTagKey: dummyVolumeName, awscloud.AwsEbsDriverTagKey: "true"}, } var err error - cloud, err = awscloud.NewCloud(region, false, "", true) + cloud, err = awscloud.NewCloud(region, false, "", true, "") if err != nil { Fail(fmt.Sprintf("could not get NewCloud: %v", err)) } @@ -261,7 +261,7 @@ var _ = Describe("[ebs-csi-e2e] [single-az] Pre-Provisioned with Multi-Attach", Tags: map[string]string{awscloud.VolumeNameTagKey: dummyVolumeName, awscloud.AwsEbsDriverTagKey: "true"}, } var err error - cloud, err = awscloud.NewCloud(region, false, "", true) + cloud, err = awscloud.NewCloud(region, false, "", true, "") if err != nil { Fail(fmt.Sprintf("could not get NewCloud: %v", err)) }