diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/config/config.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/config/config.go index f1fae18c22..109ef06ba1 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/config/config.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/config/config.go @@ -6,7 +6,6 @@ package config import ( - schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler" "time" v1 "k8s.io/api/core/v1" @@ -207,8 +206,6 @@ type K8sPluginConfig struct { // SendObjectEvents indicates whether to send k8s object events in TaskExecutionEvent updates (similar to kubectl get events). SendObjectEvents bool `json:"send-object-events" pflag:",If true, will send k8s object events in TaskExecutionEvent updates."` - - BatchScheduler schedulerConfig.Config `json:"batchScheduler,omitempty"` } // FlyteCoPilotConfig specifies configuration for the Flyte CoPilot system. FlyteCoPilot, allows running flytekit-less containers diff --git a/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go b/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go index 32d0a45a16..6f50cfd29e 100644 --- a/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go +++ b/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go @@ -2,8 +2,10 @@ package k8s import ( "context" + "sigs.k8s.io/controller-runtime/pkg/client" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io" "github.com/flyteorg/flyte/flytestdlib/storage" @@ -186,3 +188,11 @@ func MaybeUpdatePhaseVersionFromPluginContext(phaseInfo *pluginsCore.PhaseInfo, MaybeUpdatePhaseVersion(phaseInfo, &pluginState) return nil } + +type YunikornScheduablePlugin interface { + MutateResourceForYunikorn(ctx context.Context, object client.Object, taskTmpl *core.TaskTemplate) (client.Object, error) +} + +type KueueScheduablePlugin interface { + MutateResourceForKueue(ctx context.Context, object client.Object, taskTmpl *core.TaskTemplate) (client.Object, error) +} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config.go index e75627b7a5..78769a1123 100644 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config.go +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config.go @@ -1,21 +1,23 @@ package batchscheduler type Config struct { - Scheduler string `json:"scheduler,omitempty" pflag:", Specify batch scheduler to"` - Parameters string `json:"parameters,omitempty" pflag:", Specify static parameters"` + Scheduler string `json:"scheduler,omitempty" pflag:", Specify batch scheduler to"` + Default SchedulingConfig `json:"default,omitempty" pflag:", Specify default scheduling config which batch scheduler adopts"` + NameSpace map[string]SchedulingConfig `json:"Namespace,omitempty" pflag:"-, Specify namespace scheduling config"` + Domain map[string]SchedulingConfig `json:"Domain,omitempty" pflag:"-, Specify domain scheduling config"` } -func NewConfig() Config { - return Config{ - Scheduler: "", - Parameters: "", - } +type SchedulingConfig struct { + KueueConfig `json:"Kueue,omitempty" pflag:", Specify Kueue scheduling scheduling config"` + YunikornConfig `json:"Yunikorn,omitempty" pflag:", Yunikorn scheduling config"` } -func (b *Config) GetScheduler() string { - return b.Scheduler +type KueueConfig struct { + PriorityClassName string `json:"Priority,omitempty" pflag:", Kueue Prioty class"` + Queue string `json:"Queue,omitempty" pflag:", Specify batch scheduler to"` } -func (b *Config) GetParameters() string { - return b.Parameters +type YunikornConfig struct { + Parameters string `json:"parameters,omitempty" pflag:", Specify gangscheduling policy"` + Queue string `json:"queue,omitempty" pflag:", Specify leaf queue to submit to"` } diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config_test.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config_test.go deleted file mode 100644 index e234008b53..0000000000 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config_test.go +++ /dev/null @@ -1,15 +0,0 @@ -package batchscheduler - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestNewConfig(t *testing.T) { - t.Run("New scheduler plugin config", func(t *testing.T) { - config := NewConfig() - assert.Equal(t, "", config.GetScheduler()) - assert.Equal(t, "", config.GetParameters()) - }) -} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/kueue/helper.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/kueue/helper.go new file mode 100644 index 0000000000..1889ec451d --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/kueue/helper.go @@ -0,0 +1,16 @@ +package kueue + +import ( + rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" + + "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/utils" +) + +const ( + QueueName = "kueue.x-k8s.io/queue-name" + PriorityClassName = "kueue.x-k8s.io/priority-class" +) + +func UpdateKueueLabels(labels map[string]string, app *rayv1.RayJob) { + utils.UpdateLabels(labels, &app.ObjectMeta) +} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/plugins.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/plugins.go deleted file mode 100644 index 944d722787..0000000000 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/plugins.go +++ /dev/null @@ -1,23 +0,0 @@ -package batchscheduler - -import ( - "context" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn" - "sigs.k8s.io/controller-runtime/pkg/client" -) - -type SchedulerManager interface { - // Mutate is responsible for mutating the object to be scheduled. - // It will add the necessary annotations, labels, etc. to the object. - Mutate(ctx context.Context, object client.Object) error -} - -func NewSchedulerManager(cfg *Config) SchedulerManager { - switch cfg.GetScheduler() { - case yunikorn.Yunikorn: - return yunikorn.NewSchedulerManager(cfg) - default: - return scheduler.NewNoopSchedulerManager() - } -} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/noop_scheduler.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/noop_scheduler.go deleted file mode 100644 index 8f9ca31146..0000000000 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/noop_scheduler.go +++ /dev/null @@ -1,16 +0,0 @@ -package scheduler - -import ( - "context" - "sigs.k8s.io/controller-runtime/pkg/client" -) - -type NoopSchedulerManager struct{} - -func NewNoopSchedulerManager() *NoopSchedulerManager { - return &NoopSchedulerManager{} -} - -func (p *NoopSchedulerManager) Mutate(ctx context.Context, object client.Object) error { - return nil -} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn.go deleted file mode 100644 index 6abfbbb434..0000000000 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn.go +++ /dev/null @@ -1,34 +0,0 @@ -package yunikorn - -import ( - "context" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler" - rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" - "sigs.k8s.io/controller-runtime/pkg/client" -) - -const ( - Yunikorn = "yunikorn" - AppID = "yunikorn.apache.org/app-id" - TaskGroupNameKey = "yunikorn.apache.org/task-group-name" - TaskGroupsKey = "yunikorn.apache.org/task-groups" - TaskGroupParameters = "yunikorn.apache.org/schedulingPolicyParameters" -) - -type YunikornSchedulerManager struct { - parameters string -} - -func (y *YunikornSchedulerManager) Mutate(ctx context.Context, object client.Object) error { - switch object.(type) { - case *rayv1.RayJob: - return MutateRayJob(y.parameters, object.(*rayv1.RayJob)) - default: - - } - return nil -} - -func NewSchedulerManager(cfg *batchscheduler.Config) batchscheduler.SchedulerManager { - return &YunikornSchedulerManager{parameters: cfg.Parameters} -} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn_test.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn_test.go deleted file mode 100644 index 08fe5a1761..0000000000 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn_test.go +++ /dev/null @@ -1,46 +0,0 @@ -package yunikorn - -import ( - "gotest.tools/assert" - "testing" -) - -func TestNewPlugin(t *testing.T) { - tests := []struct { - input string - expect *Plugin - }{ - { - input: "", - expect: &Plugin{ - Parameters: "", - }, - }, - { - input: "placeholderTimeoutInSeconds=30 gangSchedulingStyle=Hard", - expect: &Plugin{ - Parameters: "placeholderTimeoutInSeconds=30 gangSchedulingStyle=Hard", - }, - }, - } - t.Run("New Yunikorn plugin", func(t *testing.T) { - got := NewSchedulerManager(t.input) - assert.NotNil(t, got) - assert.Equal(t, t.input, got.Parameters) - }) -} - -func TestProcess(t *testing.T) { - tests := []struct { - input interface{} - expect error - }{ - {input: 1, expect: nil}, - {input: "test", expect: nil}, - } - t.Run("Yunikorn plugin process any type", func(t *testing.T) { - got := NewSchedulerManager(t.input) - assert.NotNil(t, got) - assert.Equal(t, t.input, got.Parameters) - }) -} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/utils/helper.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/utils/helper.go new file mode 100644 index 0000000000..7133a328f3 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/utils/helper.go @@ -0,0 +1,30 @@ +package utils + +import ( + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func UpdateLabels(wanted map[string]string, objectMeta *metav1.ObjectMeta) { + for key, value := range wanted { + if _, exist := objectMeta.Labels[key]; !exist { + objectMeta.Labels[key] = value + } + } +} + +func UpdateAnnotations(wanted map[string]string, objectMeta *metav1.ObjectMeta) { + for key, value := range wanted { + if _, exist := objectMeta.Annotations[key]; !exist { + objectMeta.Annotations[key] = value + } + } +} + +func UpdatePodTemplateAnnotatations(wanted map[string]string, pod *v1.PodTemplateSpec) { + UpdateAnnotations(wanted, &pod.ObjectMeta) +} + +func UpdatePodTemplateLabels(wanted map[string]string, pod *v1.PodTemplateSpec) { + UpdateLabels(wanted, &pod.ObjectMeta) +} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/rayhandler.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go similarity index 65% rename from flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/rayhandler.go rename to flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go index 0884a68df1..2480d80eee 100644 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/rayhandler.go +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go @@ -2,12 +2,25 @@ package yunikorn import ( "encoding/json" + rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/utils" +) + +const ( + Yunikorn = "yunikorn" + AppID = "yunikorn.apache.org/app-id" + Queue = "yunikorn.apache.org/queue" + TaskGroupNameKey = "yunikorn.apache.org/task-group-name" + TaskGroupsKey = "yunikorn.apache.org/task-groups" + TaskGroupParameters = "yunikorn.apache.org/schedulingPolicyParameters" ) -func MutateRayJob(parameters string, app *rayv1.RayJob) error { +func MutateRayJob(app *rayv1.RayJob) error { appID := GenerateTaskGroupAppID() rayjobSpec := &app.Spec appSpec := rayjobSpec.RayClusterSpec @@ -60,11 +73,30 @@ func MutateRayJob(parameters string, app *rayv1.RayJob) error { return err } meta.Annotations[TaskGroupsKey] = string(info[:]) - meta.Annotations[TaskGroupParameters] = parameters meta.Annotations[AppID] = appID return nil } +func UpdateGangSchedulingParameters(parameters string, objectMeta *metav1.ObjectMeta) { + if len(parameters) == 0 { + return + } + utils.UpdateAnnotations( + map[string]string{TaskGroupParameters: parameters}, + objectMeta, + ) +} + +func UpdateAnnotations(labels map[string]string, app *rayv1.RayJob) { + appSpec := app.Spec.RayClusterSpec + headSpec := appSpec.HeadGroupSpec + utils.UpdatePodTemplateAnnotatations(labels, &headSpec.Template) + for index := range appSpec.WorkerGroupSpecs { + worker := appSpec.WorkerGroupSpecs[index] + utils.UpdatePodTemplateAnnotatations(labels, &worker.Template) + } +} + func Allocation(containers []v1.Container) v1.ResourceList { totalResources := v1.ResourceList{} for _, c := range containers { @@ -81,19 +113,19 @@ func Allocation(containers []v1.Container) v1.ResourceList { return totalResources } -func Add(a v1.ResourceList, b v1.ResourceList) v1.ResourceList { - result := a - for name, value := range a { - sum := &value - if value2, ok := b[name]; ok { +func Add(left v1.ResourceList, right v1.ResourceList) v1.ResourceList { + result := left + for name, value := range left { + sum := value + if value2, ok := right[name]; ok { sum.Add(value2) - result[name] = *sum + result[name] = sum } else { result[name] = value } } - for name, value := range b { - if _, ok := a[name]; !ok { + for name, value := range right { + if _, ok := left[name]; !ok { result[name] = value } } diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/taskgroup.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/taskgroup.go similarity index 100% rename from flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/taskgroup.go rename to flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/taskgroup.go diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/taskgroup_test.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/taskgroup_test.go similarity index 86% rename from flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/taskgroup_test.go rename to flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/taskgroup_test.go index ce5bfe8d88..5472b6feb6 100644 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/taskgroup_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/taskgroup_test.go @@ -4,9 +4,15 @@ import ( "testing" "github.com/stretchr/testify/assert" + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" ) func TestMarshal(t *testing.T) { + res := v1.ResourceList{ + v1.ResourceCPU: resource.MustParse("500m"), + v1.ResourceMemory: resource.MustParse("512Mi"), + } t1 := TaskGroup{ Name: "tg1", MinMember: int32(1), @@ -43,4 +49,4 @@ func TestMarshal(t *testing.T) { assert.Nil(t, err) } }) -} \ No newline at end of file +} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/utils.go similarity index 100% rename from flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils.go rename to flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/utils.go diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils_test.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/utils_test.go similarity index 89% rename from flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils_test.go rename to flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/utils_test.go index 14aba4ddbc..ca0502baaf 100644 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/utils_test.go @@ -43,11 +43,9 @@ func TestGenerateTaskGroupName(t *testing.T) { func TestGenerateTaskGroupAppID(t *testing.T) { t.Run("Generate ray app ID", func(t *testing.T) { - for _, tt := range tests { - got := GenerateTaskGroupAppID() - if len(got) <= 0 { - t.Error("Ray app ID is empty") - } + got := GenerateTaskGroupAppID() + if len(got) <= 0 { + t.Error("Ray app ID is empty") } }) -} \ No newline at end of file +} diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/config.go b/flyteplugins/go/tasks/plugins/k8s/ray/config.go index 8be030413b..c4d84e2a48 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/config.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/config.go @@ -2,7 +2,6 @@ package ray import ( "context" - schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler" v1 "k8s.io/api/core/v1" @@ -10,6 +9,7 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/logs" pluginmachinery "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" + schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler" "github.com/flyteorg/flyte/flytestdlib/config" ) diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/config_flags.go b/flyteplugins/go/tasks/plugins/k8s/ray/config_flags.go index a725e93c5a..02284b8d5e 100755 --- a/flyteplugins/go/tasks/plugins/k8s/ray/config_flags.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/config_flags.go @@ -55,8 +55,11 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.String(fmt.Sprintf("%v%v", prefix, "serviceType"), defaultConfig.ServiceType, "") cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "includeDashboard"), defaultConfig.IncludeDashboard, "") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "dashboardHost"), defaultConfig.DashboardHost, "") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "batchScheduler.scheduler"), defaultConfig.BatchScheduler.Scheduler, "") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "batchScheduler.parameters"), defaultConfig.BatchScheduler.Parameters, "") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "batchScheduler.scheduler"), defaultConfig.BatchScheduler.Scheduler, " Specify batch scheduler to") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "batchScheduler.default.Kueue.Priority"), defaultConfig.BatchScheduler.Default.KueueConfig.PriorityClassName, " Kueue Prioty class") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "batchScheduler.default.Kueue.Queue"), defaultConfig.BatchScheduler.Default.KueueConfig.Queue, " Specify batch scheduler to") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "batchScheduler.default.Yunikorn.parameters"), defaultConfig.BatchScheduler.Default.YunikornConfig.Parameters, " Specify gangscheduling policy") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "batchScheduler.default.Yunikorn.queue"), defaultConfig.BatchScheduler.Default.YunikornConfig.Queue, " Specify leaf queue to submit to") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "remoteClusterConfig.name"), defaultConfig.RemoteClusterConfig.Name, "Friendly name of the remote cluster") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "remoteClusterConfig.endpoint"), defaultConfig.RemoteClusterConfig.Endpoint, " Remote K8s cluster endpoint") cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "remoteClusterConfig.enabled"), defaultConfig.RemoteClusterConfig.Enabled, " Boolean flag to enable or disable") diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/config_flags_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/config_flags_test.go index 960752c0a5..73c4d6de37 100755 --- a/flyteplugins/go/tasks/plugins/k8s/ray/config_flags_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/config_flags_test.go @@ -183,14 +183,56 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) - t.Run("Test_batchScheduler.parameters", func(t *testing.T) { + t.Run("Test_batchScheduler.default.Kueue.Priority", func(t *testing.T) { t.Run("Override", func(t *testing.T) { testValue := "1" - cmdFlags.Set("batchScheduler.parameters", testValue) - if vString, err := cmdFlags.GetString("batchScheduler.parameters"); err == nil { - testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.BatchScheduler.Parameters) + cmdFlags.Set("batchScheduler.default.Kueue.Priority", testValue) + if vString, err := cmdFlags.GetString("batchScheduler.default.Kueue.Priority"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.BatchScheduler.Default.KueueConfig.PriorityClassName) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_batchScheduler.default.Kueue.Queue", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("batchScheduler.default.Kueue.Queue", testValue) + if vString, err := cmdFlags.GetString("batchScheduler.default.Kueue.Queue"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.BatchScheduler.Default.KueueConfig.Queue) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_batchScheduler.default.Yunikorn.parameters", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("batchScheduler.default.Yunikorn.parameters", testValue) + if vString, err := cmdFlags.GetString("batchScheduler.default.Yunikorn.parameters"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.BatchScheduler.Default.YunikornConfig.Parameters) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_batchScheduler.default.Yunikorn.queue", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("batchScheduler.default.Yunikorn.queue", testValue) + if vString, err := cmdFlags.GetString("batchScheduler.default.Yunikorn.queue"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.BatchScheduler.Default.YunikornConfig.Queue) } else { assert.FailNow(t, err.Error()) diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go index 178d2427ad..31a7e44fcd 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go @@ -28,6 +28,8 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/kueue" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn" ) const ( @@ -551,6 +553,51 @@ func getEventInfoForRayJob(logConfig logs.LogConfig, pluginContext k8s.PluginCon return &pluginsCore.TaskInfo{Logs: taskLogs}, nil } +func (plugin rayJobResourceHandler) MutateResourceForYunikorn(ctx context.Context, object client.Object, taskTmpl *core.TaskTemplate) (client.Object, error) { + rayJob := object.(*rayv1.RayJob) + // Update gang scheduling annotations + if err := yunikorn.MutateRayJob(rayJob); err != nil { + return rayJob, err + } + // Update Yunikorn annotations + cfg := GetConfig().BatchScheduler.Default.YunikornConfig + id := taskTmpl.Id + annotations := make(map[string]string, 0) + queueName := fmt.Sprintf("root.%s.%s", id.Project, id.Domain) + if len(cfg.Queue) > 0 { + if cfg.Queue == "namespace" { + queueName = fmt.Sprintf("%s.%s", queueName, rayJob.ObjectMeta.Namespace) + } else { + queueName = fmt.Sprintf("%s.%s", queueName, cfg.Queue) + } + } else { + queueName = fmt.Sprintf("%s.%s", queueName, id.ResourceType) + } + annotations[yunikorn.Queue] = queueName + annotations[yunikorn.TaskGroupParameters] = cfg.Parameters + yunikorn.UpdateAnnotations(annotations, rayJob) + return rayJob, nil +} + +func (plugin rayJobResourceHandler) MutateResourceForKueue(ctx context.Context, object client.Object, taskTmpl *core.TaskTemplate) (client.Object, error) { + rayJob := object.(*rayv1.RayJob) + cfg := GetConfig().BatchScheduler.Default.KueueConfig + id := taskTmpl.Id + queueName := fmt.Sprintf("%s.%s", id.Project, id.Domain) + if len(cfg.Queue) > 0 { + if cfg.Queue == "namespace" { + queueName = fmt.Sprintf("%s.%s", queueName, rayJob.ObjectMeta.Namespace) + } else { + queueName = fmt.Sprintf("%s.%s", queueName, cfg.Queue) + } + } else { + queueName = fmt.Sprintf("%s.%s", queueName, id.ResourceType) + } + rayJob.ObjectMeta.Labels[kueue.QueueName] = queueName + rayJob.ObjectMeta.Labels[kueue.PriorityClassName] = cfg.PriorityClassName + return object, nil +} + func (plugin rayJobResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) { rayJob := resource.(*rayv1.RayJob) info, err := getEventInfoForRayJob(GetConfig().Logs, pluginContext, rayJob) diff --git a/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager.go b/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager.go index 6ea6bcede1..1c84eb7b38 100644 --- a/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager.go +++ b/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager.go @@ -3,7 +3,6 @@ package k8s import ( "context" "fmt" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler" "time" "golang.org/x/time/rate" @@ -91,7 +90,6 @@ type PluginManager struct { plugin k8s.Plugin resourceToWatch runtime.Object kubeClient pluginsCore.KubeClient - schedulerMgr batchscheduler.SchedulerManager metrics PluginMetrics // Per namespace-resource backOffController *backoff.Controller @@ -199,18 +197,19 @@ func (e *PluginManager) launchResource(ctx context.Context, tCtx pluginsCore.Tas if err != nil { return pluginsCore.UnknownTransition, err } + if p, ok := e.plugin.(k8s.YunikornScheduablePlugin); ok { + o, err = p.MutateResourceForYunikorn(ctx, o, tmpl) + if err != nil { + return pluginsCore.UnknownTransition, err + } + tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + } e.addObjectMetadata(k8sTaskCtxMetadata, o, config.GetK8sPluginConfig()) logger.Infof(ctx, "Creating Object: Type:[%v], Object:[%v/%v]", o.GetObjectKind().GroupVersionKind(), o.GetNamespace(), o.GetName()) key := backoff.ComposeResourceKey(o) - err = e.schedulerMgr.Mutate(ctx, o) - if err != nil { - logger.Errorf(ctx, "Scheduler plugin failed to process object with error: %v", err) - return pluginsCore.Transition{}, err - } - pod, casted := o.(*v1.Pod) if e.backOffController != nil && casted { podRequestedResources := e.getPodEffectiveResourceLimits(ctx, pod) @@ -544,8 +543,6 @@ func NewPluginManager(ctx context.Context, iCtx pluginsCore.SetupContext, entry return nil, errors.Errorf(errors.PluginInitializationFailed, "Failed to initialize K8sResource Plugin, Kubeclient cannot be nil!") } - schedulerMgr := batchscheduler.NewSchedulerManager(&config.GetK8sPluginConfig().BatchScheduler) - logger.Infof(ctx, "Initializing K8s plugin [%s]", entry.ID) src := source.Kind(iCtx.KubeClient().GetCache(), entry.ResourceToWatch) @@ -660,7 +657,6 @@ func NewPluginManager(ctx context.Context, iCtx pluginsCore.SetupContext, entry resourceToWatch: entry.ResourceToWatch, metrics: newPluginMetrics(metricsScope), kubeClient: kubeClient, - schedulerMgr: schedulerMgr, resourceLevelMonitor: rm, eventWatcher: eventWatcher, }, nil