diff --git a/pkg/service/model_build_load_balancer_addons.go b/pkg/service/model_build_load_balancer_addons.go index b9c4fabb3..a8cc0eba5 100644 --- a/pkg/service/model_build_load_balancer_addons.go +++ b/pkg/service/model_build_load_balancer_addons.go @@ -3,19 +3,19 @@ package service import ( "context" - "github.com/pkg/errors" "sigs.k8s.io/aws-load-balancer-controller/pkg/annotations" + "sigs.k8s.io/aws-load-balancer-controller/pkg/model/core" shieldmodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/shield" ) -func (t *defaultModelBuildTask) buildLoadBalancerAddOns(ctx context.Context) error { - if _, err := t.buildShieldProtection(ctx); err != nil { +func (t *defaultModelBuildTask) buildLoadBalancerAddOns(ctx context.Context, lbARN core.StringToken) error { + if _, err := t.buildShieldProtection(ctx, lbARN); err != nil { return err } return nil } -func (t *defaultModelBuildTask) buildShieldProtection(_ context.Context) (*shieldmodel.Protection, error) { +func (t *defaultModelBuildTask) buildShieldProtection(_ context.Context, lbARN core.StringToken) (*shieldmodel.Protection, error) { explicitEnableProtections := make(map[bool]struct{}) rawEnableProtection := false exists, err := t.annotationParser.ParseBoolAnnotation(annotations.SvcLBSuffixShieldAdvancedProtection, &rawEnableProtection, t.service.Annotations) @@ -28,14 +28,10 @@ func (t *defaultModelBuildTask) buildShieldProtection(_ context.Context) (*shiel if len(explicitEnableProtections) == 0 { return nil, nil } - if len(explicitEnableProtections) > 1 { - return nil, errors.New("conflicting enable shield advanced protection") - } - if _, enableProtection := explicitEnableProtections[true]; enableProtection { - protection := shieldmodel.NewProtection(t.stack, resourceIDLoadBalancer, shieldmodel.ProtectionSpec{ - ResourceARN: t.loadBalancer.LoadBalancerARN(), - }) - return protection, nil - } - return nil, nil + _, enableProtection := explicitEnableProtections[true] + protection := shieldmodel.NewProtection(t.stack, resourceIDLoadBalancer, shieldmodel.ProtectionSpec{ + Enabled: enableProtection, + ResourceARN: lbARN, + }) + return protection, nil } diff --git a/pkg/service/model_build_load_balancer_addons_test.go b/pkg/service/model_build_load_balancer_addons_test.go new file mode 100644 index 000000000..5b2ea8815 --- /dev/null +++ b/pkg/service/model_build_load_balancer_addons_test.go @@ -0,0 +1,114 @@ +package service + +import ( + "context" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/stretchr/testify/assert" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "sigs.k8s.io/aws-load-balancer-controller/pkg/annotations" + "sigs.k8s.io/aws-load-balancer-controller/pkg/model/core" + shieldmodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/shield" +) + +func Test_defaultModelBuildTask_buildShieldProtection(t *testing.T) { + type args struct { + lbARN core.StringToken + } + tests := []struct { + testName string + svc *corev1.Service + args args + want *shieldmodel.Protection + wantError bool + }{ + { + testName: "when shield-advanced-protection annotation is not specified", + svc: &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{}, + }, + }, + args: args{ + lbARN: core.LiteralStringToken("awesome-lb-arn"), + }, + want: nil, + wantError: false, + }, + { + testName: "when shield-advanced-protection annotation set to true", + svc: &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + "service.beta.kubernetes.io/aws-load-balancer-nlb-shield-advanced-protection": "true", + }, + }, + }, + args: args{ + lbARN: core.LiteralStringToken("awesome-lb-arn"), + }, + want: &shieldmodel.Protection{ + Spec: shieldmodel.ProtectionSpec{ + Enabled: true, + ResourceARN: core.LiteralStringToken("awesome-lb-arn"), + }, + }, + wantError: false, + }, + { + testName: "when shield-advanced-protection annotation set to false", + svc: &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + "service.beta.kubernetes.io/aws-load-balancer-nlb-shield-advanced-protection": "false", + }, + }, + }, + args: args{ + lbARN: core.LiteralStringToken("awesome-lb-arn"), + }, + want: &shieldmodel.Protection{ + Spec: shieldmodel.ProtectionSpec{ + Enabled: false, + ResourceARN: core.LiteralStringToken("awesome-lb-arn"), + }, + }, + wantError: false, + }, + { + testName: "when shield-advanced-protection annotation has non boolean value", + svc: &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + "service.beta.kubernetes.io/aws-load-balancer-nlb-shield-advanced-protection": "FalSe1", + }, + }, + }, + args: args{ + lbARN: core.LiteralStringToken("awesome-lb-arn"), + }, + wantError: true, + }, + } + for _, tt := range tests { + t.Run(tt.testName, func(t *testing.T) { + stack := core.NewDefaultStack(core.StackID{Name: "awesome-stack"}) + annotationParser := annotations.NewSuffixAnnotationParser("service.beta.kubernetes.io") + task := &defaultModelBuildTask{ + service: tt.svc, + annotationParser: annotationParser, + stack: stack, + } + got, err := task.buildShieldProtection(context.Background(), tt.args.lbARN) + if tt.wantError { + assert.Error(t, err) + } else { + opts := cmpopts.IgnoreTypes(core.ResourceMeta{}) + assert.True(t, cmp.Equal(tt.want, got, opts), "diff", cmp.Diff(tt.want, got, opts)) + } + }) + } +} diff --git a/pkg/service/model_builder.go b/pkg/service/model_builder.go index 5743886cc..2b16a1ba8 100644 --- a/pkg/service/model_builder.go +++ b/pkg/service/model_builder.go @@ -2,10 +2,11 @@ package service import ( "context" - ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" "strconv" "sync" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "github.com/go-logr/logr" "github.com/pkg/errors" corev1 "k8s.io/api/core/v1" @@ -249,7 +250,7 @@ func (t *defaultModelBuildTask) buildModel(ctx context.Context) error { if err != nil { return err } - if err := t.buildLoadBalancerAddOns(ctx); err != nil { + if err := t.buildLoadBalancerAddOns(ctx, t.loadBalancer.LoadBalancerARN()); err != nil { return err } return nil