Skip to content

Commit

Permalink
add support for aws clients provider
Browse files Browse the repository at this point in the history
  • Loading branch information
oliviassss committed Oct 16, 2024
1 parent ebc3c25 commit 0353352
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 28 deletions.
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func main() {
ctrl.SetLogger(appLogger)
klog.SetLoggerWithOptions(appLogger, klog.ContextualLogger(true))

cloud, err := aws.NewCloud(controllerCFG.AWSConfig, metrics.Registry, ctrl.Log)
cloud, err := aws.NewCloud(controllerCFG.AWSConfig, metrics.Registry, ctrl.Log, nil)
if err != nil {
setupLog.Error(err, "unable to initialize AWS cloud")
os.Exit(1)
Expand Down
12 changes: 10 additions & 2 deletions pkg/aws/cloud.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/prometheus/client_golang/prometheus"
amerrors "k8s.io/apimachinery/pkg/util/errors"
epresolver "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints"
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/provider"
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services"
)

Expand Down Expand Up @@ -59,7 +60,7 @@ type Cloud interface {
}

// NewCloud constructs new Cloud implementation.
func NewCloud(cfg CloudConfig, metricsRegisterer prometheus.Registerer, logger logr.Logger) (Cloud, error) {
func NewCloud(cfg CloudConfig, metricsRegisterer prometheus.Registerer, logger logr.Logger, awsClientsProvider provider.AWSClientsProvider) (Cloud, error) {
hasIPv4 := true
addrs, err := net.InterfaceAddrs()
if err == nil {
Expand Down Expand Up @@ -129,7 +130,14 @@ func NewCloud(cfg CloudConfig, metricsRegisterer prometheus.Registerer, logger l
awsConfig.APIOptions = metrics.WithSDKMetricCollector(metricsCollector, awsConfig.APIOptions)
}

ec2Service := services.NewEC2(awsConfig, endpointsResolver)
if awsClientsProvider == nil {
var err error
awsClientsProvider, err = NewDefaultAWSClientsProvider(awsConfig, endpointsResolver)
if err != nil {
return nil, errors.Wrap(err, "failed to create aws clients provider")
}
}
ec2Service := services.NewEC2(awsClientsProvider)

vpcID, err := getVpcID(cfg, ec2Service, ec2Metadata, logger)
if err != nil {
Expand Down
31 changes: 31 additions & 0 deletions pkg/aws/default_aws_clients_provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package aws

import (
"context"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2"
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints"
)

type defaultAWSClientsProvider struct {
ec2Client *ec2.Client
elbv2Client *elasticloadbalancingv2.Client
}

func NewDefaultAWSClientsProvider(cfg aws.Config, endpointsResolver *endpoints.Resolver) (*defaultAWSClientsProvider, error) {
customEndpoint := endpointsResolver.EndpointFor(ec2.ServiceID)
ec2Client := ec2.NewFromConfig(cfg, func(o *ec2.Options) {
if customEndpoint != nil {
o.BaseEndpoint = customEndpoint
}
})
return &defaultAWSClientsProvider{
ec2Client: ec2Client,
elbv2Client: nil,
}, nil
}

func (p *defaultAWSClientsProvider) GetEC2Client(ctx context.Context, operationName string) (*ec2.Client, error) {
return p.ec2Client, nil
}
10 changes: 10 additions & 0 deletions pkg/aws/provider/provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package provider

import (
"context"
"github.com/aws/aws-sdk-go-v2/service/ec2"
)

type AWSClientsProvider interface {
GetEC2Client(ctx context.Context, operationName string) (*ec2.Client, error)
}
98 changes: 74 additions & 24 deletions pkg/aws/services/ec2.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@ package services

import (
"context"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints"
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/provider"
)

type EC2 interface {
Expand Down Expand Up @@ -37,28 +36,31 @@ type EC2 interface {
}

// NewEC2 constructs new EC2 implementation.
func NewEC2(cfg aws.Config, endpointsResolver *endpoints.Resolver) EC2 {
customEndpoint := endpointsResolver.EndpointFor(ec2.ServiceID)
func NewEC2(awsClientsProvider provider.AWSClientsProvider) EC2 {
return &ec2Client{
ec2Client: ec2.NewFromConfig(cfg, func(o *ec2.Options) {
if customEndpoint != nil {
o.BaseEndpoint = customEndpoint
}
}),
awsClientsProvider: awsClientsProvider,
}
}

type ec2Client struct {
ec2Client *ec2.Client
awsClientsProvider provider.AWSClientsProvider
}

func (c *ec2Client) DescribeInstancesWithContext(ctx context.Context, input *ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error) {
return c.ec2Client.DescribeInstances(ctx, input)
client, err := c.awsClientsProvider.GetEC2Client(ctx, "DescribeInstances")
if err != nil {
return nil, err
}
return client.DescribeInstances(ctx, input)
}

func (c *ec2Client) DescribeInstancesAsList(ctx context.Context, input *ec2.DescribeInstancesInput) ([]types.Instance, error) {
var result []types.Instance
paginator := ec2.NewDescribeInstancesPaginator(c.ec2Client, input)
client, err := c.awsClientsProvider.GetEC2Client(ctx, "DescribeInstances")
if err != nil {
return nil, err
}
paginator := ec2.NewDescribeInstancesPaginator(client, input)
for paginator.HasMorePages() {
output, err := paginator.NextPage(ctx)
if err != nil {
Expand All @@ -73,7 +75,11 @@ func (c *ec2Client) DescribeInstancesAsList(ctx context.Context, input *ec2.Desc

func (c *ec2Client) DescribeNetworkInterfacesAsList(ctx context.Context, input *ec2.DescribeNetworkInterfacesInput) ([]types.NetworkInterface, error) {
var result []types.NetworkInterface
paginator := ec2.NewDescribeNetworkInterfacesPaginator(c.ec2Client, input)
client, err := c.awsClientsProvider.GetEC2Client(ctx, "DescribeNetworkInterfaces")
if err != nil {
return nil, err
}
paginator := ec2.NewDescribeNetworkInterfacesPaginator(client, input)
for paginator.HasMorePages() {
output, err := paginator.NextPage(ctx)
if err != nil {
Expand All @@ -86,7 +92,11 @@ func (c *ec2Client) DescribeNetworkInterfacesAsList(ctx context.Context, input *

func (c *ec2Client) DescribeSecurityGroupsAsList(ctx context.Context, input *ec2.DescribeSecurityGroupsInput) ([]types.SecurityGroup, error) {
var result []types.SecurityGroup
paginator := ec2.NewDescribeSecurityGroupsPaginator(c.ec2Client, input)
client, err := c.awsClientsProvider.GetEC2Client(ctx, "DescribeSecurityGroups")
if err != nil {
return nil, err
}
paginator := ec2.NewDescribeSecurityGroupsPaginator(client, input)
for paginator.HasMorePages() {
output, err := paginator.NextPage(ctx)
if err != nil {
Expand All @@ -99,7 +109,11 @@ func (c *ec2Client) DescribeSecurityGroupsAsList(ctx context.Context, input *ec2

func (c *ec2Client) DescribeSubnetsAsList(ctx context.Context, input *ec2.DescribeSubnetsInput) ([]types.Subnet, error) {
var result []types.Subnet
paginator := ec2.NewDescribeSubnetsPaginator(c.ec2Client, input)
client, err := c.awsClientsProvider.GetEC2Client(ctx, "DescribeSubnets")
if err != nil {
return nil, err
}
paginator := ec2.NewDescribeSubnetsPaginator(client, input)
for paginator.HasMorePages() {
output, err := paginator.NextPage(ctx)
if err != nil {
Expand All @@ -112,7 +126,11 @@ func (c *ec2Client) DescribeSubnetsAsList(ctx context.Context, input *ec2.Descri

func (c *ec2Client) DescribeVPCsAsList(ctx context.Context, input *ec2.DescribeVpcsInput) ([]types.Vpc, error) {
var result []types.Vpc
paginator := ec2.NewDescribeVpcsPaginator(c.ec2Client, input)
client, err := c.awsClientsProvider.GetEC2Client(ctx, "DescribeVPCs")
if err != nil {
return nil, err
}
paginator := ec2.NewDescribeVpcsPaginator(client, input)
for paginator.HasMorePages() {
output, err := paginator.NextPage(ctx)
if err != nil {
Expand All @@ -124,33 +142,65 @@ func (c *ec2Client) DescribeVPCsAsList(ctx context.Context, input *ec2.DescribeV
}

func (c *ec2Client) CreateTagsWithContext(ctx context.Context, input *ec2.CreateTagsInput) (*ec2.CreateTagsOutput, error) {
return c.ec2Client.CreateTags(ctx, input)
client, err := c.awsClientsProvider.GetEC2Client(ctx, "CreateTags")
if err != nil {
return nil, err
}
return client.CreateTags(ctx, input)
}

func (c *ec2Client) DeleteTagsWithContext(ctx context.Context, input *ec2.DeleteTagsInput) (*ec2.DeleteTagsOutput, error) {
return c.ec2Client.DeleteTags(ctx, input)
client, err := c.awsClientsProvider.GetEC2Client(ctx, "DeleteTags")
if err != nil {
return nil, err
}
return client.DeleteTags(ctx, input)
}

func (c *ec2Client) CreateSecurityGroupWithContext(ctx context.Context, input *ec2.CreateSecurityGroupInput) (*ec2.CreateSecurityGroupOutput, error) {
return c.ec2Client.CreateSecurityGroup(ctx, input)
client, err := c.awsClientsProvider.GetEC2Client(ctx, "CreateSecurityGroup")
if err != nil {
return nil, err
}
return client.CreateSecurityGroup(ctx, input)
}

func (c *ec2Client) DeleteSecurityGroupWithContext(ctx context.Context, input *ec2.DeleteSecurityGroupInput) (*ec2.DeleteSecurityGroupOutput, error) {
return c.ec2Client.DeleteSecurityGroup(ctx, input)
client, err := c.awsClientsProvider.GetEC2Client(ctx, "DeleteSecurityGroup")
if err != nil {
return nil, err
}
return client.DeleteSecurityGroup(ctx, input)
}

func (c *ec2Client) AuthorizeSecurityGroupIngressWithContext(ctx context.Context, input *ec2.AuthorizeSecurityGroupIngressInput) (*ec2.AuthorizeSecurityGroupIngressOutput, error) {
return c.ec2Client.AuthorizeSecurityGroupIngress(ctx, input)
client, err := c.awsClientsProvider.GetEC2Client(ctx, "AuthorizeSecurityGroupIngress")
if err != nil {
return nil, err
}
return client.AuthorizeSecurityGroupIngress(ctx, input)
}

func (c *ec2Client) RevokeSecurityGroupIngressWithContext(ctx context.Context, input *ec2.RevokeSecurityGroupIngressInput) (*ec2.RevokeSecurityGroupIngressOutput, error) {
return c.ec2Client.RevokeSecurityGroupIngress(ctx, input)
client, err := c.awsClientsProvider.GetEC2Client(ctx, "RevokeSecurityGroupIngress")
if err != nil {
return nil, err
}
return client.RevokeSecurityGroupIngress(ctx, input)
}

func (c *ec2Client) DescribeAvailabilityZonesWithContext(ctx context.Context, input *ec2.DescribeAvailabilityZonesInput) (*ec2.DescribeAvailabilityZonesOutput, error) {
return c.ec2Client.DescribeAvailabilityZones(ctx, input)
client, err := c.awsClientsProvider.GetEC2Client(ctx, "DescribeAvailabilityZones")
if err != nil {
return nil, err
}
return client.DescribeAvailabilityZones(ctx, input)
}

func (c *ec2Client) DescribeVpcsWithContext(ctx context.Context, input *ec2.DescribeVpcsInput) (*ec2.DescribeVpcsOutput, error) {
return c.ec2Client.DescribeVpcs(ctx, input)
client, err := c.awsClientsProvider.GetEC2Client(ctx, "DescribeVpcs")
if err != nil {
return nil, err
}
return client.DescribeVpcs(ctx, input)
}
2 changes: 1 addition & 1 deletion test/framework/framework.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func InitFramework() (*Framework, error) {
VpcID: globalOptions.AWSVPCID,
MaxRetries: 3,
ThrottleConfig: throttle.NewDefaultServiceOperationsThrottleConfig(),
}, nil, logger)
}, nil, logger, nil)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 0353352

Please sign in to comment.