diff --git a/main.go b/main.go index c58ebe64e..e3426cd69 100644 --- a/main.go +++ b/main.go @@ -81,7 +81,7 @@ func main() { ctrl.SetLogger(appLogger) klog.SetLoggerWithOptions(appLogger, klog.ContextualLogger(true)) - cloud, err := aws.NewCloud(controllerCFG.AWSConfig, metrics.Registry, ctrl.Log) + cloud, err := aws.NewCloud(controllerCFG.AWSConfig, metrics.Registry, ctrl.Log, nil) if err != nil { setupLog.Error(err, "unable to initialize AWS cloud") os.Exit(1) diff --git a/pkg/aws/cloud.go b/pkg/aws/cloud.go index 1ebe084da..0a0ee6509 100644 --- a/pkg/aws/cloud.go +++ b/pkg/aws/cloud.go @@ -24,6 +24,7 @@ import ( "github.com/prometheus/client_golang/prometheus" amerrors "k8s.io/apimachinery/pkg/util/errors" epresolver "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/provider" "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" ) @@ -59,7 +60,7 @@ type Cloud interface { } // NewCloud constructs new Cloud implementation. -func NewCloud(cfg CloudConfig, metricsRegisterer prometheus.Registerer, logger logr.Logger) (Cloud, error) { +func NewCloud(cfg CloudConfig, metricsRegisterer prometheus.Registerer, logger logr.Logger, awsClientsProvider provider.AWSClientsProvider) (Cloud, error) { hasIPv4 := true addrs, err := net.InterfaceAddrs() if err == nil { @@ -129,7 +130,14 @@ func NewCloud(cfg CloudConfig, metricsRegisterer prometheus.Registerer, logger l awsConfig.APIOptions = metrics.WithSDKMetricCollector(metricsCollector, awsConfig.APIOptions) } - ec2Service := services.NewEC2(awsConfig, endpointsResolver) + if awsClientsProvider == nil { + var err error + awsClientsProvider, err = NewDefaultAWSClientsProvider(awsConfig, endpointsResolver) + if err != nil { + return nil, errors.Wrap(err, "failed to create aws clients provider") + } + } + ec2Service := services.NewEC2(awsClientsProvider) vpcID, err := getVpcID(cfg, ec2Service, ec2Metadata, logger) if err != nil { diff --git a/pkg/aws/default_aws_clients_provider.go b/pkg/aws/default_aws_clients_provider.go new file mode 100644 index 000000000..5acad300c --- /dev/null +++ b/pkg/aws/default_aws_clients_provider.go @@ -0,0 +1,31 @@ +package aws + +import ( + "context" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints" +) + +type defaultAWSClientsProvider struct { + ec2Client *ec2.Client + elbv2Client *elasticloadbalancingv2.Client +} + +func NewDefaultAWSClientsProvider(cfg aws.Config, endpointsResolver *endpoints.Resolver) (*defaultAWSClientsProvider, error) { + customEndpoint := endpointsResolver.EndpointFor(ec2.ServiceID) + ec2Client := ec2.NewFromConfig(cfg, func(o *ec2.Options) { + if customEndpoint != nil { + o.BaseEndpoint = customEndpoint + } + }) + return &defaultAWSClientsProvider{ + ec2Client: ec2Client, + elbv2Client: nil, + }, nil +} + +func (p *defaultAWSClientsProvider) GetEC2Client(ctx context.Context, operationName string) (*ec2.Client, error) { + return p.ec2Client, nil +} diff --git a/pkg/aws/provider/provider.go b/pkg/aws/provider/provider.go new file mode 100644 index 000000000..33734e736 --- /dev/null +++ b/pkg/aws/provider/provider.go @@ -0,0 +1,10 @@ +package provider + +import ( + "context" + "github.com/aws/aws-sdk-go-v2/service/ec2" +) + +type AWSClientsProvider interface { + GetEC2Client(ctx context.Context, operationName string) (*ec2.Client, error) +} diff --git a/pkg/aws/services/ec2.go b/pkg/aws/services/ec2.go index f52969cc5..4cb7d4b6b 100644 --- a/pkg/aws/services/ec2.go +++ b/pkg/aws/services/ec2.go @@ -2,10 +2,9 @@ package services import ( "context" - "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" - "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/provider" ) type EC2 interface { @@ -37,28 +36,31 @@ type EC2 interface { } // NewEC2 constructs new EC2 implementation. -func NewEC2(cfg aws.Config, endpointsResolver *endpoints.Resolver) EC2 { - customEndpoint := endpointsResolver.EndpointFor(ec2.ServiceID) +func NewEC2(awsClientsProvider provider.AWSClientsProvider) EC2 { return &ec2Client{ - ec2Client: ec2.NewFromConfig(cfg, func(o *ec2.Options) { - if customEndpoint != nil { - o.BaseEndpoint = customEndpoint - } - }), + awsClientsProvider: awsClientsProvider, } } type ec2Client struct { - ec2Client *ec2.Client + awsClientsProvider provider.AWSClientsProvider } func (c *ec2Client) DescribeInstancesWithContext(ctx context.Context, input *ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error) { - return c.ec2Client.DescribeInstances(ctx, input) + client, err := c.awsClientsProvider.GetEC2Client(ctx, "DescribeInstances") + if err != nil { + return nil, err + } + return client.DescribeInstances(ctx, input) } func (c *ec2Client) DescribeInstancesAsList(ctx context.Context, input *ec2.DescribeInstancesInput) ([]types.Instance, error) { var result []types.Instance - paginator := ec2.NewDescribeInstancesPaginator(c.ec2Client, input) + client, err := c.awsClientsProvider.GetEC2Client(ctx, "DescribeInstances") + if err != nil { + return nil, err + } + paginator := ec2.NewDescribeInstancesPaginator(client, input) for paginator.HasMorePages() { output, err := paginator.NextPage(ctx) if err != nil { @@ -73,7 +75,11 @@ func (c *ec2Client) DescribeInstancesAsList(ctx context.Context, input *ec2.Desc func (c *ec2Client) DescribeNetworkInterfacesAsList(ctx context.Context, input *ec2.DescribeNetworkInterfacesInput) ([]types.NetworkInterface, error) { var result []types.NetworkInterface - paginator := ec2.NewDescribeNetworkInterfacesPaginator(c.ec2Client, input) + client, err := c.awsClientsProvider.GetEC2Client(ctx, "DescribeNetworkInterfaces") + if err != nil { + return nil, err + } + paginator := ec2.NewDescribeNetworkInterfacesPaginator(client, input) for paginator.HasMorePages() { output, err := paginator.NextPage(ctx) if err != nil { @@ -86,7 +92,11 @@ func (c *ec2Client) DescribeNetworkInterfacesAsList(ctx context.Context, input * func (c *ec2Client) DescribeSecurityGroupsAsList(ctx context.Context, input *ec2.DescribeSecurityGroupsInput) ([]types.SecurityGroup, error) { var result []types.SecurityGroup - paginator := ec2.NewDescribeSecurityGroupsPaginator(c.ec2Client, input) + client, err := c.awsClientsProvider.GetEC2Client(ctx, "DescribeSecurityGroups") + if err != nil { + return nil, err + } + paginator := ec2.NewDescribeSecurityGroupsPaginator(client, input) for paginator.HasMorePages() { output, err := paginator.NextPage(ctx) if err != nil { @@ -99,7 +109,11 @@ func (c *ec2Client) DescribeSecurityGroupsAsList(ctx context.Context, input *ec2 func (c *ec2Client) DescribeSubnetsAsList(ctx context.Context, input *ec2.DescribeSubnetsInput) ([]types.Subnet, error) { var result []types.Subnet - paginator := ec2.NewDescribeSubnetsPaginator(c.ec2Client, input) + client, err := c.awsClientsProvider.GetEC2Client(ctx, "DescribeSubnets") + if err != nil { + return nil, err + } + paginator := ec2.NewDescribeSubnetsPaginator(client, input) for paginator.HasMorePages() { output, err := paginator.NextPage(ctx) if err != nil { @@ -112,7 +126,11 @@ func (c *ec2Client) DescribeSubnetsAsList(ctx context.Context, input *ec2.Descri func (c *ec2Client) DescribeVPCsAsList(ctx context.Context, input *ec2.DescribeVpcsInput) ([]types.Vpc, error) { var result []types.Vpc - paginator := ec2.NewDescribeVpcsPaginator(c.ec2Client, input) + client, err := c.awsClientsProvider.GetEC2Client(ctx, "DescribeVPCs") + if err != nil { + return nil, err + } + paginator := ec2.NewDescribeVpcsPaginator(client, input) for paginator.HasMorePages() { output, err := paginator.NextPage(ctx) if err != nil { @@ -124,33 +142,65 @@ func (c *ec2Client) DescribeVPCsAsList(ctx context.Context, input *ec2.DescribeV } func (c *ec2Client) CreateTagsWithContext(ctx context.Context, input *ec2.CreateTagsInput) (*ec2.CreateTagsOutput, error) { - return c.ec2Client.CreateTags(ctx, input) + client, err := c.awsClientsProvider.GetEC2Client(ctx, "CreateTags") + if err != nil { + return nil, err + } + return client.CreateTags(ctx, input) } func (c *ec2Client) DeleteTagsWithContext(ctx context.Context, input *ec2.DeleteTagsInput) (*ec2.DeleteTagsOutput, error) { - return c.ec2Client.DeleteTags(ctx, input) + client, err := c.awsClientsProvider.GetEC2Client(ctx, "DeleteTags") + if err != nil { + return nil, err + } + return client.DeleteTags(ctx, input) } func (c *ec2Client) CreateSecurityGroupWithContext(ctx context.Context, input *ec2.CreateSecurityGroupInput) (*ec2.CreateSecurityGroupOutput, error) { - return c.ec2Client.CreateSecurityGroup(ctx, input) + client, err := c.awsClientsProvider.GetEC2Client(ctx, "CreateSecurityGroup") + if err != nil { + return nil, err + } + return client.CreateSecurityGroup(ctx, input) } func (c *ec2Client) DeleteSecurityGroupWithContext(ctx context.Context, input *ec2.DeleteSecurityGroupInput) (*ec2.DeleteSecurityGroupOutput, error) { - return c.ec2Client.DeleteSecurityGroup(ctx, input) + client, err := c.awsClientsProvider.GetEC2Client(ctx, "DeleteSecurityGroup") + if err != nil { + return nil, err + } + return client.DeleteSecurityGroup(ctx, input) } func (c *ec2Client) AuthorizeSecurityGroupIngressWithContext(ctx context.Context, input *ec2.AuthorizeSecurityGroupIngressInput) (*ec2.AuthorizeSecurityGroupIngressOutput, error) { - return c.ec2Client.AuthorizeSecurityGroupIngress(ctx, input) + client, err := c.awsClientsProvider.GetEC2Client(ctx, "AuthorizeSecurityGroupIngress") + if err != nil { + return nil, err + } + return client.AuthorizeSecurityGroupIngress(ctx, input) } func (c *ec2Client) RevokeSecurityGroupIngressWithContext(ctx context.Context, input *ec2.RevokeSecurityGroupIngressInput) (*ec2.RevokeSecurityGroupIngressOutput, error) { - return c.ec2Client.RevokeSecurityGroupIngress(ctx, input) + client, err := c.awsClientsProvider.GetEC2Client(ctx, "RevokeSecurityGroupIngress") + if err != nil { + return nil, err + } + return client.RevokeSecurityGroupIngress(ctx, input) } func (c *ec2Client) DescribeAvailabilityZonesWithContext(ctx context.Context, input *ec2.DescribeAvailabilityZonesInput) (*ec2.DescribeAvailabilityZonesOutput, error) { - return c.ec2Client.DescribeAvailabilityZones(ctx, input) + client, err := c.awsClientsProvider.GetEC2Client(ctx, "DescribeAvailabilityZones") + if err != nil { + return nil, err + } + return client.DescribeAvailabilityZones(ctx, input) } func (c *ec2Client) DescribeVpcsWithContext(ctx context.Context, input *ec2.DescribeVpcsInput) (*ec2.DescribeVpcsOutput, error) { - return c.ec2Client.DescribeVpcs(ctx, input) + client, err := c.awsClientsProvider.GetEC2Client(ctx, "DescribeVpcs") + if err != nil { + return nil, err + } + return client.DescribeVpcs(ctx, input) } diff --git a/test/framework/framework.go b/test/framework/framework.go index 4402f16d1..52171817a 100644 --- a/test/framework/framework.go +++ b/test/framework/framework.go @@ -62,7 +62,7 @@ func InitFramework() (*Framework, error) { VpcID: globalOptions.AWSVPCID, MaxRetries: 3, ThrottleConfig: throttle.NewDefaultServiceOperationsThrottleConfig(), - }, nil, logger) + }, nil, logger, nil) if err != nil { return nil, err }