diff --git a/pkg/controller/jobframework/validation.go b/pkg/controller/jobframework/validation.go index bf7c58fcb4..4a120595ad 100644 --- a/pkg/controller/jobframework/validation.go +++ b/pkg/controller/jobframework/validation.go @@ -24,6 +24,7 @@ import ( kfmpi "github.com/kubeflow/mpi-operator/pkg/apis/kubeflow/v2beta1" kftraining "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" batchv1 "k8s.io/api/batch/v1" + corev1 "k8s.io/api/core/v1" apivalidation "k8s.io/apimachinery/pkg/api/validation" "k8s.io/apimachinery/pkg/util/sets" "k8s.io/apimachinery/pkg/util/validation" @@ -144,3 +145,28 @@ func validateUpdateForMaxExecTime(oldJob, newJob GenericJob) field.ErrorList { } return nil } + +// ValidateImmutablePodSpec function is used for serving workloads to ensure no changes are allowed +// to the PodSpec except for the image field in containers. +func ValidateImmutablePodSpec(newPodSpec *corev1.PodSpec, oldPodSpec *corev1.PodSpec, fieldPath *field.Path) field.ErrorList { + // handle updateable fields by munging those fields prior to deep equal comparison. + mungedPodSpec := newPodSpec.DeepCopy() + + // munge spec.containers[*].image + newContainers := make([]corev1.Container, 0, len(newPodSpec.Containers)) + for i, container := range mungedPodSpec.Containers { + container.Image = oldPodSpec.Containers[i].Image + newContainers = append(newContainers, container) + } + mungedPodSpec.Containers = newContainers + + // munge spec.initContainers[*].image + newInitContainers := make([]corev1.Container, 0, len(newPodSpec.InitContainers)) + for ix, container := range mungedPodSpec.InitContainers { + container.Image = oldPodSpec.InitContainers[ix].Image + newInitContainers = append(newInitContainers, container) + } + mungedPodSpec.InitContainers = newInitContainers + + return apivalidation.ValidateImmutableField(mungedPodSpec, oldPodSpec, fieldPath) +} diff --git a/pkg/controller/jobs/statefulset/statefulset_webhook.go b/pkg/controller/jobs/statefulset/statefulset_webhook.go index c442ae162a..20912376a9 100644 --- a/pkg/controller/jobs/statefulset/statefulset_webhook.go +++ b/pkg/controller/jobs/statefulset/statefulset_webhook.go @@ -107,9 +107,11 @@ var ( statefulsetReplicasPath = field.NewPath("spec", "replicas") priorityClassNameLabelPath = statefulsetLabelsPath.Key(constants.WorkloadPriorityClassLabel) statefulsetGroupNameLabelPath = statefulsetLabelsPath.Key(pod.GroupNameLabel) - - podSpecQueueNameLabelPath = field.NewPath("spec", "template", "metadata", "labels"). + podSpecQueueNameLabelPath = field.NewPath("spec", "template", "metadata", "labels"). Key(constants.QueueLabel) + specPath = field.NewPath("spec") + specTemplatePath = specPath.Child("template") + podSpecPath = specTemplatePath.Child("spec") ) func (wh *Webhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Object) (warnings admission.Warnings, err error) { @@ -139,6 +141,12 @@ func (wh *Webhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Ob )...) if isManagedByKueue(newStatefulSet.Object()) { + allErrs = append(allErrs, jobframework.ValidateImmutablePodSpec( + &newStatefulSet.Spec.Template.Spec, + &oldStatefulSet.Spec.Template.Spec, + podSpecPath, + )...) + // TODO(#3279): support resizes later allErrs = append(allErrs, apivalidation.ValidateImmutableField( newStatefulSet.Spec.Replicas, diff --git a/pkg/controller/jobs/statefulset/statefulset_webhook_test.go b/pkg/controller/jobs/statefulset/statefulset_webhook_test.go index 05f34e0976..1347755ba3 100644 --- a/pkg/controller/jobs/statefulset/statefulset_webhook_test.go +++ b/pkg/controller/jobs/statefulset/statefulset_webhook_test.go @@ -24,6 +24,7 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/validation/field" "k8s.io/utils/ptr" @@ -326,6 +327,114 @@ func TestValidateUpdate(t *testing.T) { }, }.ToAggregate(), }, + "attempt to change resources in container": { + oldObj: testingstatefulset.MakeStatefulSet("test-sts", "test-ns"). + Queue("test-queue"). + Template(corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "c", + Image: "pause:0.1.1", + Resources: corev1.ResourceRequirements{Requests: corev1.ResourceList{}}, + }, + }, + InitContainers: []corev1.Container{ + { + Name: "ic", + Image: "pause:0.1.1", + Resources: corev1.ResourceRequirements{Requests: corev1.ResourceList{}}, + }, + }, + }, + }). + Obj(), + newObj: testingstatefulset.MakeStatefulSet("test-sts", "test-ns"). + Queue("test-queue"). + Template(corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "c", + Image: "pause:0.1.1", + Resources: corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1"), + }, + }, + }, + }, + InitContainers: []corev1.Container{ + { + Name: "ic", + Image: "pause:0.1.1", + Resources: corev1.ResourceRequirements{Requests: corev1.ResourceList{}}, + }, + }, + }, + }). + Obj(), + wantErr: field.ErrorList{ + &field.Error{ + Type: field.ErrorTypeInvalid, + Field: podSpecPath.String(), + }, + }.ToAggregate(), + }, + "attempt to change resources in init container": { + oldObj: testingstatefulset.MakeStatefulSet("test-sts", "test-ns"). + Queue("test-queue"). + Template(corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "c", + Image: "pause:0.1.1", + Resources: corev1.ResourceRequirements{Requests: corev1.ResourceList{}}, + }, + }, + InitContainers: []corev1.Container{ + { + Name: "ic", + Image: "pause:0.1.1", + Resources: corev1.ResourceRequirements{Requests: corev1.ResourceList{}}, + }, + }, + }, + }). + Obj(), + newObj: testingstatefulset.MakeStatefulSet("test-sts", "test-ns"). + Queue("test-queue"). + Template(corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "c", + Image: "pause:0.1.1", + Resources: corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1"), + }, + }, + }, + }, + InitContainers: []corev1.Container{ + { + Name: "ic", + Image: "pause:0.1.1", + Resources: corev1.ResourceRequirements{Requests: corev1.ResourceList{}}, + }, + }, + }, + }). + Obj(), + wantErr: field.ErrorList{ + &field.Error{ + Type: field.ErrorTypeInvalid, + Field: podSpecPath.String(), + }, + }.ToAggregate(), + }, } for name, tc := range testCases { diff --git a/pkg/util/testingjobs/statefulset/wrappers.go b/pkg/util/testingjobs/statefulset/wrappers.go index 0bb05964d7..c429071609 100644 --- a/pkg/util/testingjobs/statefulset/wrappers.go +++ b/pkg/util/testingjobs/statefulset/wrappers.go @@ -101,6 +101,12 @@ func (ss *StatefulSetWrapper) WithOwnerReference(ownerReference metav1.OwnerRefe return ss } +// Template sets the template of the StatefulSet. +func (ss *StatefulSetWrapper) Template(template corev1.PodTemplateSpec) *StatefulSetWrapper { + ss.Spec.Template = template + return ss +} + // PodTemplateSpecLabel sets the label of the pod template spec of the StatefulSet func (ss *StatefulSetWrapper) PodTemplateSpecLabel(k, v string) *StatefulSetWrapper { if ss.Spec.Template.Labels == nil {