From 6ee80c3696df786143570bd88b2330e92bcc9aac Mon Sep 17 00:00:00 2001 From: Weiwei Li Date: Thu, 31 Oct 2024 11:06:20 -0700 Subject: [PATCH] Fix: order must match specified subnets --- pkg/service/model_build_load_balancer.go | 28 +++++-- pkg/service/model_build_load_balancer_test.go | 79 ++++++++++++++++++- 2 files changed, 101 insertions(+), 6 deletions(-) diff --git a/pkg/service/model_build_load_balancer.go b/pkg/service/model_build_load_balancer.go index 6657a4188..4c55b8656 100644 --- a/pkg/service/model_build_load_balancer.go +++ b/pkg/service/model_build_load_balancer.go @@ -300,11 +300,17 @@ func (t *defaultModelBuildTask) buildLoadBalancerTags(ctx context.Context) (map[ func (t *defaultModelBuildTask) buildLoadBalancerSubnetMappings(_ context.Context, ipAddressType elbv2model.IPAddressType, scheme elbv2model.LoadBalancerScheme, ec2Subnets []ec2types.Subnet) ([]elbv2model.SubnetMapping, error) { var eipAllocation []string eipConfigured := t.annotationParser.ParseStringSliceAnnotation(annotations.SvcLBSuffixEIPAllocations, &eipAllocation, t.service.Annotations) + var subnetsId []string + subnetIDConfigured := t.annotationParser.ParseStringSliceAnnotation(annotations.SvcLBSuffixSubnets, &subnetsId, t.service.Annotations) + reorderedSubnets := append(ec2Subnets[:0:0], ec2Subnets...) + if subnetIDConfigured { + reorderedSubnets = sortSubnetsBySubnetIdConfigured(subnetsId, ec2Subnets) + } if eipConfigured { if scheme != elbv2model.LoadBalancerSchemeInternetFacing { return nil, errors.Errorf("EIP allocations can only be set for internet facing load balancers") } - if len(eipAllocation) != len(ec2Subnets) { + if len(eipAllocation) != len(reorderedSubnets) { return nil, errors.Errorf("count of EIP allocations (%d) and subnets (%d) must match", len(eipAllocation), len(ec2Subnets)) } } @@ -317,7 +323,7 @@ func (t *defaultModelBuildTask) buildLoadBalancerSubnetMappings(_ context.Contex return nil, errors.Errorf("private IPv4 addresses can only be set for internal load balancers") } // TODO: consider relax this requirement as ELBv2 API don't require every subnet to have IPv4 address specified. - if len(rawIPv4Addresses) != len(ec2Subnets) { + if len(rawIPv4Addresses) != len(reorderedSubnets) { return nil, errors.Errorf("count of private IPv4 addresses (%d) and subnets (%d) must match", len(rawIPv4Addresses), len(ec2Subnets)) } for _, rawIPv4Address := range rawIPv4Addresses { @@ -340,7 +346,7 @@ func (t *defaultModelBuildTask) buildLoadBalancerSubnetMappings(_ context.Contex return nil, errors.Errorf("IPv6 addresses can only be set for dualstack load balancers") } // TODO: consider relax this requirement as ELBv2 API don't require every subnet to have IPv6 address specified. - if len(rawIPv6Addresses) != len(ec2Subnets) { + if len(rawIPv6Addresses) != len(reorderedSubnets) { return nil, errors.Errorf("count of IPv6 addresses (%d) and subnets (%d) must match", len(rawIPv6Addresses), len(ec2Subnets)) } for _, rawIPv6Address := range rawIPv6Addresses { @@ -355,8 +361,8 @@ func (t *defaultModelBuildTask) buildLoadBalancerSubnetMappings(_ context.Contex } } - subnetMappings := make([]elbv2model.SubnetMapping, 0, len(ec2Subnets)) - for idx, subnet := range ec2Subnets { + subnetMappings := make([]elbv2model.SubnetMapping, 0, len(reorderedSubnets)) + for idx, subnet := range reorderedSubnets { mapping := elbv2model.SubnetMapping{ SubnetID: awssdk.ToString(subnet.SubnetId), } @@ -534,3 +540,15 @@ func (t *defaultModelBuildTask) buildLoadBalancerName(_ context.Context, scheme sanitizedName := invalidLoadBalancerNamePattern.ReplaceAllString(t.service.Name, "") return fmt.Sprintf("k8s-%.8s-%.8s-%.10s", sanitizedNamespace, sanitizedName, uuid), nil } + +func sortSubnetsBySubnetIdConfigured(subnetId []string, ec2Subnets []ec2types.Subnet) []ec2types.Subnet { + subnetIndex := make(map[string]int) + for index, id := range subnetId { + subnetIndex[id] = index + } + sortedSubnets := append(ec2Subnets[:0:0], ec2Subnets...) + sort.Slice(sortedSubnets, func(i, j int) bool { + return subnetIndex[*sortedSubnets[i].SubnetId] < subnetIndex[*sortedSubnets[j].SubnetId] + }) + return sortedSubnets +} diff --git a/pkg/service/model_build_load_balancer_test.go b/pkg/service/model_build_load_balancer_test.go index 65551cce1..f7ce792e2 100644 --- a/pkg/service/model_build_load_balancer_test.go +++ b/pkg/service/model_build_load_balancer_test.go @@ -3,9 +3,11 @@ package service import ( "context" "errors" + "testing" + + awssdk "github.com/aws/aws-sdk-go-v2/aws" ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" elbv2types "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" - "testing" "k8s.io/apimachinery/pkg/util/sets" "sigs.k8s.io/aws-load-balancer-controller/pkg/model/core" @@ -1445,3 +1447,78 @@ func Test_defaultModelBuildTask_buildLoadBalancerName(t *testing.T) { }) } } + +func Test_sortSubnetsBySubnetIdConfigured(t *testing.T) { + type args struct { + subnetId []string + ec2Subnets []ec2types.Subnet + } + tests := []struct { + name string + args args + wantSubnets []ec2types.Subnet + }{ + { + name: "Sorting", + args: args{ + subnetId: []string{"subnet-a", "subnet-b", "subnet-c"}, + ec2Subnets: []ec2types.Subnet{ + { + SubnetId: awssdk.String("subnet-b"), + }, + { + SubnetId: awssdk.String("subnet-a"), + }, + { + SubnetId: awssdk.String("subnet-c"), + }, + }, + }, + wantSubnets: []ec2types.Subnet{ + { + SubnetId: awssdk.String("subnet-a"), + }, + { + SubnetId: awssdk.String("subnet-b"), + }, + { + SubnetId: awssdk.String("subnet-c"), + }, + }, + }, + { + name: "sorted subnets", + args: args{ + subnetId: []string{"subnet-a", "subnet-b", "subnet-c"}, + ec2Subnets: []ec2types.Subnet{ + { + SubnetId: awssdk.String("subnet-a"), + }, + { + SubnetId: awssdk.String("subnet-b"), + }, + { + SubnetId: awssdk.String("subnet-c"), + }, + }, + }, + wantSubnets: []ec2types.Subnet{ + { + SubnetId: awssdk.String("subnet-a"), + }, + { + SubnetId: awssdk.String("subnet-b"), + }, + { + SubnetId: awssdk.String("subnet-c"), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := sortSubnetsBySubnetIdConfigured(tt.args.subnetId, tt.args.ec2Subnets) + assert.Equal(t, tt.wantSubnets, result) + }) + } +}