From 78f7edcf7976182af7691abe296960d34b75ba87 Mon Sep 17 00:00:00 2001 From: oliviassss Date: Wed, 16 Oct 2024 14:40:24 -0700 Subject: [PATCH] refactor aws cloud service --- pkg/aws/cloud.go | 15 +- pkg/aws/default_aws_clients_provider.go | 31 --- .../provider/default_aws_clients_provider.go | 109 ++++++++ pkg/aws/provider/provider.go | 12 + pkg/aws/services/acm.go | 26 +- pkg/aws/services/elbv2.go | 263 ++++++++++++++---- pkg/aws/services/rgt.go | 22 +- pkg/aws/services/shield.go | 41 ++- pkg/aws/services/wafregional.go | 35 ++- pkg/aws/services/wafv2.go | 35 ++- 10 files changed, 433 insertions(+), 156 deletions(-) delete mode 100644 pkg/aws/default_aws_clients_provider.go create mode 100644 pkg/aws/provider/default_aws_clients_provider.go diff --git a/pkg/aws/cloud.go b/pkg/aws/cloud.go index 0a0ee6509..cc679b943 100644 --- a/pkg/aws/cloud.go +++ b/pkg/aws/cloud.go @@ -132,7 +132,7 @@ func NewCloud(cfg CloudConfig, metricsRegisterer prometheus.Registerer, logger l if awsClientsProvider == nil { var err error - awsClientsProvider, err = NewDefaultAWSClientsProvider(awsConfig, endpointsResolver) + awsClientsProvider, err = provider.NewDefaultAWSClientsProvider(awsConfig, endpointsResolver) if err != nil { return nil, errors.Wrap(err, "failed to create aws clients provider") } @@ -147,17 +147,16 @@ func NewCloud(cfg CloudConfig, metricsRegisterer prometheus.Registerer, logger l return &defaultCloud{ cfg: cfg, ec2: ec2Service, - elbv2: services.NewELBV2(awsConfig, endpointsResolver), - acm: services.NewACM(awsConfig, endpointsResolver), - wafv2: services.NewWAFv2(awsConfig, endpointsResolver), - wafRegional: services.NewWAFRegional(awsConfig, endpointsResolver, cfg.Region), - shield: services.NewShield(awsConfig, endpointsResolver), //done - rgt: services.NewRGT(awsConfig, endpointsResolver), + elbv2: services.NewELBV2(awsClientsProvider), + acm: services.NewACM(awsClientsProvider), + wafv2: services.NewWAFv2(awsClientsProvider), + wafRegional: services.NewWAFRegional(awsClientsProvider, cfg.Region), + shield: services.NewShield(awsClientsProvider), + rgt: services.NewRGT(awsClientsProvider), }, nil } func getVpcID(cfg CloudConfig, ec2Service services.EC2, ec2Metadata services.EC2Metadata, logger logr.Logger) (string, error) { - if cfg.VpcID != "" { logger.V(1).Info("vpcid is specified using flag --aws-vpc-id, controller will use the value", "vpc: ", cfg.VpcID) return cfg.VpcID, nil diff --git a/pkg/aws/default_aws_clients_provider.go b/pkg/aws/default_aws_clients_provider.go deleted file mode 100644 index 5acad300c..000000000 --- a/pkg/aws/default_aws_clients_provider.go +++ /dev/null @@ -1,31 +0,0 @@ -package aws - -import ( - "context" - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/ec2" - "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" - "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints" -) - -type defaultAWSClientsProvider struct { - ec2Client *ec2.Client - elbv2Client *elasticloadbalancingv2.Client -} - -func NewDefaultAWSClientsProvider(cfg aws.Config, endpointsResolver *endpoints.Resolver) (*defaultAWSClientsProvider, error) { - customEndpoint := endpointsResolver.EndpointFor(ec2.ServiceID) - ec2Client := ec2.NewFromConfig(cfg, func(o *ec2.Options) { - if customEndpoint != nil { - o.BaseEndpoint = customEndpoint - } - }) - return &defaultAWSClientsProvider{ - ec2Client: ec2Client, - elbv2Client: nil, - }, nil -} - -func (p *defaultAWSClientsProvider) GetEC2Client(ctx context.Context, operationName string) (*ec2.Client, error) { - return p.ec2Client, nil -} diff --git a/pkg/aws/provider/default_aws_clients_provider.go b/pkg/aws/provider/default_aws_clients_provider.go new file mode 100644 index 000000000..25752e2a0 --- /dev/null +++ b/pkg/aws/provider/default_aws_clients_provider.go @@ -0,0 +1,109 @@ +package provider + +import ( + "context" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/acm" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" + "github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi" + "github.com/aws/aws-sdk-go-v2/service/shield" + "github.com/aws/aws-sdk-go-v2/service/wafregional" + "github.com/aws/aws-sdk-go-v2/service/wafv2" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints" +) + +type defaultAWSClientsProvider struct { + ec2Client *ec2.Client + elbv2Client *elasticloadbalancingv2.Client + acmClient *acm.Client + wafv2Client *wafv2.Client + wafRegionClient *wafregional.Client + shieldClient *shield.Client + rgtClient *resourcegroupstaggingapi.Client +} + +func NewDefaultAWSClientsProvider(cfg aws.Config, endpointsResolver *endpoints.Resolver) (*defaultAWSClientsProvider, error) { + ec2CustomEndpoint := endpointsResolver.EndpointFor(ec2.ServiceID) + elbv2CustomEndpoint := endpointsResolver.EndpointFor(elasticloadbalancingv2.ServiceID) + acmCustomEndpoint := endpointsResolver.EndpointFor(acm.ServiceID) + wafv2CustomEndpoint := endpointsResolver.EndpointFor(wafv2.ServiceID) + wafregionalCustomEndpoint := endpointsResolver.EndpointFor(wafregional.ServiceID) + shieldCustomEndpoint := endpointsResolver.EndpointFor(shield.ServiceID) + rgtCustomEndpoint := endpointsResolver.EndpointFor(resourcegroupstaggingapi.ServiceID) + + ec2Client := ec2.NewFromConfig(cfg, func(o *ec2.Options) { + if ec2CustomEndpoint != nil { + o.BaseEndpoint = ec2CustomEndpoint + } + }) + elbv2Client := elasticloadbalancingv2.NewFromConfig(cfg, func(o *elasticloadbalancingv2.Options) { + if elbv2CustomEndpoint != nil { + o.BaseEndpoint = elbv2CustomEndpoint + } + }) + acmClient := acm.NewFromConfig(cfg, func(o *acm.Options) { + if acmCustomEndpoint != nil { + o.BaseEndpoint = acmCustomEndpoint + } + }) + wafv2Client := wafv2.NewFromConfig(cfg, func(o *wafv2.Options) { + if wafv2CustomEndpoint != nil { + o.BaseEndpoint = wafv2CustomEndpoint + } + }) + wafregionalClient := wafregional.NewFromConfig(cfg, func(o *wafregional.Options) { + o.Region = cfg.Region + o.BaseEndpoint = wafregionalCustomEndpoint + }) + sheildClient := shield.NewFromConfig(cfg, func(o *shield.Options) { + o.Region = "us-east-1" + o.BaseEndpoint = shieldCustomEndpoint + }) + rgtClient := resourcegroupstaggingapi.NewFromConfig(cfg, func(o *resourcegroupstaggingapi.Options) { + if rgtCustomEndpoint != nil { + o.BaseEndpoint = rgtCustomEndpoint + } + }) + + return &defaultAWSClientsProvider{ + ec2Client: ec2Client, + elbv2Client: elbv2Client, + acmClient: acmClient, + wafv2Client: wafv2Client, + wafRegionClient: wafregionalClient, + shieldClient: sheildClient, + rgtClient: rgtClient, + }, nil +} + +// DO NOT REMOVE operationName as parameter, this is on purpose +// to retain the default behavior for OSS controller to use the default client for each aws service +// for our internal controller, we will choose different client based on operationName +func (p *defaultAWSClientsProvider) GetEC2Client(ctx context.Context, operationName string) (*ec2.Client, error) { + return p.ec2Client, nil +} + +func (p *defaultAWSClientsProvider) GetELBV2Client(ctx context.Context, operationName string) (*elasticloadbalancingv2.Client, error) { + return p.elbv2Client, nil +} + +func (p *defaultAWSClientsProvider) GetACMClient(ctx context.Context, operationName string) (*acm.Client, error) { + return p.acmClient, nil +} + +func (p *defaultAWSClientsProvider) GetWAFv2Client(ctx context.Context, operationName string) (*wafv2.Client, error) { + return p.wafv2Client, nil +} + +func (p *defaultAWSClientsProvider) GetWAFRegionClient(ctx context.Context, operationName string) (*wafregional.Client, error) { + return p.wafRegionClient, nil +} + +func (p *defaultAWSClientsProvider) GetShieldClient(ctx context.Context, operationName string) (*shield.Client, error) { + return p.shieldClient, nil +} + +func (p *defaultAWSClientsProvider) GetRGTClient(ctx context.Context, operationName string) (*resourcegroupstaggingapi.Client, error) { + return p.rgtClient, nil +} diff --git a/pkg/aws/provider/provider.go b/pkg/aws/provider/provider.go index 33734e736..f2fbfbbca 100644 --- a/pkg/aws/provider/provider.go +++ b/pkg/aws/provider/provider.go @@ -2,9 +2,21 @@ package provider import ( "context" + "github.com/aws/aws-sdk-go-v2/service/acm" "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" + "github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi" + "github.com/aws/aws-sdk-go-v2/service/shield" + "github.com/aws/aws-sdk-go-v2/service/wafregional" + "github.com/aws/aws-sdk-go-v2/service/wafv2" ) type AWSClientsProvider interface { GetEC2Client(ctx context.Context, operationName string) (*ec2.Client, error) + GetELBV2Client(ctx context.Context, operationName string) (*elasticloadbalancingv2.Client, error) + GetACMClient(ctx context.Context, operationName string) (*acm.Client, error) + GetWAFv2Client(ctx context.Context, operationName string) (*wafv2.Client, error) + GetWAFRegionClient(ctx context.Context, operationName string) (*wafregional.Client, error) + GetShieldClient(ctx context.Context, operationName string) (*shield.Client, error) + GetRGTClient(ctx context.Context, operationName string) (*resourcegroupstaggingapi.Client, error) } diff --git a/pkg/aws/services/acm.go b/pkg/aws/services/acm.go index 78ad10fd3..eab8e4319 100644 --- a/pkg/aws/services/acm.go +++ b/pkg/aws/services/acm.go @@ -2,10 +2,9 @@ package services import ( "context" - "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/acm" "github.com/aws/aws-sdk-go-v2/service/acm/types" - "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/provider" ) type ACM interface { @@ -15,24 +14,23 @@ type ACM interface { } // NewACM constructs new ACM implementation. -func NewACM(cfg aws.Config, endpointsResolver *endpoints.Resolver) ACM { - customEndpoint := endpointsResolver.EndpointFor(acm.ServiceID) +func NewACM(awsClientsProvider provider.AWSClientsProvider) ACM { return &acmClient{ - acmClient: acm.NewFromConfig(cfg, func(o *acm.Options) { - if customEndpoint != nil { - o.BaseEndpoint = customEndpoint - } - }), + awsClientsProvider: awsClientsProvider, } } type acmClient struct { - acmClient *acm.Client + awsClientsProvider provider.AWSClientsProvider } func (c *acmClient) ListCertificatesAsList(ctx context.Context, input *acm.ListCertificatesInput) ([]types.CertificateSummary, error) { var result []types.CertificateSummary - paginator := acm.NewListCertificatesPaginator(c.acmClient, input) + client, err := c.awsClientsProvider.GetACMClient(ctx, "ListCertificates") + if err != nil { + return nil, err + } + paginator := acm.NewListCertificatesPaginator(client, input) for paginator.HasMorePages() { output, err := paginator.NextPage(ctx) if err != nil { @@ -44,5 +42,9 @@ func (c *acmClient) ListCertificatesAsList(ctx context.Context, input *acm.ListC } func (c *acmClient) DescribeCertificateWithContext(ctx context.Context, input *acm.DescribeCertificateInput) (*acm.DescribeCertificateOutput, error) { - return c.acmClient.DescribeCertificate(ctx, input) + client, err := c.awsClientsProvider.GetACMClient(ctx, "DescribeCertificate") + if err != nil { + return nil, err + } + return client.DescribeCertificate(ctx, input) } diff --git a/pkg/aws/services/elbv2.go b/pkg/aws/services/elbv2.go index 0ff0e7d18..e3983d0e2 100644 --- a/pkg/aws/services/elbv2.go +++ b/pkg/aws/services/elbv2.go @@ -2,12 +2,11 @@ package services import ( "context" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/provider" "time" - "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" - "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints" ) type ELBV2 interface { @@ -62,154 +61,284 @@ type ELBV2 interface { ModifyListenerAttributesWithContext(ctx context.Context, input *elasticloadbalancingv2.ModifyListenerAttributesInput) (*elasticloadbalancingv2.ModifyListenerAttributesOutput, error) } -func NewELBV2(cfg aws.Config, endpointsResolver *endpoints.Resolver) ELBV2 { - customEndpoint := endpointsResolver.EndpointFor(elasticloadbalancingv2.ServiceID) - client := elasticloadbalancingv2.NewFromConfig(cfg, func(o *elasticloadbalancingv2.Options) { - if customEndpoint != nil { - o.BaseEndpoint = customEndpoint - } - }) - return &elbv2Client{elbv2Client: client} +func NewELBV2(awsClientsProvider provider.AWSClientsProvider) ELBV2 { + return &elbv2Client{ + awsClientsProvider: awsClientsProvider, + } } // default implementation for ELBV2. type elbv2Client struct { - elbv2Client *elasticloadbalancingv2.Client + awsClientsProvider provider.AWSClientsProvider } func (c *elbv2Client) AddListenerCertificatesWithContext(ctx context.Context, input *elasticloadbalancingv2.AddListenerCertificatesInput) (*elasticloadbalancingv2.AddListenerCertificatesOutput, error) { - return c.elbv2Client.AddListenerCertificates(ctx, input) + client, err := c.awsClientsProvider.GetELBV2Client(ctx, "AddListenerCertificates") + if err != nil { + return nil, err + } + return client.AddListenerCertificates(ctx, input) } func (c *elbv2Client) RemoveListenerCertificatesWithContext(ctx context.Context, input *elasticloadbalancingv2.RemoveListenerCertificatesInput) (*elasticloadbalancingv2.RemoveListenerCertificatesOutput, error) { - return c.elbv2Client.RemoveListenerCertificates(ctx, input) + client, err := c.awsClientsProvider.GetELBV2Client(ctx, "RemoveListenerCertificates") + if err != nil { + return nil, err + } + return client.RemoveListenerCertificates(ctx, input) } func (c *elbv2Client) DescribeListenersWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeListenersInput) (*elasticloadbalancingv2.DescribeListenersOutput, error) { - return c.elbv2Client.DescribeListeners(ctx, input) + client, err := c.awsClientsProvider.GetELBV2Client(ctx, "DescribeListeners") + if err != nil { + return nil, err + } + return client.DescribeListeners(ctx, input) } func (c *elbv2Client) DescribeRulesWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeRulesInput) (*elasticloadbalancingv2.DescribeRulesOutput, error) { - return c.elbv2Client.DescribeRules(ctx, input) + client, err := c.awsClientsProvider.GetELBV2Client(ctx, "DescribeRules") + if err != nil { + return nil, err + } + return client.DescribeRules(ctx, input) } func (c *elbv2Client) RegisterTargetsWithContext(ctx context.Context, input *elasticloadbalancingv2.RegisterTargetsInput) (*elasticloadbalancingv2.RegisterTargetsOutput, error) { - return c.elbv2Client.RegisterTargets(ctx, input) + client, err := c.awsClientsProvider.GetELBV2Client(ctx, "RegisterTargets") + if err != nil { + return nil, err + } + return client.RegisterTargets(ctx, input) } func (c *elbv2Client) DeregisterTargetsWithContext(ctx context.Context, input *elasticloadbalancingv2.DeregisterTargetsInput) (*elasticloadbalancingv2.DeregisterTargetsOutput, error) { - return c.elbv2Client.DeregisterTargets(ctx, input) + client, err := c.awsClientsProvider.GetELBV2Client(ctx, "DeregisterTargets") + if err != nil { + return nil, err + } + return client.DeregisterTargets(ctx, input) } func (c *elbv2Client) DescribeTrustStoresWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeTrustStoresInput) (*elasticloadbalancingv2.DescribeTrustStoresOutput, error) { - return c.elbv2Client.DescribeTrustStores(ctx, input) + client, err := c.awsClientsProvider.GetELBV2Client(ctx, "DescribeTrustStores") + if err != nil { + return nil, err + } + return client.DescribeTrustStores(ctx, input) } func (c *elbv2Client) ModifyRuleWithContext(ctx context.Context, input *elasticloadbalancingv2.ModifyRuleInput) (*elasticloadbalancingv2.ModifyRuleOutput, error) { - return c.elbv2Client.ModifyRule(ctx, input) + client, err := c.awsClientsProvider.GetELBV2Client(ctx, "ModifyRule") + if err != nil { + return nil, err + } + return client.ModifyRule(ctx, input) } func (c *elbv2Client) DeleteRuleWithContext(ctx context.Context, input *elasticloadbalancingv2.DeleteRuleInput) (*elasticloadbalancingv2.DeleteRuleOutput, error) { - return c.elbv2Client.DeleteRule(ctx, input) + client, err := c.awsClientsProvider.GetELBV2Client(ctx, "DeleteRule") + if err != nil { + return nil, err + } + return client.DeleteRule(ctx, input) } func (c *elbv2Client) CreateRuleWithContext(ctx context.Context, input *elasticloadbalancingv2.CreateRuleInput) (*elasticloadbalancingv2.CreateRuleOutput, error) { - return c.elbv2Client.CreateRule(ctx, input) + client, err := c.awsClientsProvider.GetELBV2Client(ctx, "CreateRule") + if err != nil { + return nil, err + } + return client.CreateRule(ctx, input) } func (c *elbv2Client) WaitUntilLoadBalancerAvailableWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeLoadBalancersInput) error { - waiter := elasticloadbalancingv2.NewLoadBalancerAvailableWaiter(c.elbv2Client) - err := waiter.Wait(ctx, input, 5*time.Minute) + client, err := c.awsClientsProvider.GetELBV2Client(ctx, "DescribeLoadBalancers") + if err != nil { + return err + } + waiter := elasticloadbalancingv2.NewLoadBalancerAvailableWaiter(client) + err = waiter.Wait(ctx, input, 5*time.Minute) return err } func (c *elbv2Client) DescribeLoadBalancersWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeLoadBalancersInput) (*elasticloadbalancingv2.DescribeLoadBalancersOutput, error) { - return c.elbv2Client.DescribeLoadBalancers(ctx, input) + client, err := c.awsClientsProvider.GetELBV2Client(ctx, "DescribeLoadBalancers") + if err != nil { + return nil, err + } + return client.DescribeLoadBalancers(ctx, input) } func (c *elbv2Client) DescribeTargetHealthWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeTargetHealthInput) (*elasticloadbalancingv2.DescribeTargetHealthOutput, error) { - return c.elbv2Client.DescribeTargetHealth(ctx, input) + client, err := c.awsClientsProvider.GetELBV2Client(ctx, "DescribeTargetHealth") + if err != nil { + return nil, err + } + return client.DescribeTargetHealth(ctx, input) } func (c *elbv2Client) DescribeTargetGroupsWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeTargetGroupsInput) (*elasticloadbalancingv2.DescribeTargetGroupsOutput, error) { - return c.elbv2Client.DescribeTargetGroups(ctx, input) + client, err := c.awsClientsProvider.GetELBV2Client(ctx, "DescribeTargetGroups") + if err != nil { + return nil, err + } + return client.DescribeTargetGroups(ctx, input) } func (c *elbv2Client) DeleteTargetGroupWithContext(ctx context.Context, input *elasticloadbalancingv2.DeleteTargetGroupInput) (*elasticloadbalancingv2.DeleteTargetGroupOutput, error) { - return c.elbv2Client.DeleteTargetGroup(ctx, input) + client, err := c.awsClientsProvider.GetELBV2Client(ctx, "DeleteTargetGroup") + if err != nil { + return nil, err + } + return client.DeleteTargetGroup(ctx, input) } func (c *elbv2Client) ModifyTargetGroupWithContext(ctx context.Context, input *elasticloadbalancingv2.ModifyTargetGroupInput) (*elasticloadbalancingv2.ModifyTargetGroupOutput, error) { - return c.elbv2Client.ModifyTargetGroup(ctx, input) + client, err := c.awsClientsProvider.GetELBV2Client(ctx, "ModifyTargetGroup") + if err != nil { + return nil, err + } + return client.ModifyTargetGroup(ctx, input) } func (c *elbv2Client) CreateTargetGroupWithContext(ctx context.Context, input *elasticloadbalancingv2.CreateTargetGroupInput) (*elasticloadbalancingv2.CreateTargetGroupOutput, error) { - return c.elbv2Client.CreateTargetGroup(ctx, input) + client, err := c.awsClientsProvider.GetELBV2Client(ctx, "CreateTargetGroup") + if err != nil { + return nil, err + } + return client.CreateTargetGroup(ctx, input) } func (c *elbv2Client) DescribeTargetGroupAttributesWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeTargetGroupAttributesInput) (*elasticloadbalancingv2.DescribeTargetGroupAttributesOutput, error) { - return c.elbv2Client.DescribeTargetGroupAttributes(ctx, input) + client, err := c.awsClientsProvider.GetELBV2Client(ctx, "DescribeTargetGroupAttributes") + if err != nil { + return nil, err + } + return client.DescribeTargetGroupAttributes(ctx, input) } func (c *elbv2Client) ModifyTargetGroupAttributesWithContext(ctx context.Context, input *elasticloadbalancingv2.ModifyTargetGroupAttributesInput) (*elasticloadbalancingv2.ModifyTargetGroupAttributesOutput, error) { - return c.elbv2Client.ModifyTargetGroupAttributes(ctx, input) + client, err := c.awsClientsProvider.GetELBV2Client(ctx, "ModifyTargetGroupAttributes") + if err != nil { + return nil, err + } + return client.ModifyTargetGroupAttributes(ctx, input) } func (c *elbv2Client) SetSecurityGroupsWithContext(ctx context.Context, input *elasticloadbalancingv2.SetSecurityGroupsInput) (*elasticloadbalancingv2.SetSecurityGroupsOutput, error) { - return c.elbv2Client.SetSecurityGroups(ctx, input) + client, err := c.awsClientsProvider.GetELBV2Client(ctx, "SetSecurityGroups") + if err != nil { + return nil, err + } + return client.SetSecurityGroups(ctx, input) } func (c *elbv2Client) SetSubnetsWithContext(ctx context.Context, input *elasticloadbalancingv2.SetSubnetsInput) (*elasticloadbalancingv2.SetSubnetsOutput, error) { - return c.elbv2Client.SetSubnets(ctx, input) + client, err := c.awsClientsProvider.GetELBV2Client(ctx, "SetSubnets") + if err != nil { + return nil, err + } + return client.SetSubnets(ctx, input) } func (c *elbv2Client) SetIpAddressTypeWithContext(ctx context.Context, input *elasticloadbalancingv2.SetIpAddressTypeInput) (*elasticloadbalancingv2.SetIpAddressTypeOutput, error) { - return c.elbv2Client.SetIpAddressType(ctx, input) + client, err := c.awsClientsProvider.GetELBV2Client(ctx, "SetIpAddressType") + if err != nil { + return nil, err + } + return client.SetIpAddressType(ctx, input) } func (c *elbv2Client) DeleteLoadBalancerWithContext(ctx context.Context, input *elasticloadbalancingv2.DeleteLoadBalancerInput) (*elasticloadbalancingv2.DeleteLoadBalancerOutput, error) { - return c.elbv2Client.DeleteLoadBalancer(ctx, input) + client, err := c.awsClientsProvider.GetELBV2Client(ctx, "DeleteLoadBalancer") + if err != nil { + return nil, err + } + return client.DeleteLoadBalancer(ctx, input) } func (c *elbv2Client) CreateLoadBalancerWithContext(ctx context.Context, input *elasticloadbalancingv2.CreateLoadBalancerInput) (*elasticloadbalancingv2.CreateLoadBalancerOutput, error) { - return c.elbv2Client.CreateLoadBalancer(ctx, input) + client, err := c.awsClientsProvider.GetELBV2Client(ctx, "CreateLoadBalancer") + if err != nil { + return nil, err + } + return client.CreateLoadBalancer(ctx, input) } func (c *elbv2Client) DescribeLoadBalancerAttributesWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeLoadBalancerAttributesInput) (*elasticloadbalancingv2.DescribeLoadBalancerAttributesOutput, error) { - return c.elbv2Client.DescribeLoadBalancerAttributes(ctx, input) + client, err := c.awsClientsProvider.GetELBV2Client(ctx, "DescribeLoadBalancerAttributes") + if err != nil { + return nil, err + } + return client.DescribeLoadBalancerAttributes(ctx, input) } func (c *elbv2Client) ModifyLoadBalancerAttributesWithContext(ctx context.Context, input *elasticloadbalancingv2.ModifyLoadBalancerAttributesInput) (*elasticloadbalancingv2.ModifyLoadBalancerAttributesOutput, error) { - return c.elbv2Client.ModifyLoadBalancerAttributes(ctx, input) + client, err := c.awsClientsProvider.GetELBV2Client(ctx, "ModifyLoadBalancerAttributes") + if err != nil { + return nil, err + } + return client.ModifyLoadBalancerAttributes(ctx, input) } func (c *elbv2Client) ModifyListenerWithContext(ctx context.Context, input *elasticloadbalancingv2.ModifyListenerInput) (*elasticloadbalancingv2.ModifyListenerOutput, error) { - return c.elbv2Client.ModifyListener(ctx, input) + client, err := c.awsClientsProvider.GetELBV2Client(ctx, "ModifyListener") + if err != nil { + return nil, err + } + return client.ModifyListener(ctx, input) } func (c *elbv2Client) DeleteListenerWithContext(ctx context.Context, input *elasticloadbalancingv2.DeleteListenerInput) (*elasticloadbalancingv2.DeleteListenerOutput, error) { - return c.elbv2Client.DeleteListener(ctx, input) + client, err := c.awsClientsProvider.GetELBV2Client(ctx, "DeleteListener") + if err != nil { + return nil, err + } + return client.DeleteListener(ctx, input) } func (c *elbv2Client) CreateListenerWithContext(ctx context.Context, input *elasticloadbalancingv2.CreateListenerInput) (*elasticloadbalancingv2.CreateListenerOutput, error) { - return c.elbv2Client.CreateListener(ctx, input) + client, err := c.awsClientsProvider.GetELBV2Client(ctx, "CreateListener") + if err != nil { + return nil, err + } + return client.CreateListener(ctx, input) } func (c *elbv2Client) DescribeTagsWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeTagsInput) (*elasticloadbalancingv2.DescribeTagsOutput, error) { - return c.elbv2Client.DescribeTags(ctx, input) + client, err := c.awsClientsProvider.GetELBV2Client(ctx, "DescribeTags") + if err != nil { + return nil, err + } + return client.DescribeTags(ctx, input) } func (c *elbv2Client) AddTagsWithContext(ctx context.Context, input *elasticloadbalancingv2.AddTagsInput) (*elasticloadbalancingv2.AddTagsOutput, error) { - return c.elbv2Client.AddTags(ctx, input) + client, err := c.awsClientsProvider.GetELBV2Client(ctx, "AddTags") + if err != nil { + return nil, err + } + return client.AddTags(ctx, input) } func (c *elbv2Client) RemoveTagsWithContext(ctx context.Context, input *elasticloadbalancingv2.RemoveTagsInput) (*elasticloadbalancingv2.RemoveTagsOutput, error) { - return c.elbv2Client.RemoveTags(ctx, input) + client, err := c.awsClientsProvider.GetELBV2Client(ctx, "RemoveTags") + if err != nil { + return nil, err + } + return client.RemoveTags(ctx, input) } func (c *elbv2Client) DescribeLoadBalancersAsList(ctx context.Context, input *elasticloadbalancingv2.DescribeLoadBalancersInput) ([]types.LoadBalancer, error) { var result []types.LoadBalancer - paginator := elasticloadbalancingv2.NewDescribeLoadBalancersPaginator(c.elbv2Client, input) + var client *elasticloadbalancingv2.Client + var err error + client, err = c.awsClientsProvider.GetELBV2Client(ctx, "DescribeLoadBalancers") + if err != nil { + return nil, err + } + paginator := elasticloadbalancingv2.NewDescribeLoadBalancersPaginator(client, input) for paginator.HasMorePages() { output, err := paginator.NextPage(ctx) if err != nil { @@ -222,7 +351,13 @@ func (c *elbv2Client) DescribeLoadBalancersAsList(ctx context.Context, input *el func (c *elbv2Client) DescribeTargetGroupsAsList(ctx context.Context, input *elasticloadbalancingv2.DescribeTargetGroupsInput) ([]types.TargetGroup, error) { var result []types.TargetGroup - paginator := elasticloadbalancingv2.NewDescribeTargetGroupsPaginator(c.elbv2Client, input) + var client *elasticloadbalancingv2.Client + var err error + client, err = c.awsClientsProvider.GetELBV2Client(ctx, "DescribeTargetGroups") + if err != nil { + return nil, err + } + paginator := elasticloadbalancingv2.NewDescribeTargetGroupsPaginator(client, input) for paginator.HasMorePages() { output, err := paginator.NextPage(ctx) if err != nil { @@ -235,7 +370,13 @@ func (c *elbv2Client) DescribeTargetGroupsAsList(ctx context.Context, input *ela func (c *elbv2Client) DescribeListenersAsList(ctx context.Context, input *elasticloadbalancingv2.DescribeListenersInput) ([]types.Listener, error) { var result []types.Listener - paginator := elasticloadbalancingv2.NewDescribeListenersPaginator(c.elbv2Client, input) + var client *elasticloadbalancingv2.Client + var err error + client, err = c.awsClientsProvider.GetELBV2Client(ctx, "DescribeListeners") + if err != nil { + return nil, err + } + paginator := elasticloadbalancingv2.NewDescribeListenersPaginator(client, input) for paginator.HasMorePages() { output, err := paginator.NextPage(ctx) if err != nil { @@ -248,7 +389,13 @@ func (c *elbv2Client) DescribeListenersAsList(ctx context.Context, input *elasti func (c *elbv2Client) DescribeListenerCertificatesAsList(ctx context.Context, input *elasticloadbalancingv2.DescribeListenerCertificatesInput) ([]types.Certificate, error) { var result []types.Certificate - paginator := elasticloadbalancingv2.NewDescribeListenerCertificatesPaginator(c.elbv2Client, input) + var client *elasticloadbalancingv2.Client + var err error + client, err = c.awsClientsProvider.GetELBV2Client(ctx, "DescribeListenerCertificates") + if err != nil { + return nil, err + } + paginator := elasticloadbalancingv2.NewDescribeListenerCertificatesPaginator(client, input) for paginator.HasMorePages() { output, err := paginator.NextPage(ctx) if err != nil { @@ -261,7 +408,13 @@ func (c *elbv2Client) DescribeListenerCertificatesAsList(ctx context.Context, in func (c *elbv2Client) DescribeRulesAsList(ctx context.Context, input *elasticloadbalancingv2.DescribeRulesInput) ([]types.Rule, error) { var result []types.Rule - paginator := elasticloadbalancingv2.NewDescribeRulesPaginator(c.elbv2Client, input) + var client *elasticloadbalancingv2.Client + var err error + client, err = c.awsClientsProvider.GetELBV2Client(ctx, "DescribeRules") + if err != nil { + return nil, err + } + paginator := elasticloadbalancingv2.NewDescribeRulesPaginator(client, input) for paginator.HasMorePages() { output, err := paginator.NextPage(ctx) if err != nil { @@ -273,9 +426,17 @@ func (c *elbv2Client) DescribeRulesAsList(ctx context.Context, input *elasticloa } func (c *elbv2Client) DescribeListenerAttributesWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeListenerAttributesInput) (*elasticloadbalancingv2.DescribeListenerAttributesOutput, error) { - return c.elbv2Client.DescribeListenerAttributes(ctx, input) + client, err := c.awsClientsProvider.GetELBV2Client(ctx, "DescribeListenerAttributes") + if err != nil { + return nil, err + } + return client.DescribeListenerAttributes(ctx, input) } func (c *elbv2Client) ModifyListenerAttributesWithContext(ctx context.Context, input *elasticloadbalancingv2.ModifyListenerAttributesInput) (*elasticloadbalancingv2.ModifyListenerAttributesOutput, error) { - return c.elbv2Client.ModifyListenerAttributes(ctx, input) + client, err := c.awsClientsProvider.GetELBV2Client(ctx, "ModifyListenerAttributes") + if err != nil { + return nil, err + } + return client.ModifyListenerAttributes(ctx, input) } diff --git a/pkg/aws/services/rgt.go b/pkg/aws/services/rgt.go index 1d39a0bf1..1558e0e4e 100644 --- a/pkg/aws/services/rgt.go +++ b/pkg/aws/services/rgt.go @@ -5,7 +5,7 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi" rgttypes "github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi/types" - "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/provider" ) const ( @@ -18,23 +18,23 @@ type RGT interface { } // NewRGT constructs new RGT implementation. -func NewRGT(cfg aws.Config, endpointsResolver *endpoints.Resolver) RGT { - customEndpoint := endpointsResolver.EndpointFor(resourcegroupstaggingapi.ServiceID) - client := resourcegroupstaggingapi.NewFromConfig(cfg, func(o *resourcegroupstaggingapi.Options) { - if customEndpoint != nil { - o.BaseEndpoint = customEndpoint - } - }) - return &rgtClient{rgtClient: client} +func NewRGT(awsClientsProvider provider.AWSClientsProvider) RGT { + return &rgtClient{ + awsClientsProvider: awsClientsProvider, + } } type rgtClient struct { - rgtClient *resourcegroupstaggingapi.Client + awsClientsProvider provider.AWSClientsProvider } func (c *rgtClient) GetResourcesAsList(ctx context.Context, input *resourcegroupstaggingapi.GetResourcesInput) ([]rgttypes.ResourceTagMapping, error) { + client, err := c.awsClientsProvider.GetRGTClient(ctx, "GetResources") + if err != nil { + return nil, err + } var result []rgttypes.ResourceTagMapping - paginator := resourcegroupstaggingapi.NewGetResourcesPaginator(c.rgtClient, input) + paginator := resourcegroupstaggingapi.NewGetResourcesPaginator(client, input) for paginator.HasMorePages() { output, err := paginator.NextPage(ctx) if err != nil { diff --git a/pkg/aws/services/shield.go b/pkg/aws/services/shield.go index 02def20c4..ad97be240 100644 --- a/pkg/aws/services/shield.go +++ b/pkg/aws/services/shield.go @@ -2,9 +2,8 @@ package services import ( "context" - "github.com/aws/aws-sdk-go-v2/aws" shieldsdk "github.com/aws/aws-sdk-go-v2/service/shield" - "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/provider" ) type Shield interface { @@ -15,33 +14,45 @@ type Shield interface { } // NewShield constructs new Shield implementation. -func NewShield(cfg aws.Config, endpointsResolver *endpoints.Resolver) Shield { - customEndpoint := endpointsResolver.EndpointFor(shieldsdk.ServiceID) - // shield is only available as a global API in us-east-1. - client := shieldsdk.NewFromConfig(cfg, func(o *shieldsdk.Options) { - o.Region = "us-east-1" - o.BaseEndpoint = customEndpoint - }) - return &shieldClient{shieldClient: client} +func NewShield(awsClientsProvider provider.AWSClientsProvider) Shield { + return &shieldClient{ + awsClientsProvider: awsClientsProvider, + } } // default implementation for Shield. type shieldClient struct { - shieldClient *shieldsdk.Client + awsClientsProvider provider.AWSClientsProvider } func (s *shieldClient) GetSubscriptionStateWithContext(ctx context.Context, input *shieldsdk.GetSubscriptionStateInput) (*shieldsdk.GetSubscriptionStateOutput, error) { - return s.shieldClient.GetSubscriptionState(ctx, input) + client, err := s.awsClientsProvider.GetShieldClient(ctx, "GetSubscriptionState") + if err != nil { + return nil, err + } + return client.GetSubscriptionState(ctx, input) } func (s *shieldClient) DescribeProtectionWithContext(ctx context.Context, input *shieldsdk.DescribeProtectionInput) (*shieldsdk.DescribeProtectionOutput, error) { - return s.shieldClient.DescribeProtection(ctx, input) + client, err := s.awsClientsProvider.GetShieldClient(ctx, "DescribeProtection") + if err != nil { + return nil, err + } + return client.DescribeProtection(ctx, input) } func (s *shieldClient) CreateProtectionWithContext(ctx context.Context, input *shieldsdk.CreateProtectionInput) (*shieldsdk.CreateProtectionOutput, error) { - return s.shieldClient.CreateProtection(ctx, input) + client, err := s.awsClientsProvider.GetShieldClient(ctx, "CreateProtection") + if err != nil { + return nil, err + } + return client.CreateProtection(ctx, input) } func (s *shieldClient) DeleteProtectionWithContext(ctx context.Context, input *shieldsdk.DeleteProtectionInput) (*shieldsdk.DeleteProtectionOutput, error) { - return s.shieldClient.DeleteProtection(ctx, input) + client, err := s.awsClientsProvider.GetShieldClient(ctx, "DeleteProtection") + if err != nil { + return nil, err + } + return client.DeleteProtection(ctx, input) } diff --git a/pkg/aws/services/wafregional.go b/pkg/aws/services/wafregional.go index e11a81c37..9ffdf0661 100644 --- a/pkg/aws/services/wafregional.go +++ b/pkg/aws/services/wafregional.go @@ -2,9 +2,8 @@ package services import ( "context" - "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/wafregional" - "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/provider" ) type WAFRegional interface { @@ -16,21 +15,17 @@ type WAFRegional interface { } // NewWAFRegional constructs new WAFRegional implementation. -func NewWAFRegional(cfg aws.Config, endpointsResolver *endpoints.Resolver, region string) WAFRegional { - customEndpoint := endpointsResolver.EndpointFor(wafregional.ServiceID) +func NewWAFRegional(awsClientsProvider provider.AWSClientsProvider, region string) WAFRegional { return &wafRegionalClient{ - wafRegionalClient: wafregional.NewFromConfig(cfg, func(o *wafregional.Options) { - o.Region = region - o.BaseEndpoint = customEndpoint - }), - region: region, + awsClientsProvider: awsClientsProvider, + region: region, } } // default implementation for WAFRegional. type wafRegionalClient struct { - wafRegionalClient *wafregional.Client - region string + awsClientsProvider provider.AWSClientsProvider + region string } func (c *wafRegionalClient) Available() bool { @@ -42,13 +37,25 @@ func (c *wafRegionalClient) Available() bool { } func (c *wafRegionalClient) AssociateWebACLWithContext(ctx context.Context, input *wafregional.AssociateWebACLInput) (*wafregional.AssociateWebACLOutput, error) { - return c.wafRegionalClient.AssociateWebACL(ctx, input) + client, err := c.awsClientsProvider.GetWAFRegionClient(ctx, "AssociateWebACL") + if err != nil { + return nil, err + } + return client.AssociateWebACL(ctx, input) } func (c *wafRegionalClient) DisassociateWebACLWithContext(ctx context.Context, input *wafregional.DisassociateWebACLInput) (*wafregional.DisassociateWebACLOutput, error) { - return c.wafRegionalClient.DisassociateWebACL(ctx, input) + client, err := c.awsClientsProvider.GetWAFRegionClient(ctx, "DisassociateWebACL") + if err != nil { + return nil, err + } + return client.DisassociateWebACL(ctx, input) } func (c *wafRegionalClient) GetWebACLForResourceWithContext(ctx context.Context, input *wafregional.GetWebACLForResourceInput) (*wafregional.GetWebACLForResourceOutput, error) { - return c.wafRegionalClient.GetWebACLForResource(ctx, input) + client, err := c.awsClientsProvider.GetWAFRegionClient(ctx, "GetWebACLForResource") + if err != nil { + return nil, err + } + return client.GetWebACLForResource(ctx, input) } diff --git a/pkg/aws/services/wafv2.go b/pkg/aws/services/wafv2.go index 3547a8678..e18640920 100644 --- a/pkg/aws/services/wafv2.go +++ b/pkg/aws/services/wafv2.go @@ -2,9 +2,8 @@ package services import ( "context" - "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/wafv2" - "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/provider" ) type WAFv2 interface { @@ -14,28 +13,36 @@ type WAFv2 interface { } // NewWAFv2 constructs new WAFv2 implementation. -func NewWAFv2(cfg aws.Config, endpointsResolver *endpoints.Resolver) WAFv2 { - customEndpoint := endpointsResolver.EndpointFor(wafv2.ServiceID) - client := wafv2.NewFromConfig(cfg, func(o *wafv2.Options) { - if customEndpoint != nil { - o.BaseEndpoint = customEndpoint - } - }) - return &wafv2Client{wafv2Client: client} +func NewWAFv2(awsClientsProvider provider.AWSClientsProvider) WAFv2 { + return &wafv2Client{ + awsClientsProvider: awsClientsProvider, + } } type wafv2Client struct { - wafv2Client *wafv2.Client + awsClientsProvider provider.AWSClientsProvider } func (c *wafv2Client) AssociateWebACLWithContext(ctx context.Context, req *wafv2.AssociateWebACLInput) (*wafv2.AssociateWebACLOutput, error) { - return c.wafv2Client.AssociateWebACL(ctx, req) + client, err := c.awsClientsProvider.GetWAFv2Client(ctx, "AssociateWebACL") + if err != nil { + return nil, err + } + return client.AssociateWebACL(ctx, req) } func (c *wafv2Client) DisassociateWebACLWithContext(ctx context.Context, req *wafv2.DisassociateWebACLInput) (*wafv2.DisassociateWebACLOutput, error) { - return c.wafv2Client.DisassociateWebACL(ctx, req) + client, err := c.awsClientsProvider.GetWAFv2Client(ctx, "DisassociateWebACL") + if err != nil { + return nil, err + } + return client.DisassociateWebACL(ctx, req) } func (c *wafv2Client) GetWebACLForResourceWithContext(ctx context.Context, req *wafv2.GetWebACLForResourceInput) (*wafv2.GetWebACLForResourceOutput, error) { - return c.wafv2Client.GetWebACLForResource(ctx, req) + client, err := c.awsClientsProvider.GetWAFv2Client(ctx, "GetWebACLForResource") + if err != nil { + return nil, err + } + return client.GetWebACLForResource(ctx, req) }