Skip to content

Commit

Permalink
refactor aws cloud service
Browse files Browse the repository at this point in the history
  • Loading branch information
oliviassss committed Oct 16, 2024
1 parent 0353352 commit 78f7edc
Show file tree
Hide file tree
Showing 10 changed files with 433 additions and 156 deletions.
15 changes: 7 additions & 8 deletions pkg/aws/cloud.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func NewCloud(cfg CloudConfig, metricsRegisterer prometheus.Registerer, logger l

if awsClientsProvider == nil {
var err error
awsClientsProvider, err = NewDefaultAWSClientsProvider(awsConfig, endpointsResolver)
awsClientsProvider, err = provider.NewDefaultAWSClientsProvider(awsConfig, endpointsResolver)
if err != nil {
return nil, errors.Wrap(err, "failed to create aws clients provider")
}
Expand All @@ -147,17 +147,16 @@ func NewCloud(cfg CloudConfig, metricsRegisterer prometheus.Registerer, logger l
return &defaultCloud{
cfg: cfg,
ec2: ec2Service,
elbv2: services.NewELBV2(awsConfig, endpointsResolver),
acm: services.NewACM(awsConfig, endpointsResolver),
wafv2: services.NewWAFv2(awsConfig, endpointsResolver),
wafRegional: services.NewWAFRegional(awsConfig, endpointsResolver, cfg.Region),
shield: services.NewShield(awsConfig, endpointsResolver), //done
rgt: services.NewRGT(awsConfig, endpointsResolver),
elbv2: services.NewELBV2(awsClientsProvider),
acm: services.NewACM(awsClientsProvider),
wafv2: services.NewWAFv2(awsClientsProvider),
wafRegional: services.NewWAFRegional(awsClientsProvider, cfg.Region),
shield: services.NewShield(awsClientsProvider),
rgt: services.NewRGT(awsClientsProvider),
}, nil
}

func getVpcID(cfg CloudConfig, ec2Service services.EC2, ec2Metadata services.EC2Metadata, logger logr.Logger) (string, error) {

if cfg.VpcID != "" {
logger.V(1).Info("vpcid is specified using flag --aws-vpc-id, controller will use the value", "vpc: ", cfg.VpcID)
return cfg.VpcID, nil
Expand Down
31 changes: 0 additions & 31 deletions pkg/aws/default_aws_clients_provider.go

This file was deleted.

109 changes: 109 additions & 0 deletions pkg/aws/provider/default_aws_clients_provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package provider

import (
"context"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/acm"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2"
"github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi"
"github.com/aws/aws-sdk-go-v2/service/shield"
"github.com/aws/aws-sdk-go-v2/service/wafregional"
"github.com/aws/aws-sdk-go-v2/service/wafv2"
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints"
)

type defaultAWSClientsProvider struct {
ec2Client *ec2.Client
elbv2Client *elasticloadbalancingv2.Client
acmClient *acm.Client
wafv2Client *wafv2.Client
wafRegionClient *wafregional.Client
shieldClient *shield.Client
rgtClient *resourcegroupstaggingapi.Client
}

func NewDefaultAWSClientsProvider(cfg aws.Config, endpointsResolver *endpoints.Resolver) (*defaultAWSClientsProvider, error) {
ec2CustomEndpoint := endpointsResolver.EndpointFor(ec2.ServiceID)
elbv2CustomEndpoint := endpointsResolver.EndpointFor(elasticloadbalancingv2.ServiceID)
acmCustomEndpoint := endpointsResolver.EndpointFor(acm.ServiceID)
wafv2CustomEndpoint := endpointsResolver.EndpointFor(wafv2.ServiceID)
wafregionalCustomEndpoint := endpointsResolver.EndpointFor(wafregional.ServiceID)
shieldCustomEndpoint := endpointsResolver.EndpointFor(shield.ServiceID)
rgtCustomEndpoint := endpointsResolver.EndpointFor(resourcegroupstaggingapi.ServiceID)

ec2Client := ec2.NewFromConfig(cfg, func(o *ec2.Options) {
if ec2CustomEndpoint != nil {
o.BaseEndpoint = ec2CustomEndpoint
}
})
elbv2Client := elasticloadbalancingv2.NewFromConfig(cfg, func(o *elasticloadbalancingv2.Options) {
if elbv2CustomEndpoint != nil {
o.BaseEndpoint = elbv2CustomEndpoint
}
})
acmClient := acm.NewFromConfig(cfg, func(o *acm.Options) {
if acmCustomEndpoint != nil {
o.BaseEndpoint = acmCustomEndpoint
}
})
wafv2Client := wafv2.NewFromConfig(cfg, func(o *wafv2.Options) {
if wafv2CustomEndpoint != nil {
o.BaseEndpoint = wafv2CustomEndpoint
}
})
wafregionalClient := wafregional.NewFromConfig(cfg, func(o *wafregional.Options) {
o.Region = cfg.Region
o.BaseEndpoint = wafregionalCustomEndpoint
})
sheildClient := shield.NewFromConfig(cfg, func(o *shield.Options) {
o.Region = "us-east-1"
o.BaseEndpoint = shieldCustomEndpoint
})
rgtClient := resourcegroupstaggingapi.NewFromConfig(cfg, func(o *resourcegroupstaggingapi.Options) {
if rgtCustomEndpoint != nil {
o.BaseEndpoint = rgtCustomEndpoint
}
})

return &defaultAWSClientsProvider{
ec2Client: ec2Client,
elbv2Client: elbv2Client,
acmClient: acmClient,
wafv2Client: wafv2Client,
wafRegionClient: wafregionalClient,
shieldClient: sheildClient,
rgtClient: rgtClient,
}, nil
}

// DO NOT REMOVE operationName as parameter, this is on purpose
// to retain the default behavior for OSS controller to use the default client for each aws service
// for our internal controller, we will choose different client based on operationName
func (p *defaultAWSClientsProvider) GetEC2Client(ctx context.Context, operationName string) (*ec2.Client, error) {
return p.ec2Client, nil
}

func (p *defaultAWSClientsProvider) GetELBV2Client(ctx context.Context, operationName string) (*elasticloadbalancingv2.Client, error) {
return p.elbv2Client, nil
}

func (p *defaultAWSClientsProvider) GetACMClient(ctx context.Context, operationName string) (*acm.Client, error) {
return p.acmClient, nil
}

func (p *defaultAWSClientsProvider) GetWAFv2Client(ctx context.Context, operationName string) (*wafv2.Client, error) {
return p.wafv2Client, nil
}

func (p *defaultAWSClientsProvider) GetWAFRegionClient(ctx context.Context, operationName string) (*wafregional.Client, error) {
return p.wafRegionClient, nil
}

func (p *defaultAWSClientsProvider) GetShieldClient(ctx context.Context, operationName string) (*shield.Client, error) {
return p.shieldClient, nil
}

func (p *defaultAWSClientsProvider) GetRGTClient(ctx context.Context, operationName string) (*resourcegroupstaggingapi.Client, error) {
return p.rgtClient, nil
}
12 changes: 12 additions & 0 deletions pkg/aws/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,21 @@ package provider

import (
"context"
"github.com/aws/aws-sdk-go-v2/service/acm"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2"
"github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi"
"github.com/aws/aws-sdk-go-v2/service/shield"
"github.com/aws/aws-sdk-go-v2/service/wafregional"
"github.com/aws/aws-sdk-go-v2/service/wafv2"
)

type AWSClientsProvider interface {
GetEC2Client(ctx context.Context, operationName string) (*ec2.Client, error)
GetELBV2Client(ctx context.Context, operationName string) (*elasticloadbalancingv2.Client, error)
GetACMClient(ctx context.Context, operationName string) (*acm.Client, error)
GetWAFv2Client(ctx context.Context, operationName string) (*wafv2.Client, error)
GetWAFRegionClient(ctx context.Context, operationName string) (*wafregional.Client, error)
GetShieldClient(ctx context.Context, operationName string) (*shield.Client, error)
GetRGTClient(ctx context.Context, operationName string) (*resourcegroupstaggingapi.Client, error)
}
26 changes: 14 additions & 12 deletions pkg/aws/services/acm.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@ package services

import (
"context"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/acm"
"github.com/aws/aws-sdk-go-v2/service/acm/types"
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints"
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/provider"
)

type ACM interface {
Expand All @@ -15,24 +14,23 @@ type ACM interface {
}

// NewACM constructs new ACM implementation.
func NewACM(cfg aws.Config, endpointsResolver *endpoints.Resolver) ACM {
customEndpoint := endpointsResolver.EndpointFor(acm.ServiceID)
func NewACM(awsClientsProvider provider.AWSClientsProvider) ACM {
return &acmClient{
acmClient: acm.NewFromConfig(cfg, func(o *acm.Options) {
if customEndpoint != nil {
o.BaseEndpoint = customEndpoint
}
}),
awsClientsProvider: awsClientsProvider,
}
}

type acmClient struct {
acmClient *acm.Client
awsClientsProvider provider.AWSClientsProvider
}

func (c *acmClient) ListCertificatesAsList(ctx context.Context, input *acm.ListCertificatesInput) ([]types.CertificateSummary, error) {
var result []types.CertificateSummary
paginator := acm.NewListCertificatesPaginator(c.acmClient, input)
client, err := c.awsClientsProvider.GetACMClient(ctx, "ListCertificates")
if err != nil {
return nil, err
}
paginator := acm.NewListCertificatesPaginator(client, input)
for paginator.HasMorePages() {
output, err := paginator.NextPage(ctx)
if err != nil {
Expand All @@ -44,5 +42,9 @@ func (c *acmClient) ListCertificatesAsList(ctx context.Context, input *acm.ListC
}

func (c *acmClient) DescribeCertificateWithContext(ctx context.Context, input *acm.DescribeCertificateInput) (*acm.DescribeCertificateOutput, error) {
return c.acmClient.DescribeCertificate(ctx, input)
client, err := c.awsClientsProvider.GetACMClient(ctx, "DescribeCertificate")
if err != nil {
return nil, err
}
return client.DescribeCertificate(ctx, input)
}
Loading

0 comments on commit 78f7edc

Please sign in to comment.