Skip to content

Commit

Permalink
Batch EC2 DescribeSnapshots
Browse files Browse the repository at this point in the history
  • Loading branch information
AndrewSirenko committed Mar 6, 2024
1 parent 5cd0e3a commit d2dc135
Show file tree
Hide file tree
Showing 2 changed files with 308 additions and 18 deletions.
170 changes: 152 additions & 18 deletions pkg/cloud/cloud.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,12 @@ const (

// Batcher
const (
volumeIDBatcher batcherType = iota
volumeIDBatcher volumeBatcherType = iota
volumeTagBatcher

snapshotIDBatcher snapshotBatcherType = iota
snapshotTagBatcher

batchDescribeTimeout = 30 * time.Second
)

Expand Down Expand Up @@ -242,14 +245,19 @@ type ec2ListSnapshotsResponse struct {
NextToken *string
}

// batcherType is an enum representing the types of batchers available.
type batcherType int
// volumeBatcherType is an enum representing the types of volume batchers available.
type volumeBatcherType int

// snapshotBatcherType is an enum representing the types of snapshot batchers available.
type snapshotBatcherType int

// batcherManager maintains a collection of batchers for different types of tasks.
type batcherManager struct {
volumeIDBatcher *batcher.Batcher[string, *ec2.Volume]
volumeTagBatcher *batcher.Batcher[string, *ec2.Volume]
instanceIDBatcher *batcher.Batcher[string, *ec2.Instance]
volumeIDBatcher *batcher.Batcher[string, *ec2.Volume]
volumeTagBatcher *batcher.Batcher[string, *ec2.Volume]
instanceIDBatcher *batcher.Batcher[string, *ec2.Instance]
snapshotIDBatcher *batcher.Batcher[string, *ec2.Snapshot]
snapshotTagBatcher *batcher.Batcher[string, *ec2.Snapshot]
}

type cloud struct {
Expand Down Expand Up @@ -334,14 +342,20 @@ func newBatcherManager(svc ec2iface.EC2API) *batcherManager {
volumeTagBatcher: batcher.New(500, 1*time.Second, func(names []string) (map[string]*ec2.Volume, error) {
return execBatchDescribeVolumes(svc, names, volumeTagBatcher)
}),
instanceIDBatcher: batcher.New(50, 300*time.Millisecond, func(names []string) (map[string]*ec2.Instance, error) {
return execBatchDescribeInstances(svc, names)
instanceIDBatcher: batcher.New(50, 300*time.Millisecond, func(ids []string) (map[string]*ec2.Instance, error) {
return execBatchDescribeInstances(svc, ids)
}),
snapshotIDBatcher: batcher.New(500, 300*time.Millisecond, func(ids []string) (map[string]*ec2.Snapshot, error) {
return execBatchDescribeSnapshots(svc, ids, snapshotIDBatcher)
}),
snapshotTagBatcher: batcher.New(500, 300*time.Millisecond, func(names []string) (map[string]*ec2.Snapshot, error) {
return execBatchDescribeSnapshots(svc, names, snapshotTagBatcher)
}),
}
}

// execBatchDescribeVolumes executes a batched DescribeVolumes API call depending on the type of batcher.
func execBatchDescribeVolumes(svc ec2iface.EC2API, input []string, batcher batcherType) (map[string]*ec2.Volume, error) {
func execBatchDescribeVolumes(svc ec2iface.EC2API, input []string, batcher volumeBatcherType) (map[string]*ec2.Volume, error) {
var request *ec2.DescribeVolumesInput

switch batcher {
Expand Down Expand Up @@ -427,7 +441,7 @@ func (c *cloud) batchDescribeVolumes(request *ec2.DescribeVolumesInput) (*ec2.Vo
// extractVolumeKey retrieves the key associated with a given volume based on the batcher type.
// For the volumeIDBatcher type, it returns the volume's ID.
// For other types, it searches for the VolumeNameTagKey within the volume's tags.
func extractVolumeKey(v *ec2.Volume, batcher batcherType) (string, error) {
func extractVolumeKey(v *ec2.Volume, batcher volumeBatcherType) (string, error) {
if batcher == volumeIDBatcher {
if v.VolumeId == nil {
return "", errors.New("extractVolumeKey: missing volume ID")
Expand Down Expand Up @@ -1014,6 +1028,114 @@ func (c *cloud) GetDiskByID(ctx context.Context, volumeID string) (*Disk, error)
}, nil
}

// execBatchDescribeSnapshots executes a batched DescribeSnapshots API call depending on the type of batcher.
func execBatchDescribeSnapshots(svc ec2iface.EC2API, input []string, batcher snapshotBatcherType) (map[string]*ec2.Snapshot, error) {
var request *ec2.DescribeSnapshotsInput

switch batcher {
case snapshotIDBatcher:
klog.V(7).InfoS("execBatchDescribeSnapshots", "snapshotIds", input)
request = &ec2.DescribeSnapshotsInput{
SnapshotIds: aws.StringSlice(input),
}

case snapshotTagBatcher:
klog.V(7).InfoS("execBatchDescribeSnapshots", "names", input)
filters := []*ec2.Filter{
{
Name: aws.String("tag:" + SnapshotNameTagKey),
Values: aws.StringSlice(input),
},
}
request = &ec2.DescribeSnapshotsInput{
Filters: filters,
}

default:
return nil, fmt.Errorf("execBatchDescribeSnapshots: unsupported request type")
}

ctx, cancel := context.WithTimeout(context.Background(), batchDescribeTimeout)
defer cancel()

resp, err := describeSnapshots(ctx, svc, request)
if err != nil {
return nil, err
}

result := make(map[string]*ec2.Snapshot)

for _, snapshot := range resp {
key, err := extractSnapshotKey(snapshot, batcher)
if err != nil {
klog.Warningf("execBatchDescribeSnapshots: skipping snapshot: %v, reason: %v", snapshot, err)
continue
}
result[key] = snapshot
}

klog.V(7).InfoS("execBatchDescribeSnapshots: success", "result", result)
return result, nil
}

// batchDescribeSnapshots processes a DescribeSnapshots request. Depending on the request,
// it determines the appropriate batcher to use, queues the task, and waits for the result.
func (c *cloud) batchDescribeSnapshots(request *ec2.DescribeSnapshotsInput) (*ec2.Snapshot, error) {
var b *batcher.Batcher[string, *ec2.Snapshot]
var task string

switch {
case len(request.SnapshotIds) == 1 && request.SnapshotIds[0] != nil:
b = c.bm.snapshotIDBatcher
task = *request.SnapshotIds[0]

case len(request.Filters) == 1 && *request.Filters[0].Name == "tag:"+SnapshotNameTagKey && len(request.Filters[0].Values) == 1:
b = c.bm.snapshotTagBatcher
task = *request.Filters[0].Values[0]

default:
return nil, fmt.Errorf("batchDescribeSnapshots: invalid request, request: %v", request)
}

ch := make(chan batcher.BatchResult[*ec2.Snapshot])

b.AddTask(task, ch)

r := <-ch

if r.Err != nil {
return nil, r.Err
}
if r.Result == nil {
return nil, ErrNotFound
}
return r.Result, nil
}

// extractSnapshotKey retrieves the key associated with a given snapshot based on the batcher type.
// For the snapshotIDBatcher type, it returns the snapshot's ID.
// For other types, it searches for the SnapshotNameTagKey within the snapshot's tags.
func extractSnapshotKey(s *ec2.Snapshot, batcher snapshotBatcherType) (string, error) {
if batcher == snapshotIDBatcher {
if s.SnapshotId == nil {
return "", errors.New("extractSnapshotKey: missing snapshot ID")
}
return *s.SnapshotId, nil
}
for _, tag := range s.Tags {
klog.V(7).InfoS("extractSnapshotKey: processing tag", "snapshot", s, "*tag.Key", *tag.Key, "SnapshotNameTagKey", SnapshotNameTagKey)
if tag.Key == nil || tag.Value == nil {
klog.V(7).InfoS("extractSnapshotKey: skipping snapshot due to missing tag", "snapshot", s, "tag", tag)
continue
}
if *tag.Key == SnapshotNameTagKey {
klog.V(7).InfoS("extractSnapshotKey: found snapshot name tag", "snapshot", s, "tag", tag)
return *tag.Value, nil
}
}
return "", errors.New("extractSnapshotKey: missing SnapshotNameTagKey in snapshot tags")
}

func (c *cloud) CreateSnapshot(ctx context.Context, volumeID string, snapshotOptions *SnapshotOptions) (snapshot *Snapshot, err error) {
descriptions := "Created by AWS EBS CSI driver for volume " + volumeID

Expand Down Expand Up @@ -1257,11 +1379,11 @@ func (c *cloud) getInstance(ctx context.Context, nodeID string) (*ec2.Instance,
}
}

func (c *cloud) getSnapshot(ctx context.Context, request *ec2.DescribeSnapshotsInput) (*ec2.Snapshot, error) {
func describeSnapshots(ctx context.Context, svc ec2iface.EC2API, request *ec2.DescribeSnapshotsInput) ([]*ec2.Snapshot, error) {
var snapshots []*ec2.Snapshot
var nextToken *string
for {
response, err := c.ec2.DescribeSnapshotsWithContext(ctx, request)
response, err := svc.DescribeSnapshotsWithContext(ctx, request)
if err != nil {
return nil, err
}
Expand All @@ -1273,13 +1395,25 @@ func (c *cloud) getSnapshot(ctx context.Context, request *ec2.DescribeSnapshotsI
request.NextToken = nextToken
}

if l := len(snapshots); l > 1 {
return nil, ErrMultiSnapshots
} else if l < 1 {
return nil, ErrNotFound
}
return snapshots, nil
}

return snapshots[0], nil
func (c *cloud) getSnapshot(ctx context.Context, request *ec2.DescribeSnapshotsInput) (*ec2.Snapshot, error) {
if c.bm == nil {
snapshots, err := describeSnapshots(ctx, c.ec2, request)
if err != nil {
return nil, err
}

if l := len(snapshots); l > 1 {
return nil, ErrMultiSnapshots
} else if l < 1 {
return nil, ErrNotFound
}
return snapshots[0], nil
} else {
return c.batchDescribeSnapshots(request)
}
}

// listSnapshots returns all snapshots based from a request
Expand Down
156 changes: 156 additions & 0 deletions pkg/cloud/cloud_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,162 @@ func executeDescribeInstancesTest(t *testing.T, c *cloud, instanceIds []string,
}
}

func generateSnapshots(snapIDCount, snapTagCount int) []*ec2.Snapshot {
snapshots := make([]*ec2.Snapshot, 0, snapIDCount+snapTagCount)

for i := 0; i < snapIDCount; i++ {
snapID := fmt.Sprintf("snap-%d", i)
snapshots = append(snapshots, &ec2.Snapshot{SnapshotId: aws.String(snapID)})
}

for i := 0; i < snapTagCount; i++ {
snapshotName := fmt.Sprintf("snap-name-%d", i)
snapshots = append(snapshots, &ec2.Snapshot{Tags: []*ec2.Tag{{Key: aws.String(SnapshotNameTagKey), Value: aws.String(snapshotName)}}})
}

return snapshots
}

func extractSnapshotIdentifiers(snapshots []*ec2.Snapshot) (snapshotIDs []string, snapshotNames []string) {
for _, snapshot := range snapshots {
if snapshot.SnapshotId != nil {
snapshotIDs = append(snapshotIDs, *snapshot.SnapshotId)
}
for _, tag := range snapshot.Tags {
if tag.Key != nil && *tag.Key == SnapshotNameTagKey && tag.Value != nil {
snapshotNames = append(snapshotNames, *tag.Value)
}
}
}
return snapshotIDs, snapshotNames
}

func TestBatchDescribeSnapshots(t *testing.T) {
testCases := []struct {
name string
snapshots []*ec2.Snapshot
mockFunc func(mockEC2 *MockEC2API, expErr error, snapshots []*ec2.Snapshot)
expErr error
}{
{
name: "success: snapshot by ID",
snapshots: generateSnapshots(3, 0),
mockFunc: func(mockEC2 *MockEC2API, expErr error, snapshots []*ec2.Snapshot) {
snapshotOutput := &ec2.DescribeSnapshotsOutput{Snapshots: snapshots}
mockEC2.EXPECT().DescribeSnapshotsWithContext(gomock.Any(), gomock.Any()).Return(snapshotOutput, expErr).Times(1)
},
},
{
name: "success: snapshot by tag",
snapshots: generateSnapshots(0, 3),
mockFunc: func(mockEC2 *MockEC2API, expErr error, snapshots []*ec2.Snapshot) {
snapshotOutput := &ec2.DescribeSnapshotsOutput{Snapshots: snapshots}
mockEC2.EXPECT().DescribeSnapshotsWithContext(gomock.Any(), gomock.Any()).Return(snapshotOutput, expErr).Times(1)
},
},
{
name: "success: snapshot by ID and tag",
snapshots: generateSnapshots(3, 4),
mockFunc: func(mockEC2 *MockEC2API, expErr error, snapshots []*ec2.Snapshot) {
snapshotOutput := &ec2.DescribeSnapshotsOutput{Snapshots: snapshots}
mockEC2.EXPECT().DescribeSnapshotsWithContext(gomock.Any(), gomock.Any()).Return(snapshotOutput, expErr).Times(2)
},
},
{
name: "fail: EC2 API generic error",
snapshots: generateSnapshots(3, 2),
mockFunc: func(mockEC2 *MockEC2API, expErr error, snapshots []*ec2.Snapshot) {
mockEC2.EXPECT().DescribeSnapshotsWithContext(gomock.Any(), gomock.Any()).Return(nil, expErr).Times(2)
},
expErr: fmt.Errorf("generic EC2 API error"),
},
}

for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()

mockEC2 := NewMockEC2API(mockCtrl)
c := newCloud(mockEC2)
cloudInstance := c.(*cloud)
cloudInstance.bm = newBatcherManager(cloudInstance.ec2)

tc.mockFunc(mockEC2, tc.expErr, tc.snapshots)
snapshotIDs, snapshotNames := extractSnapshotIdentifiers(tc.snapshots)
executeDescribeSnapshotsTest(t, cloudInstance, snapshotIDs, snapshotNames, tc.expErr)
})
}
}

func executeDescribeSnapshotsTest(t *testing.T, c *cloud, snapshotIDs, snapshotNames []string, expErr error) {
var wg sync.WaitGroup

getRequestForID := func(id string) *ec2.DescribeSnapshotsInput {
return &ec2.DescribeSnapshotsInput{SnapshotIds: []*string{&id}}
}

getRequestForTag := func(snapName string) *ec2.DescribeSnapshotsInput {
return &ec2.DescribeSnapshotsInput{
Filters: []*ec2.Filter{
{
Name: aws.String("tag:" + SnapshotNameTagKey),
Values: []*string{&snapName},
},
},
}
}

requests := make([]*ec2.DescribeSnapshotsInput, 0, len(snapshotIDs)+len(snapshotNames))
for _, snapshotID := range snapshotIDs {
requests = append(requests, getRequestForID(snapshotID))
}
for _, snapshotName := range snapshotNames {
requests = append(requests, getRequestForTag(snapshotName))
}

r := make([]chan *ec2.Snapshot, len(requests))
e := make([]chan error, len(requests))

for i, request := range requests {
wg.Add(1)
r[i] = make(chan *ec2.Snapshot, 1)
e[i] = make(chan error, 1)

go func(resultCh chan *ec2.Snapshot, errCh chan error) {
defer wg.Done()
snapshot, err := c.batchDescribeSnapshots(request)
if err != nil {
errCh <- err
return
}
resultCh <- snapshot
}(r[i], e[i])
}

wg.Wait()

for i := range requests {
select {
case result := <-r[i]:
if result == nil {
t.Errorf("Received nil result for a request")
}
case err := <-e[i]:
if expErr == nil {
t.Errorf("Error while processing request: %v", err)
}
if !errors.Is(err, expErr) {
t.Errorf("Expected error %v, but got %v", expErr, err)
}
default:
t.Errorf("Did not receive a result or an error for a request")
}
}
}

func TestCreateDisk(t *testing.T) {
testCases := []struct {
name string
Expand Down

0 comments on commit d2dc135

Please sign in to comment.