diff --git a/pkg/cloud/cloud.go b/pkg/cloud/cloud.go index e912abd8a4..710a358696 100644 --- a/pkg/cloud/cloud.go +++ b/pkg/cloud/cloud.go @@ -141,9 +141,12 @@ const ( // Batcher const ( - volumeIDBatcher batcherType = iota + volumeIDBatcher volumeBatcherType = iota volumeTagBatcher + snapshotIDBatcher snapshotBatcherType = iota + snapshotTagBatcher + batchDescribeTimeout = 30 * time.Second ) @@ -242,14 +245,19 @@ type ec2ListSnapshotsResponse struct { NextToken *string } -// batcherType is an enum representing the types of batchers available. -type batcherType int +// volumeBatcherType is an enum representing the types of volume batchers available. +type volumeBatcherType int + +// snapshotBatcherType is an enum representing the types of snapshot batchers available. +type snapshotBatcherType int // batcherManager maintains a collection of batchers for different types of tasks. type batcherManager struct { - volumeIDBatcher *batcher.Batcher[string, *ec2.Volume] - volumeTagBatcher *batcher.Batcher[string, *ec2.Volume] - instanceIDBatcher *batcher.Batcher[string, *ec2.Instance] + volumeIDBatcher *batcher.Batcher[string, *ec2.Volume] + volumeTagBatcher *batcher.Batcher[string, *ec2.Volume] + instanceIDBatcher *batcher.Batcher[string, *ec2.Instance] + snapshotIDBatcher *batcher.Batcher[string, *ec2.Snapshot] + snapshotTagBatcher *batcher.Batcher[string, *ec2.Snapshot] } type cloud struct { @@ -334,14 +342,20 @@ func newBatcherManager(svc ec2iface.EC2API) *batcherManager { volumeTagBatcher: batcher.New(500, 1*time.Second, func(names []string) (map[string]*ec2.Volume, error) { return execBatchDescribeVolumes(svc, names, volumeTagBatcher) }), - instanceIDBatcher: batcher.New(50, 300*time.Millisecond, func(names []string) (map[string]*ec2.Instance, error) { - return execBatchDescribeInstances(svc, names) + instanceIDBatcher: batcher.New(50, 300*time.Millisecond, func(ids []string) (map[string]*ec2.Instance, error) { + return execBatchDescribeInstances(svc, ids) + }), + snapshotIDBatcher: batcher.New(500, 300*time.Millisecond, func(ids []string) (map[string]*ec2.Snapshot, error) { + return execBatchDescribeSnapshots(svc, ids, snapshotIDBatcher) + }), + snapshotTagBatcher: batcher.New(500, 300*time.Millisecond, func(names []string) (map[string]*ec2.Snapshot, error) { + return execBatchDescribeSnapshots(svc, names, snapshotTagBatcher) }), } } // execBatchDescribeVolumes executes a batched DescribeVolumes API call depending on the type of batcher. -func execBatchDescribeVolumes(svc ec2iface.EC2API, input []string, batcher batcherType) (map[string]*ec2.Volume, error) { +func execBatchDescribeVolumes(svc ec2iface.EC2API, input []string, batcher volumeBatcherType) (map[string]*ec2.Volume, error) { var request *ec2.DescribeVolumesInput switch batcher { @@ -427,7 +441,7 @@ func (c *cloud) batchDescribeVolumes(request *ec2.DescribeVolumesInput) (*ec2.Vo // extractVolumeKey retrieves the key associated with a given volume based on the batcher type. // For the volumeIDBatcher type, it returns the volume's ID. // For other types, it searches for the VolumeNameTagKey within the volume's tags. -func extractVolumeKey(v *ec2.Volume, batcher batcherType) (string, error) { +func extractVolumeKey(v *ec2.Volume, batcher volumeBatcherType) (string, error) { if batcher == volumeIDBatcher { if v.VolumeId == nil { return "", errors.New("extractVolumeKey: missing volume ID") @@ -1014,6 +1028,114 @@ func (c *cloud) GetDiskByID(ctx context.Context, volumeID string) (*Disk, error) }, nil } +// execBatchDescribeSnapshots executes a batched DescribeSnapshots API call depending on the type of batcher. +func execBatchDescribeSnapshots(svc ec2iface.EC2API, input []string, batcher snapshotBatcherType) (map[string]*ec2.Snapshot, error) { + var request *ec2.DescribeSnapshotsInput + + switch batcher { + case snapshotIDBatcher: + klog.V(7).InfoS("execBatchDescribeSnapshots", "snapshotIds", input) + request = &ec2.DescribeSnapshotsInput{ + SnapshotIds: aws.StringSlice(input), + } + + case snapshotTagBatcher: + klog.V(7).InfoS("execBatchDescribeSnapshots", "names", input) + filters := []*ec2.Filter{ + { + Name: aws.String("tag:" + SnapshotNameTagKey), + Values: aws.StringSlice(input), + }, + } + request = &ec2.DescribeSnapshotsInput{ + Filters: filters, + } + + default: + return nil, fmt.Errorf("execBatchDescribeSnapshots: unsupported request type") + } + + ctx, cancel := context.WithTimeout(context.Background(), batchDescribeTimeout) + defer cancel() + + resp, err := describeSnapshots(ctx, svc, request) + if err != nil { + return nil, err + } + + result := make(map[string]*ec2.Snapshot) + + for _, snapshot := range resp { + key, err := extractSnapshotKey(snapshot, batcher) + if err != nil { + klog.Warningf("execBatchDescribeSnapshots: skipping snapshot: %v, reason: %v", snapshot, err) + continue + } + result[key] = snapshot + } + + klog.V(7).InfoS("execBatchDescribeSnapshots: success", "result", result) + return result, nil +} + +// batchDescribeSnapshots processes a DescribeSnapshots request. Depending on the request, +// it determines the appropriate batcher to use, queues the task, and waits for the result. +func (c *cloud) batchDescribeSnapshots(request *ec2.DescribeSnapshotsInput) (*ec2.Snapshot, error) { + var b *batcher.Batcher[string, *ec2.Snapshot] + var task string + + switch { + case len(request.SnapshotIds) == 1 && request.SnapshotIds[0] != nil: + b = c.bm.snapshotIDBatcher + task = *request.SnapshotIds[0] + + case len(request.Filters) == 1 && *request.Filters[0].Name == "tag:"+SnapshotNameTagKey && len(request.Filters[0].Values) == 1: + b = c.bm.snapshotTagBatcher + task = *request.Filters[0].Values[0] + + default: + return nil, fmt.Errorf("batchDescribeSnapshots: invalid request, request: %v", request) + } + + ch := make(chan batcher.BatchResult[*ec2.Snapshot]) + + b.AddTask(task, ch) + + r := <-ch + + if r.Err != nil { + return nil, r.Err + } + if r.Result == nil { + return nil, ErrNotFound + } + return r.Result, nil +} + +// extractSnapshotKey retrieves the key associated with a given snapshot based on the batcher type. +// For the snapshotIDBatcher type, it returns the snapshot's ID. +// For other types, it searches for the SnapshotNameTagKey within the snapshot's tags. +func extractSnapshotKey(s *ec2.Snapshot, batcher snapshotBatcherType) (string, error) { + if batcher == snapshotIDBatcher { + if s.SnapshotId == nil { + return "", errors.New("extractSnapshotKey: missing snapshot ID") + } + return *s.SnapshotId, nil + } + for _, tag := range s.Tags { + klog.V(7).InfoS("extractSnapshotKey: processing tag", "snapshot", s, "*tag.Key", *tag.Key, "SnapshotNameTagKey", SnapshotNameTagKey) + if tag.Key == nil || tag.Value == nil { + klog.V(7).InfoS("extractSnapshotKey: skipping snapshot due to missing tag", "snapshot", s, "tag", tag) + continue + } + if *tag.Key == SnapshotNameTagKey { + klog.V(7).InfoS("extractSnapshotKey: found snapshot name tag", "snapshot", s, "tag", tag) + return *tag.Value, nil + } + } + return "", errors.New("extractSnapshotKey: missing SnapshotNameTagKey in snapshot tags") +} + func (c *cloud) CreateSnapshot(ctx context.Context, volumeID string, snapshotOptions *SnapshotOptions) (snapshot *Snapshot, err error) { descriptions := "Created by AWS EBS CSI driver for volume " + volumeID @@ -1257,11 +1379,11 @@ func (c *cloud) getInstance(ctx context.Context, nodeID string) (*ec2.Instance, } } -func (c *cloud) getSnapshot(ctx context.Context, request *ec2.DescribeSnapshotsInput) (*ec2.Snapshot, error) { +func describeSnapshots(ctx context.Context, svc ec2iface.EC2API, request *ec2.DescribeSnapshotsInput) ([]*ec2.Snapshot, error) { var snapshots []*ec2.Snapshot var nextToken *string for { - response, err := c.ec2.DescribeSnapshotsWithContext(ctx, request) + response, err := svc.DescribeSnapshotsWithContext(ctx, request) if err != nil { return nil, err } @@ -1273,13 +1395,25 @@ func (c *cloud) getSnapshot(ctx context.Context, request *ec2.DescribeSnapshotsI request.NextToken = nextToken } - if l := len(snapshots); l > 1 { - return nil, ErrMultiSnapshots - } else if l < 1 { - return nil, ErrNotFound - } + return snapshots, nil +} - return snapshots[0], nil +func (c *cloud) getSnapshot(ctx context.Context, request *ec2.DescribeSnapshotsInput) (*ec2.Snapshot, error) { + if c.bm == nil { + snapshots, err := describeSnapshots(ctx, c.ec2, request) + if err != nil { + return nil, err + } + + if l := len(snapshots); l > 1 { + return nil, ErrMultiSnapshots + } else if l < 1 { + return nil, ErrNotFound + } + return snapshots[0], nil + } else { + return c.batchDescribeSnapshots(request) + } } // listSnapshots returns all snapshots based from a request diff --git a/pkg/cloud/cloud_test.go b/pkg/cloud/cloud_test.go index 18e77c991d..7e2a0b686f 100644 --- a/pkg/cloud/cloud_test.go +++ b/pkg/cloud/cloud_test.go @@ -351,6 +351,162 @@ func executeDescribeInstancesTest(t *testing.T, c *cloud, instanceIds []string, } } +func generateSnapshots(snapIDCount, snapTagCount int) []*ec2.Snapshot { + snapshots := make([]*ec2.Snapshot, 0, snapIDCount+snapTagCount) + + for i := 0; i < snapIDCount; i++ { + snapID := fmt.Sprintf("snap-%d", i) + snapshots = append(snapshots, &ec2.Snapshot{SnapshotId: aws.String(snapID)}) + } + + for i := 0; i < snapTagCount; i++ { + snapshotName := fmt.Sprintf("snap-name-%d", i) + snapshots = append(snapshots, &ec2.Snapshot{Tags: []*ec2.Tag{{Key: aws.String(SnapshotNameTagKey), Value: aws.String(snapshotName)}}}) + } + + return snapshots +} + +func extractSnapshotIdentifiers(snapshots []*ec2.Snapshot) (snapshotIDs []string, snapshotNames []string) { + for _, snapshot := range snapshots { + if snapshot.SnapshotId != nil { + snapshotIDs = append(snapshotIDs, *snapshot.SnapshotId) + } + for _, tag := range snapshot.Tags { + if tag.Key != nil && *tag.Key == SnapshotNameTagKey && tag.Value != nil { + snapshotNames = append(snapshotNames, *tag.Value) + } + } + } + return snapshotIDs, snapshotNames +} + +func TestBatchDescribeSnapshots(t *testing.T) { + testCases := []struct { + name string + snapshots []*ec2.Snapshot + mockFunc func(mockEC2 *MockEC2API, expErr error, snapshots []*ec2.Snapshot) + expErr error + }{ + { + name: "success: snapshot by ID", + snapshots: generateSnapshots(3, 0), + mockFunc: func(mockEC2 *MockEC2API, expErr error, snapshots []*ec2.Snapshot) { + snapshotOutput := &ec2.DescribeSnapshotsOutput{Snapshots: snapshots} + mockEC2.EXPECT().DescribeSnapshotsWithContext(gomock.Any(), gomock.Any()).Return(snapshotOutput, expErr).Times(1) + }, + }, + { + name: "success: snapshot by tag", + snapshots: generateSnapshots(0, 3), + mockFunc: func(mockEC2 *MockEC2API, expErr error, snapshots []*ec2.Snapshot) { + snapshotOutput := &ec2.DescribeSnapshotsOutput{Snapshots: snapshots} + mockEC2.EXPECT().DescribeSnapshotsWithContext(gomock.Any(), gomock.Any()).Return(snapshotOutput, expErr).Times(1) + }, + }, + { + name: "success: snapshot by ID and tag", + snapshots: generateSnapshots(3, 4), + mockFunc: func(mockEC2 *MockEC2API, expErr error, snapshots []*ec2.Snapshot) { + snapshotOutput := &ec2.DescribeSnapshotsOutput{Snapshots: snapshots} + mockEC2.EXPECT().DescribeSnapshotsWithContext(gomock.Any(), gomock.Any()).Return(snapshotOutput, expErr).Times(2) + }, + }, + { + name: "fail: EC2 API generic error", + snapshots: generateSnapshots(3, 2), + mockFunc: func(mockEC2 *MockEC2API, expErr error, snapshots []*ec2.Snapshot) { + mockEC2.EXPECT().DescribeSnapshotsWithContext(gomock.Any(), gomock.Any()).Return(nil, expErr).Times(2) + }, + expErr: fmt.Errorf("generic EC2 API error"), + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + mockEC2 := NewMockEC2API(mockCtrl) + c := newCloud(mockEC2) + cloudInstance := c.(*cloud) + cloudInstance.bm = newBatcherManager(cloudInstance.ec2) + + tc.mockFunc(mockEC2, tc.expErr, tc.snapshots) + snapshotIDs, snapshotNames := extractSnapshotIdentifiers(tc.snapshots) + executeDescribeSnapshotsTest(t, cloudInstance, snapshotIDs, snapshotNames, tc.expErr) + }) + } +} + +func executeDescribeSnapshotsTest(t *testing.T, c *cloud, snapshotIDs, snapshotNames []string, expErr error) { + var wg sync.WaitGroup + + getRequestForID := func(id string) *ec2.DescribeSnapshotsInput { + return &ec2.DescribeSnapshotsInput{SnapshotIds: []*string{&id}} + } + + getRequestForTag := func(snapName string) *ec2.DescribeSnapshotsInput { + return &ec2.DescribeSnapshotsInput{ + Filters: []*ec2.Filter{ + { + Name: aws.String("tag:" + SnapshotNameTagKey), + Values: []*string{&snapName}, + }, + }, + } + } + + requests := make([]*ec2.DescribeSnapshotsInput, 0, len(snapshotIDs)+len(snapshotNames)) + for _, snapshotID := range snapshotIDs { + requests = append(requests, getRequestForID(snapshotID)) + } + for _, snapshotName := range snapshotNames { + requests = append(requests, getRequestForTag(snapshotName)) + } + + r := make([]chan *ec2.Snapshot, len(requests)) + e := make([]chan error, len(requests)) + + for i, request := range requests { + wg.Add(1) + r[i] = make(chan *ec2.Snapshot, 1) + e[i] = make(chan error, 1) + + go func(resultCh chan *ec2.Snapshot, errCh chan error) { + defer wg.Done() + snapshot, err := c.batchDescribeSnapshots(request) + if err != nil { + errCh <- err + return + } + resultCh <- snapshot + }(r[i], e[i]) + } + + wg.Wait() + + for i := range requests { + select { + case result := <-r[i]: + if result == nil { + t.Errorf("Received nil result for a request") + } + case err := <-e[i]: + if expErr == nil { + t.Errorf("Error while processing request: %v", err) + } + if !errors.Is(err, expErr) { + t.Errorf("Expected error %v, but got %v", expErr, err) + } + default: + t.Errorf("Did not receive a result or an error for a request") + } + } +} + func TestCreateDisk(t *testing.T) { testCases := []struct { name string