diff --git a/pkg/algorithm/slices.go b/pkg/algorithm/slices.go new file mode 100644 index 0000000000..a82eae5832 --- /dev/null +++ b/pkg/algorithm/slices.go @@ -0,0 +1,18 @@ +package algorithm + +import "cmp" + +// RemoveSliceDuplicates returns a copy of the slice without duplicate entries. +func RemoveSliceDuplicates[S ~[]E, E cmp.Ordered](s S) []E { + result := make([]E, 0, len(s)) + found := make(map[E]struct{}, len(s)) + + for _, x := range s { + if _, ok := found[x]; !ok { + found[x] = struct{}{} + result = append(result, x) + } + } + + return result +} diff --git a/pkg/algorithm/slices_test.go b/pkg/algorithm/slices_test.go new file mode 100644 index 0000000000..decf9deb6e --- /dev/null +++ b/pkg/algorithm/slices_test.go @@ -0,0 +1,46 @@ +package algorithm + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_RemoveSliceDuplicates(t *testing.T) { + type args struct { + data []string + } + tests := []struct { + name string + args args + want []string + }{ + { + name: "empty", + args: args{ + data: []string{}, + }, + want: []string{}, + }, + { + name: "no duplicate entries", + args: args{ + data: []string{"a", "b", "c", "d"}, + }, + want: []string{"a", "b", "c", "d"}, + }, + { + name: "with duplicates", + args: args{ + data: []string{"a", "b", "a", "c", "b"}, + }, + want: []string{"a", "b", "c"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := RemoveSliceDuplicates(tt.args.data) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/pkg/networking/security_group_resolver.go b/pkg/networking/security_group_resolver.go index 402d1795f0..3b807729d6 100644 --- a/pkg/networking/security_group_resolver.go +++ b/pkg/networking/security_group_resolver.go @@ -7,6 +7,7 @@ import ( awssdk "github.com/aws/aws-sdk-go/aws" ec2sdk "github.com/aws/aws-sdk-go/service/ec2" "github.com/pkg/errors" + "sigs.k8s.io/aws-load-balancer-controller/pkg/algorithm" "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" ) @@ -33,29 +34,38 @@ type defaultSecurityGroupResolver struct { } func (r *defaultSecurityGroupResolver) ResolveViaNameOrID(ctx context.Context, sgNameOrIDs []string) ([]string, error) { - sgIDs, sgNames := r.splitIntoSgNameAndIDs(sgNameOrIDs) var resolvedSGs []*ec2sdk.SecurityGroup + var errMessages []string + + sgIDs, sgNames := r.splitIntoSgNameAndIDs(sgNameOrIDs) + if len(sgIDs) > 0 { sgs, err := r.resolveViaGroupID(ctx, sgIDs) if err != nil { - return nil, err + errMessages = append(errMessages, err.Error()) + } else { + resolvedSGs = append(resolvedSGs, sgs...) } - resolvedSGs = append(resolvedSGs, sgs...) } + if len(sgNames) > 0 { sgs, err := r.resolveViaGroupName(ctx, sgNames) if err != nil { - return nil, err + errMessages = append(errMessages, err.Error()) + } else { + resolvedSGs = append(resolvedSGs, sgs...) } - resolvedSGs = append(resolvedSGs, sgs...) } + + if len(errMessages) > 0 { + return nil, errors.Errorf("couldn't find all security groups: %s", strings.Join(errMessages, ", ")) + } + resolvedSGIDs := make([]string, 0, len(resolvedSGs)) for _, sg := range resolvedSGs { resolvedSGIDs = append(resolvedSGIDs, awssdk.StringValue(sg.GroupId)) } - if len(resolvedSGIDs) != len(sgNameOrIDs) { - return nil, errors.Errorf("couldn't find all securityGroups, nameOrIDs: %v, found: %v", sgNameOrIDs, resolvedSGIDs) - } + return resolvedSGIDs, nil } @@ -63,14 +73,27 @@ func (r *defaultSecurityGroupResolver) resolveViaGroupID(ctx context.Context, sg req := &ec2sdk.DescribeSecurityGroupsInput{ GroupIds: awssdk.StringSlice(sgIDs), } + sgs, err := r.ec2Client.DescribeSecurityGroupsAsList(ctx, req) if err != nil { return nil, err } + + resolvedSGIDs := make([]string, 0, len(sgs)) + for _, sg := range sgs { + resolvedSGIDs = append(resolvedSGIDs, awssdk.StringValue(sg.GroupId)) + } + + if len(sgIDs) != len(resolvedSGIDs) { + return nil, errors.Errorf("requested ids [%s] but found [%s]", strings.Join(sgIDs, ", "), strings.Join(resolvedSGIDs, ", ")) + } + return sgs, nil } func (r *defaultSecurityGroupResolver) resolveViaGroupName(ctx context.Context, sgNames []string) ([]*ec2sdk.SecurityGroup, error) { + sgNames = algorithm.RemoveSliceDuplicates(sgNames) + req := &ec2sdk.DescribeSecurityGroupsInput{ Filters: []*ec2sdk.Filter{ { @@ -83,10 +106,27 @@ func (r *defaultSecurityGroupResolver) resolveViaGroupName(ctx context.Context, }, }, } + sgs, err := r.ec2Client.DescribeSecurityGroupsAsList(ctx, req) if err != nil { return nil, err } + + resolvedSGNames := make([]string, 0, len(sgs)) + for _, sg := range sgs { + for _, tag := range sg.Tags { + if awssdk.StringValue(tag.Key) == "Name" { + resolvedSGNames = append(resolvedSGNames, awssdk.StringValue(tag.Value)) + } + } + } + + resolvedSGNames = algorithm.RemoveSliceDuplicates(resolvedSGNames) + + if len(sgNames) != len(resolvedSGNames) { + return nil, errors.Errorf("requested names [%s] but found [%s]", strings.Join(sgNames, ", "), strings.Join(resolvedSGNames, ", ")) + } + return sgs, nil } diff --git a/pkg/networking/security_group_resolver_test.go b/pkg/networking/security_group_resolver_test.go index ad155b75a1..bd2c663c50 100644 --- a/pkg/networking/security_group_resolver_test.go +++ b/pkg/networking/security_group_resolver_test.go @@ -88,9 +88,15 @@ func Test_defaultSecurityGroupResolver_ResolveViaNameOrID(t *testing.T) { resp: []*ec2sdk.SecurityGroup{ { GroupId: awssdk.String("sg-0912f63b"), + Tags: []*ec2sdk.Tag{ + {Key: awssdk.String("Name"), Value: awssdk.String("sg group one")}, + }, }, { GroupId: awssdk.String("sg-08982de7"), + Tags: []*ec2sdk.Tag{ + {Key: awssdk.String("Name"), Value: awssdk.String("sg group two")}, + }, }, }, }, @@ -101,6 +107,50 @@ func Test_defaultSecurityGroupResolver_ResolveViaNameOrID(t *testing.T) { "sg-0912f63b", }, }, + { + name: "single name multiple ids", + args: args{ + nameOrIDs: []string{ + "sg group one", + }, + describeSGCalls: []describeSecurityGroupsAsListCall{ + { + req: &ec2sdk.DescribeSecurityGroupsInput{ + Filters: []*ec2sdk.Filter{ + { + Name: awssdk.String("tag:Name"), + Values: awssdk.StringSlice([]string{ + "sg group one", + }), + }, + { + Name: awssdk.String("vpc-id"), + Values: awssdk.StringSlice([]string{defaultVPCID}), + }, + }, + }, + resp: []*ec2sdk.SecurityGroup{ + { + GroupId: awssdk.String("sg-id1"), + Tags: []*ec2sdk.Tag{ + {Key: awssdk.String("Name"), Value: awssdk.String("sg group one")}, + }, + }, + { + GroupId: awssdk.String("sg-id2"), + Tags: []*ec2sdk.Tag{ + {Key: awssdk.String("Name"), Value: awssdk.String("sg group one")}, + }, + }, + }, + }, + }, + }, + want: []string{ + "sg-id1", + "sg-id2", + }, + }, { name: "mixed group name and id", args: args{ @@ -127,6 +177,9 @@ func Test_defaultSecurityGroupResolver_ResolveViaNameOrID(t *testing.T) { resp: []*ec2sdk.SecurityGroup{ { GroupId: awssdk.String("sg-0912f63b"), + Tags: []*ec2sdk.Tag{ + {Key: awssdk.String("Name"), Value: awssdk.String("sg group one")}, + }, }, }, }, @@ -151,7 +204,6 @@ func Test_defaultSecurityGroupResolver_ResolveViaNameOrID(t *testing.T) { name: "describe by id returns error", args: args{ nameOrIDs: []string{ - "sg group name", "sg-id", }, describeSGCalls: []describeSecurityGroupsAsListCall{ @@ -163,24 +215,21 @@ func Test_defaultSecurityGroupResolver_ResolveViaNameOrID(t *testing.T) { }, }, }, - wantErr: errors.New("Describe.Error: unable to describe security groups"), + wantErr: errors.New("couldn't find all security groups: Describe.Error: unable to describe security groups"), }, { name: "describe by name returns error", args: args{ nameOrIDs: []string{ "sg group name", - "sg-id", }, describeSGCalls: []describeSecurityGroupsAsListCall{ { req: &ec2sdk.DescribeSecurityGroupsInput{ Filters: []*ec2sdk.Filter{ { - Name: awssdk.String("tag:Name"), - Values: awssdk.StringSlice([]string{ - "sg group name", - }), + Name: awssdk.String("tag:Name"), + Values: awssdk.StringSlice([]string{"sg group name"}), }, { Name: awssdk.String("vpc-id"), @@ -190,27 +239,38 @@ func Test_defaultSecurityGroupResolver_ResolveViaNameOrID(t *testing.T) { }, err: awserr.New("Describe.Error", "unable to describe security groups", nil), }, + }, + }, + wantErr: errors.New("couldn't find all security groups: Describe.Error: unable to describe security groups"), + }, + { + name: "unable to resolve security groups by id", + args: args{ + nameOrIDs: []string{ + "sg-id1", + "sg-id404", + }, + describeSGCalls: []describeSecurityGroupsAsListCall{ { req: &ec2sdk.DescribeSecurityGroupsInput{ - GroupIds: awssdk.StringSlice([]string{"sg-id"}), + GroupIds: awssdk.StringSlice([]string{"sg-id1", "sg-id404"}), }, resp: []*ec2sdk.SecurityGroup{ { - GroupId: awssdk.String("sg-id"), + GroupId: awssdk.String("sg-id1"), }, }, }, }, }, - wantErr: errors.New("Describe.Error: unable to describe security groups"), + wantErr: errors.New("couldn't find all security groups: requested ids [sg-id1, sg-id404] but found [sg-id1]"), }, { - name: "unable to resolve all security groups", + name: "unable to resolve security groups by name", args: args{ nameOrIDs: []string{ "sg group one", - "sg-id1", - "sg-id404", + "sg group two", }, describeSGCalls: []describeSecurityGroupsAsListCall{ { @@ -220,6 +280,7 @@ func Test_defaultSecurityGroupResolver_ResolveViaNameOrID(t *testing.T) { Name: awssdk.String("tag:Name"), Values: awssdk.StringSlice([]string{ "sg group one", + "sg group two", }), }, { @@ -231,22 +292,48 @@ func Test_defaultSecurityGroupResolver_ResolveViaNameOrID(t *testing.T) { resp: []*ec2sdk.SecurityGroup{ { GroupId: awssdk.String("sg-0912f63b"), + Tags: []*ec2sdk.Tag{ + {Key: awssdk.String("Name"), Value: awssdk.String("sg group one")}, + }, }, }, }, + }, + }, + wantErr: errors.New("couldn't find all security groups: requested names [sg group one, sg group two] but found [sg group one]"), + }, + { + name: "unable to resolve all security groups by ids and names", + args: args{ + nameOrIDs: []string{ + "sg-08982de7", + "sg group one", + }, + describeSGCalls: []describeSecurityGroupsAsListCall{ { req: &ec2sdk.DescribeSecurityGroupsInput{ - GroupIds: awssdk.StringSlice([]string{"sg-id1", "sg-id404"}), + GroupIds: awssdk.StringSlice([]string{"sg-08982de7"}), }, - resp: []*ec2sdk.SecurityGroup{ - { - GroupId: awssdk.String("sg-id1"), + resp: []*ec2sdk.SecurityGroup{}, + }, + { + req: &ec2sdk.DescribeSecurityGroupsInput{ + Filters: []*ec2sdk.Filter{ + { + Name: awssdk.String("tag:Name"), + Values: awssdk.StringSlice([]string{"sg group one"}), + }, + { + Name: awssdk.String("vpc-id"), + Values: awssdk.StringSlice([]string{defaultVPCID}), + }, }, }, + resp: []*ec2sdk.SecurityGroup{}, }, }, }, - wantErr: errors.New("couldn't find all securityGroups, nameOrIDs: [sg group one sg-id1 sg-id404], found: [sg-id1 sg-0912f63b]"), + wantErr: errors.New("couldn't find all security groups: requested ids [sg-08982de7] but found [], requested names [sg group one] but found []"), }, }