diff --git a/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go b/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go index 38a84f9b2b..6f50cfd29e 100644 --- a/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go +++ b/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go @@ -5,6 +5,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io" "github.com/flyteorg/flyte/flytestdlib/storage" @@ -187,3 +188,11 @@ func MaybeUpdatePhaseVersionFromPluginContext(phaseInfo *pluginsCore.PhaseInfo, MaybeUpdatePhaseVersion(phaseInfo, &pluginState) return nil } + +type YunikornScheduablePlugin interface { + MutateResourceForYunikorn(ctx context.Context, object client.Object, taskTmpl *core.TaskTemplate) (client.Object, error) +} + +type KueueScheduablePlugin interface { + MutateResourceForKueue(ctx context.Context, object client.Object, taskTmpl *core.TaskTemplate) (client.Object, error) +} diff --git a/flyteplugins/go/tasks/pluginmachinery/workqueue/mocks/processor.go b/flyteplugins/go/tasks/pluginmachinery/workqueue/mocks/processor.go index ec5c0d12a6..b9d59d1b47 100644 --- a/flyteplugins/go/tasks/pluginmachinery/workqueue/mocks/processor.go +++ b/flyteplugins/go/tasks/pluginmachinery/workqueue/mocks/processor.go @@ -23,12 +23,12 @@ func (_m Processor_Process) Return(_a0 workqueue.WorkStatus, _a1 error) *Process } func (_m *Processor) OnProcess(ctx context.Context, workItem workqueue.WorkItem) *Processor_Process { - c_call := _m.On("Process", ctx, workItem) + c_call := _m.On("Mutate", ctx, workItem) return &Processor_Process{Call: c_call} } func (_m *Processor) OnProcessMatch(matchers ...interface{}) *Processor_Process { - c_call := _m.On("Process", matchers...) + c_call := _m.On("Mutate", matchers...) return &Processor_Process{Call: c_call} } diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config.go new file mode 100644 index 0000000000..78769a1123 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config.go @@ -0,0 +1,23 @@ +package batchscheduler + +type Config struct { + Scheduler string `json:"scheduler,omitempty" pflag:", Specify batch scheduler to"` + Default SchedulingConfig `json:"default,omitempty" pflag:", Specify default scheduling config which batch scheduler adopts"` + NameSpace map[string]SchedulingConfig `json:"Namespace,omitempty" pflag:"-, Specify namespace scheduling config"` + Domain map[string]SchedulingConfig `json:"Domain,omitempty" pflag:"-, Specify domain scheduling config"` +} + +type SchedulingConfig struct { + KueueConfig `json:"Kueue,omitempty" pflag:", Specify Kueue scheduling scheduling config"` + YunikornConfig `json:"Yunikorn,omitempty" pflag:", Yunikorn scheduling config"` +} + +type KueueConfig struct { + PriorityClassName string `json:"Priority,omitempty" pflag:", Kueue Prioty class"` + Queue string `json:"Queue,omitempty" pflag:", Specify batch scheduler to"` +} + +type YunikornConfig struct { + Parameters string `json:"parameters,omitempty" pflag:", Specify gangscheduling policy"` + Queue string `json:"queue,omitempty" pflag:", Specify leaf queue to submit to"` +} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/kueue/helper.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/kueue/helper.go new file mode 100644 index 0000000000..1889ec451d --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/kueue/helper.go @@ -0,0 +1,16 @@ +package kueue + +import ( + rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" + + "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/utils" +) + +const ( + QueueName = "kueue.x-k8s.io/queue-name" + PriorityClassName = "kueue.x-k8s.io/priority-class" +) + +func UpdateKueueLabels(labels map[string]string, app *rayv1.RayJob) { + utils.UpdateLabels(labels, &app.ObjectMeta) +} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/utils/helper.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/utils/helper.go new file mode 100644 index 0000000000..7133a328f3 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/utils/helper.go @@ -0,0 +1,30 @@ +package utils + +import ( + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func UpdateLabels(wanted map[string]string, objectMeta *metav1.ObjectMeta) { + for key, value := range wanted { + if _, exist := objectMeta.Labels[key]; !exist { + objectMeta.Labels[key] = value + } + } +} + +func UpdateAnnotations(wanted map[string]string, objectMeta *metav1.ObjectMeta) { + for key, value := range wanted { + if _, exist := objectMeta.Annotations[key]; !exist { + objectMeta.Annotations[key] = value + } + } +} + +func UpdatePodTemplateAnnotatations(wanted map[string]string, pod *v1.PodTemplateSpec) { + UpdateAnnotations(wanted, &pod.ObjectMeta) +} + +func UpdatePodTemplateLabels(wanted map[string]string, pod *v1.PodTemplateSpec) { + UpdateLabels(wanted, &pod.ObjectMeta) +} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go new file mode 100644 index 0000000000..2480d80eee --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go @@ -0,0 +1,133 @@ +package yunikorn + +import ( + "encoding/json" + + rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/utils" +) + +const ( + Yunikorn = "yunikorn" + AppID = "yunikorn.apache.org/app-id" + Queue = "yunikorn.apache.org/queue" + TaskGroupNameKey = "yunikorn.apache.org/task-group-name" + TaskGroupsKey = "yunikorn.apache.org/task-groups" + TaskGroupParameters = "yunikorn.apache.org/schedulingPolicyParameters" +) + +func MutateRayJob(app *rayv1.RayJob) error { + appID := GenerateTaskGroupAppID() + rayjobSpec := &app.Spec + appSpec := rayjobSpec.RayClusterSpec + TaskGroups := make([]TaskGroup, 1) + for index := range appSpec.WorkerGroupSpecs { + worker := &appSpec.WorkerGroupSpecs[index] + worker.Template.Spec.SchedulerName = Yunikorn + meta := worker.Template.ObjectMeta + spec := worker.Template.Spec + name := GenerateTaskGroupName(false, index) + TaskGroups = append(TaskGroups, TaskGroup{ + Name: name, + MinMember: *worker.Replicas, + Labels: meta.Labels, + Annotations: meta.Annotations, + MinResource: Allocation(spec.Containers), + NodeSelector: spec.NodeSelector, + Affinity: spec.Affinity, + TopologySpreadConstraints: spec.TopologySpreadConstraints, + }) + meta.Annotations[TaskGroupNameKey] = name + meta.Annotations[AppID] = appID + } + headSpec := &appSpec.HeadGroupSpec + headSpec.Template.Spec.SchedulerName = Yunikorn + meta := headSpec.Template.ObjectMeta + spec := headSpec.Template.Spec + headName := GenerateTaskGroupName(true, 0) + res := Allocation(spec.Containers) + if ok := *appSpec.EnableInTreeAutoscaling; ok { + res2 := v1.ResourceList{ + v1.ResourceCPU: resource.MustParse("500m"), + v1.ResourceMemory: resource.MustParse("512Mi"), + } + res = Add(res, res2) + } + TaskGroups[0] = TaskGroup{ + Name: headName, + MinMember: 1, + Labels: meta.Labels, + Annotations: meta.Annotations, + MinResource: res, + NodeSelector: spec.NodeSelector, + Affinity: spec.Affinity, + TopologySpreadConstraints: spec.TopologySpreadConstraints, + } + meta.Annotations[TaskGroupNameKey] = headName + info, err := json.Marshal(TaskGroups) + if err != nil { + return err + } + meta.Annotations[TaskGroupsKey] = string(info[:]) + meta.Annotations[AppID] = appID + return nil +} + +func UpdateGangSchedulingParameters(parameters string, objectMeta *metav1.ObjectMeta) { + if len(parameters) == 0 { + return + } + utils.UpdateAnnotations( + map[string]string{TaskGroupParameters: parameters}, + objectMeta, + ) +} + +func UpdateAnnotations(labels map[string]string, app *rayv1.RayJob) { + appSpec := app.Spec.RayClusterSpec + headSpec := appSpec.HeadGroupSpec + utils.UpdatePodTemplateAnnotatations(labels, &headSpec.Template) + for index := range appSpec.WorkerGroupSpecs { + worker := appSpec.WorkerGroupSpecs[index] + utils.UpdatePodTemplateAnnotatations(labels, &worker.Template) + } +} + +func Allocation(containers []v1.Container) v1.ResourceList { + totalResources := v1.ResourceList{} + for _, c := range containers { + for name, q := range c.Resources.Limits { + if _, exists := totalResources[name]; !exists { + totalResources[name] = q.DeepCopy() + continue + } + total := totalResources[name] + total.Add(q) + totalResources[name] = total + } + } + return totalResources +} + +func Add(left v1.ResourceList, right v1.ResourceList) v1.ResourceList { + result := left + for name, value := range left { + sum := value + if value2, ok := right[name]; ok { + sum.Add(value2) + result[name] = sum + } else { + result[name] = value + } + } + for name, value := range right { + if _, ok := left[name]; !ok { + result[name] = value + } + } + return result +} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/taskgroup.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/taskgroup.go new file mode 100644 index 0000000000..5a52579ce4 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/taskgroup.go @@ -0,0 +1,24 @@ +package yunikorn + +import ( + "encoding/json" + + v1 "k8s.io/api/core/v1" +) + +type TaskGroup struct { + Name string + MinMember int32 + Labels map[string]string + Annotations map[string]string + MinResource v1.ResourceList + NodeSelector map[string]string + Tolerations []v1.Toleration + Affinity *v1.Affinity + TopologySpreadConstraints []v1.TopologySpreadConstraint +} + +func Marshal(taskGroups []TaskGroup) ([]byte, error) { + info, err := json.Marshal(taskGroups) + return info, err +} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/taskgroup_test.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/taskgroup_test.go new file mode 100644 index 0000000000..5472b6feb6 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/taskgroup_test.go @@ -0,0 +1,52 @@ +package yunikorn + +import ( + "testing" + + "github.com/stretchr/testify/assert" + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" +) + +func TestMarshal(t *testing.T) { + res := v1.ResourceList{ + v1.ResourceCPU: resource.MustParse("500m"), + v1.ResourceMemory: resource.MustParse("512Mi"), + } + t1 := TaskGroup{ + Name: "tg1", + MinMember: int32(1), + Labels: map[string]string{"attr": "value"}, + Annotations: map[string]string{"attr": "value"}, + MinResource: res, + NodeSelector: map[string]string{"node": "gpunode"}, + Tolerations: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + } + t2 := TaskGroup{ + Name: "tg2", + MinMember: int32(1), + Labels: map[string]string{"attr": "value"}, + Annotations: map[string]string{"attr": "value"}, + MinResource: res, + NodeSelector: map[string]string{"node": "gpunode"}, + Tolerations: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + } + var tests = []struct { + input []TaskGroup + }{ + {input: nil}, + {input: []TaskGroup{}}, + {input: []TaskGroup{t1}}, + {input: []TaskGroup{t1, t2}}, + } + t.Run("Serialize task groups", func(t *testing.T) { + for _, tt := range tests { + _, err := Marshal(tt.input) + assert.Nil(t, err) + } + }) +} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/utils.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/utils.go new file mode 100644 index 0000000000..91a5357fac --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/utils.go @@ -0,0 +1,23 @@ +package yunikorn + +import ( + "fmt" + + "github.com/google/uuid" +) + +const ( + TaskGroupGenericName = "task-group" +) + +func GenerateTaskGroupName(master bool, index int) string { + if master { + return fmt.Sprintf("%s-%s", TaskGroupGenericName, "head") + } + return fmt.Sprintf("%s-%s-%d", TaskGroupGenericName, "worker", index) +} + +func GenerateTaskGroupAppID() string { + uid := uuid.New().String() + return fmt.Sprintf("%s-%s", TaskGroupGenericName, uid) +} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/utils_test.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/utils_test.go new file mode 100644 index 0000000000..ca0502baaf --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/utils_test.go @@ -0,0 +1,51 @@ +package yunikorn + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGenerateTaskGroupName(t *testing.T) { + type inputFormat struct { + isMaster bool + index int + } + var tests = []struct { + input inputFormat + expect string + }{ + { + input: inputFormat{isMaster: true, index: 0}, + expect: fmt.Sprintf("%s-%s", TaskGroupGenericName, "head"), + }, + { + input: inputFormat{isMaster: true, index: 1}, + expect: fmt.Sprintf("%s-%s", TaskGroupGenericName, "head"), + }, + { + input: inputFormat{isMaster: false, index: 0}, + expect: fmt.Sprintf("%s-%s-%d", TaskGroupGenericName, "worker", 0), + }, + { + input: inputFormat{isMaster: false, index: 1}, + expect: fmt.Sprintf("%s-%s-%d", TaskGroupGenericName, "worker", 1), + }, + } + t.Run("Generate ray task group name", func(t *testing.T) { + for _, tt := range tests { + got := GenerateTaskGroupName(tt.input.isMaster, tt.input.index) + assert.Equal(t, tt.expect, got) + } + }) +} + +func TestGenerateTaskGroupAppID(t *testing.T) { + t.Run("Generate ray app ID", func(t *testing.T) { + got := GenerateTaskGroupAppID() + if len(got) <= 0 { + t.Error("Ray app ID is empty") + } + }) +} diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/config.go b/flyteplugins/go/tasks/plugins/k8s/ray/config.go index 9a05f98f25..c4d84e2a48 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/config.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/config.go @@ -9,6 +9,7 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/logs" pluginmachinery "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" + schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler" "github.com/flyteorg/flyte/flytestdlib/config" ) @@ -23,6 +24,7 @@ var ( DashboardHost: "0.0.0.0", EnableUsageStats: false, ServiceAccount: "default", + BatchScheduler: schedulerConfig.Config{}, Defaults: DefaultConfig{ HeadNode: NodeConfig{ StartParameters: map[string]string{ @@ -76,6 +78,8 @@ type Config struct { // or 0.0.0.0 (available from all interfaces). By default, this is localhost. DashboardHost string `json:"dashboardHost,omitempty"` + BatchScheduler schedulerConfig.Config `json:"batchScheduler,omitempty"` + // DeprecatedNodeIPAddress the IP address of the head node. By default, this is pod ip address. DeprecatedNodeIPAddress string `json:"nodeIPAddress,omitempty" pflag:"-,DEPRECATED. Please use DefaultConfig.[HeadNode|WorkerNode].IPAddress"` diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/config_flags.go b/flyteplugins/go/tasks/plugins/k8s/ray/config_flags.go index 5048869eab..02284b8d5e 100755 --- a/flyteplugins/go/tasks/plugins/k8s/ray/config_flags.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/config_flags.go @@ -55,6 +55,11 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.String(fmt.Sprintf("%v%v", prefix, "serviceType"), defaultConfig.ServiceType, "") cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "includeDashboard"), defaultConfig.IncludeDashboard, "") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "dashboardHost"), defaultConfig.DashboardHost, "") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "batchScheduler.scheduler"), defaultConfig.BatchScheduler.Scheduler, " Specify batch scheduler to") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "batchScheduler.default.Kueue.Priority"), defaultConfig.BatchScheduler.Default.KueueConfig.PriorityClassName, " Kueue Prioty class") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "batchScheduler.default.Kueue.Queue"), defaultConfig.BatchScheduler.Default.KueueConfig.Queue, " Specify batch scheduler to") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "batchScheduler.default.Yunikorn.parameters"), defaultConfig.BatchScheduler.Default.YunikornConfig.Parameters, " Specify gangscheduling policy") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "batchScheduler.default.Yunikorn.queue"), defaultConfig.BatchScheduler.Default.YunikornConfig.Queue, " Specify leaf queue to submit to") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "remoteClusterConfig.name"), defaultConfig.RemoteClusterConfig.Name, "Friendly name of the remote cluster") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "remoteClusterConfig.endpoint"), defaultConfig.RemoteClusterConfig.Endpoint, " Remote K8s cluster endpoint") cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "remoteClusterConfig.enabled"), defaultConfig.RemoteClusterConfig.Enabled, " Boolean flag to enable or disable") diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/config_flags_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/config_flags_test.go index 05871adc51..73c4d6de37 100755 --- a/flyteplugins/go/tasks/plugins/k8s/ray/config_flags_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/config_flags_test.go @@ -169,6 +169,76 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_batchScheduler.scheduler", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("batchScheduler.scheduler", testValue) + if vString, err := cmdFlags.GetString("batchScheduler.scheduler"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.BatchScheduler.Scheduler) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_batchScheduler.default.Kueue.Priority", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("batchScheduler.default.Kueue.Priority", testValue) + if vString, err := cmdFlags.GetString("batchScheduler.default.Kueue.Priority"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.BatchScheduler.Default.KueueConfig.PriorityClassName) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_batchScheduler.default.Kueue.Queue", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("batchScheduler.default.Kueue.Queue", testValue) + if vString, err := cmdFlags.GetString("batchScheduler.default.Kueue.Queue"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.BatchScheduler.Default.KueueConfig.Queue) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_batchScheduler.default.Yunikorn.parameters", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("batchScheduler.default.Yunikorn.parameters", testValue) + if vString, err := cmdFlags.GetString("batchScheduler.default.Yunikorn.parameters"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.BatchScheduler.Default.YunikornConfig.Parameters) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_batchScheduler.default.Yunikorn.queue", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("batchScheduler.default.Yunikorn.queue", testValue) + if vString, err := cmdFlags.GetString("batchScheduler.default.Yunikorn.queue"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.BatchScheduler.Default.YunikornConfig.Queue) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) t.Run("Test_remoteClusterConfig.name", func(t *testing.T) { t.Run("Override", func(t *testing.T) { diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go index 90388b46a5..31a7e44fcd 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go @@ -28,6 +28,8 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/kueue" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn" ) const ( @@ -119,9 +121,7 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC podSpec.ServiceAccountName = cfg.ServiceAccount headPodSpec := podSpec.DeepCopy() - rayjob, err := constructRayJob(taskCtx, rayJob, objectMeta, *podSpec, headPodSpec, headNodeRayStartParams, primaryContainerIdx, *primaryContainer) - return rayjob, err } @@ -209,9 +209,9 @@ func constructRayJob(taskCtx pluginsCore.TaskExecutionContext, rayJob plugins.Ra submitterPodTemplate := buildSubmitterPodTemplate(headPodSpec, objectMeta, taskCtx) // TODO: This is for backward compatibility. Remove this block once runtime_env is removed from ray proto. - var err error var runtimeEnvYaml string runtimeEnvYaml = rayJob.RuntimeEnvYaml + var err error // If runtime_env exists but runtime_env_yaml does not, convert runtime_env to runtime_env_yaml if rayJob.RuntimeEnv != "" && rayJob.RuntimeEnvYaml == "" { runtimeEnvYaml, err = convertBase64RuntimeEnvToYaml(rayJob.RuntimeEnv) @@ -553,6 +553,51 @@ func getEventInfoForRayJob(logConfig logs.LogConfig, pluginContext k8s.PluginCon return &pluginsCore.TaskInfo{Logs: taskLogs}, nil } +func (plugin rayJobResourceHandler) MutateResourceForYunikorn(ctx context.Context, object client.Object, taskTmpl *core.TaskTemplate) (client.Object, error) { + rayJob := object.(*rayv1.RayJob) + // Update gang scheduling annotations + if err := yunikorn.MutateRayJob(rayJob); err != nil { + return rayJob, err + } + // Update Yunikorn annotations + cfg := GetConfig().BatchScheduler.Default.YunikornConfig + id := taskTmpl.Id + annotations := make(map[string]string, 0) + queueName := fmt.Sprintf("root.%s.%s", id.Project, id.Domain) + if len(cfg.Queue) > 0 { + if cfg.Queue == "namespace" { + queueName = fmt.Sprintf("%s.%s", queueName, rayJob.ObjectMeta.Namespace) + } else { + queueName = fmt.Sprintf("%s.%s", queueName, cfg.Queue) + } + } else { + queueName = fmt.Sprintf("%s.%s", queueName, id.ResourceType) + } + annotations[yunikorn.Queue] = queueName + annotations[yunikorn.TaskGroupParameters] = cfg.Parameters + yunikorn.UpdateAnnotations(annotations, rayJob) + return rayJob, nil +} + +func (plugin rayJobResourceHandler) MutateResourceForKueue(ctx context.Context, object client.Object, taskTmpl *core.TaskTemplate) (client.Object, error) { + rayJob := object.(*rayv1.RayJob) + cfg := GetConfig().BatchScheduler.Default.KueueConfig + id := taskTmpl.Id + queueName := fmt.Sprintf("%s.%s", id.Project, id.Domain) + if len(cfg.Queue) > 0 { + if cfg.Queue == "namespace" { + queueName = fmt.Sprintf("%s.%s", queueName, rayJob.ObjectMeta.Namespace) + } else { + queueName = fmt.Sprintf("%s.%s", queueName, cfg.Queue) + } + } else { + queueName = fmt.Sprintf("%s.%s", queueName, id.ResourceType) + } + rayJob.ObjectMeta.Labels[kueue.QueueName] = queueName + rayJob.ObjectMeta.Labels[kueue.PriorityClassName] = cfg.PriorityClassName + return object, nil +} + func (plugin rayJobResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) { rayJob := resource.(*rayv1.RayJob) info, err := getEventInfoForRayJob(GetConfig().Logs, pluginContext, rayJob) diff --git a/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager.go b/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager.go index f9c3806ee6..1c84eb7b38 100644 --- a/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager.go +++ b/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager.go @@ -197,6 +197,13 @@ func (e *PluginManager) launchResource(ctx context.Context, tCtx pluginsCore.Tas if err != nil { return pluginsCore.UnknownTransition, err } + if p, ok := e.plugin.(k8s.YunikornScheduablePlugin); ok { + o, err = p.MutateResourceForYunikorn(ctx, o, tmpl) + if err != nil { + return pluginsCore.UnknownTransition, err + } + tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + } e.addObjectMetadata(k8sTaskCtxMetadata, o, config.GetK8sPluginConfig()) logger.Infof(ctx, "Creating Object: Type:[%v], Object:[%v/%v]", o.GetObjectKind().GroupVersionKind(), o.GetNamespace(), o.GetName())