diff --git a/charts/aws-ebs-csi-driver/templates/clusterrole-csi-node.yaml b/charts/aws-ebs-csi-driver/templates/clusterrole-csi-node.yaml index 43ca2ce735..2b7295aafd 100644 --- a/charts/aws-ebs-csi-driver/templates/clusterrole-csi-node.yaml +++ b/charts/aws-ebs-csi-driver/templates/clusterrole-csi-node.yaml @@ -12,3 +12,6 @@ rules: - apiGroups: ["storage.k8s.io"] resources: ["volumeattachments"] verbs: ["get", "list", "watch"] + - apiGroups: ["storage.k8s.io"] + resources: ["csinodes"] + verbs: ["get"] diff --git a/deploy/kubernetes/base/clusterrole-csi-node.yaml b/deploy/kubernetes/base/clusterrole-csi-node.yaml index 7786bed2ae..f4d08b63f7 100644 --- a/deploy/kubernetes/base/clusterrole-csi-node.yaml +++ b/deploy/kubernetes/base/clusterrole-csi-node.yaml @@ -13,3 +13,6 @@ rules: - apiGroups: ["storage.k8s.io"] resources: ["volumeattachments"] verbs: ["get", "list", "watch"] + - apiGroups: ["storage.k8s.io"] + resources: ["csinodes"] + verbs: ["get"] diff --git a/hack/update-mockgen.sh b/hack/update-mockgen.sh index 5681f223f9..08aa608c0a 100755 --- a/hack/update-mockgen.sh +++ b/hack/update-mockgen.sh @@ -29,6 +29,7 @@ BIN="$(dirname "$(realpath "${BASH_SOURCE[0]}")")/../bin" "${BIN}/mockgen" -package driver -destination=./pkg/driver/mock_k8s_client.go -mock_names='Interface=MockKubernetesClient' k8s.io/client-go/kubernetes Interface "${BIN}/mockgen" -package driver -destination=./pkg/driver/mock_k8s_corev1.go k8s.io/client-go/kubernetes/typed/core/v1 CoreV1Interface,NodeInterface "${BIN}/mockgen" -package driver -destination=./pkg/driver/mock_k8s_storagev1.go k8s.io/client-go/kubernetes/typed/storage/v1 VolumeAttachmentInterface,StorageV1Interface +"${BIN}/mockgen" -package driver -destination=./pkg/driver/mock_k8s_storagev1_csinode.go k8s.io/client-go/kubernetes/typed/storage/v1 CSINodeInterface # Fixes "Mounter Type cannot implement 'Mounter' as it has a non-exported method and is defined in a different package" # See https://github.com/kubernetes/mount-utils/commit/a20fcfb15a701977d086330b47b7efad51eb608e for context. diff --git a/pkg/driver/mock_k8s_storagev1_csinode.go b/pkg/driver/mock_k8s_storagev1_csinode.go new file mode 100644 index 0000000000..0cd771f294 --- /dev/null +++ b/pkg/driver/mock_k8s_storagev1_csinode.go @@ -0,0 +1,178 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: k8s.io/client-go/kubernetes/typed/storage/v1 (interfaces: CSINodeInterface) + +// Package driver is a generated GoMock package. +package driver + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + v1 "k8s.io/api/storage/v1" + v10 "k8s.io/apimachinery/pkg/apis/meta/v1" + types "k8s.io/apimachinery/pkg/types" + watch "k8s.io/apimachinery/pkg/watch" + v11 "k8s.io/client-go/applyconfigurations/storage/v1" +) + +// MockCSINodeInterface is a mock of CSINodeInterface interface. +type MockCSINodeInterface struct { + ctrl *gomock.Controller + recorder *MockCSINodeInterfaceMockRecorder +} + +// MockCSINodeInterfaceMockRecorder is the mock recorder for MockCSINodeInterface. +type MockCSINodeInterfaceMockRecorder struct { + mock *MockCSINodeInterface +} + +// NewMockCSINodeInterface creates a new mock instance. +func NewMockCSINodeInterface(ctrl *gomock.Controller) *MockCSINodeInterface { + mock := &MockCSINodeInterface{ctrl: ctrl} + mock.recorder = &MockCSINodeInterfaceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockCSINodeInterface) EXPECT() *MockCSINodeInterfaceMockRecorder { + return m.recorder +} + +// Apply mocks base method. +func (m *MockCSINodeInterface) Apply(arg0 context.Context, arg1 *v11.CSINodeApplyConfiguration, arg2 v10.ApplyOptions) (*v1.CSINode, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Apply", arg0, arg1, arg2) + ret0, _ := ret[0].(*v1.CSINode) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Apply indicates an expected call of Apply. +func (mr *MockCSINodeInterfaceMockRecorder) Apply(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Apply", reflect.TypeOf((*MockCSINodeInterface)(nil).Apply), arg0, arg1, arg2) +} + +// Create mocks base method. +func (m *MockCSINodeInterface) Create(arg0 context.Context, arg1 *v1.CSINode, arg2 v10.CreateOptions) (*v1.CSINode, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Create", arg0, arg1, arg2) + ret0, _ := ret[0].(*v1.CSINode) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Create indicates an expected call of Create. +func (mr *MockCSINodeInterfaceMockRecorder) Create(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockCSINodeInterface)(nil).Create), arg0, arg1, arg2) +} + +// Delete mocks base method. +func (m *MockCSINodeInterface) Delete(arg0 context.Context, arg1 string, arg2 v10.DeleteOptions) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Delete", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// Delete indicates an expected call of Delete. +func (mr *MockCSINodeInterfaceMockRecorder) Delete(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockCSINodeInterface)(nil).Delete), arg0, arg1, arg2) +} + +// DeleteCollection mocks base method. +func (m *MockCSINodeInterface) DeleteCollection(arg0 context.Context, arg1 v10.DeleteOptions, arg2 v10.ListOptions) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteCollection", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteCollection indicates an expected call of DeleteCollection. +func (mr *MockCSINodeInterfaceMockRecorder) DeleteCollection(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteCollection", reflect.TypeOf((*MockCSINodeInterface)(nil).DeleteCollection), arg0, arg1, arg2) +} + +// Get mocks base method. +func (m *MockCSINodeInterface) Get(arg0 context.Context, arg1 string, arg2 v10.GetOptions) (*v1.CSINode, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", arg0, arg1, arg2) + ret0, _ := ret[0].(*v1.CSINode) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockCSINodeInterfaceMockRecorder) Get(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockCSINodeInterface)(nil).Get), arg0, arg1, arg2) +} + +// List mocks base method. +func (m *MockCSINodeInterface) List(arg0 context.Context, arg1 v10.ListOptions) (*v1.CSINodeList, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "List", arg0, arg1) + ret0, _ := ret[0].(*v1.CSINodeList) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// List indicates an expected call of List. +func (mr *MockCSINodeInterfaceMockRecorder) List(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockCSINodeInterface)(nil).List), arg0, arg1) +} + +// Patch mocks base method. +func (m *MockCSINodeInterface) Patch(arg0 context.Context, arg1 string, arg2 types.PatchType, arg3 []byte, arg4 v10.PatchOptions, arg5 ...string) (*v1.CSINode, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1, arg2, arg3, arg4} + for _, a := range arg5 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Patch", varargs...) + ret0, _ := ret[0].(*v1.CSINode) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Patch indicates an expected call of Patch. +func (mr *MockCSINodeInterfaceMockRecorder) Patch(arg0, arg1, arg2, arg3, arg4 interface{}, arg5 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1, arg2, arg3, arg4}, arg5...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Patch", reflect.TypeOf((*MockCSINodeInterface)(nil).Patch), varargs...) +} + +// Update mocks base method. +func (m *MockCSINodeInterface) Update(arg0 context.Context, arg1 *v1.CSINode, arg2 v10.UpdateOptions) (*v1.CSINode, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Update", arg0, arg1, arg2) + ret0, _ := ret[0].(*v1.CSINode) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Update indicates an expected call of Update. +func (mr *MockCSINodeInterfaceMockRecorder) Update(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockCSINodeInterface)(nil).Update), arg0, arg1, arg2) +} + +// Watch mocks base method. +func (m *MockCSINodeInterface) Watch(arg0 context.Context, arg1 v10.ListOptions) (watch.Interface, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Watch", arg0, arg1) + ret0, _ := ret[0].(watch.Interface) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Watch indicates an expected call of Watch. +func (mr *MockCSINodeInterfaceMockRecorder) Watch(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Watch", reflect.TypeOf((*MockCSINodeInterface)(nil).Watch), arg0, arg1) +} diff --git a/pkg/driver/node.go b/pkg/driver/node.go index b5f50ef0cd..0d9c5f12c6 100644 --- a/pkg/driver/node.go +++ b/pkg/driver/node.go @@ -35,6 +35,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" k8stypes "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/kubernetes" "k8s.io/klog/v2" "k8s.io/kubernetes/pkg/volume" ) @@ -68,6 +69,8 @@ var ( csi.NodeServiceCapability_RPC_GET_VOLUME_STATS, } + // taintRemovalInitialDelay is the initial delay for node taint removal + taintRemovalInitialDelay = 1 * time.Second // taintRemovalBackoff is the exponential backoff configuration for node taint removal taintRemovalBackoff = wait.Backoff{ Duration: 500 * time.Millisecond, @@ -103,7 +106,9 @@ func newNodeService(driverOptions *DriverOptions) nodeService { // Remove taint from node to indicate driver startup success // This is done at the last possible moment to prevent race conditions or false positive removals - go removeTaintInBackground(cloud.DefaultKubernetesAPIClient, removeNotReadyTaint) + time.AfterFunc(taintRemovalInitialDelay, func() { + removeTaintInBackground(cloud.DefaultKubernetesAPIClient, removeNotReadyTaint) + }) return nodeService{ metadata: metadata, @@ -896,6 +901,11 @@ func removeNotReadyTaint(k8sClient cloud.KubernetesAPIClient) error { return err } + err = checkAllocatable(clientset, nodeName) + if err != nil { + return err + } + var taintsToKeep []corev1.Taint for _, taint := range node.Spec.Taints { if taint.Key != AgentNotReadyNodeTaintKey { @@ -936,6 +946,25 @@ func removeNotReadyTaint(k8sClient cloud.KubernetesAPIClient) error { return nil } +func checkAllocatable(clientset kubernetes.Interface, nodeName string) error { + csiNode, err := clientset.StorageV1().CSINodes().Get(context.Background(), nodeName, metav1.GetOptions{}) + if err != nil { + return fmt.Errorf("isAllocatableSet: failed to get CSINode for %s: %w", nodeName, err) + } + + for _, driver := range csiNode.Spec.Drivers { + if driver.Name == DriverName { + if driver.Allocatable != nil && driver.Allocatable.Count != nil { + klog.InfoS("CSINode Allocatable value is set", "nodeName", nodeName, "count", *driver.Allocatable.Count) + return nil + } + return fmt.Errorf("isAllocatableSet: allocatable value not set for driver on node %s", nodeName) + } + } + + return fmt.Errorf("isAllocatableSet: driver not found on node %s", nodeName) +} + func recheckFormattingOptionParameter(context map[string]string, key string, fsConfigs map[string]fileSystemConfig, fsType string) (value string, err error) { v, ok := context[key] if ok { diff --git a/pkg/driver/node_test.go b/pkg/driver/node_test.go index 0fcf67b504..79f53ab2d1 100644 --- a/pkg/driver/node_test.go +++ b/pkg/driver/node_test.go @@ -38,6 +38,8 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" corev1 "k8s.io/api/core/v1" + v1 "k8s.io/api/storage/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/kubernetes" ) @@ -2403,6 +2405,35 @@ func TestRemoveNotReadyTaint(t *testing.T) { t.Setenv("CSI_NODE_NAME", nodeName) getNodeMock, _ := getNodeMock(mockCtl, nodeName, &corev1.Node{}, nil) + storageV1Mock := NewMockStorageV1Interface(mockCtl) + getNodeMock.(*MockKubernetesClient).EXPECT().StorageV1().Return(storageV1Mock).AnyTimes() + + csiNodesMock := NewMockCSINodeInterface(mockCtl) + storageV1Mock.EXPECT().CSINodes().Return(csiNodesMock).Times(1) + + count := int32(1) + + mockCSINode := &v1.CSINode{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-node-123", + }, + Spec: v1.CSINodeSpec{ + Drivers: []v1.CSINodeDriver{ + { + Name: DriverName, + Allocatable: &v1.VolumeNodeResources{ + Count: &count, + }, + }, + }, + }, + } + + csiNodesMock.EXPECT(). + Get(gomock.Any(), gomock.Eq("test-node-123"), gomock.Any()). + Return(mockCSINode, nil). + Times(1) + return func() (kubernetes.Interface, error) { return getNodeMock, nil } @@ -2414,16 +2445,52 @@ func TestRemoveNotReadyTaint(t *testing.T) { setup: func(t *testing.T, mockCtl *gomock.Controller) func() (kubernetes.Interface, error) { t.Setenv("CSI_NODE_NAME", nodeName) getNodeMock, mockNode := getNodeMock(mockCtl, nodeName, &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: nodeName, + }, Spec: corev1.NodeSpec{ Taints: []corev1.Taint{ { Key: AgentNotReadyNodeTaintKey, - Effect: "NoExecute", + Effect: corev1.TaintEffectNoExecute, }, }, }, }, nil) - mockNode.EXPECT().Patch(gomock.Any(), gomock.Eq(nodeName), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("Failed to patch node!")) + + storageV1Mock := NewMockStorageV1Interface(mockCtl) + getNodeMock.(*MockKubernetesClient).EXPECT().StorageV1().Return(storageV1Mock).AnyTimes() + + csiNodesMock := NewMockCSINodeInterface(mockCtl) + storageV1Mock.EXPECT().CSINodes().Return(csiNodesMock).Times(1) + + count := int32(1) + mockCSINode := &v1.CSINode{ + ObjectMeta: metav1.ObjectMeta{ + Name: nodeName, + }, + Spec: v1.CSINodeSpec{ + Drivers: []v1.CSINodeDriver{ + { + Name: DriverName, + NodeID: nodeName, + Allocatable: &v1.VolumeNodeResources{ + Count: &count, + }, + }, + }, + }, + } + + csiNodesMock.EXPECT(). + Get(gomock.Any(), gomock.Eq(nodeName), gomock.Any()). + Return(mockCSINode, nil). + Times(1) + + mockNode.EXPECT(). + Patch(gomock.Any(), gomock.Eq(nodeName), gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("Failed to patch node!")). + Times(1) return func() (kubernetes.Interface, error) { return getNodeMock, nil @@ -2436,16 +2503,52 @@ func TestRemoveNotReadyTaint(t *testing.T) { setup: func(t *testing.T, mockCtl *gomock.Controller) func() (kubernetes.Interface, error) { t.Setenv("CSI_NODE_NAME", nodeName) getNodeMock, mockNode := getNodeMock(mockCtl, nodeName, &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: nodeName, + }, Spec: corev1.NodeSpec{ Taints: []corev1.Taint{ { Key: AgentNotReadyNodeTaintKey, - Effect: "NoSchedule", + Effect: corev1.TaintEffectNoSchedule, }, }, }, }, nil) - mockNode.EXPECT().Patch(gomock.Any(), gomock.Eq(nodeName), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) + + storageV1Mock := NewMockStorageV1Interface(mockCtl) + getNodeMock.(*MockKubernetesClient).EXPECT().StorageV1().Return(storageV1Mock).AnyTimes() + + csiNodesMock := NewMockCSINodeInterface(mockCtl) + storageV1Mock.EXPECT().CSINodes().Return(csiNodesMock).Times(1) + + count := int32(1) + mockCSINode := &v1.CSINode{ + ObjectMeta: metav1.ObjectMeta{ + Name: nodeName, + }, + Spec: v1.CSINodeSpec{ + Drivers: []v1.CSINodeDriver{ + { + Name: DriverName, + NodeID: nodeName, + Allocatable: &v1.VolumeNodeResources{ + Count: &count, + }, + }, + }, + }, + } + + csiNodesMock.EXPECT(). + Get(gomock.Any(), gomock.Eq(nodeName), gomock.Any()). + Return(mockCSINode, nil). + Times(1) + + mockNode.EXPECT(). + Patch(gomock.Any(), gomock.Eq(nodeName), gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil, nil). + Times(1) return func() (kubernetes.Interface, error) { return getNodeMock, nil @@ -2453,6 +2556,132 @@ func TestRemoveNotReadyTaint(t *testing.T) { }, expResult: nil, }, + { + name: "failed to get CSINode", + setup: func(t *testing.T, mockCtl *gomock.Controller) func() (kubernetes.Interface, error) { + t.Setenv("CSI_NODE_NAME", nodeName) + getNodeMock, _ := getNodeMock(mockCtl, nodeName, &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: nodeName, + }, + Spec: corev1.NodeSpec{ + Taints: []corev1.Taint{ + { + Key: AgentNotReadyNodeTaintKey, + Effect: corev1.TaintEffectNoSchedule, + }, + }, + }, + }, nil) + + storageV1Mock := NewMockStorageV1Interface(mockCtl) + getNodeMock.(*MockKubernetesClient).EXPECT().StorageV1().Return(storageV1Mock).AnyTimes() + + csiNodesMock := NewMockCSINodeInterface(mockCtl) + storageV1Mock.EXPECT().CSINodes().Return(csiNodesMock).Times(1) + + csiNodesMock.EXPECT(). + Get(gomock.Any(), gomock.Eq(nodeName), gomock.Any()). + Return(nil, fmt.Errorf("Failed to get CSINode")). + Times(1) + + return func() (kubernetes.Interface, error) { + return getNodeMock, nil + } + }, + expResult: fmt.Errorf("isAllocatableSet: failed to get CSINode for %s: Failed to get CSINode", nodeName), + }, + { + name: "allocatable value not set for driver on node", + setup: func(t *testing.T, mockCtl *gomock.Controller) func() (kubernetes.Interface, error) { + t.Setenv("CSI_NODE_NAME", nodeName) + getNodeMock, _ := getNodeMock(mockCtl, nodeName, &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: nodeName, + }, + Spec: corev1.NodeSpec{ + Taints: []corev1.Taint{ + { + Key: AgentNotReadyNodeTaintKey, + Effect: corev1.TaintEffectNoSchedule, + }, + }, + }, + }, nil) + + storageV1Mock := NewMockStorageV1Interface(mockCtl) + getNodeMock.(*MockKubernetesClient).EXPECT().StorageV1().Return(storageV1Mock).AnyTimes() + + csiNodesMock := NewMockCSINodeInterface(mockCtl) + storageV1Mock.EXPECT().CSINodes().Return(csiNodesMock).Times(1) + + mockCSINode := &v1.CSINode{ + ObjectMeta: metav1.ObjectMeta{ + Name: nodeName, + }, + Spec: v1.CSINodeSpec{ + Drivers: []v1.CSINodeDriver{ + { + Name: DriverName, + NodeID: nodeName, + }, + }, + }, + } + + csiNodesMock.EXPECT(). + Get(gomock.Any(), gomock.Eq(nodeName), gomock.Any()). + Return(mockCSINode, nil). + Times(1) + + return func() (kubernetes.Interface, error) { + return getNodeMock, nil + } + }, + expResult: fmt.Errorf("isAllocatableSet: allocatable value not set for driver on node %s", nodeName), + }, + { + name: "driver not found on node", + setup: func(t *testing.T, mockCtl *gomock.Controller) func() (kubernetes.Interface, error) { + t.Setenv("CSI_NODE_NAME", nodeName) + getNodeMock, _ := getNodeMock(mockCtl, nodeName, &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: nodeName, + }, + Spec: corev1.NodeSpec{ + Taints: []corev1.Taint{ + { + Key: AgentNotReadyNodeTaintKey, + Effect: corev1.TaintEffectNoSchedule, + }, + }, + }, + }, nil) + + storageV1Mock := NewMockStorageV1Interface(mockCtl) + getNodeMock.(*MockKubernetesClient).EXPECT().StorageV1().Return(storageV1Mock).AnyTimes() + + csiNodesMock := NewMockCSINodeInterface(mockCtl) + storageV1Mock.EXPECT().CSINodes().Return(csiNodesMock).Times(1) + + mockCSINode := &v1.CSINode{ + ObjectMeta: metav1.ObjectMeta{ + Name: nodeName, + }, + Spec: v1.CSINodeSpec{}, + } + + csiNodesMock.EXPECT(). + Get(gomock.Any(), gomock.Eq(nodeName), gomock.Any()). + Return(mockCSINode, nil). + Times(1) + + return func() (kubernetes.Interface, error) { + return getNodeMock, nil + } + }, + expResult: fmt.Errorf("isAllocatableSet: driver not found on node %s", nodeName), + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { @@ -2462,8 +2691,13 @@ func TestRemoveNotReadyTaint(t *testing.T) { k8sClientGetter := tc.setup(t, mockCtl) result := removeNotReadyTaint(k8sClientGetter) - if !reflect.DeepEqual(result, tc.expResult) { - t.Fatalf("Expected result `%v`, got result `%v`", tc.expResult, result) + if (result == nil) != (tc.expResult == nil) { + t.Fatalf("expected %v, got %v", tc.expResult, result) + } + if result != nil && tc.expResult != nil { + if result.Error() != tc.expResult.Error() { + t.Fatalf("Expected error message `%v`, got `%v`", tc.expResult.Error(), result.Error()) + } } }) }