From 516978ee0ec50dd728c6146fa05ad28c482ab5c3 Mon Sep 17 00:00:00 2001 From: yuteng Date: Mon, 22 Jul 2024 22:02:07 +0800 Subject: [PATCH 01/30] gangscheduling annotations Signed-off-by: yuteng --- flyteidl/go.sum | 4 +- .../plugins/k8s/ray/batchscheduler/config.go | 6 + .../k8s/ray/batchscheduler/yunikorn.go | 131 ++++++++++++++++++ .../go/tasks/plugins/k8s/ray/config.go | 7 + flyteplugins/go/tasks/plugins/k8s/ray/ray.go | 23 ++- 5 files changed, 167 insertions(+), 4 deletions(-) create mode 100644 flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config.go create mode 100644 flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go diff --git a/flyteidl/go.sum b/flyteidl/go.sum index 5d5cb7e9a2..1819269c1c 100644 --- a/flyteidl/go.sum +++ b/flyteidl/go.sum @@ -214,8 +214,8 @@ github.com/prometheus/common v0.44.0 h1:+5BrQJwiBB9xsMygAB3TNvpQKOwlkc25LbISbrdO github.com/prometheus/common v0.44.0/go.mod h1:ofAIvZbQ1e/nugmZGz4/qCb9Ap1VoSTIO7x0VV9VvuY= github.com/prometheus/procfs v0.10.1 h1:kYK1Va/YMlutzCGazswoHKo//tZVlFpKYh+PymziUAg= github.com/prometheus/procfs v0.10.1/go.mod h1:nwNm2aOCAYw8uTR/9bWRREkZFxAUcWzPHWJq+XBB/FM= -github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= -github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config.go new file mode 100644 index 0000000000..fb9bef7652 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config.go @@ -0,0 +1,6 @@ +package batchscheduler + +type BatchSchedulerConfig struct { + Scheduler string `json:"scheduler"` + Parameters string `json:"parameters,omitempty` +} diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go new file mode 100644 index 0000000000..1feb54ff20 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go @@ -0,0 +1,131 @@ +package batchscheduler + +import ( + "encoding/json" + "fmt" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +const ( + // Pod lebel + BatchSchedulerLabel = "batch-scheduler" + SchedulerLabel = "scheduler" + SchedulerName = "yunikorn" + TaskGroupNameKey = "yunikorn.apache.org/task-group-name" + TaskGroupsKey = "yunikorn.apache.org/task-groups" + TaskGroupPrarameters = "yunikorn.apache.org/schedulingPolicyParameters" +) + +type TaskGroup struct { + Name string + MinMember int32 + Labels map[string]string + Annotations map[string]string + MinResource v1.ResourceList + NodeSelector map[string]string + Tolerations []v1.Toleration + Affinity *v1.Affinity + TopologySpreadConstraints []v1.TopologySpreadConstraint +} + +func GenerateTaskGroupName(master bool, index int) string { + tgName := "task-group" + if master { + return fmt.Sprintf("%s-%s", tgName, "head") + } + return fmt.Sprintf("%s-%s-%d", tgName, "worker", index) +} + +func SetSchedulerNameAndBuildGangInfo(config BatchSchedulerConfig, metadata *metav1.ObjectMeta, workerGroupsSpec []*plugins.WorkerGroupSpec, head, worker *v1.PodSpec) (map[string]map[string]string, error) { + if config.Scheduler != SchedulerName { + return nil, nil + } + head.SchedulerName = SchedulerName + worker.SchedulerName = SchedulerName + + TaskGroupsAnnotations := make(map[string]map[string]string, 0) + // Parsing placeholders from the pod resource among head and workers + TaskGroups := make([]TaskGroup, 0) + headName := GenerateTaskGroupNameFromMaster(metadata.Name, true, 0) + TaskGroups = append(TaskGroups, TaskGroup{ + Name: headName, + MinMember: 1, + Labels: metadata.Labels, + Annotations: metadata.Annotations, + MinResource: head.Containers[0].Resources.Requests, + NodeSelector: head.NodeSelector, + Affinity: head.Affinity, + TopologySpreadConstraints: head.TopologySpreadConstraints, + }) + + for index, spec := range workerGroupsSpec { + name := GenerateTaskGroupNameFromMaster(metadata.Name, false, index) + tg := TaskGroup{ + Name: name, + MinMember: spec.Replicas, + Labels: metadata.Labels, + Annotations: metadata.Annotations, + MinResource: worker.Containers[0].Resources.Requests, + NodeSelector: worker.NodeSelector, + Affinity: worker.Affinity, + TopologySpreadConstraints: worker.TopologySpreadConstraints, + } + TaskGroupsAnnotations[name] = map[string]string{ + TaskGroupNameKey: name, + } + TaskGroups = append(TaskGroups, tg) + } + + // Yunikorn head gang scheduling annotations + info, _ := json.Marshal(TaskGroups) + if err != nil { + return nil, err + } + headAnnotations := make(map[string]string, 0) + headAnnotations[TaskGroupNameKey] = headName + headAnnotations[TaskGroupsKey] = string(info[:]) + headAnnotations[TaskGroupPrarameters] = config.Parameters + TaskGroupsAnnotations[headName] = headAnnotations + return TaskGroupsAnnotations, nil +} + +func AddGangSchedulingAnnotations(name string, metadata *metav1.ObjectMeta, TGAnnotations map[string]map[string]string) { + if TGAnnotations == nil { + return + } + + if _, ok := TGAnnotations[name]; !ok { + return + } + + annotations := TGAnnotations[name] + if _, ok := metadata.Annotations[TaskGroupNameKey]; !ok { + metadata.Annotations[TaskGroupNameKey] = annotations[TaskGroupNameKey] + } + if _, ok := metadata.Annotations[TaskGroupsKey]; !ok { + metadata.Annotations[TaskGroupsKey] = annotations[TaskGroupsKey] + } + if _, ok := metadata.Annotations[TaskGroupPrarameters]; !ok { + if _, ok = annotations[TaskGroupPrarameters]; !ok { + return + } + metadata.Annotations[TaskGroupPrarameters] = annotations[TaskGroupPrarameters] + } + return +} + +func RemoveGangSchedulingAnnotations(metadata *metav1.ObjectMeta) { + if _, ok := metadata.Annotations[TaskGroupNameKey]; ok { + delete(metadata.Annotations, TaskGroupNameKey) + } + if _, ok := metadata.Annotations[TaskGroupsKey]; ok { + delete(metadata.Annotations, TaskGroupsKey) + } + if _, ok := metadata.Annotations[TaskGroupPrarameters]; ok { + delete(metadata.Annotations, TaskGroupPrarameters) + } + return +} diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/config.go b/flyteplugins/go/tasks/plugins/k8s/ray/config.go index 9a05f98f25..f39782421f 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/config.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/config.go @@ -9,6 +9,7 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/logs" pluginmachinery "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler" "github.com/flyteorg/flyte/flytestdlib/config" ) @@ -23,6 +24,10 @@ var ( DashboardHost: "0.0.0.0", EnableUsageStats: false, ServiceAccount: "default", + BatchScheduler: batchscheduler.BatchSchedulerConfig{ + Scheduler: "yunikorn", + Parameters: "timeout=10", + }, Defaults: DefaultConfig{ HeadNode: NodeConfig{ StartParameters: map[string]string{ @@ -76,6 +81,8 @@ type Config struct { // or 0.0.0.0 (available from all interfaces). By default, this is localhost. DashboardHost string `json:"dashboardHost,omitempty"` + BatchScheduler batchscheduler.BatchSchedulerConfig `json:"batchSchedulerConfig,omitempty"` + // DeprecatedNodeIPAddress the IP address of the head node. By default, this is pod ip address. DeprecatedNodeIPAddress string `json:"nodeIPAddress,omitempty" pflag:"-,DEPRECATED. Please use DefaultConfig.[HeadNode|WorkerNode].IPAddress"` diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go index 90388b46a5..185d55316d 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go @@ -28,6 +28,7 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler" ) const ( @@ -119,7 +120,6 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC podSpec.ServiceAccountName = cfg.ServiceAccount headPodSpec := podSpec.DeepCopy() - rayjob, err := constructRayJob(taskCtx, rayJob, objectMeta, *podSpec, headPodSpec, headNodeRayStartParams, primaryContainerIdx, *primaryContainer) return rayjob, err @@ -128,6 +128,18 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC func constructRayJob(taskCtx pluginsCore.TaskExecutionContext, rayJob plugins.RayJob, objectMeta *metav1.ObjectMeta, podSpec v1.PodSpec, headPodSpec *v1.PodSpec, headNodeRayStartParams map[string]string, primaryContainerIdx int, primaryContainer v1.Container) (*rayv1.RayJob, error) { enableIngress := true cfg := GetConfig() + TGAnnotations := batchscheduler.SetSchedulerNameAndBuildGangInfo( + cfg.BatchScheduler, + objectMeta, + rayJob.RayCluster.WorkerGroupSpec, + &podSpec, + headPodSpec, + ) + batchscheduler.AddGangSchedulingAnnotations( + batchscheduler.GenerateTaskGroupName(true, 0), + objectMeta, + TGAnnotations, + ) rayClusterSpec := rayv1.RayClusterSpec{ HeadGroupSpec: rayv1.HeadGroupSpec{ Template: buildHeadPodTemplate( @@ -143,9 +155,15 @@ func constructRayJob(taskCtx pluginsCore.TaskExecutionContext, rayJob plugins.Ra WorkerGroupSpecs: []rayv1.WorkerGroupSpec{}, EnableInTreeAutoscaling: &rayJob.RayCluster.EnableAutoscaling, } + batchscheduler.RemoveGangSchedulingAnnotations(objectMeta) - for _, spec := range rayJob.RayCluster.WorkerGroupSpec { + for index, spec := range rayJob.RayCluster.WorkerGroupSpec { workerPodSpec := podSpec.DeepCopy() + batchscheduler.AddGangSchedulingAnnotations( + batchscheduler.GenerateTaskGroupName(false, index), + objectMeta, + TGAnnotations, + ) workerPodTemplate := buildWorkerPodTemplate( &workerPodSpec.Containers[primaryContainerIdx], workerPodSpec, @@ -188,6 +206,7 @@ func constructRayJob(taskCtx pluginsCore.TaskExecutionContext, rayJob plugins.Ra rayClusterSpec.WorkerGroupSpecs = append(rayClusterSpec.WorkerGroupSpecs, workerNodeSpec) } + batchscheduler.RemoveGangSchedulingAnnotations(objectMeta) serviceAccountName := flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) if len(serviceAccountName) == 0 { From 0e00f6181676045177952812dc7ef05ad8a3e365 Mon Sep 17 00:00:00 2001 From: yuteng Date: Mon, 22 Jul 2024 23:04:15 +0800 Subject: [PATCH 02/30] test Signed-off-by: yuteng --- .../k8s/ray/batchscheduler/yunikorn.go | 12 ++--- .../k8s/ray/batchscheduler/yunikorn_test.go | 49 +++++++++++++++++++ .../go/tasks/plugins/k8s/ray/config.go | 4 +- 3 files changed, 57 insertions(+), 8 deletions(-) create mode 100644 flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn_test.go diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go index 1feb54ff20..997419720b 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go @@ -17,6 +17,7 @@ const ( TaskGroupNameKey = "yunikorn.apache.org/task-group-name" TaskGroupsKey = "yunikorn.apache.org/task-groups" TaskGroupPrarameters = "yunikorn.apache.org/schedulingPolicyParameters" + TaskGroupGenericName = "task-group" ) type TaskGroup struct { @@ -32,11 +33,10 @@ type TaskGroup struct { } func GenerateTaskGroupName(master bool, index int) string { - tgName := "task-group" if master { - return fmt.Sprintf("%s-%s", tgName, "head") + return fmt.Sprintf("%s-%s", TaskGroupGenericName, "head") } - return fmt.Sprintf("%s-%s-%d", tgName, "worker", index) + return fmt.Sprintf("%s-%s-%d", TaskGroupGenericName, "worker", index) } func SetSchedulerNameAndBuildGangInfo(config BatchSchedulerConfig, metadata *metav1.ObjectMeta, workerGroupsSpec []*plugins.WorkerGroupSpec, head, worker *v1.PodSpec) (map[string]map[string]string, error) { @@ -102,13 +102,13 @@ func AddGangSchedulingAnnotations(name string, metadata *metav1.ObjectMeta, TGAn } annotations := TGAnnotations[name] - if _, ok := metadata.Annotations[TaskGroupNameKey]; !ok { + if _, ok := metadata.Annotations[TaskGroupNameKey]; ok { metadata.Annotations[TaskGroupNameKey] = annotations[TaskGroupNameKey] } - if _, ok := metadata.Annotations[TaskGroupsKey]; !ok { + if _, ok := metadata.Annotations[TaskGroupsKey]; ok { metadata.Annotations[TaskGroupsKey] = annotations[TaskGroupsKey] } - if _, ok := metadata.Annotations[TaskGroupPrarameters]; !ok { + if _, ok := metadata.Annotations[TaskGroupPrarameters]; ok { if _, ok = annotations[TaskGroupPrarameters]; !ok { return } diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn_test.go new file mode 100644 index 0000000000..49e94a71b7 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn_test.go @@ -0,0 +1,49 @@ +package batchscheduler + +import ( + "fmt" + "testing" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func TestGenerateTaskGroupName(t *testing.T) { + var tests = []struct{ + master bool + index int + expect string + }{ + {true, 0, fmt.Sprintf("%s-%s", GenerateTaskGroupName, "head")}, + {false, 0, fmt.Sprintf("%s-%s-%d", GenerateTaskGroupName, "worker", 0)}, + {false, 1, fmt.Sprintf("%s-%s-%d", GenerateTaskGroupName, "worker", 1)}, + } + for _, tt := range tests { + t.Run("Generating Task group name", func(t *testing.T) { + if got := GenerateTaskGroupName(tt.master, tt.index); got != tt.expect { + t.Errorf("got %s, expect %s", got, tt.expect) + } + }) + } +} + +func TestRemoveGangSchedulingAnnotations(t *testing.T) { + var tests = []struct{ + input *metav1.ObjectMeta + expect int + }{ + {input: &metav1.ObjectMeta{"others": "extra", TaskGroupNameKey: "TGName", TaskGroupsKey: "TGs", TaskGroupPrarameters: "parameters"}, 1}, + {input: &metav1.ObjectMeta{TaskGroupNameKey: "TGName", TaskGroupsKey: "TGs", TaskGroupPrarameters: "parameters"}, 0}, + {input: &metav1.ObjectMeta{TaskGroupNameKey: "TGName", TaskGroupsKey: "TGs"}, 0}, + {input: &metav1.ObjectMeta{TaskGroupNameKey: "TGName"}, 0}, + {input: &metav1.ObjectMeta{}, 0}, + } + for _, tt := range tests { + t.Run("Remove Gang scheduling labels", func(t *testing.T){ + RemoveGangSchedulingAnnotations(tt.input) + if got := len(tt.input); got != tt.expect { + t.Errorf("got %d, expect %d", got, tt.expect) + } + }) + } +} + diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/config.go b/flyteplugins/go/tasks/plugins/k8s/ray/config.go index f39782421f..8e6d52d98a 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/config.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/config.go @@ -25,8 +25,8 @@ var ( EnableUsageStats: false, ServiceAccount: "default", BatchScheduler: batchscheduler.BatchSchedulerConfig{ - Scheduler: "yunikorn", - Parameters: "timeout=10", + Scheduler: "", + Parameters: "", }, Defaults: DefaultConfig{ HeadNode: NodeConfig{ From 93090066d1fdd9d39763d79eb54950de1ae78aae Mon Sep 17 00:00:00 2001 From: yuteng Date: Mon, 22 Jul 2024 23:20:40 +0800 Subject: [PATCH 03/30] test Signed-off-by: yuteng --- .../plugins/k8s/ray/batchscheduler/yunikorn_test.go | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn_test.go index 49e94a71b7..76f0cf3e72 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn_test.go @@ -8,9 +8,9 @@ import ( ) func TestGenerateTaskGroupName(t *testing.T) { - var tests = []struct{ + var tests = []struct { master bool - index int + index int expect string }{ {true, 0, fmt.Sprintf("%s-%s", GenerateTaskGroupName, "head")}, @@ -27,8 +27,8 @@ func TestGenerateTaskGroupName(t *testing.T) { } func TestRemoveGangSchedulingAnnotations(t *testing.T) { - var tests = []struct{ - input *metav1.ObjectMeta + var tests = []struct { + input *metav1.ObjectMeta expect int }{ {input: &metav1.ObjectMeta{"others": "extra", TaskGroupNameKey: "TGName", TaskGroupsKey: "TGs", TaskGroupPrarameters: "parameters"}, 1}, @@ -38,7 +38,7 @@ func TestRemoveGangSchedulingAnnotations(t *testing.T) { {input: &metav1.ObjectMeta{}, 0}, } for _, tt := range tests { - t.Run("Remove Gang scheduling labels", func(t *testing.T){ + t.Run("Remove Gang scheduling labels", func(t *testing.T) { RemoveGangSchedulingAnnotations(tt.input) if got := len(tt.input); got != tt.expect { t.Errorf("got %d, expect %d", got, tt.expect) @@ -46,4 +46,3 @@ func TestRemoveGangSchedulingAnnotations(t *testing.T) { }) } } - From dc0e84e48f234e7dc2735c6e2ece7107605708ba Mon Sep 17 00:00:00 2001 From: yuteng Date: Tue, 23 Jul 2024 21:36:43 +0800 Subject: [PATCH 04/30] return error when parse fail Signed-off-by: yuteng --- .../go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go | 6 +++--- flyteplugins/go/tasks/plugins/k8s/ray/ray.go | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go index 997419720b..780036a462 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go @@ -49,7 +49,7 @@ func SetSchedulerNameAndBuildGangInfo(config BatchSchedulerConfig, metadata *met TaskGroupsAnnotations := make(map[string]map[string]string, 0) // Parsing placeholders from the pod resource among head and workers TaskGroups := make([]TaskGroup, 0) - headName := GenerateTaskGroupNameFromMaster(metadata.Name, true, 0) + headName := GenerateTaskGroupName(true, 0) TaskGroups = append(TaskGroups, TaskGroup{ Name: headName, MinMember: 1, @@ -62,7 +62,7 @@ func SetSchedulerNameAndBuildGangInfo(config BatchSchedulerConfig, metadata *met }) for index, spec := range workerGroupsSpec { - name := GenerateTaskGroupNameFromMaster(metadata.Name, false, index) + name := GenerateTaskGroupName(false, index) tg := TaskGroup{ Name: name, MinMember: spec.Replicas, @@ -80,7 +80,7 @@ func SetSchedulerNameAndBuildGangInfo(config BatchSchedulerConfig, metadata *met } // Yunikorn head gang scheduling annotations - info, _ := json.Marshal(TaskGroups) + info, err := json.Marshal(TaskGroups) if err != nil { return nil, err } diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go index 185d55316d..a4fd0537bb 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go @@ -128,7 +128,8 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC func constructRayJob(taskCtx pluginsCore.TaskExecutionContext, rayJob plugins.RayJob, objectMeta *metav1.ObjectMeta, podSpec v1.PodSpec, headPodSpec *v1.PodSpec, headNodeRayStartParams map[string]string, primaryContainerIdx int, primaryContainer v1.Container) (*rayv1.RayJob, error) { enableIngress := true cfg := GetConfig() - TGAnnotations := batchscheduler.SetSchedulerNameAndBuildGangInfo( + var err error + TGAnnotations, err := batchscheduler.SetSchedulerNameAndBuildGangInfo( cfg.BatchScheduler, objectMeta, rayJob.RayCluster.WorkerGroupSpec, @@ -228,7 +229,6 @@ func constructRayJob(taskCtx pluginsCore.TaskExecutionContext, rayJob plugins.Ra submitterPodTemplate := buildSubmitterPodTemplate(headPodSpec, objectMeta, taskCtx) // TODO: This is for backward compatibility. Remove this block once runtime_env is removed from ray proto. - var err error var runtimeEnvYaml string runtimeEnvYaml = rayJob.RuntimeEnvYaml // If runtime_env exists but runtime_env_yaml does not, convert runtime_env to runtime_env_yaml From 20f3646127405a86cfda3bc079e6b46f9eaade0f Mon Sep 17 00:00:00 2001 From: yuteng Date: Tue, 23 Jul 2024 22:21:23 +0800 Subject: [PATCH 05/30] update when no assign Signed-off-by: yuteng --- .../go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go | 6 +++--- flyteplugins/go/tasks/plugins/k8s/ray/config.go | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go index 780036a462..0cb3c9f7b0 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go @@ -102,13 +102,13 @@ func AddGangSchedulingAnnotations(name string, metadata *metav1.ObjectMeta, TGAn } annotations := TGAnnotations[name] - if _, ok := metadata.Annotations[TaskGroupNameKey]; ok { + if _, ok := metadata.Annotations[TaskGroupNameKey]; !ok { metadata.Annotations[TaskGroupNameKey] = annotations[TaskGroupNameKey] } - if _, ok := metadata.Annotations[TaskGroupsKey]; ok { + if _, ok := metadata.Annotations[TaskGroupsKey]; !ok { metadata.Annotations[TaskGroupsKey] = annotations[TaskGroupsKey] } - if _, ok := metadata.Annotations[TaskGroupPrarameters]; ok { + if _, ok := metadata.Annotations[TaskGroupPrarameters]; !ok { if _, ok = annotations[TaskGroupPrarameters]; !ok { return } diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/config.go b/flyteplugins/go/tasks/plugins/k8s/ray/config.go index 8e6d52d98a..f39782421f 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/config.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/config.go @@ -25,8 +25,8 @@ var ( EnableUsageStats: false, ServiceAccount: "default", BatchScheduler: batchscheduler.BatchSchedulerConfig{ - Scheduler: "", - Parameters: "", + Scheduler: "yunikorn", + Parameters: "timeout=10", }, Defaults: DefaultConfig{ HeadNode: NodeConfig{ From f86948e09f4f86624ec61cbf5655dd42193a26c2 Mon Sep 17 00:00:00 2001 From: yuteng Date: Wed, 24 Jul 2024 14:54:24 +0800 Subject: [PATCH 06/30] fix bug that worker groups share same task group name Signed-off-by: yuteng --- .../plugins/k8s/ray/batchscheduler/yunikorn.go | 14 +++++++++----- flyteplugins/go/tasks/plugins/k8s/ray/ray.go | 2 +- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go index 0cb3c9f7b0..4dbe538460 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go @@ -101,18 +101,22 @@ func AddGangSchedulingAnnotations(name string, metadata *metav1.ObjectMeta, TGAn return } + // Updating Yunikorn gang scheduling annotations annotations := TGAnnotations[name] if _, ok := metadata.Annotations[TaskGroupNameKey]; !ok { - metadata.Annotations[TaskGroupNameKey] = annotations[TaskGroupNameKey] + if _, ok = annotations[TaskGroupNameKey]; ok { + metadata.Annotations[TaskGroupNameKey] = annotations[TaskGroupNameKey] + } } if _, ok := metadata.Annotations[TaskGroupsKey]; !ok { - metadata.Annotations[TaskGroupsKey] = annotations[TaskGroupsKey] + if _, ok = annotations[TaskGroupsKey]; ok { + metadata.Annotations[TaskGroupsKey] = annotations[TaskGroupsKey] + } } if _, ok := metadata.Annotations[TaskGroupPrarameters]; !ok { - if _, ok = annotations[TaskGroupPrarameters]; !ok { - return + if _, ok = annotations[TaskGroupPrarameters]; ok { + metadata.Annotations[TaskGroupPrarameters] = annotations[TaskGroupPrarameters] } - metadata.Annotations[TaskGroupPrarameters] = annotations[TaskGroupPrarameters] } return } diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go index a4fd0537bb..f73b21dd99 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go @@ -206,8 +206,8 @@ func constructRayJob(taskCtx pluginsCore.TaskExecutionContext, rayJob plugins.Ra } rayClusterSpec.WorkerGroupSpecs = append(rayClusterSpec.WorkerGroupSpecs, workerNodeSpec) + batchscheduler.RemoveGangSchedulingAnnotations(objectMeta) } - batchscheduler.RemoveGangSchedulingAnnotations(objectMeta) serviceAccountName := flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) if len(serviceAccountName) == 0 { From 75ccabc1e26ff22cc851d1f273ff32e3003631ab Mon Sep 17 00:00:00 2001 From: yuteng Date: Thu, 25 Jul 2024 06:25:10 +0800 Subject: [PATCH 07/30] make test pass Signed-off-by: yuteng --- .../k8s/ray/batchscheduler/yunikorn_test.go | 91 +++++++++++++++++-- .../go/tasks/plugins/k8s/ray/config.go | 4 +- 2 files changed, 83 insertions(+), 12 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn_test.go index 76f0cf3e72..edd8fda8ea 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn_test.go @@ -1,10 +1,46 @@ package batchscheduler import ( - "fmt" "testing" + "k8s.io/apimachinery/pkg/api/resource" + v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" +) + +var ( + podSpec = &v1.PodSpec{ + Containers: []v1.Container{ + v1.Container{ + Resources: v1.ResourceRequirements{ + Requests: v1.ResourceList { + "cpu": resource.MustParse("500m"), + "memory": resource.MustParse("1Gi"), + }, + }, + }, + }, + NodeSelector: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + } + rayWorkersSpec = []*plugins.WorkerGroupSpec{ + &plugins.WorkerGroupSpec{ + GroupName: "group1", + Replicas: int32(1), + MinReplicas: int32(1), + MaxReplicas: int32(2), + RayStartParams: nil, + }, + &plugins.WorkerGroupSpec{ + GroupName: "group2", + Replicas: int32(1), + MinReplicas: int32(1), + MaxReplicas: int32(2), + RayStartParams: nil, + }, + } ) func TestGenerateTaskGroupName(t *testing.T) { @@ -13,9 +49,9 @@ func TestGenerateTaskGroupName(t *testing.T) { index int expect string }{ - {true, 0, fmt.Sprintf("%s-%s", GenerateTaskGroupName, "head")}, - {false, 0, fmt.Sprintf("%s-%s-%d", GenerateTaskGroupName, "worker", 0)}, - {false, 1, fmt.Sprintf("%s-%s-%d", GenerateTaskGroupName, "worker", 1)}, + {true, 0, GenerateTaskGroupName(true, 0)}, + {false, 0, GenerateTaskGroupName(false, 0)}, + {false, 1, GenerateTaskGroupName(false, 1)}, } for _, tt := range tests { t.Run("Generating Task group name", func(t *testing.T) { @@ -26,21 +62,56 @@ func TestGenerateTaskGroupName(t *testing.T) { } } +func TestSetSchedulerName(t *testing.T) { + head := podSpec.DeepCopy() + worker := podSpec.DeepCopy() + var tests = []struct { + schedulerConfig BatchSchedulerConfig + expect string + }{ + {schedulerConfig: BatchSchedulerConfig{Scheduler:"", Parameters:""}, expect: ""}, + {schedulerConfig: BatchSchedulerConfig{Scheduler:SchedulerName, Parameters:"gangSchedulingStyle=Hard"}, expect: SchedulerName}, + {schedulerConfig: BatchSchedulerConfig{Scheduler:"other", Parameters:""}, expect: ""}, + } + for _, tt := range tests { + t.Run("Scheduler Name", func(t *testing.T) { + SetSchedulerNameAndBuildGangInfo( + tt.schedulerConfig, + &metav1.ObjectMeta{ + Labels: map[string]string{}, + Annotations: map[string]string{}, + }, + rayWorkersSpec, + head, + worker, + ) + if got := head.SchedulerName; got != tt.expect { + t.Errorf("head pod scheduler name: expect %s, got %s", tt.expect, got) + } + if got := worker.SchedulerName; got != tt.expect { + t.Errorf("worker pod scheduler name: expect %s, got %s", tt.expect, got) + } + head.SchedulerName = "" + worker.SchedulerName = "" + }) + } +} + func TestRemoveGangSchedulingAnnotations(t *testing.T) { var tests = []struct { input *metav1.ObjectMeta expect int }{ - {input: &metav1.ObjectMeta{"others": "extra", TaskGroupNameKey: "TGName", TaskGroupsKey: "TGs", TaskGroupPrarameters: "parameters"}, 1}, - {input: &metav1.ObjectMeta{TaskGroupNameKey: "TGName", TaskGroupsKey: "TGs", TaskGroupPrarameters: "parameters"}, 0}, - {input: &metav1.ObjectMeta{TaskGroupNameKey: "TGName", TaskGroupsKey: "TGs"}, 0}, - {input: &metav1.ObjectMeta{TaskGroupNameKey: "TGName"}, 0}, - {input: &metav1.ObjectMeta{}, 0}, + {input: &metav1.ObjectMeta{Annotations: map[string]string{"others": "extra", TaskGroupNameKey: "TGName", TaskGroupsKey: "TGs", TaskGroupPrarameters: "parameters"}}, expect: 1}, + {input: &metav1.ObjectMeta{Annotations: map[string]string{TaskGroupNameKey: "TGName", TaskGroupsKey: "TGs", TaskGroupPrarameters: "parameters"}}, expect: 0}, + {input: &metav1.ObjectMeta{Annotations: map[string]string{TaskGroupNameKey: "TGName", TaskGroupsKey: "TGs"}}, expect: 0}, + {input: &metav1.ObjectMeta{Annotations: map[string]string{TaskGroupNameKey: "TGName"}}, expect: 0}, + {input: &metav1.ObjectMeta{}, expect: 0}, } for _, tt := range tests { t.Run("Remove Gang scheduling labels", func(t *testing.T) { RemoveGangSchedulingAnnotations(tt.input) - if got := len(tt.input); got != tt.expect { + if got := len(tt.input.Annotations); got != tt.expect { t.Errorf("got %d, expect %d", got, tt.expect) } }) diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/config.go b/flyteplugins/go/tasks/plugins/k8s/ray/config.go index f39782421f..8e6d52d98a 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/config.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/config.go @@ -25,8 +25,8 @@ var ( EnableUsageStats: false, ServiceAccount: "default", BatchScheduler: batchscheduler.BatchSchedulerConfig{ - Scheduler: "yunikorn", - Parameters: "timeout=10", + Scheduler: "", + Parameters: "", }, Defaults: DefaultConfig{ HeadNode: NodeConfig{ From e3f36318d244ee5ce764fa7e0b683f25859e13d5 Mon Sep 17 00:00:00 2001 From: yuteng Date: Thu, 25 Jul 2024 12:23:40 +0800 Subject: [PATCH 08/30] refactor Signed-off-by: yuteng --- .../plugins/k8s/ray/batchscheduler/config.go | 19 ++- .../plugins/k8s/ray/batchscheduler/default.go | 32 +++++ .../plugins/k8s/ray/batchscheduler/plugins.go | 24 ++++ .../k8s/ray/batchscheduler/plugins_test.go | 23 ++++ .../k8s/ray/batchscheduler/yunikorn.go | 110 +++++++++------ .../k8s/ray/batchscheduler/yunikorn_test.go | 127 ++++++++++-------- .../go/tasks/plugins/k8s/ray/config.go | 5 +- flyteplugins/go/tasks/plugins/k8s/ray/ray.go | 28 ++-- 8 files changed, 248 insertions(+), 120 deletions(-) create mode 100644 flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/default.go create mode 100644 flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins.go create mode 100644 flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins_test.go diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config.go index fb9bef7652..ff436c3b05 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config.go @@ -1,6 +1,21 @@ package batchscheduler type BatchSchedulerConfig struct { - Scheduler string `json:"scheduler"` - Parameters string `json:"parameters,omitempty` + Scheduler string `json:"scheduler,omitempty"` + Parameters string `json:"parameters,omitempty"` +} + +func NewDefaultBatchSchedulerConfig() BatchSchedulerConfig { + return BatchSchedulerConfig{ + Scheduler: "", + Parameters: "", + } +} + +func (b *BatchSchedulerConfig) GetScheduler() string { + return b.Scheduler +} + +func (b *BatchSchedulerConfig) GetParameters() string { + return b.Parameters } diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/default.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/default.go new file mode 100644 index 0000000000..02f0d4658a --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/default.go @@ -0,0 +1,32 @@ +package batchscheduler + +import ( + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +var ( + DefaultScheduler = "default" +) + +type DefaultPlugin struct{} + +func NewDefaultPlugin() *DefaultPlugin { + return &DefaultPlugin{} +} + +func (d *DefaultPlugin) GetSchedulerName() string { return DefaultScheduler } + +func (d *DefaultPlugin) ParseJob( + config *BatchSchedulerConfig, + metadata *metav1.ObjectMeta, + workerGroupsSpec []*plugins.WorkerGroupSpec, + pod *v1.PodSpec, + primaryContainerIdx int, +) error { + return nil +} +func (d *DefaultPlugin) ProcessHead(metadata *metav1.ObjectMeta, head *v1.PodSpec) {} +func (d *DefaultPlugin) ProcessWorker(metadata *metav1.ObjectMeta, worker *v1.PodSpec, index int) {} +func (d *DefaultPlugin) AfterProcess(metadata *metav1.ObjectMeta) {} diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins.go new file mode 100644 index 0000000000..edf51b0eb7 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins.go @@ -0,0 +1,24 @@ +package batchscheduler + +import ( + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +type SchedulerPlugin interface { + GetSchedulerName() string + ParseJob(config *BatchSchedulerConfig, metadata *metav1.ObjectMeta, workerGroupsSpec []*plugins.WorkerGroupSpec, pod *v1.PodSpec, primaryContainerIdx int) error + ProcessHead(metadata *metav1.ObjectMeta, head *v1.PodSpec) + ProcessWorker(metadata *metav1.ObjectMeta, worker *v1.PodSpec, index int) + AfterProcess(metadata *metav1.ObjectMeta) +} + +func NewSchedulerPlugin(config *BatchSchedulerConfig) SchedulerPlugin { + switch config.GetScheduler() { + case Yunikorn: + return NewYunikornPlugin() + default: + return NewDefaultPlugin() + } +} diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins_test.go new file mode 100644 index 0000000000..2d0bc3e957 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins_test.go @@ -0,0 +1,23 @@ +package batchscheduler + +import ( + "testing" +) + +func TestCreateSchedulerPlugin(t *testing.T) { + var tests = []struct{ + input *BatchSchedulerConfig + expect string + }{ + {input: &BatchSchedulerConfig{Scheduler: DefaultScheduler}, expect: DefaultScheduler}, + {input: &BatchSchedulerConfig{Scheduler: Yunikorn}, expect: Yunikorn}, + {input: &BatchSchedulerConfig{Scheduler:"Unknown"}, expect: DefaultScheduler}, + } + for _, tt := range tests { + t.Run("New scheduler plugin", func(t *testing.T) { + if got := NewSchedulerPlugin(tt.input); got.GetSchedulerName() != tt.expect { + t.Errorf("got %s, expect %s", got, tt.expect) + } + }) + } +} \ No newline at end of file diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go index 4dbe538460..8d70a46ae0 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go @@ -11,15 +11,45 @@ import ( const ( // Pod lebel - BatchSchedulerLabel = "batch-scheduler" - SchedulerLabel = "scheduler" - SchedulerName = "yunikorn" + Yunikorn = "yunikorn" TaskGroupNameKey = "yunikorn.apache.org/task-group-name" TaskGroupsKey = "yunikorn.apache.org/task-groups" TaskGroupPrarameters = "yunikorn.apache.org/schedulingPolicyParameters" TaskGroupGenericName = "task-group" ) +type YunikornGangSchedulingConfig struct { + Annotations map[string]map[string]string + Parameters string +} + +func NewYunikornPlugin() *YunikornGangSchedulingConfig { + return &YunikornGangSchedulingConfig{ + Annotations: nil, + } +} + +func (s *YunikornGangSchedulingConfig) GetSchedulerName() string { return Yunikorn } + +func (s *YunikornGangSchedulingConfig) ParseJob(config *BatchSchedulerConfig, metadata *metav1.ObjectMeta, workerGroupsSpec []*plugins.WorkerGroupSpec, pod *v1.PodSpec, primaryContainerIdx int) error { + s.Parameters = config.GetParameters() + return s.BuildGangInfo(metadata, workerGroupsSpec, pod, primaryContainerIdx) +} + +func (s *YunikornGangSchedulingConfig) ProcessHead(metadata *metav1.ObjectMeta, head *v1.PodSpec) { + s.SetSchedulerName(head) + s.AddGangSchedulingAnnotations(GenerateTaskGroupName(true, 0), metadata) +} + +func (s *YunikornGangSchedulingConfig) ProcessWorker(metadata *metav1.ObjectMeta, worker *v1.PodSpec, index int) { + s.SetSchedulerName(worker) + s.AddGangSchedulingAnnotations(GenerateTaskGroupName(false, index), metadata) +} + +func (s *YunikornGangSchedulingConfig) AfterProcess(metadata *metav1.ObjectMeta) { + RemoveGangSchedulingAnnotations(metadata) +} + type TaskGroup struct { Name string MinMember int32 @@ -39,15 +69,18 @@ func GenerateTaskGroupName(master bool, index int) string { return fmt.Sprintf("%s-%s-%d", TaskGroupGenericName, "worker", index) } -func SetSchedulerNameAndBuildGangInfo(config BatchSchedulerConfig, metadata *metav1.ObjectMeta, workerGroupsSpec []*plugins.WorkerGroupSpec, head, worker *v1.PodSpec) (map[string]map[string]string, error) { - if config.Scheduler != SchedulerName { - return nil, nil - } - head.SchedulerName = SchedulerName - worker.SchedulerName = SchedulerName +func (s *YunikornGangSchedulingConfig) SetSchedulerName(spec *v1.PodSpec) { + spec.SchedulerName = s.GetSchedulerName() +} - TaskGroupsAnnotations := make(map[string]map[string]string, 0) +func (s *YunikornGangSchedulingConfig) BuildGangInfo( + metadata *metav1.ObjectMeta, + workerGroupsSpec []*plugins.WorkerGroupSpec, + pod *v1.PodSpec, + primaryContainerIdx int, +) error { // Parsing placeholders from the pod resource among head and workers + s.Annotations = make(map[string]map[string]string, 0) TaskGroups := make([]TaskGroup, 0) headName := GenerateTaskGroupName(true, 0) TaskGroups = append(TaskGroups, TaskGroup{ @@ -55,12 +88,11 @@ func SetSchedulerNameAndBuildGangInfo(config BatchSchedulerConfig, metadata *met MinMember: 1, Labels: metadata.Labels, Annotations: metadata.Annotations, - MinResource: head.Containers[0].Resources.Requests, - NodeSelector: head.NodeSelector, - Affinity: head.Affinity, - TopologySpreadConstraints: head.TopologySpreadConstraints, + MinResource: pod.Containers[primaryContainerIdx].Resources.Requests, + NodeSelector: pod.NodeSelector, + Affinity: pod.Affinity, + TopologySpreadConstraints: pod.TopologySpreadConstraints, }) - for index, spec := range workerGroupsSpec { name := GenerateTaskGroupName(false, index) tg := TaskGroup{ @@ -68,41 +100,42 @@ func SetSchedulerNameAndBuildGangInfo(config BatchSchedulerConfig, metadata *met MinMember: spec.Replicas, Labels: metadata.Labels, Annotations: metadata.Annotations, - MinResource: worker.Containers[0].Resources.Requests, - NodeSelector: worker.NodeSelector, - Affinity: worker.Affinity, - TopologySpreadConstraints: worker.TopologySpreadConstraints, + MinResource: pod.Containers[primaryContainerIdx].Resources.Requests, + NodeSelector: pod.NodeSelector, + Affinity: pod.Affinity, + TopologySpreadConstraints: pod.TopologySpreadConstraints, } - TaskGroupsAnnotations[name] = map[string]string{ + s.Annotations[name] = map[string]string{ TaskGroupNameKey: name, } TaskGroups = append(TaskGroups, tg) } - // Yunikorn head gang scheduling annotations - info, err := json.Marshal(TaskGroups) - if err != nil { - return nil, err + var info []byte + var err error + if info, err = json.Marshal(TaskGroups); err != nil { + s.Annotations = nil + return err } headAnnotations := make(map[string]string, 0) headAnnotations[TaskGroupNameKey] = headName headAnnotations[TaskGroupsKey] = string(info[:]) - headAnnotations[TaskGroupPrarameters] = config.Parameters - TaskGroupsAnnotations[headName] = headAnnotations - return TaskGroupsAnnotations, nil + headAnnotations[TaskGroupPrarameters] = s.Parameters + s.Annotations[headName] = headAnnotations + return nil } -func AddGangSchedulingAnnotations(name string, metadata *metav1.ObjectMeta, TGAnnotations map[string]map[string]string) { - if TGAnnotations == nil { +func (s *YunikornGangSchedulingConfig) AddGangSchedulingAnnotations(name string, metadata *metav1.ObjectMeta) { + if s.Annotations == nil { return } - if _, ok := TGAnnotations[name]; !ok { + if _, ok := s.Annotations[name]; !ok { return } // Updating Yunikorn gang scheduling annotations - annotations := TGAnnotations[name] + annotations := s.Annotations[name] if _, ok := metadata.Annotations[TaskGroupNameKey]; !ok { if _, ok = annotations[TaskGroupNameKey]; ok { metadata.Annotations[TaskGroupNameKey] = annotations[TaskGroupNameKey] @@ -118,18 +151,13 @@ func AddGangSchedulingAnnotations(name string, metadata *metav1.ObjectMeta, TGAn metadata.Annotations[TaskGroupPrarameters] = annotations[TaskGroupPrarameters] } } - return } func RemoveGangSchedulingAnnotations(metadata *metav1.ObjectMeta) { - if _, ok := metadata.Annotations[TaskGroupNameKey]; ok { - delete(metadata.Annotations, TaskGroupNameKey) - } - if _, ok := metadata.Annotations[TaskGroupsKey]; ok { - delete(metadata.Annotations, TaskGroupsKey) - } - if _, ok := metadata.Annotations[TaskGroupPrarameters]; ok { - delete(metadata.Annotations, TaskGroupPrarameters) + if metadata == nil { + return } - return + delete(metadata.Annotations, TaskGroupNameKey) + delete(metadata.Annotations, TaskGroupsKey) + delete(metadata.Annotations, TaskGroupPrarameters) } diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn_test.go index edd8fda8ea..79e7cd76a6 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn_test.go @@ -3,46 +3,57 @@ package batchscheduler import ( "testing" - "k8s.io/apimachinery/pkg/api/resource" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" ) var ( podSpec = &v1.PodSpec{ Containers: []v1.Container{ - v1.Container{ + { Resources: v1.ResourceRequirements{ - Requests: v1.ResourceList { - "cpu": resource.MustParse("500m"), + Requests: v1.ResourceList{ + "cpu": resource.MustParse("500m"), "memory": resource.MustParse("1Gi"), }, }, }, }, - NodeSelector: nil, - Affinity: nil, + NodeSelector: nil, + Affinity: nil, TopologySpreadConstraints: nil, } rayWorkersSpec = []*plugins.WorkerGroupSpec{ - &plugins.WorkerGroupSpec{ - GroupName: "group1", - Replicas: int32(1), - MinReplicas: int32(1), - MaxReplicas: int32(2), + { + GroupName: "group1", + Replicas: int32(1), + MinReplicas: int32(1), + MaxReplicas: int32(2), RayStartParams: nil, }, - &plugins.WorkerGroupSpec{ - GroupName: "group2", - Replicas: int32(1), - MinReplicas: int32(1), - MaxReplicas: int32(2), + { + GroupName: "group2", + Replicas: int32(1), + MinReplicas: int32(1), + MaxReplicas: int32(2), RayStartParams: nil, }, } ) +func TestSetSchedulerName(t *testing.T) { + t.Run("Set Scheduler Name", func(t *testing.T){ + p := NewYunikornPlugin() + p.SetSchedulerName(podSpec) + if got := podSpec.SchedulerName; got != p.GetSchedulerName() { + t.Errorf("got %s, expect %s", got, p.GetSchedulerName()) + } + podSpec.SchedulerName = "" + }) +} + func TestGenerateTaskGroupName(t *testing.T) { var tests = []struct { master bool @@ -62,51 +73,53 @@ func TestGenerateTaskGroupName(t *testing.T) { } } -func TestSetSchedulerName(t *testing.T) { - head := podSpec.DeepCopy() - worker := podSpec.DeepCopy() - var tests = []struct { - schedulerConfig BatchSchedulerConfig - expect string - }{ - {schedulerConfig: BatchSchedulerConfig{Scheduler:"", Parameters:""}, expect: ""}, - {schedulerConfig: BatchSchedulerConfig{Scheduler:SchedulerName, Parameters:"gangSchedulingStyle=Hard"}, expect: SchedulerName}, - {schedulerConfig: BatchSchedulerConfig{Scheduler:"other", Parameters:""}, expect: ""}, - } - for _, tt := range tests { - t.Run("Scheduler Name", func(t *testing.T) { - SetSchedulerNameAndBuildGangInfo( - tt.schedulerConfig, - &metav1.ObjectMeta{ - Labels: map[string]string{}, - Annotations: map[string]string{}, - }, - rayWorkersSpec, - head, - worker, - ) - if got := head.SchedulerName; got != tt.expect { - t.Errorf("head pod scheduler name: expect %s, got %s", tt.expect, got) - } - if got := worker.SchedulerName; got != tt.expect { - t.Errorf("worker pod scheduler name: expect %s, got %s", tt.expect, got) - } - head.SchedulerName = "" - worker.SchedulerName = "" - }) - } -} - func TestRemoveGangSchedulingAnnotations(t *testing.T) { var tests = []struct { input *metav1.ObjectMeta expect int }{ - {input: &metav1.ObjectMeta{Annotations: map[string]string{"others": "extra", TaskGroupNameKey: "TGName", TaskGroupsKey: "TGs", TaskGroupPrarameters: "parameters"}}, expect: 1}, - {input: &metav1.ObjectMeta{Annotations: map[string]string{TaskGroupNameKey: "TGName", TaskGroupsKey: "TGs", TaskGroupPrarameters: "parameters"}}, expect: 0}, - {input: &metav1.ObjectMeta{Annotations: map[string]string{TaskGroupNameKey: "TGName", TaskGroupsKey: "TGs"}}, expect: 0}, - {input: &metav1.ObjectMeta{Annotations: map[string]string{TaskGroupNameKey: "TGName"}}, expect: 0}, - {input: &metav1.ObjectMeta{}, expect: 0}, + { + input: &metav1.ObjectMeta{ + Annotations: map[string]string{ + "others": "extra", + TaskGroupNameKey: "TGName", + TaskGroupsKey: "TGs", + TaskGroupPrarameters: "parameters", + }, + }, + expect: 1, + }, + { + input: &metav1.ObjectMeta{ + Annotations: map[string]string{ + TaskGroupNameKey: "TGName", + TaskGroupsKey: "TGs", + TaskGroupPrarameters: "parameters", + }, + }, + expect: 0, + }, + { + input: &metav1.ObjectMeta{ + Annotations: map[string]string{ + TaskGroupNameKey: "TGName", + TaskGroupsKey: "TGs", + }, + }, + expect: 0, + }, + { + input: &metav1.ObjectMeta{ + Annotations: map[string]string{ + TaskGroupNameKey: "TGName", + }, + }, + expect: 0, + }, + { + input: &metav1.ObjectMeta{}, + expect: 0, + }, } for _, tt := range tests { t.Run("Remove Gang scheduling labels", func(t *testing.T) { diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/config.go b/flyteplugins/go/tasks/plugins/k8s/ray/config.go index 8e6d52d98a..286ba59aff 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/config.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/config.go @@ -24,10 +24,7 @@ var ( DashboardHost: "0.0.0.0", EnableUsageStats: false, ServiceAccount: "default", - BatchScheduler: batchscheduler.BatchSchedulerConfig{ - Scheduler: "", - Parameters: "", - }, + BatchScheduler: batchscheduler.NewDefaultBatchSchedulerConfig(), Defaults: DefaultConfig{ HeadNode: NodeConfig{ StartParameters: map[string]string{ diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go index f73b21dd99..18224b3c8a 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go @@ -126,21 +126,21 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC } func constructRayJob(taskCtx pluginsCore.TaskExecutionContext, rayJob plugins.RayJob, objectMeta *metav1.ObjectMeta, podSpec v1.PodSpec, headPodSpec *v1.PodSpec, headNodeRayStartParams map[string]string, primaryContainerIdx int, primaryContainer v1.Container) (*rayv1.RayJob, error) { + var err error enableIngress := true cfg := GetConfig() - var err error - TGAnnotations, err := batchscheduler.SetSchedulerNameAndBuildGangInfo( - cfg.BatchScheduler, + schedulerPlugin := batchscheduler.NewSchedulerPlugin(&cfg.BatchScheduler) + err = schedulerPlugin.ParseJob( + &cfg.BatchScheduler, objectMeta, rayJob.RayCluster.WorkerGroupSpec, &podSpec, - headPodSpec, - ) - batchscheduler.AddGangSchedulingAnnotations( - batchscheduler.GenerateTaskGroupName(true, 0), - objectMeta, - TGAnnotations, + primaryContainerIdx, ) + if err != nil { + return nil, err + } + schedulerPlugin.ProcessHead(objectMeta, headPodSpec) rayClusterSpec := rayv1.RayClusterSpec{ HeadGroupSpec: rayv1.HeadGroupSpec{ Template: buildHeadPodTemplate( @@ -156,15 +156,11 @@ func constructRayJob(taskCtx pluginsCore.TaskExecutionContext, rayJob plugins.Ra WorkerGroupSpecs: []rayv1.WorkerGroupSpec{}, EnableInTreeAutoscaling: &rayJob.RayCluster.EnableAutoscaling, } - batchscheduler.RemoveGangSchedulingAnnotations(objectMeta) + schedulerPlugin.AfterProcess(objectMeta) for index, spec := range rayJob.RayCluster.WorkerGroupSpec { workerPodSpec := podSpec.DeepCopy() - batchscheduler.AddGangSchedulingAnnotations( - batchscheduler.GenerateTaskGroupName(false, index), - objectMeta, - TGAnnotations, - ) + schedulerPlugin.ProcessWorker(objectMeta, workerPodSpec, index) workerPodTemplate := buildWorkerPodTemplate( &workerPodSpec.Containers[primaryContainerIdx], workerPodSpec, @@ -206,7 +202,7 @@ func constructRayJob(taskCtx pluginsCore.TaskExecutionContext, rayJob plugins.Ra } rayClusterSpec.WorkerGroupSpecs = append(rayClusterSpec.WorkerGroupSpecs, workerNodeSpec) - batchscheduler.RemoveGangSchedulingAnnotations(objectMeta) + schedulerPlugin.AfterProcess(objectMeta) } serviceAccountName := flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) From ca4a4eb83318a7db4d296a42cc06c8e51a6fb4f6 Mon Sep 17 00:00:00 2001 From: yuteng Date: Thu, 25 Jul 2024 15:35:46 +0800 Subject: [PATCH 09/30] unit tests Signed-off-by: yuteng --- .../k8s/ray/batchscheduler/plugins_test.go | 7 +- .../k8s/ray/batchscheduler/yunikorn.go | 5 +- .../k8s/ray/batchscheduler/yunikorn_test.go | 129 ++++++++++++++---- .../go/tasks/plugins/k8s/ray/config.go | 2 +- 4 files changed, 112 insertions(+), 31 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins_test.go index 2d0bc3e957..28142e5334 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins_test.go @@ -2,6 +2,8 @@ package batchscheduler import ( "testing" + + "github.com/stretchr/testify/assert" ) func TestCreateSchedulerPlugin(t *testing.T) { @@ -15,9 +17,8 @@ func TestCreateSchedulerPlugin(t *testing.T) { } for _, tt := range tests { t.Run("New scheduler plugin", func(t *testing.T) { - if got := NewSchedulerPlugin(tt.input); got.GetSchedulerName() != tt.expect { - t.Errorf("got %s, expect %s", got, tt.expect) - } + p := NewSchedulerPlugin(tt.input) + assert.Equal(t, tt.expect, p.GetSchedulerName()) }) } } \ No newline at end of file diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go index 8d70a46ae0..35717429fe 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go @@ -32,6 +32,7 @@ func NewYunikornPlugin() *YunikornGangSchedulingConfig { func (s *YunikornGangSchedulingConfig) GetSchedulerName() string { return Yunikorn } func (s *YunikornGangSchedulingConfig) ParseJob(config *BatchSchedulerConfig, metadata *metav1.ObjectMeta, workerGroupsSpec []*plugins.WorkerGroupSpec, pod *v1.PodSpec, primaryContainerIdx int) error { + s.Annotations = nil s.Parameters = config.GetParameters() return s.BuildGangInfo(metadata, workerGroupsSpec, pod, primaryContainerIdx) } @@ -120,7 +121,9 @@ func (s *YunikornGangSchedulingConfig) BuildGangInfo( headAnnotations := make(map[string]string, 0) headAnnotations[TaskGroupNameKey] = headName headAnnotations[TaskGroupsKey] = string(info[:]) - headAnnotations[TaskGroupPrarameters] = s.Parameters + if len(s.Parameters) > 0 { + headAnnotations[TaskGroupPrarameters] = s.Parameters + } s.Annotations[headName] = headAnnotations return nil } diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn_test.go index 79e7cd76a6..59e730329c 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn_test.go @@ -2,6 +2,9 @@ package batchscheduler import ( "testing" + "encoding/json" + + "github.com/stretchr/testify/assert" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" v1 "k8s.io/api/core/v1" @@ -25,35 +28,113 @@ var ( Affinity: nil, TopologySpreadConstraints: nil, } - rayWorkersSpec = []*plugins.WorkerGroupSpec{ - { - GroupName: "group1", - Replicas: int32(1), - MinReplicas: int32(1), - MaxReplicas: int32(2), - RayStartParams: nil, - }, - { - GroupName: "group2", - Replicas: int32(1), - MinReplicas: int32(1), - MaxReplicas: int32(2), - RayStartParams: nil, - }, - } ) func TestSetSchedulerName(t *testing.T) { - t.Run("Set Scheduler Name", func(t *testing.T){ + t.Run("Set Scheduler Name", func(t *testing.T) { p := NewYunikornPlugin() p.SetSchedulerName(podSpec) - if got := podSpec.SchedulerName; got != p.GetSchedulerName() { - t.Errorf("got %s, expect %s", got, p.GetSchedulerName()) - } + assert.Equal(t, p.GetSchedulerName(), podSpec.SchedulerName) podSpec.SchedulerName = "" }) } +func TestBuildGangInfo(t *testing.T) { + names := []string{GenerateTaskGroupName(true, 0)} + res := v1.ResourceList{ + "cpu": resource.MustParse("500m"), + "memory": resource.MustParse("1Gi"), + } + for index := 0; index < 2; index++ { + names = append(names, GenerateTaskGroupName(false, index)) + } + var tests = []struct { + workerGroupNum int + taskGroups []TaskGroup + }{ + { + workerGroupNum: 2, + taskGroups: []TaskGroup{ + { + Name: names[0], + MinMember: int32(1), + Labels: nil, + Annotations: map[string]string{"others": "extra"}, + MinResource: res, + NodeSelector: nil, + Tolerations: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + }, + { + Name: names[1], + MinMember: int32(1), + Labels: nil, + Annotations: map[string]string{"others": "extra"}, + MinResource: res, + NodeSelector: nil, + Tolerations: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + }, + { + Name: names[2], + MinMember: int32(2), + Labels: nil, + Annotations: map[string]string{"others": "extra"}, + MinResource: res, + NodeSelector: nil, + Tolerations: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + }, + }, + }, + } + for _, tt := range tests { + t.Run("Create Yunikorn gang scheduling annotations", func(t *testing.T) { + workersSpec := make([]*plugins.WorkerGroupSpec, 0) + for index := 0; index < tt.workerGroupNum; index++ { + count := 1 * (1 + index) + max := 2 * (1 + index) + workersSpec = append(workersSpec, &plugins.WorkerGroupSpec{ + Replicas: int32(count), + MinReplicas: int32(count), + MaxReplicas: int32(max), + }) + } + metadata := &metav1.ObjectMeta{ + Annotations: map[string]string{"others": "extra"}, + } + p := NewYunikornPlugin() + err := p.BuildGangInfo(metadata, workersSpec, podSpec, 0) + assert.Nil(t, err) + // test worker name + for index := 0; index < tt.workerGroupNum; index++ { + workerIndex := index + 1 + name := names[workerIndex] + if annotations, ok := p.Annotations[name]; ok { + assert.Equal(t, 1, len(annotations)) + assert.Equal(t, name, annotations[TaskGroupNameKey]) + } else { + t.Errorf("Worker group %d annotatiosn miss", index) + } + } + // Test head name and groups + headName := names[0] + if annotations, ok := p.Annotations[headName]; ok { + info, err := json.Marshal(tt.taskGroups) + assert.Nil(t, err) + assert.Equal(t, 2, len(annotations)) + assert.Equal(t, headName, annotations[TaskGroupNameKey]) + assert.Equal(t, string(info[:]), annotations[TaskGroupsKey]) + } else { + t.Error("Head annotations miss") + } + }) + } +} + func TestGenerateTaskGroupName(t *testing.T) { var tests = []struct { master bool @@ -66,9 +147,7 @@ func TestGenerateTaskGroupName(t *testing.T) { } for _, tt := range tests { t.Run("Generating Task group name", func(t *testing.T) { - if got := GenerateTaskGroupName(tt.master, tt.index); got != tt.expect { - t.Errorf("got %s, expect %s", got, tt.expect) - } + assert.Equal(t, tt.expect, GenerateTaskGroupName(tt.master, tt.index)) }) } } @@ -124,9 +203,7 @@ func TestRemoveGangSchedulingAnnotations(t *testing.T) { for _, tt := range tests { t.Run("Remove Gang scheduling labels", func(t *testing.T) { RemoveGangSchedulingAnnotations(tt.input) - if got := len(tt.input.Annotations); got != tt.expect { - t.Errorf("got %d, expect %d", got, tt.expect) - } + assert.Equal(t, tt.expect, len(tt.input.Annotations)) }) } } diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/config.go b/flyteplugins/go/tasks/plugins/k8s/ray/config.go index 286ba59aff..3d96e66ec7 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/config.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/config.go @@ -78,7 +78,7 @@ type Config struct { // or 0.0.0.0 (available from all interfaces). By default, this is localhost. DashboardHost string `json:"dashboardHost,omitempty"` - BatchScheduler batchscheduler.BatchSchedulerConfig `json:"batchSchedulerConfig,omitempty"` + BatchScheduler batchscheduler.BatchSchedulerConfig `json:"BatchScheduler,omitempty"` // DeprecatedNodeIPAddress the IP address of the head node. By default, this is pod ip address. DeprecatedNodeIPAddress string `json:"nodeIPAddress,omitempty" pflag:"-,DEPRECATED. Please use DefaultConfig.[HeadNode|WorkerNode].IPAddress"` From d8ba6bcf62c79941b531dc204dfa81c684464b72 Mon Sep 17 00:00:00 2001 From: yuteng Date: Thu, 25 Jul 2024 15:43:57 +0800 Subject: [PATCH 10/30] lint:scheduler config Signed-off-by: yuteng --- .../go/tasks/plugins/k8s/ray/batchscheduler/config.go | 10 +++++----- .../go/tasks/plugins/k8s/ray/batchscheduler/default.go | 2 +- .../go/tasks/plugins/k8s/ray/batchscheduler/plugins.go | 4 ++-- .../plugins/k8s/ray/batchscheduler/plugins_test.go | 8 ++++---- .../tasks/plugins/k8s/ray/batchscheduler/yunikorn.go | 2 +- flyteplugins/go/tasks/plugins/k8s/ray/config.go | 4 ++-- 6 files changed, 15 insertions(+), 15 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config.go index ff436c3b05..70dccd7d88 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config.go @@ -1,21 +1,21 @@ package batchscheduler -type BatchSchedulerConfig struct { +type Config struct { Scheduler string `json:"scheduler,omitempty"` Parameters string `json:"parameters,omitempty"` } -func NewDefaultBatchSchedulerConfig() BatchSchedulerConfig { - return BatchSchedulerConfig{ +func NewConfig() Config { + return Config{ Scheduler: "", Parameters: "", } } -func (b *BatchSchedulerConfig) GetScheduler() string { +func (b *Config) GetScheduler() string { return b.Scheduler } -func (b *BatchSchedulerConfig) GetParameters() string { +func (b *Config) GetParameters() string { return b.Parameters } diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/default.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/default.go index 02f0d4658a..dc17d1155f 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/default.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/default.go @@ -19,7 +19,7 @@ func NewDefaultPlugin() *DefaultPlugin { func (d *DefaultPlugin) GetSchedulerName() string { return DefaultScheduler } func (d *DefaultPlugin) ParseJob( - config *BatchSchedulerConfig, + config *Config, metadata *metav1.ObjectMeta, workerGroupsSpec []*plugins.WorkerGroupSpec, pod *v1.PodSpec, diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins.go index edf51b0eb7..5aa2ff2ff9 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins.go @@ -8,13 +8,13 @@ import ( type SchedulerPlugin interface { GetSchedulerName() string - ParseJob(config *BatchSchedulerConfig, metadata *metav1.ObjectMeta, workerGroupsSpec []*plugins.WorkerGroupSpec, pod *v1.PodSpec, primaryContainerIdx int) error + ParseJob(config *Config, metadata *metav1.ObjectMeta, workerGroupsSpec []*plugins.WorkerGroupSpec, pod *v1.PodSpec, primaryContainerIdx int) error ProcessHead(metadata *metav1.ObjectMeta, head *v1.PodSpec) ProcessWorker(metadata *metav1.ObjectMeta, worker *v1.PodSpec, index int) AfterProcess(metadata *metav1.ObjectMeta) } -func NewSchedulerPlugin(config *BatchSchedulerConfig) SchedulerPlugin { +func NewSchedulerPlugin(config *Config) SchedulerPlugin { switch config.GetScheduler() { case Yunikorn: return NewYunikornPlugin() diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins_test.go index 28142e5334..ac3ee9db9c 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins_test.go @@ -8,12 +8,12 @@ import ( func TestCreateSchedulerPlugin(t *testing.T) { var tests = []struct{ - input *BatchSchedulerConfig + input *Config expect string }{ - {input: &BatchSchedulerConfig{Scheduler: DefaultScheduler}, expect: DefaultScheduler}, - {input: &BatchSchedulerConfig{Scheduler: Yunikorn}, expect: Yunikorn}, - {input: &BatchSchedulerConfig{Scheduler:"Unknown"}, expect: DefaultScheduler}, + {input: &Config{Scheduler: DefaultScheduler}, expect: DefaultScheduler}, + {input: &Config{Scheduler: Yunikorn}, expect: Yunikorn}, + {input: &Config{Scheduler:"Unknown"}, expect: DefaultScheduler}, } for _, tt := range tests { t.Run("New scheduler plugin", func(t *testing.T) { diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go index 35717429fe..c9f93b10be 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go @@ -31,7 +31,7 @@ func NewYunikornPlugin() *YunikornGangSchedulingConfig { func (s *YunikornGangSchedulingConfig) GetSchedulerName() string { return Yunikorn } -func (s *YunikornGangSchedulingConfig) ParseJob(config *BatchSchedulerConfig, metadata *metav1.ObjectMeta, workerGroupsSpec []*plugins.WorkerGroupSpec, pod *v1.PodSpec, primaryContainerIdx int) error { +func (s *YunikornGangSchedulingConfig) ParseJob(config *Config, metadata *metav1.ObjectMeta, workerGroupsSpec []*plugins.WorkerGroupSpec, pod *v1.PodSpec, primaryContainerIdx int) error { s.Annotations = nil s.Parameters = config.GetParameters() return s.BuildGangInfo(metadata, workerGroupsSpec, pod, primaryContainerIdx) diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/config.go b/flyteplugins/go/tasks/plugins/k8s/ray/config.go index 3d96e66ec7..ce547a680f 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/config.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/config.go @@ -24,7 +24,7 @@ var ( DashboardHost: "0.0.0.0", EnableUsageStats: false, ServiceAccount: "default", - BatchScheduler: batchscheduler.NewDefaultBatchSchedulerConfig(), + BatchScheduler: batchscheduler.NewConfig(), Defaults: DefaultConfig{ HeadNode: NodeConfig{ StartParameters: map[string]string{ @@ -78,7 +78,7 @@ type Config struct { // or 0.0.0.0 (available from all interfaces). By default, this is localhost. DashboardHost string `json:"dashboardHost,omitempty"` - BatchScheduler batchscheduler.BatchSchedulerConfig `json:"BatchScheduler,omitempty"` + BatchScheduler batchscheduler.Config `json:"BatchScheduler,omitempty"` // DeprecatedNodeIPAddress the IP address of the head node. By default, this is pod ip address. DeprecatedNodeIPAddress string `json:"nodeIPAddress,omitempty" pflag:"-,DEPRECATED. Please use DefaultConfig.[HeadNode|WorkerNode].IPAddress"` From b4e6467847edc6530ccc94e4d489a69e71fbebd6 Mon Sep 17 00:00:00 2001 From: yuteng Date: Thu, 25 Jul 2024 15:53:40 +0800 Subject: [PATCH 11/30] gci format Signed-off-by: yuteng --- .../go/tasks/plugins/k8s/ray/batchscheduler/default.go | 7 ++++--- .../go/tasks/plugins/k8s/ray/batchscheduler/plugins.go | 3 ++- .../tasks/plugins/k8s/ray/batchscheduler/plugins_test.go | 8 ++++---- .../go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go | 5 +++-- .../tasks/plugins/k8s/ray/batchscheduler/yunikorn_test.go | 6 +++--- 5 files changed, 16 insertions(+), 13 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/default.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/default.go index dc17d1155f..e00f595ff4 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/default.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/default.go @@ -1,9 +1,10 @@ package batchscheduler import ( - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" ) var ( @@ -27,6 +28,6 @@ func (d *DefaultPlugin) ParseJob( ) error { return nil } -func (d *DefaultPlugin) ProcessHead(metadata *metav1.ObjectMeta, head *v1.PodSpec) {} +func (d *DefaultPlugin) ProcessHead(metadata *metav1.ObjectMeta, head *v1.PodSpec) {} func (d *DefaultPlugin) ProcessWorker(metadata *metav1.ObjectMeta, worker *v1.PodSpec, index int) {} -func (d *DefaultPlugin) AfterProcess(metadata *metav1.ObjectMeta) {} +func (d *DefaultPlugin) AfterProcess(metadata *metav1.ObjectMeta) {} diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins.go index 5aa2ff2ff9..13962bfc23 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins.go @@ -1,9 +1,10 @@ package batchscheduler import ( - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" ) type SchedulerPlugin interface { diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins_test.go index ac3ee9db9c..f8a13bfeae 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins_test.go @@ -7,13 +7,13 @@ import ( ) func TestCreateSchedulerPlugin(t *testing.T) { - var tests = []struct{ - input *Config + var tests = []struct { + input *Config expect string }{ {input: &Config{Scheduler: DefaultScheduler}, expect: DefaultScheduler}, {input: &Config{Scheduler: Yunikorn}, expect: Yunikorn}, - {input: &Config{Scheduler:"Unknown"}, expect: DefaultScheduler}, + {input: &Config{Scheduler: "Unknown"}, expect: DefaultScheduler}, } for _, tt := range tests { t.Run("New scheduler plugin", func(t *testing.T) { @@ -21,4 +21,4 @@ func TestCreateSchedulerPlugin(t *testing.T) { assert.Equal(t, tt.expect, p.GetSchedulerName()) }) } -} \ No newline at end of file +} diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go index c9f93b10be..3a864e5ce7 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go @@ -4,9 +4,10 @@ import ( "encoding/json" "fmt" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" ) const ( @@ -20,7 +21,7 @@ const ( type YunikornGangSchedulingConfig struct { Annotations map[string]map[string]string - Parameters string + Parameters string } func NewYunikornPlugin() *YunikornGangSchedulingConfig { diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn_test.go index 59e730329c..26368a7494 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn_test.go @@ -1,15 +1,15 @@ package batchscheduler import ( - "testing" "encoding/json" + "testing" "github.com/stretchr/testify/assert" - - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" ) var ( From 3fbe18d016b032967318784afc5ed0f278aae75d Mon Sep 17 00:00:00 2001 From: yuteng Date: Fri, 26 Jul 2024 20:35:02 +0800 Subject: [PATCH 12/30] Update go.sum to origin one Signed-off-by: yuteng --- flyteidl/go.sum | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flyteidl/go.sum b/flyteidl/go.sum index 1819269c1c..5d5cb7e9a2 100644 --- a/flyteidl/go.sum +++ b/flyteidl/go.sum @@ -214,8 +214,8 @@ github.com/prometheus/common v0.44.0 h1:+5BrQJwiBB9xsMygAB3TNvpQKOwlkc25LbISbrdO github.com/prometheus/common v0.44.0/go.mod h1:ofAIvZbQ1e/nugmZGz4/qCb9Ap1VoSTIO7x0VV9VvuY= github.com/prometheus/procfs v0.10.1 h1:kYK1Va/YMlutzCGazswoHKo//tZVlFpKYh+PymziUAg= github.com/prometheus/procfs v0.10.1/go.mod h1:nwNm2aOCAYw8uTR/9bWRREkZFxAUcWzPHWJq+XBB/FM= -github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= -github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= From 20c91fe292d1787c039afcc67f63cfab0485ab3c Mon Sep 17 00:00:00 2001 From: yuteng Date: Mon, 29 Jul 2024 04:53:02 +0800 Subject: [PATCH 13/30] refactor Signed-off-by: yuteng --- .../ray/batchscheduler/{ => config}/config.go | 2 +- .../ray/batchscheduler/config/config_test.go | 15 + .../plugins/k8s/ray/batchscheduler/default.go | 33 - .../plugins/k8s/ray/batchscheduler/plugins.go | 15 +- .../k8s/ray/batchscheduler/plugins_test.go | 12 +- .../scheduler/kubernetes/default.go | 34 + .../scheduler/kubernetes/default_test.go | 103 +++ .../scheduler/yunikorn/taskgroup.go | 24 + .../scheduler/yunikorn/taskgroup_test.go | 46 ++ .../scheduler/yunikorn/utils.go | 16 + .../scheduler/yunikorn/utils_test.go | 42 ++ .../{ => scheduler/yunikorn}/yunikorn.go | 117 ++-- .../scheduler/yunikorn/yunikorn_test.go | 650 ++++++++++++++++++ .../k8s/ray/batchscheduler/yunikorn_test.go | 209 ------ .../go/tasks/plugins/k8s/ray/config.go | 6 +- flyteplugins/go/tasks/plugins/k8s/ray/ray.go | 2 +- 16 files changed, 1003 insertions(+), 323 deletions(-) rename flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/{ => config}/config.go (93%) create mode 100644 flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config/config_test.go delete mode 100644 flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/default.go create mode 100644 flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/kubernetes/default.go create mode 100644 flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/kubernetes/default_test.go create mode 100644 flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/taskgroup.go create mode 100644 flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/taskgroup_test.go create mode 100644 flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/utils.go create mode 100644 flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/utils_test.go rename flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/{ => scheduler/yunikorn}/yunikorn.go (50%) create mode 100644 flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/yunikorn_test.go delete mode 100644 flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn_test.go diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config/config.go similarity index 93% rename from flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config.go rename to flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config/config.go index 70dccd7d88..e1633a5bc0 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config/config.go @@ -1,4 +1,4 @@ -package batchscheduler +package config type Config struct { Scheduler string `json:"scheduler,omitempty"` diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config/config_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config/config_test.go new file mode 100644 index 0000000000..b7eb9fc354 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config/config_test.go @@ -0,0 +1,15 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewConfig(t *testing.T) { + t.Run("New scheduler plugin config", func(t *testing.T) { + config := NewConfig() + assert.Equal(t, "", config.GetScheduler()) + assert.Equal(t, "", config.GetParameters()) + }) +} diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/default.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/default.go deleted file mode 100644 index e00f595ff4..0000000000 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/default.go +++ /dev/null @@ -1,33 +0,0 @@ -package batchscheduler - -import ( - v1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" -) - -var ( - DefaultScheduler = "default" -) - -type DefaultPlugin struct{} - -func NewDefaultPlugin() *DefaultPlugin { - return &DefaultPlugin{} -} - -func (d *DefaultPlugin) GetSchedulerName() string { return DefaultScheduler } - -func (d *DefaultPlugin) ParseJob( - config *Config, - metadata *metav1.ObjectMeta, - workerGroupsSpec []*plugins.WorkerGroupSpec, - pod *v1.PodSpec, - primaryContainerIdx int, -) error { - return nil -} -func (d *DefaultPlugin) ProcessHead(metadata *metav1.ObjectMeta, head *v1.PodSpec) {} -func (d *DefaultPlugin) ProcessWorker(metadata *metav1.ObjectMeta, worker *v1.PodSpec, index int) {} -func (d *DefaultPlugin) AfterProcess(metadata *metav1.ObjectMeta) {} diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins.go index 13962bfc23..6ee85b8043 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins.go @@ -5,21 +5,24 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" + schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/kubernetes" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn" ) type SchedulerPlugin interface { GetSchedulerName() string - ParseJob(config *Config, metadata *metav1.ObjectMeta, workerGroupsSpec []*plugins.WorkerGroupSpec, pod *v1.PodSpec, primaryContainerIdx int) error - ProcessHead(metadata *metav1.ObjectMeta, head *v1.PodSpec) + ParseJob(config *schedulerConfig.Config, metadata *metav1.ObjectMeta, workerGroupsSpec []*plugins.WorkerGroupSpec, pod *v1.PodSpec, primaryContainerIdx int) error + ProcessHead(metadata *metav1.ObjectMeta, head *v1.PodSpec, index int) ProcessWorker(metadata *metav1.ObjectMeta, worker *v1.PodSpec, index int) AfterProcess(metadata *metav1.ObjectMeta) } -func NewSchedulerPlugin(config *Config) SchedulerPlugin { +func NewSchedulerPlugin(config *schedulerConfig.Config) SchedulerPlugin { switch config.GetScheduler() { - case Yunikorn: - return NewYunikornPlugin() + case yunikorn.Yunikorn: + return yunikorn.NewYunikornPlugin() default: - return NewDefaultPlugin() + return kubernetes.NewDefaultPlugin() } } diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins_test.go index f8a13bfeae..11731ec91f 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins_test.go @@ -4,16 +4,20 @@ import ( "testing" "github.com/stretchr/testify/assert" + + schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/kubernetes" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn" ) func TestCreateSchedulerPlugin(t *testing.T) { var tests = []struct { - input *Config + input *schedulerConfig.Config expect string }{ - {input: &Config{Scheduler: DefaultScheduler}, expect: DefaultScheduler}, - {input: &Config{Scheduler: Yunikorn}, expect: Yunikorn}, - {input: &Config{Scheduler: "Unknown"}, expect: DefaultScheduler}, + {input: &schedulerConfig.Config{Scheduler: kubernetes.DefaultScheduler}, expect: kubernetes.DefaultScheduler}, + {input: &schedulerConfig.Config{Scheduler: yunikorn.Yunikorn}, expect: yunikorn.Yunikorn}, + {input: &schedulerConfig.Config{Scheduler: "Unknown"}, expect: kubernetes.DefaultScheduler}, } for _, tt := range tests { t.Run("New scheduler plugin", func(t *testing.T) { diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/kubernetes/default.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/kubernetes/default.go new file mode 100644 index 0000000000..20241f8752 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/kubernetes/default.go @@ -0,0 +1,34 @@ +package kubernetes + +import ( + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" + schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config" +) + +var ( + DefaultScheduler = "default" +) + +type Plugin struct{} + +func NewDefaultPlugin() *Plugin { + return &Plugin{} +} + +func (d *Plugin) GetSchedulerName() string { return DefaultScheduler } + +func (d *Plugin) ParseJob( + config *schedulerConfig.Config, + metadata *metav1.ObjectMeta, + workerGroupsSpec []*plugins.WorkerGroupSpec, + pod *v1.PodSpec, + primaryContainerIdx int, +) error { + return nil +} +func (d *Plugin) ProcessHead(metadata *metav1.ObjectMeta, head *v1.PodSpec, index int) {} +func (d *Plugin) ProcessWorker(metadata *metav1.ObjectMeta, worker *v1.PodSpec, index int) {} +func (d *Plugin) AfterProcess(metadata *metav1.ObjectMeta) {} diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/kubernetes/default_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/kubernetes/default_test.go new file mode 100644 index 0000000000..3857e296bd --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/kubernetes/default_test.go @@ -0,0 +1,103 @@ +package kubernetes + +import ( + "testing" + + "github.com/stretchr/testify/assert" + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" + schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config" +) + +var ( + metadata = &metav1.ObjectMeta{ + Labels: map[string]string{"others": "extra"}, + Annotations: map[string]string{"others": "extra"}, + } + res = v1.ResourceList{ + "cpu": resource.MustParse("500m"), + "memory": resource.MustParse("1Gi"), + } + podSpec = &v1.PodSpec{ + Containers: []v1.Container{ + { + Resources: v1.ResourceRequirements{ + Requests: res, + }, + }, + }, + NodeSelector: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + } +) + +func TestNewDefaultPlugin(t *testing.T) { + t.Run("New default scheduler plugin", func(t *testing.T) { + p := NewDefaultPlugin() + assert.NotNil(t, p) + assert.Equal(t, DefaultScheduler, p.GetSchedulerName()) + }) +} + +func TestParseJob(t *testing.T) { + t.Run("Default scheduler plugin parse job", func(t *testing.T) { + p := schedulerConfig.NewConfig() + rayWorkersSpec := []*plugins.WorkerGroupSpec{ + { + GroupName: "g1", + Replicas: int32(2), + MinReplicas: int32(1), + MaxReplicas: int32(3), + RayStartParams: map[string]string{ + "parameters": "specific parameters", + }, + }, + } + index := 0 + err := NewDefaultPlugin().ParseJob(&p, metadata, rayWorkersSpec, podSpec, index) + assert.Nil(t, err) + workerspec := rayWorkersSpec[0] + assert.Equal(t, "g1", workerspec.GroupName) + assert.Equal(t, int32(2), workerspec.Replicas) + assert.Equal(t, int32(1), workerspec.MinReplicas) + assert.Equal(t, int32(3), workerspec.MaxReplicas) + assert.Equal(t, map[string]string{"parameters": "specific parameters"}, workerspec.RayStartParams) + assert.Equal(t, map[string]string{"others": "extra"}, metadata.Annotations) + assert.Equal(t, map[string]string{"others": "extra"}, metadata.Labels) + assert.Equal(t, res, podSpec.Containers[index].Resources.Requests) + assert.Equal(t, "", p.GetScheduler()) + assert.Equal(t, "", p.GetParameters()) + }) +} + +func TestProcessHead(t *testing.T) { + t.Run("Default scheduler plugin process head", func(t *testing.T) { + index := 0 + NewDefaultPlugin().ProcessHead(metadata, podSpec, index) + assert.Equal(t, map[string]string{"others": "extra"}, metadata.Annotations) + assert.Equal(t, map[string]string{"others": "extra"}, metadata.Labels) + assert.Equal(t, res, podSpec.Containers[index].Resources.Requests) + }) +} + +func TestProcessWorker(t *testing.T) { + t.Run("Default scheduler plugin preprocess worker", func(t *testing.T) { + index := 0 + NewDefaultPlugin().ProcessWorker(metadata, podSpec, index) + assert.Equal(t, map[string]string{"others": "extra"}, metadata.Annotations) + assert.Equal(t, map[string]string{"others": "extra"}, metadata.Labels) + assert.Equal(t, res, podSpec.Containers[index].Resources.Requests) + }) +} + +func TestAfterProcess(t *testing.T) { + t.Run("Default scheduler plugin afterly process worker", func(t *testing.T) { + NewDefaultPlugin().AfterProcess(metadata) + assert.Equal(t, map[string]string{"others": "extra"}, metadata.Annotations) + assert.Equal(t, map[string]string{"others": "extra"}, metadata.Labels) + }) +} diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/taskgroup.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/taskgroup.go new file mode 100644 index 0000000000..5a52579ce4 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/taskgroup.go @@ -0,0 +1,24 @@ +package yunikorn + +import ( + "encoding/json" + + v1 "k8s.io/api/core/v1" +) + +type TaskGroup struct { + Name string + MinMember int32 + Labels map[string]string + Annotations map[string]string + MinResource v1.ResourceList + NodeSelector map[string]string + Tolerations []v1.Toleration + Affinity *v1.Affinity + TopologySpreadConstraints []v1.TopologySpreadConstraint +} + +func Marshal(taskGroups []TaskGroup) ([]byte, error) { + info, err := json.Marshal(taskGroups) + return info, err +} diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/taskgroup_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/taskgroup_test.go new file mode 100644 index 0000000000..180e2a6e84 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/taskgroup_test.go @@ -0,0 +1,46 @@ +package yunikorn + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMarshal(t *testing.T) { + t1 := TaskGroup{ + Name: "tg1", + MinMember: int32(1), + Labels: map[string]string{"attr": "value"}, + Annotations: map[string]string{"attr": "value"}, + MinResource: res, + NodeSelector: map[string]string{"node": "gpunode"}, + Tolerations: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + } + t2 := TaskGroup{ + Name: "tg2", + MinMember: int32(1), + Labels: map[string]string{"attr": "value"}, + Annotations: map[string]string{"attr": "value"}, + MinResource: res, + NodeSelector: map[string]string{"node": "gpunode"}, + Tolerations: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + } + var tests = []struct { + input []TaskGroup + }{ + {input: nil}, + {input: []TaskGroup{}}, + {input: []TaskGroup{t1}}, + {input: []TaskGroup{t1, t2}}, + } + t.Run("Serialize task groups", func(t *testing.T) { + for _, tt := range tests { + _, err := Marshal(tt.input) + assert.Nil(t, err) + } + }) +} diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/utils.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/utils.go new file mode 100644 index 0000000000..ff94255282 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/utils.go @@ -0,0 +1,16 @@ +package yunikorn + +import ( + "fmt" +) + +const ( + TaskGroupGenericName = "task-group" +) + +func GenerateTaskGroupName(master bool, index int) string { + if master { + return fmt.Sprintf("%s-%s", TaskGroupGenericName, "head") + } + return fmt.Sprintf("%s-%s-%d", TaskGroupGenericName, "worker", index) +} diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/utils_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/utils_test.go new file mode 100644 index 0000000000..6e14d9b7f7 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/utils_test.go @@ -0,0 +1,42 @@ +package yunikorn + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGenerateTaskGroupName(t *testing.T) { + type inputFormat struct { + isMaster bool + index int + } + var tests = []struct { + input inputFormat + expect string + }{ + { + input: inputFormat{isMaster: true, index: 0}, + expect: fmt.Sprintf("%s-%s", TaskGroupGenericName, "head"), + }, + { + input: inputFormat{isMaster: true, index: 1}, + expect: fmt.Sprintf("%s-%s", TaskGroupGenericName, "head"), + }, + { + input: inputFormat{isMaster: false, index: 0}, + expect: fmt.Sprintf("%s-%s-%d", TaskGroupGenericName, "worker", 0), + }, + { + input: inputFormat{isMaster: false, index: 1}, + expect: fmt.Sprintf("%s-%s-%d", TaskGroupGenericName, "worker", 1), + }, + } + t.Run("Gernerate ray task group name", func(t *testing.T) { + for _, tt := range tests { + got := GenerateTaskGroupName(tt.input.isMaster, tt.input.index) + assert.Equal(t, tt.expect, got) + } + }) +} diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/yunikorn.go similarity index 50% rename from flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go rename to flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/yunikorn.go index 3a864e5ce7..0eda2e8f2e 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/yunikorn.go @@ -1,13 +1,13 @@ -package batchscheduler +package yunikorn import ( - "encoding/json" - "fmt" + "errors" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" + schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config" ) const ( @@ -16,92 +16,89 @@ const ( TaskGroupNameKey = "yunikorn.apache.org/task-group-name" TaskGroupsKey = "yunikorn.apache.org/task-groups" TaskGroupPrarameters = "yunikorn.apache.org/schedulingPolicyParameters" - TaskGroupGenericName = "task-group" ) -type YunikornGangSchedulingConfig struct { +type Plugin struct { Annotations map[string]map[string]string Parameters string } -func NewYunikornPlugin() *YunikornGangSchedulingConfig { - return &YunikornGangSchedulingConfig{ +func NewYunikornPlugin() *Plugin { + return &Plugin{ Annotations: nil, + Parameters: "", } } -func (s *YunikornGangSchedulingConfig) GetSchedulerName() string { return Yunikorn } +func (s *Plugin) GetSchedulerName() string { return Yunikorn } -func (s *YunikornGangSchedulingConfig) ParseJob(config *Config, metadata *metav1.ObjectMeta, workerGroupsSpec []*plugins.WorkerGroupSpec, pod *v1.PodSpec, primaryContainerIdx int) error { +func (s *Plugin) ParseJob(config *schedulerConfig.Config, metadata *metav1.ObjectMeta, workerGroupsSpec []*plugins.WorkerGroupSpec, pod *v1.PodSpec, primaryContainerIdx int) error { s.Annotations = nil - s.Parameters = config.GetParameters() + if parameters := config.GetParameters(); len(parameters) > 0 { + s.Parameters = parameters + } return s.BuildGangInfo(metadata, workerGroupsSpec, pod, primaryContainerIdx) } -func (s *YunikornGangSchedulingConfig) ProcessHead(metadata *metav1.ObjectMeta, head *v1.PodSpec) { +func (s *Plugin) ProcessHead(metadata *metav1.ObjectMeta, head *v1.PodSpec, index int) { s.SetSchedulerName(head) - s.AddGangSchedulingAnnotations(GenerateTaskGroupName(true, 0), metadata) + s.AddGangSchedulingAnnotations(GenerateTaskGroupName(true, index), metadata) } -func (s *YunikornGangSchedulingConfig) ProcessWorker(metadata *metav1.ObjectMeta, worker *v1.PodSpec, index int) { +func (s *Plugin) ProcessWorker(metadata *metav1.ObjectMeta, worker *v1.PodSpec, index int) { s.SetSchedulerName(worker) s.AddGangSchedulingAnnotations(GenerateTaskGroupName(false, index), metadata) } -func (s *YunikornGangSchedulingConfig) AfterProcess(metadata *metav1.ObjectMeta) { - RemoveGangSchedulingAnnotations(metadata) -} - -type TaskGroup struct { - Name string - MinMember int32 - Labels map[string]string - Annotations map[string]string - MinResource v1.ResourceList - NodeSelector map[string]string - Tolerations []v1.Toleration - Affinity *v1.Affinity - TopologySpreadConstraints []v1.TopologySpreadConstraint -} - -func GenerateTaskGroupName(master bool, index int) string { - if master { - return fmt.Sprintf("%s-%s", TaskGroupGenericName, "head") +func (s *Plugin) AfterProcess(metadata *metav1.ObjectMeta) { + if metadata == nil { + return } - return fmt.Sprintf("%s-%s-%d", TaskGroupGenericName, "worker", index) + delete(metadata.Annotations, TaskGroupNameKey) + delete(metadata.Annotations, TaskGroupsKey) + delete(metadata.Annotations, TaskGroupPrarameters) } -func (s *YunikornGangSchedulingConfig) SetSchedulerName(spec *v1.PodSpec) { +func (s *Plugin) SetSchedulerName(spec *v1.PodSpec) { spec.SchedulerName = s.GetSchedulerName() } -func (s *YunikornGangSchedulingConfig) BuildGangInfo( +func (s *Plugin) BuildGangInfo( metadata *metav1.ObjectMeta, workerGroupsSpec []*plugins.WorkerGroupSpec, pod *v1.PodSpec, primaryContainerIdx int, ) error { + if pod == nil { + return errors.New("Ray gang scheduling: pod is nil") + } // Parsing placeholders from the pod resource among head and workers - s.Annotations = make(map[string]map[string]string, 0) + var labels, annotations map[string]string = nil, nil + if metadata != nil { + labels = metadata.Labels + annotations = metadata.Annotations + } TaskGroups := make([]TaskGroup, 0) headName := GenerateTaskGroupName(true, 0) TaskGroups = append(TaskGroups, TaskGroup{ Name: headName, MinMember: 1, - Labels: metadata.Labels, - Annotations: metadata.Annotations, + Labels: labels, + Annotations: annotations, MinResource: pod.Containers[primaryContainerIdx].Resources.Requests, NodeSelector: pod.NodeSelector, Affinity: pod.Affinity, TopologySpreadConstraints: pod.TopologySpreadConstraints, }) + + s.Annotations = make(map[string]map[string]string, 0) for index, spec := range workerGroupsSpec { name := GenerateTaskGroupName(false, index) tg := TaskGroup{ Name: name, MinMember: spec.Replicas, - Labels: metadata.Labels, - Annotations: metadata.Annotations, + Labels: labels, + Annotations: annotations, MinResource: pod.Containers[primaryContainerIdx].Resources.Requests, NodeSelector: pod.NodeSelector, Affinity: pod.Affinity, @@ -112,13 +109,10 @@ func (s *YunikornGangSchedulingConfig) BuildGangInfo( } TaskGroups = append(TaskGroups, tg) } + // Yunikorn head gang scheduling annotations var info []byte - var err error - if info, err = json.Marshal(TaskGroups); err != nil { - s.Annotations = nil - return err - } + info, _ = Marshal(TaskGroups) headAnnotations := make(map[string]string, 0) headAnnotations[TaskGroupNameKey] = headName headAnnotations[TaskGroupsKey] = string(info[:]) @@ -129,8 +123,8 @@ func (s *YunikornGangSchedulingConfig) BuildGangInfo( return nil } -func (s *YunikornGangSchedulingConfig) AddGangSchedulingAnnotations(name string, metadata *metav1.ObjectMeta) { - if s.Annotations == nil { +func (s *Plugin) AddGangSchedulingAnnotations(name string, metadata *metav1.ObjectMeta) { + if s.Annotations == nil || metadata == nil { return } @@ -138,30 +132,21 @@ func (s *YunikornGangSchedulingConfig) AddGangSchedulingAnnotations(name string, return } + if metadata.Annotations == nil { + metadata.Annotations = make(map[string]string, 0) + } + // Updating Yunikorn gang scheduling annotations annotations := s.Annotations[name] - if _, ok := metadata.Annotations[TaskGroupNameKey]; !ok { - if _, ok = annotations[TaskGroupNameKey]; ok { - metadata.Annotations[TaskGroupNameKey] = annotations[TaskGroupNameKey] - } + if _, ok := annotations[TaskGroupNameKey]; ok { + metadata.Annotations[TaskGroupNameKey] = annotations[TaskGroupNameKey] } - if _, ok := metadata.Annotations[TaskGroupsKey]; !ok { - if _, ok = annotations[TaskGroupsKey]; ok { - metadata.Annotations[TaskGroupsKey] = annotations[TaskGroupsKey] - } + if _, ok := annotations[TaskGroupsKey]; ok { + metadata.Annotations[TaskGroupsKey] = annotations[TaskGroupsKey] } if _, ok := metadata.Annotations[TaskGroupPrarameters]; !ok { - if _, ok = annotations[TaskGroupPrarameters]; ok { - metadata.Annotations[TaskGroupPrarameters] = annotations[TaskGroupPrarameters] + if parameters, ok := annotations[TaskGroupPrarameters]; ok && len(parameters) > 0 { + metadata.Annotations[TaskGroupPrarameters] = parameters } } } - -func RemoveGangSchedulingAnnotations(metadata *metav1.ObjectMeta) { - if metadata == nil { - return - } - delete(metadata.Annotations, TaskGroupNameKey) - delete(metadata.Annotations, TaskGroupsKey) - delete(metadata.Annotations, TaskGroupPrarameters) -} diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/yunikorn_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/yunikorn_test.go new file mode 100644 index 0000000000..768ea8e376 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/yunikorn_test.go @@ -0,0 +1,650 @@ +package yunikorn + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" + schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config" +) + +var ( + res = v1.ResourceList{ + "cpu": resource.MustParse("500m"), + "memory": resource.MustParse("1Gi"), + } +) + +func TestParseJob(t *testing.T) { + type inputFormat struct { + config *schedulerConfig.Config + metadata *metav1.ObjectMeta + workerGroupNum int + podSpec *v1.PodSpec + index int + } + type expectFormat struct { + raiseErr bool + parameters string + taskGroups []TaskGroup + } + var tests = []struct { + input inputFormat + expect expectFormat + }{ + { + input: inputFormat{ + config: &schedulerConfig.Config{ + Scheduler: "yunikorn", + Parameters: "placeholderTimeoutInSeconds=15 gangSchedulingStyle=Soft", + }, + workerGroupNum: 1, + podSpec: nil, + metadata: &metav1.ObjectMeta{}, + index: 0, + }, + expect: expectFormat{ + raiseErr: true, + parameters: "placeholderTimeoutInSeconds=15 gangSchedulingStyle=Soft", + taskGroups: []TaskGroup{ + { + Name: GenerateTaskGroupName(true, 0), + MinMember: int32(1), + Labels: nil, + Annotations: map[string]string{"others": "extra"}, + MinResource: res, + NodeSelector: nil, + Tolerations: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + }, + { + Name: GenerateTaskGroupName(false, 0), + MinMember: int32(1), + Labels: nil, + Annotations: map[string]string{"others": "extra"}, + MinResource: res, + NodeSelector: nil, + Tolerations: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + }, + }, + }, + }, + { + input: inputFormat{ + config: &schedulerConfig.Config{ + Scheduler: "yunikorn", + Parameters: "placeholderTimeoutInSeconds=15 gangSchedulingStyle=Soft", + }, + workerGroupNum: 1, + podSpec: &v1.PodSpec{ + Containers: []v1.Container{ + { + Resources: v1.ResourceRequirements{ + Requests: res, + }, + }, + }, + NodeSelector: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + }, + metadata: &metav1.ObjectMeta{ + Annotations: map[string]string{"others": "extra"}, + }, + index: 0, + }, + expect: expectFormat{ + raiseErr: false, + parameters: "placeholderTimeoutInSeconds=15 gangSchedulingStyle=Soft", + taskGroups: []TaskGroup{ + { + Name: GenerateTaskGroupName(true, 0), + MinMember: int32(1), + Labels: nil, + Annotations: map[string]string{"others": "extra"}, + MinResource: res, + NodeSelector: nil, + Tolerations: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + }, + { + Name: GenerateTaskGroupName(false, 0), + MinMember: int32(1), + Labels: nil, + Annotations: map[string]string{"others": "extra"}, + MinResource: res, + NodeSelector: nil, + Tolerations: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run("Yunikorn parse job", func(t *testing.T) { + workersSpec := make([]*plugins.WorkerGroupSpec, 0) + for index := 0; index < tt.input.workerGroupNum; index++ { + count := 1 * (1 + index) + max := 2 * (1 + index) + workersSpec = append(workersSpec, &plugins.WorkerGroupSpec{ + Replicas: int32(count), + MinReplicas: int32(count), + MaxReplicas: int32(max), + }) + } + p := NewYunikornPlugin() + err := p.ParseJob(tt.input.config, tt.input.metadata, workersSpec, tt.input.podSpec, tt.input.index) + if tt.expect.raiseErr { + assert.NotNil(t, err) + } else { + assert.Nil(t, err) + assert.Equal(t, Yunikorn, p.GetSchedulerName()) + names := []string{GenerateTaskGroupName(true, 0)} + for index := 0; index < tt.input.workerGroupNum; index++ { + names = append(names, GenerateTaskGroupName(false, index)) + } + // task-groups among head and workers + assert.Equal(t, len(names), len(p.Annotations)) + // check head annotations + head := p.Annotations[names[0]] + assert.Equal(t, names[0], head[TaskGroupNameKey]) + assert.Equal(t, tt.expect.parameters, head[TaskGroupPrarameters]) + // task-groups in head + var taskgroups []TaskGroup + err = json.Unmarshal([]byte(head[TaskGroupsKey]), &taskgroups) + assert.Nil(t, err) + assert.Equal(t, len(names), len(taskgroups)) + for index, tg := range taskgroups { + assert.Equal(t, names[index], tg.Name) + } + } + }) + } +} + +func TestProcessHead(t *testing.T) { + type inputFormat struct { + config *schedulerConfig.Config + metadata *metav1.ObjectMeta + workerGroupNum int + podSpec *v1.PodSpec + index int + } + type expectFormat struct { + name string + taskgroupsNum int + parameters string + } + var tests = []struct { + input inputFormat + expect expectFormat + }{ + { + input: inputFormat{ + config: &schedulerConfig.Config{ + Scheduler: "yunikorn", + Parameters: "placeholderTimeoutInSeconds=15 gangSchedulingStyle=Soft", + }, + workerGroupNum: 1, + podSpec: &v1.PodSpec{ + Containers: []v1.Container{ + { + Resources: v1.ResourceRequirements{ + Requests: res, + }, + }, + }, + NodeSelector: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + }, + metadata: &metav1.ObjectMeta{ + Annotations: map[string]string{"others": "extra"}, + }, + index: 0, + }, + expect: expectFormat{ + name: GenerateTaskGroupName(true, 0), + taskgroupsNum: 2, + parameters: "placeholderTimeoutInSeconds=15 gangSchedulingStyle=Soft", + }, + }, + } + for _, tt := range tests { + t.Run("Yunikorn process head", func(t *testing.T) { + workersSpec := make([]*plugins.WorkerGroupSpec, 0) + for index := 0; index < tt.input.workerGroupNum; index++ { + workersSpec = append(workersSpec, &plugins.WorkerGroupSpec{ + Replicas: int32(1), + MinReplicas: int32(1), + MaxReplicas: int32(2), + }) + } + p := NewYunikornPlugin() + err := p.ParseJob(tt.input.config, tt.input.metadata, workersSpec, tt.input.podSpec, tt.input.index) + assert.Nil(t, err) + p.ProcessHead(tt.input.metadata, tt.input.podSpec, tt.input.index) + assert.Equal(t, Yunikorn, tt.input.podSpec.SchedulerName) + assert.Equal(t, tt.expect.name, tt.input.metadata.Annotations[TaskGroupNameKey]) + assert.Equal(t, tt.expect.parameters, tt.input.metadata.Annotations[TaskGroupPrarameters]) + var taskgroups []TaskGroup + err = json.Unmarshal([]byte(tt.input.metadata.Annotations[TaskGroupsKey]), &taskgroups) + assert.Nil(t, err) + assert.Equal(t, tt.expect.taskgroupsNum, len(taskgroups)) + }) + } +} + +func TestProcessWorker(t *testing.T) { + type inputFormat struct { + config *schedulerConfig.Config + metadata *metav1.ObjectMeta + workerGroupNum int + podSpec *v1.PodSpec + index int + } + type expectFormat struct { + name string + taskgroupsNum int + } + var tests = []struct { + input inputFormat + expect expectFormat + }{ + { + input: inputFormat{ + config: &schedulerConfig.Config{ + Scheduler: "yunikorn", + Parameters: "placeholderTimeoutInSeconds=15 gangSchedulingStyle=Soft", + }, + workerGroupNum: 1, + podSpec: &v1.PodSpec{ + Containers: []v1.Container{ + { + Resources: v1.ResourceRequirements{ + Requests: res, + }, + }, + }, + NodeSelector: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + }, + metadata: &metav1.ObjectMeta{ + Annotations: map[string]string{"others": "extra"}, + }, + index: 0, + }, + expect: expectFormat{ + name: GenerateTaskGroupName(false, 0), + taskgroupsNum: 2, + }, + }, + } + for _, tt := range tests { + t.Run("Yunikorn process worker", func(t *testing.T) { + workersSpec := make([]*plugins.WorkerGroupSpec, 0) + for index := 0; index < tt.input.workerGroupNum; index++ { + workersSpec = append(workersSpec, &plugins.WorkerGroupSpec{ + Replicas: int32(1), + MinReplicas: int32(1), + MaxReplicas: int32(2), + }) + } + p := NewYunikornPlugin() + err := p.ParseJob(tt.input.config, tt.input.metadata, workersSpec, tt.input.podSpec, tt.input.index) + assert.Nil(t, err) + p.ProcessWorker(tt.input.metadata, tt.input.podSpec, tt.input.index) + assert.Equal(t, Yunikorn, tt.input.podSpec.SchedulerName) + assert.Equal(t, tt.expect.name, tt.input.metadata.Annotations[TaskGroupNameKey]) + }) + } +} + +func TestAfterProcess(t *testing.T) { + type expectFormat struct { + isNil bool + length int + } + var tests = []struct { + input *metav1.ObjectMeta + expect expectFormat + }{ + { + input: nil, + expect: expectFormat{isNil: true, length: -1}, + }, + { + input: &metav1.ObjectMeta{ + Annotations: map[string]string{ + "others": "extra", + TaskGroupNameKey: "TGName", + TaskGroupsKey: "TGs", + TaskGroupPrarameters: "parameters", + }, + }, + expect: expectFormat{isNil: false, length: 1}, + }, + { + input: &metav1.ObjectMeta{ + Annotations: map[string]string{ + TaskGroupNameKey: "TGName", + TaskGroupsKey: "TGs", + TaskGroupPrarameters: "parameters", + }, + }, + expect: expectFormat{isNil: false, length: 0}, + }, + { + input: &metav1.ObjectMeta{ + Annotations: map[string]string{ + TaskGroupNameKey: "TGName", + TaskGroupsKey: "TGs", + }, + }, + expect: expectFormat{isNil: false, length: 0}, + }, + { + input: &metav1.ObjectMeta{ + Annotations: map[string]string{ + TaskGroupNameKey: "TGName", + }, + }, + expect: expectFormat{isNil: false, length: 0}, + }, + { + input: &metav1.ObjectMeta{}, + expect: expectFormat{isNil: false, length: 0}, + }, + } + for _, tt := range tests { + t.Run("Remove Gang scheduling labels", func(t *testing.T) { + p := NewYunikornPlugin() + p.AfterProcess(tt.input) + if tt.expect.isNil { + assert.Nil(t, tt.input) + } else { + assert.NotNil(t, tt.input) + assert.Equal(t, tt.expect.length, len(tt.input.Annotations)) + } + }) + } +} + +func TestSetSchedulerName(t *testing.T) { + t.Run("Set Scheduler Name", func(t *testing.T) { + p := NewYunikornPlugin() + podSpec := &v1.PodSpec{ + Containers: []v1.Container{ + { + Resources: v1.ResourceRequirements{ + Requests: res, + }, + }, + }, + NodeSelector: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + } + p.SetSchedulerName(podSpec) + assert.Equal(t, p.GetSchedulerName(), podSpec.SchedulerName) + podSpec.SchedulerName = "" + }) +} + +func TestBuildGangInfo(t *testing.T) { + names := []string{GenerateTaskGroupName(true, 0)} + for index := 0; index < 2; index++ { + names = append(names, GenerateTaskGroupName(false, index)) + } + type inputFormat struct { + workerGroupNum int + podSpec *v1.PodSpec + metadata *metav1.ObjectMeta + } + var tests = []struct { + input inputFormat + taskGroups []TaskGroup + }{ + { + input: inputFormat{ + workerGroupNum: 1, + podSpec: nil, + metadata: &metav1.ObjectMeta{ + Annotations: map[string]string{"others": "extra"}, + }, + }, + taskGroups: nil, + }, + { + input: inputFormat{ + workerGroupNum: 1, + podSpec: &v1.PodSpec{ + Containers: []v1.Container{ + { + Resources: v1.ResourceRequirements{ + Requests: res, + }, + }, + }, + NodeSelector: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + }, + metadata: &metav1.ObjectMeta{ + Annotations: map[string]string{"others": "extra"}, + }, + }, + taskGroups: []TaskGroup{ + { + Name: GenerateTaskGroupName(true, 0), + MinMember: int32(1), + Labels: nil, + Annotations: map[string]string{"others": "extra"}, + MinResource: res, + NodeSelector: nil, + Tolerations: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + }, + { + Name: GenerateTaskGroupName(false, 0), + MinMember: int32(1), + Labels: nil, + Annotations: map[string]string{"others": "extra"}, + MinResource: res, + NodeSelector: nil, + Tolerations: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + }, + }, + }, + { + input: inputFormat{ + workerGroupNum: 2, + podSpec: &v1.PodSpec{ + Containers: []v1.Container{ + { + Resources: v1.ResourceRequirements{ + Requests: res, + }, + }, + }, + NodeSelector: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + }, + metadata: &metav1.ObjectMeta{ + Annotations: map[string]string{"others": "extra"}, + }, + }, + taskGroups: []TaskGroup{ + { + Name: GenerateTaskGroupName(true, 0), + MinMember: int32(1), + Labels: nil, + Annotations: map[string]string{"others": "extra"}, + MinResource: res, + NodeSelector: nil, + Tolerations: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + }, + { + Name: GenerateTaskGroupName(false, 0), + MinMember: int32(1), + Labels: nil, + Annotations: map[string]string{"others": "extra"}, + MinResource: res, + NodeSelector: nil, + Tolerations: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + }, + { + Name: GenerateTaskGroupName(false, 1), + MinMember: int32(2), + Labels: nil, + Annotations: map[string]string{"others": "extra"}, + MinResource: res, + NodeSelector: nil, + Tolerations: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + }, + }, + }, + } + for _, tt := range tests { + t.Run("Create Yunikorn gang scheduling annotations", func(t *testing.T) { + workersSpec := make([]*plugins.WorkerGroupSpec, 0) + for index := 0; index < tt.input.workerGroupNum; index++ { + count := 1 * (1 + index) + max := 2 * (1 + index) + workersSpec = append(workersSpec, &plugins.WorkerGroupSpec{ + Replicas: int32(count), + MinReplicas: int32(count), + MaxReplicas: int32(max), + }) + } + p := NewYunikornPlugin() + if err := p.BuildGangInfo(tt.input.metadata, workersSpec, tt.input.podSpec, 0); tt.input.podSpec == nil { + assert.NotNil(t, err) + } else { + assert.Nil(t, err) + // test worker name + for index := 0; index < tt.input.workerGroupNum; index++ { + name := GenerateTaskGroupName(false, index) + if annotations, ok := p.Annotations[name]; ok { + assert.Equal(t, 1, len(annotations)) + assert.Equal(t, name, annotations[TaskGroupNameKey]) + } else { + t.Errorf("Worker group %d annotatiosn miss", index) + } + } + // Test head name and groups + headName := GenerateTaskGroupName(true, 0) + if annotations, ok := p.Annotations[headName]; ok { + info, err := json.Marshal(tt.taskGroups) + assert.Nil(t, err) + assert.Equal(t, 2, len(annotations)) + assert.Equal(t, headName, annotations[TaskGroupNameKey]) + assert.Equal(t, string(info[:]), annotations[TaskGroupsKey]) + } else { + t.Error("Head annotations miss") + } + } + }) + } +} + +func TestAddGangSchedulingAnnotations(t *testing.T) { + taskGroupsAnnotations := map[string]map[string]string{ + GenerateTaskGroupName(true, 0): { + TaskGroupNameKey: GenerateTaskGroupName(true, 0), + TaskGroupsKey: "TGs", + TaskGroupPrarameters: "parameters", + }, + GenerateTaskGroupName(false, 0): { + TaskGroupNameKey: GenerateTaskGroupName(false, 0), + }, + } + type inputFormat struct { + annotations map[string]map[string]string + metadata *metav1.ObjectMeta + name string + } + var tests = []struct { + input inputFormat + expect *metav1.ObjectMeta + }{ + { + input: inputFormat{ + annotations: nil, + metadata: nil, + name: "", + }, + expect: nil, + }, + { + input: inputFormat{ + annotations: taskGroupsAnnotations, + metadata: nil, + name: "", + }, + expect: nil, + }, + { + input: inputFormat{ + annotations: taskGroupsAnnotations, + metadata: &metav1.ObjectMeta{}, + name: "Unknown", + }, + expect: &metav1.ObjectMeta{}, + }, + { + input: inputFormat{ + annotations: taskGroupsAnnotations, + metadata: &metav1.ObjectMeta{}, + name: GenerateTaskGroupName(true, 0), + }, + expect: &metav1.ObjectMeta{ + Annotations: taskGroupsAnnotations[GenerateTaskGroupName(true, 0)], + }, + }, + { + input: inputFormat{ + annotations: taskGroupsAnnotations, + metadata: &metav1.ObjectMeta{}, + name: GenerateTaskGroupName(false, 0), + }, + expect: &metav1.ObjectMeta{ + Annotations: taskGroupsAnnotations[GenerateTaskGroupName(false, 0)], + }, + }, + } + for _, tt := range tests { + t.Run("Check gang scheduling annotatiosn after labeling", func(t *testing.T) { + p := NewYunikornPlugin() + p.Annotations = tt.input.annotations + p.AddGangSchedulingAnnotations(tt.input.name, tt.input.metadata) + if tt.expect == nil { + assert.Nil(t, tt.expect, tt.input.metadata) + } else { + assert.Equal(t, tt.expect.Annotations, tt.input.metadata.Annotations) + } + }) + } +} diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn_test.go deleted file mode 100644 index 26368a7494..0000000000 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn_test.go +++ /dev/null @@ -1,209 +0,0 @@ -package batchscheduler - -import ( - "encoding/json" - "testing" - - "github.com/stretchr/testify/assert" - v1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/api/resource" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" -) - -var ( - podSpec = &v1.PodSpec{ - Containers: []v1.Container{ - { - Resources: v1.ResourceRequirements{ - Requests: v1.ResourceList{ - "cpu": resource.MustParse("500m"), - "memory": resource.MustParse("1Gi"), - }, - }, - }, - }, - NodeSelector: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - } -) - -func TestSetSchedulerName(t *testing.T) { - t.Run("Set Scheduler Name", func(t *testing.T) { - p := NewYunikornPlugin() - p.SetSchedulerName(podSpec) - assert.Equal(t, p.GetSchedulerName(), podSpec.SchedulerName) - podSpec.SchedulerName = "" - }) -} - -func TestBuildGangInfo(t *testing.T) { - names := []string{GenerateTaskGroupName(true, 0)} - res := v1.ResourceList{ - "cpu": resource.MustParse("500m"), - "memory": resource.MustParse("1Gi"), - } - for index := 0; index < 2; index++ { - names = append(names, GenerateTaskGroupName(false, index)) - } - var tests = []struct { - workerGroupNum int - taskGroups []TaskGroup - }{ - { - workerGroupNum: 2, - taskGroups: []TaskGroup{ - { - Name: names[0], - MinMember: int32(1), - Labels: nil, - Annotations: map[string]string{"others": "extra"}, - MinResource: res, - NodeSelector: nil, - Tolerations: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - }, - { - Name: names[1], - MinMember: int32(1), - Labels: nil, - Annotations: map[string]string{"others": "extra"}, - MinResource: res, - NodeSelector: nil, - Tolerations: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - }, - { - Name: names[2], - MinMember: int32(2), - Labels: nil, - Annotations: map[string]string{"others": "extra"}, - MinResource: res, - NodeSelector: nil, - Tolerations: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - }, - }, - }, - } - for _, tt := range tests { - t.Run("Create Yunikorn gang scheduling annotations", func(t *testing.T) { - workersSpec := make([]*plugins.WorkerGroupSpec, 0) - for index := 0; index < tt.workerGroupNum; index++ { - count := 1 * (1 + index) - max := 2 * (1 + index) - workersSpec = append(workersSpec, &plugins.WorkerGroupSpec{ - Replicas: int32(count), - MinReplicas: int32(count), - MaxReplicas: int32(max), - }) - } - metadata := &metav1.ObjectMeta{ - Annotations: map[string]string{"others": "extra"}, - } - p := NewYunikornPlugin() - err := p.BuildGangInfo(metadata, workersSpec, podSpec, 0) - assert.Nil(t, err) - // test worker name - for index := 0; index < tt.workerGroupNum; index++ { - workerIndex := index + 1 - name := names[workerIndex] - if annotations, ok := p.Annotations[name]; ok { - assert.Equal(t, 1, len(annotations)) - assert.Equal(t, name, annotations[TaskGroupNameKey]) - } else { - t.Errorf("Worker group %d annotatiosn miss", index) - } - } - // Test head name and groups - headName := names[0] - if annotations, ok := p.Annotations[headName]; ok { - info, err := json.Marshal(tt.taskGroups) - assert.Nil(t, err) - assert.Equal(t, 2, len(annotations)) - assert.Equal(t, headName, annotations[TaskGroupNameKey]) - assert.Equal(t, string(info[:]), annotations[TaskGroupsKey]) - } else { - t.Error("Head annotations miss") - } - }) - } -} - -func TestGenerateTaskGroupName(t *testing.T) { - var tests = []struct { - master bool - index int - expect string - }{ - {true, 0, GenerateTaskGroupName(true, 0)}, - {false, 0, GenerateTaskGroupName(false, 0)}, - {false, 1, GenerateTaskGroupName(false, 1)}, - } - for _, tt := range tests { - t.Run("Generating Task group name", func(t *testing.T) { - assert.Equal(t, tt.expect, GenerateTaskGroupName(tt.master, tt.index)) - }) - } -} - -func TestRemoveGangSchedulingAnnotations(t *testing.T) { - var tests = []struct { - input *metav1.ObjectMeta - expect int - }{ - { - input: &metav1.ObjectMeta{ - Annotations: map[string]string{ - "others": "extra", - TaskGroupNameKey: "TGName", - TaskGroupsKey: "TGs", - TaskGroupPrarameters: "parameters", - }, - }, - expect: 1, - }, - { - input: &metav1.ObjectMeta{ - Annotations: map[string]string{ - TaskGroupNameKey: "TGName", - TaskGroupsKey: "TGs", - TaskGroupPrarameters: "parameters", - }, - }, - expect: 0, - }, - { - input: &metav1.ObjectMeta{ - Annotations: map[string]string{ - TaskGroupNameKey: "TGName", - TaskGroupsKey: "TGs", - }, - }, - expect: 0, - }, - { - input: &metav1.ObjectMeta{ - Annotations: map[string]string{ - TaskGroupNameKey: "TGName", - }, - }, - expect: 0, - }, - { - input: &metav1.ObjectMeta{}, - expect: 0, - }, - } - for _, tt := range tests { - t.Run("Remove Gang scheduling labels", func(t *testing.T) { - RemoveGangSchedulingAnnotations(tt.input) - assert.Equal(t, tt.expect, len(tt.input.Annotations)) - }) - } -} diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/config.go b/flyteplugins/go/tasks/plugins/k8s/ray/config.go index ce547a680f..3efed684c7 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/config.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/config.go @@ -9,7 +9,7 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/logs" pluginmachinery "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler" + schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config" "github.com/flyteorg/flyte/flytestdlib/config" ) @@ -24,7 +24,7 @@ var ( DashboardHost: "0.0.0.0", EnableUsageStats: false, ServiceAccount: "default", - BatchScheduler: batchscheduler.NewConfig(), + BatchScheduler: schedulerConfig.NewConfig(), Defaults: DefaultConfig{ HeadNode: NodeConfig{ StartParameters: map[string]string{ @@ -78,7 +78,7 @@ type Config struct { // or 0.0.0.0 (available from all interfaces). By default, this is localhost. DashboardHost string `json:"dashboardHost,omitempty"` - BatchScheduler batchscheduler.Config `json:"BatchScheduler,omitempty"` + BatchScheduler schedulerConfig.Config `json:"BatchScheduler,omitempty"` // DeprecatedNodeIPAddress the IP address of the head node. By default, this is pod ip address. DeprecatedNodeIPAddress string `json:"nodeIPAddress,omitempty" pflag:"-,DEPRECATED. Please use DefaultConfig.[HeadNode|WorkerNode].IPAddress"` diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go index 18224b3c8a..3c1f01d5e6 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go @@ -140,7 +140,7 @@ func constructRayJob(taskCtx pluginsCore.TaskExecutionContext, rayJob plugins.Ra if err != nil { return nil, err } - schedulerPlugin.ProcessHead(objectMeta, headPodSpec) + schedulerPlugin.ProcessHead(objectMeta, headPodSpec, primaryContainerIdx) rayClusterSpec := rayv1.RayClusterSpec{ HeadGroupSpec: rayv1.HeadGroupSpec{ Template: buildHeadPodTemplate( From c14d265514f7e8a51fbd7129bef02057fc6a9d60 Mon Sep 17 00:00:00 2001 From: yuteng Date: Mon, 29 Jul 2024 22:58:33 +0800 Subject: [PATCH 14/30] update test in ray Signed-off-by: yuteng --- .../go/tasks/plugins/k8s/ray/ray_test.go | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go index 7b555e9f23..82b7960d9d 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go @@ -27,6 +27,8 @@ import ( mocks2 "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s/mocks" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" + schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn" ) const ( @@ -474,6 +476,36 @@ func TestDefaultStartParameters(t *testing.T) { assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Tolerations, toleration) } +func TestYunikornAnnotationsCreate(t *testing.T) { + assert.NoError(t, SetConfig(&Config{ + BatchScheduler: schedulerConfig.Config{ + Scheduler: "yunikorn", + Parameters: "gangSchedulingStyle=Soft", + }, + })) + rayJobResourceHandler := rayJobResourceHandler{} + rayJob := &plugins.RayJob{ + RayCluster: &plugins.RayCluster{ + HeadGroupSpec: &plugins.HeadGroupSpec{}, + WorkerGroupSpec: []*plugins.WorkerGroupSpec{{GroupName: workerGroupName, Replicas: 3, MinReplicas: 3, MaxReplicas: 3}}, + EnableAutoscaling: true, + }, + ShutdownAfterJobFinishes: true, + TtlSecondsAfterFinished: 120, + } + + taskTemplate := dummyRayTaskTemplate("ray-id", rayJob) + RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), dummyRayTaskContext(taskTemplate, resourceRequirements, nil, "", serviceAccount)) + assert.Nil(t, err) + ray, ok := RayResource.(*rayv1.RayJob) + assert.True(t, ok) + headAnnotations := ray.Spec.RayClusterSpec.HeadGroupSpec.Template.ObjectMeta.Annotations + workerAnnotations := ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.ObjectMeta.Annotations + assert.Equal(t, yunikorn.GenerateTaskGroupName(true, 0), headAnnotations[yunikorn.TaskGroupNameKey]) + assert.Equal(t, "gangSchedulingStyle=Soft", headAnnotations[yunikorn.TaskGroupPrarameters]) + assert.Equal(t, yunikorn.GenerateTaskGroupName(false, 0), workerAnnotations[yunikorn.TaskGroupNameKey]) +} + func TestInjectLogsSidecar(t *testing.T) { rayJobObj := transformRayJobToCustomObj(dummyRayCustomObj()) params := []struct { From b047ecd6bffbdc37241cc4e06ae4059fa7efa06e Mon Sep 17 00:00:00 2001 From: yuteng Date: Tue, 30 Jul 2024 23:07:16 +0800 Subject: [PATCH 15/30] pflag Signed-off-by: yuteng --- .../go/tasks/plugins/k8s/ray/config.go | 2 +- .../go/tasks/plugins/k8s/ray/config_flags.go | 2 ++ .../plugins/k8s/ray/config_flags_test.go | 28 +++++++++++++++++++ 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/config.go b/flyteplugins/go/tasks/plugins/k8s/ray/config.go index 3efed684c7..17141983cb 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/config.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/config.go @@ -78,7 +78,7 @@ type Config struct { // or 0.0.0.0 (available from all interfaces). By default, this is localhost. DashboardHost string `json:"dashboardHost,omitempty"` - BatchScheduler schedulerConfig.Config `json:"BatchScheduler,omitempty"` + BatchScheduler schedulerConfig.Config `json:"batchScheduler,omitempty"` // DeprecatedNodeIPAddress the IP address of the head node. By default, this is pod ip address. DeprecatedNodeIPAddress string `json:"nodeIPAddress,omitempty" pflag:"-,DEPRECATED. Please use DefaultConfig.[HeadNode|WorkerNode].IPAddress"` diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/config_flags.go b/flyteplugins/go/tasks/plugins/k8s/ray/config_flags.go index 5048869eab..a725e93c5a 100755 --- a/flyteplugins/go/tasks/plugins/k8s/ray/config_flags.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/config_flags.go @@ -55,6 +55,8 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.String(fmt.Sprintf("%v%v", prefix, "serviceType"), defaultConfig.ServiceType, "") cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "includeDashboard"), defaultConfig.IncludeDashboard, "") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "dashboardHost"), defaultConfig.DashboardHost, "") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "batchScheduler.scheduler"), defaultConfig.BatchScheduler.Scheduler, "") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "batchScheduler.parameters"), defaultConfig.BatchScheduler.Parameters, "") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "remoteClusterConfig.name"), defaultConfig.RemoteClusterConfig.Name, "Friendly name of the remote cluster") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "remoteClusterConfig.endpoint"), defaultConfig.RemoteClusterConfig.Endpoint, " Remote K8s cluster endpoint") cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "remoteClusterConfig.enabled"), defaultConfig.RemoteClusterConfig.Enabled, " Boolean flag to enable or disable") diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/config_flags_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/config_flags_test.go index 05871adc51..960752c0a5 100755 --- a/flyteplugins/go/tasks/plugins/k8s/ray/config_flags_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/config_flags_test.go @@ -169,6 +169,34 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_batchScheduler.scheduler", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("batchScheduler.scheduler", testValue) + if vString, err := cmdFlags.GetString("batchScheduler.scheduler"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.BatchScheduler.Scheduler) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_batchScheduler.parameters", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("batchScheduler.parameters", testValue) + if vString, err := cmdFlags.GetString("batchScheduler.parameters"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.BatchScheduler.Parameters) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) t.Run("Test_remoteClusterConfig.name", func(t *testing.T) { t.Run("Override", func(t *testing.T) { From 43b88b7e01fd5799d6f47bb591c19d3c824a55fb Mon Sep 17 00:00:00 2001 From: yuteng Date: Sat, 3 Aug 2024 14:25:53 +0800 Subject: [PATCH 16/30] codespell Signed-off-by: yuteng --- .../k8s/ray/batchscheduler/scheduler/yunikorn/utils_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/utils_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/utils_test.go index 6e14d9b7f7..b857853670 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/utils_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/utils_test.go @@ -33,7 +33,7 @@ func TestGenerateTaskGroupName(t *testing.T) { expect: fmt.Sprintf("%s-%s-%d", TaskGroupGenericName, "worker", 1), }, } - t.Run("Gernerate ray task group name", func(t *testing.T) { + t.Run("Generate ray task group name", func(t *testing.T) { for _, tt := range tests { got := GenerateTaskGroupName(tt.input.isMaster, tt.input.index) assert.Equal(t, tt.expect, got) From 9e73735e0f94d2cc270ef49a331c50bc8f89f7e6 Mon Sep 17 00:00:00 2001 From: yuteng Date: Sat, 17 Aug 2024 12:34:23 +0800 Subject: [PATCH 17/30] scheduler config description Signed-off-by: yuteng --- .../go/tasks/plugins/k8s/ray/batchscheduler/config/config.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config/config.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config/config.go index e1633a5bc0..483d940ca8 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config/config.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config/config.go @@ -1,8 +1,8 @@ package config type Config struct { - Scheduler string `json:"scheduler,omitempty"` - Parameters string `json:"parameters,omitempty"` + Scheduler string `json:"scheduler,omitempty" pflag:", Specify batch scheduler to"` + Parameters string `json:"parameters,omitempty" pflag:", Specify static parameters"` } func NewConfig() Config { From 58b9a0b7de7ba2049637910d61dbcaeb22a525f5 Mon Sep 17 00:00:00 2001 From: yuteng Date: Sat, 17 Aug 2024 12:41:34 +0800 Subject: [PATCH 18/30] use empty scheduler config Signed-off-by: yuteng --- flyteplugins/go/tasks/plugins/k8s/ray/config.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/config.go b/flyteplugins/go/tasks/plugins/k8s/ray/config.go index 17141983cb..cbe12497c3 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/config.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/config.go @@ -24,7 +24,7 @@ var ( DashboardHost: "0.0.0.0", EnableUsageStats: false, ServiceAccount: "default", - BatchScheduler: schedulerConfig.NewConfig(), + BatchScheduler: schedulerConfig.Config{}, Defaults: DefaultConfig{ HeadNode: NodeConfig{ StartParameters: map[string]string{ From 9c00cbbbc3eade77008ddfcc9e9008329b152a50 Mon Sep 17 00:00:00 2001 From: yuteng Date: Mon, 19 Aug 2024 21:57:49 +0800 Subject: [PATCH 19/30] move to k8s repo Signed-off-by: yuteng --- .../k8s/ray/batchscheduler/config/config.go | 21 - .../ray/batchscheduler/config/config_test.go | 15 - .../plugins/k8s/ray/batchscheduler/plugins.go | 28 - .../k8s/ray/batchscheduler/plugins_test.go | 28 - .../scheduler/kubernetes/default.go | 34 - .../scheduler/kubernetes/default_test.go | 103 --- .../scheduler/yunikorn/taskgroup.go | 24 - .../scheduler/yunikorn/taskgroup_test.go | 46 -- .../scheduler/yunikorn/utils.go | 16 - .../scheduler/yunikorn/utils_test.go | 42 -- .../scheduler/yunikorn/yunikorn.go | 152 ---- .../scheduler/yunikorn/yunikorn_test.go | 650 ------------------ .../go/tasks/plugins/k8s/ray/config.go | 2 +- flyteplugins/go/tasks/plugins/k8s/ray/ray.go | 2 +- .../go/tasks/plugins/k8s/ray/ray_test.go | 4 +- 15 files changed, 4 insertions(+), 1163 deletions(-) delete mode 100644 flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config/config.go delete mode 100644 flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config/config_test.go delete mode 100644 flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins.go delete mode 100644 flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins_test.go delete mode 100644 flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/kubernetes/default.go delete mode 100644 flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/kubernetes/default_test.go delete mode 100644 flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/taskgroup.go delete mode 100644 flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/taskgroup_test.go delete mode 100644 flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/utils.go delete mode 100644 flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/utils_test.go delete mode 100644 flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/yunikorn.go delete mode 100644 flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/yunikorn_test.go diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config/config.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config/config.go deleted file mode 100644 index 483d940ca8..0000000000 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config/config.go +++ /dev/null @@ -1,21 +0,0 @@ -package config - -type Config struct { - Scheduler string `json:"scheduler,omitempty" pflag:", Specify batch scheduler to"` - Parameters string `json:"parameters,omitempty" pflag:", Specify static parameters"` -} - -func NewConfig() Config { - return Config{ - Scheduler: "", - Parameters: "", - } -} - -func (b *Config) GetScheduler() string { - return b.Scheduler -} - -func (b *Config) GetParameters() string { - return b.Parameters -} diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config/config_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config/config_test.go deleted file mode 100644 index b7eb9fc354..0000000000 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config/config_test.go +++ /dev/null @@ -1,15 +0,0 @@ -package config - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestNewConfig(t *testing.T) { - t.Run("New scheduler plugin config", func(t *testing.T) { - config := NewConfig() - assert.Equal(t, "", config.GetScheduler()) - assert.Equal(t, "", config.GetParameters()) - }) -} diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins.go deleted file mode 100644 index 6ee85b8043..0000000000 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins.go +++ /dev/null @@ -1,28 +0,0 @@ -package batchscheduler - -import ( - v1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" - schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/kubernetes" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn" -) - -type SchedulerPlugin interface { - GetSchedulerName() string - ParseJob(config *schedulerConfig.Config, metadata *metav1.ObjectMeta, workerGroupsSpec []*plugins.WorkerGroupSpec, pod *v1.PodSpec, primaryContainerIdx int) error - ProcessHead(metadata *metav1.ObjectMeta, head *v1.PodSpec, index int) - ProcessWorker(metadata *metav1.ObjectMeta, worker *v1.PodSpec, index int) - AfterProcess(metadata *metav1.ObjectMeta) -} - -func NewSchedulerPlugin(config *schedulerConfig.Config) SchedulerPlugin { - switch config.GetScheduler() { - case yunikorn.Yunikorn: - return yunikorn.NewYunikornPlugin() - default: - return kubernetes.NewDefaultPlugin() - } -} diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins_test.go deleted file mode 100644 index 11731ec91f..0000000000 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins_test.go +++ /dev/null @@ -1,28 +0,0 @@ -package batchscheduler - -import ( - "testing" - - "github.com/stretchr/testify/assert" - - schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/kubernetes" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn" -) - -func TestCreateSchedulerPlugin(t *testing.T) { - var tests = []struct { - input *schedulerConfig.Config - expect string - }{ - {input: &schedulerConfig.Config{Scheduler: kubernetes.DefaultScheduler}, expect: kubernetes.DefaultScheduler}, - {input: &schedulerConfig.Config{Scheduler: yunikorn.Yunikorn}, expect: yunikorn.Yunikorn}, - {input: &schedulerConfig.Config{Scheduler: "Unknown"}, expect: kubernetes.DefaultScheduler}, - } - for _, tt := range tests { - t.Run("New scheduler plugin", func(t *testing.T) { - p := NewSchedulerPlugin(tt.input) - assert.Equal(t, tt.expect, p.GetSchedulerName()) - }) - } -} diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/kubernetes/default.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/kubernetes/default.go deleted file mode 100644 index 20241f8752..0000000000 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/kubernetes/default.go +++ /dev/null @@ -1,34 +0,0 @@ -package kubernetes - -import ( - v1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" - schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config" -) - -var ( - DefaultScheduler = "default" -) - -type Plugin struct{} - -func NewDefaultPlugin() *Plugin { - return &Plugin{} -} - -func (d *Plugin) GetSchedulerName() string { return DefaultScheduler } - -func (d *Plugin) ParseJob( - config *schedulerConfig.Config, - metadata *metav1.ObjectMeta, - workerGroupsSpec []*plugins.WorkerGroupSpec, - pod *v1.PodSpec, - primaryContainerIdx int, -) error { - return nil -} -func (d *Plugin) ProcessHead(metadata *metav1.ObjectMeta, head *v1.PodSpec, index int) {} -func (d *Plugin) ProcessWorker(metadata *metav1.ObjectMeta, worker *v1.PodSpec, index int) {} -func (d *Plugin) AfterProcess(metadata *metav1.ObjectMeta) {} diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/kubernetes/default_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/kubernetes/default_test.go deleted file mode 100644 index 3857e296bd..0000000000 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/kubernetes/default_test.go +++ /dev/null @@ -1,103 +0,0 @@ -package kubernetes - -import ( - "testing" - - "github.com/stretchr/testify/assert" - v1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/api/resource" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" - schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config" -) - -var ( - metadata = &metav1.ObjectMeta{ - Labels: map[string]string{"others": "extra"}, - Annotations: map[string]string{"others": "extra"}, - } - res = v1.ResourceList{ - "cpu": resource.MustParse("500m"), - "memory": resource.MustParse("1Gi"), - } - podSpec = &v1.PodSpec{ - Containers: []v1.Container{ - { - Resources: v1.ResourceRequirements{ - Requests: res, - }, - }, - }, - NodeSelector: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - } -) - -func TestNewDefaultPlugin(t *testing.T) { - t.Run("New default scheduler plugin", func(t *testing.T) { - p := NewDefaultPlugin() - assert.NotNil(t, p) - assert.Equal(t, DefaultScheduler, p.GetSchedulerName()) - }) -} - -func TestParseJob(t *testing.T) { - t.Run("Default scheduler plugin parse job", func(t *testing.T) { - p := schedulerConfig.NewConfig() - rayWorkersSpec := []*plugins.WorkerGroupSpec{ - { - GroupName: "g1", - Replicas: int32(2), - MinReplicas: int32(1), - MaxReplicas: int32(3), - RayStartParams: map[string]string{ - "parameters": "specific parameters", - }, - }, - } - index := 0 - err := NewDefaultPlugin().ParseJob(&p, metadata, rayWorkersSpec, podSpec, index) - assert.Nil(t, err) - workerspec := rayWorkersSpec[0] - assert.Equal(t, "g1", workerspec.GroupName) - assert.Equal(t, int32(2), workerspec.Replicas) - assert.Equal(t, int32(1), workerspec.MinReplicas) - assert.Equal(t, int32(3), workerspec.MaxReplicas) - assert.Equal(t, map[string]string{"parameters": "specific parameters"}, workerspec.RayStartParams) - assert.Equal(t, map[string]string{"others": "extra"}, metadata.Annotations) - assert.Equal(t, map[string]string{"others": "extra"}, metadata.Labels) - assert.Equal(t, res, podSpec.Containers[index].Resources.Requests) - assert.Equal(t, "", p.GetScheduler()) - assert.Equal(t, "", p.GetParameters()) - }) -} - -func TestProcessHead(t *testing.T) { - t.Run("Default scheduler plugin process head", func(t *testing.T) { - index := 0 - NewDefaultPlugin().ProcessHead(metadata, podSpec, index) - assert.Equal(t, map[string]string{"others": "extra"}, metadata.Annotations) - assert.Equal(t, map[string]string{"others": "extra"}, metadata.Labels) - assert.Equal(t, res, podSpec.Containers[index].Resources.Requests) - }) -} - -func TestProcessWorker(t *testing.T) { - t.Run("Default scheduler plugin preprocess worker", func(t *testing.T) { - index := 0 - NewDefaultPlugin().ProcessWorker(metadata, podSpec, index) - assert.Equal(t, map[string]string{"others": "extra"}, metadata.Annotations) - assert.Equal(t, map[string]string{"others": "extra"}, metadata.Labels) - assert.Equal(t, res, podSpec.Containers[index].Resources.Requests) - }) -} - -func TestAfterProcess(t *testing.T) { - t.Run("Default scheduler plugin afterly process worker", func(t *testing.T) { - NewDefaultPlugin().AfterProcess(metadata) - assert.Equal(t, map[string]string{"others": "extra"}, metadata.Annotations) - assert.Equal(t, map[string]string{"others": "extra"}, metadata.Labels) - }) -} diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/taskgroup.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/taskgroup.go deleted file mode 100644 index 5a52579ce4..0000000000 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/taskgroup.go +++ /dev/null @@ -1,24 +0,0 @@ -package yunikorn - -import ( - "encoding/json" - - v1 "k8s.io/api/core/v1" -) - -type TaskGroup struct { - Name string - MinMember int32 - Labels map[string]string - Annotations map[string]string - MinResource v1.ResourceList - NodeSelector map[string]string - Tolerations []v1.Toleration - Affinity *v1.Affinity - TopologySpreadConstraints []v1.TopologySpreadConstraint -} - -func Marshal(taskGroups []TaskGroup) ([]byte, error) { - info, err := json.Marshal(taskGroups) - return info, err -} diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/taskgroup_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/taskgroup_test.go deleted file mode 100644 index 180e2a6e84..0000000000 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/taskgroup_test.go +++ /dev/null @@ -1,46 +0,0 @@ -package yunikorn - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestMarshal(t *testing.T) { - t1 := TaskGroup{ - Name: "tg1", - MinMember: int32(1), - Labels: map[string]string{"attr": "value"}, - Annotations: map[string]string{"attr": "value"}, - MinResource: res, - NodeSelector: map[string]string{"node": "gpunode"}, - Tolerations: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - } - t2 := TaskGroup{ - Name: "tg2", - MinMember: int32(1), - Labels: map[string]string{"attr": "value"}, - Annotations: map[string]string{"attr": "value"}, - MinResource: res, - NodeSelector: map[string]string{"node": "gpunode"}, - Tolerations: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - } - var tests = []struct { - input []TaskGroup - }{ - {input: nil}, - {input: []TaskGroup{}}, - {input: []TaskGroup{t1}}, - {input: []TaskGroup{t1, t2}}, - } - t.Run("Serialize task groups", func(t *testing.T) { - for _, tt := range tests { - _, err := Marshal(tt.input) - assert.Nil(t, err) - } - }) -} diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/utils.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/utils.go deleted file mode 100644 index ff94255282..0000000000 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/utils.go +++ /dev/null @@ -1,16 +0,0 @@ -package yunikorn - -import ( - "fmt" -) - -const ( - TaskGroupGenericName = "task-group" -) - -func GenerateTaskGroupName(master bool, index int) string { - if master { - return fmt.Sprintf("%s-%s", TaskGroupGenericName, "head") - } - return fmt.Sprintf("%s-%s-%d", TaskGroupGenericName, "worker", index) -} diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/utils_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/utils_test.go deleted file mode 100644 index b857853670..0000000000 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/utils_test.go +++ /dev/null @@ -1,42 +0,0 @@ -package yunikorn - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestGenerateTaskGroupName(t *testing.T) { - type inputFormat struct { - isMaster bool - index int - } - var tests = []struct { - input inputFormat - expect string - }{ - { - input: inputFormat{isMaster: true, index: 0}, - expect: fmt.Sprintf("%s-%s", TaskGroupGenericName, "head"), - }, - { - input: inputFormat{isMaster: true, index: 1}, - expect: fmt.Sprintf("%s-%s", TaskGroupGenericName, "head"), - }, - { - input: inputFormat{isMaster: false, index: 0}, - expect: fmt.Sprintf("%s-%s-%d", TaskGroupGenericName, "worker", 0), - }, - { - input: inputFormat{isMaster: false, index: 1}, - expect: fmt.Sprintf("%s-%s-%d", TaskGroupGenericName, "worker", 1), - }, - } - t.Run("Generate ray task group name", func(t *testing.T) { - for _, tt := range tests { - got := GenerateTaskGroupName(tt.input.isMaster, tt.input.index) - assert.Equal(t, tt.expect, got) - } - }) -} diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/yunikorn.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/yunikorn.go deleted file mode 100644 index 0eda2e8f2e..0000000000 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/yunikorn.go +++ /dev/null @@ -1,152 +0,0 @@ -package yunikorn - -import ( - "errors" - - v1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" - schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config" -) - -const ( - // Pod lebel - Yunikorn = "yunikorn" - TaskGroupNameKey = "yunikorn.apache.org/task-group-name" - TaskGroupsKey = "yunikorn.apache.org/task-groups" - TaskGroupPrarameters = "yunikorn.apache.org/schedulingPolicyParameters" -) - -type Plugin struct { - Annotations map[string]map[string]string - Parameters string -} - -func NewYunikornPlugin() *Plugin { - return &Plugin{ - Annotations: nil, - Parameters: "", - } -} - -func (s *Plugin) GetSchedulerName() string { return Yunikorn } - -func (s *Plugin) ParseJob(config *schedulerConfig.Config, metadata *metav1.ObjectMeta, workerGroupsSpec []*plugins.WorkerGroupSpec, pod *v1.PodSpec, primaryContainerIdx int) error { - s.Annotations = nil - if parameters := config.GetParameters(); len(parameters) > 0 { - s.Parameters = parameters - } - return s.BuildGangInfo(metadata, workerGroupsSpec, pod, primaryContainerIdx) -} - -func (s *Plugin) ProcessHead(metadata *metav1.ObjectMeta, head *v1.PodSpec, index int) { - s.SetSchedulerName(head) - s.AddGangSchedulingAnnotations(GenerateTaskGroupName(true, index), metadata) -} - -func (s *Plugin) ProcessWorker(metadata *metav1.ObjectMeta, worker *v1.PodSpec, index int) { - s.SetSchedulerName(worker) - s.AddGangSchedulingAnnotations(GenerateTaskGroupName(false, index), metadata) -} - -func (s *Plugin) AfterProcess(metadata *metav1.ObjectMeta) { - if metadata == nil { - return - } - delete(metadata.Annotations, TaskGroupNameKey) - delete(metadata.Annotations, TaskGroupsKey) - delete(metadata.Annotations, TaskGroupPrarameters) -} - -func (s *Plugin) SetSchedulerName(spec *v1.PodSpec) { - spec.SchedulerName = s.GetSchedulerName() -} - -func (s *Plugin) BuildGangInfo( - metadata *metav1.ObjectMeta, - workerGroupsSpec []*plugins.WorkerGroupSpec, - pod *v1.PodSpec, - primaryContainerIdx int, -) error { - if pod == nil { - return errors.New("Ray gang scheduling: pod is nil") - } - // Parsing placeholders from the pod resource among head and workers - var labels, annotations map[string]string = nil, nil - if metadata != nil { - labels = metadata.Labels - annotations = metadata.Annotations - } - TaskGroups := make([]TaskGroup, 0) - headName := GenerateTaskGroupName(true, 0) - TaskGroups = append(TaskGroups, TaskGroup{ - Name: headName, - MinMember: 1, - Labels: labels, - Annotations: annotations, - MinResource: pod.Containers[primaryContainerIdx].Resources.Requests, - NodeSelector: pod.NodeSelector, - Affinity: pod.Affinity, - TopologySpreadConstraints: pod.TopologySpreadConstraints, - }) - - s.Annotations = make(map[string]map[string]string, 0) - for index, spec := range workerGroupsSpec { - name := GenerateTaskGroupName(false, index) - tg := TaskGroup{ - Name: name, - MinMember: spec.Replicas, - Labels: labels, - Annotations: annotations, - MinResource: pod.Containers[primaryContainerIdx].Resources.Requests, - NodeSelector: pod.NodeSelector, - Affinity: pod.Affinity, - TopologySpreadConstraints: pod.TopologySpreadConstraints, - } - s.Annotations[name] = map[string]string{ - TaskGroupNameKey: name, - } - TaskGroups = append(TaskGroups, tg) - } - - // Yunikorn head gang scheduling annotations - var info []byte - info, _ = Marshal(TaskGroups) - headAnnotations := make(map[string]string, 0) - headAnnotations[TaskGroupNameKey] = headName - headAnnotations[TaskGroupsKey] = string(info[:]) - if len(s.Parameters) > 0 { - headAnnotations[TaskGroupPrarameters] = s.Parameters - } - s.Annotations[headName] = headAnnotations - return nil -} - -func (s *Plugin) AddGangSchedulingAnnotations(name string, metadata *metav1.ObjectMeta) { - if s.Annotations == nil || metadata == nil { - return - } - - if _, ok := s.Annotations[name]; !ok { - return - } - - if metadata.Annotations == nil { - metadata.Annotations = make(map[string]string, 0) - } - - // Updating Yunikorn gang scheduling annotations - annotations := s.Annotations[name] - if _, ok := annotations[TaskGroupNameKey]; ok { - metadata.Annotations[TaskGroupNameKey] = annotations[TaskGroupNameKey] - } - if _, ok := annotations[TaskGroupsKey]; ok { - metadata.Annotations[TaskGroupsKey] = annotations[TaskGroupsKey] - } - if _, ok := metadata.Annotations[TaskGroupPrarameters]; !ok { - if parameters, ok := annotations[TaskGroupPrarameters]; ok && len(parameters) > 0 { - metadata.Annotations[TaskGroupPrarameters] = parameters - } - } -} diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/yunikorn_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/yunikorn_test.go deleted file mode 100644 index 768ea8e376..0000000000 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn/yunikorn_test.go +++ /dev/null @@ -1,650 +0,0 @@ -package yunikorn - -import ( - "encoding/json" - "testing" - - "github.com/stretchr/testify/assert" - v1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/api/resource" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" - schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config" -) - -var ( - res = v1.ResourceList{ - "cpu": resource.MustParse("500m"), - "memory": resource.MustParse("1Gi"), - } -) - -func TestParseJob(t *testing.T) { - type inputFormat struct { - config *schedulerConfig.Config - metadata *metav1.ObjectMeta - workerGroupNum int - podSpec *v1.PodSpec - index int - } - type expectFormat struct { - raiseErr bool - parameters string - taskGroups []TaskGroup - } - var tests = []struct { - input inputFormat - expect expectFormat - }{ - { - input: inputFormat{ - config: &schedulerConfig.Config{ - Scheduler: "yunikorn", - Parameters: "placeholderTimeoutInSeconds=15 gangSchedulingStyle=Soft", - }, - workerGroupNum: 1, - podSpec: nil, - metadata: &metav1.ObjectMeta{}, - index: 0, - }, - expect: expectFormat{ - raiseErr: true, - parameters: "placeholderTimeoutInSeconds=15 gangSchedulingStyle=Soft", - taskGroups: []TaskGroup{ - { - Name: GenerateTaskGroupName(true, 0), - MinMember: int32(1), - Labels: nil, - Annotations: map[string]string{"others": "extra"}, - MinResource: res, - NodeSelector: nil, - Tolerations: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - }, - { - Name: GenerateTaskGroupName(false, 0), - MinMember: int32(1), - Labels: nil, - Annotations: map[string]string{"others": "extra"}, - MinResource: res, - NodeSelector: nil, - Tolerations: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - }, - }, - }, - }, - { - input: inputFormat{ - config: &schedulerConfig.Config{ - Scheduler: "yunikorn", - Parameters: "placeholderTimeoutInSeconds=15 gangSchedulingStyle=Soft", - }, - workerGroupNum: 1, - podSpec: &v1.PodSpec{ - Containers: []v1.Container{ - { - Resources: v1.ResourceRequirements{ - Requests: res, - }, - }, - }, - NodeSelector: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - }, - metadata: &metav1.ObjectMeta{ - Annotations: map[string]string{"others": "extra"}, - }, - index: 0, - }, - expect: expectFormat{ - raiseErr: false, - parameters: "placeholderTimeoutInSeconds=15 gangSchedulingStyle=Soft", - taskGroups: []TaskGroup{ - { - Name: GenerateTaskGroupName(true, 0), - MinMember: int32(1), - Labels: nil, - Annotations: map[string]string{"others": "extra"}, - MinResource: res, - NodeSelector: nil, - Tolerations: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - }, - { - Name: GenerateTaskGroupName(false, 0), - MinMember: int32(1), - Labels: nil, - Annotations: map[string]string{"others": "extra"}, - MinResource: res, - NodeSelector: nil, - Tolerations: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - }, - }, - }, - }, - } - for _, tt := range tests { - t.Run("Yunikorn parse job", func(t *testing.T) { - workersSpec := make([]*plugins.WorkerGroupSpec, 0) - for index := 0; index < tt.input.workerGroupNum; index++ { - count := 1 * (1 + index) - max := 2 * (1 + index) - workersSpec = append(workersSpec, &plugins.WorkerGroupSpec{ - Replicas: int32(count), - MinReplicas: int32(count), - MaxReplicas: int32(max), - }) - } - p := NewYunikornPlugin() - err := p.ParseJob(tt.input.config, tt.input.metadata, workersSpec, tt.input.podSpec, tt.input.index) - if tt.expect.raiseErr { - assert.NotNil(t, err) - } else { - assert.Nil(t, err) - assert.Equal(t, Yunikorn, p.GetSchedulerName()) - names := []string{GenerateTaskGroupName(true, 0)} - for index := 0; index < tt.input.workerGroupNum; index++ { - names = append(names, GenerateTaskGroupName(false, index)) - } - // task-groups among head and workers - assert.Equal(t, len(names), len(p.Annotations)) - // check head annotations - head := p.Annotations[names[0]] - assert.Equal(t, names[0], head[TaskGroupNameKey]) - assert.Equal(t, tt.expect.parameters, head[TaskGroupPrarameters]) - // task-groups in head - var taskgroups []TaskGroup - err = json.Unmarshal([]byte(head[TaskGroupsKey]), &taskgroups) - assert.Nil(t, err) - assert.Equal(t, len(names), len(taskgroups)) - for index, tg := range taskgroups { - assert.Equal(t, names[index], tg.Name) - } - } - }) - } -} - -func TestProcessHead(t *testing.T) { - type inputFormat struct { - config *schedulerConfig.Config - metadata *metav1.ObjectMeta - workerGroupNum int - podSpec *v1.PodSpec - index int - } - type expectFormat struct { - name string - taskgroupsNum int - parameters string - } - var tests = []struct { - input inputFormat - expect expectFormat - }{ - { - input: inputFormat{ - config: &schedulerConfig.Config{ - Scheduler: "yunikorn", - Parameters: "placeholderTimeoutInSeconds=15 gangSchedulingStyle=Soft", - }, - workerGroupNum: 1, - podSpec: &v1.PodSpec{ - Containers: []v1.Container{ - { - Resources: v1.ResourceRequirements{ - Requests: res, - }, - }, - }, - NodeSelector: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - }, - metadata: &metav1.ObjectMeta{ - Annotations: map[string]string{"others": "extra"}, - }, - index: 0, - }, - expect: expectFormat{ - name: GenerateTaskGroupName(true, 0), - taskgroupsNum: 2, - parameters: "placeholderTimeoutInSeconds=15 gangSchedulingStyle=Soft", - }, - }, - } - for _, tt := range tests { - t.Run("Yunikorn process head", func(t *testing.T) { - workersSpec := make([]*plugins.WorkerGroupSpec, 0) - for index := 0; index < tt.input.workerGroupNum; index++ { - workersSpec = append(workersSpec, &plugins.WorkerGroupSpec{ - Replicas: int32(1), - MinReplicas: int32(1), - MaxReplicas: int32(2), - }) - } - p := NewYunikornPlugin() - err := p.ParseJob(tt.input.config, tt.input.metadata, workersSpec, tt.input.podSpec, tt.input.index) - assert.Nil(t, err) - p.ProcessHead(tt.input.metadata, tt.input.podSpec, tt.input.index) - assert.Equal(t, Yunikorn, tt.input.podSpec.SchedulerName) - assert.Equal(t, tt.expect.name, tt.input.metadata.Annotations[TaskGroupNameKey]) - assert.Equal(t, tt.expect.parameters, tt.input.metadata.Annotations[TaskGroupPrarameters]) - var taskgroups []TaskGroup - err = json.Unmarshal([]byte(tt.input.metadata.Annotations[TaskGroupsKey]), &taskgroups) - assert.Nil(t, err) - assert.Equal(t, tt.expect.taskgroupsNum, len(taskgroups)) - }) - } -} - -func TestProcessWorker(t *testing.T) { - type inputFormat struct { - config *schedulerConfig.Config - metadata *metav1.ObjectMeta - workerGroupNum int - podSpec *v1.PodSpec - index int - } - type expectFormat struct { - name string - taskgroupsNum int - } - var tests = []struct { - input inputFormat - expect expectFormat - }{ - { - input: inputFormat{ - config: &schedulerConfig.Config{ - Scheduler: "yunikorn", - Parameters: "placeholderTimeoutInSeconds=15 gangSchedulingStyle=Soft", - }, - workerGroupNum: 1, - podSpec: &v1.PodSpec{ - Containers: []v1.Container{ - { - Resources: v1.ResourceRequirements{ - Requests: res, - }, - }, - }, - NodeSelector: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - }, - metadata: &metav1.ObjectMeta{ - Annotations: map[string]string{"others": "extra"}, - }, - index: 0, - }, - expect: expectFormat{ - name: GenerateTaskGroupName(false, 0), - taskgroupsNum: 2, - }, - }, - } - for _, tt := range tests { - t.Run("Yunikorn process worker", func(t *testing.T) { - workersSpec := make([]*plugins.WorkerGroupSpec, 0) - for index := 0; index < tt.input.workerGroupNum; index++ { - workersSpec = append(workersSpec, &plugins.WorkerGroupSpec{ - Replicas: int32(1), - MinReplicas: int32(1), - MaxReplicas: int32(2), - }) - } - p := NewYunikornPlugin() - err := p.ParseJob(tt.input.config, tt.input.metadata, workersSpec, tt.input.podSpec, tt.input.index) - assert.Nil(t, err) - p.ProcessWorker(tt.input.metadata, tt.input.podSpec, tt.input.index) - assert.Equal(t, Yunikorn, tt.input.podSpec.SchedulerName) - assert.Equal(t, tt.expect.name, tt.input.metadata.Annotations[TaskGroupNameKey]) - }) - } -} - -func TestAfterProcess(t *testing.T) { - type expectFormat struct { - isNil bool - length int - } - var tests = []struct { - input *metav1.ObjectMeta - expect expectFormat - }{ - { - input: nil, - expect: expectFormat{isNil: true, length: -1}, - }, - { - input: &metav1.ObjectMeta{ - Annotations: map[string]string{ - "others": "extra", - TaskGroupNameKey: "TGName", - TaskGroupsKey: "TGs", - TaskGroupPrarameters: "parameters", - }, - }, - expect: expectFormat{isNil: false, length: 1}, - }, - { - input: &metav1.ObjectMeta{ - Annotations: map[string]string{ - TaskGroupNameKey: "TGName", - TaskGroupsKey: "TGs", - TaskGroupPrarameters: "parameters", - }, - }, - expect: expectFormat{isNil: false, length: 0}, - }, - { - input: &metav1.ObjectMeta{ - Annotations: map[string]string{ - TaskGroupNameKey: "TGName", - TaskGroupsKey: "TGs", - }, - }, - expect: expectFormat{isNil: false, length: 0}, - }, - { - input: &metav1.ObjectMeta{ - Annotations: map[string]string{ - TaskGroupNameKey: "TGName", - }, - }, - expect: expectFormat{isNil: false, length: 0}, - }, - { - input: &metav1.ObjectMeta{}, - expect: expectFormat{isNil: false, length: 0}, - }, - } - for _, tt := range tests { - t.Run("Remove Gang scheduling labels", func(t *testing.T) { - p := NewYunikornPlugin() - p.AfterProcess(tt.input) - if tt.expect.isNil { - assert.Nil(t, tt.input) - } else { - assert.NotNil(t, tt.input) - assert.Equal(t, tt.expect.length, len(tt.input.Annotations)) - } - }) - } -} - -func TestSetSchedulerName(t *testing.T) { - t.Run("Set Scheduler Name", func(t *testing.T) { - p := NewYunikornPlugin() - podSpec := &v1.PodSpec{ - Containers: []v1.Container{ - { - Resources: v1.ResourceRequirements{ - Requests: res, - }, - }, - }, - NodeSelector: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - } - p.SetSchedulerName(podSpec) - assert.Equal(t, p.GetSchedulerName(), podSpec.SchedulerName) - podSpec.SchedulerName = "" - }) -} - -func TestBuildGangInfo(t *testing.T) { - names := []string{GenerateTaskGroupName(true, 0)} - for index := 0; index < 2; index++ { - names = append(names, GenerateTaskGroupName(false, index)) - } - type inputFormat struct { - workerGroupNum int - podSpec *v1.PodSpec - metadata *metav1.ObjectMeta - } - var tests = []struct { - input inputFormat - taskGroups []TaskGroup - }{ - { - input: inputFormat{ - workerGroupNum: 1, - podSpec: nil, - metadata: &metav1.ObjectMeta{ - Annotations: map[string]string{"others": "extra"}, - }, - }, - taskGroups: nil, - }, - { - input: inputFormat{ - workerGroupNum: 1, - podSpec: &v1.PodSpec{ - Containers: []v1.Container{ - { - Resources: v1.ResourceRequirements{ - Requests: res, - }, - }, - }, - NodeSelector: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - }, - metadata: &metav1.ObjectMeta{ - Annotations: map[string]string{"others": "extra"}, - }, - }, - taskGroups: []TaskGroup{ - { - Name: GenerateTaskGroupName(true, 0), - MinMember: int32(1), - Labels: nil, - Annotations: map[string]string{"others": "extra"}, - MinResource: res, - NodeSelector: nil, - Tolerations: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - }, - { - Name: GenerateTaskGroupName(false, 0), - MinMember: int32(1), - Labels: nil, - Annotations: map[string]string{"others": "extra"}, - MinResource: res, - NodeSelector: nil, - Tolerations: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - }, - }, - }, - { - input: inputFormat{ - workerGroupNum: 2, - podSpec: &v1.PodSpec{ - Containers: []v1.Container{ - { - Resources: v1.ResourceRequirements{ - Requests: res, - }, - }, - }, - NodeSelector: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - }, - metadata: &metav1.ObjectMeta{ - Annotations: map[string]string{"others": "extra"}, - }, - }, - taskGroups: []TaskGroup{ - { - Name: GenerateTaskGroupName(true, 0), - MinMember: int32(1), - Labels: nil, - Annotations: map[string]string{"others": "extra"}, - MinResource: res, - NodeSelector: nil, - Tolerations: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - }, - { - Name: GenerateTaskGroupName(false, 0), - MinMember: int32(1), - Labels: nil, - Annotations: map[string]string{"others": "extra"}, - MinResource: res, - NodeSelector: nil, - Tolerations: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - }, - { - Name: GenerateTaskGroupName(false, 1), - MinMember: int32(2), - Labels: nil, - Annotations: map[string]string{"others": "extra"}, - MinResource: res, - NodeSelector: nil, - Tolerations: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - }, - }, - }, - } - for _, tt := range tests { - t.Run("Create Yunikorn gang scheduling annotations", func(t *testing.T) { - workersSpec := make([]*plugins.WorkerGroupSpec, 0) - for index := 0; index < tt.input.workerGroupNum; index++ { - count := 1 * (1 + index) - max := 2 * (1 + index) - workersSpec = append(workersSpec, &plugins.WorkerGroupSpec{ - Replicas: int32(count), - MinReplicas: int32(count), - MaxReplicas: int32(max), - }) - } - p := NewYunikornPlugin() - if err := p.BuildGangInfo(tt.input.metadata, workersSpec, tt.input.podSpec, 0); tt.input.podSpec == nil { - assert.NotNil(t, err) - } else { - assert.Nil(t, err) - // test worker name - for index := 0; index < tt.input.workerGroupNum; index++ { - name := GenerateTaskGroupName(false, index) - if annotations, ok := p.Annotations[name]; ok { - assert.Equal(t, 1, len(annotations)) - assert.Equal(t, name, annotations[TaskGroupNameKey]) - } else { - t.Errorf("Worker group %d annotatiosn miss", index) - } - } - // Test head name and groups - headName := GenerateTaskGroupName(true, 0) - if annotations, ok := p.Annotations[headName]; ok { - info, err := json.Marshal(tt.taskGroups) - assert.Nil(t, err) - assert.Equal(t, 2, len(annotations)) - assert.Equal(t, headName, annotations[TaskGroupNameKey]) - assert.Equal(t, string(info[:]), annotations[TaskGroupsKey]) - } else { - t.Error("Head annotations miss") - } - } - }) - } -} - -func TestAddGangSchedulingAnnotations(t *testing.T) { - taskGroupsAnnotations := map[string]map[string]string{ - GenerateTaskGroupName(true, 0): { - TaskGroupNameKey: GenerateTaskGroupName(true, 0), - TaskGroupsKey: "TGs", - TaskGroupPrarameters: "parameters", - }, - GenerateTaskGroupName(false, 0): { - TaskGroupNameKey: GenerateTaskGroupName(false, 0), - }, - } - type inputFormat struct { - annotations map[string]map[string]string - metadata *metav1.ObjectMeta - name string - } - var tests = []struct { - input inputFormat - expect *metav1.ObjectMeta - }{ - { - input: inputFormat{ - annotations: nil, - metadata: nil, - name: "", - }, - expect: nil, - }, - { - input: inputFormat{ - annotations: taskGroupsAnnotations, - metadata: nil, - name: "", - }, - expect: nil, - }, - { - input: inputFormat{ - annotations: taskGroupsAnnotations, - metadata: &metav1.ObjectMeta{}, - name: "Unknown", - }, - expect: &metav1.ObjectMeta{}, - }, - { - input: inputFormat{ - annotations: taskGroupsAnnotations, - metadata: &metav1.ObjectMeta{}, - name: GenerateTaskGroupName(true, 0), - }, - expect: &metav1.ObjectMeta{ - Annotations: taskGroupsAnnotations[GenerateTaskGroupName(true, 0)], - }, - }, - { - input: inputFormat{ - annotations: taskGroupsAnnotations, - metadata: &metav1.ObjectMeta{}, - name: GenerateTaskGroupName(false, 0), - }, - expect: &metav1.ObjectMeta{ - Annotations: taskGroupsAnnotations[GenerateTaskGroupName(false, 0)], - }, - }, - } - for _, tt := range tests { - t.Run("Check gang scheduling annotatiosn after labeling", func(t *testing.T) { - p := NewYunikornPlugin() - p.Annotations = tt.input.annotations - p.AddGangSchedulingAnnotations(tt.input.name, tt.input.metadata) - if tt.expect == nil { - assert.Nil(t, tt.expect, tt.input.metadata) - } else { - assert.Equal(t, tt.expect.Annotations, tt.input.metadata.Annotations) - } - }) - } -} diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/config.go b/flyteplugins/go/tasks/plugins/k8s/ray/config.go index cbe12497c3..5658ccc353 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/config.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/config.go @@ -9,7 +9,7 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/logs" pluginmachinery "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" - schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config" + schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config" "github.com/flyteorg/flyte/flytestdlib/config" ) diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go index 3c1f01d5e6..0015c2b057 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go @@ -28,7 +28,7 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler" ) const ( diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go index 82b7960d9d..0709535fab 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go @@ -27,8 +27,8 @@ import ( mocks2 "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s/mocks" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" - schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/scheduler/yunikorn" + schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn" ) const ( From f5474ef36b0787160040d28cc8d9a68e3363f446 Mon Sep 17 00:00:00 2001 From: yuteng Date: Mon, 19 Aug 2024 21:57:58 +0800 Subject: [PATCH 20/30] move to k8s repo Signed-off-by: yuteng --- .../k8s/batchscheduler/config/config.go | 21 + .../k8s/batchscheduler/config/config_test.go | 15 + .../plugins/k8s/batchscheduler/plugins.go | 28 + .../k8s/batchscheduler/plugins_test.go | 28 + .../scheduler/kubernetes/default.go | 34 + .../scheduler/kubernetes/default_test.go | 103 +++ .../scheduler/yunikorn/taskgroup.go | 24 + .../scheduler/yunikorn/taskgroup_test.go | 46 ++ .../scheduler/yunikorn/utils.go | 16 + .../scheduler/yunikorn/utils_test.go | 42 ++ .../scheduler/yunikorn/yunikorn.go | 152 ++++ .../scheduler/yunikorn/yunikorn_test.go | 650 ++++++++++++++++++ 12 files changed, 1159 insertions(+) create mode 100644 flyteplugins/go/tasks/plugins/k8s/batchscheduler/config/config.go create mode 100644 flyteplugins/go/tasks/plugins/k8s/batchscheduler/config/config_test.go create mode 100644 flyteplugins/go/tasks/plugins/k8s/batchscheduler/plugins.go create mode 100644 flyteplugins/go/tasks/plugins/k8s/batchscheduler/plugins_test.go create mode 100644 flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/kubernetes/default.go create mode 100644 flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/kubernetes/default_test.go create mode 100644 flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/taskgroup.go create mode 100644 flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/taskgroup_test.go create mode 100644 flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils.go create mode 100644 flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils_test.go create mode 100644 flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn.go create mode 100644 flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn_test.go diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config/config.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config/config.go new file mode 100644 index 0000000000..483d940ca8 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config/config.go @@ -0,0 +1,21 @@ +package config + +type Config struct { + Scheduler string `json:"scheduler,omitempty" pflag:", Specify batch scheduler to"` + Parameters string `json:"parameters,omitempty" pflag:", Specify static parameters"` +} + +func NewConfig() Config { + return Config{ + Scheduler: "", + Parameters: "", + } +} + +func (b *Config) GetScheduler() string { + return b.Scheduler +} + +func (b *Config) GetParameters() string { + return b.Parameters +} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config/config_test.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config/config_test.go new file mode 100644 index 0000000000..b7eb9fc354 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config/config_test.go @@ -0,0 +1,15 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewConfig(t *testing.T) { + t.Run("New scheduler plugin config", func(t *testing.T) { + config := NewConfig() + assert.Equal(t, "", config.GetScheduler()) + assert.Equal(t, "", config.GetParameters()) + }) +} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/plugins.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/plugins.go new file mode 100644 index 0000000000..3f2d277207 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/plugins.go @@ -0,0 +1,28 @@ +package batchscheduler + +import ( + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" + schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/kubernetes" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn" +) + +type SchedulerPlugin interface { + GetSchedulerName() string + ParseJob(config *schedulerConfig.Config, metadata *metav1.ObjectMeta, workerGroupsSpec []*plugins.WorkerGroupSpec, pod *v1.PodSpec, primaryContainerIdx int) error + ProcessHead(metadata *metav1.ObjectMeta, head *v1.PodSpec, index int) + ProcessWorker(metadata *metav1.ObjectMeta, worker *v1.PodSpec, index int) + AfterProcess(metadata *metav1.ObjectMeta) +} + +func NewSchedulerPlugin(config *schedulerConfig.Config) SchedulerPlugin { + switch config.GetScheduler() { + case yunikorn.Yunikorn: + return yunikorn.NewYunikornPlugin() + default: + return kubernetes.NewDefaultPlugin() + } +} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/plugins_test.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/plugins_test.go new file mode 100644 index 0000000000..85a90bad99 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/plugins_test.go @@ -0,0 +1,28 @@ +package batchscheduler + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/kubernetes" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn" +) + +func TestCreateSchedulerPlugin(t *testing.T) { + var tests = []struct { + input *schedulerConfig.Config + expect string + }{ + {input: &schedulerConfig.Config{Scheduler: kubernetes.DefaultScheduler}, expect: kubernetes.DefaultScheduler}, + {input: &schedulerConfig.Config{Scheduler: yunikorn.Yunikorn}, expect: yunikorn.Yunikorn}, + {input: &schedulerConfig.Config{Scheduler: "Unknown"}, expect: kubernetes.DefaultScheduler}, + } + for _, tt := range tests { + t.Run("New scheduler plugin", func(t *testing.T) { + p := NewSchedulerPlugin(tt.input) + assert.Equal(t, tt.expect, p.GetSchedulerName()) + }) + } +} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/kubernetes/default.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/kubernetes/default.go new file mode 100644 index 0000000000..2b6d25721c --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/kubernetes/default.go @@ -0,0 +1,34 @@ +package kubernetes + +import ( + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" + schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config" +) + +var ( + DefaultScheduler = "default" +) + +type Plugin struct{} + +func NewDefaultPlugin() *Plugin { + return &Plugin{} +} + +func (d *Plugin) GetSchedulerName() string { return DefaultScheduler } + +func (d *Plugin) ParseJob( + config *schedulerConfig.Config, + metadata *metav1.ObjectMeta, + workerGroupsSpec []*plugins.WorkerGroupSpec, + pod *v1.PodSpec, + primaryContainerIdx int, +) error { + return nil +} +func (d *Plugin) ProcessHead(metadata *metav1.ObjectMeta, head *v1.PodSpec, index int) {} +func (d *Plugin) ProcessWorker(metadata *metav1.ObjectMeta, worker *v1.PodSpec, index int) {} +func (d *Plugin) AfterProcess(metadata *metav1.ObjectMeta) {} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/kubernetes/default_test.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/kubernetes/default_test.go new file mode 100644 index 0000000000..21ed71ca0d --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/kubernetes/default_test.go @@ -0,0 +1,103 @@ +package kubernetes + +import ( + "testing" + + "github.com/stretchr/testify/assert" + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" + schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config" +) + +var ( + metadata = &metav1.ObjectMeta{ + Labels: map[string]string{"others": "extra"}, + Annotations: map[string]string{"others": "extra"}, + } + res = v1.ResourceList{ + "cpu": resource.MustParse("500m"), + "memory": resource.MustParse("1Gi"), + } + podSpec = &v1.PodSpec{ + Containers: []v1.Container{ + { + Resources: v1.ResourceRequirements{ + Requests: res, + }, + }, + }, + NodeSelector: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + } +) + +func TestNewDefaultPlugin(t *testing.T) { + t.Run("New default scheduler plugin", func(t *testing.T) { + p := NewDefaultPlugin() + assert.NotNil(t, p) + assert.Equal(t, DefaultScheduler, p.GetSchedulerName()) + }) +} + +func TestParseJob(t *testing.T) { + t.Run("Default scheduler plugin parse job", func(t *testing.T) { + p := schedulerConfig.NewConfig() + rayWorkersSpec := []*plugins.WorkerGroupSpec{ + { + GroupName: "g1", + Replicas: int32(2), + MinReplicas: int32(1), + MaxReplicas: int32(3), + RayStartParams: map[string]string{ + "parameters": "specific parameters", + }, + }, + } + index := 0 + err := NewDefaultPlugin().ParseJob(&p, metadata, rayWorkersSpec, podSpec, index) + assert.Nil(t, err) + workerspec := rayWorkersSpec[0] + assert.Equal(t, "g1", workerspec.GroupName) + assert.Equal(t, int32(2), workerspec.Replicas) + assert.Equal(t, int32(1), workerspec.MinReplicas) + assert.Equal(t, int32(3), workerspec.MaxReplicas) + assert.Equal(t, map[string]string{"parameters": "specific parameters"}, workerspec.RayStartParams) + assert.Equal(t, map[string]string{"others": "extra"}, metadata.Annotations) + assert.Equal(t, map[string]string{"others": "extra"}, metadata.Labels) + assert.Equal(t, res, podSpec.Containers[index].Resources.Requests) + assert.Equal(t, "", p.GetScheduler()) + assert.Equal(t, "", p.GetParameters()) + }) +} + +func TestProcessHead(t *testing.T) { + t.Run("Default scheduler plugin process head", func(t *testing.T) { + index := 0 + NewDefaultPlugin().ProcessHead(metadata, podSpec, index) + assert.Equal(t, map[string]string{"others": "extra"}, metadata.Annotations) + assert.Equal(t, map[string]string{"others": "extra"}, metadata.Labels) + assert.Equal(t, res, podSpec.Containers[index].Resources.Requests) + }) +} + +func TestProcessWorker(t *testing.T) { + t.Run("Default scheduler plugin preprocess worker", func(t *testing.T) { + index := 0 + NewDefaultPlugin().ProcessWorker(metadata, podSpec, index) + assert.Equal(t, map[string]string{"others": "extra"}, metadata.Annotations) + assert.Equal(t, map[string]string{"others": "extra"}, metadata.Labels) + assert.Equal(t, res, podSpec.Containers[index].Resources.Requests) + }) +} + +func TestAfterProcess(t *testing.T) { + t.Run("Default scheduler plugin afterly process worker", func(t *testing.T) { + NewDefaultPlugin().AfterProcess(metadata) + assert.Equal(t, map[string]string{"others": "extra"}, metadata.Annotations) + assert.Equal(t, map[string]string{"others": "extra"}, metadata.Labels) + }) +} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/taskgroup.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/taskgroup.go new file mode 100644 index 0000000000..5a52579ce4 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/taskgroup.go @@ -0,0 +1,24 @@ +package yunikorn + +import ( + "encoding/json" + + v1 "k8s.io/api/core/v1" +) + +type TaskGroup struct { + Name string + MinMember int32 + Labels map[string]string + Annotations map[string]string + MinResource v1.ResourceList + NodeSelector map[string]string + Tolerations []v1.Toleration + Affinity *v1.Affinity + TopologySpreadConstraints []v1.TopologySpreadConstraint +} + +func Marshal(taskGroups []TaskGroup) ([]byte, error) { + info, err := json.Marshal(taskGroups) + return info, err +} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/taskgroup_test.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/taskgroup_test.go new file mode 100644 index 0000000000..180e2a6e84 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/taskgroup_test.go @@ -0,0 +1,46 @@ +package yunikorn + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMarshal(t *testing.T) { + t1 := TaskGroup{ + Name: "tg1", + MinMember: int32(1), + Labels: map[string]string{"attr": "value"}, + Annotations: map[string]string{"attr": "value"}, + MinResource: res, + NodeSelector: map[string]string{"node": "gpunode"}, + Tolerations: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + } + t2 := TaskGroup{ + Name: "tg2", + MinMember: int32(1), + Labels: map[string]string{"attr": "value"}, + Annotations: map[string]string{"attr": "value"}, + MinResource: res, + NodeSelector: map[string]string{"node": "gpunode"}, + Tolerations: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + } + var tests = []struct { + input []TaskGroup + }{ + {input: nil}, + {input: []TaskGroup{}}, + {input: []TaskGroup{t1}}, + {input: []TaskGroup{t1, t2}}, + } + t.Run("Serialize task groups", func(t *testing.T) { + for _, tt := range tests { + _, err := Marshal(tt.input) + assert.Nil(t, err) + } + }) +} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils.go new file mode 100644 index 0000000000..ff94255282 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils.go @@ -0,0 +1,16 @@ +package yunikorn + +import ( + "fmt" +) + +const ( + TaskGroupGenericName = "task-group" +) + +func GenerateTaskGroupName(master bool, index int) string { + if master { + return fmt.Sprintf("%s-%s", TaskGroupGenericName, "head") + } + return fmt.Sprintf("%s-%s-%d", TaskGroupGenericName, "worker", index) +} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils_test.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils_test.go new file mode 100644 index 0000000000..b857853670 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils_test.go @@ -0,0 +1,42 @@ +package yunikorn + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGenerateTaskGroupName(t *testing.T) { + type inputFormat struct { + isMaster bool + index int + } + var tests = []struct { + input inputFormat + expect string + }{ + { + input: inputFormat{isMaster: true, index: 0}, + expect: fmt.Sprintf("%s-%s", TaskGroupGenericName, "head"), + }, + { + input: inputFormat{isMaster: true, index: 1}, + expect: fmt.Sprintf("%s-%s", TaskGroupGenericName, "head"), + }, + { + input: inputFormat{isMaster: false, index: 0}, + expect: fmt.Sprintf("%s-%s-%d", TaskGroupGenericName, "worker", 0), + }, + { + input: inputFormat{isMaster: false, index: 1}, + expect: fmt.Sprintf("%s-%s-%d", TaskGroupGenericName, "worker", 1), + }, + } + t.Run("Generate ray task group name", func(t *testing.T) { + for _, tt := range tests { + got := GenerateTaskGroupName(tt.input.isMaster, tt.input.index) + assert.Equal(t, tt.expect, got) + } + }) +} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn.go new file mode 100644 index 0000000000..554191c175 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn.go @@ -0,0 +1,152 @@ +package yunikorn + +import ( + "errors" + + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" + schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config" +) + +const ( + // Pod lebel + Yunikorn = "yunikorn" + TaskGroupNameKey = "yunikorn.apache.org/task-group-name" + TaskGroupsKey = "yunikorn.apache.org/task-groups" + TaskGroupPrarameters = "yunikorn.apache.org/schedulingPolicyParameters" +) + +type Plugin struct { + Annotations map[string]map[string]string + Parameters string +} + +func NewYunikornPlugin() *Plugin { + return &Plugin{ + Annotations: nil, + Parameters: "", + } +} + +func (s *Plugin) GetSchedulerName() string { return Yunikorn } + +func (s *Plugin) ParseJob(config *schedulerConfig.Config, metadata *metav1.ObjectMeta, workerGroupsSpec []*plugins.WorkerGroupSpec, pod *v1.PodSpec, primaryContainerIdx int) error { + s.Annotations = nil + if parameters := config.GetParameters(); len(parameters) > 0 { + s.Parameters = parameters + } + return s.BuildGangInfo(metadata, workerGroupsSpec, pod, primaryContainerIdx) +} + +func (s *Plugin) ProcessHead(metadata *metav1.ObjectMeta, head *v1.PodSpec, index int) { + s.SetSchedulerName(head) + s.AddGangSchedulingAnnotations(GenerateTaskGroupName(true, index), metadata) +} + +func (s *Plugin) ProcessWorker(metadata *metav1.ObjectMeta, worker *v1.PodSpec, index int) { + s.SetSchedulerName(worker) + s.AddGangSchedulingAnnotations(GenerateTaskGroupName(false, index), metadata) +} + +func (s *Plugin) AfterProcess(metadata *metav1.ObjectMeta) { + if metadata == nil { + return + } + delete(metadata.Annotations, TaskGroupNameKey) + delete(metadata.Annotations, TaskGroupsKey) + delete(metadata.Annotations, TaskGroupPrarameters) +} + +func (s *Plugin) SetSchedulerName(spec *v1.PodSpec) { + spec.SchedulerName = s.GetSchedulerName() +} + +func (s *Plugin) BuildGangInfo( + metadata *metav1.ObjectMeta, + workerGroupsSpec []*plugins.WorkerGroupSpec, + pod *v1.PodSpec, + primaryContainerIdx int, +) error { + if pod == nil { + return errors.New("Ray gang scheduling: pod is nil") + } + // Parsing placeholders from the pod resource among head and workers + var labels, annotations map[string]string = nil, nil + if metadata != nil { + labels = metadata.Labels + annotations = metadata.Annotations + } + TaskGroups := make([]TaskGroup, 0) + headName := GenerateTaskGroupName(true, 0) + TaskGroups = append(TaskGroups, TaskGroup{ + Name: headName, + MinMember: 1, + Labels: labels, + Annotations: annotations, + MinResource: pod.Containers[primaryContainerIdx].Resources.Requests, + NodeSelector: pod.NodeSelector, + Affinity: pod.Affinity, + TopologySpreadConstraints: pod.TopologySpreadConstraints, + }) + + s.Annotations = make(map[string]map[string]string, 0) + for index, spec := range workerGroupsSpec { + name := GenerateTaskGroupName(false, index) + tg := TaskGroup{ + Name: name, + MinMember: spec.Replicas, + Labels: labels, + Annotations: annotations, + MinResource: pod.Containers[primaryContainerIdx].Resources.Requests, + NodeSelector: pod.NodeSelector, + Affinity: pod.Affinity, + TopologySpreadConstraints: pod.TopologySpreadConstraints, + } + s.Annotations[name] = map[string]string{ + TaskGroupNameKey: name, + } + TaskGroups = append(TaskGroups, tg) + } + + // Yunikorn head gang scheduling annotations + var info []byte + info, _ = Marshal(TaskGroups) + headAnnotations := make(map[string]string, 0) + headAnnotations[TaskGroupNameKey] = headName + headAnnotations[TaskGroupsKey] = string(info[:]) + if len(s.Parameters) > 0 { + headAnnotations[TaskGroupPrarameters] = s.Parameters + } + s.Annotations[headName] = headAnnotations + return nil +} + +func (s *Plugin) AddGangSchedulingAnnotations(name string, metadata *metav1.ObjectMeta) { + if s.Annotations == nil || metadata == nil { + return + } + + if _, ok := s.Annotations[name]; !ok { + return + } + + if metadata.Annotations == nil { + metadata.Annotations = make(map[string]string, 0) + } + + // Updating Yunikorn gang scheduling annotations + annotations := s.Annotations[name] + if _, ok := annotations[TaskGroupNameKey]; ok { + metadata.Annotations[TaskGroupNameKey] = annotations[TaskGroupNameKey] + } + if _, ok := annotations[TaskGroupsKey]; ok { + metadata.Annotations[TaskGroupsKey] = annotations[TaskGroupsKey] + } + if _, ok := metadata.Annotations[TaskGroupPrarameters]; !ok { + if parameters, ok := annotations[TaskGroupPrarameters]; ok && len(parameters) > 0 { + metadata.Annotations[TaskGroupPrarameters] = parameters + } + } +} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn_test.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn_test.go new file mode 100644 index 0000000000..28214e9d14 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn_test.go @@ -0,0 +1,650 @@ +package yunikorn + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" + schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config" +) + +var ( + res = v1.ResourceList{ + "cpu": resource.MustParse("500m"), + "memory": resource.MustParse("1Gi"), + } +) + +func TestParseJob(t *testing.T) { + type inputFormat struct { + config *schedulerConfig.Config + metadata *metav1.ObjectMeta + workerGroupNum int + podSpec *v1.PodSpec + index int + } + type expectFormat struct { + raiseErr bool + parameters string + taskGroups []TaskGroup + } + var tests = []struct { + input inputFormat + expect expectFormat + }{ + { + input: inputFormat{ + config: &schedulerConfig.Config{ + Scheduler: "yunikorn", + Parameters: "placeholderTimeoutInSeconds=15 gangSchedulingStyle=Soft", + }, + workerGroupNum: 1, + podSpec: nil, + metadata: &metav1.ObjectMeta{}, + index: 0, + }, + expect: expectFormat{ + raiseErr: true, + parameters: "placeholderTimeoutInSeconds=15 gangSchedulingStyle=Soft", + taskGroups: []TaskGroup{ + { + Name: GenerateTaskGroupName(true, 0), + MinMember: int32(1), + Labels: nil, + Annotations: map[string]string{"others": "extra"}, + MinResource: res, + NodeSelector: nil, + Tolerations: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + }, + { + Name: GenerateTaskGroupName(false, 0), + MinMember: int32(1), + Labels: nil, + Annotations: map[string]string{"others": "extra"}, + MinResource: res, + NodeSelector: nil, + Tolerations: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + }, + }, + }, + }, + { + input: inputFormat{ + config: &schedulerConfig.Config{ + Scheduler: "yunikorn", + Parameters: "placeholderTimeoutInSeconds=15 gangSchedulingStyle=Soft", + }, + workerGroupNum: 1, + podSpec: &v1.PodSpec{ + Containers: []v1.Container{ + { + Resources: v1.ResourceRequirements{ + Requests: res, + }, + }, + }, + NodeSelector: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + }, + metadata: &metav1.ObjectMeta{ + Annotations: map[string]string{"others": "extra"}, + }, + index: 0, + }, + expect: expectFormat{ + raiseErr: false, + parameters: "placeholderTimeoutInSeconds=15 gangSchedulingStyle=Soft", + taskGroups: []TaskGroup{ + { + Name: GenerateTaskGroupName(true, 0), + MinMember: int32(1), + Labels: nil, + Annotations: map[string]string{"others": "extra"}, + MinResource: res, + NodeSelector: nil, + Tolerations: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + }, + { + Name: GenerateTaskGroupName(false, 0), + MinMember: int32(1), + Labels: nil, + Annotations: map[string]string{"others": "extra"}, + MinResource: res, + NodeSelector: nil, + Tolerations: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run("Yunikorn parse job", func(t *testing.T) { + workersSpec := make([]*plugins.WorkerGroupSpec, 0) + for index := 0; index < tt.input.workerGroupNum; index++ { + count := 1 * (1 + index) + max := 2 * (1 + index) + workersSpec = append(workersSpec, &plugins.WorkerGroupSpec{ + Replicas: int32(count), + MinReplicas: int32(count), + MaxReplicas: int32(max), + }) + } + p := NewYunikornPlugin() + err := p.ParseJob(tt.input.config, tt.input.metadata, workersSpec, tt.input.podSpec, tt.input.index) + if tt.expect.raiseErr { + assert.NotNil(t, err) + } else { + assert.Nil(t, err) + assert.Equal(t, Yunikorn, p.GetSchedulerName()) + names := []string{GenerateTaskGroupName(true, 0)} + for index := 0; index < tt.input.workerGroupNum; index++ { + names = append(names, GenerateTaskGroupName(false, index)) + } + // task-groups among head and workers + assert.Equal(t, len(names), len(p.Annotations)) + // check head annotations + head := p.Annotations[names[0]] + assert.Equal(t, names[0], head[TaskGroupNameKey]) + assert.Equal(t, tt.expect.parameters, head[TaskGroupPrarameters]) + // task-groups in head + var taskgroups []TaskGroup + err = json.Unmarshal([]byte(head[TaskGroupsKey]), &taskgroups) + assert.Nil(t, err) + assert.Equal(t, len(names), len(taskgroups)) + for index, tg := range taskgroups { + assert.Equal(t, names[index], tg.Name) + } + } + }) + } +} + +func TestProcessHead(t *testing.T) { + type inputFormat struct { + config *schedulerConfig.Config + metadata *metav1.ObjectMeta + workerGroupNum int + podSpec *v1.PodSpec + index int + } + type expectFormat struct { + name string + taskgroupsNum int + parameters string + } + var tests = []struct { + input inputFormat + expect expectFormat + }{ + { + input: inputFormat{ + config: &schedulerConfig.Config{ + Scheduler: "yunikorn", + Parameters: "placeholderTimeoutInSeconds=15 gangSchedulingStyle=Soft", + }, + workerGroupNum: 1, + podSpec: &v1.PodSpec{ + Containers: []v1.Container{ + { + Resources: v1.ResourceRequirements{ + Requests: res, + }, + }, + }, + NodeSelector: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + }, + metadata: &metav1.ObjectMeta{ + Annotations: map[string]string{"others": "extra"}, + }, + index: 0, + }, + expect: expectFormat{ + name: GenerateTaskGroupName(true, 0), + taskgroupsNum: 2, + parameters: "placeholderTimeoutInSeconds=15 gangSchedulingStyle=Soft", + }, + }, + } + for _, tt := range tests { + t.Run("Yunikorn process head", func(t *testing.T) { + workersSpec := make([]*plugins.WorkerGroupSpec, 0) + for index := 0; index < tt.input.workerGroupNum; index++ { + workersSpec = append(workersSpec, &plugins.WorkerGroupSpec{ + Replicas: int32(1), + MinReplicas: int32(1), + MaxReplicas: int32(2), + }) + } + p := NewYunikornPlugin() + err := p.ParseJob(tt.input.config, tt.input.metadata, workersSpec, tt.input.podSpec, tt.input.index) + assert.Nil(t, err) + p.ProcessHead(tt.input.metadata, tt.input.podSpec, tt.input.index) + assert.Equal(t, Yunikorn, tt.input.podSpec.SchedulerName) + assert.Equal(t, tt.expect.name, tt.input.metadata.Annotations[TaskGroupNameKey]) + assert.Equal(t, tt.expect.parameters, tt.input.metadata.Annotations[TaskGroupPrarameters]) + var taskgroups []TaskGroup + err = json.Unmarshal([]byte(tt.input.metadata.Annotations[TaskGroupsKey]), &taskgroups) + assert.Nil(t, err) + assert.Equal(t, tt.expect.taskgroupsNum, len(taskgroups)) + }) + } +} + +func TestProcessWorker(t *testing.T) { + type inputFormat struct { + config *schedulerConfig.Config + metadata *metav1.ObjectMeta + workerGroupNum int + podSpec *v1.PodSpec + index int + } + type expectFormat struct { + name string + taskgroupsNum int + } + var tests = []struct { + input inputFormat + expect expectFormat + }{ + { + input: inputFormat{ + config: &schedulerConfig.Config{ + Scheduler: "yunikorn", + Parameters: "placeholderTimeoutInSeconds=15 gangSchedulingStyle=Soft", + }, + workerGroupNum: 1, + podSpec: &v1.PodSpec{ + Containers: []v1.Container{ + { + Resources: v1.ResourceRequirements{ + Requests: res, + }, + }, + }, + NodeSelector: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + }, + metadata: &metav1.ObjectMeta{ + Annotations: map[string]string{"others": "extra"}, + }, + index: 0, + }, + expect: expectFormat{ + name: GenerateTaskGroupName(false, 0), + taskgroupsNum: 2, + }, + }, + } + for _, tt := range tests { + t.Run("Yunikorn process worker", func(t *testing.T) { + workersSpec := make([]*plugins.WorkerGroupSpec, 0) + for index := 0; index < tt.input.workerGroupNum; index++ { + workersSpec = append(workersSpec, &plugins.WorkerGroupSpec{ + Replicas: int32(1), + MinReplicas: int32(1), + MaxReplicas: int32(2), + }) + } + p := NewYunikornPlugin() + err := p.ParseJob(tt.input.config, tt.input.metadata, workersSpec, tt.input.podSpec, tt.input.index) + assert.Nil(t, err) + p.ProcessWorker(tt.input.metadata, tt.input.podSpec, tt.input.index) + assert.Equal(t, Yunikorn, tt.input.podSpec.SchedulerName) + assert.Equal(t, tt.expect.name, tt.input.metadata.Annotations[TaskGroupNameKey]) + }) + } +} + +func TestAfterProcess(t *testing.T) { + type expectFormat struct { + isNil bool + length int + } + var tests = []struct { + input *metav1.ObjectMeta + expect expectFormat + }{ + { + input: nil, + expect: expectFormat{isNil: true, length: -1}, + }, + { + input: &metav1.ObjectMeta{ + Annotations: map[string]string{ + "others": "extra", + TaskGroupNameKey: "TGName", + TaskGroupsKey: "TGs", + TaskGroupPrarameters: "parameters", + }, + }, + expect: expectFormat{isNil: false, length: 1}, + }, + { + input: &metav1.ObjectMeta{ + Annotations: map[string]string{ + TaskGroupNameKey: "TGName", + TaskGroupsKey: "TGs", + TaskGroupPrarameters: "parameters", + }, + }, + expect: expectFormat{isNil: false, length: 0}, + }, + { + input: &metav1.ObjectMeta{ + Annotations: map[string]string{ + TaskGroupNameKey: "TGName", + TaskGroupsKey: "TGs", + }, + }, + expect: expectFormat{isNil: false, length: 0}, + }, + { + input: &metav1.ObjectMeta{ + Annotations: map[string]string{ + TaskGroupNameKey: "TGName", + }, + }, + expect: expectFormat{isNil: false, length: 0}, + }, + { + input: &metav1.ObjectMeta{}, + expect: expectFormat{isNil: false, length: 0}, + }, + } + for _, tt := range tests { + t.Run("Remove Gang scheduling labels", func(t *testing.T) { + p := NewYunikornPlugin() + p.AfterProcess(tt.input) + if tt.expect.isNil { + assert.Nil(t, tt.input) + } else { + assert.NotNil(t, tt.input) + assert.Equal(t, tt.expect.length, len(tt.input.Annotations)) + } + }) + } +} + +func TestSetSchedulerName(t *testing.T) { + t.Run("Set Scheduler Name", func(t *testing.T) { + p := NewYunikornPlugin() + podSpec := &v1.PodSpec{ + Containers: []v1.Container{ + { + Resources: v1.ResourceRequirements{ + Requests: res, + }, + }, + }, + NodeSelector: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + } + p.SetSchedulerName(podSpec) + assert.Equal(t, p.GetSchedulerName(), podSpec.SchedulerName) + podSpec.SchedulerName = "" + }) +} + +func TestBuildGangInfo(t *testing.T) { + names := []string{GenerateTaskGroupName(true, 0)} + for index := 0; index < 2; index++ { + names = append(names, GenerateTaskGroupName(false, index)) + } + type inputFormat struct { + workerGroupNum int + podSpec *v1.PodSpec + metadata *metav1.ObjectMeta + } + var tests = []struct { + input inputFormat + taskGroups []TaskGroup + }{ + { + input: inputFormat{ + workerGroupNum: 1, + podSpec: nil, + metadata: &metav1.ObjectMeta{ + Annotations: map[string]string{"others": "extra"}, + }, + }, + taskGroups: nil, + }, + { + input: inputFormat{ + workerGroupNum: 1, + podSpec: &v1.PodSpec{ + Containers: []v1.Container{ + { + Resources: v1.ResourceRequirements{ + Requests: res, + }, + }, + }, + NodeSelector: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + }, + metadata: &metav1.ObjectMeta{ + Annotations: map[string]string{"others": "extra"}, + }, + }, + taskGroups: []TaskGroup{ + { + Name: GenerateTaskGroupName(true, 0), + MinMember: int32(1), + Labels: nil, + Annotations: map[string]string{"others": "extra"}, + MinResource: res, + NodeSelector: nil, + Tolerations: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + }, + { + Name: GenerateTaskGroupName(false, 0), + MinMember: int32(1), + Labels: nil, + Annotations: map[string]string{"others": "extra"}, + MinResource: res, + NodeSelector: nil, + Tolerations: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + }, + }, + }, + { + input: inputFormat{ + workerGroupNum: 2, + podSpec: &v1.PodSpec{ + Containers: []v1.Container{ + { + Resources: v1.ResourceRequirements{ + Requests: res, + }, + }, + }, + NodeSelector: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + }, + metadata: &metav1.ObjectMeta{ + Annotations: map[string]string{"others": "extra"}, + }, + }, + taskGroups: []TaskGroup{ + { + Name: GenerateTaskGroupName(true, 0), + MinMember: int32(1), + Labels: nil, + Annotations: map[string]string{"others": "extra"}, + MinResource: res, + NodeSelector: nil, + Tolerations: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + }, + { + Name: GenerateTaskGroupName(false, 0), + MinMember: int32(1), + Labels: nil, + Annotations: map[string]string{"others": "extra"}, + MinResource: res, + NodeSelector: nil, + Tolerations: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + }, + { + Name: GenerateTaskGroupName(false, 1), + MinMember: int32(2), + Labels: nil, + Annotations: map[string]string{"others": "extra"}, + MinResource: res, + NodeSelector: nil, + Tolerations: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + }, + }, + }, + } + for _, tt := range tests { + t.Run("Create Yunikorn gang scheduling annotations", func(t *testing.T) { + workersSpec := make([]*plugins.WorkerGroupSpec, 0) + for index := 0; index < tt.input.workerGroupNum; index++ { + count := 1 * (1 + index) + max := 2 * (1 + index) + workersSpec = append(workersSpec, &plugins.WorkerGroupSpec{ + Replicas: int32(count), + MinReplicas: int32(count), + MaxReplicas: int32(max), + }) + } + p := NewYunikornPlugin() + if err := p.BuildGangInfo(tt.input.metadata, workersSpec, tt.input.podSpec, 0); tt.input.podSpec == nil { + assert.NotNil(t, err) + } else { + assert.Nil(t, err) + // test worker name + for index := 0; index < tt.input.workerGroupNum; index++ { + name := GenerateTaskGroupName(false, index) + if annotations, ok := p.Annotations[name]; ok { + assert.Equal(t, 1, len(annotations)) + assert.Equal(t, name, annotations[TaskGroupNameKey]) + } else { + t.Errorf("Worker group %d annotatiosn miss", index) + } + } + // Test head name and groups + headName := GenerateTaskGroupName(true, 0) + if annotations, ok := p.Annotations[headName]; ok { + info, err := json.Marshal(tt.taskGroups) + assert.Nil(t, err) + assert.Equal(t, 2, len(annotations)) + assert.Equal(t, headName, annotations[TaskGroupNameKey]) + assert.Equal(t, string(info[:]), annotations[TaskGroupsKey]) + } else { + t.Error("Head annotations miss") + } + } + }) + } +} + +func TestAddGangSchedulingAnnotations(t *testing.T) { + taskGroupsAnnotations := map[string]map[string]string{ + GenerateTaskGroupName(true, 0): { + TaskGroupNameKey: GenerateTaskGroupName(true, 0), + TaskGroupsKey: "TGs", + TaskGroupPrarameters: "parameters", + }, + GenerateTaskGroupName(false, 0): { + TaskGroupNameKey: GenerateTaskGroupName(false, 0), + }, + } + type inputFormat struct { + annotations map[string]map[string]string + metadata *metav1.ObjectMeta + name string + } + var tests = []struct { + input inputFormat + expect *metav1.ObjectMeta + }{ + { + input: inputFormat{ + annotations: nil, + metadata: nil, + name: "", + }, + expect: nil, + }, + { + input: inputFormat{ + annotations: taskGroupsAnnotations, + metadata: nil, + name: "", + }, + expect: nil, + }, + { + input: inputFormat{ + annotations: taskGroupsAnnotations, + metadata: &metav1.ObjectMeta{}, + name: "Unknown", + }, + expect: &metav1.ObjectMeta{}, + }, + { + input: inputFormat{ + annotations: taskGroupsAnnotations, + metadata: &metav1.ObjectMeta{}, + name: GenerateTaskGroupName(true, 0), + }, + expect: &metav1.ObjectMeta{ + Annotations: taskGroupsAnnotations[GenerateTaskGroupName(true, 0)], + }, + }, + { + input: inputFormat{ + annotations: taskGroupsAnnotations, + metadata: &metav1.ObjectMeta{}, + name: GenerateTaskGroupName(false, 0), + }, + expect: &metav1.ObjectMeta{ + Annotations: taskGroupsAnnotations[GenerateTaskGroupName(false, 0)], + }, + }, + } + for _, tt := range tests { + t.Run("Check gang scheduling annotatiosn after labeling", func(t *testing.T) { + p := NewYunikornPlugin() + p.Annotations = tt.input.annotations + p.AddGangSchedulingAnnotations(tt.input.name, tt.input.metadata) + if tt.expect == nil { + assert.Nil(t, tt.expect, tt.input.metadata) + } else { + assert.Equal(t, tt.expect.Annotations, tt.input.metadata.Annotations) + } + }) + } +} From ecefa774e9999e81162717bdb7cfb56e22961073 Mon Sep 17 00:00:00 2001 From: yuteng Date: Tue, 27 Aug 2024 20:45:40 +0800 Subject: [PATCH 21/30] refactoring Signed-off-by: yuteng --- .../plugins/k8s/batchscheduler/plugins.go | 18 +- .../k8s/batchscheduler/plugins_test.go | 28 - .../scheduler/kubernetes/default.go | 25 +- .../scheduler/kubernetes/default_test.go | 103 --- .../scheduler/yunikorn/rayhandler.go | 74 ++ .../scheduler/yunikorn/taskgroup_test.go | 46 -- .../scheduler/yunikorn/utils.go | 7 +- .../scheduler/yunikorn/utils_test.go | 42 -- .../scheduler/yunikorn/yunikorn.go | 139 +--- .../scheduler/yunikorn/yunikorn_test.go | 650 ------------------ flyteplugins/go/tasks/plugins/k8s/ray/ray.go | 22 +- .../go/tasks/plugins/k8s/ray/ray_test.go | 32 - 12 files changed, 101 insertions(+), 1085 deletions(-) delete mode 100644 flyteplugins/go/tasks/plugins/k8s/batchscheduler/plugins_test.go delete mode 100644 flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/kubernetes/default_test.go create mode 100644 flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/rayhandler.go delete mode 100644 flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/taskgroup_test.go delete mode 100644 flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils_test.go delete mode 100644 flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn_test.go diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/plugins.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/plugins.go index 3f2d277207..f0625fd180 100644 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/plugins.go +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/plugins.go @@ -1,28 +1,20 @@ package batchscheduler import ( - v1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config" "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/kubernetes" "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn" ) type SchedulerPlugin interface { - GetSchedulerName() string - ParseJob(config *schedulerConfig.Config, metadata *metav1.ObjectMeta, workerGroupsSpec []*plugins.WorkerGroupSpec, pod *v1.PodSpec, primaryContainerIdx int) error - ProcessHead(metadata *metav1.ObjectMeta, head *v1.PodSpec, index int) - ProcessWorker(metadata *metav1.ObjectMeta, worker *v1.PodSpec, index int) - AfterProcess(metadata *metav1.ObjectMeta) + Process(app interface{}) error } -func NewSchedulerPlugin(config *schedulerConfig.Config) SchedulerPlugin { - switch config.GetScheduler() { +func NewSchedulerPlugin(cfg *schedulerConfig.Config) SchedulerPlugin { + switch cfg.GetScheduler() { case yunikorn.Yunikorn: - return yunikorn.NewYunikornPlugin() + return yunikorn.NewPlugin(cfg.GetParameters()) default: - return kubernetes.NewDefaultPlugin() + return kubernetes.NewPlugin() } } diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/plugins_test.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/plugins_test.go deleted file mode 100644 index 85a90bad99..0000000000 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/plugins_test.go +++ /dev/null @@ -1,28 +0,0 @@ -package batchscheduler - -import ( - "testing" - - "github.com/stretchr/testify/assert" - - schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/kubernetes" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn" -) - -func TestCreateSchedulerPlugin(t *testing.T) { - var tests = []struct { - input *schedulerConfig.Config - expect string - }{ - {input: &schedulerConfig.Config{Scheduler: kubernetes.DefaultScheduler}, expect: kubernetes.DefaultScheduler}, - {input: &schedulerConfig.Config{Scheduler: yunikorn.Yunikorn}, expect: yunikorn.Yunikorn}, - {input: &schedulerConfig.Config{Scheduler: "Unknown"}, expect: kubernetes.DefaultScheduler}, - } - for _, tt := range tests { - t.Run("New scheduler plugin", func(t *testing.T) { - p := NewSchedulerPlugin(tt.input) - assert.Equal(t, tt.expect, p.GetSchedulerName()) - }) - } -} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/kubernetes/default.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/kubernetes/default.go index 2b6d25721c..bd5fad5976 100644 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/kubernetes/default.go +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/kubernetes/default.go @@ -1,34 +1,13 @@ package kubernetes -import ( - v1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" - schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config" -) - var ( DefaultScheduler = "default" ) type Plugin struct{} -func NewDefaultPlugin() *Plugin { +func NewPlugin() *Plugin { return &Plugin{} } -func (d *Plugin) GetSchedulerName() string { return DefaultScheduler } - -func (d *Plugin) ParseJob( - config *schedulerConfig.Config, - metadata *metav1.ObjectMeta, - workerGroupsSpec []*plugins.WorkerGroupSpec, - pod *v1.PodSpec, - primaryContainerIdx int, -) error { - return nil -} -func (d *Plugin) ProcessHead(metadata *metav1.ObjectMeta, head *v1.PodSpec, index int) {} -func (d *Plugin) ProcessWorker(metadata *metav1.ObjectMeta, worker *v1.PodSpec, index int) {} -func (d *Plugin) AfterProcess(metadata *metav1.ObjectMeta) {} +func (p *Plugin) Process(app interface{}) error { return nil } diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/kubernetes/default_test.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/kubernetes/default_test.go deleted file mode 100644 index 21ed71ca0d..0000000000 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/kubernetes/default_test.go +++ /dev/null @@ -1,103 +0,0 @@ -package kubernetes - -import ( - "testing" - - "github.com/stretchr/testify/assert" - v1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/api/resource" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" - schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config" -) - -var ( - metadata = &metav1.ObjectMeta{ - Labels: map[string]string{"others": "extra"}, - Annotations: map[string]string{"others": "extra"}, - } - res = v1.ResourceList{ - "cpu": resource.MustParse("500m"), - "memory": resource.MustParse("1Gi"), - } - podSpec = &v1.PodSpec{ - Containers: []v1.Container{ - { - Resources: v1.ResourceRequirements{ - Requests: res, - }, - }, - }, - NodeSelector: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - } -) - -func TestNewDefaultPlugin(t *testing.T) { - t.Run("New default scheduler plugin", func(t *testing.T) { - p := NewDefaultPlugin() - assert.NotNil(t, p) - assert.Equal(t, DefaultScheduler, p.GetSchedulerName()) - }) -} - -func TestParseJob(t *testing.T) { - t.Run("Default scheduler plugin parse job", func(t *testing.T) { - p := schedulerConfig.NewConfig() - rayWorkersSpec := []*plugins.WorkerGroupSpec{ - { - GroupName: "g1", - Replicas: int32(2), - MinReplicas: int32(1), - MaxReplicas: int32(3), - RayStartParams: map[string]string{ - "parameters": "specific parameters", - }, - }, - } - index := 0 - err := NewDefaultPlugin().ParseJob(&p, metadata, rayWorkersSpec, podSpec, index) - assert.Nil(t, err) - workerspec := rayWorkersSpec[0] - assert.Equal(t, "g1", workerspec.GroupName) - assert.Equal(t, int32(2), workerspec.Replicas) - assert.Equal(t, int32(1), workerspec.MinReplicas) - assert.Equal(t, int32(3), workerspec.MaxReplicas) - assert.Equal(t, map[string]string{"parameters": "specific parameters"}, workerspec.RayStartParams) - assert.Equal(t, map[string]string{"others": "extra"}, metadata.Annotations) - assert.Equal(t, map[string]string{"others": "extra"}, metadata.Labels) - assert.Equal(t, res, podSpec.Containers[index].Resources.Requests) - assert.Equal(t, "", p.GetScheduler()) - assert.Equal(t, "", p.GetParameters()) - }) -} - -func TestProcessHead(t *testing.T) { - t.Run("Default scheduler plugin process head", func(t *testing.T) { - index := 0 - NewDefaultPlugin().ProcessHead(metadata, podSpec, index) - assert.Equal(t, map[string]string{"others": "extra"}, metadata.Annotations) - assert.Equal(t, map[string]string{"others": "extra"}, metadata.Labels) - assert.Equal(t, res, podSpec.Containers[index].Resources.Requests) - }) -} - -func TestProcessWorker(t *testing.T) { - t.Run("Default scheduler plugin preprocess worker", func(t *testing.T) { - index := 0 - NewDefaultPlugin().ProcessWorker(metadata, podSpec, index) - assert.Equal(t, map[string]string{"others": "extra"}, metadata.Annotations) - assert.Equal(t, map[string]string{"others": "extra"}, metadata.Labels) - assert.Equal(t, res, podSpec.Containers[index].Resources.Requests) - }) -} - -func TestAfterProcess(t *testing.T) { - t.Run("Default scheduler plugin afterly process worker", func(t *testing.T) { - NewDefaultPlugin().AfterProcess(metadata) - assert.Equal(t, map[string]string{"others": "extra"}, metadata.Annotations) - assert.Equal(t, map[string]string{"others": "extra"}, metadata.Labels) - }) -} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/rayhandler.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/rayhandler.go new file mode 100644 index 0000000000..57f6cc3123 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/rayhandler.go @@ -0,0 +1,74 @@ +package yunikorn + +import ( + "encoding/json" + + v1 "k8s.io/api/core/v1" + rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" +) + +func ProcessRay(paras string, app *rayv1.RayJob) error { + jobname := GenerateTaskGroupName(true, 0) + rayjobSpec := &app.Spec + appSpec := rayjobSpec.RayClusterSpec + TaskGroups := make([]TaskGroup, 1) + for index := range appSpec.WorkerGroupSpecs { + worker := &appSpec.WorkerGroupSpecs[index] + worker.Template.Spec.SchedulerName = Yunikorn + meta := worker.Template.ObjectMeta + spec := worker.Template.Spec + name := GenerateTaskGroupName(false, index) + TaskGroups = append(TaskGroups, TaskGroup{ + Name: name, + MinMember: *worker.Replicas, + //Labels: meta.Labels, + //Annotations: meta.Annotations, + MinResource: Allocation(spec.Containers), + //NodeSelector: spec.NodeSelector, + //Affinity: spec.Affinity, + //TopologySpreadConstraints: spec.TopologySpreadConstraints, + }) + meta.Annotations[TaskGroupNameKey] = name + meta.Annotations[AppID] = jobname + } + headSpec := &appSpec.HeadGroupSpec + headSpec.Template.Spec.SchedulerName = Yunikorn + meta := headSpec.Template.ObjectMeta + spec := headSpec.Template.Spec + headName := GenerateTaskGroupName(true, 0) + TaskGroups[0] = TaskGroup{ + Name: headName, + MinMember: 1, + //Labels: meta.Labels, + //Annotations: meta.Annotations, + MinResource: Allocation(spec.Containers), + //NodeSelector: spec.NodeSelector, + //Affinity: spec.Affinity, + //TopologySpreadConstraints: spec.TopologySpreadConstraints, + } + meta.Annotations[TaskGroupNameKey] = headName + info, err := json.Marshal(TaskGroups) + if err != nil { + return err + } + meta.Annotations[TaskGroupsKey] = string(info[:]) + meta.Annotations[TaskGroupPrarameters] = paras + meta.Annotations[AppID] = jobname + return nil +} + +func Allocation(containers []v1.Container) v1.ResourceList { + totalResources := v1.ResourceList{} + for _, c := range containers { + for name, q := range c.Resources.Limits { + if _, exists := totalResources[name]; !exists { + totalResources[name] = q.DeepCopy() + continue + } + total := totalResources[name] + total.Add(q) + totalResources[name] = total + } + } + return totalResources +} \ No newline at end of file diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/taskgroup_test.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/taskgroup_test.go deleted file mode 100644 index 180e2a6e84..0000000000 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/taskgroup_test.go +++ /dev/null @@ -1,46 +0,0 @@ -package yunikorn - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestMarshal(t *testing.T) { - t1 := TaskGroup{ - Name: "tg1", - MinMember: int32(1), - Labels: map[string]string{"attr": "value"}, - Annotations: map[string]string{"attr": "value"}, - MinResource: res, - NodeSelector: map[string]string{"node": "gpunode"}, - Tolerations: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - } - t2 := TaskGroup{ - Name: "tg2", - MinMember: int32(1), - Labels: map[string]string{"attr": "value"}, - Annotations: map[string]string{"attr": "value"}, - MinResource: res, - NodeSelector: map[string]string{"node": "gpunode"}, - Tolerations: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - } - var tests = []struct { - input []TaskGroup - }{ - {input: nil}, - {input: []TaskGroup{}}, - {input: []TaskGroup{t1}}, - {input: []TaskGroup{t1, t2}}, - } - t.Run("Serialize task groups", func(t *testing.T) { - for _, tt := range tests { - _, err := Marshal(tt.input) - assert.Nil(t, err) - } - }) -} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils.go index ff94255282..b88fcc4d9e 100644 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils.go +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils.go @@ -2,6 +2,8 @@ package yunikorn import ( "fmt" + + "github.com/google/uuid" ) const ( @@ -9,8 +11,9 @@ const ( ) func GenerateTaskGroupName(master bool, index int) string { + uid := uuid.New().String() if master { - return fmt.Sprintf("%s-%s", TaskGroupGenericName, "head") + return fmt.Sprintf("%s-%s-%s", TaskGroupGenericName, "head", uid) } - return fmt.Sprintf("%s-%s-%d", TaskGroupGenericName, "worker", index) + return fmt.Sprintf("%s-%s-%d-%s", TaskGroupGenericName, "worker", index, uid) } diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils_test.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils_test.go deleted file mode 100644 index b857853670..0000000000 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils_test.go +++ /dev/null @@ -1,42 +0,0 @@ -package yunikorn - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestGenerateTaskGroupName(t *testing.T) { - type inputFormat struct { - isMaster bool - index int - } - var tests = []struct { - input inputFormat - expect string - }{ - { - input: inputFormat{isMaster: true, index: 0}, - expect: fmt.Sprintf("%s-%s", TaskGroupGenericName, "head"), - }, - { - input: inputFormat{isMaster: true, index: 1}, - expect: fmt.Sprintf("%s-%s", TaskGroupGenericName, "head"), - }, - { - input: inputFormat{isMaster: false, index: 0}, - expect: fmt.Sprintf("%s-%s-%d", TaskGroupGenericName, "worker", 0), - }, - { - input: inputFormat{isMaster: false, index: 1}, - expect: fmt.Sprintf("%s-%s-%d", TaskGroupGenericName, "worker", 1), - }, - } - t.Run("Generate ray task group name", func(t *testing.T) { - for _, tt := range tests { - got := GenerateTaskGroupName(tt.input.isMaster, tt.input.index) - assert.Equal(t, tt.expect, got) - } - }) -} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn.go index 554191c175..a5cde4d87a 100644 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn.go +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn.go @@ -1,152 +1,33 @@ package yunikorn import ( - "errors" - - v1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" - schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config" + rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" ) const ( // Pod lebel Yunikorn = "yunikorn" + AppID = "yunikorn.apache.org/app-id" TaskGroupNameKey = "yunikorn.apache.org/task-group-name" TaskGroupsKey = "yunikorn.apache.org/task-groups" TaskGroupPrarameters = "yunikorn.apache.org/schedulingPolicyParameters" ) type Plugin struct { - Annotations map[string]map[string]string Parameters string } -func NewYunikornPlugin() *Plugin { +func NewPlugin(parameters string) *Plugin { return &Plugin{ - Annotations: nil, - Parameters: "", - } -} - -func (s *Plugin) GetSchedulerName() string { return Yunikorn } - -func (s *Plugin) ParseJob(config *schedulerConfig.Config, metadata *metav1.ObjectMeta, workerGroupsSpec []*plugins.WorkerGroupSpec, pod *v1.PodSpec, primaryContainerIdx int) error { - s.Annotations = nil - if parameters := config.GetParameters(); len(parameters) > 0 { - s.Parameters = parameters - } - return s.BuildGangInfo(metadata, workerGroupsSpec, pod, primaryContainerIdx) -} - -func (s *Plugin) ProcessHead(metadata *metav1.ObjectMeta, head *v1.PodSpec, index int) { - s.SetSchedulerName(head) - s.AddGangSchedulingAnnotations(GenerateTaskGroupName(true, index), metadata) -} - -func (s *Plugin) ProcessWorker(metadata *metav1.ObjectMeta, worker *v1.PodSpec, index int) { - s.SetSchedulerName(worker) - s.AddGangSchedulingAnnotations(GenerateTaskGroupName(false, index), metadata) -} - -func (s *Plugin) AfterProcess(metadata *metav1.ObjectMeta) { - if metadata == nil { - return + Parameters: parameters, } - delete(metadata.Annotations, TaskGroupNameKey) - delete(metadata.Annotations, TaskGroupsKey) - delete(metadata.Annotations, TaskGroupPrarameters) -} - -func (s *Plugin) SetSchedulerName(spec *v1.PodSpec) { - spec.SchedulerName = s.GetSchedulerName() } -func (s *Plugin) BuildGangInfo( - metadata *metav1.ObjectMeta, - workerGroupsSpec []*plugins.WorkerGroupSpec, - pod *v1.PodSpec, - primaryContainerIdx int, -) error { - if pod == nil { - return errors.New("Ray gang scheduling: pod is nil") - } - // Parsing placeholders from the pod resource among head and workers - var labels, annotations map[string]string = nil, nil - if metadata != nil { - labels = metadata.Labels - annotations = metadata.Annotations - } - TaskGroups := make([]TaskGroup, 0) - headName := GenerateTaskGroupName(true, 0) - TaskGroups = append(TaskGroups, TaskGroup{ - Name: headName, - MinMember: 1, - Labels: labels, - Annotations: annotations, - MinResource: pod.Containers[primaryContainerIdx].Resources.Requests, - NodeSelector: pod.NodeSelector, - Affinity: pod.Affinity, - TopologySpreadConstraints: pod.TopologySpreadConstraints, - }) - - s.Annotations = make(map[string]map[string]string, 0) - for index, spec := range workerGroupsSpec { - name := GenerateTaskGroupName(false, index) - tg := TaskGroup{ - Name: name, - MinMember: spec.Replicas, - Labels: labels, - Annotations: annotations, - MinResource: pod.Containers[primaryContainerIdx].Resources.Requests, - NodeSelector: pod.NodeSelector, - Affinity: pod.Affinity, - TopologySpreadConstraints: pod.TopologySpreadConstraints, - } - s.Annotations[name] = map[string]string{ - TaskGroupNameKey: name, - } - TaskGroups = append(TaskGroups, tg) - } - - // Yunikorn head gang scheduling annotations - var info []byte - info, _ = Marshal(TaskGroups) - headAnnotations := make(map[string]string, 0) - headAnnotations[TaskGroupNameKey] = headName - headAnnotations[TaskGroupsKey] = string(info[:]) - if len(s.Parameters) > 0 { - headAnnotations[TaskGroupPrarameters] = s.Parameters - } - s.Annotations[headName] = headAnnotations - return nil -} - -func (s *Plugin) AddGangSchedulingAnnotations(name string, metadata *metav1.ObjectMeta) { - if s.Annotations == nil || metadata == nil { - return - } - - if _, ok := s.Annotations[name]; !ok { - return - } - - if metadata.Annotations == nil { - metadata.Annotations = make(map[string]string, 0) - } - - // Updating Yunikorn gang scheduling annotations - annotations := s.Annotations[name] - if _, ok := annotations[TaskGroupNameKey]; ok { - metadata.Annotations[TaskGroupNameKey] = annotations[TaskGroupNameKey] - } - if _, ok := annotations[TaskGroupsKey]; ok { - metadata.Annotations[TaskGroupsKey] = annotations[TaskGroupsKey] - } - if _, ok := metadata.Annotations[TaskGroupPrarameters]; !ok { - if parameters, ok := annotations[TaskGroupPrarameters]; ok && len(parameters) > 0 { - metadata.Annotations[TaskGroupPrarameters] = parameters - } +func (p *Plugin) Process(app interface{}) error { + switch v := app.(type) { + case *rayv1.RayJob: + return ProcessRay(p.Parameters , v) + default: + return nil } } diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn_test.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn_test.go deleted file mode 100644 index 28214e9d14..0000000000 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn_test.go +++ /dev/null @@ -1,650 +0,0 @@ -package yunikorn - -import ( - "encoding/json" - "testing" - - "github.com/stretchr/testify/assert" - v1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/api/resource" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" - schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config" -) - -var ( - res = v1.ResourceList{ - "cpu": resource.MustParse("500m"), - "memory": resource.MustParse("1Gi"), - } -) - -func TestParseJob(t *testing.T) { - type inputFormat struct { - config *schedulerConfig.Config - metadata *metav1.ObjectMeta - workerGroupNum int - podSpec *v1.PodSpec - index int - } - type expectFormat struct { - raiseErr bool - parameters string - taskGroups []TaskGroup - } - var tests = []struct { - input inputFormat - expect expectFormat - }{ - { - input: inputFormat{ - config: &schedulerConfig.Config{ - Scheduler: "yunikorn", - Parameters: "placeholderTimeoutInSeconds=15 gangSchedulingStyle=Soft", - }, - workerGroupNum: 1, - podSpec: nil, - metadata: &metav1.ObjectMeta{}, - index: 0, - }, - expect: expectFormat{ - raiseErr: true, - parameters: "placeholderTimeoutInSeconds=15 gangSchedulingStyle=Soft", - taskGroups: []TaskGroup{ - { - Name: GenerateTaskGroupName(true, 0), - MinMember: int32(1), - Labels: nil, - Annotations: map[string]string{"others": "extra"}, - MinResource: res, - NodeSelector: nil, - Tolerations: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - }, - { - Name: GenerateTaskGroupName(false, 0), - MinMember: int32(1), - Labels: nil, - Annotations: map[string]string{"others": "extra"}, - MinResource: res, - NodeSelector: nil, - Tolerations: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - }, - }, - }, - }, - { - input: inputFormat{ - config: &schedulerConfig.Config{ - Scheduler: "yunikorn", - Parameters: "placeholderTimeoutInSeconds=15 gangSchedulingStyle=Soft", - }, - workerGroupNum: 1, - podSpec: &v1.PodSpec{ - Containers: []v1.Container{ - { - Resources: v1.ResourceRequirements{ - Requests: res, - }, - }, - }, - NodeSelector: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - }, - metadata: &metav1.ObjectMeta{ - Annotations: map[string]string{"others": "extra"}, - }, - index: 0, - }, - expect: expectFormat{ - raiseErr: false, - parameters: "placeholderTimeoutInSeconds=15 gangSchedulingStyle=Soft", - taskGroups: []TaskGroup{ - { - Name: GenerateTaskGroupName(true, 0), - MinMember: int32(1), - Labels: nil, - Annotations: map[string]string{"others": "extra"}, - MinResource: res, - NodeSelector: nil, - Tolerations: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - }, - { - Name: GenerateTaskGroupName(false, 0), - MinMember: int32(1), - Labels: nil, - Annotations: map[string]string{"others": "extra"}, - MinResource: res, - NodeSelector: nil, - Tolerations: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - }, - }, - }, - }, - } - for _, tt := range tests { - t.Run("Yunikorn parse job", func(t *testing.T) { - workersSpec := make([]*plugins.WorkerGroupSpec, 0) - for index := 0; index < tt.input.workerGroupNum; index++ { - count := 1 * (1 + index) - max := 2 * (1 + index) - workersSpec = append(workersSpec, &plugins.WorkerGroupSpec{ - Replicas: int32(count), - MinReplicas: int32(count), - MaxReplicas: int32(max), - }) - } - p := NewYunikornPlugin() - err := p.ParseJob(tt.input.config, tt.input.metadata, workersSpec, tt.input.podSpec, tt.input.index) - if tt.expect.raiseErr { - assert.NotNil(t, err) - } else { - assert.Nil(t, err) - assert.Equal(t, Yunikorn, p.GetSchedulerName()) - names := []string{GenerateTaskGroupName(true, 0)} - for index := 0; index < tt.input.workerGroupNum; index++ { - names = append(names, GenerateTaskGroupName(false, index)) - } - // task-groups among head and workers - assert.Equal(t, len(names), len(p.Annotations)) - // check head annotations - head := p.Annotations[names[0]] - assert.Equal(t, names[0], head[TaskGroupNameKey]) - assert.Equal(t, tt.expect.parameters, head[TaskGroupPrarameters]) - // task-groups in head - var taskgroups []TaskGroup - err = json.Unmarshal([]byte(head[TaskGroupsKey]), &taskgroups) - assert.Nil(t, err) - assert.Equal(t, len(names), len(taskgroups)) - for index, tg := range taskgroups { - assert.Equal(t, names[index], tg.Name) - } - } - }) - } -} - -func TestProcessHead(t *testing.T) { - type inputFormat struct { - config *schedulerConfig.Config - metadata *metav1.ObjectMeta - workerGroupNum int - podSpec *v1.PodSpec - index int - } - type expectFormat struct { - name string - taskgroupsNum int - parameters string - } - var tests = []struct { - input inputFormat - expect expectFormat - }{ - { - input: inputFormat{ - config: &schedulerConfig.Config{ - Scheduler: "yunikorn", - Parameters: "placeholderTimeoutInSeconds=15 gangSchedulingStyle=Soft", - }, - workerGroupNum: 1, - podSpec: &v1.PodSpec{ - Containers: []v1.Container{ - { - Resources: v1.ResourceRequirements{ - Requests: res, - }, - }, - }, - NodeSelector: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - }, - metadata: &metav1.ObjectMeta{ - Annotations: map[string]string{"others": "extra"}, - }, - index: 0, - }, - expect: expectFormat{ - name: GenerateTaskGroupName(true, 0), - taskgroupsNum: 2, - parameters: "placeholderTimeoutInSeconds=15 gangSchedulingStyle=Soft", - }, - }, - } - for _, tt := range tests { - t.Run("Yunikorn process head", func(t *testing.T) { - workersSpec := make([]*plugins.WorkerGroupSpec, 0) - for index := 0; index < tt.input.workerGroupNum; index++ { - workersSpec = append(workersSpec, &plugins.WorkerGroupSpec{ - Replicas: int32(1), - MinReplicas: int32(1), - MaxReplicas: int32(2), - }) - } - p := NewYunikornPlugin() - err := p.ParseJob(tt.input.config, tt.input.metadata, workersSpec, tt.input.podSpec, tt.input.index) - assert.Nil(t, err) - p.ProcessHead(tt.input.metadata, tt.input.podSpec, tt.input.index) - assert.Equal(t, Yunikorn, tt.input.podSpec.SchedulerName) - assert.Equal(t, tt.expect.name, tt.input.metadata.Annotations[TaskGroupNameKey]) - assert.Equal(t, tt.expect.parameters, tt.input.metadata.Annotations[TaskGroupPrarameters]) - var taskgroups []TaskGroup - err = json.Unmarshal([]byte(tt.input.metadata.Annotations[TaskGroupsKey]), &taskgroups) - assert.Nil(t, err) - assert.Equal(t, tt.expect.taskgroupsNum, len(taskgroups)) - }) - } -} - -func TestProcessWorker(t *testing.T) { - type inputFormat struct { - config *schedulerConfig.Config - metadata *metav1.ObjectMeta - workerGroupNum int - podSpec *v1.PodSpec - index int - } - type expectFormat struct { - name string - taskgroupsNum int - } - var tests = []struct { - input inputFormat - expect expectFormat - }{ - { - input: inputFormat{ - config: &schedulerConfig.Config{ - Scheduler: "yunikorn", - Parameters: "placeholderTimeoutInSeconds=15 gangSchedulingStyle=Soft", - }, - workerGroupNum: 1, - podSpec: &v1.PodSpec{ - Containers: []v1.Container{ - { - Resources: v1.ResourceRequirements{ - Requests: res, - }, - }, - }, - NodeSelector: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - }, - metadata: &metav1.ObjectMeta{ - Annotations: map[string]string{"others": "extra"}, - }, - index: 0, - }, - expect: expectFormat{ - name: GenerateTaskGroupName(false, 0), - taskgroupsNum: 2, - }, - }, - } - for _, tt := range tests { - t.Run("Yunikorn process worker", func(t *testing.T) { - workersSpec := make([]*plugins.WorkerGroupSpec, 0) - for index := 0; index < tt.input.workerGroupNum; index++ { - workersSpec = append(workersSpec, &plugins.WorkerGroupSpec{ - Replicas: int32(1), - MinReplicas: int32(1), - MaxReplicas: int32(2), - }) - } - p := NewYunikornPlugin() - err := p.ParseJob(tt.input.config, tt.input.metadata, workersSpec, tt.input.podSpec, tt.input.index) - assert.Nil(t, err) - p.ProcessWorker(tt.input.metadata, tt.input.podSpec, tt.input.index) - assert.Equal(t, Yunikorn, tt.input.podSpec.SchedulerName) - assert.Equal(t, tt.expect.name, tt.input.metadata.Annotations[TaskGroupNameKey]) - }) - } -} - -func TestAfterProcess(t *testing.T) { - type expectFormat struct { - isNil bool - length int - } - var tests = []struct { - input *metav1.ObjectMeta - expect expectFormat - }{ - { - input: nil, - expect: expectFormat{isNil: true, length: -1}, - }, - { - input: &metav1.ObjectMeta{ - Annotations: map[string]string{ - "others": "extra", - TaskGroupNameKey: "TGName", - TaskGroupsKey: "TGs", - TaskGroupPrarameters: "parameters", - }, - }, - expect: expectFormat{isNil: false, length: 1}, - }, - { - input: &metav1.ObjectMeta{ - Annotations: map[string]string{ - TaskGroupNameKey: "TGName", - TaskGroupsKey: "TGs", - TaskGroupPrarameters: "parameters", - }, - }, - expect: expectFormat{isNil: false, length: 0}, - }, - { - input: &metav1.ObjectMeta{ - Annotations: map[string]string{ - TaskGroupNameKey: "TGName", - TaskGroupsKey: "TGs", - }, - }, - expect: expectFormat{isNil: false, length: 0}, - }, - { - input: &metav1.ObjectMeta{ - Annotations: map[string]string{ - TaskGroupNameKey: "TGName", - }, - }, - expect: expectFormat{isNil: false, length: 0}, - }, - { - input: &metav1.ObjectMeta{}, - expect: expectFormat{isNil: false, length: 0}, - }, - } - for _, tt := range tests { - t.Run("Remove Gang scheduling labels", func(t *testing.T) { - p := NewYunikornPlugin() - p.AfterProcess(tt.input) - if tt.expect.isNil { - assert.Nil(t, tt.input) - } else { - assert.NotNil(t, tt.input) - assert.Equal(t, tt.expect.length, len(tt.input.Annotations)) - } - }) - } -} - -func TestSetSchedulerName(t *testing.T) { - t.Run("Set Scheduler Name", func(t *testing.T) { - p := NewYunikornPlugin() - podSpec := &v1.PodSpec{ - Containers: []v1.Container{ - { - Resources: v1.ResourceRequirements{ - Requests: res, - }, - }, - }, - NodeSelector: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - } - p.SetSchedulerName(podSpec) - assert.Equal(t, p.GetSchedulerName(), podSpec.SchedulerName) - podSpec.SchedulerName = "" - }) -} - -func TestBuildGangInfo(t *testing.T) { - names := []string{GenerateTaskGroupName(true, 0)} - for index := 0; index < 2; index++ { - names = append(names, GenerateTaskGroupName(false, index)) - } - type inputFormat struct { - workerGroupNum int - podSpec *v1.PodSpec - metadata *metav1.ObjectMeta - } - var tests = []struct { - input inputFormat - taskGroups []TaskGroup - }{ - { - input: inputFormat{ - workerGroupNum: 1, - podSpec: nil, - metadata: &metav1.ObjectMeta{ - Annotations: map[string]string{"others": "extra"}, - }, - }, - taskGroups: nil, - }, - { - input: inputFormat{ - workerGroupNum: 1, - podSpec: &v1.PodSpec{ - Containers: []v1.Container{ - { - Resources: v1.ResourceRequirements{ - Requests: res, - }, - }, - }, - NodeSelector: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - }, - metadata: &metav1.ObjectMeta{ - Annotations: map[string]string{"others": "extra"}, - }, - }, - taskGroups: []TaskGroup{ - { - Name: GenerateTaskGroupName(true, 0), - MinMember: int32(1), - Labels: nil, - Annotations: map[string]string{"others": "extra"}, - MinResource: res, - NodeSelector: nil, - Tolerations: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - }, - { - Name: GenerateTaskGroupName(false, 0), - MinMember: int32(1), - Labels: nil, - Annotations: map[string]string{"others": "extra"}, - MinResource: res, - NodeSelector: nil, - Tolerations: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - }, - }, - }, - { - input: inputFormat{ - workerGroupNum: 2, - podSpec: &v1.PodSpec{ - Containers: []v1.Container{ - { - Resources: v1.ResourceRequirements{ - Requests: res, - }, - }, - }, - NodeSelector: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - }, - metadata: &metav1.ObjectMeta{ - Annotations: map[string]string{"others": "extra"}, - }, - }, - taskGroups: []TaskGroup{ - { - Name: GenerateTaskGroupName(true, 0), - MinMember: int32(1), - Labels: nil, - Annotations: map[string]string{"others": "extra"}, - MinResource: res, - NodeSelector: nil, - Tolerations: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - }, - { - Name: GenerateTaskGroupName(false, 0), - MinMember: int32(1), - Labels: nil, - Annotations: map[string]string{"others": "extra"}, - MinResource: res, - NodeSelector: nil, - Tolerations: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - }, - { - Name: GenerateTaskGroupName(false, 1), - MinMember: int32(2), - Labels: nil, - Annotations: map[string]string{"others": "extra"}, - MinResource: res, - NodeSelector: nil, - Tolerations: nil, - Affinity: nil, - TopologySpreadConstraints: nil, - }, - }, - }, - } - for _, tt := range tests { - t.Run("Create Yunikorn gang scheduling annotations", func(t *testing.T) { - workersSpec := make([]*plugins.WorkerGroupSpec, 0) - for index := 0; index < tt.input.workerGroupNum; index++ { - count := 1 * (1 + index) - max := 2 * (1 + index) - workersSpec = append(workersSpec, &plugins.WorkerGroupSpec{ - Replicas: int32(count), - MinReplicas: int32(count), - MaxReplicas: int32(max), - }) - } - p := NewYunikornPlugin() - if err := p.BuildGangInfo(tt.input.metadata, workersSpec, tt.input.podSpec, 0); tt.input.podSpec == nil { - assert.NotNil(t, err) - } else { - assert.Nil(t, err) - // test worker name - for index := 0; index < tt.input.workerGroupNum; index++ { - name := GenerateTaskGroupName(false, index) - if annotations, ok := p.Annotations[name]; ok { - assert.Equal(t, 1, len(annotations)) - assert.Equal(t, name, annotations[TaskGroupNameKey]) - } else { - t.Errorf("Worker group %d annotatiosn miss", index) - } - } - // Test head name and groups - headName := GenerateTaskGroupName(true, 0) - if annotations, ok := p.Annotations[headName]; ok { - info, err := json.Marshal(tt.taskGroups) - assert.Nil(t, err) - assert.Equal(t, 2, len(annotations)) - assert.Equal(t, headName, annotations[TaskGroupNameKey]) - assert.Equal(t, string(info[:]), annotations[TaskGroupsKey]) - } else { - t.Error("Head annotations miss") - } - } - }) - } -} - -func TestAddGangSchedulingAnnotations(t *testing.T) { - taskGroupsAnnotations := map[string]map[string]string{ - GenerateTaskGroupName(true, 0): { - TaskGroupNameKey: GenerateTaskGroupName(true, 0), - TaskGroupsKey: "TGs", - TaskGroupPrarameters: "parameters", - }, - GenerateTaskGroupName(false, 0): { - TaskGroupNameKey: GenerateTaskGroupName(false, 0), - }, - } - type inputFormat struct { - annotations map[string]map[string]string - metadata *metav1.ObjectMeta - name string - } - var tests = []struct { - input inputFormat - expect *metav1.ObjectMeta - }{ - { - input: inputFormat{ - annotations: nil, - metadata: nil, - name: "", - }, - expect: nil, - }, - { - input: inputFormat{ - annotations: taskGroupsAnnotations, - metadata: nil, - name: "", - }, - expect: nil, - }, - { - input: inputFormat{ - annotations: taskGroupsAnnotations, - metadata: &metav1.ObjectMeta{}, - name: "Unknown", - }, - expect: &metav1.ObjectMeta{}, - }, - { - input: inputFormat{ - annotations: taskGroupsAnnotations, - metadata: &metav1.ObjectMeta{}, - name: GenerateTaskGroupName(true, 0), - }, - expect: &metav1.ObjectMeta{ - Annotations: taskGroupsAnnotations[GenerateTaskGroupName(true, 0)], - }, - }, - { - input: inputFormat{ - annotations: taskGroupsAnnotations, - metadata: &metav1.ObjectMeta{}, - name: GenerateTaskGroupName(false, 0), - }, - expect: &metav1.ObjectMeta{ - Annotations: taskGroupsAnnotations[GenerateTaskGroupName(false, 0)], - }, - }, - } - for _, tt := range tests { - t.Run("Check gang scheduling annotatiosn after labeling", func(t *testing.T) { - p := NewYunikornPlugin() - p.Annotations = tt.input.annotations - p.AddGangSchedulingAnnotations(tt.input.name, tt.input.metadata) - if tt.expect == nil { - assert.Nil(t, tt.expect, tt.input.metadata) - } else { - assert.Equal(t, tt.expect.Annotations, tt.input.metadata.Annotations) - } - }) - } -} diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go index 0015c2b057..7aa0b014fd 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go @@ -121,7 +121,10 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC headPodSpec := podSpec.DeepCopy() rayjob, err := constructRayJob(taskCtx, rayJob, objectMeta, *podSpec, headPodSpec, headNodeRayStartParams, primaryContainerIdx, *primaryContainer) - + if err != nil { + return rayjob, err + } + err = batchscheduler.NewSchedulerPlugin(&cfg.BatchScheduler).Process(rayjob) return rayjob, err } @@ -129,18 +132,6 @@ func constructRayJob(taskCtx pluginsCore.TaskExecutionContext, rayJob plugins.Ra var err error enableIngress := true cfg := GetConfig() - schedulerPlugin := batchscheduler.NewSchedulerPlugin(&cfg.BatchScheduler) - err = schedulerPlugin.ParseJob( - &cfg.BatchScheduler, - objectMeta, - rayJob.RayCluster.WorkerGroupSpec, - &podSpec, - primaryContainerIdx, - ) - if err != nil { - return nil, err - } - schedulerPlugin.ProcessHead(objectMeta, headPodSpec, primaryContainerIdx) rayClusterSpec := rayv1.RayClusterSpec{ HeadGroupSpec: rayv1.HeadGroupSpec{ Template: buildHeadPodTemplate( @@ -156,11 +147,9 @@ func constructRayJob(taskCtx pluginsCore.TaskExecutionContext, rayJob plugins.Ra WorkerGroupSpecs: []rayv1.WorkerGroupSpec{}, EnableInTreeAutoscaling: &rayJob.RayCluster.EnableAutoscaling, } - schedulerPlugin.AfterProcess(objectMeta) - for index, spec := range rayJob.RayCluster.WorkerGroupSpec { + for _, spec := range rayJob.RayCluster.WorkerGroupSpec { workerPodSpec := podSpec.DeepCopy() - schedulerPlugin.ProcessWorker(objectMeta, workerPodSpec, index) workerPodTemplate := buildWorkerPodTemplate( &workerPodSpec.Containers[primaryContainerIdx], workerPodSpec, @@ -202,7 +191,6 @@ func constructRayJob(taskCtx pluginsCore.TaskExecutionContext, rayJob plugins.Ra } rayClusterSpec.WorkerGroupSpecs = append(rayClusterSpec.WorkerGroupSpecs, workerNodeSpec) - schedulerPlugin.AfterProcess(objectMeta) } serviceAccountName := flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go index 0709535fab..7b555e9f23 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go @@ -27,8 +27,6 @@ import ( mocks2 "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s/mocks" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" - schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn" ) const ( @@ -476,36 +474,6 @@ func TestDefaultStartParameters(t *testing.T) { assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Tolerations, toleration) } -func TestYunikornAnnotationsCreate(t *testing.T) { - assert.NoError(t, SetConfig(&Config{ - BatchScheduler: schedulerConfig.Config{ - Scheduler: "yunikorn", - Parameters: "gangSchedulingStyle=Soft", - }, - })) - rayJobResourceHandler := rayJobResourceHandler{} - rayJob := &plugins.RayJob{ - RayCluster: &plugins.RayCluster{ - HeadGroupSpec: &plugins.HeadGroupSpec{}, - WorkerGroupSpec: []*plugins.WorkerGroupSpec{{GroupName: workerGroupName, Replicas: 3, MinReplicas: 3, MaxReplicas: 3}}, - EnableAutoscaling: true, - }, - ShutdownAfterJobFinishes: true, - TtlSecondsAfterFinished: 120, - } - - taskTemplate := dummyRayTaskTemplate("ray-id", rayJob) - RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), dummyRayTaskContext(taskTemplate, resourceRequirements, nil, "", serviceAccount)) - assert.Nil(t, err) - ray, ok := RayResource.(*rayv1.RayJob) - assert.True(t, ok) - headAnnotations := ray.Spec.RayClusterSpec.HeadGroupSpec.Template.ObjectMeta.Annotations - workerAnnotations := ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.ObjectMeta.Annotations - assert.Equal(t, yunikorn.GenerateTaskGroupName(true, 0), headAnnotations[yunikorn.TaskGroupNameKey]) - assert.Equal(t, "gangSchedulingStyle=Soft", headAnnotations[yunikorn.TaskGroupPrarameters]) - assert.Equal(t, yunikorn.GenerateTaskGroupName(false, 0), workerAnnotations[yunikorn.TaskGroupNameKey]) -} - func TestInjectLogsSidecar(t *testing.T) { rayJobObj := transformRayJobToCustomObj(dummyRayCustomObj()) params := []struct { From 189f61861c9a6ba1b5931865658b3803f67b1af5 Mon Sep 17 00:00:00 2001 From: yuteng Date: Wed, 28 Aug 2024 00:15:33 +0800 Subject: [PATCH 22/30] refactor Signed-off-by: yuteng --- .../scheduler/kubernetes/default_test.go | 13 +++++ .../scheduler/yunikorn/rayhandler.go | 6 +-- .../scheduler/yunikorn/taskgroup_test.go | 46 ++++++++++++++++ .../scheduler/yunikorn/utils.go | 10 ++-- .../scheduler/yunikorn/utils_test.go | 53 +++++++++++++++++++ .../scheduler/yunikorn/yunikorn.go | 2 +- .../scheduler/yunikorn/yunikorn_test.go | 46 ++++++++++++++++ 7 files changed, 169 insertions(+), 7 deletions(-) create mode 100644 flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/kubernetes/default_test.go create mode 100644 flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/taskgroup_test.go create mode 100644 flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils_test.go create mode 100644 flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn_test.go diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/kubernetes/default_test.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/kubernetes/default_test.go new file mode 100644 index 0000000000..ecc15bf4c0 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/kubernetes/default_test.go @@ -0,0 +1,13 @@ +package kubernetes + +import ( + "testing" + "gotest.tools/assert" +) + +func TestNewPlugin(t *testing.T) { + p := NewPlugin() + t.Run("New default plugin", func(t *testing.T) { + assert.NotNil(t, p) + }) +} \ No newline at end of file diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/rayhandler.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/rayhandler.go index 57f6cc3123..bd65d4eaaf 100644 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/rayhandler.go +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/rayhandler.go @@ -7,8 +7,8 @@ import ( rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" ) -func ProcessRay(paras string, app *rayv1.RayJob) error { - jobname := GenerateTaskGroupName(true, 0) +func ProcessRay(parameters string, app *rayv1.RayJob) error { + jobname := GenerateTaskGroupAppID() rayjobSpec := &app.Spec appSpec := rayjobSpec.RayClusterSpec TaskGroups := make([]TaskGroup, 1) @@ -52,7 +52,7 @@ func ProcessRay(paras string, app *rayv1.RayJob) error { return err } meta.Annotations[TaskGroupsKey] = string(info[:]) - meta.Annotations[TaskGroupPrarameters] = paras + meta.Annotations[TaskGroupPrarameters] = parameters meta.Annotations[AppID] = jobname return nil } diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/taskgroup_test.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/taskgroup_test.go new file mode 100644 index 0000000000..ce5bfe8d88 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/taskgroup_test.go @@ -0,0 +1,46 @@ +package yunikorn + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMarshal(t *testing.T) { + t1 := TaskGroup{ + Name: "tg1", + MinMember: int32(1), + Labels: map[string]string{"attr": "value"}, + Annotations: map[string]string{"attr": "value"}, + MinResource: res, + NodeSelector: map[string]string{"node": "gpunode"}, + Tolerations: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + } + t2 := TaskGroup{ + Name: "tg2", + MinMember: int32(1), + Labels: map[string]string{"attr": "value"}, + Annotations: map[string]string{"attr": "value"}, + MinResource: res, + NodeSelector: map[string]string{"node": "gpunode"}, + Tolerations: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + } + var tests = []struct { + input []TaskGroup + }{ + {input: nil}, + {input: []TaskGroup{}}, + {input: []TaskGroup{t1}}, + {input: []TaskGroup{t1, t2}}, + } + t.Run("Serialize task groups", func(t *testing.T) { + for _, tt := range tests { + _, err := Marshal(tt.input) + assert.Nil(t, err) + } + }) +} \ No newline at end of file diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils.go index b88fcc4d9e..27dac78e07 100644 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils.go +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils.go @@ -11,9 +11,13 @@ const ( ) func GenerateTaskGroupName(master bool, index int) string { - uid := uuid.New().String() if master { - return fmt.Sprintf("%s-%s-%s", TaskGroupGenericName, "head", uid) + return fmt.Sprintf("%s-%s-%s", TaskGroupGenericName, "head") } - return fmt.Sprintf("%s-%s-%d-%s", TaskGroupGenericName, "worker", index, uid) + return fmt.Sprintf("%s-%s-%d-%s", TaskGroupGenericName, "worker", index) +} + +func GenerateTaskGroupAppID() string { + uid := uuid.New().String() + return fmt.Sprintf("%s-%s", TaskGroupGenericName, uid) } diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils_test.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils_test.go new file mode 100644 index 0000000000..14aba4ddbc --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils_test.go @@ -0,0 +1,53 @@ +package yunikorn + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGenerateTaskGroupName(t *testing.T) { + type inputFormat struct { + isMaster bool + index int + } + var tests = []struct { + input inputFormat + expect string + }{ + { + input: inputFormat{isMaster: true, index: 0}, + expect: fmt.Sprintf("%s-%s", TaskGroupGenericName, "head"), + }, + { + input: inputFormat{isMaster: true, index: 1}, + expect: fmt.Sprintf("%s-%s", TaskGroupGenericName, "head"), + }, + { + input: inputFormat{isMaster: false, index: 0}, + expect: fmt.Sprintf("%s-%s-%d", TaskGroupGenericName, "worker", 0), + }, + { + input: inputFormat{isMaster: false, index: 1}, + expect: fmt.Sprintf("%s-%s-%d", TaskGroupGenericName, "worker", 1), + }, + } + t.Run("Generate ray task group name", func(t *testing.T) { + for _, tt := range tests { + got := GenerateTaskGroupName(tt.input.isMaster, tt.input.index) + assert.Equal(t, tt.expect, got) + } + }) +} + +func TestGenerateTaskGroupAppID(t *testing.T) { + t.Run("Generate ray app ID", func(t *testing.T) { + for _, tt := range tests { + got := GenerateTaskGroupAppID() + if len(got) <= 0 { + t.Error("Ray app ID is empty") + } + } + }) +} \ No newline at end of file diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn.go index a5cde4d87a..c94174f5dc 100644 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn.go +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn.go @@ -26,7 +26,7 @@ func NewPlugin(parameters string) *Plugin { func (p *Plugin) Process(app interface{}) error { switch v := app.(type) { case *rayv1.RayJob: - return ProcessRay(p.Parameters , v) + return ProcessRay(p.Parameters, v) default: return nil } diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn_test.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn_test.go new file mode 100644 index 0000000000..d8a2619083 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn_test.go @@ -0,0 +1,46 @@ +package yunikorn + +import ( + "testing" + "gotest.tools/assert" +) + +func TestNewPlugin(t *testing.T) { + tests := []struct{ + input string + expect *Plugin + }{ + { + input: "", + expect: &Plugin{ + Parameters: "", + }, + }, + { + input: "placeholderTimeoutInSeconds=30 gangSchedulingStyle=Hard", + expect: &Plugin{ + Parameters: "placeholderTimeoutInSeconds=30 gangSchedulingStyle=Hard", + }, + }, + } + t.Run("New Yunikorn plugin", func(t *testing.T) { + got := NewPlugin(t.input) + assert.NotNil(t, got) + assert.Equal(t, t.input, got.Parameters) + }) +} + +func TestProcess(t *testing.T) { + tests := []struct{ + input interface{} + expect error + }{ + {input: 1, expect: nil}, + {input: "test", expect: nil}, + } + t.Run("Yunikorn plugin process any type", func(t *testing.T) { + got := NewPlugin(t.input) + assert.NotNil(t, got) + assert.Equal(t, t.input, got.Parameters) + }) +} \ No newline at end of file From 3f83a1a88ff3fe8159b4cb972316d69733ef2cef Mon Sep 17 00:00:00 2001 From: yuteng Date: Wed, 28 Aug 2024 00:32:06 +0800 Subject: [PATCH 23/30] fix wrong format Signed-off-by: yuteng --- .../plugins/k8s/batchscheduler/scheduler/yunikorn/utils.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils.go index 27dac78e07..91a5357fac 100644 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils.go +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils.go @@ -12,9 +12,9 @@ const ( func GenerateTaskGroupName(master bool, index int) string { if master { - return fmt.Sprintf("%s-%s-%s", TaskGroupGenericName, "head") + return fmt.Sprintf("%s-%s", TaskGroupGenericName, "head") } - return fmt.Sprintf("%s-%s-%d-%s", TaskGroupGenericName, "worker", index) + return fmt.Sprintf("%s-%s-%d", TaskGroupGenericName, "worker", index) } func GenerateTaskGroupAppID() string { From 9cb0a5e10f13ce561a5ec4289a789310ff9d8724 Mon Sep 17 00:00:00 2001 From: yuteng Date: Sun, 1 Sep 2024 22:07:02 +0800 Subject: [PATCH 24/30] fix ray sidecar problem and ready to refactor based on this commit Signed-off-by: yuteng --- .../scheduler/yunikorn/rayhandler.go | 34 ++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/rayhandler.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/rayhandler.go index bd65d4eaaf..562251955f 100644 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/rayhandler.go +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/rayhandler.go @@ -3,6 +3,7 @@ package yunikorn import ( "encoding/json" + "k8s.io/apimachinery/pkg/api/resource" v1 "k8s.io/api/core/v1" rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" ) @@ -36,12 +37,24 @@ func ProcessRay(parameters string, app *rayv1.RayJob) error { meta := headSpec.Template.ObjectMeta spec := headSpec.Template.Spec headName := GenerateTaskGroupName(true, 0) + res := Allocation(spec.Containers) + if ok := *appSpec.EnableInTreeAutoscaling; ok { + res2 := v1.ResourceList{ + v1.ResourceCPU: resource.MustParse("500m"), + v1.ResourceMemory: resource.MustParse("512Mi"), + } + tmp, _ := json.Marshal(res2) + meta.Annotations["tmp"] = string(tmp) + tmp, _ = json.Marshal(Add(res, res2)) + meta.Annotations["Sum"] = string(tmp) + //res = Add(res, res2) + } TaskGroups[0] = TaskGroup{ Name: headName, MinMember: 1, //Labels: meta.Labels, //Annotations: meta.Annotations, - MinResource: Allocation(spec.Containers), + MinResource: res, //NodeSelector: spec.NodeSelector, //Affinity: spec.Affinity, //TopologySpreadConstraints: spec.TopologySpreadConstraints, @@ -71,4 +84,23 @@ func Allocation(containers []v1.Container) v1.ResourceList { } } return totalResources +} + +func Add(a v1.ResourceList, b v1.ResourceList) v1.ResourceList { + result := a + for name, value := range a { + sum := &value + if value2, ok := b[name]; ok { + sum.Add(value2) + result[name] = *sum + } else { + result[name] = value + } + } + for name, value := range b { + if _, ok := a[name]; !ok { + result[name] = value + } + } + return result } \ No newline at end of file From ad68f48a1fab1d732861522d35e0b906e6c6a598 Mon Sep 17 00:00:00 2001 From: yuteng Date: Sun, 1 Sep 2024 22:16:29 +0800 Subject: [PATCH 25/30] fix resource computation Signed-off-by: yuteng --- .../batchscheduler/scheduler/yunikorn/rayhandler.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/rayhandler.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/rayhandler.go index 562251955f..93fe22598c 100644 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/rayhandler.go +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/rayhandler.go @@ -43,11 +43,11 @@ func ProcessRay(parameters string, app *rayv1.RayJob) error { v1.ResourceCPU: resource.MustParse("500m"), v1.ResourceMemory: resource.MustParse("512Mi"), } - tmp, _ := json.Marshal(res2) - meta.Annotations["tmp"] = string(tmp) - tmp, _ = json.Marshal(Add(res, res2)) - meta.Annotations["Sum"] = string(tmp) - //res = Add(res, res2) + //tmp, _ := json.Marshal(res2) + //meta.Annotations["tmp"] = string(tmp) + //tmp, _ = json.Marshal(Add(res, res2)) + //meta.Annotations["Sum"] = string(tmp) + res = Add(res, res2) } TaskGroups[0] = TaskGroup{ Name: headName, From 64b8655fd9093912f77836b5c0c4f2b96ce76aa3 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sun, 8 Sep 2024 02:09:02 -0700 Subject: [PATCH 26/30] kevin refactor Signed-off-by: Kevin Su --- .../go/tasks/pluginmachinery/k8s/plugin.go | 3 + .../workqueue/mocks/processor.go | 4 +- .../plugins/k8s/batchscheduler/plugins.go | 24 ++++++-- .../scheduler/kubernetes/default.go | 13 ---- .../scheduler/kubernetes/default_test.go | 13 ---- .../scheduler/yunikorn/rayhandler.go | 59 ++++++++++++------- .../scheduler/yunikorn/yunikorn.go | 33 ++++------- flyteplugins/go/tasks/plugins/k8s/ray/ray.go | 9 +-- .../nodes/task/k8s/plugin_manager.go | 16 +++++ 9 files changed, 94 insertions(+), 80 deletions(-) delete mode 100644 flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/kubernetes/default.go delete mode 100644 flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/kubernetes/default_test.go diff --git a/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go b/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go index 38a84f9b2b..44efc81ac4 100644 --- a/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go +++ b/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go @@ -2,6 +2,7 @@ package k8s import ( "context" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler" "sigs.k8s.io/controller-runtime/pkg/client" @@ -28,6 +29,8 @@ type PluginEntry struct { IsDefault bool // Returns a new KubeClient to be used instead of the internal controller-runtime client. CustomKubeClient func(ctx context.Context) (pluginsCore.KubeClient, error) + // Return a new scheduler plugin to be used instead of the default k8s scheduler. + Scheduler func(ctx context.Context) batchscheduler.SchedulerPlugin } // System level properties that this Plugin supports diff --git a/flyteplugins/go/tasks/pluginmachinery/workqueue/mocks/processor.go b/flyteplugins/go/tasks/pluginmachinery/workqueue/mocks/processor.go index ec5c0d12a6..b9d59d1b47 100644 --- a/flyteplugins/go/tasks/pluginmachinery/workqueue/mocks/processor.go +++ b/flyteplugins/go/tasks/pluginmachinery/workqueue/mocks/processor.go @@ -23,12 +23,12 @@ func (_m Processor_Process) Return(_a0 workqueue.WorkStatus, _a1 error) *Process } func (_m *Processor) OnProcess(ctx context.Context, workItem workqueue.WorkItem) *Processor_Process { - c_call := _m.On("Process", ctx, workItem) + c_call := _m.On("Mutate", ctx, workItem) return &Processor_Process{Call: c_call} } func (_m *Processor) OnProcessMatch(matchers ...interface{}) *Processor_Process { - c_call := _m.On("Process", matchers...) + c_call := _m.On("Mutate", matchers...) return &Processor_Process{Call: c_call} } diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/plugins.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/plugins.go index f0625fd180..ffc69cfc6d 100644 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/plugins.go +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/plugins.go @@ -1,20 +1,34 @@ package batchscheduler import ( + "context" schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/kubernetes" "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn" + "reflect" + "sigs.k8s.io/controller-runtime/pkg/client" ) type SchedulerPlugin interface { - Process(app interface{}) error + // Mutate is responsible for mutating the object to be scheduled. + // It will add the necessary annotations, labels, etc. to the object. + Mutate(ctx context.Context, object client.Object) error } -func NewSchedulerPlugin(cfg *schedulerConfig.Config) SchedulerPlugin { +type NoopSchedulerPlugin struct{} + +func NewNoopSchedulerPlugin() *NoopSchedulerPlugin { + return &NoopSchedulerPlugin{} +} + +func (p *NoopSchedulerPlugin) Mutate(ctx context.Context, object client.Object) error { + return nil +} + +func NewSchedulerPlugin(t reflect.Type, cfg *schedulerConfig.Config) SchedulerPlugin { switch cfg.GetScheduler() { case yunikorn.Yunikorn: - return yunikorn.NewPlugin(cfg.GetParameters()) + return yunikorn.NewPlugin(t, cfg.GetParameters()) default: - return kubernetes.NewPlugin() + return NewNoopSchedulerPlugin() } } diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/kubernetes/default.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/kubernetes/default.go deleted file mode 100644 index bd5fad5976..0000000000 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/kubernetes/default.go +++ /dev/null @@ -1,13 +0,0 @@ -package kubernetes - -var ( - DefaultScheduler = "default" -) - -type Plugin struct{} - -func NewPlugin() *Plugin { - return &Plugin{} -} - -func (p *Plugin) Process(app interface{}) error { return nil } diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/kubernetes/default_test.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/kubernetes/default_test.go deleted file mode 100644 index ecc15bf4c0..0000000000 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/kubernetes/default_test.go +++ /dev/null @@ -1,13 +0,0 @@ -package kubernetes - -import ( - "testing" - "gotest.tools/assert" -) - -func TestNewPlugin(t *testing.T) { - p := NewPlugin() - t.Run("New default plugin", func(t *testing.T) { - assert.NotNil(t, p) - }) -} \ No newline at end of file diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/rayhandler.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/rayhandler.go index 93fe22598c..ae3ef066f0 100644 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/rayhandler.go +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/rayhandler.go @@ -1,15 +1,26 @@ package yunikorn import ( + "context" "encoding/json" + "sigs.k8s.io/controller-runtime/pkg/client" - "k8s.io/apimachinery/pkg/api/resource" - v1 "k8s.io/api/core/v1" rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" ) +type RayHandler struct { + parameters string +} + +func (h *RayHandler) Mutate(ctx context.Context, object client.Object) error { + rayJob := object.(*rayv1.RayJob) + return ProcessRay(h.parameters, rayJob) +} + func ProcessRay(parameters string, app *rayv1.RayJob) error { - jobname := GenerateTaskGroupAppID() + appID := GenerateTaskGroupAppID() rayjobSpec := &app.Spec appSpec := rayjobSpec.RayClusterSpec TaskGroups := make([]TaskGroup, 1) @@ -20,17 +31,17 @@ func ProcessRay(parameters string, app *rayv1.RayJob) error { spec := worker.Template.Spec name := GenerateTaskGroupName(false, index) TaskGroups = append(TaskGroups, TaskGroup{ - Name: name, - MinMember: *worker.Replicas, + Name: name, + MinMember: *worker.Replicas, //Labels: meta.Labels, //Annotations: meta.Annotations, - MinResource: Allocation(spec.Containers), + MinResource: Allocation(spec.Containers), //NodeSelector: spec.NodeSelector, //Affinity: spec.Affinity, //TopologySpreadConstraints: spec.TopologySpreadConstraints, }) - meta.Annotations[TaskGroupNameKey] = name - meta.Annotations[AppID] = jobname + meta.Annotations[TaskGroupNameKey] = name + meta.Annotations[AppID] = appID } headSpec := &appSpec.HeadGroupSpec headSpec.Template.Spec.SchedulerName = Yunikorn @@ -50,11 +61,11 @@ func ProcessRay(parameters string, app *rayv1.RayJob) error { res = Add(res, res2) } TaskGroups[0] = TaskGroup{ - Name: headName, - MinMember: 1, + Name: headName, + MinMember: 1, //Labels: meta.Labels, //Annotations: meta.Annotations, - MinResource: res, + MinResource: res, //NodeSelector: spec.NodeSelector, //Affinity: spec.Affinity, //TopologySpreadConstraints: spec.TopologySpreadConstraints, @@ -65,25 +76,25 @@ func ProcessRay(parameters string, app *rayv1.RayJob) error { return err } meta.Annotations[TaskGroupsKey] = string(info[:]) - meta.Annotations[TaskGroupPrarameters] = parameters - meta.Annotations[AppID] = jobname + meta.Annotations[TaskGroupParameters] = parameters + meta.Annotations[AppID] = appID return nil } func Allocation(containers []v1.Container) v1.ResourceList { - totalResources := v1.ResourceList{} - for _, c := range containers { - for name, q := range c.Resources.Limits { - if _, exists := totalResources[name]; !exists { + totalResources := v1.ResourceList{} + for _, c := range containers { + for name, q := range c.Resources.Limits { + if _, exists := totalResources[name]; !exists { totalResources[name] = q.DeepCopy() continue - } + } total := totalResources[name] total.Add(q) totalResources[name] = total - } - } - return totalResources + } + } + return totalResources } func Add(a v1.ResourceList, b v1.ResourceList) v1.ResourceList { @@ -103,4 +114,8 @@ func Add(a v1.ResourceList, b v1.ResourceList) v1.ResourceList { } } return result -} \ No newline at end of file +} + +func NewRayHandler(parameters string) *RayHandler { + return &RayHandler{parameters: parameters} +} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn.go index c94174f5dc..df07bdfc0b 100644 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn.go +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn.go @@ -1,33 +1,24 @@ package yunikorn import ( + "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler" rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" + "reflect" ) const ( - // Pod lebel - Yunikorn = "yunikorn" - AppID = "yunikorn.apache.org/app-id" - TaskGroupNameKey = "yunikorn.apache.org/task-group-name" - TaskGroupsKey = "yunikorn.apache.org/task-groups" - TaskGroupPrarameters = "yunikorn.apache.org/schedulingPolicyParameters" + Yunikorn = "yunikorn" + AppID = "yunikorn.apache.org/app-id" + TaskGroupNameKey = "yunikorn.apache.org/task-group-name" + TaskGroupsKey = "yunikorn.apache.org/task-groups" + TaskGroupParameters = "yunikorn.apache.org/schedulingPolicyParameters" ) -type Plugin struct { - Parameters string -} - -func NewPlugin(parameters string) *Plugin { - return &Plugin{ - Parameters: parameters, - } -} - -func (p *Plugin) Process(app interface{}) error { - switch v := app.(type) { - case *rayv1.RayJob: - return ProcessRay(p.Parameters, v) +func NewPlugin(t reflect.Type, parameters string) batchscheduler.SchedulerPlugin { + switch t { + case reflect.TypeOf(rayv1.RayJob{}): + return NewRayHandler(parameters) default: - return nil + return batchscheduler.NewNoopSchedulerPlugin() } } diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go index 7aa0b014fd..e17a383f78 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go @@ -5,6 +5,7 @@ import ( "encoding/base64" "encoding/json" "fmt" + "reflect" "regexp" "strconv" "strings" @@ -121,10 +122,6 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC headPodSpec := podSpec.DeepCopy() rayjob, err := constructRayJob(taskCtx, rayJob, objectMeta, *podSpec, headPodSpec, headNodeRayStartParams, primaryContainerIdx, *primaryContainer) - if err != nil { - return rayjob, err - } - err = batchscheduler.NewSchedulerPlugin(&cfg.BatchScheduler).Process(rayjob) return rayjob, err } @@ -620,5 +617,9 @@ func init() { return k8s.NewDefaultKubeClient(kubeConfig) }, + Scheduler: func(ctx context.Context) batchscheduler.SchedulerPlugin { + cfg := GetConfig().BatchScheduler + return batchscheduler.NewSchedulerPlugin(reflect.TypeOf(rayv1.RayJob{}), &cfg) + }, }) } diff --git a/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager.go b/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager.go index f9c3806ee6..0a372d837a 100644 --- a/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager.go +++ b/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager.go @@ -3,6 +3,7 @@ package k8s import ( "context" "fmt" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler" "time" "golang.org/x/time/rate" @@ -90,6 +91,7 @@ type PluginManager struct { plugin k8s.Plugin resourceToWatch runtime.Object kubeClient pluginsCore.KubeClient + scheduler batchscheduler.SchedulerPlugin metrics PluginMetrics // Per namespace-resource backOffController *backoff.Controller @@ -203,6 +205,12 @@ func (e *PluginManager) launchResource(ctx context.Context, tCtx pluginsCore.Tas key := backoff.ComposeResourceKey(o) + err = e.scheduler.Mutate(ctx, o) + if err != nil { + logger.Errorf(ctx, "Scheduler plugin failed to process object with error: %v", err) + return pluginsCore.Transition{}, err + } + pod, casted := o.(*v1.Pod) if e.backOffController != nil && casted { podRequestedResources := e.getPodEffectiveResourceLimits(ctx, pod) @@ -536,6 +544,13 @@ func NewPluginManager(ctx context.Context, iCtx pluginsCore.SetupContext, entry return nil, errors.Errorf(errors.PluginInitializationFailed, "Failed to initialize K8sResource Plugin, Kubeclient cannot be nil!") } + var scheduler batchscheduler.SchedulerPlugin + if entry.Scheduler != nil { + scheduler = entry.Scheduler(ctx) + } else { + scheduler = batchscheduler.NewNoopSchedulerPlugin() + } + logger.Infof(ctx, "Initializing K8s plugin [%s]", entry.ID) src := source.Kind(iCtx.KubeClient().GetCache(), entry.ResourceToWatch) @@ -650,6 +665,7 @@ func NewPluginManager(ctx context.Context, iCtx pluginsCore.SetupContext, entry resourceToWatch: entry.ResourceToWatch, metrics: newPluginMetrics(metricsScope), kubeClient: kubeClient, + scheduler: scheduler, resourceLevelMonitor: rm, eventWatcher: eventWatcher, }, nil From a4cca9d62990f6e080905232c8e59e5641ba5d9d Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sun, 8 Sep 2024 02:39:45 -0700 Subject: [PATCH 27/30] kevin refactor Signed-off-by: Kevin Su --- .../pluginmachinery/flytek8s/config/config.go | 3 +++ .../go/tasks/pluginmachinery/k8s/plugin.go | 2 +- .../k8s/batchscheduler/{config => }/config.go | 2 +- .../{config => }/config_test.go | 2 +- .../plugins/k8s/batchscheduler/plugins.go | 21 +++++------------- .../scheduler/noop_scheduler.go | 16 ++++++++++++++ .../scheduler/yunikorn/rayhandler.go | 18 +-------------- .../scheduler/yunikorn/yunikorn.go | 22 ++++++++++++++----- .../scheduler/yunikorn/yunikorn_test.go | 20 ++++++++--------- .../go/tasks/plugins/k8s/ray/config.go | 2 +- flyteplugins/go/tasks/plugins/k8s/ray/ray.go | 6 ----- .../nodes/task/k8s/plugin_manager.go | 9 ++------ 12 files changed, 57 insertions(+), 66 deletions(-) rename flyteplugins/go/tasks/plugins/k8s/batchscheduler/{config => }/config.go (94%) rename flyteplugins/go/tasks/plugins/k8s/batchscheduler/{config => }/config_test.go (92%) create mode 100644 flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/noop_scheduler.go diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/config/config.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/config/config.go index 109ef06ba1..f1fae18c22 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/config/config.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/config/config.go @@ -6,6 +6,7 @@ package config import ( + schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler" "time" v1 "k8s.io/api/core/v1" @@ -206,6 +207,8 @@ type K8sPluginConfig struct { // SendObjectEvents indicates whether to send k8s object events in TaskExecutionEvent updates (similar to kubectl get events). SendObjectEvents bool `json:"send-object-events" pflag:",If true, will send k8s object events in TaskExecutionEvent updates."` + + BatchScheduler schedulerConfig.Config `json:"batchScheduler,omitempty"` } // FlyteCoPilotConfig specifies configuration for the Flyte CoPilot system. FlyteCoPilot, allows running flytekit-less containers diff --git a/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go b/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go index 44efc81ac4..307ed68bb9 100644 --- a/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go +++ b/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go @@ -30,7 +30,7 @@ type PluginEntry struct { // Returns a new KubeClient to be used instead of the internal controller-runtime client. CustomKubeClient func(ctx context.Context) (pluginsCore.KubeClient, error) // Return a new scheduler plugin to be used instead of the default k8s scheduler. - Scheduler func(ctx context.Context) batchscheduler.SchedulerPlugin + Scheduler func(ctx context.Context) batchscheduler.SchedulerManager } // System level properties that this Plugin supports diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config/config.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config.go similarity index 94% rename from flyteplugins/go/tasks/plugins/k8s/batchscheduler/config/config.go rename to flyteplugins/go/tasks/plugins/k8s/batchscheduler/config.go index 483d940ca8..e75627b7a5 100644 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config/config.go +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config.go @@ -1,4 +1,4 @@ -package config +package batchscheduler type Config struct { Scheduler string `json:"scheduler,omitempty" pflag:", Specify batch scheduler to"` diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config/config_test.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config_test.go similarity index 92% rename from flyteplugins/go/tasks/plugins/k8s/batchscheduler/config/config_test.go rename to flyteplugins/go/tasks/plugins/k8s/batchscheduler/config_test.go index b7eb9fc354..e234008b53 100644 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config/config_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config_test.go @@ -1,4 +1,4 @@ -package config +package batchscheduler import ( "testing" diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/plugins.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/plugins.go index ffc69cfc6d..944d722787 100644 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/plugins.go +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/plugins.go @@ -2,33 +2,22 @@ package batchscheduler import ( "context" - schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler" "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn" - "reflect" "sigs.k8s.io/controller-runtime/pkg/client" ) -type SchedulerPlugin interface { +type SchedulerManager interface { // Mutate is responsible for mutating the object to be scheduled. // It will add the necessary annotations, labels, etc. to the object. Mutate(ctx context.Context, object client.Object) error } -type NoopSchedulerPlugin struct{} - -func NewNoopSchedulerPlugin() *NoopSchedulerPlugin { - return &NoopSchedulerPlugin{} -} - -func (p *NoopSchedulerPlugin) Mutate(ctx context.Context, object client.Object) error { - return nil -} - -func NewSchedulerPlugin(t reflect.Type, cfg *schedulerConfig.Config) SchedulerPlugin { +func NewSchedulerManager(cfg *Config) SchedulerManager { switch cfg.GetScheduler() { case yunikorn.Yunikorn: - return yunikorn.NewPlugin(t, cfg.GetParameters()) + return yunikorn.NewSchedulerManager(cfg) default: - return NewNoopSchedulerPlugin() + return scheduler.NewNoopSchedulerManager() } } diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/noop_scheduler.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/noop_scheduler.go new file mode 100644 index 0000000000..8f9ca31146 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/noop_scheduler.go @@ -0,0 +1,16 @@ +package scheduler + +import ( + "context" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +type NoopSchedulerManager struct{} + +func NewNoopSchedulerManager() *NoopSchedulerManager { + return &NoopSchedulerManager{} +} + +func (p *NoopSchedulerManager) Mutate(ctx context.Context, object client.Object) error { + return nil +} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/rayhandler.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/rayhandler.go index ae3ef066f0..c39a5a937b 100644 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/rayhandler.go +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/rayhandler.go @@ -1,25 +1,13 @@ package yunikorn import ( - "context" "encoding/json" - "sigs.k8s.io/controller-runtime/pkg/client" - rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" ) -type RayHandler struct { - parameters string -} - -func (h *RayHandler) Mutate(ctx context.Context, object client.Object) error { - rayJob := object.(*rayv1.RayJob) - return ProcessRay(h.parameters, rayJob) -} - -func ProcessRay(parameters string, app *rayv1.RayJob) error { +func MutateRayJob(parameters string, app *rayv1.RayJob) error { appID := GenerateTaskGroupAppID() rayjobSpec := &app.Spec appSpec := rayjobSpec.RayClusterSpec @@ -115,7 +103,3 @@ func Add(a v1.ResourceList, b v1.ResourceList) v1.ResourceList { } return result } - -func NewRayHandler(parameters string) *RayHandler { - return &RayHandler{parameters: parameters} -} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn.go index df07bdfc0b..6abfbbb434 100644 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn.go +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn.go @@ -1,9 +1,10 @@ package yunikorn import ( + "context" "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler" rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" - "reflect" + "sigs.k8s.io/controller-runtime/pkg/client" ) const ( @@ -14,11 +15,20 @@ const ( TaskGroupParameters = "yunikorn.apache.org/schedulingPolicyParameters" ) -func NewPlugin(t reflect.Type, parameters string) batchscheduler.SchedulerPlugin { - switch t { - case reflect.TypeOf(rayv1.RayJob{}): - return NewRayHandler(parameters) +type YunikornSchedulerManager struct { + parameters string +} + +func (y *YunikornSchedulerManager) Mutate(ctx context.Context, object client.Object) error { + switch object.(type) { + case *rayv1.RayJob: + return MutateRayJob(y.parameters, object.(*rayv1.RayJob)) default: - return batchscheduler.NewNoopSchedulerPlugin() + } + return nil +} + +func NewSchedulerManager(cfg *batchscheduler.Config) batchscheduler.SchedulerManager { + return &YunikornSchedulerManager{parameters: cfg.Parameters} } diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn_test.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn_test.go index d8a2619083..08fe5a1761 100644 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn_test.go @@ -1,46 +1,46 @@ package yunikorn import ( - "testing" "gotest.tools/assert" + "testing" ) func TestNewPlugin(t *testing.T) { - tests := []struct{ - input string + tests := []struct { + input string expect *Plugin }{ { input: "", expect: &Plugin{ - Parameters: "", + Parameters: "", }, }, { input: "placeholderTimeoutInSeconds=30 gangSchedulingStyle=Hard", expect: &Plugin{ - Parameters: "placeholderTimeoutInSeconds=30 gangSchedulingStyle=Hard", + Parameters: "placeholderTimeoutInSeconds=30 gangSchedulingStyle=Hard", }, }, } t.Run("New Yunikorn plugin", func(t *testing.T) { - got := NewPlugin(t.input) + got := NewSchedulerManager(t.input) assert.NotNil(t, got) assert.Equal(t, t.input, got.Parameters) }) } func TestProcess(t *testing.T) { - tests := []struct{ - input interface{} + tests := []struct { + input interface{} expect error }{ {input: 1, expect: nil}, {input: "test", expect: nil}, } t.Run("Yunikorn plugin process any type", func(t *testing.T) { - got := NewPlugin(t.input) + got := NewSchedulerManager(t.input) assert.NotNil(t, got) assert.Equal(t, t.input, got.Parameters) }) -} \ No newline at end of file +} diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/config.go b/flyteplugins/go/tasks/plugins/k8s/ray/config.go index 5658ccc353..8be030413b 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/config.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/config.go @@ -2,6 +2,7 @@ package ray import ( "context" + schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler" v1 "k8s.io/api/core/v1" @@ -9,7 +10,6 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/logs" pluginmachinery "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" - schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config" "github.com/flyteorg/flyte/flytestdlib/config" ) diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go index e17a383f78..ed88014728 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go @@ -5,7 +5,6 @@ import ( "encoding/base64" "encoding/json" "fmt" - "reflect" "regexp" "strconv" "strings" @@ -29,7 +28,6 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler" ) const ( @@ -617,9 +615,5 @@ func init() { return k8s.NewDefaultKubeClient(kubeConfig) }, - Scheduler: func(ctx context.Context) batchscheduler.SchedulerPlugin { - cfg := GetConfig().BatchScheduler - return batchscheduler.NewSchedulerPlugin(reflect.TypeOf(rayv1.RayJob{}), &cfg) - }, }) } diff --git a/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager.go b/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager.go index 0a372d837a..837adbd004 100644 --- a/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager.go +++ b/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager.go @@ -91,7 +91,7 @@ type PluginManager struct { plugin k8s.Plugin resourceToWatch runtime.Object kubeClient pluginsCore.KubeClient - scheduler batchscheduler.SchedulerPlugin + scheduler batchscheduler.SchedulerManager metrics PluginMetrics // Per namespace-resource backOffController *backoff.Controller @@ -544,12 +544,7 @@ func NewPluginManager(ctx context.Context, iCtx pluginsCore.SetupContext, entry return nil, errors.Errorf(errors.PluginInitializationFailed, "Failed to initialize K8sResource Plugin, Kubeclient cannot be nil!") } - var scheduler batchscheduler.SchedulerPlugin - if entry.Scheduler != nil { - scheduler = entry.Scheduler(ctx) - } else { - scheduler = batchscheduler.NewNoopSchedulerPlugin() - } + scheduler := batchscheduler.NewSchedulerManager(&config.GetK8sPluginConfig().BatchScheduler) logger.Infof(ctx, "Initializing K8s plugin [%s]", entry.ID) src := source.Kind(iCtx.KubeClient().GetCache(), entry.ResourceToWatch) From 4387f647cf62b3b548b5ed1d23d445fd9bf1fce0 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sun, 8 Sep 2024 02:43:10 -0700 Subject: [PATCH 28/30] nit Signed-off-by: Kevin Su --- flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go | 4 ---- flyteplugins/go/tasks/plugins/k8s/ray/ray.go | 2 +- .../pkg/controller/nodes/task/k8s/plugin_manager.go | 8 ++++---- 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go b/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go index 307ed68bb9..32d0a45a16 100644 --- a/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go +++ b/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go @@ -2,8 +2,6 @@ package k8s import ( "context" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler" - "sigs.k8s.io/controller-runtime/pkg/client" pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" @@ -29,8 +27,6 @@ type PluginEntry struct { IsDefault bool // Returns a new KubeClient to be used instead of the internal controller-runtime client. CustomKubeClient func(ctx context.Context) (pluginsCore.KubeClient, error) - // Return a new scheduler plugin to be used instead of the default k8s scheduler. - Scheduler func(ctx context.Context) batchscheduler.SchedulerManager } // System level properties that this Plugin supports diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go index ed88014728..178d2427ad 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go @@ -124,7 +124,6 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC } func constructRayJob(taskCtx pluginsCore.TaskExecutionContext, rayJob plugins.RayJob, objectMeta *metav1.ObjectMeta, podSpec v1.PodSpec, headPodSpec *v1.PodSpec, headNodeRayStartParams map[string]string, primaryContainerIdx int, primaryContainer v1.Container) (*rayv1.RayJob, error) { - var err error enableIngress := true cfg := GetConfig() rayClusterSpec := rayv1.RayClusterSpec{ @@ -210,6 +209,7 @@ func constructRayJob(taskCtx pluginsCore.TaskExecutionContext, rayJob plugins.Ra // TODO: This is for backward compatibility. Remove this block once runtime_env is removed from ray proto. var runtimeEnvYaml string runtimeEnvYaml = rayJob.RuntimeEnvYaml + var err error // If runtime_env exists but runtime_env_yaml does not, convert runtime_env to runtime_env_yaml if rayJob.RuntimeEnv != "" && rayJob.RuntimeEnvYaml == "" { runtimeEnvYaml, err = convertBase64RuntimeEnvToYaml(rayJob.RuntimeEnv) diff --git a/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager.go b/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager.go index 837adbd004..6ea6bcede1 100644 --- a/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager.go +++ b/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager.go @@ -91,7 +91,7 @@ type PluginManager struct { plugin k8s.Plugin resourceToWatch runtime.Object kubeClient pluginsCore.KubeClient - scheduler batchscheduler.SchedulerManager + schedulerMgr batchscheduler.SchedulerManager metrics PluginMetrics // Per namespace-resource backOffController *backoff.Controller @@ -205,7 +205,7 @@ func (e *PluginManager) launchResource(ctx context.Context, tCtx pluginsCore.Tas key := backoff.ComposeResourceKey(o) - err = e.scheduler.Mutate(ctx, o) + err = e.schedulerMgr.Mutate(ctx, o) if err != nil { logger.Errorf(ctx, "Scheduler plugin failed to process object with error: %v", err) return pluginsCore.Transition{}, err @@ -544,7 +544,7 @@ func NewPluginManager(ctx context.Context, iCtx pluginsCore.SetupContext, entry return nil, errors.Errorf(errors.PluginInitializationFailed, "Failed to initialize K8sResource Plugin, Kubeclient cannot be nil!") } - scheduler := batchscheduler.NewSchedulerManager(&config.GetK8sPluginConfig().BatchScheduler) + schedulerMgr := batchscheduler.NewSchedulerManager(&config.GetK8sPluginConfig().BatchScheduler) logger.Infof(ctx, "Initializing K8s plugin [%s]", entry.ID) src := source.Kind(iCtx.KubeClient().GetCache(), entry.ResourceToWatch) @@ -660,7 +660,7 @@ func NewPluginManager(ctx context.Context, iCtx pluginsCore.SetupContext, entry resourceToWatch: entry.ResourceToWatch, metrics: newPluginMetrics(metricsScope), kubeClient: kubeClient, - scheduler: scheduler, + schedulerMgr: schedulerMgr, resourceLevelMonitor: rm, eventWatcher: eventWatcher, }, nil From 7515dd82e0e0ceef31aebc69add526c9f5558277 Mon Sep 17 00:00:00 2001 From: yuteng Date: Mon, 9 Sep 2024 23:10:54 +0800 Subject: [PATCH 29/30] build completed tg Signed-off-by: yuteng --- .../scheduler/yunikorn/rayhandler.go | 36 +++++++++---------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/rayhandler.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/rayhandler.go index c39a5a937b..0884a68df1 100644 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/rayhandler.go +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/rayhandler.go @@ -19,14 +19,14 @@ func MutateRayJob(parameters string, app *rayv1.RayJob) error { spec := worker.Template.Spec name := GenerateTaskGroupName(false, index) TaskGroups = append(TaskGroups, TaskGroup{ - Name: name, - MinMember: *worker.Replicas, - //Labels: meta.Labels, - //Annotations: meta.Annotations, - MinResource: Allocation(spec.Containers), - //NodeSelector: spec.NodeSelector, - //Affinity: spec.Affinity, - //TopologySpreadConstraints: spec.TopologySpreadConstraints, + Name: name, + MinMember: *worker.Replicas, + Labels: meta.Labels, + Annotations: meta.Annotations, + MinResource: Allocation(spec.Containers), + NodeSelector: spec.NodeSelector, + Affinity: spec.Affinity, + TopologySpreadConstraints: spec.TopologySpreadConstraints, }) meta.Annotations[TaskGroupNameKey] = name meta.Annotations[AppID] = appID @@ -42,21 +42,17 @@ func MutateRayJob(parameters string, app *rayv1.RayJob) error { v1.ResourceCPU: resource.MustParse("500m"), v1.ResourceMemory: resource.MustParse("512Mi"), } - //tmp, _ := json.Marshal(res2) - //meta.Annotations["tmp"] = string(tmp) - //tmp, _ = json.Marshal(Add(res, res2)) - //meta.Annotations["Sum"] = string(tmp) res = Add(res, res2) } TaskGroups[0] = TaskGroup{ - Name: headName, - MinMember: 1, - //Labels: meta.Labels, - //Annotations: meta.Annotations, - MinResource: res, - //NodeSelector: spec.NodeSelector, - //Affinity: spec.Affinity, - //TopologySpreadConstraints: spec.TopologySpreadConstraints, + Name: headName, + MinMember: 1, + Labels: meta.Labels, + Annotations: meta.Annotations, + MinResource: res, + NodeSelector: spec.NodeSelector, + Affinity: spec.Affinity, + TopologySpreadConstraints: spec.TopologySpreadConstraints, } meta.Annotations[TaskGroupNameKey] = headName info, err := json.Marshal(TaskGroups) From bd74f60439cafee276f8effb4e83d53a80a92798 Mon Sep 17 00:00:00 2001 From: yuteng Date: Sat, 12 Oct 2024 23:43:36 +0800 Subject: [PATCH 30/30] refactoring proposal Signed-off-by: yuteng --- .../pluginmachinery/flytek8s/config/config.go | 3 -- .../go/tasks/pluginmachinery/k8s/plugin.go | 10 ++++ .../plugins/k8s/batchscheduler/config.go | 24 +++++---- .../plugins/k8s/batchscheduler/config_test.go | 15 ------ .../k8s/batchscheduler/kueue/helper.go | 16 ++++++ .../plugins/k8s/batchscheduler/plugins.go | 23 -------- .../scheduler/noop_scheduler.go | 16 ------ .../scheduler/yunikorn/yunikorn.go | 34 ------------ .../scheduler/yunikorn/yunikorn_test.go | 46 ---------------- .../k8s/batchscheduler/utils/helper.go | 30 +++++++++++ .../rayhandler.go => yunikorn/helper.go} | 52 +++++++++++++++---- .../{scheduler => }/yunikorn/taskgroup.go | 0 .../yunikorn/taskgroup_test.go | 8 ++- .../{scheduler => }/yunikorn/utils.go | 0 .../{scheduler => }/yunikorn/utils_test.go | 10 ++-- .../go/tasks/plugins/k8s/ray/config.go | 2 +- .../go/tasks/plugins/k8s/ray/config_flags.go | 7 ++- .../plugins/k8s/ray/config_flags_test.go | 50 ++++++++++++++++-- flyteplugins/go/tasks/plugins/k8s/ray/ray.go | 47 +++++++++++++++++ .../nodes/task/k8s/plugin_manager.go | 18 +++---- 20 files changed, 228 insertions(+), 183 deletions(-) delete mode 100644 flyteplugins/go/tasks/plugins/k8s/batchscheduler/config_test.go create mode 100644 flyteplugins/go/tasks/plugins/k8s/batchscheduler/kueue/helper.go delete mode 100644 flyteplugins/go/tasks/plugins/k8s/batchscheduler/plugins.go delete mode 100644 flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/noop_scheduler.go delete mode 100644 flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn.go delete mode 100644 flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn_test.go create mode 100644 flyteplugins/go/tasks/plugins/k8s/batchscheduler/utils/helper.go rename flyteplugins/go/tasks/plugins/k8s/batchscheduler/{scheduler/yunikorn/rayhandler.go => yunikorn/helper.go} (65%) rename flyteplugins/go/tasks/plugins/k8s/batchscheduler/{scheduler => }/yunikorn/taskgroup.go (100%) rename flyteplugins/go/tasks/plugins/k8s/batchscheduler/{scheduler => }/yunikorn/taskgroup_test.go (86%) rename flyteplugins/go/tasks/plugins/k8s/batchscheduler/{scheduler => }/yunikorn/utils.go (100%) rename flyteplugins/go/tasks/plugins/k8s/batchscheduler/{scheduler => }/yunikorn/utils_test.go (89%) diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/config/config.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/config/config.go index f1fae18c22..109ef06ba1 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/config/config.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/config/config.go @@ -6,7 +6,6 @@ package config import ( - schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler" "time" v1 "k8s.io/api/core/v1" @@ -207,8 +206,6 @@ type K8sPluginConfig struct { // SendObjectEvents indicates whether to send k8s object events in TaskExecutionEvent updates (similar to kubectl get events). SendObjectEvents bool `json:"send-object-events" pflag:",If true, will send k8s object events in TaskExecutionEvent updates."` - - BatchScheduler schedulerConfig.Config `json:"batchScheduler,omitempty"` } // FlyteCoPilotConfig specifies configuration for the Flyte CoPilot system. FlyteCoPilot, allows running flytekit-less containers diff --git a/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go b/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go index 32d0a45a16..6f50cfd29e 100644 --- a/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go +++ b/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go @@ -2,8 +2,10 @@ package k8s import ( "context" + "sigs.k8s.io/controller-runtime/pkg/client" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io" "github.com/flyteorg/flyte/flytestdlib/storage" @@ -186,3 +188,11 @@ func MaybeUpdatePhaseVersionFromPluginContext(phaseInfo *pluginsCore.PhaseInfo, MaybeUpdatePhaseVersion(phaseInfo, &pluginState) return nil } + +type YunikornScheduablePlugin interface { + MutateResourceForYunikorn(ctx context.Context, object client.Object, taskTmpl *core.TaskTemplate) (client.Object, error) +} + +type KueueScheduablePlugin interface { + MutateResourceForKueue(ctx context.Context, object client.Object, taskTmpl *core.TaskTemplate) (client.Object, error) +} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config.go index e75627b7a5..78769a1123 100644 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config.go +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config.go @@ -1,21 +1,23 @@ package batchscheduler type Config struct { - Scheduler string `json:"scheduler,omitempty" pflag:", Specify batch scheduler to"` - Parameters string `json:"parameters,omitempty" pflag:", Specify static parameters"` + Scheduler string `json:"scheduler,omitempty" pflag:", Specify batch scheduler to"` + Default SchedulingConfig `json:"default,omitempty" pflag:", Specify default scheduling config which batch scheduler adopts"` + NameSpace map[string]SchedulingConfig `json:"Namespace,omitempty" pflag:"-, Specify namespace scheduling config"` + Domain map[string]SchedulingConfig `json:"Domain,omitempty" pflag:"-, Specify domain scheduling config"` } -func NewConfig() Config { - return Config{ - Scheduler: "", - Parameters: "", - } +type SchedulingConfig struct { + KueueConfig `json:"Kueue,omitempty" pflag:", Specify Kueue scheduling scheduling config"` + YunikornConfig `json:"Yunikorn,omitempty" pflag:", Yunikorn scheduling config"` } -func (b *Config) GetScheduler() string { - return b.Scheduler +type KueueConfig struct { + PriorityClassName string `json:"Priority,omitempty" pflag:", Kueue Prioty class"` + Queue string `json:"Queue,omitempty" pflag:", Specify batch scheduler to"` } -func (b *Config) GetParameters() string { - return b.Parameters +type YunikornConfig struct { + Parameters string `json:"parameters,omitempty" pflag:", Specify gangscheduling policy"` + Queue string `json:"queue,omitempty" pflag:", Specify leaf queue to submit to"` } diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config_test.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config_test.go deleted file mode 100644 index e234008b53..0000000000 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/config_test.go +++ /dev/null @@ -1,15 +0,0 @@ -package batchscheduler - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestNewConfig(t *testing.T) { - t.Run("New scheduler plugin config", func(t *testing.T) { - config := NewConfig() - assert.Equal(t, "", config.GetScheduler()) - assert.Equal(t, "", config.GetParameters()) - }) -} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/kueue/helper.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/kueue/helper.go new file mode 100644 index 0000000000..1889ec451d --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/kueue/helper.go @@ -0,0 +1,16 @@ +package kueue + +import ( + rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" + + "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/utils" +) + +const ( + QueueName = "kueue.x-k8s.io/queue-name" + PriorityClassName = "kueue.x-k8s.io/priority-class" +) + +func UpdateKueueLabels(labels map[string]string, app *rayv1.RayJob) { + utils.UpdateLabels(labels, &app.ObjectMeta) +} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/plugins.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/plugins.go deleted file mode 100644 index 944d722787..0000000000 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/plugins.go +++ /dev/null @@ -1,23 +0,0 @@ -package batchscheduler - -import ( - "context" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn" - "sigs.k8s.io/controller-runtime/pkg/client" -) - -type SchedulerManager interface { - // Mutate is responsible for mutating the object to be scheduled. - // It will add the necessary annotations, labels, etc. to the object. - Mutate(ctx context.Context, object client.Object) error -} - -func NewSchedulerManager(cfg *Config) SchedulerManager { - switch cfg.GetScheduler() { - case yunikorn.Yunikorn: - return yunikorn.NewSchedulerManager(cfg) - default: - return scheduler.NewNoopSchedulerManager() - } -} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/noop_scheduler.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/noop_scheduler.go deleted file mode 100644 index 8f9ca31146..0000000000 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/noop_scheduler.go +++ /dev/null @@ -1,16 +0,0 @@ -package scheduler - -import ( - "context" - "sigs.k8s.io/controller-runtime/pkg/client" -) - -type NoopSchedulerManager struct{} - -func NewNoopSchedulerManager() *NoopSchedulerManager { - return &NoopSchedulerManager{} -} - -func (p *NoopSchedulerManager) Mutate(ctx context.Context, object client.Object) error { - return nil -} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn.go deleted file mode 100644 index 6abfbbb434..0000000000 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn.go +++ /dev/null @@ -1,34 +0,0 @@ -package yunikorn - -import ( - "context" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler" - rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" - "sigs.k8s.io/controller-runtime/pkg/client" -) - -const ( - Yunikorn = "yunikorn" - AppID = "yunikorn.apache.org/app-id" - TaskGroupNameKey = "yunikorn.apache.org/task-group-name" - TaskGroupsKey = "yunikorn.apache.org/task-groups" - TaskGroupParameters = "yunikorn.apache.org/schedulingPolicyParameters" -) - -type YunikornSchedulerManager struct { - parameters string -} - -func (y *YunikornSchedulerManager) Mutate(ctx context.Context, object client.Object) error { - switch object.(type) { - case *rayv1.RayJob: - return MutateRayJob(y.parameters, object.(*rayv1.RayJob)) - default: - - } - return nil -} - -func NewSchedulerManager(cfg *batchscheduler.Config) batchscheduler.SchedulerManager { - return &YunikornSchedulerManager{parameters: cfg.Parameters} -} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn_test.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn_test.go deleted file mode 100644 index 08fe5a1761..0000000000 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/yunikorn_test.go +++ /dev/null @@ -1,46 +0,0 @@ -package yunikorn - -import ( - "gotest.tools/assert" - "testing" -) - -func TestNewPlugin(t *testing.T) { - tests := []struct { - input string - expect *Plugin - }{ - { - input: "", - expect: &Plugin{ - Parameters: "", - }, - }, - { - input: "placeholderTimeoutInSeconds=30 gangSchedulingStyle=Hard", - expect: &Plugin{ - Parameters: "placeholderTimeoutInSeconds=30 gangSchedulingStyle=Hard", - }, - }, - } - t.Run("New Yunikorn plugin", func(t *testing.T) { - got := NewSchedulerManager(t.input) - assert.NotNil(t, got) - assert.Equal(t, t.input, got.Parameters) - }) -} - -func TestProcess(t *testing.T) { - tests := []struct { - input interface{} - expect error - }{ - {input: 1, expect: nil}, - {input: "test", expect: nil}, - } - t.Run("Yunikorn plugin process any type", func(t *testing.T) { - got := NewSchedulerManager(t.input) - assert.NotNil(t, got) - assert.Equal(t, t.input, got.Parameters) - }) -} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/utils/helper.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/utils/helper.go new file mode 100644 index 0000000000..7133a328f3 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/utils/helper.go @@ -0,0 +1,30 @@ +package utils + +import ( + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func UpdateLabels(wanted map[string]string, objectMeta *metav1.ObjectMeta) { + for key, value := range wanted { + if _, exist := objectMeta.Labels[key]; !exist { + objectMeta.Labels[key] = value + } + } +} + +func UpdateAnnotations(wanted map[string]string, objectMeta *metav1.ObjectMeta) { + for key, value := range wanted { + if _, exist := objectMeta.Annotations[key]; !exist { + objectMeta.Annotations[key] = value + } + } +} + +func UpdatePodTemplateAnnotatations(wanted map[string]string, pod *v1.PodTemplateSpec) { + UpdateAnnotations(wanted, &pod.ObjectMeta) +} + +func UpdatePodTemplateLabels(wanted map[string]string, pod *v1.PodTemplateSpec) { + UpdateLabels(wanted, &pod.ObjectMeta) +} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/rayhandler.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go similarity index 65% rename from flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/rayhandler.go rename to flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go index 0884a68df1..2480d80eee 100644 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/rayhandler.go +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go @@ -2,12 +2,25 @@ package yunikorn import ( "encoding/json" + rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/utils" +) + +const ( + Yunikorn = "yunikorn" + AppID = "yunikorn.apache.org/app-id" + Queue = "yunikorn.apache.org/queue" + TaskGroupNameKey = "yunikorn.apache.org/task-group-name" + TaskGroupsKey = "yunikorn.apache.org/task-groups" + TaskGroupParameters = "yunikorn.apache.org/schedulingPolicyParameters" ) -func MutateRayJob(parameters string, app *rayv1.RayJob) error { +func MutateRayJob(app *rayv1.RayJob) error { appID := GenerateTaskGroupAppID() rayjobSpec := &app.Spec appSpec := rayjobSpec.RayClusterSpec @@ -60,11 +73,30 @@ func MutateRayJob(parameters string, app *rayv1.RayJob) error { return err } meta.Annotations[TaskGroupsKey] = string(info[:]) - meta.Annotations[TaskGroupParameters] = parameters meta.Annotations[AppID] = appID return nil } +func UpdateGangSchedulingParameters(parameters string, objectMeta *metav1.ObjectMeta) { + if len(parameters) == 0 { + return + } + utils.UpdateAnnotations( + map[string]string{TaskGroupParameters: parameters}, + objectMeta, + ) +} + +func UpdateAnnotations(labels map[string]string, app *rayv1.RayJob) { + appSpec := app.Spec.RayClusterSpec + headSpec := appSpec.HeadGroupSpec + utils.UpdatePodTemplateAnnotatations(labels, &headSpec.Template) + for index := range appSpec.WorkerGroupSpecs { + worker := appSpec.WorkerGroupSpecs[index] + utils.UpdatePodTemplateAnnotatations(labels, &worker.Template) + } +} + func Allocation(containers []v1.Container) v1.ResourceList { totalResources := v1.ResourceList{} for _, c := range containers { @@ -81,19 +113,19 @@ func Allocation(containers []v1.Container) v1.ResourceList { return totalResources } -func Add(a v1.ResourceList, b v1.ResourceList) v1.ResourceList { - result := a - for name, value := range a { - sum := &value - if value2, ok := b[name]; ok { +func Add(left v1.ResourceList, right v1.ResourceList) v1.ResourceList { + result := left + for name, value := range left { + sum := value + if value2, ok := right[name]; ok { sum.Add(value2) - result[name] = *sum + result[name] = sum } else { result[name] = value } } - for name, value := range b { - if _, ok := a[name]; !ok { + for name, value := range right { + if _, ok := left[name]; !ok { result[name] = value } } diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/taskgroup.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/taskgroup.go similarity index 100% rename from flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/taskgroup.go rename to flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/taskgroup.go diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/taskgroup_test.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/taskgroup_test.go similarity index 86% rename from flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/taskgroup_test.go rename to flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/taskgroup_test.go index ce5bfe8d88..5472b6feb6 100644 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/taskgroup_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/taskgroup_test.go @@ -4,9 +4,15 @@ import ( "testing" "github.com/stretchr/testify/assert" + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" ) func TestMarshal(t *testing.T) { + res := v1.ResourceList{ + v1.ResourceCPU: resource.MustParse("500m"), + v1.ResourceMemory: resource.MustParse("512Mi"), + } t1 := TaskGroup{ Name: "tg1", MinMember: int32(1), @@ -43,4 +49,4 @@ func TestMarshal(t *testing.T) { assert.Nil(t, err) } }) -} \ No newline at end of file +} diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/utils.go similarity index 100% rename from flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils.go rename to flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/utils.go diff --git a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils_test.go b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/utils_test.go similarity index 89% rename from flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils_test.go rename to flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/utils_test.go index 14aba4ddbc..ca0502baaf 100644 --- a/flyteplugins/go/tasks/plugins/k8s/batchscheduler/scheduler/yunikorn/utils_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/utils_test.go @@ -43,11 +43,9 @@ func TestGenerateTaskGroupName(t *testing.T) { func TestGenerateTaskGroupAppID(t *testing.T) { t.Run("Generate ray app ID", func(t *testing.T) { - for _, tt := range tests { - got := GenerateTaskGroupAppID() - if len(got) <= 0 { - t.Error("Ray app ID is empty") - } + got := GenerateTaskGroupAppID() + if len(got) <= 0 { + t.Error("Ray app ID is empty") } }) -} \ No newline at end of file +} diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/config.go b/flyteplugins/go/tasks/plugins/k8s/ray/config.go index 8be030413b..c4d84e2a48 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/config.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/config.go @@ -2,7 +2,6 @@ package ray import ( "context" - schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler" v1 "k8s.io/api/core/v1" @@ -10,6 +9,7 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/logs" pluginmachinery "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" + schedulerConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler" "github.com/flyteorg/flyte/flytestdlib/config" ) diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/config_flags.go b/flyteplugins/go/tasks/plugins/k8s/ray/config_flags.go index a725e93c5a..02284b8d5e 100755 --- a/flyteplugins/go/tasks/plugins/k8s/ray/config_flags.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/config_flags.go @@ -55,8 +55,11 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.String(fmt.Sprintf("%v%v", prefix, "serviceType"), defaultConfig.ServiceType, "") cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "includeDashboard"), defaultConfig.IncludeDashboard, "") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "dashboardHost"), defaultConfig.DashboardHost, "") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "batchScheduler.scheduler"), defaultConfig.BatchScheduler.Scheduler, "") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "batchScheduler.parameters"), defaultConfig.BatchScheduler.Parameters, "") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "batchScheduler.scheduler"), defaultConfig.BatchScheduler.Scheduler, " Specify batch scheduler to") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "batchScheduler.default.Kueue.Priority"), defaultConfig.BatchScheduler.Default.KueueConfig.PriorityClassName, " Kueue Prioty class") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "batchScheduler.default.Kueue.Queue"), defaultConfig.BatchScheduler.Default.KueueConfig.Queue, " Specify batch scheduler to") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "batchScheduler.default.Yunikorn.parameters"), defaultConfig.BatchScheduler.Default.YunikornConfig.Parameters, " Specify gangscheduling policy") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "batchScheduler.default.Yunikorn.queue"), defaultConfig.BatchScheduler.Default.YunikornConfig.Queue, " Specify leaf queue to submit to") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "remoteClusterConfig.name"), defaultConfig.RemoteClusterConfig.Name, "Friendly name of the remote cluster") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "remoteClusterConfig.endpoint"), defaultConfig.RemoteClusterConfig.Endpoint, " Remote K8s cluster endpoint") cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "remoteClusterConfig.enabled"), defaultConfig.RemoteClusterConfig.Enabled, " Boolean flag to enable or disable") diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/config_flags_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/config_flags_test.go index 960752c0a5..73c4d6de37 100755 --- a/flyteplugins/go/tasks/plugins/k8s/ray/config_flags_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/config_flags_test.go @@ -183,14 +183,56 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) - t.Run("Test_batchScheduler.parameters", func(t *testing.T) { + t.Run("Test_batchScheduler.default.Kueue.Priority", func(t *testing.T) { t.Run("Override", func(t *testing.T) { testValue := "1" - cmdFlags.Set("batchScheduler.parameters", testValue) - if vString, err := cmdFlags.GetString("batchScheduler.parameters"); err == nil { - testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.BatchScheduler.Parameters) + cmdFlags.Set("batchScheduler.default.Kueue.Priority", testValue) + if vString, err := cmdFlags.GetString("batchScheduler.default.Kueue.Priority"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.BatchScheduler.Default.KueueConfig.PriorityClassName) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_batchScheduler.default.Kueue.Queue", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("batchScheduler.default.Kueue.Queue", testValue) + if vString, err := cmdFlags.GetString("batchScheduler.default.Kueue.Queue"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.BatchScheduler.Default.KueueConfig.Queue) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_batchScheduler.default.Yunikorn.parameters", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("batchScheduler.default.Yunikorn.parameters", testValue) + if vString, err := cmdFlags.GetString("batchScheduler.default.Yunikorn.parameters"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.BatchScheduler.Default.YunikornConfig.Parameters) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_batchScheduler.default.Yunikorn.queue", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("batchScheduler.default.Yunikorn.queue", testValue) + if vString, err := cmdFlags.GetString("batchScheduler.default.Yunikorn.queue"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.BatchScheduler.Default.YunikornConfig.Queue) } else { assert.FailNow(t, err.Error()) diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go index 178d2427ad..31a7e44fcd 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go @@ -28,6 +28,8 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/kueue" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn" ) const ( @@ -551,6 +553,51 @@ func getEventInfoForRayJob(logConfig logs.LogConfig, pluginContext k8s.PluginCon return &pluginsCore.TaskInfo{Logs: taskLogs}, nil } +func (plugin rayJobResourceHandler) MutateResourceForYunikorn(ctx context.Context, object client.Object, taskTmpl *core.TaskTemplate) (client.Object, error) { + rayJob := object.(*rayv1.RayJob) + // Update gang scheduling annotations + if err := yunikorn.MutateRayJob(rayJob); err != nil { + return rayJob, err + } + // Update Yunikorn annotations + cfg := GetConfig().BatchScheduler.Default.YunikornConfig + id := taskTmpl.Id + annotations := make(map[string]string, 0) + queueName := fmt.Sprintf("root.%s.%s", id.Project, id.Domain) + if len(cfg.Queue) > 0 { + if cfg.Queue == "namespace" { + queueName = fmt.Sprintf("%s.%s", queueName, rayJob.ObjectMeta.Namespace) + } else { + queueName = fmt.Sprintf("%s.%s", queueName, cfg.Queue) + } + } else { + queueName = fmt.Sprintf("%s.%s", queueName, id.ResourceType) + } + annotations[yunikorn.Queue] = queueName + annotations[yunikorn.TaskGroupParameters] = cfg.Parameters + yunikorn.UpdateAnnotations(annotations, rayJob) + return rayJob, nil +} + +func (plugin rayJobResourceHandler) MutateResourceForKueue(ctx context.Context, object client.Object, taskTmpl *core.TaskTemplate) (client.Object, error) { + rayJob := object.(*rayv1.RayJob) + cfg := GetConfig().BatchScheduler.Default.KueueConfig + id := taskTmpl.Id + queueName := fmt.Sprintf("%s.%s", id.Project, id.Domain) + if len(cfg.Queue) > 0 { + if cfg.Queue == "namespace" { + queueName = fmt.Sprintf("%s.%s", queueName, rayJob.ObjectMeta.Namespace) + } else { + queueName = fmt.Sprintf("%s.%s", queueName, cfg.Queue) + } + } else { + queueName = fmt.Sprintf("%s.%s", queueName, id.ResourceType) + } + rayJob.ObjectMeta.Labels[kueue.QueueName] = queueName + rayJob.ObjectMeta.Labels[kueue.PriorityClassName] = cfg.PriorityClassName + return object, nil +} + func (plugin rayJobResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) { rayJob := resource.(*rayv1.RayJob) info, err := getEventInfoForRayJob(GetConfig().Logs, pluginContext, rayJob) diff --git a/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager.go b/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager.go index 6ea6bcede1..1c84eb7b38 100644 --- a/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager.go +++ b/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager.go @@ -3,7 +3,6 @@ package k8s import ( "context" "fmt" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler" "time" "golang.org/x/time/rate" @@ -91,7 +90,6 @@ type PluginManager struct { plugin k8s.Plugin resourceToWatch runtime.Object kubeClient pluginsCore.KubeClient - schedulerMgr batchscheduler.SchedulerManager metrics PluginMetrics // Per namespace-resource backOffController *backoff.Controller @@ -199,18 +197,19 @@ func (e *PluginManager) launchResource(ctx context.Context, tCtx pluginsCore.Tas if err != nil { return pluginsCore.UnknownTransition, err } + if p, ok := e.plugin.(k8s.YunikornScheduablePlugin); ok { + o, err = p.MutateResourceForYunikorn(ctx, o, tmpl) + if err != nil { + return pluginsCore.UnknownTransition, err + } + tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + } e.addObjectMetadata(k8sTaskCtxMetadata, o, config.GetK8sPluginConfig()) logger.Infof(ctx, "Creating Object: Type:[%v], Object:[%v/%v]", o.GetObjectKind().GroupVersionKind(), o.GetNamespace(), o.GetName()) key := backoff.ComposeResourceKey(o) - err = e.schedulerMgr.Mutate(ctx, o) - if err != nil { - logger.Errorf(ctx, "Scheduler plugin failed to process object with error: %v", err) - return pluginsCore.Transition{}, err - } - pod, casted := o.(*v1.Pod) if e.backOffController != nil && casted { podRequestedResources := e.getPodEffectiveResourceLimits(ctx, pod) @@ -544,8 +543,6 @@ func NewPluginManager(ctx context.Context, iCtx pluginsCore.SetupContext, entry return nil, errors.Errorf(errors.PluginInitializationFailed, "Failed to initialize K8sResource Plugin, Kubeclient cannot be nil!") } - schedulerMgr := batchscheduler.NewSchedulerManager(&config.GetK8sPluginConfig().BatchScheduler) - logger.Infof(ctx, "Initializing K8s plugin [%s]", entry.ID) src := source.Kind(iCtx.KubeClient().GetCache(), entry.ResourceToWatch) @@ -660,7 +657,6 @@ func NewPluginManager(ctx context.Context, iCtx pluginsCore.SetupContext, entry resourceToWatch: entry.ResourceToWatch, metrics: newPluginMetrics(metricsScope), kubeClient: kubeClient, - schedulerMgr: schedulerMgr, resourceLevelMonitor: rm, eventWatcher: eventWatcher, }, nil