diff --git a/pkg/cloud/cloud.go b/pkg/cloud/cloud.go index 975152db85..e31e6752ad 100644 --- a/pkg/cloud/cloud.go +++ b/pkg/cloud/cloud.go @@ -23,6 +23,7 @@ import ( "errors" "fmt" "os" + "strconv" "strings" "sync" "time" @@ -34,6 +35,7 @@ import ( "github.com/aws/smithy-go" "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/batcher" dm "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud/devicemanager" + "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/expiringcache" "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/util" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/klog/v2" @@ -91,6 +93,7 @@ var ( const ( volumeDetachedState = "detached" volumeAttachedState = "attached" + cacheForgetDelay = 1 * time.Hour ) // AWS provisioning limits. @@ -310,12 +313,14 @@ type batcherManager struct { } type cloud struct { - region string - ec2 EC2API - dm dm.DeviceManager - bm *batcherManager - rm *retryManager - vwp volumeWaitParameters + region string + ec2 EC2API + dm dm.DeviceManager + bm *batcherManager + rm *retryManager + vwp volumeWaitParameters + likelyBadDeviceNames expiringcache.ExpiringCache[string, sync.Map] + latestClientTokens expiringcache.ExpiringCache[string, int] } var _ Cloud = &cloud{} @@ -364,12 +369,14 @@ func newEC2Cloud(region string, awsSdkDebugLog bool, userAgentExtra string, batc } return &cloud{ - region: region, - dm: dm.NewDeviceManager(), - ec2: svc, - bm: bm, - rm: newRetryManager(), - vwp: vwp, + region: region, + dm: dm.NewDeviceManager(), + ec2: svc, + bm: bm, + rm: newRetryManager(), + vwp: vwp, + likelyBadDeviceNames: expiringcache.New[string, sync.Map](cacheForgetDelay), + latestClientTokens: expiringcache.New[string, int](cacheForgetDelay), } } @@ -586,8 +593,22 @@ func (c *cloud) CreateDisk(ctx context.Context, volumeName string, diskOptions * } } - // We hash the volume name to generate a unique token that is less than or equal to 64 characters - clientToken := sha256.Sum256([]byte(volumeName)) + // The first client token used for any volume is the volume name as provided via CSI + // However, if a volume fails to create asyncronously (that is, the CreateVolume call + // succeeds but the volume ultimately fails to create), the client token is burned until + // EC2 forgets about its use (measured as 12 hours under normal conditions) + // + // To prevent becoming stuck for 12 hours when this occurs, we sequentially append "-2", + // "-3", "-4", etc to the volume name before hashing on the subsequent attempt after a + // volume fails to create because of an IdempotentParameterMismatch AWS error + // The most recent appended value is stored in an expiring cache to prevent memory leaks + tokenBase := volumeName + if tokenNumber, ok := c.latestClientTokens.Get(volumeName); ok { + tokenBase += "-" + strconv.Itoa(*tokenNumber) + } + + // We use a sha256 hash to guarantee the token that is less than or equal to 64 characters + clientToken := sha256.Sum256([]byte(tokenBase)) requestInput := &ec2.CreateVolumeInput{ AvailabilityZone: aws.String(zone), @@ -630,6 +651,11 @@ func (c *cloud) CreateDisk(ctx context.Context, volumeName string, diskOptions * return nil, ErrNotFound } if isAWSErrorIdempotentParameterMismatch(err) { + nextTokenNumber := 2 + if tokenNumber, ok := c.latestClientTokens.Get(volumeName); ok { + nextTokenNumber = *tokenNumber + 1 + } + c.latestClientTokens.Set(volumeName, &nextTokenNumber) return nil, ErrIdempotentParameterMismatch } return nil, fmt.Errorf("could not create volume in EC2: %w", err) @@ -847,34 +873,19 @@ func (c *cloud) batchDescribeInstances(request *ec2.DescribeInstancesInput) (*ty return r.Result, nil } -// Node likely bad device names cache -// Remember device names that are already in use on an instance and use them last when attaching volumes -// This works around device names that are used but do not appear in the mapping from DescribeInstanceStatus -const cacheForgetDelay = 1 * time.Hour - -type cachedNode struct { - timer *time.Timer - likelyBadNames map[string]struct{} -} - -var cacheMutex sync.Mutex -var nodeDeviceCache map[string]cachedNode = map[string]cachedNode{} - func (c *cloud) AttachDisk(ctx context.Context, volumeID, nodeID string) (string, error) { instance, err := c.getInstance(ctx, nodeID) if err != nil { return "", err } - likelyBadNames := map[string]struct{}{} - cacheMutex.Lock() - if node, ok := nodeDeviceCache[nodeID]; ok { - likelyBadNames = node.likelyBadNames - node.timer.Reset(cacheForgetDelay) + likelyBadDeviceNames, ok := c.likelyBadDeviceNames.Get(nodeID) + if !ok { + likelyBadDeviceNames = new(sync.Map) + c.likelyBadDeviceNames.Set(nodeID, likelyBadDeviceNames) } - cacheMutex.Unlock() - device, err := c.dm.NewDevice(instance, volumeID, likelyBadNames) + device, err := c.dm.NewDevice(instance, volumeID, likelyBadDeviceNames) if err != nil { return "", err } @@ -892,37 +903,16 @@ func (c *cloud) AttachDisk(ctx context.Context, volumeID, nodeID string) (string }) if attachErr != nil { if isAWSErrorBlockDeviceInUse(attachErr) { - cacheMutex.Lock() - if node, ok := nodeDeviceCache[nodeID]; ok { - // Node already had existing cached bad names, add on to the list - node.likelyBadNames[device.Path] = struct{}{} - node.timer.Reset(cacheForgetDelay) - } else { - // Node has no existing cached bad device names, setup a new struct instance - nodeDeviceCache[nodeID] = cachedNode{ - timer: time.AfterFunc(cacheForgetDelay, func() { - // If this ever fires, the node has not had a volume attached for an hour - // In order to prevent a semi-permanent memory leak, delete it from the map - cacheMutex.Lock() - delete(nodeDeviceCache, nodeID) - cacheMutex.Unlock() - }), - likelyBadNames: map[string]struct{}{ - device.Path: {}, - }, - } - } - cacheMutex.Unlock() + // If block device is "in use", that likely indicates a bad name that is in use by a block + // device that we do not know about (example: block devices attached in the AMI, which are + // not reported in DescribeInstance's block device map) + // + // Store such bad names in the "likely bad" map to be considered last in future attempts + likelyBadDeviceNames.Store(device.Path, struct{}{}) } return "", fmt.Errorf("could not attach volume %q to node %q: %w", volumeID, nodeID, attachErr) } - cacheMutex.Lock() - if node, ok := nodeDeviceCache[nodeID]; ok { - // Remove succesfully attached devices from the "likely bad" list - delete(node.likelyBadNames, device.Path) - node.timer.Reset(cacheForgetDelay) - } - cacheMutex.Unlock() + likelyBadDeviceNames.Delete(device.Path) klog.V(5).InfoS("[Debug] AttachVolume", "volumeID", volumeID, "nodeID", nodeID, "resp", resp) } diff --git a/pkg/cloud/cloud_test.go b/pkg/cloud/cloud_test.go index 99b21ed81f..f0142510eb 100644 --- a/pkg/cloud/cloud_test.go +++ b/pkg/cloud/cloud_test.go @@ -20,13 +20,14 @@ import ( "context" "errors" "fmt" - "k8s.io/apimachinery/pkg/util/wait" "reflect" "strings" "sync" "testing" "time" + "k8s.io/apimachinery/pkg/util/wait" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" @@ -35,6 +36,7 @@ import ( "github.com/golang/mock/gomock" dm "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud/devicemanager" + "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/expiringcache" "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/util" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -1280,6 +1282,72 @@ func TestCreateDisk(t *testing.T) { } } +// Test client error IdempotentParameterMismatch by forcing it to progress twice +func TestCreateDiskClientToken(t *testing.T) { + t.Parallel() + + const volumeName = "test-vol-client-token" + const volumeId = "vol-abcd1234" + diskOptions := &DiskOptions{ + CapacityBytes: util.GiBToBytes(1), + Tags: map[string]string{VolumeNameTagKey: volumeName, AwsEbsDriverTagKey: "true"}, + AvailabilityZone: defaultZone, + } + + // Hash of "test-vol-client-token" + const expectedClientToken1 = "6a1b29bd7c5c5541d9d6baa2938e954fc5739dc77e97facf23590bd13f8582c2" + // Hash of "test-vol-client-token-2" + const expectedClientToken2 = "21465f5586388bb8804d0cec2df13c00f9a975c8cddec4bc35e964cdce59015b" + // Hash of "test-vol-client-token-3" + const expectedClientToken3 = "1bee5a79d83981c0041df2c414bb02e0c10aeb49343b63f50f71470edbaa736b" + + mockCtrl := gomock.NewController(t) + mockEC2 := NewMockEC2API(mockCtrl) + c := newCloud(mockEC2) + + gomock.InOrder( + mockEC2.EXPECT().CreateVolume(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, input *ec2.CreateVolumeInput, _ ...func(*ec2.Options)) (*ec2.CreateVolumeOutput, error) { + assert.Equal(t, expectedClientToken1, *input.ClientToken) + return nil, &smithy.GenericAPIError{Code: "IdempotentParameterMismatch"} + }), + mockEC2.EXPECT().CreateVolume(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, input *ec2.CreateVolumeInput, _ ...func(*ec2.Options)) (*ec2.CreateVolumeOutput, error) { + assert.Equal(t, expectedClientToken2, *input.ClientToken) + return nil, &smithy.GenericAPIError{Code: "IdempotentParameterMismatch"} + }), + mockEC2.EXPECT().CreateVolume(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, input *ec2.CreateVolumeInput, _ ...func(*ec2.Options)) (*ec2.CreateVolumeOutput, error) { + assert.Equal(t, expectedClientToken3, *input.ClientToken) + return &ec2.CreateVolumeOutput{ + VolumeId: aws.String(volumeId), + Size: aws.Int32(util.BytesToGiB(diskOptions.CapacityBytes)), + }, nil + }), + mockEC2.EXPECT().DescribeVolumes(gomock.Any(), gomock.Any()).Return(&ec2.DescribeVolumesOutput{ + Volumes: []types.Volume{ + { + VolumeId: aws.String(volumeId), + Size: aws.Int32(util.BytesToGiB(diskOptions.CapacityBytes)), + State: types.VolumeState("available"), + AvailabilityZone: aws.String(diskOptions.AvailabilityZone), + }, + }, + }, nil).AnyTimes(), + ) + + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(defaultCreateDiskDeadline)) + defer cancel() + for i := range 3 { + _, err := c.CreateDisk(ctx, volumeName, diskOptions) + if i < 2 { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + } +} + func TestDeleteDisk(t *testing.T) { testCases := []struct { name string @@ -1341,14 +1409,13 @@ func TestAttachDisk(t *testing.T) { } testCases := []struct { - name string - volumeID string - nodeID string - nodeID2 string - path string - expErr error - mockFunc func(*MockEC2API, context.Context, string, string, string, string, dm.DeviceManager) - validateFunc func(t *testing.T) + name string + volumeID string + nodeID string + nodeID2 string + path string + expErr error + mockFunc func(*MockEC2API, context.Context, string, string, string, string, dm.DeviceManager) }{ { name: "success: AttachVolume normal", @@ -1377,16 +1444,23 @@ func TestAttachDisk(t *testing.T) { name: "success: AttachVolume skip likely bad name", volumeID: defaultVolumeID, nodeID: defaultNodeID, + nodeID2: defaultNodeID, // Induce second attach path: "/dev/xvdab", - expErr: nil, + expErr: fmt.Errorf("could not attach volume %q to node %q: %w", defaultVolumeID, defaultNodeID, blockDeviceInUseErr), mockFunc: func(mockEC2 *MockEC2API, ctx context.Context, volumeID, nodeID, nodeID2, path string, dm dm.DeviceManager) { volumeRequest := createVolumeRequest(volumeID) instanceRequest := createInstanceRequest(nodeID) - attachRequest := createAttachRequest(volumeID, nodeID, path) + attachRequest1 := createAttachRequest(volumeID, nodeID, defaultPath) + attachRequest2 := createAttachRequest(volumeID, nodeID, path) gomock.InOrder( + // First call - fail with "already in use" error mockEC2.EXPECT().DescribeInstances(gomock.Any(), gomock.Eq(instanceRequest)).Return(newDescribeInstancesOutput(nodeID), nil), - mockEC2.EXPECT().AttachVolume(gomock.Any(), gomock.Eq(attachRequest), gomock.Any()).Return(&ec2.AttachVolumeOutput{ + mockEC2.EXPECT().AttachVolume(gomock.Any(), gomock.Eq(attachRequest1), gomock.Any()).Return(nil, blockDeviceInUseErr), + + // Second call - succeed, expect bad device name to be skipped + mockEC2.EXPECT().DescribeInstances(gomock.Any(), gomock.Eq(instanceRequest)).Return(newDescribeInstancesOutput(nodeID), nil), + mockEC2.EXPECT().AttachVolume(gomock.Any(), gomock.Eq(attachRequest2), gomock.Any()).Return(&ec2.AttachVolumeOutput{ Device: aws.String(path), InstanceId: aws.String(nodeID), VolumeId: aws.String(volumeID), @@ -1394,15 +1468,6 @@ func TestAttachDisk(t *testing.T) { }, nil), mockEC2.EXPECT().DescribeVolumes(gomock.Any(), volumeRequest).Return(createDescribeVolumesOutput([]*string{&volumeID}, nodeID, path, "attached"), nil), ) - - nodeDeviceCache = map[string]cachedNode{ - defaultNodeID: { - timer: time.NewTimer(1 * time.Hour), - likelyBadNames: map[string]struct{}{ - defaultPath: {}, - }, - }, - } }, }, { @@ -1416,7 +1481,7 @@ func TestAttachDisk(t *testing.T) { instanceRequest := createInstanceRequest(nodeID) fakeInstance := newFakeInstance(nodeID, volumeID, path) - _, err := dm.NewDevice(&fakeInstance, volumeID, map[string]struct{}{}) + _, err := dm.NewDevice(&fakeInstance, volumeID, new(sync.Map)) require.NoError(t, err) gomock.InOrder( @@ -1439,9 +1504,6 @@ func TestAttachDisk(t *testing.T) { mockEC2.EXPECT().AttachVolume(gomock.Any(), attachRequest, gomock.Any()).Return(nil, errors.New("AttachVolume error")), ) }, - validateFunc: func(t *testing.T) { - assert.NotContains(t, nodeDeviceCache, defaultNodeID) - }, }, { name: "fail: AttachVolume returned block device already in use error", @@ -1458,11 +1520,6 @@ func TestAttachDisk(t *testing.T) { mockEC2.EXPECT().AttachVolume(ctx, attachRequest, gomock.Any()).Return(nil, blockDeviceInUseErr), ) }, - validateFunc: func(t *testing.T) { - assert.Contains(t, nodeDeviceCache, defaultNodeID) - assert.NotNil(t, nodeDeviceCache[defaultNodeID].timer) - assert.Contains(t, nodeDeviceCache[defaultNodeID].likelyBadNames, defaultPath) - }, }, { name: "success: AttachVolume multi-attach", @@ -1524,9 +1581,6 @@ func TestAttachDisk(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - // Reset node likely bad names cache - nodeDeviceCache = map[string]cachedNode{} - mockCtrl := gomock.NewController(t) mockEC2 := NewMockEC2API(mockCtrl) c := newCloud(mockEC2) @@ -1552,10 +1606,6 @@ func TestAttachDisk(t *testing.T) { assert.Equal(t, tc.path, devicePath) } - if tc.validateFunc != nil { - tc.validateFunc(t) - } - mockCtrl.Finish() }) } @@ -3086,11 +3136,13 @@ func testVolumeWaitParameters() volumeWaitParameters { func newCloud(mockEC2 EC2API) Cloud { c := &cloud{ - region: "test-region", - dm: dm.NewDeviceManager(), - ec2: mockEC2, - rm: newRetryManager(), - vwp: testVolumeWaitParameters(), + region: "test-region", + dm: dm.NewDeviceManager(), + ec2: mockEC2, + rm: newRetryManager(), + vwp: testVolumeWaitParameters(), + likelyBadDeviceNames: expiringcache.New[string, sync.Map](cacheForgetDelay), + latestClientTokens: expiringcache.New[string, int](cacheForgetDelay), } return c } diff --git a/pkg/cloud/devicemanager/allocator.go b/pkg/cloud/devicemanager/allocator.go index 52ebb7829c..683eee5b03 100644 --- a/pkg/cloud/devicemanager/allocator.go +++ b/pkg/cloud/devicemanager/allocator.go @@ -18,6 +18,7 @@ package devicemanager import ( "fmt" + "sync" ) // ExistingNames is a map of assigned device names. Presence of a key with a device @@ -34,7 +35,7 @@ type ExistingNames map[string]string // call), so all available device names are used eventually and it minimizes // device name reuse. type NameAllocator interface { - GetNext(existingNames ExistingNames, likelyBadNames map[string]struct{}) (name string, err error) + GetNext(existingNames ExistingNames, likelyBadNames *sync.Map) (name string, err error) } type nameAllocator struct{} @@ -46,18 +47,27 @@ var _ NameAllocator = &nameAllocator{} // // likelyBadNames is a map of names that have previously returned an "in use" error when attempting to mount to them // These names are unlikely to result in a successful mount, and may be permanently unavailable, so use them last -func (d *nameAllocator) GetNext(existingNames ExistingNames, likelyBadNames map[string]struct{}) (string, error) { +func (d *nameAllocator) GetNext(existingNames ExistingNames, likelyBadNames *sync.Map) (string, error) { for _, name := range deviceNames { _, existing := existingNames[name] - _, likelyBad := likelyBadNames[name] + _, likelyBad := likelyBadNames.Load(name) if !existing && !likelyBad { return name, nil } } - for name := range likelyBadNames { - if _, existing := existingNames[name]; !existing { - return name, nil + + finalResortName := "" + likelyBadNames.Range(func(name, _ interface{}) bool { + if name, ok := name.(string); ok { + if _, existing := existingNames[name]; !existing { + finalResortName = name + return false + } } + return true + }) + if finalResortName != "" { + return finalResortName, nil } return "", fmt.Errorf("there are no names available") diff --git a/pkg/cloud/devicemanager/allocator_test.go b/pkg/cloud/devicemanager/allocator_test.go index d1b3f8082e..eb05e2e889 100644 --- a/pkg/cloud/devicemanager/allocator_test.go +++ b/pkg/cloud/devicemanager/allocator_test.go @@ -17,6 +17,7 @@ limitations under the License. package devicemanager import ( + "sync" "testing" ) @@ -26,7 +27,7 @@ func TestNameAllocator(t *testing.T) { for _, name := range deviceNames { t.Run(name, func(t *testing.T) { - actual, err := allocator.GetNext(existingNames, map[string]struct{}{}) + actual, err := allocator.GetNext(existingNames, new(sync.Map)) if err != nil { t.Errorf("test %q: unexpected error: %v", name, err) } @@ -39,18 +40,24 @@ func TestNameAllocator(t *testing.T) { } func TestNameAllocatorLikelyBadName(t *testing.T) { - skippedName := deviceNames[32] - existingNames := map[string]string{} + skippedNameExisting := deviceNames[11] + skippedNameNew := deviceNames[32] + likelyBadNames := new(sync.Map) + likelyBadNames.Store(skippedNameExisting, struct{}{}) + likelyBadNames.Store(skippedNameNew, struct{}{}) + existingNames := map[string]string{ + skippedNameExisting: "", + } allocator := nameAllocator{} for _, name := range deviceNames { - if name == skippedName { - // Name in likelyBadNames should be skipped until it is the last available name + if name == skippedNameExisting || name == skippedNameNew { + // Names in likelyBadNames should be skipped until it is the last available name continue } t.Run(name, func(t *testing.T) { - actual, err := allocator.GetNext(existingNames, map[string]struct{}{skippedName: {}}) + actual, err := allocator.GetNext(existingNames, likelyBadNames) if err != nil { t.Errorf("test %q: unexpected error: %v", name, err) } @@ -61,9 +68,16 @@ func TestNameAllocatorLikelyBadName(t *testing.T) { }) } - lastName, _ := allocator.GetNext(existingNames, map[string]struct{}{skippedName: {}}) - if lastName != skippedName { - t.Errorf("test %q: expected %q, got %q (likelyBadNames fallback)", skippedName, skippedName, lastName) + onlyExisting := new(sync.Map) + onlyExisting.Store(skippedNameExisting, struct{}{}) + _, err := allocator.GetNext(existingNames, onlyExisting) + if err != nil { + t.Errorf("got nil when error expected (likelyBadNames with only existing names)") + } + + lastName, _ := allocator.GetNext(existingNames, likelyBadNames) + if lastName != skippedNameNew { + t.Errorf("test %q: expected %q, got %q (likelyBadNames fallback)", skippedNameNew, skippedNameNew, lastName) } } @@ -72,10 +86,10 @@ func TestNameAllocatorError(t *testing.T) { existingNames := map[string]string{} for i := 0; i < len(deviceNames); i++ { - name, _ := allocator.GetNext(existingNames, map[string]struct{}{}) + name, _ := allocator.GetNext(existingNames, new(sync.Map)) existingNames[name] = "" } - name, err := allocator.GetNext(existingNames, map[string]struct{}{}) + name, err := allocator.GetNext(existingNames, new(sync.Map)) if err == nil { t.Errorf("expected error, got device %q", name) } diff --git a/pkg/cloud/devicemanager/manager.go b/pkg/cloud/devicemanager/manager.go index 784ac252b0..76fa703677 100644 --- a/pkg/cloud/devicemanager/manager.go +++ b/pkg/cloud/devicemanager/manager.go @@ -52,7 +52,7 @@ type DeviceManager interface { // NewDevice retrieves the device if the device is already assigned. // Otherwise it creates a new device with next available device name // and mark it as unassigned device. - NewDevice(instance *types.Instance, volumeID string, likelyBadNames map[string]struct{}) (device *Device, err error) + NewDevice(instance *types.Instance, volumeID string, likelyBadNames *sync.Map) (device *Device, err error) // GetDevice returns the device already assigned to the volume. GetDevice(instance *types.Instance, volumeID string) (device *Device, err error) @@ -103,7 +103,7 @@ func NewDeviceManager() DeviceManager { } } -func (d *deviceManager) NewDevice(instance *types.Instance, volumeID string, likelyBadNames map[string]struct{}) (*Device, error) { +func (d *deviceManager) NewDevice(instance *types.Instance, volumeID string, likelyBadNames *sync.Map) (*Device, error) { d.mux.Lock() defer d.mux.Unlock() diff --git a/pkg/cloud/devicemanager/manager_test.go b/pkg/cloud/devicemanager/manager_test.go index f71a845ba9..88b16e24d0 100644 --- a/pkg/cloud/devicemanager/manager_test.go +++ b/pkg/cloud/devicemanager/manager_test.go @@ -17,6 +17,7 @@ limitations under the License. package devicemanager import ( + "sync" "testing" "github.com/aws/aws-sdk-go-v2/aws" @@ -59,7 +60,7 @@ func TestNewDevice(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { // Should fail if instance is nil - dev1, err := dm.NewDevice(nil, tc.volumeID, map[string]struct{}{}) + dev1, err := dm.NewDevice(nil, tc.volumeID, new(sync.Map)) if err == nil { t.Fatalf("Expected error when nil instance is passed in, got nothing") } @@ -70,11 +71,11 @@ func TestNewDevice(t *testing.T) { fakeInstance := newFakeInstance(tc.instanceID, tc.existingVolumeID, tc.existingDevicePath) // Should create valid Device with valid path - dev1, err = dm.NewDevice(fakeInstance, tc.volumeID, map[string]struct{}{}) + dev1, err = dm.NewDevice(fakeInstance, tc.volumeID, new(sync.Map)) assertDevice(t, dev1, false, err) // Devices with same instance and volume should have same paths - dev2, err := dm.NewDevice(fakeInstance, tc.volumeID, map[string]struct{}{}) + dev2, err := dm.NewDevice(fakeInstance, tc.volumeID, new(sync.Map)) assertDevice(t, dev2, true /*IsAlreadyAssigned*/, err) if dev1.Path != dev2.Path { t.Fatalf("Expected equal paths, got %v and %v", dev1.Path, dev2.Path) @@ -82,7 +83,7 @@ func TestNewDevice(t *testing.T) { // Should create new Device with the same path after releasing dev2.Release(false) - dev3, err := dm.NewDevice(fakeInstance, tc.volumeID, map[string]struct{}{}) + dev3, err := dm.NewDevice(fakeInstance, tc.volumeID, new(sync.Map)) assertDevice(t, dev3, false, err) if dev3.Path != dev1.Path { t.Fatalf("Expected equal paths, got %v and %v", dev1.Path, dev3.Path) @@ -136,7 +137,7 @@ func TestNewDeviceWithExistingDevice(t *testing.T) { t.Run(tc.name, func(t *testing.T) { fakeInstance := newFakeInstance("fake-instance", tc.existingID, tc.existingPath) - dev, err := dm.NewDevice(fakeInstance, tc.volumeID, map[string]struct{}{}) + dev, err := dm.NewDevice(fakeInstance, tc.volumeID, new(sync.Map)) assertDevice(t, dev, tc.existingID == tc.volumeID, err) if dev.Path != tc.expectedPath { @@ -169,7 +170,7 @@ func TestGetDevice(t *testing.T) { fakeInstance := newFakeInstance(tc.instanceID, tc.existingVolumeID, tc.existingDevicePath) // Should create valid Device with valid path - dev1, err := dm.NewDevice(fakeInstance, tc.volumeID, map[string]struct{}{}) + dev1, err := dm.NewDevice(fakeInstance, tc.volumeID, new(sync.Map)) assertDevice(t, dev1, false /*IsAlreadyAssigned*/, err) // Devices with same instance and volume should have same paths @@ -205,7 +206,7 @@ func TestReleaseDevice(t *testing.T) { fakeInstance := newFakeInstance(tc.instanceID, tc.existingVolumeID, tc.existingDevicePath) // Should get assigned Device after releasing tainted device - dev, err := dm.NewDevice(fakeInstance, tc.volumeID, map[string]struct{}{}) + dev, err := dm.NewDevice(fakeInstance, tc.volumeID, new(sync.Map)) assertDevice(t, dev, false /*IsAlreadyAssigned*/, err) dev.Taint() dev.Release(false) diff --git a/pkg/expiringcache/expiring_cache.go b/pkg/expiringcache/expiring_cache.go new file mode 100644 index 0000000000..962301d12d --- /dev/null +++ b/pkg/expiringcache/expiring_cache.go @@ -0,0 +1,97 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package expiringcache + +import ( + "sync" + "time" +) + +// ExpiringCache is a thread-safe "time expiring" cache that +// automatically removes objects that are not accessed for a +// configurable delay +// +// It is used in various places where we need to cache data for an +// unknown amount of time, to prevent memory leaks +// +// From the consumer's perspective, it behaves similarly to a map +// KeyType is the type of the object that is used as a key +// ValueType is the type of the object that is stored +type ExpiringCache[KeyType comparable, ValueType any] interface { + // Get operates identically to retrieving from a map, returning + // the value and/or boolean indicating if the value existed in the map + // + // Multiple callers can receive the same value simultaneously from Get, + // it is the caller's responsibility to ensure they are not modified + Get(key KeyType) (value *ValueType, ok bool) + // Set operates identically to setting a value in a map, adding an entry + // or overriding the existing value for a given key + Set(key KeyType, value *ValueType) +} + +type timedValue[ValueType any] struct { + value *ValueType + timer *time.Timer +} + +type expiringCache[KeyType comparable, ValueType any] struct { + expirationDelay time.Duration + values map[KeyType]timedValue[ValueType] + mutex sync.Mutex +} + +// New returns a new ExpiringCache +// for a given KeyType, ValueType, and expiration delay +func New[KeyType comparable, ValueType any](expirationDelay time.Duration) ExpiringCache[KeyType, ValueType] { + return &expiringCache[KeyType, ValueType]{ + expirationDelay: expirationDelay, + values: make(map[KeyType]timedValue[ValueType]), + } +} + +func (c *expiringCache[KeyType, ValueType]) Get(key KeyType) (*ValueType, bool) { + c.mutex.Lock() + defer c.mutex.Unlock() + + if v, ok := c.values[key]; ok { + v.timer.Reset(c.expirationDelay) + return v.value, true + } else { + return nil, false + } +} + +func (c *expiringCache[KeyType, ValueType]) Set(key KeyType, value *ValueType) { + c.mutex.Lock() + defer c.mutex.Unlock() + + if v, ok := c.values[key]; ok { + v.timer.Reset(c.expirationDelay) + v.value = value + c.values[key] = v + } else { + c.values[key] = timedValue[ValueType]{ + timer: time.AfterFunc(c.expirationDelay, func() { + c.mutex.Lock() + defer c.mutex.Unlock() + + delete(c.values, key) + }), + value: value, + } + } +} diff --git a/pkg/expiringcache/expiring_cache_test.go b/pkg/expiringcache/expiring_cache_test.go new file mode 100644 index 0000000000..64f8169cba --- /dev/null +++ b/pkg/expiringcache/expiring_cache_test.go @@ -0,0 +1,65 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package expiringcache + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +const ( + testExpiration = time.Millisecond * 50 + testSleep = time.Millisecond * 35 + testKey = "key" +) + +var ( + testValue1 = "value" + testValue2 = "value2" +) + +func TestExpiringCache(t *testing.T) { + t.Parallel() + + cache := New[string, string](testExpiration) + + value, ok := cache.Get(testKey) + assert.False(t, ok, "Should not be able to Get() value before Set()ing it") + assert.Nil(t, value, "Value should be nil when Get() returns not ok") + + cache.Set(testKey, &testValue1) + value, ok = cache.Get(testKey) + assert.True(t, ok, "Should be able to Get() after Set()ing it") + assert.Equal(t, &testValue1, value, "Should Get() the same value that was Set()") + + cache.Set(testKey, &testValue2) + value, ok = cache.Get(testKey) + assert.True(t, ok, "Should be able to Get() after Set()ing it (after overwrite)") + assert.Equal(t, &testValue2, value, "Should Get() the same value that was Set() (after overwrite)") + + time.Sleep(testSleep) + value, ok = cache.Get(testKey) + assert.True(t, ok, "Should be able to Get() after sleeping less than the expiration delay") + assert.Equal(t, &testValue2, value, "Should Get() the same value that was Set() (after sleep)") + + time.Sleep(testSleep * 2) + value, ok = cache.Get(testKey) + assert.False(t, ok, "Should not be able to Get() value after it expires") + assert.Nil(t, value, "Value should be nil when Get() returns not ok (after expiration)") +}