Skip to content

Commit

Permalink
Fix restoring PodTemplate on JobSet suspend
Browse files Browse the repository at this point in the history
  • Loading branch information
mimowo committed Jul 26, 2024
1 parent dd5bc69 commit b9678b7
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 39 deletions.
23 changes: 8 additions & 15 deletions pkg/controllers/jobset_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,11 @@ func (r *JobSetReconciler) reconcile(ctx context.Context, js *jobset.JobSet, upd
// Handle suspending a jobset or resuming a suspended jobset.
jobsetSuspended := jobSetSuspended(js)
if jobsetSuspended {
if err := r.suspendJobs(ctx, js, ownedJobs.active, updateStatusOpts); err != nil {
log.Error(err, "suspending jobset")
if err := r.deleteJobs(ctx, ownedJobs.active); err != nil {
log.Error(err, "deleting jobs")
return ctrl.Result{}, err
}
setJobSetSuspendedCondition(js, updateStatusOpts)
} else {
if err := r.resumeJobsIfNecessary(ctx, js, ownedJobs.active, rjobStatuses, updateStatusOpts); err != nil {
log.Error(err, "resuming jobset")
Expand Down Expand Up @@ -378,19 +379,6 @@ func (r *JobSetReconciler) calculateReplicatedJobStatuses(ctx context.Context, j
return rjStatus
}

func (r *JobSetReconciler) suspendJobs(ctx context.Context, js *jobset.JobSet, activeJobs []*batchv1.Job, updateStatusOpts *statusUpdateOpts) error {
for _, job := range activeJobs {
if !jobSuspended(job) {
job.Spec.Suspend = ptr.To(true)
if err := r.Update(ctx, job); err != nil {
return err
}
}
}
setJobSetSuspendedCondition(js, updateStatusOpts)
return nil
}

// resumeJobsIfNecessary iterates through each replicatedJob, resuming any suspended jobs if the JobSet
// is not suspended.
func (r *JobSetReconciler) resumeJobsIfNecessary(ctx context.Context, js *jobset.JobSet, activeJobs []*batchv1.Job, replicatedJobStatuses []jobset.ReplicatedJobStatus, updateStatusOpts *statusUpdateOpts) error {
Expand Down Expand Up @@ -493,6 +481,11 @@ func (r *JobSetReconciler) reconcileReplicatedJobs(ctx context.Context, js *jobs
return err
}

// Don't create child Jobs if the JobSet is suspended
if jobSetSuspended(js) {
continue
}

status := findReplicatedJobStatus(replicatedJobStatus, replicatedJob.Name)

// For startup policy, if the replicatedJob is started we can skip this loop.
Expand Down
35 changes: 25 additions & 10 deletions pkg/webhooks/jobset_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,18 +309,33 @@ func (j *jobSetWebhook) ValidateUpdate(ctx context.Context, old, newObj runtime.
}
mungedSpec := js.Spec.DeepCopy()

// Allow pod template to be mutated for suspended JobSets.
// Allow pod template to be mutated for suspended JobSets, or JobSets getting suspended.
// This is needed for integration with Kueue/DWS.
if ptr.Deref(oldJS.Spec.Suspend, false) {
if ptr.Deref(oldJS.Spec.Suspend, false) || ptr.Deref(js.Spec.Suspend, false) {
for index := range js.Spec.ReplicatedJobs {
// Pod values which must be mutable for Kueue are defined here: https://github.com/kubernetes-sigs/kueue/blob/a50d395c36a2cb3965be5232162cf1fded1bdb08/apis/kueue/v1beta1/workload_types.go#L256-L260
mungedSpec.ReplicatedJobs[index].Template.Spec.Template.Annotations = oldJS.Spec.ReplicatedJobs[index].Template.Spec.Template.Annotations
mungedSpec.ReplicatedJobs[index].Template.Spec.Template.Labels = oldJS.Spec.ReplicatedJobs[index].Template.Spec.Template.Labels
mungedSpec.ReplicatedJobs[index].Template.Spec.Template.Spec.NodeSelector = oldJS.Spec.ReplicatedJobs[index].Template.Spec.Template.Spec.NodeSelector
mungedSpec.ReplicatedJobs[index].Template.Spec.Template.Spec.Tolerations = oldJS.Spec.ReplicatedJobs[index].Template.Spec.Template.Spec.Tolerations

// Pod Scheduling Gates can be updated for batch/v1 Job: https://github.com/kubernetes/kubernetes/blob/ceb58a4dbc671b9d0a2de6d73a1616bc0c299863/pkg/apis/batch/validation/validation.go#L662
mungedSpec.ReplicatedJobs[index].Template.Spec.Template.Spec.SchedulingGates = oldJS.Spec.ReplicatedJobs[index].Template.Spec.Template.Spec.SchedulingGates
munge := true
// Don't allow to unsuspend if there are still active Jobs with
// different PodTemplate. We do this to avoid a race condition when
// Jobs with an old PodTemplate (before suspension) are still
// runnable and would not be deleted if we unsuspend the JobSet.
if !ptr.Deref(js.Spec.Suspend, false) {
rStatus := js.Status.ReplicatedJobsStatus
// Don't allow to mutate PodTemplate on unsuspending if there
// are still active or suspended jobs from the previous run
if len(rStatus) > index && (rStatus[index].Active > 0 || rStatus[index].Suspended > 0) {
munge = false
}
}
if munge {
// Pod values which must be mutable for Kueue are defined here: https://github.com/kubernetes-sigs/kueue/blob/a50d395c36a2cb3965be5232162cf1fded1bdb08/apis/kueue/v1beta1/workload_types.go#L256-L260
mungedSpec.ReplicatedJobs[index].Template.Spec.Template.Annotations = oldJS.Spec.ReplicatedJobs[index].Template.Spec.Template.Annotations
mungedSpec.ReplicatedJobs[index].Template.Spec.Template.Labels = oldJS.Spec.ReplicatedJobs[index].Template.Spec.Template.Labels
mungedSpec.ReplicatedJobs[index].Template.Spec.Template.Spec.NodeSelector = oldJS.Spec.ReplicatedJobs[index].Template.Spec.Template.Spec.NodeSelector
mungedSpec.ReplicatedJobs[index].Template.Spec.Template.Spec.Tolerations = oldJS.Spec.ReplicatedJobs[index].Template.Spec.Template.Spec.Tolerations

// Pod Scheduling Gates can be updated for batch/v1 Job: https://github.com/kubernetes/kubernetes/blob/ceb58a4dbc671b9d0a2de6d73a1616bc0c299863/pkg/apis/batch/validation/validation.go#L662
mungedSpec.ReplicatedJobs[index].Template.Spec.Template.Spec.SchedulingGates = oldJS.Spec.ReplicatedJobs[index].Template.Spec.Template.Spec.SchedulingGates
}
}
}

Expand Down
43 changes: 43 additions & 0 deletions pkg/webhooks/jobset_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1517,6 +1517,49 @@ func TestValidateUpdate(t *testing.T) {
},
},
},
{
name: "replicated job pod template can be updated for jobset getting suspended",
js: &jobset.JobSet{
ObjectMeta: validObjectMeta,
Spec: jobset.JobSetSpec{
Suspend: ptr.To(true),
ReplicatedJobs: []jobset.ReplicatedJob{
{
Name: "test-jobset-replicated-job-0",
Replicas: 2,
Template: batchv1.JobTemplateSpec{
// Restoring the template by removing the annotation
Spec: batchv1.JobSpec{
Parallelism: ptr.To[int32](2),
Template: corev1.PodTemplateSpec{},
},
},
},
},
},
},
oldJs: &jobset.JobSet{
ObjectMeta: validObjectMeta,
Spec: jobset.JobSetSpec{
ReplicatedJobs: []jobset.ReplicatedJob{
{
Name: "test-jobset-replicated-job-0",
Replicas: 2,
Template: batchv1.JobTemplateSpec{
Spec: batchv1.JobSpec{
Parallelism: ptr.To[int32](2),
Template: corev1.PodTemplateSpec{
ObjectMeta: metav1.ObjectMeta{
Annotations: map[string]string{"key": "value"},
},
},
},
},
},
},
},
},
},
{
name: "replicated job pod template cannot be updated for running jobset",
js: &jobset.JobSet{
Expand Down
94 changes: 94 additions & 0 deletions test/e2e/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
"k8s.io/utils/ptr"

jobset "sigs.k8s.io/jobset/api/jobset/v1alpha2"
"sigs.k8s.io/jobset/pkg/util/testing"
Expand Down Expand Up @@ -131,6 +132,77 @@ var _ = ginkgo.Describe("JobSet", func() {
})
})

ginkgo.When("job is unsuspended and suspend", func() {
for rep := 0; rep < 100; rep++ {
ginkgo.It(fmt.Sprintf("should allow to update PodTemplate on unsuspend and restore the PodTemplate on suspend-%v", rep), func() {
ctx := context.Background()
js := shortSleepTestJobSet(ns).Obj()
jsKey := types.NamespacedName{Name: js.Name, Namespace: js.Namespace}

ginkgo.By("Create a suspended JobSet", func() {
js.Spec.Suspend = ptr.To(true)
js.Spec.TTLSecondsAfterFinished = ptr.To[int32](5)
gomega.Expect(k8sClient.Create(ctx, js)).Should(gomega.Succeed())
})

ginkgo.By("Unsuspend the JobSet setting nodeSelectors that prevent pods from being scheduled", func() {
gomega.Eventually(func() error {
gomega.Expect(k8sClient.Get(ctx, jsKey, js)).Should(gomega.Succeed())
js.Spec.Suspend = ptr.To(false)
podTemplate := &js.Spec.ReplicatedJobs[0].Template.Spec.Template
if podTemplate.Spec.NodeSelector == nil {
podTemplate.Spec.NodeSelector = make(map[string]string)
}
podTemplate.Spec.NodeSelector["kubernetes.io/hostname"] = "non-existing-node"
if podTemplate.Labels == nil {
podTemplate.Labels = make(map[string]string)
}
podTemplate.Labels["custom-label-key"] = "custom-label-value"
if podTemplate.Annotations == nil {
podTemplate.Annotations = make(map[string]string)
}
podTemplate.Annotations["custom-annotation-key"] = "custom-annotation-value"
return k8sClient.Update(ctx, js)
}, timeout, interval).Should(gomega.Succeed())
})

ginkgo.By("Await for at least one active Job to make sure there are some running Pods", func() {
gomega.Eventually(func() int32 {
gomega.Expect(k8sClient.Get(ctx, jsKey, js)).Should(gomega.Succeed())
if js.Status.ReplicatedJobsStatus == nil {
return 0
}
return js.Status.ReplicatedJobsStatus[0].Active
}, timeout, interval).Should(gomega.BeNumerically(">=", 1))
})

ginkgo.By("Suspend the JobSet restoring the PodTemplate properties", func() {
gomega.Eventually(func() error {
gomega.Expect(k8sClient.Get(ctx, jsKey, js)).Should(gomega.Succeed())
js.Spec.Suspend = ptr.To(true)
podTemplate := &js.Spec.ReplicatedJobs[0].Template.Spec.Template
delete(podTemplate.Spec.NodeSelector, "kubernetes.io/hostname")
delete(podTemplate.Labels, "custom-label-key")
delete(podTemplate.Annotations, "custom-annotation-key")
return k8sClient.Update(ctx, js)
}, timeout, interval).Should(gomega.Succeed())
})

ginkgo.By("Unsuspending the JobSet again with PodTemplate allowing completion", func() {
gomega.Eventually(func() error {
gomega.Expect(k8sClient.Get(ctx, jsKey, js)).Should(gomega.Succeed())
js.Spec.Suspend = ptr.To(false)
return k8sClient.Update(ctx, js)
}, timeout, interval).Should(gomega.Succeed())
})

ginkgo.By("Await for the JobSet to complete successfully", func() {
util.JobSetCompleted(ctx, k8sClient, js, timeout)
})
})
}
})

}) // end of Describe

// getPingCommand returns ping command for 4 hostnames
Expand Down Expand Up @@ -246,3 +318,25 @@ func sleepTestJobSet(ns *corev1.Namespace) *testing.JobSetWrapper {
Replicas(int32(replicas)).
Obj())
}

func shortSleepTestJobSet(ns *corev1.Namespace) *testing.JobSetWrapper {
jsName := "js"
rjobName := "rjob"
replicas := 3
return testing.MakeJobSet(jsName, ns.Name).
ReplicatedJob(testing.MakeReplicatedJob(rjobName).
Job(testing.MakeJobTemplate("job", ns.Name).
PodSpec(corev1.PodSpec{
RestartPolicy: "Never",
Containers: []corev1.Container{
{
Name: "short-sleep-test-container",
Image: "bash:latest",
Command: []string{"bash", "-c"},
Args: []string{"sleep 1"},
},
},
}).Obj()).
Replicas(int32(replicas)).
Obj())
}
38 changes: 24 additions & 14 deletions test/integration/controller/jobset_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,10 @@ var _ = ginkgo.Describe("JobSet controller", func() {

if up.jobSetUpdateFn != nil {
up.jobSetUpdateFn(&jobSet)
gomega.Expect(k8sClient.Get(ctx, types.NamespacedName{Name: js.Name, Namespace: js.Namespace}, &jobSet)).To(gomega.Succeed())
} else if up.jobUpdateFn != nil {
if up.checkJobCreation == nil {
gomega.Eventually(testutil.NumJobs, timeout, interval).WithArguments(ctx, k8sClient, js).Should(gomega.Equal(testutil.NumExpectedJobs(js)))
gomega.Eventually(testutil.NumJobs, timeout, interval).WithArguments(ctx, k8sClient, js).Should(gomega.Equal(testutil.NumExpectedJobs(&jobSet)))
} else {
up.checkJobCreation(&jobSet)
}
Expand Down Expand Up @@ -899,11 +900,11 @@ var _ = ginkgo.Describe("JobSet controller", func() {
matchJobSetReplicatedStatus(js, []jobset.ReplicatedJobStatus{
{
Name: "replicated-job-b",
Suspended: 3,
Suspended: 0,
},
{
Name: "replicated-job-a",
Suspended: 1,
Suspended: 0,
},
})
},
Expand Down Expand Up @@ -931,11 +932,11 @@ var _ = ginkgo.Describe("JobSet controller", func() {
matchJobSetReplicatedStatus(js, []jobset.ReplicatedJobStatus{
{
Name: "replicated-job-b",
Suspended: 3,
Suspended: 0,
},
{
Name: "replicated-job-a",
Suspended: 1,
Suspended: 0,
},
})
},
Expand Down Expand Up @@ -987,6 +988,11 @@ var _ = ginkgo.Describe("JobSet controller", func() {
{
jobSetUpdateFn: func(js *jobset.JobSet) {
suspendJobSet(js, true)
// For suspended JobSet all jobs will be deleted, so we
// expect a foreground deletion finalizer for every job.
numJobs, err := testutil.NumJobs(ctx, k8sClient, js)
gomega.Expect(err).To(gomega.BeNil())
removeForegroundDeletionFinalizers(js, numJobs)
},
checkJobSetState: func(js *jobset.JobSet) {
ginkgo.By("checking all jobs are suspended")
Expand Down Expand Up @@ -1143,11 +1149,11 @@ var _ = ginkgo.Describe("JobSet controller", func() {
matchJobSetReplicatedStatus(js, []jobset.ReplicatedJobStatus{
{
Name: "replicated-job-b",
Suspended: 3,
Suspended: 0,
},
{
Name: "replicated-job-a",
Suspended: 1,
Suspended: 0,
},
})
},
Expand All @@ -1167,11 +1173,11 @@ var _ = ginkgo.Describe("JobSet controller", func() {
matchJobSetReplicatedStatus(js, []jobset.ReplicatedJobStatus{
{
Name: "replicated-job-b",
Suspended: 3,
Suspended: 0,
},
{
Name: "replicated-job-a",
Suspended: 1,
Suspended: 0,
},
})
},
Expand All @@ -1191,11 +1197,11 @@ var _ = ginkgo.Describe("JobSet controller", func() {
matchJobSetReplicatedStatus(js, []jobset.ReplicatedJobStatus{
{
Name: "replicated-job-b",
Suspended: 3,
Suspended: 0,
},
{
Name: "replicated-job-a",
Suspended: 1,
Suspended: 0,
},
})
},
Expand Down Expand Up @@ -1241,11 +1247,11 @@ var _ = ginkgo.Describe("JobSet controller", func() {
matchJobSetReplicatedStatus(js, []jobset.ReplicatedJobStatus{
{
Name: "replicated-job-b",
Suspended: 3,
Suspended: 0,
},
{
Name: "replicated-job-a",
Suspended: 1,
Suspended: 0,
},
})
},
Expand All @@ -1262,7 +1268,7 @@ var _ = ginkgo.Describe("JobSet controller", func() {
matchJobSetReplicatedStatus(js, []jobset.ReplicatedJobStatus{
{
Name: "replicated-job-b",
Suspended: 3,
Suspended: 0,
},
{
Name: "replicated-job-a",
Expand All @@ -1277,6 +1283,10 @@ var _ = ginkgo.Describe("JobSet controller", func() {
jobUpdateFn: func(jobList *batchv1.JobList) {
readyReplicatedJob(jobList, "replicated-job-a")
},
checkJobCreation: func(js *jobset.JobSet) {
expectedStarts := 1
gomega.Eventually(testutil.NumJobs, timeout, interval).WithArguments(ctx, k8sClient, js).Should(gomega.Equal(expectedStarts))
},
},
{
checkJobSetState: func(js *jobset.JobSet) {
Expand Down
4 changes: 4 additions & 0 deletions test/util/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
batchv1 "k8s.io/api/batch/v1"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/utils/ptr"

apierrors "k8s.io/apimachinery/pkg/api/errors"

Expand All @@ -38,6 +39,9 @@ const interval = time.Millisecond * 250

func NumExpectedJobs(js *jobset.JobSet) int {
expectedJobs := 0
if ptr.Deref(js.Spec.Suspend, false) {
return 0
}
for _, rjob := range js.Spec.ReplicatedJobs {
expectedJobs += int(rjob.Replicas)
}
Expand Down

0 comments on commit b9678b7

Please sign in to comment.