diff --git a/pkg/controller/jobs/kubeflow/jobs/paddlejob/paddlejob_multikueue_adapter_test.go b/pkg/controller/jobs/kubeflow/jobs/paddlejob/paddlejob_multikueue_adapter_test.go index f70fada709..4c2caafd58 100644 --- a/pkg/controller/jobs/kubeflow/jobs/paddlejob/paddlejob_multikueue_adapter_test.go +++ b/pkg/controller/jobs/kubeflow/jobs/paddlejob/paddlejob_multikueue_adapter_test.go @@ -18,6 +18,7 @@ package paddlejob import ( "context" + "errors" "testing" "github.com/google/go-cmp/cmp" @@ -49,6 +50,7 @@ func TestMultikueueAdapter(t *testing.T) { } paddleJobBuilder := kfutiltesting.MakePaddleJob("paddlejob1", TestNamespace).Queue("queue").Suspend(false) + paddleJobManagedByKueueBuilder := paddleJobBuilder.Clone().ManagedBy(kueue.MultiKueueControllerName) cases := map[string]struct { managersPaddleJobs []kftraining.PaddleJob @@ -118,6 +120,42 @@ func TestMultikueueAdapter(t *testing.T) { return adapter.DeleteRemoteObject(ctx, workerClient, types.NamespacedName{Name: "paddlejob1", Namespace: TestNamespace}) }, }, + "missing job is not considered managed": { + operation: func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error { + if isManged, _, _ := adapter.IsJobManagedByKueue(ctx, managerClient, types.NamespacedName{Name: "paddlejob1", Namespace: TestNamespace}); isManged { + return errors.New("expecting false") + } + return nil + }, + }, + "job with wrong managedBy is not considered managed": { + managersPaddleJobs: []kftraining.PaddleJob{ + *paddleJobBuilder.DeepCopy(), + }, + operation: func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error { + if isManged, _, _ := adapter.IsJobManagedByKueue(ctx, managerClient, types.NamespacedName{Name: "paddlejob1", Namespace: TestNamespace}); isManged { + return errors.New("expecting false") + } + return nil + }, + wantManagersPaddleJobs: []kftraining.PaddleJob{ + *paddleJobBuilder.DeepCopy(), + }, + }, + "job managedBy multikueue": { + managersPaddleJobs: []kftraining.PaddleJob{ + *paddleJobManagedByKueueBuilder.DeepCopy(), + }, + operation: func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error { + if isManged, _, _ := adapter.IsJobManagedByKueue(ctx, managerClient, types.NamespacedName{Name: "paddlejob1", Namespace: TestNamespace}); !isManged { + return errors.New("expecting true") + } + return nil + }, + wantManagersPaddleJobs: []kftraining.PaddleJob{ + *paddleJobManagedByKueueBuilder.DeepCopy(), + }, + }, } for name, tc := range cases { t.Run(name, func(t *testing.T) { diff --git a/pkg/controller/jobs/kubeflow/jobs/pytorchjob/pytorch_multikueue_adapter_test.go b/pkg/controller/jobs/kubeflow/jobs/pytorchjob/pytorch_multikueue_adapter_test.go index 905000b0d1..70e08dbb4b 100644 --- a/pkg/controller/jobs/kubeflow/jobs/pytorchjob/pytorch_multikueue_adapter_test.go +++ b/pkg/controller/jobs/kubeflow/jobs/pytorchjob/pytorch_multikueue_adapter_test.go @@ -18,6 +18,7 @@ package pytorchjob import ( "context" + "errors" "testing" "github.com/google/go-cmp/cmp" @@ -49,6 +50,7 @@ func TestMultikueueAdapter(t *testing.T) { } pyTorchJobBuilder := kfutiltesting.MakePyTorchJob("pytorchjob1", TestNamespace).Queue("queue").Suspend(false) + pyTorchJobManagedByKueueBuilder := pyTorchJobBuilder.Clone().ManagedBy(kueue.MultiKueueControllerName) cases := map[string]struct { managersPyTorchJobs []kftraining.PyTorchJob @@ -117,6 +119,42 @@ func TestMultikueueAdapter(t *testing.T) { return adapter.DeleteRemoteObject(ctx, workerClient, types.NamespacedName{Name: "pytorchjob1", Namespace: TestNamespace}) }, }, + "missing job is not considered managed": { + operation: func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error { + if isManged, _, _ := adapter.IsJobManagedByKueue(ctx, managerClient, types.NamespacedName{Name: "pytorchjob1", Namespace: TestNamespace}); isManged { + return errors.New("expecting false") + } + return nil + }, + }, + "job with wrong managedBy is not considered managed": { + managersPyTorchJobs: []kftraining.PyTorchJob{ + *pyTorchJobBuilder.DeepCopy(), + }, + operation: func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error { + if isManged, _, _ := adapter.IsJobManagedByKueue(ctx, managerClient, types.NamespacedName{Name: "pytorchjob1", Namespace: TestNamespace}); isManged { + return errors.New("expecting false") + } + return nil + }, + wantManagersPyTorchJobs: []kftraining.PyTorchJob{ + *pyTorchJobBuilder.DeepCopy(), + }, + }, + "job managedBy multikueue": { + managersPyTorchJobs: []kftraining.PyTorchJob{ + *pyTorchJobManagedByKueueBuilder.DeepCopy(), + }, + operation: func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error { + if isManged, _, _ := adapter.IsJobManagedByKueue(ctx, managerClient, types.NamespacedName{Name: "pytorchjob1", Namespace: TestNamespace}); !isManged { + return errors.New("expecting true") + } + return nil + }, + wantManagersPyTorchJobs: []kftraining.PyTorchJob{ + *pyTorchJobManagedByKueueBuilder.DeepCopy(), + }, + }, } for name, tc := range cases { t.Run(name, func(t *testing.T) { diff --git a/pkg/controller/jobs/kubeflow/jobs/tfjob/tfjob_multikueue_adapter_test.go b/pkg/controller/jobs/kubeflow/jobs/tfjob/tfjob_multikueue_adapter_test.go index 3a3b90503c..ee43b784d2 100644 --- a/pkg/controller/jobs/kubeflow/jobs/tfjob/tfjob_multikueue_adapter_test.go +++ b/pkg/controller/jobs/kubeflow/jobs/tfjob/tfjob_multikueue_adapter_test.go @@ -18,6 +18,7 @@ package tfjob import ( "context" + "errors" "testing" "github.com/google/go-cmp/cmp" @@ -49,6 +50,7 @@ func TestMultikueueAdapter(t *testing.T) { } tfJobBuilder := kfutiltesting.MakeTFJob("tfjob1", TestNamespace).Queue("queue").Suspend(false) + tfJobManagedByKueueBuilder := tfJobBuilder.Clone().ManagedBy(kueue.MultiKueueControllerName) cases := map[string]struct { managersTFJobs []kftraining.TFJob @@ -117,6 +119,42 @@ func TestMultikueueAdapter(t *testing.T) { return adapter.DeleteRemoteObject(ctx, workerClient, types.NamespacedName{Name: "tfjob1", Namespace: TestNamespace}) }, }, + "missing job is not considered managed": { + operation: func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error { + if isManged, _, _ := adapter.IsJobManagedByKueue(ctx, managerClient, types.NamespacedName{Name: "tfjob1", Namespace: TestNamespace}); isManged { + return errors.New("expecting false") + } + return nil + }, + }, + "job with wrong managedBy is not considered managed": { + managersTFJobs: []kftraining.TFJob{ + *tfJobBuilder.DeepCopy(), + }, + operation: func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error { + if isManged, _, _ := adapter.IsJobManagedByKueue(ctx, managerClient, types.NamespacedName{Name: "tfjob1", Namespace: TestNamespace}); isManged { + return errors.New("expecting false") + } + return nil + }, + wantManagersTFJobs: []kftraining.TFJob{ + *tfJobBuilder.DeepCopy(), + }, + }, + "job managedBy multikueue": { + managersTFJobs: []kftraining.TFJob{ + *tfJobManagedByKueueBuilder.DeepCopy(), + }, + operation: func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error { + if isManged, _, _ := adapter.IsJobManagedByKueue(ctx, managerClient, types.NamespacedName{Name: "tfjob1", Namespace: TestNamespace}); !isManged { + return errors.New("expecting true") + } + return nil + }, + wantManagersTFJobs: []kftraining.TFJob{ + *tfJobManagedByKueueBuilder.DeepCopy(), + }, + }, } for name, tc := range cases { t.Run(name, func(t *testing.T) { diff --git a/pkg/controller/jobs/kubeflow/jobs/xgboostjob/xgboostjob_multikueue_adapter_test.go b/pkg/controller/jobs/kubeflow/jobs/xgboostjob/xgboostjob_multikueue_adapter_test.go index fb26a4241a..d212929bf3 100644 --- a/pkg/controller/jobs/kubeflow/jobs/xgboostjob/xgboostjob_multikueue_adapter_test.go +++ b/pkg/controller/jobs/kubeflow/jobs/xgboostjob/xgboostjob_multikueue_adapter_test.go @@ -18,6 +18,7 @@ package xgboostjob import ( "context" + "errors" "testing" "github.com/google/go-cmp/cmp" @@ -49,6 +50,7 @@ func TestMultikueueAdapter(t *testing.T) { } xgboostJobBuilder := kfutiltesting.MakeXGBoostJob("xgboostjob1", TestNamespace).Queue("queue").Suspend(false) + xgboostJobManagedByKueueBuilder := xgboostJobBuilder.Clone().ManagedBy(kueue.MultiKueueControllerName) cases := map[string]struct { managersXGBoostJobs []kftraining.XGBoostJob @@ -117,6 +119,42 @@ func TestMultikueueAdapter(t *testing.T) { return adapter.DeleteRemoteObject(ctx, workerClient, types.NamespacedName{Name: "xgboostjob1", Namespace: TestNamespace}) }, }, + "missing job is not considered managed": { + operation: func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error { + if isManged, _, _ := adapter.IsJobManagedByKueue(ctx, managerClient, types.NamespacedName{Name: "xgboostjob1", Namespace: TestNamespace}); isManged { + return errors.New("expecting false") + } + return nil + }, + }, + "job with wrong managedBy is not considered managed": { + managersXGBoostJobs: []kftraining.XGBoostJob{ + *xgboostJobBuilder.DeepCopy(), + }, + operation: func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error { + if isManged, _, _ := adapter.IsJobManagedByKueue(ctx, managerClient, types.NamespacedName{Name: "xgboostjob1", Namespace: TestNamespace}); isManged { + return errors.New("expecting false") + } + return nil + }, + wantManagersXGBoostJobs: []kftraining.XGBoostJob{ + *xgboostJobBuilder.DeepCopy(), + }, + }, + "job managedBy multikueue": { + managersXGBoostJobs: []kftraining.XGBoostJob{ + *xgboostJobManagedByKueueBuilder.DeepCopy(), + }, + operation: func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error { + if isManged, _, _ := adapter.IsJobManagedByKueue(ctx, managerClient, types.NamespacedName{Name: "xgboostjob1", Namespace: TestNamespace}); !isManged { + return errors.New("expecting true") + } + return nil + }, + wantManagersXGBoostJobs: []kftraining.XGBoostJob{ + *xgboostJobManagedByKueueBuilder.DeepCopy(), + }, + }, } for name, tc := range cases { t.Run(name, func(t *testing.T) { diff --git a/pkg/controller/jobs/kubeflow/kubeflowjob/kubeflowjob_controller.go b/pkg/controller/jobs/kubeflow/kubeflowjob/kubeflowjob_controller.go index e804d4f673..5a6ceb1a4d 100644 --- a/pkg/controller/jobs/kubeflow/kubeflowjob/kubeflowjob_controller.go +++ b/pkg/controller/jobs/kubeflow/kubeflowjob/kubeflowjob_controller.go @@ -29,6 +29,7 @@ import ( kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1" "sigs.k8s.io/kueue/pkg/controller/jobframework" + "sigs.k8s.io/kueue/pkg/features" "sigs.k8s.io/kueue/pkg/podset" ) @@ -39,6 +40,7 @@ type KubeflowJob struct { var _ jobframework.GenericJob = (*KubeflowJob)(nil) var _ jobframework.JobWithPriorityClass = (*KubeflowJob)(nil) var _ jobframework.JobWithCustomValidation = (*KubeflowJob)(nil) +var _ jobframework.JobWithManagedBy = (*KubeflowJob)(nil) func (j *KubeflowJob) Object() client.Object { return j.KFJobControl.Object() @@ -196,3 +198,17 @@ func (j *KubeflowJob) ValidateOnUpdate(_ jobframework.GenericJob) field.ErrorLis func podsCount(replicaSpecs map[kftraining.ReplicaType]*kftraining.ReplicaSpec, replicaType kftraining.ReplicaType) int32 { return ptr.Deref(replicaSpecs[replicaType].Replicas, 1) } + +func (j *KubeflowJob) CanDefaultManagedBy() bool { + jobSpecManagedBy := j.KFJobControl.RunPolicy().ManagedBy + return features.Enabled(features.MultiKueue) && + (jobSpecManagedBy == nil || *jobSpecManagedBy == kftraining.KubeflowJobsController) +} + +func (j *KubeflowJob) ManagedBy() *string { + return j.KFJobControl.RunPolicy().ManagedBy +} + +func (j *KubeflowJob) SetManagedBy(managedBy *string) { + j.KFJobControl.RunPolicy().ManagedBy = managedBy +} diff --git a/pkg/controller/jobs/kubeflow/kubeflowjob/kubeflowjob_multikueue_adapter.go b/pkg/controller/jobs/kubeflow/kubeflowjob/kubeflowjob_multikueue_adapter.go index 2b8d85774d..06fbb11a03 100644 --- a/pkg/controller/jobs/kubeflow/kubeflowjob/kubeflowjob_multikueue_adapter.go +++ b/pkg/controller/jobs/kubeflow/kubeflowjob/kubeflowjob_multikueue_adapter.go @@ -25,6 +25,7 @@ import ( "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/types" "k8s.io/klog/v2" + "k8s.io/utils/ptr" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" @@ -77,7 +78,19 @@ func (a adapter[PtrT, T]) KeepAdmissionCheckPending() bool { return false } -func (a adapter[PtrT, T]) IsJobManagedByKueue(context.Context, client.Client, types.NamespacedName) (bool, string, error) { +func (a adapter[PtrT, T]) IsJobManagedByKueue(ctx context.Context, c client.Client, key types.NamespacedName) (bool, string, error) { + kJobObj := PtrT(new(T)) + err := c.Get(ctx, key, kJobObj) + if err != nil { + return false, "", err + } + + kJob := a.fromObject(kJobObj) + jobControllerName := ptr.Deref(kJob.KFJobControl.RunPolicy().ManagedBy, "") + if jobControllerName != kueue.MultiKueueControllerName { + return false, fmt.Sprintf("Expecting spec.managedBy to be %q not %q", kueue.MultiKueueControllerName, jobControllerName), nil + } + return true, "", nil } @@ -127,6 +140,12 @@ func (a adapter[PtrT, T]) SyncJob( labels[kueue.MultiKueueOriginLabel] = origin remoteJob.SetLabels(labels) + fmt.Println("KACZKA1", a.fromObject(remoteJob).ManagedBy() ) + // clear the managedBy enables the controller to take over + a.fromObject(remoteJob).SetManagedBy(nil) + + fmt.Println("KACZKA2", a.fromObject(remoteJob).ManagedBy() ) + return remoteClient.Create(ctx, remoteJob) } diff --git a/pkg/util/testingjobs/paddlejob/wrappers.go b/pkg/util/testingjobs/paddlejob/wrappers.go index 7172b939ab..aa521d4e03 100644 --- a/pkg/util/testingjobs/paddlejob/wrappers.go +++ b/pkg/util/testingjobs/paddlejob/wrappers.go @@ -215,3 +215,9 @@ func (j *PaddleJobWrapper) StatusConditions(conditions ...kftraining.JobConditio j.Status.Conditions = conditions return j } + +// ManagedBy adds a managedby. +func (j *PaddleJobWrapper) ManagedBy(c string) *PaddleJobWrapper { + j.Spec.RunPolicy.ManagedBy = &c + return j +} diff --git a/pkg/util/testingjobs/pytorchjob/wrappers_pytorchjob.go b/pkg/util/testingjobs/pytorchjob/wrappers_pytorchjob.go index 4fe11f192a..ecd5f1f5da 100644 --- a/pkg/util/testingjobs/pytorchjob/wrappers_pytorchjob.go +++ b/pkg/util/testingjobs/pytorchjob/wrappers_pytorchjob.go @@ -232,3 +232,9 @@ func (j *PyTorchJobWrapper) SetTypeMeta() *PyTorchJobWrapper { j.Kind = kftraining.PyTorchJobKind return j } + +// ManagedBy adds a managedby. +func (j *PyTorchJobWrapper) ManagedBy(c string) *PyTorchJobWrapper { + j.Spec.RunPolicy.ManagedBy = &c + return j +} diff --git a/pkg/util/testingjobs/tfjob/wrappers_tfjob.go b/pkg/util/testingjobs/tfjob/wrappers_tfjob.go index 19c6838215..bac65571d5 100644 --- a/pkg/util/testingjobs/tfjob/wrappers_tfjob.go +++ b/pkg/util/testingjobs/tfjob/wrappers_tfjob.go @@ -211,3 +211,9 @@ func (j *TFJobWrapper) Image(replicaType kftraining.ReplicaType, image string, a j.Spec.TFReplicaSpecs[replicaType].Template.Spec.Containers[0].Args = args return j } + +// ManagedBy adds a managedby. +func (j *TFJobWrapper) ManagedBy(c string) *TFJobWrapper { + j.Spec.RunPolicy.ManagedBy = &c + return j +} \ No newline at end of file diff --git a/pkg/util/testingjobs/xgboostjob/wrappers.go b/pkg/util/testingjobs/xgboostjob/wrappers.go index 718f1ee6ee..725edef41d 100644 --- a/pkg/util/testingjobs/xgboostjob/wrappers.go +++ b/pkg/util/testingjobs/xgboostjob/wrappers.go @@ -215,3 +215,9 @@ func (j *XGBoostJobWrapper) StatusConditions(conditions ...kftraining.JobConditi j.Status.Conditions = conditions return j } + +// ManagedBy adds a managedby. +func (j *XGBoostJobWrapper) ManagedBy(c string) *XGBoostJobWrapper { + j.Spec.RunPolicy.ManagedBy = &c + return j +} \ No newline at end of file