Skip to content

Commit

Permalink
Improve relationship between driver and cloud
Browse files Browse the repository at this point in the history
Signed-off-by: torredil <[email protected]>
  • Loading branch information
torredil committed Apr 4, 2024
1 parent 5f95608 commit c249057
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 262 deletions.
25 changes: 24 additions & 1 deletion cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"time"

"github.com/kubernetes-sigs/aws-ebs-csi-driver/cmd/hooks"
"github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud"
"github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud/metadata"
"github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/driver"
"github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/metrics"
Expand Down Expand Up @@ -132,7 +133,29 @@ func main() {
r.InitializeMetricsHandler(options.HttpEndpoint, "/metrics")
}

drv, err := driver.NewDriver(&options)
region := os.Getenv("AWS_REGION")
if region == "" {
klog.V(5).InfoS("[Debug] Retrieving region from metadata service")
cfg := metadata.MetadataServiceConfig{
EC2MetadataClient: metadata.DefaultEC2MetadataClient,
K8sAPIClient: metadata.DefaultKubernetesAPIClient,
}
metadata, metadataErr := metadata.NewMetadataService(cfg, region)
if metadataErr != nil {
klog.ErrorS(err, "Could not determine region from any metadata service. The region can be manually supplied via the AWS_REGION environment variable.")
panic(err)
}
region = metadata.GetRegion()
}

klog.InfoS("batching", "status", options.Batching)
cloud, err := cloud.NewCloud(region, options.AwsSdkDebugLog, options.UserAgentExtra, options.Batching)
if err != nil {
klog.ErrorS(err, "failed to create cloud service")
klog.FlushAndExit(klog.ExitFlushTimeout, 1)
}

drv, err := driver.NewDriver(cloud, &options)
if err != nil {
klog.ErrorS(err, "failed to create driver")
klog.FlushAndExit(klog.ExitFlushTimeout, 1)
Expand Down
77 changes: 21 additions & 56 deletions pkg/driver/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,13 @@ import (
"context"
"errors"
"fmt"
"os"
"strconv"
"strings"

"github.com/aws/aws-sdk-go-v2/aws/arn"
"github.com/awslabs/volume-modifier-for-k8s/pkg/rpc"
csi "github.com/container-storage-interface/spec/lib/go/csi"
"github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud"
"github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud/metadata"
"github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/driver/internal"
"github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/util"
"github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/util/template"
Expand Down Expand Up @@ -58,59 +56,26 @@ var (

const isManagedByDriver = "true"

// controllerService represents the controller service of CSI driver
type controllerService struct {
// ControllerService represents the controller service of CSI driver
type ControllerService struct {
cloud cloud.Cloud
inFlight *internal.InFlight
options *Options
modifyVolumeManager *modifyVolumeManager

rpc.UnimplementedModifyServer
}

var (
// NewMetadataFunc is a variable for the cloud.NewMetadata function that can
// be overwritten in unit tests.
NewMetadataFunc = metadata.NewMetadataService
// NewCloudFunc is a variable for the cloud.NewCloud function that can
// be overwritten in unit tests.
NewCloudFunc = cloud.NewCloud
)

// newControllerService creates a new controller service
// it panics if failed to create the service
func newControllerService(o *Options) controllerService {
region := os.Getenv("AWS_REGION")
if region == "" {
klog.V(5).InfoS("[Debug] Retrieving region from metadata service")

cfg := metadata.MetadataServiceConfig{
EC2MetadataClient: metadata.DefaultEC2MetadataClient,
K8sAPIClient: metadata.DefaultKubernetesAPIClient,
}
metadata, err := NewMetadataFunc(cfg, region)
if err != nil {
klog.ErrorS(err, "Could not determine region from any metadata service. The region can be manually supplied via the AWS_REGION environment variable.")
panic(err)
}
region = metadata.GetRegion()
}

klog.InfoS("batching", "status", o.Batching)
cloudSrv, err := NewCloudFunc(region, o.AwsSdkDebugLog, o.UserAgentExtra, o.Batching)
if err != nil {
panic(err)
}

return controllerService{
cloud: cloudSrv,
inFlight: internal.NewInFlight(),
// NewControllerService creates a new controller service
func NewControllerService(c cloud.Cloud, o *Options) *ControllerService {
return &ControllerService{
cloud: c,
options: o,
inFlight: internal.NewInFlight(),
modifyVolumeManager: newModifyVolumeManager(),
}
}

func (d *controllerService) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) (*csi.CreateVolumeResponse, error) {
func (d *ControllerService) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) (*csi.CreateVolumeResponse, error) {
klog.V(4).InfoS("CreateVolume: called", "args", *req)
if err := validateCreateVolumeRequest(req); err != nil {
return nil, err
Expand Down Expand Up @@ -399,7 +364,7 @@ func validateCreateVolumeRequest(req *csi.CreateVolumeRequest) error {
return nil
}

func (d *controllerService) DeleteVolume(ctx context.Context, req *csi.DeleteVolumeRequest) (*csi.DeleteVolumeResponse, error) {
func (d *ControllerService) DeleteVolume(ctx context.Context, req *csi.DeleteVolumeRequest) (*csi.DeleteVolumeResponse, error) {
klog.V(4).InfoS("DeleteVolume: called", "args", *req)
if err := validateDeleteVolumeRequest(req); err != nil {
return nil, err
Expand Down Expand Up @@ -431,7 +396,7 @@ func validateDeleteVolumeRequest(req *csi.DeleteVolumeRequest) error {
return nil
}

func (d *controllerService) ControllerPublishVolume(ctx context.Context, req *csi.ControllerPublishVolumeRequest) (*csi.ControllerPublishVolumeResponse, error) {
func (d *ControllerService) ControllerPublishVolume(ctx context.Context, req *csi.ControllerPublishVolumeRequest) (*csi.ControllerPublishVolumeResponse, error) {
klog.V(4).InfoS("ControllerPublishVolume: called", "args", *req)
if err := validateControllerPublishVolumeRequest(req); err != nil {
return nil, err
Expand Down Expand Up @@ -480,7 +445,7 @@ func validateControllerPublishVolumeRequest(req *csi.ControllerPublishVolumeRequ
return nil
}

func (d *controllerService) ControllerUnpublishVolume(ctx context.Context, req *csi.ControllerUnpublishVolumeRequest) (*csi.ControllerUnpublishVolumeResponse, error) {
func (d *ControllerService) ControllerUnpublishVolume(ctx context.Context, req *csi.ControllerUnpublishVolumeRequest) (*csi.ControllerUnpublishVolumeResponse, error) {
klog.V(4).InfoS("ControllerUnpublishVolume: called", "args", *req)

if err := validateControllerUnpublishVolumeRequest(req); err != nil {
Expand Down Expand Up @@ -520,7 +485,7 @@ func validateControllerUnpublishVolumeRequest(req *csi.ControllerUnpublishVolume
return nil
}

func (d *controllerService) ControllerGetCapabilities(ctx context.Context, req *csi.ControllerGetCapabilitiesRequest) (*csi.ControllerGetCapabilitiesResponse, error) {
func (d *ControllerService) ControllerGetCapabilities(ctx context.Context, req *csi.ControllerGetCapabilitiesRequest) (*csi.ControllerGetCapabilitiesResponse, error) {
klog.V(4).InfoS("ControllerGetCapabilities: called", "args", *req)
var caps []*csi.ControllerServiceCapability
for _, cap := range controllerCaps {
Expand All @@ -536,17 +501,17 @@ func (d *controllerService) ControllerGetCapabilities(ctx context.Context, req *
return &csi.ControllerGetCapabilitiesResponse{Capabilities: caps}, nil
}

func (d *controllerService) GetCapacity(ctx context.Context, req *csi.GetCapacityRequest) (*csi.GetCapacityResponse, error) {
func (d *ControllerService) GetCapacity(ctx context.Context, req *csi.GetCapacityRequest) (*csi.GetCapacityResponse, error) {
klog.V(4).InfoS("GetCapacity: called", "args", *req)
return nil, status.Error(codes.Unimplemented, "")
}

func (d *controllerService) ListVolumes(ctx context.Context, req *csi.ListVolumesRequest) (*csi.ListVolumesResponse, error) {
func (d *ControllerService) ListVolumes(ctx context.Context, req *csi.ListVolumesRequest) (*csi.ListVolumesResponse, error) {
klog.V(4).InfoS("ListVolumes: called", "args", *req)
return nil, status.Error(codes.Unimplemented, "")
}

func (d *controllerService) ValidateVolumeCapabilities(ctx context.Context, req *csi.ValidateVolumeCapabilitiesRequest) (*csi.ValidateVolumeCapabilitiesResponse, error) {
func (d *ControllerService) ValidateVolumeCapabilities(ctx context.Context, req *csi.ValidateVolumeCapabilitiesRequest) (*csi.ValidateVolumeCapabilitiesResponse, error) {
klog.V(4).InfoS("ValidateVolumeCapabilities: called", "args", *req)
volumeID := req.GetVolumeId()
if len(volumeID) == 0 {
Expand Down Expand Up @@ -574,7 +539,7 @@ func (d *controllerService) ValidateVolumeCapabilities(ctx context.Context, req
}, nil
}

func (d *controllerService) ControllerExpandVolume(ctx context.Context, req *csi.ControllerExpandVolumeRequest) (*csi.ControllerExpandVolumeResponse, error) {
func (d *ControllerService) ControllerExpandVolume(ctx context.Context, req *csi.ControllerExpandVolumeRequest) (*csi.ControllerExpandVolumeResponse, error) {
klog.V(4).InfoS("ControllerExpandVolume: called", "args", *req)
volumeID := req.GetVolumeId()
if len(volumeID) == 0 {
Expand Down Expand Up @@ -627,7 +592,7 @@ func (d *controllerService) ControllerExpandVolume(ctx context.Context, req *csi
}, nil
}

func (d *controllerService) ControllerModifyVolume(ctx context.Context, req *csi.ControllerModifyVolumeRequest) (*csi.ControllerModifyVolumeResponse, error) {
func (d *ControllerService) ControllerModifyVolume(ctx context.Context, req *csi.ControllerModifyVolumeRequest) (*csi.ControllerModifyVolumeResponse, error) {
klog.V(4).InfoS("ControllerModifyVolume: called", "args", *req)

volumeID := req.GetVolumeId()
Expand All @@ -648,7 +613,7 @@ func (d *controllerService) ControllerModifyVolume(ctx context.Context, req *csi
return &csi.ControllerModifyVolumeResponse{}, nil
}

func (d *controllerService) ControllerGetVolume(ctx context.Context, req *csi.ControllerGetVolumeRequest) (*csi.ControllerGetVolumeResponse, error) {
func (d *ControllerService) ControllerGetVolume(ctx context.Context, req *csi.ControllerGetVolumeRequest) (*csi.ControllerGetVolumeResponse, error) {
klog.V(4).InfoS("ControllerGetVolume: called", "args", *req)
return nil, status.Error(codes.Unimplemented, "")
}
Expand Down Expand Up @@ -706,7 +671,7 @@ func isValidVolumeContext(volContext map[string]string) bool {
return true
}

func (d *controllerService) CreateSnapshot(ctx context.Context, req *csi.CreateSnapshotRequest) (*csi.CreateSnapshotResponse, error) {
func (d *ControllerService) CreateSnapshot(ctx context.Context, req *csi.CreateSnapshotRequest) (*csi.CreateSnapshotResponse, error) {
klog.V(4).InfoS("CreateSnapshot: called", "args", req)
if err := validateCreateSnapshotRequest(req); err != nil {
return nil, err
Expand Down Expand Up @@ -835,7 +800,7 @@ func validateCreateSnapshotRequest(req *csi.CreateSnapshotRequest) error {
return nil
}

func (d *controllerService) DeleteSnapshot(ctx context.Context, req *csi.DeleteSnapshotRequest) (*csi.DeleteSnapshotResponse, error) {
func (d *ControllerService) DeleteSnapshot(ctx context.Context, req *csi.DeleteSnapshotRequest) (*csi.DeleteSnapshotResponse, error) {
klog.V(4).InfoS("DeleteSnapshot: called", "args", req)
if err := validateDeleteSnapshotRequest(req); err != nil {
return nil, err
Expand Down Expand Up @@ -868,7 +833,7 @@ func validateDeleteSnapshotRequest(req *csi.DeleteSnapshotRequest) error {
return nil
}

func (d *controllerService) ListSnapshots(ctx context.Context, req *csi.ListSnapshotsRequest) (*csi.ListSnapshotsResponse, error) {
func (d *ControllerService) ListSnapshots(ctx context.Context, req *csi.ListSnapshotsRequest) (*csi.ListSnapshotsResponse, error) {
klog.V(4).InfoS("ListSnapshots: called", "args", req)
var snapshots []*cloud.Snapshot

Expand Down
12 changes: 6 additions & 6 deletions pkg/driver/controller_modify_volume.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func (h *modifyVolumeRequestHandler) mergeModifyVolumeRequest(r *modifyVolumeReq
// the ec2 API call to the CSI Driver main thread via response channels.
// This method receives requests from CSI driver main thread via the request channel. When a new request is received from the request channel, we first
// validate the new request. If the new request is acceptable, it will be merged with the existing request for the volume.
func (d *controllerService) processModifyVolumeRequests(h *modifyVolumeRequestHandler, responseChans []chan modifyVolumeResponse) {
func (d *ControllerService) processModifyVolumeRequests(h *modifyVolumeRequestHandler, responseChans []chan modifyVolumeResponse) {
klog.V(4).InfoS("Start processing ModifyVolumeRequest for ", "volume ID", h.volumeID)
process := func(req *modifyVolumeRequest) {
if err := h.validateModifyVolumeRequest(req); err != nil {
Expand Down Expand Up @@ -167,7 +167,7 @@ func (d *controllerService) processModifyVolumeRequests(h *modifyVolumeRequestHa
// If there’s ModifyVolumeRequestHandler for the volume, meaning that there is inflight request(s) for the volume, we will send the new request
// to the goroutine for the volume via the receiving channel.
// Note that each volume with inflight requests has their own goroutine which follows timeout schedule of their own.
func (d *controllerService) addModifyVolumeRequest(volumeID string, r *modifyVolumeRequest) {
func (d *ControllerService) addModifyVolumeRequest(volumeID string, r *modifyVolumeRequest) {
requestHandler := newModifyVolumeRequestHandler(volumeID, r)
handler, loaded := d.modifyVolumeManager.requestHandlerMap.LoadOrStore(volumeID, requestHandler)
if loaded {
Expand All @@ -179,7 +179,7 @@ func (d *controllerService) addModifyVolumeRequest(volumeID string, r *modifyVol
}
}

func (d *controllerService) executeModifyVolumeRequest(volumeID string, req *modifyVolumeRequest) (int32, error) {
func (d *ControllerService) executeModifyVolumeRequest(volumeID string, req *modifyVolumeRequest) (int32, error) {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
actualSizeGiB, err := d.cloud.ResizeOrModifyDisk(ctx, volumeID, req.newSize, &req.modifyDiskOptions)
Expand All @@ -190,14 +190,14 @@ func (d *controllerService) executeModifyVolumeRequest(volumeID string, req *mod
}
}

func (d *controllerService) GetCSIDriverModificationCapability(
func (d *ControllerService) GetCSIDriverModificationCapability(
_ context.Context,
_ *rpc.GetCSIDriverModificationCapabilityRequest,
) (*rpc.GetCSIDriverModificationCapabilityResponse, error) {
return &rpc.GetCSIDriverModificationCapabilityResponse{}, nil
}

func (d *controllerService) ModifyVolumeProperties(
func (d *ControllerService) ModifyVolumeProperties(
ctx context.Context,
req *rpc.ModifyVolumePropertiesRequest,
) (*rpc.ModifyVolumePropertiesResponse, error) {
Expand Down Expand Up @@ -260,7 +260,7 @@ func parseModifyVolumeParameters(params map[string]string) (*cloud.ModifyDiskOpt
return &options, nil
}

func (d *controllerService) modifyVolumeWithCoalescing(ctx context.Context, volume string, options *cloud.ModifyDiskOptions) error {
func (d *ControllerService) modifyVolumeWithCoalescing(ctx context.Context, volume string, options *cloud.ModifyDiskOptions) error {
responseChan := make(chan modifyVolumeResponse)
request := modifyVolumeRequest{
modifyDiskOptions: *options,
Expand Down
Loading

0 comments on commit c249057

Please sign in to comment.