Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating annotations of pods belonging Ray cluster in order to adopting Yunikorn Gang scheduling #5594

Draft
wants to merge 30 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

"sigs.k8s.io/controller-runtime/pkg/client"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io"
"github.com/flyteorg/flyte/flytestdlib/storage"
Expand Down Expand Up @@ -187,3 +188,11 @@ func MaybeUpdatePhaseVersionFromPluginContext(phaseInfo *pluginsCore.PhaseInfo,
MaybeUpdatePhaseVersion(phaseInfo, &pluginState)
return nil
}

type YunikornScheduablePlugin interface {
MutateResourceForYunikorn(ctx context.Context, object client.Object, taskTmpl *core.TaskTemplate) (client.Object, error)
}

type KueueScheduablePlugin interface {
MutateResourceForKueue(ctx context.Context, object client.Object, taskTmpl *core.TaskTemplate) (client.Object, error)
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 23 additions & 0 deletions flyteplugins/go/tasks/plugins/k8s/batchscheduler/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package batchscheduler

type Config struct {
Scheduler string `json:"scheduler,omitempty" pflag:", Specify batch scheduler to"`
Default SchedulingConfig `json:"default,omitempty" pflag:", Specify default scheduling config which batch scheduler adopts"`
NameSpace map[string]SchedulingConfig `json:"Namespace,omitempty" pflag:"-, Specify namespace scheduling config"`
Domain map[string]SchedulingConfig `json:"Domain,omitempty" pflag:"-, Specify domain scheduling config"`
}

type SchedulingConfig struct {
KueueConfig `json:"Kueue,omitempty" pflag:", Specify Kueue scheduling scheduling config"`
YunikornConfig `json:"Yunikorn,omitempty" pflag:", Yunikorn scheduling config"`
}

type KueueConfig struct {
PriorityClassName string `json:"Priority,omitempty" pflag:", Kueue Prioty class"`
Queue string `json:"Queue,omitempty" pflag:", Specify batch scheduler to"`
}

type YunikornConfig struct {
Parameters string `json:"parameters,omitempty" pflag:", Specify gangscheduling policy"`
Queue string `json:"queue,omitempty" pflag:", Specify leaf queue to submit to"`
}
16 changes: 16 additions & 0 deletions flyteplugins/go/tasks/plugins/k8s/batchscheduler/kueue/helper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package kueue

import (
rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"

"github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/utils"
)

const (
QueueName = "kueue.x-k8s.io/queue-name"
PriorityClassName = "kueue.x-k8s.io/priority-class"
)

func UpdateKueueLabels(labels map[string]string, app *rayv1.RayJob) {
utils.UpdateLabels(labels, &app.ObjectMeta)

Check warning on line 15 in flyteplugins/go/tasks/plugins/k8s/batchscheduler/kueue/helper.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/batchscheduler/kueue/helper.go#L14-L15

Added lines #L14 - L15 were not covered by tests
}
30 changes: 30 additions & 0 deletions flyteplugins/go/tasks/plugins/k8s/batchscheduler/utils/helper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package utils

import (
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)

func UpdateLabels(wanted map[string]string, objectMeta *metav1.ObjectMeta) {
for key, value := range wanted {
if _, exist := objectMeta.Labels[key]; !exist {
objectMeta.Labels[key] = value

Check warning on line 11 in flyteplugins/go/tasks/plugins/k8s/batchscheduler/utils/helper.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/batchscheduler/utils/helper.go#L8-L11

Added lines #L8 - L11 were not covered by tests
}
}
}

func UpdateAnnotations(wanted map[string]string, objectMeta *metav1.ObjectMeta) {
for key, value := range wanted {
if _, exist := objectMeta.Annotations[key]; !exist {
objectMeta.Annotations[key] = value

Check warning on line 19 in flyteplugins/go/tasks/plugins/k8s/batchscheduler/utils/helper.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/batchscheduler/utils/helper.go#L16-L19

Added lines #L16 - L19 were not covered by tests
}
}
}

func UpdatePodTemplateAnnotatations(wanted map[string]string, pod *v1.PodTemplateSpec) {
UpdateAnnotations(wanted, &pod.ObjectMeta)

Check warning on line 25 in flyteplugins/go/tasks/plugins/k8s/batchscheduler/utils/helper.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/batchscheduler/utils/helper.go#L24-L25

Added lines #L24 - L25 were not covered by tests
}

func UpdatePodTemplateLabels(wanted map[string]string, pod *v1.PodTemplateSpec) {
UpdateLabels(wanted, &pod.ObjectMeta)

Check warning on line 29 in flyteplugins/go/tasks/plugins/k8s/batchscheduler/utils/helper.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/batchscheduler/utils/helper.go#L28-L29

Added lines #L28 - L29 were not covered by tests
}
133 changes: 133 additions & 0 deletions flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package yunikorn

import (
"encoding/json"

rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

"github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/batchscheduler/utils"
)

const (
Yunikorn = "yunikorn"
AppID = "yunikorn.apache.org/app-id"
Queue = "yunikorn.apache.org/queue"
TaskGroupNameKey = "yunikorn.apache.org/task-group-name"
TaskGroupsKey = "yunikorn.apache.org/task-groups"
TaskGroupParameters = "yunikorn.apache.org/schedulingPolicyParameters"
)

func MutateRayJob(app *rayv1.RayJob) error {
appID := GenerateTaskGroupAppID()
rayjobSpec := &app.Spec
appSpec := rayjobSpec.RayClusterSpec
TaskGroups := make([]TaskGroup, 1)
for index := range appSpec.WorkerGroupSpecs {
worker := &appSpec.WorkerGroupSpecs[index]
worker.Template.Spec.SchedulerName = Yunikorn
meta := worker.Template.ObjectMeta
spec := worker.Template.Spec
name := GenerateTaskGroupName(false, index)
TaskGroups = append(TaskGroups, TaskGroup{
Name: name,
MinMember: *worker.Replicas,
Labels: meta.Labels,
Annotations: meta.Annotations,
MinResource: Allocation(spec.Containers),
NodeSelector: spec.NodeSelector,
Affinity: spec.Affinity,
TopologySpreadConstraints: spec.TopologySpreadConstraints,
})
meta.Annotations[TaskGroupNameKey] = name
meta.Annotations[AppID] = appID

Check warning on line 45 in flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go#L23-L45

Added lines #L23 - L45 were not covered by tests
}
headSpec := &appSpec.HeadGroupSpec
headSpec.Template.Spec.SchedulerName = Yunikorn
meta := headSpec.Template.ObjectMeta
spec := headSpec.Template.Spec
headName := GenerateTaskGroupName(true, 0)
res := Allocation(spec.Containers)
if ok := *appSpec.EnableInTreeAutoscaling; ok {
res2 := v1.ResourceList{
v1.ResourceCPU: resource.MustParse("500m"),
v1.ResourceMemory: resource.MustParse("512Mi"),

Check warning on line 56 in flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go#L47-L56

Added lines #L47 - L56 were not covered by tests
}
res = Add(res, res2)

Check warning on line 58 in flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go#L58

Added line #L58 was not covered by tests
}
TaskGroups[0] = TaskGroup{
Name: headName,
MinMember: 1,
Labels: meta.Labels,
Annotations: meta.Annotations,
MinResource: res,
NodeSelector: spec.NodeSelector,
Affinity: spec.Affinity,
TopologySpreadConstraints: spec.TopologySpreadConstraints,

Check warning on line 68 in flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go#L60-L68

Added lines #L60 - L68 were not covered by tests
}
meta.Annotations[TaskGroupNameKey] = headName
info, err := json.Marshal(TaskGroups)
if err != nil {
return err

Check warning on line 73 in flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go#L70-L73

Added lines #L70 - L73 were not covered by tests
}
meta.Annotations[TaskGroupsKey] = string(info[:])
meta.Annotations[AppID] = appID
return nil

Check warning on line 77 in flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go#L75-L77

Added lines #L75 - L77 were not covered by tests
}

func UpdateGangSchedulingParameters(parameters string, objectMeta *metav1.ObjectMeta) {
if len(parameters) == 0 {
return

Check warning on line 82 in flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go#L80-L82

Added lines #L80 - L82 were not covered by tests
}
utils.UpdateAnnotations(
map[string]string{TaskGroupParameters: parameters},
objectMeta,
)

Check warning on line 87 in flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go#L84-L87

Added lines #L84 - L87 were not covered by tests
}

func UpdateAnnotations(labels map[string]string, app *rayv1.RayJob) {
appSpec := app.Spec.RayClusterSpec
headSpec := appSpec.HeadGroupSpec
utils.UpdatePodTemplateAnnotatations(labels, &headSpec.Template)
for index := range appSpec.WorkerGroupSpecs {
worker := appSpec.WorkerGroupSpecs[index]
utils.UpdatePodTemplateAnnotatations(labels, &worker.Template)

Check warning on line 96 in flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go#L90-L96

Added lines #L90 - L96 were not covered by tests
}
}

func Allocation(containers []v1.Container) v1.ResourceList {
totalResources := v1.ResourceList{}
for _, c := range containers {
for name, q := range c.Resources.Limits {
if _, exists := totalResources[name]; !exists {
totalResources[name] = q.DeepCopy()
continue

Check warning on line 106 in flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go#L100-L106

Added lines #L100 - L106 were not covered by tests
}
total := totalResources[name]
total.Add(q)
totalResources[name] = total

Check warning on line 110 in flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go#L108-L110

Added lines #L108 - L110 were not covered by tests
}
}
return totalResources

Check warning on line 113 in flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go#L113

Added line #L113 was not covered by tests
}

func Add(left v1.ResourceList, right v1.ResourceList) v1.ResourceList {
result := left
for name, value := range left {
sum := value
if value2, ok := right[name]; ok {
sum.Add(value2)
result[name] = sum
} else {
result[name] = value

Check warning on line 124 in flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go#L116-L124

Added lines #L116 - L124 were not covered by tests
}
}
for name, value := range right {
if _, ok := left[name]; !ok {
result[name] = value

Check warning on line 129 in flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go#L127-L129

Added lines #L127 - L129 were not covered by tests
}
}
return result

Check warning on line 132 in flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/helper.go#L132

Added line #L132 was not covered by tests
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package yunikorn

import (
"encoding/json"

v1 "k8s.io/api/core/v1"
)

type TaskGroup struct {
Name string
MinMember int32
Labels map[string]string
Annotations map[string]string
MinResource v1.ResourceList
NodeSelector map[string]string
Tolerations []v1.Toleration
Affinity *v1.Affinity
TopologySpreadConstraints []v1.TopologySpreadConstraint
}

func Marshal(taskGroups []TaskGroup) ([]byte, error) {
info, err := json.Marshal(taskGroups)
return info, err
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package yunikorn

import (
"testing"

"github.com/stretchr/testify/assert"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
)

func TestMarshal(t *testing.T) {
res := v1.ResourceList{
v1.ResourceCPU: resource.MustParse("500m"),
v1.ResourceMemory: resource.MustParse("512Mi"),
}
t1 := TaskGroup{
Name: "tg1",
MinMember: int32(1),
Labels: map[string]string{"attr": "value"},
Annotations: map[string]string{"attr": "value"},
MinResource: res,
NodeSelector: map[string]string{"node": "gpunode"},
Tolerations: nil,
Affinity: nil,
TopologySpreadConstraints: nil,
}
t2 := TaskGroup{
Name: "tg2",
MinMember: int32(1),
Labels: map[string]string{"attr": "value"},
Annotations: map[string]string{"attr": "value"},
MinResource: res,
NodeSelector: map[string]string{"node": "gpunode"},
Tolerations: nil,
Affinity: nil,
TopologySpreadConstraints: nil,
}
var tests = []struct {
input []TaskGroup
}{
{input: nil},
{input: []TaskGroup{}},
{input: []TaskGroup{t1}},
{input: []TaskGroup{t1, t2}},
}
t.Run("Serialize task groups", func(t *testing.T) {
for _, tt := range tests {
_, err := Marshal(tt.input)
assert.Nil(t, err)
}
})
}
23 changes: 23 additions & 0 deletions flyteplugins/go/tasks/plugins/k8s/batchscheduler/yunikorn/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package yunikorn

import (
"fmt"

"github.com/google/uuid"
)

const (
TaskGroupGenericName = "task-group"
)

func GenerateTaskGroupName(master bool, index int) string {
if master {
return fmt.Sprintf("%s-%s", TaskGroupGenericName, "head")
}
return fmt.Sprintf("%s-%s-%d", TaskGroupGenericName, "worker", index)
}

func GenerateTaskGroupAppID() string {
uid := uuid.New().String()
return fmt.Sprintf("%s-%s", TaskGroupGenericName, uid)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package yunikorn

import (
"fmt"
"testing"

"github.com/stretchr/testify/assert"
)

func TestGenerateTaskGroupName(t *testing.T) {
type inputFormat struct {
isMaster bool
index int
}
var tests = []struct {
input inputFormat
expect string
}{
{
input: inputFormat{isMaster: true, index: 0},
expect: fmt.Sprintf("%s-%s", TaskGroupGenericName, "head"),
},
{
input: inputFormat{isMaster: true, index: 1},
expect: fmt.Sprintf("%s-%s", TaskGroupGenericName, "head"),
},
{
input: inputFormat{isMaster: false, index: 0},
expect: fmt.Sprintf("%s-%s-%d", TaskGroupGenericName, "worker", 0),
},
{
input: inputFormat{isMaster: false, index: 1},
expect: fmt.Sprintf("%s-%s-%d", TaskGroupGenericName, "worker", 1),
},
}
t.Run("Generate ray task group name", func(t *testing.T) {
for _, tt := range tests {
got := GenerateTaskGroupName(tt.input.isMaster, tt.input.index)
assert.Equal(t, tt.expect, got)
}
})
}

func TestGenerateTaskGroupAppID(t *testing.T) {
t.Run("Generate ray app ID", func(t *testing.T) {
got := GenerateTaskGroupAppID()
if len(got) <= 0 {
t.Error("Ray app ID is empty")
}
})
}
Loading
Loading