Skip to content

Commit

Permalink
wip to pass scheme from upper layer
Browse files Browse the repository at this point in the history
  • Loading branch information
yash97 committed Aug 18, 2024
1 parent df30344 commit 275be82
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 36 deletions.
6 changes: 3 additions & 3 deletions pkg/service/model_build_load_balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func (t *defaultModelBuildTask) buildLoadBalancerSpec(ctx context.Context, schem
if err != nil {
return elbv2model.LoadBalancerSpec{}, err
}
securityGroups, err := t.buildLoadBalancerSecurityGroups(ctx, existingLB, ipAddressType)
securityGroups, err := t.buildLoadBalancerSecurityGroups(ctx, existingLB, scheme, ipAddressType)
if err != nil {
return elbv2model.LoadBalancerSpec{}, err
}
Expand Down Expand Up @@ -101,7 +101,7 @@ func (t *defaultModelBuildTask) buildLoadBalancerSpec(ctx context.Context, schem
}

func (t *defaultModelBuildTask) buildLoadBalancerSecurityGroups(ctx context.Context, existingLB *elbv2deploy.LoadBalancerWithTags,
ipAddressType elbv2model.IPAddressType) ([]core.StringToken, error) {
scheme elbv2model.LoadBalancerScheme, ipAddressType elbv2model.IPAddressType) ([]core.StringToken, error) {
if existingLB != nil && len(existingLB.LoadBalancer.SecurityGroups) == 0 {
return nil, nil
}
Expand All @@ -115,7 +115,7 @@ func (t *defaultModelBuildTask) buildLoadBalancerSecurityGroups(ctx context.Cont
var lbSGTokens []core.StringToken
t.annotationParser.ParseStringSliceAnnotation(annotations.SvcLBSuffixLoadBalancerSecurityGroups, &sgNameOrIDs, t.service.Annotations)
if len(sgNameOrIDs) == 0 {
managedSG, err := t.buildManagedSecurityGroup(ctx, ipAddressType)
managedSG, err := t.buildManagedSecurityGroup(ctx, ipAddressType, scheme)
if err != nil {
return nil, err
}
Expand Down
18 changes: 8 additions & 10 deletions pkg/service/model_build_managed_sg.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,22 @@ const (
resourceIDManagedSecurityGroup = "ManagedLBSecurityGroup"
)

func (t *defaultModelBuildTask) buildManagedSecurityGroup(ctx context.Context, ipAddressType elbv2model.IPAddressType) (*ec2model.SecurityGroup, error) {
sgSpec, err := t.buildManagedSecurityGroupSpec(ctx, ipAddressType)
func (t *defaultModelBuildTask) buildManagedSecurityGroup(ctx context.Context, ipAddressType elbv2model.IPAddressType, scheme elbv2model.LoadBalancerScheme) (*ec2model.SecurityGroup, error) {
sgSpec, err := t.buildManagedSecurityGroupSpec(ctx, ipAddressType, scheme)
if err != nil {
return nil, err
}
sg := ec2model.NewSecurityGroup(t.stack, resourceIDManagedSecurityGroup, sgSpec)
return sg, nil
}

func (t *defaultModelBuildTask) buildManagedSecurityGroupSpec(ctx context.Context, ipAddressType elbv2model.IPAddressType) (ec2model.SecurityGroupSpec, error) {
func (t *defaultModelBuildTask) buildManagedSecurityGroupSpec(ctx context.Context, ipAddressType elbv2model.IPAddressType, scheme elbv2model.LoadBalancerScheme) (ec2model.SecurityGroupSpec, error) {
name := t.buildManagedSecurityGroupName(ctx)
tags, err := t.buildManagedSecurityGroupTags(ctx)
if err != nil {
return ec2model.SecurityGroupSpec{}, err
}
ingressPermissions, err := t.buildManagedSecurityGroupIngressPermissions(ctx, ipAddressType)
ingressPermissions, err := t.buildManagedSecurityGroupIngressPermissions(ctx, ipAddressType, scheme)
if err != nil {
return ec2model.SecurityGroupSpec{}, err
}
Expand All @@ -63,11 +63,11 @@ func (t *defaultModelBuildTask) buildManagedSecurityGroupName(_ context.Context)
return fmt.Sprintf("k8s-%.8s-%.8s-%.10s", sanitizedNamespace, sanitizedName, uuid)
}

func (t *defaultModelBuildTask) buildManagedSecurityGroupIngressPermissions(ctx context.Context, ipAddressType elbv2model.IPAddressType) ([]ec2model.IPPermission, error) {
func (t *defaultModelBuildTask) buildManagedSecurityGroupIngressPermissions(ctx context.Context, ipAddressType elbv2model.IPAddressType, scheme elbv2model.LoadBalancerScheme) ([]ec2model.IPPermission, error) {
var permissions []ec2model.IPPermission
var prefixListIDs []string
prefixListsConfigured := t.annotationParser.ParseStringSliceAnnotation(annotations.SvcLBSuffixSecurityGroupPrefixLists, &prefixListIDs, t.service.Annotations)
cidrs, err := t.buildCIDRsFromSourceRanges(ctx, ipAddressType, prefixListsConfigured)
cidrs, err := t.buildCIDRsFromSourceRanges(ctx, ipAddressType, prefixListsConfigured, scheme)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -116,7 +116,7 @@ func (t *defaultModelBuildTask) buildManagedSecurityGroupIngressPermissions(ctx
return permissions, nil
}

func (t *defaultModelBuildTask) buildCIDRsFromSourceRanges(ctx context.Context, ipAddressType elbv2model.IPAddressType, prefixListsConfigured bool) ([]string, error) {
func (t *defaultModelBuildTask) buildCIDRsFromSourceRanges(ctx context.Context, ipAddressType elbv2model.IPAddressType, prefixListsConfigured bool, scheme elbv2model.LoadBalancerScheme) ([]string, error) {
var cidrs []string
for _, cidr := range t.service.Spec.LoadBalancerSourceRanges {
cidrs = append(cidrs, cidr)
Expand All @@ -133,9 +133,7 @@ func (t *defaultModelBuildTask) buildCIDRsFromSourceRanges(ctx context.Context,
if prefixListsConfigured {
return cidrs, nil
}
var scheme string
ok := t.annotationParser.ParseStringAnnotation(annotations.SvcLBSuffixScheme, &scheme, t.service.Annotations)
if ok && (scheme == string(elbv2model.LoadBalancerSchemeInternal) || scheme == "") {
if scheme == elbv2model.LoadBalancerSchemeInternal {
vpcInfo, err := t.vpcInfoProvider.FetchVPCInfo(ctx, t.vpcID, networking.FetchVPCInfoWithoutCache())
if err != nil {
return cidrs, err
Expand Down
45 changes: 22 additions & 23 deletions pkg/service/model_build_managed_sg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"sigs.k8s.io/aws-load-balancer-controller/pkg/annotations"
ec2model "sigs.k8s.io/aws-load-balancer-controller/pkg/model/ec2"
"sigs.k8s.io/aws-load-balancer-controller/pkg/model/elbv2"
elbv2model "sigs.k8s.io/aws-load-balancer-controller/pkg/model/elbv2"
"sigs.k8s.io/aws-load-balancer-controller/pkg/networking"
)
Expand All @@ -22,6 +23,7 @@ func Test_buildCIDRsFromSourceRanges_buildCIDRsFromSourceRanges(t *testing.T) {
svc *corev1.Service
ipAddressType elbv2model.IPAddressType
prefixListsConfigured bool
scheme elbv2model.LoadBalancerScheme
}
tests := []struct {
name string
Expand Down Expand Up @@ -87,33 +89,28 @@ func Test_buildCIDRsFromSourceRanges_buildCIDRsFromSourceRanges(t *testing.T) {
{
name: "fetch vpc info for internal scheme",
fields: fields{
svc: &corev1.Service{
ObjectMeta: metav1.ObjectMeta{
Annotations: map[string]string{
"service.beta.kubernetes.io/aws-load-balancer-scheme": "internal",
},
},
},
svc: &corev1.Service{},
ipAddressType: elbv2model.IPAddressTypeDualStack,
prefixListsConfigured: false,
scheme: elbv2.LoadBalancerSchemeInternal,
},
setupMock: func(MockVPCInfoProvider *networking.MockVPCInfoProvider) {
vpcInfo := networking.VPCInfo{
CidrBlockAssociationSet: []*ec2sdk.VpcCidrBlockAssociation{
{
CidrBlock: aws.String("192.168.0.0/16"),
CidrBlockState: &ec2sdk.VpcCidrBlockState{State: aws.String(ec2sdk.VpcCidrBlockStateCodeAssociated)},
},
CidrBlockAssociationSet: []*ec2sdk.VpcCidrBlockAssociation{
{
CidrBlock: aws.String("192.168.0.0/16"),
CidrBlockState: &ec2sdk.VpcCidrBlockState{State: aws.String(ec2sdk.VpcCidrBlockStateCodeAssociated)},
},
Ipv6CidrBlockAssociationSet: []*ec2sdk.VpcIpv6CidrBlockAssociation{
{
Ipv6CidrBlock: aws.String("fd00::/8"),
Ipv6CidrBlockState: &ec2sdk.VpcCidrBlockState{State: aws.String(ec2sdk.VpcCidrBlockStateCodeAssociated)},
},
},
Ipv6CidrBlockAssociationSet: []*ec2sdk.VpcIpv6CidrBlockAssociation{
{
Ipv6CidrBlock: aws.String("fd00::/8"),
Ipv6CidrBlockState: &ec2sdk.VpcCidrBlockState{State: aws.String(ec2sdk.VpcCidrBlockStateCodeAssociated)},
},
}
MockVPCInfoProvider.EXPECT().FetchVPCInfo(gomock.Any(), "vpc-1234", gomock.Any()).Return(vpcInfo, nil)
},
},
}
MockVPCInfoProvider.EXPECT().FetchVPCInfo(gomock.Any(), "vpc-1234", gomock.Any()).Return(vpcInfo, nil)
},
wantErr: false,
want: []string{
"192.168.0.0/16",
Expand All @@ -132,6 +129,7 @@ func Test_buildCIDRsFromSourceRanges_buildCIDRsFromSourceRanges(t *testing.T) {
},
ipAddressType: elbv2model.IPAddressTypeDualStack,
prefixListsConfigured: false,
scheme: elbv2.LoadBalancerSchemeInternal,
},
setupMock: func(MockVPCInfoProvider *networking.MockVPCInfoProvider) {
MockVPCInfoProvider.EXPECT().FetchVPCInfo(gomock.Any(), "vpc-1234", gomock.Any()).Return(networking.VPCInfo{}, errors.New("failed to fetch vpcInfo"))
Expand All @@ -152,9 +150,9 @@ func Test_buildCIDRsFromSourceRanges_buildCIDRsFromSourceRanges(t *testing.T) {
annotationParser: annotationParser,
service: tt.fields.svc,
vpcID: "vpc-1234",
vpcInfoProvider: mockVPCInfoProvider,
vpcInfoProvider: mockVPCInfoProvider,
}
got, err := task.buildCIDRsFromSourceRanges(context.Background(), tt.fields.ipAddressType, tt.fields.prefixListsConfigured)
got, err := task.buildCIDRsFromSourceRanges(context.Background(), tt.fields.ipAddressType, tt.fields.prefixListsConfigured, tt.fields.scheme)
if tt.wantErr {
assert.Error(t, err)
} else {
Expand All @@ -169,6 +167,7 @@ func Test_buildCIDRsFromSourceRanges_buildManagedSecurityGroupIngressPermissions
type fields struct {
svc *corev1.Service
ipAddressType elbv2model.IPAddressType
scheme elbv2model.LoadBalancerScheme
}
tests := []struct {
name string
Expand Down Expand Up @@ -348,7 +347,7 @@ func Test_buildCIDRsFromSourceRanges_buildManagedSecurityGroupIngressPermissions
annotationParser: annotationParser,
service: tt.fields.svc,
}
got, err := task.buildManagedSecurityGroupIngressPermissions(context.Background(), tt.fields.ipAddressType)
got, err := task.buildManagedSecurityGroupIngressPermissions(context.Background(), tt.fields.ipAddressType, tt.fields.scheme)
if tt.wantErr {
assert.Error(t, err)
} else {
Expand Down

0 comments on commit 275be82

Please sign in to comment.