Skip to content

Commit

Permalink
Update Kubeflow MK adapters to support ManagedBy
Browse files Browse the repository at this point in the history
  • Loading branch information
mszadkow committed Jan 31, 2025
1 parent dd4c8b3 commit a0cbb9a
Show file tree
Hide file tree
Showing 10 changed files with 212 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package paddlejob

import (
"context"
"errors"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -49,6 +50,7 @@ func TestMultikueueAdapter(t *testing.T) {
}

paddleJobBuilder := kfutiltesting.MakePaddleJob("paddlejob1", TestNamespace).Queue("queue").Suspend(false)
paddleJobManagedByKueueBuilder := paddleJobBuilder.Clone().ManagedBy(kueue.MultiKueueControllerName)

cases := map[string]struct {
managersPaddleJobs []kftraining.PaddleJob
Expand Down Expand Up @@ -118,6 +120,42 @@ func TestMultikueueAdapter(t *testing.T) {
return adapter.DeleteRemoteObject(ctx, workerClient, types.NamespacedName{Name: "paddlejob1", Namespace: TestNamespace})
},
},
"missing job is not considered managed": {
operation: func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error {
if isManged, _, _ := adapter.IsJobManagedByKueue(ctx, managerClient, types.NamespacedName{Name: "paddlejob1", Namespace: TestNamespace}); isManged {
return errors.New("expecting false")
}
return nil
},
},
"job with wrong managedBy is not considered managed": {
managersPaddleJobs: []kftraining.PaddleJob{
*paddleJobBuilder.DeepCopy(),
},
operation: func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error {
if isManged, _, _ := adapter.IsJobManagedByKueue(ctx, managerClient, types.NamespacedName{Name: "paddlejob1", Namespace: TestNamespace}); isManged {
return errors.New("expecting false")
}
return nil
},
wantManagersPaddleJobs: []kftraining.PaddleJob{
*paddleJobBuilder.DeepCopy(),
},
},
"job managedBy multikueue": {
managersPaddleJobs: []kftraining.PaddleJob{
*paddleJobManagedByKueueBuilder.DeepCopy(),
},
operation: func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error {
if isManged, _, _ := adapter.IsJobManagedByKueue(ctx, managerClient, types.NamespacedName{Name: "paddlejob1", Namespace: TestNamespace}); !isManged {
return errors.New("expecting true")
}
return nil
},
wantManagersPaddleJobs: []kftraining.PaddleJob{
*paddleJobManagedByKueueBuilder.DeepCopy(),
},
},
}
for name, tc := range cases {
t.Run(name, func(t *testing.T) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package pytorchjob

import (
"context"
"errors"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -49,6 +50,7 @@ func TestMultikueueAdapter(t *testing.T) {
}

pyTorchJobBuilder := kfutiltesting.MakePyTorchJob("pytorchjob1", TestNamespace).Queue("queue").Suspend(false)
pyTorchJobManagedByKueueBuilder := pyTorchJobBuilder.Clone().ManagedBy(kueue.MultiKueueControllerName)

cases := map[string]struct {
managersPyTorchJobs []kftraining.PyTorchJob
Expand Down Expand Up @@ -117,6 +119,42 @@ func TestMultikueueAdapter(t *testing.T) {
return adapter.DeleteRemoteObject(ctx, workerClient, types.NamespacedName{Name: "pytorchjob1", Namespace: TestNamespace})
},
},
"missing job is not considered managed": {
operation: func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error {
if isManged, _, _ := adapter.IsJobManagedByKueue(ctx, managerClient, types.NamespacedName{Name: "pytorchjob1", Namespace: TestNamespace}); isManged {
return errors.New("expecting false")
}
return nil
},
},
"job with wrong managedBy is not considered managed": {
managersPyTorchJobs: []kftraining.PyTorchJob{
*pyTorchJobBuilder.DeepCopy(),
},
operation: func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error {
if isManged, _, _ := adapter.IsJobManagedByKueue(ctx, managerClient, types.NamespacedName{Name: "pytorchjob1", Namespace: TestNamespace}); isManged {
return errors.New("expecting false")
}
return nil
},
wantManagersPyTorchJobs: []kftraining.PyTorchJob{
*pyTorchJobBuilder.DeepCopy(),
},
},
"job managedBy multikueue": {
managersPyTorchJobs: []kftraining.PyTorchJob{
*pyTorchJobManagedByKueueBuilder.DeepCopy(),
},
operation: func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error {
if isManged, _, _ := adapter.IsJobManagedByKueue(ctx, managerClient, types.NamespacedName{Name: "pytorchjob1", Namespace: TestNamespace}); !isManged {
return errors.New("expecting true")
}
return nil
},
wantManagersPyTorchJobs: []kftraining.PyTorchJob{
*pyTorchJobManagedByKueueBuilder.DeepCopy(),
},
},
}
for name, tc := range cases {
t.Run(name, func(t *testing.T) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package tfjob

import (
"context"
"errors"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -49,6 +50,7 @@ func TestMultikueueAdapter(t *testing.T) {
}

tfJobBuilder := kfutiltesting.MakeTFJob("tfjob1", TestNamespace).Queue("queue").Suspend(false)
tfJobManagedByKueueBuilder := tfJobBuilder.Clone().ManagedBy(kueue.MultiKueueControllerName)

cases := map[string]struct {
managersTFJobs []kftraining.TFJob
Expand Down Expand Up @@ -117,6 +119,42 @@ func TestMultikueueAdapter(t *testing.T) {
return adapter.DeleteRemoteObject(ctx, workerClient, types.NamespacedName{Name: "tfjob1", Namespace: TestNamespace})
},
},
"missing job is not considered managed": {
operation: func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error {
if isManged, _, _ := adapter.IsJobManagedByKueue(ctx, managerClient, types.NamespacedName{Name: "tfjob1", Namespace: TestNamespace}); isManged {
return errors.New("expecting false")
}
return nil
},
},
"job with wrong managedBy is not considered managed": {
managersTFJobs: []kftraining.TFJob{
*tfJobBuilder.DeepCopy(),
},
operation: func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error {
if isManged, _, _ := adapter.IsJobManagedByKueue(ctx, managerClient, types.NamespacedName{Name: "tfjob1", Namespace: TestNamespace}); isManged {
return errors.New("expecting false")
}
return nil
},
wantManagersTFJobs: []kftraining.TFJob{
*tfJobBuilder.DeepCopy(),
},
},
"job managedBy multikueue": {
managersTFJobs: []kftraining.TFJob{
*tfJobManagedByKueueBuilder.DeepCopy(),
},
operation: func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error {
if isManged, _, _ := adapter.IsJobManagedByKueue(ctx, managerClient, types.NamespacedName{Name: "tfjob1", Namespace: TestNamespace}); !isManged {
return errors.New("expecting true")
}
return nil
},
wantManagersTFJobs: []kftraining.TFJob{
*tfJobManagedByKueueBuilder.DeepCopy(),
},
},
}
for name, tc := range cases {
t.Run(name, func(t *testing.T) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package xgboostjob

import (
"context"
"errors"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -49,6 +50,7 @@ func TestMultikueueAdapter(t *testing.T) {
}

xgboostJobBuilder := kfutiltesting.MakeXGBoostJob("xgboostjob1", TestNamespace).Queue("queue").Suspend(false)
xgboostJobManagedByKueueBuilder := xgboostJobBuilder.Clone().ManagedBy(kueue.MultiKueueControllerName)

cases := map[string]struct {
managersXGBoostJobs []kftraining.XGBoostJob
Expand Down Expand Up @@ -117,6 +119,42 @@ func TestMultikueueAdapter(t *testing.T) {
return adapter.DeleteRemoteObject(ctx, workerClient, types.NamespacedName{Name: "xgboostjob1", Namespace: TestNamespace})
},
},
"missing job is not considered managed": {
operation: func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error {
if isManged, _, _ := adapter.IsJobManagedByKueue(ctx, managerClient, types.NamespacedName{Name: "xgboostjob1", Namespace: TestNamespace}); isManged {
return errors.New("expecting false")
}
return nil
},
},
"job with wrong managedBy is not considered managed": {
managersXGBoostJobs: []kftraining.XGBoostJob{
*xgboostJobBuilder.DeepCopy(),
},
operation: func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error {
if isManged, _, _ := adapter.IsJobManagedByKueue(ctx, managerClient, types.NamespacedName{Name: "xgboostjob1", Namespace: TestNamespace}); isManged {
return errors.New("expecting false")
}
return nil
},
wantManagersXGBoostJobs: []kftraining.XGBoostJob{
*xgboostJobBuilder.DeepCopy(),
},
},
"job managedBy multikueue": {
managersXGBoostJobs: []kftraining.XGBoostJob{
*xgboostJobManagedByKueueBuilder.DeepCopy(),
},
operation: func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error {
if isManged, _, _ := adapter.IsJobManagedByKueue(ctx, managerClient, types.NamespacedName{Name: "xgboostjob1", Namespace: TestNamespace}); !isManged {
return errors.New("expecting true")
}
return nil
},
wantManagersXGBoostJobs: []kftraining.XGBoostJob{
*xgboostJobManagedByKueueBuilder.DeepCopy(),
},
},
}
for name, tc := range cases {
t.Run(name, func(t *testing.T) {
Expand Down
16 changes: 16 additions & 0 deletions pkg/controller/jobs/kubeflow/kubeflowjob/kubeflowjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (

kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1"
"sigs.k8s.io/kueue/pkg/controller/jobframework"
"sigs.k8s.io/kueue/pkg/features"
"sigs.k8s.io/kueue/pkg/podset"
)

Expand All @@ -39,6 +40,7 @@ type KubeflowJob struct {
var _ jobframework.GenericJob = (*KubeflowJob)(nil)
var _ jobframework.JobWithPriorityClass = (*KubeflowJob)(nil)
var _ jobframework.JobWithCustomValidation = (*KubeflowJob)(nil)
var _ jobframework.JobWithManagedBy = (*KubeflowJob)(nil)

func (j *KubeflowJob) Object() client.Object {
return j.KFJobControl.Object()
Expand Down Expand Up @@ -196,3 +198,17 @@ func (j *KubeflowJob) ValidateOnUpdate(_ jobframework.GenericJob) field.ErrorLis
func podsCount(replicaSpecs map[kftraining.ReplicaType]*kftraining.ReplicaSpec, replicaType kftraining.ReplicaType) int32 {
return ptr.Deref(replicaSpecs[replicaType].Replicas, 1)
}

func (j *KubeflowJob) CanDefaultManagedBy() bool {
jobSpecManagedBy := j.KFJobControl.RunPolicy().ManagedBy
return features.Enabled(features.MultiKueue) &&
(jobSpecManagedBy == nil || *jobSpecManagedBy == kftraining.KubeflowJobsController)
}

func (j *KubeflowJob) ManagedBy() *string {
return j.KFJobControl.RunPolicy().ManagedBy
}

func (j *KubeflowJob) SetManagedBy(managedBy *string) {
j.KFJobControl.RunPolicy().ManagedBy = managedBy
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/types"
"k8s.io/klog/v2"
"k8s.io/utils/ptr"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"

Expand Down Expand Up @@ -77,7 +78,19 @@ func (a adapter[PtrT, T]) KeepAdmissionCheckPending() bool {
return false
}

func (a adapter[PtrT, T]) IsJobManagedByKueue(context.Context, client.Client, types.NamespacedName) (bool, string, error) {
func (a adapter[PtrT, T]) IsJobManagedByKueue(ctx context.Context, c client.Client, key types.NamespacedName) (bool, string, error) {
kJobObj := PtrT(new(T))
err := c.Get(ctx, key, kJobObj)
if err != nil {
return false, "", err
}

kJob := a.fromObject(kJobObj)
jobControllerName := ptr.Deref(kJob.KFJobControl.RunPolicy().ManagedBy, "")
if jobControllerName != kueue.MultiKueueControllerName {
return false, fmt.Sprintf("Expecting spec.managedBy to be %q not %q", kueue.MultiKueueControllerName, jobControllerName), nil
}

return true, "", nil
}

Expand Down Expand Up @@ -127,6 +140,12 @@ func (a adapter[PtrT, T]) SyncJob(
labels[kueue.MultiKueueOriginLabel] = origin
remoteJob.SetLabels(labels)

fmt.Println("KACZKA1", a.fromObject(remoteJob).ManagedBy() )
// clear the managedBy enables the controller to take over
a.fromObject(remoteJob).SetManagedBy(nil)

fmt.Println("KACZKA2", a.fromObject(remoteJob).ManagedBy() )

return remoteClient.Create(ctx, remoteJob)
}

Expand Down
6 changes: 6 additions & 0 deletions pkg/util/testingjobs/paddlejob/wrappers.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,3 +215,9 @@ func (j *PaddleJobWrapper) StatusConditions(conditions ...kftraining.JobConditio
j.Status.Conditions = conditions
return j
}

// ManagedBy adds a managedby.
func (j *PaddleJobWrapper) ManagedBy(c string) *PaddleJobWrapper {
j.Spec.RunPolicy.ManagedBy = &c
return j
}
6 changes: 6 additions & 0 deletions pkg/util/testingjobs/pytorchjob/wrappers_pytorchjob.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,9 @@ func (j *PyTorchJobWrapper) SetTypeMeta() *PyTorchJobWrapper {
j.Kind = kftraining.PyTorchJobKind
return j
}

// ManagedBy adds a managedby.
func (j *PyTorchJobWrapper) ManagedBy(c string) *PyTorchJobWrapper {
j.Spec.RunPolicy.ManagedBy = &c
return j
}
6 changes: 6 additions & 0 deletions pkg/util/testingjobs/tfjob/wrappers_tfjob.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,9 @@ func (j *TFJobWrapper) Image(replicaType kftraining.ReplicaType, image string, a
j.Spec.TFReplicaSpecs[replicaType].Template.Spec.Containers[0].Args = args
return j
}

// ManagedBy adds a managedby.
func (j *TFJobWrapper) ManagedBy(c string) *TFJobWrapper {
j.Spec.RunPolicy.ManagedBy = &c
return j
}
6 changes: 6 additions & 0 deletions pkg/util/testingjobs/xgboostjob/wrappers.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,3 +215,9 @@ func (j *XGBoostJobWrapper) StatusConditions(conditions ...kftraining.JobConditi
j.Status.Conditions = conditions
return j
}

// ManagedBy adds a managedby.
func (j *XGBoostJobWrapper) ManagedBy(c string) *XGBoostJobWrapper {
j.Spec.RunPolicy.ManagedBy = &c
return j
}

0 comments on commit a0cbb9a

Please sign in to comment.