diff --git a/pkg/constants/constants.go b/pkg/constants/constants.go index f522293c..18950f25 100644 --- a/pkg/constants/constants.go +++ b/pkg/constants/constants.go @@ -42,8 +42,9 @@ const ( BucketVersioning = "bucketVersioning" - IsNodeServer = "IS_NODE_SERVER" - KubeNodeName = "KUBE_NODE_NAME" + IsNodeServer = "IS_NODE_SERVER" + KubeNodeName = "KUBE_NODE_NAME" + MaxVolumesPerNodeEnv = "MAX_VOLUMES_PER_NODE" ) var ( diff --git a/pkg/driver/nodeserver.go b/pkg/driver/nodeserver.go index baafb839..e5b922c5 100644 --- a/pkg/driver/nodeserver.go +++ b/pkg/driver/nodeserver.go @@ -12,7 +12,6 @@ package driver import ( "fmt" - "os" "github.com/IBM/ibm-object-csi-driver/pkg/constants" "github.com/IBM/ibm-object-csi-driver/pkg/mounter" @@ -29,12 +28,19 @@ import ( type nodeServer struct { *S3Driver csi.UnimplementedNodeServer - Stats utils.StatsUtils - NodeID string + Stats utils.StatsUtils + NodeServerConfig Mounter mounter.NewMounterFactory MounterUtils mounterUtils.MounterUtils } +type NodeServerConfig struct { + MaxVolumesPerNode int64 + Region string + Zone string + NodeID string +} + func (ns *nodeServer) NodeStageVolume(_ context.Context, req *csi.NodeStageVolumeRequest) (*csi.NodeStageVolumeResponse, error) { klog.V(2).Infof("CSINodeServer-NodeStageVolume: Request %+v", req) @@ -286,28 +292,15 @@ func (ns *nodeServer) NodeGetCapabilities(_ context.Context, req *csi.NodeGetCap func (ns *nodeServer) NodeGetInfo(_ context.Context, req *csi.NodeGetInfoRequest) (*csi.NodeGetInfoResponse, error) { klog.V(3).Infof("NodeGetInfo: called with args %+v", req) - nodeName := os.Getenv(constants.KubeNodeName) - if nodeName == "" { - return nil, fmt.Errorf("KUBE_NODE_NAME env variable not set") - } - - region, zone, err := ns.Stats.GetRegionAndZone(nodeName) - if err != nil { - return nil, err - } - - klog.V(3).Infof("NodeGetInfo: Node region %s", region) - klog.V(3).Infof("NodeGetInfo: Node zone %s", zone) - topology := &csi.Topology{ Segments: map[string]string{ - constants.NodeRegionLabel: region, - constants.NodeZoneLabel: zone, + constants.NodeRegionLabel: ns.Region, + constants.NodeZoneLabel: ns.Zone, }, } resp := &csi.NodeGetInfoResponse{ NodeId: ns.NodeID, - MaxVolumesPerNode: constants.DefaultVolumesPerNode, + MaxVolumesPerNode: ns.MaxVolumesPerNode, AccessibleTopology: topology, } klog.V(2).Info("NodeGetInfo: ", resp) diff --git a/pkg/driver/nodeserver_test.go b/pkg/driver/nodeserver_test.go index d75e3141..e3bf9e86 100644 --- a/pkg/driver/nodeserver_test.go +++ b/pkg/driver/nodeserver_test.go @@ -18,8 +18,6 @@ package driver import ( "errors" - "fmt" - "os" "reflect" "testing" @@ -643,64 +641,46 @@ func TestNodeGetCapabilities(t *testing.T) { } func TestNodeGetInfo(t *testing.T) { + testMaxVolumesPerNode := int64(10) + testRegion := "test-region" + testZone := "test-zone" + + nodeServer := nodeServer{ + NodeServerConfig: NodeServerConfig{ + MaxVolumesPerNode: testMaxVolumesPerNode, + Region: testRegion, + Zone: testZone, + NodeID: testNodeID, + }, + } + testCases := []struct { - testCaseName string - envKubeNodeName string - driverStatsUtils utils.StatsUtils - req *csi.NodeGetInfoRequest - expectedResp *csi.NodeGetInfoResponse - expectedErr error + testCaseName string + req *csi.NodeGetInfoRequest + expectedResp *csi.NodeGetInfoResponse + expectedErr error }{ { - testCaseName: "Positive: Successful", - envKubeNodeName: testNodeID, - req: &csi.NodeGetInfoRequest{}, - driverStatsUtils: utils.NewFakeStatsUtilsImpl(utils.FakeStatsUtilsFuncStruct{ - GetRegionAndZoneFn: func(nodeName string) (string, string, error) { - return "test-region", "test-zone", nil - }, - }), + testCaseName: "Positive: Successful", + req: &csi.NodeGetInfoRequest{}, expectedResp: &csi.NodeGetInfoResponse{ NodeId: testNodeID, - MaxVolumesPerNode: constants.DefaultVolumesPerNode, + MaxVolumesPerNode: testMaxVolumesPerNode, AccessibleTopology: &csi.Topology{ Segments: map[string]string{ - constants.NodeRegionLabel: "test-region", - constants.NodeZoneLabel: "test-zone", + constants.NodeRegionLabel: testRegion, + constants.NodeZoneLabel: testZone, }, }, }, expectedErr: nil, }, - { - testCaseName: "Negative: Failed to get KUBE_NODE_NAME env variable", - envKubeNodeName: "", - driverStatsUtils: &utils.DriverStatsUtils{}, - req: &csi.NodeGetInfoRequest{}, - expectedResp: nil, - expectedErr: errors.New("KUBE_NODE_NAME env variable not set"), - }, - { - testCaseName: "Negative: Failed to get region and zone", - envKubeNodeName: testNodeID, - driverStatsUtils: &utils.DriverStatsUtils{}, - req: &csi.NodeGetInfoRequest{}, - expectedResp: nil, - expectedErr: errors.New("unable to load in-cluster configuration"), - }, } for _, tc := range testCases { t.Log("Testcase being executed", zap.String("testcase", tc.testCaseName)) - _ = os.Setenv(constants.KubeNodeName, tc.envKubeNodeName) - - nodeServer := nodeServer{ - NodeID: testNodeID, - Stats: tc.driverStatsUtils, - } actualResp, actualErr := nodeServer.NodeGetInfo(ctx, tc.req) - fmt.Println(actualErr) if tc.expectedErr != nil { assert.Error(t, actualErr) diff --git a/pkg/driver/s3-driver.go b/pkg/driver/s3-driver.go index bea88a5a..185c2ba7 100644 --- a/pkg/driver/s3-driver.go +++ b/pkg/driver/s3-driver.go @@ -12,8 +12,11 @@ package driver import ( "fmt" + "os" + "strconv" "github.com/IBM/ibm-csi-common/pkg/utils" + "github.com/IBM/ibm-object-csi-driver/pkg/constants" "github.com/IBM/ibm-object-csi-driver/pkg/mounter" mounterUtils "github.com/IBM/ibm-object-csi-driver/pkg/mounter/utils" "github.com/IBM/ibm-object-csi-driver/pkg/s3client" @@ -129,14 +132,36 @@ func newControllerServer(d *S3Driver, statsUtil pkgUtils.StatsUtils, s3cosSessio } } -func newNodeServer(d *S3Driver, statsUtil pkgUtils.StatsUtils, nodeID string, mountObj mounter.NewMounterFactory, mounterUtil mounterUtils.MounterUtils) *nodeServer { - return &nodeServer{ - S3Driver: d, - Stats: statsUtil, - NodeID: nodeID, - Mounter: mountObj, - MounterUtils: mounterUtil, +func newNodeServer(d *S3Driver, statsUtil pkgUtils.StatsUtils, nodeID string, mountObj mounter.NewMounterFactory, mounterUtil mounterUtils.MounterUtils) (*nodeServer, error) { + nodeName := os.Getenv(constants.KubeNodeName) + if nodeName == "" { + return nil, fmt.Errorf("KUBE_NODE_NAME env variable not set") + } + + region, zone, err := statsUtil.GetRegionAndZone(nodeName) + if err != nil { + return nil, err } + + var maxVolumesPerNode int64 + maxVolumesPerNodeStr := os.Getenv(constants.MaxVolumesPerNodeEnv) + if maxVolumesPerNodeStr != "" { + maxVolumesPerNode, err = strconv.ParseInt(maxVolumesPerNodeStr, 10, 64) + if err != nil { + return nil, err + } + } else { + d.logger.Warn("MAX_VOLUMES_PER_NODE env variable not set. Using default value") + maxVolumesPerNode = int64(constants.DefaultVolumesPerNode) + } + + return &nodeServer{ + S3Driver: d, + Stats: statsUtil, + NodeServerConfig: NodeServerConfig{MaxVolumesPerNode: maxVolumesPerNode, Region: region, Zone: zone, NodeID: nodeID}, + Mounter: mountObj, + MounterUtils: mounterUtil, + }, nil } func (driver *S3Driver) NewS3CosDriver(nodeID string, endpoint string, s3cosSession s3client.ObjectStorageSessionFactory, mountObj mounter.NewMounterFactory, statsUtil pkgUtils.StatsUtils, mounterUtil mounterUtils.MounterUtils) (*S3Driver, error) { @@ -158,13 +183,13 @@ func (driver *S3Driver) NewS3CosDriver(nodeID string, endpoint string, s3cosSess case "controller": driver.cs = newControllerServer(driver, statsUtil, s3cosSession, driver.logger) case "node": - driver.ns = newNodeServer(driver, statsUtil, nodeID, mountObj, mounterUtil) + driver.ns, err = newNodeServer(driver, statsUtil, nodeID, mountObj, mounterUtil) case "controller-node": driver.cs = newControllerServer(driver, statsUtil, s3cosSession, driver.logger) - driver.ns = newNodeServer(driver, statsUtil, nodeID, mountObj, mounterUtil) + driver.ns, err = newNodeServer(driver, statsUtil, nodeID, mountObj, mounterUtil) } - return driver, nil + return driver, err } func (driver *S3Driver) Run() { diff --git a/pkg/driver/s3-driver_test.go b/pkg/driver/s3-driver_test.go index 358af1b1..a184b1bc 100644 --- a/pkg/driver/s3-driver_test.go +++ b/pkg/driver/s3-driver_test.go @@ -18,8 +18,11 @@ package driver import ( "bytes" + "errors" + "strconv" "testing" + "github.com/IBM/ibm-object-csi-driver/pkg/constants" "github.com/IBM/ibm-object-csi-driver/pkg/mounter" mounterUtils "github.com/IBM/ibm-object-csi-driver/pkg/mounter/utils" "github.com/IBM/ibm-object-csi-driver/pkg/s3client" @@ -94,64 +97,198 @@ func TestAddNodeServiceCapabilities(t *testing.T) { } } -func TestNewS3CosDriver(t *testing.T) { +func TestNewNodeServer(t *testing.T) { vendorVersion := "test-vendor-version-1.1.2" - driverName := "mydriver" + driverName := "test-csi-driver" - endpoint := "test-endpoint" nodeID := "test-nodeID" - - fakeCosSession := &s3client.FakeCOSSessionFactory{} - fakeMountObj := &mounter.FakeMounterFactory{} + testRegion := "test-region" + testZone := "test-zone" + + testCases := []struct { + testCaseName string + envVars map[string]string + statsUtils utils.StatsUtils + verifyResult func(*testing.T, *nodeServer, error) + expectedErr error + }{ + { + testCaseName: "Positive: success", + envVars: map[string]string{ + constants.KubeNodeName: nodeID, + constants.MaxVolumesPerNodeEnv: "10", + }, + statsUtils: utils.NewFakeStatsUtilsImpl(utils.FakeStatsUtilsFuncStruct{ + GetRegionAndZoneFn: func(nodeName string) (string, string, error) { return testRegion, testZone, nil }, + }), + verifyResult: func(t *testing.T, ns *nodeServer, err error) { + assert.NoError(t, err) + assert.NotNil(t, ns) + assert.Equal(t, ns.MaxVolumesPerNode, int64(10)) + assert.Equal(t, ns.Region, testRegion) + assert.Equal(t, ns.Zone, testZone) + assert.Equal(t, ns.NodeID, nodeID) + }, + expectedErr: nil, + }, + { + testCaseName: "Negative: Failed to get KUBE_NODE_NAME env variable", + envVars: map[string]string{ + constants.KubeNodeName: "", + constants.MaxVolumesPerNodeEnv: "", + }, + verifyResult: func(t *testing.T, ns *nodeServer, err error) { + assert.Nil(t, ns) + }, + expectedErr: errors.New("KUBE_NODE_NAME env variable not set"), + }, + { + testCaseName: "Negative: Failed to get region and zone", + envVars: map[string]string{ + constants.KubeNodeName: nodeID, + constants.MaxVolumesPerNodeEnv: "", + }, + statsUtils: utils.NewFakeStatsUtilsImpl(utils.FakeStatsUtilsFuncStruct{ + GetRegionAndZoneFn: func(nodeName string) (string, string, error) { + return "", "", errors.New("unable to load in-cluster configuration") + }, + }), + verifyResult: func(t *testing.T, ns *nodeServer, err error) { + assert.Nil(t, ns) + }, + expectedErr: errors.New("unable to load in-cluster configuration"), + }, + { + testCaseName: "Negative: invalid value of maxVolumesPerNode", + envVars: map[string]string{ + constants.KubeNodeName: nodeID, + constants.MaxVolumesPerNodeEnv: "invalid", + }, + statsUtils: utils.NewFakeStatsUtilsImpl(utils.FakeStatsUtilsFuncStruct{ + GetRegionAndZoneFn: func(nodeName string) (string, string, error) { + return testRegion, testZone, nil + }, + }), + verifyResult: func(t *testing.T, ns *nodeServer, err error) { + assert.Nil(t, ns) + }, + expectedErr: errors.New("invalid syntax"), + }, + { + testCaseName: "Positive: maxVolumesPerNode not set", + envVars: map[string]string{ + constants.KubeNodeName: nodeID, + constants.MaxVolumesPerNodeEnv: "", + }, + statsUtils: utils.NewFakeStatsUtilsImpl(utils.FakeStatsUtilsFuncStruct{ + GetRegionAndZoneFn: func(nodeName string) (string, string, error) { return testRegion, testZone, nil }, + }), + verifyResult: func(t *testing.T, ns *nodeServer, err error) { + assert.NoError(t, err) + assert.NotNil(t, ns) + assert.Equal(t, ns.MaxVolumesPerNode, int64(0)) + assert.Equal(t, ns.Region, testRegion) + assert.Equal(t, ns.Zone, testZone) + assert.Equal(t, ns.NodeID, nodeID) + }, + expectedErr: nil, + }, + } logger, teardown := GetTestLogger(t) defer teardown() // Setup the CSI driver - driver, err := Setups3Driver(defaultMode, driverName, vendorVersion, logger) + driver, err := Setups3Driver("node", driverName, vendorVersion, logger) assert.NoError(t, err) assert.NotEmpty(t, driver) - statsUtil := utils.NewFakeStatsUtilsImpl(utils.FakeStatsUtilsFuncStruct{}) + fakeMountObj := &mounter.FakeMounterFactory{} mounterUtil := mounterUtils.NewFakeMounterUtilsImpl(mounterUtils.FakeMounterUtilsFuncStruct{}) - csiDriver, err := driver.NewS3CosDriver(nodeID, endpoint, fakeCosSession, fakeMountObj, statsUtil, mounterUtil) - assert.NoError(t, err) - assert.NotEmpty(t, csiDriver) -} - -func TestNewS3CosDriver_mode_node(t *testing.T) { - vendorVersion := "test-vendor-version-1.1.2" - driverName := "mydriver" + for _, tc := range testCases { + t.Log("Testcase being executed", zap.String("testcase", tc.testCaseName)) - endpoint := "test-endpoint" - nodeID := "test-nodeID" - - fakeCosSession := &s3client.FakeCOSSessionFactory{} - fakeMountObj := &mounter.FakeMounterFactory{} - - logger, teardown := GetTestLogger(t) - defer teardown() + for k, v := range tc.envVars { + t.Setenv(k, v) + } - // Setup the CSI driver - driver, err := Setups3Driver("node", driverName, vendorVersion, logger) - assert.NoError(t, err) - assert.NotEmpty(t, driver) + actualResp, actualErr := newNodeServer(driver, tc.statsUtils, nodeID, fakeMountObj, mounterUtil) - statsUtil := utils.NewFakeStatsUtilsImpl(utils.FakeStatsUtilsFuncStruct{}) - mounterUtil := mounterUtils.NewFakeMounterUtilsImpl(mounterUtils.FakeMounterUtilsFuncStruct{}) + if tc.expectedErr != nil { + assert.Error(t, actualErr) + assert.Contains(t, actualErr.Error(), tc.expectedErr.Error()) + } else { + assert.NoError(t, actualErr) + } - csiDriver, err := driver.NewS3CosDriver(nodeID, endpoint, fakeCosSession, fakeMountObj, statsUtil, mounterUtil) - assert.NoError(t, err) - assert.NotEmpty(t, csiDriver) + if tc.verifyResult != nil { + tc.verifyResult(t, actualResp, actualErr) + } + } } -func TestNewS3CosDriver_mode_controller_node(t *testing.T) { +func TestNewS3CosDriver(t *testing.T) { vendorVersion := "test-vendor-version-1.1.2" - driverName := "mydriver" + driverName := "test-csi-driver" endpoint := "test-endpoint" nodeID := "test-nodeID" + testRegion := "test-region" + testZone := "test-zone" + + envVars := map[string]string{ + constants.KubeNodeName: testNodeID, + constants.MaxVolumesPerNodeEnv: strconv.Itoa(constants.DefaultVolumesPerNode), + } + + testCases := []struct { + testCaseName string + mode string + statsUtils utils.StatsUtils + verifyResult func(*testing.T, *S3Driver, error) + expectedErr error + }{ + { + testCaseName: "Positive: controller mode", + mode: "controller", + statsUtils: utils.NewFakeStatsUtilsImpl(utils.FakeStatsUtilsFuncStruct{}), + verifyResult: func(t *testing.T, driver *S3Driver, err error) { + assert.NoError(t, err) + assert.NotEmpty(t, driver.cs) + }, + expectedErr: nil, + }, + { + testCaseName: "Positive: node mode", + mode: "node", + statsUtils: utils.NewFakeStatsUtilsImpl(utils.FakeStatsUtilsFuncStruct{ + GetRegionAndZoneFn: func(nodeName string) (string, string, error) { return testRegion, testZone, nil }, + }), + verifyResult: func(t *testing.T, driver *S3Driver, err error) { + assert.NoError(t, err) + assert.NotEmpty(t, driver.ns) + assert.Equal(t, driver.ns.Region, testRegion) + assert.Equal(t, driver.ns.Zone, testZone) + }, + expectedErr: nil, + }, + { + testCaseName: "Positive: controller and node mode", + mode: "controller-node", + statsUtils: utils.NewFakeStatsUtilsImpl(utils.FakeStatsUtilsFuncStruct{ + GetRegionAndZoneFn: func(nodeName string) (string, string, error) { return testRegion, testZone, nil }, + }), + verifyResult: func(t *testing.T, driver *S3Driver, err error) { + assert.NoError(t, err) + assert.NotEmpty(t, driver.cs) + assert.NotEmpty(t, driver.ns) + assert.Equal(t, driver.ns.Region, testRegion) + assert.Equal(t, driver.ns.Zone, testZone) + }, + expectedErr: nil, + }, + } fakeCosSession := &s3client.FakeCOSSessionFactory{} fakeMountObj := &mounter.FakeMounterFactory{} @@ -159,17 +296,33 @@ func TestNewS3CosDriver_mode_controller_node(t *testing.T) { logger, teardown := GetTestLogger(t) defer teardown() - // Setup the CSI driver - driver, err := Setups3Driver("controller-node", driverName, vendorVersion, logger) - assert.NoError(t, err) - assert.NotEmpty(t, driver) + for k, v := range envVars { + t.Setenv(k, v) + } - statsUtil := utils.NewFakeStatsUtilsImpl(utils.FakeStatsUtilsFuncStruct{}) mounterUtil := mounterUtils.NewFakeMounterUtilsImpl(mounterUtils.FakeMounterUtilsFuncStruct{}) - csiDriver, err := driver.NewS3CosDriver(nodeID, endpoint, fakeCosSession, fakeMountObj, statsUtil, mounterUtil) - assert.NoError(t, err) - assert.NotEmpty(t, csiDriver) + for _, tc := range testCases { + t.Log("Testcase being executed", zap.String("testcase", tc.testCaseName)) + + // Setup the CSI driver + driver, err := Setups3Driver(tc.mode, driverName, vendorVersion, logger) + assert.NoError(t, err) + assert.NotEmpty(t, driver) + + actualResp, actualErr := driver.NewS3CosDriver(nodeID, endpoint, fakeCosSession, fakeMountObj, tc.statsUtils, mounterUtil) + + if tc.expectedErr != nil { + assert.Error(t, actualErr) + assert.Contains(t, actualErr.Error(), tc.expectedErr.Error()) + } else { + assert.NoError(t, actualErr) + } + + if tc.verifyResult != nil { + tc.verifyResult(t, actualResp, actualErr) + } + } } func TestSetups3Driver_Positive(t *testing.T) {