Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Annotation Check for Load-Balancer Scheme in SG Source Ranges #3781

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pkg/service/model_build_load_balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func (t *defaultModelBuildTask) buildLoadBalancerSpec(ctx context.Context, schem
if err != nil {
return elbv2model.LoadBalancerSpec{}, err
}
securityGroups, err := t.buildLoadBalancerSecurityGroups(ctx, existingLB, ipAddressType)
securityGroups, err := t.buildLoadBalancerSecurityGroups(ctx, existingLB, scheme, ipAddressType)
if err != nil {
return elbv2model.LoadBalancerSpec{}, err
}
Expand Down Expand Up @@ -101,7 +101,7 @@ func (t *defaultModelBuildTask) buildLoadBalancerSpec(ctx context.Context, schem
}

func (t *defaultModelBuildTask) buildLoadBalancerSecurityGroups(ctx context.Context, existingLB *elbv2deploy.LoadBalancerWithTags,
ipAddressType elbv2model.IPAddressType) ([]core.StringToken, error) {
scheme elbv2model.LoadBalancerScheme, ipAddressType elbv2model.IPAddressType) ([]core.StringToken, error) {
if existingLB != nil && len(existingLB.LoadBalancer.SecurityGroups) == 0 {
return nil, nil
}
Expand All @@ -115,7 +115,7 @@ func (t *defaultModelBuildTask) buildLoadBalancerSecurityGroups(ctx context.Cont
var lbSGTokens []core.StringToken
t.annotationParser.ParseStringSliceAnnotation(annotations.SvcLBSuffixLoadBalancerSecurityGroups, &sgNameOrIDs, t.service.Annotations)
if len(sgNameOrIDs) == 0 {
managedSG, err := t.buildManagedSecurityGroup(ctx, ipAddressType)
managedSG, err := t.buildManagedSecurityGroup(ctx, ipAddressType, scheme)
if err != nil {
return nil, err
}
Expand Down
32 changes: 22 additions & 10 deletions pkg/service/model_build_managed_sg.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,29 @@ import (
"sigs.k8s.io/aws-load-balancer-controller/pkg/annotations"
ec2model "sigs.k8s.io/aws-load-balancer-controller/pkg/model/ec2"
elbv2model "sigs.k8s.io/aws-load-balancer-controller/pkg/model/elbv2"
"sigs.k8s.io/aws-load-balancer-controller/pkg/networking"
)

const (
resourceIDManagedSecurityGroup = "ManagedLBSecurityGroup"
)

func (t *defaultModelBuildTask) buildManagedSecurityGroup(ctx context.Context, ipAddressType elbv2model.IPAddressType) (*ec2model.SecurityGroup, error) {
sgSpec, err := t.buildManagedSecurityGroupSpec(ctx, ipAddressType)
func (t *defaultModelBuildTask) buildManagedSecurityGroup(ctx context.Context, ipAddressType elbv2model.IPAddressType, scheme elbv2model.LoadBalancerScheme) (*ec2model.SecurityGroup, error) {
sgSpec, err := t.buildManagedSecurityGroupSpec(ctx, ipAddressType, scheme)
if err != nil {
return nil, err
}
sg := ec2model.NewSecurityGroup(t.stack, resourceIDManagedSecurityGroup, sgSpec)
return sg, nil
}

func (t *defaultModelBuildTask) buildManagedSecurityGroupSpec(ctx context.Context, ipAddressType elbv2model.IPAddressType) (ec2model.SecurityGroupSpec, error) {
func (t *defaultModelBuildTask) buildManagedSecurityGroupSpec(ctx context.Context, ipAddressType elbv2model.IPAddressType, scheme elbv2model.LoadBalancerScheme) (ec2model.SecurityGroupSpec, error) {
name := t.buildManagedSecurityGroupName(ctx)
tags, err := t.buildManagedSecurityGroupTags(ctx)
if err != nil {
return ec2model.SecurityGroupSpec{}, err
}
ingressPermissions, err := t.buildManagedSecurityGroupIngressPermissions(ctx, ipAddressType)
ingressPermissions, err := t.buildManagedSecurityGroupIngressPermissions(ctx, ipAddressType, scheme)
if err != nil {
return ec2model.SecurityGroupSpec{}, err
}
Expand All @@ -62,11 +63,11 @@ func (t *defaultModelBuildTask) buildManagedSecurityGroupName(_ context.Context)
return fmt.Sprintf("k8s-%.8s-%.8s-%.10s", sanitizedNamespace, sanitizedName, uuid)
}

func (t *defaultModelBuildTask) buildManagedSecurityGroupIngressPermissions(ctx context.Context, ipAddressType elbv2model.IPAddressType) ([]ec2model.IPPermission, error) {
func (t *defaultModelBuildTask) buildManagedSecurityGroupIngressPermissions(ctx context.Context, ipAddressType elbv2model.IPAddressType, scheme elbv2model.LoadBalancerScheme) ([]ec2model.IPPermission, error) {
var permissions []ec2model.IPPermission
var prefixListIDs []string
prefixListsConfigured := t.annotationParser.ParseStringSliceAnnotation(annotations.SvcLBSuffixSecurityGroupPrefixLists, &prefixListIDs, t.service.Annotations)
cidrs, err := t.buildCIDRsFromSourceRanges(ctx, ipAddressType, prefixListsConfigured)
cidrs, err := t.buildCIDRsFromSourceRanges(ctx, ipAddressType, prefixListsConfigured, scheme)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -115,7 +116,7 @@ func (t *defaultModelBuildTask) buildManagedSecurityGroupIngressPermissions(ctx
return permissions, nil
}

func (t *defaultModelBuildTask) buildCIDRsFromSourceRanges(_ context.Context, ipAddressType elbv2model.IPAddressType, prefixListsConfigured bool) ([]string, error) {
func (t *defaultModelBuildTask) buildCIDRsFromSourceRanges(ctx context.Context, ipAddressType elbv2model.IPAddressType, prefixListsConfigured bool, scheme elbv2model.LoadBalancerScheme) ([]string, error) {
var cidrs []string
for _, cidr := range t.service.Spec.LoadBalancerSourceRanges {
cidrs = append(cidrs, cidr)
Expand All @@ -132,9 +133,20 @@ func (t *defaultModelBuildTask) buildCIDRsFromSourceRanges(_ context.Context, ip
if prefixListsConfigured {
return cidrs, nil
}
cidrs = append(cidrs, "0.0.0.0/0")
if ipAddressType == elbv2model.IPAddressTypeDualStack {
cidrs = append(cidrs, "::/0")
if scheme == elbv2model.LoadBalancerSchemeInternal {
vpcInfo, err := t.vpcInfoProvider.FetchVPCInfo(ctx, t.vpcID, networking.FetchVPCInfoWithoutCache())
if err != nil {
return cidrs, err
}
cidrs = append(cidrs, vpcInfo.AssociatedIPv4CIDRs()...)
if ipAddressType == elbv2model.IPAddressTypeDualStack {
cidrs = append(cidrs, vpcInfo.AssociatedIPv6CIDRs()...)
}
} else {
cidrs = append(cidrs, "0.0.0.0/0")
if ipAddressType == elbv2model.IPAddressTypeDualStack {
cidrs = append(cidrs, "::/0")
}
}
}
return cidrs, nil
Expand Down
85 changes: 77 additions & 8 deletions pkg/service/model_build_managed_sg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,32 @@ import (
"testing"

"github.com/aws/aws-sdk-go/aws"
ec2sdk "github.com/aws/aws-sdk-go/service/ec2"
"github.com/golang/mock/gomock"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"sigs.k8s.io/aws-load-balancer-controller/pkg/annotations"
ec2model "sigs.k8s.io/aws-load-balancer-controller/pkg/model/ec2"
"sigs.k8s.io/aws-load-balancer-controller/pkg/model/elbv2"
elbv2model "sigs.k8s.io/aws-load-balancer-controller/pkg/model/elbv2"
"sigs.k8s.io/aws-load-balancer-controller/pkg/networking"
)

func Test_buildCIDRsFromSourceRanges_buildCIDRsFromSourceRanges(t *testing.T) {
type fields struct {
svc *corev1.Service
ipAddressType elbv2model.IPAddressType
prefixListsConfigured bool
scheme elbv2model.LoadBalancerScheme
}
tests := []struct {
name string
fields fields
want []string
wantErr bool
name string
fields fields
setupMock func(MockVPCInfoProvider *networking.MockVPCInfoProvider)
want []string
wantErr bool
}{
{
name: "default IPv4",
Expand All @@ -36,7 +43,8 @@ func Test_buildCIDRsFromSourceRanges_buildCIDRsFromSourceRanges(t *testing.T) {
ipAddressType: elbv2model.IPAddressTypeIPV4,
prefixListsConfigured: false,
},
wantErr: false,
setupMock: func(MockVPCInfoProvider *networking.MockVPCInfoProvider) {},
wantErr: false,
want: []string{
"0.0.0.0/0",
},
Expand All @@ -54,7 +62,8 @@ func Test_buildCIDRsFromSourceRanges_buildCIDRsFromSourceRanges(t *testing.T) {
ipAddressType: elbv2model.IPAddressTypeDualStack,
prefixListsConfigured: false,
},
wantErr: false,
setupMock: func(MockVPCInfoProvider *networking.MockVPCInfoProvider) {},
wantErr: false,
want: []string{
"0.0.0.0/0",
"::/0",
Expand All @@ -73,18 +82,77 @@ func Test_buildCIDRsFromSourceRanges_buildCIDRsFromSourceRanges(t *testing.T) {
ipAddressType: elbv2model.IPAddressTypeDualStack,
prefixListsConfigured: true,
},
setupMock: func(MockVPCInfoProvider *networking.MockVPCInfoProvider) {},
wantErr: false,
want: nil,
},
{
name: "fetch vpc info for internal scheme",
fields: fields{
svc: &corev1.Service{},
ipAddressType: elbv2model.IPAddressTypeDualStack,
prefixListsConfigured: false,
scheme: elbv2.LoadBalancerSchemeInternal,
},
setupMock: func(MockVPCInfoProvider *networking.MockVPCInfoProvider) {
vpcInfo := networking.VPCInfo{
CidrBlockAssociationSet: []*ec2sdk.VpcCidrBlockAssociation{
{
CidrBlock: aws.String("192.168.0.0/16"),
CidrBlockState: &ec2sdk.VpcCidrBlockState{State: aws.String(ec2sdk.VpcCidrBlockStateCodeAssociated)},
},
},
Ipv6CidrBlockAssociationSet: []*ec2sdk.VpcIpv6CidrBlockAssociation{
{
Ipv6CidrBlock: aws.String("fd00::/8"),
Ipv6CidrBlockState: &ec2sdk.VpcCidrBlockState{State: aws.String(ec2sdk.VpcCidrBlockStateCodeAssociated)},
},
},
}
MockVPCInfoProvider.EXPECT().FetchVPCInfo(gomock.Any(), "vpc-1234", gomock.Any()).Return(vpcInfo, nil)
},
wantErr: false,
want: []string{
"192.168.0.0/16",
"fd00::/8",
},
},
{
name: "error fetching vpc info",
fields: fields{
svc: &corev1.Service{
ObjectMeta: metav1.ObjectMeta{
Annotations: map[string]string{
"service.beta.kubernetes.io/aws-load-balancer-scheme": "internal",
},
},
},
ipAddressType: elbv2model.IPAddressTypeDualStack,
prefixListsConfigured: false,
scheme: elbv2.LoadBalancerSchemeInternal,
},
setupMock: func(MockVPCInfoProvider *networking.MockVPCInfoProvider) {
MockVPCInfoProvider.EXPECT().FetchVPCInfo(gomock.Any(), "vpc-1234", gomock.Any()).Return(networking.VPCInfo{}, errors.New("failed to fetch vpcInfo"))
},
wantErr: true,
want: nil,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t1 *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockVPCInfoProvider := networking.NewMockVPCInfoProvider(ctrl)
tt.setupMock(mockVPCInfoProvider)
annotationParser := annotations.NewSuffixAnnotationParser("service.beta.kubernetes.io")
task := &defaultModelBuildTask{
annotationParser: annotationParser,
service: tt.fields.svc,
vpcID: "vpc-1234",
vpcInfoProvider: mockVPCInfoProvider,
}
got, err := task.buildCIDRsFromSourceRanges(context.Background(), tt.fields.ipAddressType, tt.fields.prefixListsConfigured)
got, err := task.buildCIDRsFromSourceRanges(context.Background(), tt.fields.ipAddressType, tt.fields.prefixListsConfigured, tt.fields.scheme)
if tt.wantErr {
assert.Error(t, err)
} else {
Expand All @@ -99,6 +167,7 @@ func Test_buildCIDRsFromSourceRanges_buildManagedSecurityGroupIngressPermissions
type fields struct {
svc *corev1.Service
ipAddressType elbv2model.IPAddressType
scheme elbv2model.LoadBalancerScheme
}
tests := []struct {
name string
Expand Down Expand Up @@ -278,7 +347,7 @@ func Test_buildCIDRsFromSourceRanges_buildManagedSecurityGroupIngressPermissions
annotationParser: annotationParser,
service: tt.fields.svc,
}
got, err := task.buildManagedSecurityGroupIngressPermissions(context.Background(), tt.fields.ipAddressType)
got, err := task.buildManagedSecurityGroupIngressPermissions(context.Background(), tt.fields.ipAddressType, tt.fields.scheme)
if tt.wantErr {
assert.Error(t, err)
} else {
Expand Down
Loading
Loading