Skip to content

Commit

Permalink
DisruptionCrons should validate their DisruptionTemplates (#894)
Browse files Browse the repository at this point in the history
* Add new functions to DisruptionSpec to validate without requiring selectors

* DisruptionCrons should validate their DisruptionTemplate

* Validate DisruptionCrons on Update as well
  • Loading branch information
ptnapoleon authored Aug 7, 2024
1 parent 9947edb commit 561c08c
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 20 deletions.
11 changes: 10 additions & 1 deletion api/v1beta1/disruption_cron_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package v1beta1
import (
"encoding/json"
"errors"
"fmt"
"strings"
"time"

Expand Down Expand Up @@ -74,13 +75,17 @@ func (d *DisruptionCron) ValidateCreate() (admission.Warnings, error) {
return nil, err
}

if err := d.Spec.DisruptionTemplate.ValidateSelectorsOptional(false); err != nil {
return nil, fmt.Errorf("error while validating the spec.disruptionTemplate: %w", err)
}

// send informative event to disruption cron to broadcast
d.emitEvent(EventDisruptionCronCreated)

return nil, nil
}

func (d *DisruptionCron) ValidateUpdate(oldObject runtime.Object) (warnings admission.Warnings, err error) {
func (d *DisruptionCron) ValidateUpdate(oldObject runtime.Object) (admission.Warnings, error) {
log := logger.With("disruptionCronName", d.Name, "disruptionCronNamespace", d.Namespace)

log.Infow("validating updated disruption cron", "spec", d.Spec)
Expand All @@ -89,6 +94,10 @@ func (d *DisruptionCron) ValidateUpdate(oldObject runtime.Object) (warnings admi
return nil, err
}

if err := d.Spec.DisruptionTemplate.ValidateSelectorsOptional(false); err != nil {
return nil, err
}

// send informative event to disruption cron to broadcast
d.emitEvent(EventDisruptionCronUpdated)

Expand Down
23 changes: 22 additions & 1 deletion api/v1beta1/disruption_cron_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
authV1 "k8s.io/api/authentication/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/util/intstr"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
Expand Down Expand Up @@ -50,7 +51,6 @@ var _ = Describe("DisruptionCron Webhook", func() {
It("should send an EventDisruptionCronCreated event to the broadcast", func() {
// Arrange
disruptionCron := makeValidDisruptionCron()

disruptionCronJSON, err := json.Marshal(disruptionCron)
Expect(err).ShouldNot(HaveOccurred())

Expand Down Expand Up @@ -136,6 +136,25 @@ var _ = Describe("DisruptionCron Webhook", func() {
})
})

When("disruptionTemplate is invalid", func() {
When("the count is invalid", func() {
// Other forms of invalid disruptions are covered by the disruption_webook_test.go
// we just want to confirm we do validate the disruptionTemplate as part of the disruption cron webhook
It("should return an error", func() {
disruptionCron := makeValidDisruptionCron()
disruptionCron.Spec.DisruptionTemplate.Count = &intstr.IntOrString{
StrVal: "2hr",
}

warnings, err := disruptionCron.ValidateCreate()

Expect(warnings).To(BeNil())
Expect(err).Should(HaveOccurred())
Expect(err).To(MatchError(ContainSubstring("error while validating the spec.disruptionTemplate")))
})
})
})

When("permitted user groups is present", func() {

BeforeEach(func() {
Expand Down Expand Up @@ -409,9 +428,11 @@ var _ = Describe("DisruptionCron Webhook", func() {
})

func makeValidDisruptionCron() *DisruptionCron {
validDisruption := makeValidNetworkDisruption()
return &DisruptionCron{
TypeMeta: metav1.TypeMeta{
Kind: DisruptionCronKind,
},
Spec: DisruptionCronSpec{DisruptionTemplate: validDisruption.Spec},
}
}
45 changes: 27 additions & 18 deletions api/v1beta1/disruption_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -493,9 +493,16 @@ func (s DisruptionSpec) HashNoCount() (string, error) {
return s.Hash()
}

// Validate applies rules for disruption global scope and all subsequent disruption specifications
// Validate applies rules for disruption global scope and all subsequent disruption specifications, requiring selectors
// intended to be called when DisruptionSpec belongs directly to a Disruption
// also exists for backwards compatibility
func (s DisruptionSpec) Validate() (retErr error) {
if err := s.validateGlobalDisruptionScope(); err != nil {
return s.ValidateSelectorsOptional(true)
}

// ValidateSelectorsOptional applies rules for disruption global scope and all subsequent disruption specifications
func (s DisruptionSpec) ValidateSelectorsOptional(requireSelectors bool) (retErr error) {
if err := s.validateGlobalDisruptionScope(requireSelectors); err != nil {
retErr = multierror.Append(retErr, err)
}

Expand Down Expand Up @@ -548,25 +555,27 @@ func AdvancedSelectorsToRequirements(advancedSelectors []metav1.LabelSelectorReq
return reqs, nil
}

// Validate applies rules for disruption global scope
func (s DisruptionSpec) validateGlobalDisruptionScope() (retErr error) {
// Rule: at least one kind of selector is set
if s.Selector.AsSelector().Empty() && len(s.AdvancedSelector) == 0 {
retErr = multierror.Append(retErr, errors.New("either selector or advancedSelector field must be set"))
}
// validateGlobalDisruptionScope applies rules for disruption global scope, leaving selectors optional
func (s DisruptionSpec) validateGlobalDisruptionScope(requireSelectors bool) (retErr error) {
if requireSelectors {
// Rule: at least one kind of selector is set
if s.Selector.AsSelector().Empty() && len(s.AdvancedSelector) == 0 {
retErr = multierror.Append(retErr, errors.New("either selector or advancedSelector field must be set"))
}

// Rule: selectors must be valid
if !s.Selector.AsSelector().Empty() {
_, err := labels.ParseToRequirements(s.Selector.AsSelector().String())
if err != nil {
retErr = multierror.Append(retErr, err)
// Rule: selectors must be valid
if !s.Selector.AsSelector().Empty() {
_, err := labels.ParseToRequirements(s.Selector.AsSelector().String())
if err != nil {
retErr = multierror.Append(retErr, err)
}
}
}

if len(s.AdvancedSelector) > 0 {
_, err := AdvancedSelectorsToRequirements(s.AdvancedSelector)
if err != nil {
retErr = multierror.Append(retErr, err)
if len(s.AdvancedSelector) > 0 {
_, err := AdvancedSelectorsToRequirements(s.AdvancedSelector)
if err != nil {
retErr = multierror.Append(retErr, err)
}
}
}

Expand Down

0 comments on commit 561c08c

Please sign in to comment.