From a5c8fd549ae4b5802d802c2b3cecf7ed55590b88 Mon Sep 17 00:00:00 2001 From: yuteng Date: Thu, 25 Jul 2024 15:35:46 +0800 Subject: [PATCH] unit tests Signed-off-by: yuteng --- .../k8s/ray/batchscheduler/plugins_test.go | 7 +- .../k8s/ray/batchscheduler/yunikorn.go | 5 +- .../k8s/ray/batchscheduler/yunikorn_test.go | 129 ++++++++++++++---- .../go/tasks/plugins/k8s/ray/config.go | 2 +- 4 files changed, 112 insertions(+), 31 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins_test.go index 2d0bc3e9573..28142e5334b 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins_test.go @@ -2,6 +2,8 @@ package batchscheduler import ( "testing" + + "github.com/stretchr/testify/assert" ) func TestCreateSchedulerPlugin(t *testing.T) { @@ -15,9 +17,8 @@ func TestCreateSchedulerPlugin(t *testing.T) { } for _, tt := range tests { t.Run("New scheduler plugin", func(t *testing.T) { - if got := NewSchedulerPlugin(tt.input); got.GetSchedulerName() != tt.expect { - t.Errorf("got %s, expect %s", got, tt.expect) - } + p := NewSchedulerPlugin(tt.input) + assert.Equal(t, tt.expect, p.GetSchedulerName()) }) } } \ No newline at end of file diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go index 8d70a46ae08..35717429fe1 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go @@ -32,6 +32,7 @@ func NewYunikornPlugin() *YunikornGangSchedulingConfig { func (s *YunikornGangSchedulingConfig) GetSchedulerName() string { return Yunikorn } func (s *YunikornGangSchedulingConfig) ParseJob(config *BatchSchedulerConfig, metadata *metav1.ObjectMeta, workerGroupsSpec []*plugins.WorkerGroupSpec, pod *v1.PodSpec, primaryContainerIdx int) error { + s.Annotations = nil s.Parameters = config.GetParameters() return s.BuildGangInfo(metadata, workerGroupsSpec, pod, primaryContainerIdx) } @@ -120,7 +121,9 @@ func (s *YunikornGangSchedulingConfig) BuildGangInfo( headAnnotations := make(map[string]string, 0) headAnnotations[TaskGroupNameKey] = headName headAnnotations[TaskGroupsKey] = string(info[:]) - headAnnotations[TaskGroupPrarameters] = s.Parameters + if len(s.Parameters) > 0 { + headAnnotations[TaskGroupPrarameters] = s.Parameters + } s.Annotations[headName] = headAnnotations return nil } diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn_test.go index 79e7cd76a6b..59e730329c0 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn_test.go @@ -2,6 +2,9 @@ package batchscheduler import ( "testing" + "encoding/json" + + "github.com/stretchr/testify/assert" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" v1 "k8s.io/api/core/v1" @@ -25,35 +28,113 @@ var ( Affinity: nil, TopologySpreadConstraints: nil, } - rayWorkersSpec = []*plugins.WorkerGroupSpec{ - { - GroupName: "group1", - Replicas: int32(1), - MinReplicas: int32(1), - MaxReplicas: int32(2), - RayStartParams: nil, - }, - { - GroupName: "group2", - Replicas: int32(1), - MinReplicas: int32(1), - MaxReplicas: int32(2), - RayStartParams: nil, - }, - } ) func TestSetSchedulerName(t *testing.T) { - t.Run("Set Scheduler Name", func(t *testing.T){ + t.Run("Set Scheduler Name", func(t *testing.T) { p := NewYunikornPlugin() p.SetSchedulerName(podSpec) - if got := podSpec.SchedulerName; got != p.GetSchedulerName() { - t.Errorf("got %s, expect %s", got, p.GetSchedulerName()) - } + assert.Equal(t, p.GetSchedulerName(), podSpec.SchedulerName) podSpec.SchedulerName = "" }) } +func TestBuildGangInfo(t *testing.T) { + names := []string{GenerateTaskGroupName(true, 0)} + res := v1.ResourceList{ + "cpu": resource.MustParse("500m"), + "memory": resource.MustParse("1Gi"), + } + for index := 0; index < 2; index++ { + names = append(names, GenerateTaskGroupName(false, index)) + } + var tests = []struct { + workerGroupNum int + taskGroups []TaskGroup + }{ + { + workerGroupNum: 2, + taskGroups: []TaskGroup{ + { + Name: names[0], + MinMember: int32(1), + Labels: nil, + Annotations: map[string]string{"others": "extra"}, + MinResource: res, + NodeSelector: nil, + Tolerations: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + }, + { + Name: names[1], + MinMember: int32(1), + Labels: nil, + Annotations: map[string]string{"others": "extra"}, + MinResource: res, + NodeSelector: nil, + Tolerations: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + }, + { + Name: names[2], + MinMember: int32(2), + Labels: nil, + Annotations: map[string]string{"others": "extra"}, + MinResource: res, + NodeSelector: nil, + Tolerations: nil, + Affinity: nil, + TopologySpreadConstraints: nil, + }, + }, + }, + } + for _, tt := range tests { + t.Run("Create Yunikorn gang scheduling annotations", func(t *testing.T) { + workersSpec := make([]*plugins.WorkerGroupSpec, 0) + for index := 0; index < tt.workerGroupNum; index++ { + count := 1 * (1 + index) + max := 2 * (1 + index) + workersSpec = append(workersSpec, &plugins.WorkerGroupSpec{ + Replicas: int32(count), + MinReplicas: int32(count), + MaxReplicas: int32(max), + }) + } + metadata := &metav1.ObjectMeta{ + Annotations: map[string]string{"others": "extra"}, + } + p := NewYunikornPlugin() + err := p.BuildGangInfo(metadata, workersSpec, podSpec, 0) + assert.Nil(t, err) + // test worker name + for index := 0; index < tt.workerGroupNum; index++ { + workerIndex := index + 1 + name := names[workerIndex] + if annotations, ok := p.Annotations[name]; ok { + assert.Equal(t, 1, len(annotations)) + assert.Equal(t, name, annotations[TaskGroupNameKey]) + } else { + t.Errorf("Worker group %d annotatiosn miss", index) + } + } + // Test head name and groups + headName := names[0] + if annotations, ok := p.Annotations[headName]; ok { + info, err := json.Marshal(tt.taskGroups) + assert.Nil(t, err) + assert.Equal(t, 2, len(annotations)) + assert.Equal(t, headName, annotations[TaskGroupNameKey]) + assert.Equal(t, string(info[:]), annotations[TaskGroupsKey]) + } else { + t.Error("Head annotations miss") + } + }) + } +} + func TestGenerateTaskGroupName(t *testing.T) { var tests = []struct { master bool @@ -66,9 +147,7 @@ func TestGenerateTaskGroupName(t *testing.T) { } for _, tt := range tests { t.Run("Generating Task group name", func(t *testing.T) { - if got := GenerateTaskGroupName(tt.master, tt.index); got != tt.expect { - t.Errorf("got %s, expect %s", got, tt.expect) - } + assert.Equal(t, tt.expect, GenerateTaskGroupName(tt.master, tt.index)) }) } } @@ -124,9 +203,7 @@ func TestRemoveGangSchedulingAnnotations(t *testing.T) { for _, tt := range tests { t.Run("Remove Gang scheduling labels", func(t *testing.T) { RemoveGangSchedulingAnnotations(tt.input) - if got := len(tt.input.Annotations); got != tt.expect { - t.Errorf("got %d, expect %d", got, tt.expect) - } + assert.Equal(t, tt.expect, len(tt.input.Annotations)) }) } } diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/config.go b/flyteplugins/go/tasks/plugins/k8s/ray/config.go index 286ba59aff7..3d96e66ec7a 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/config.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/config.go @@ -78,7 +78,7 @@ type Config struct { // or 0.0.0.0 (available from all interfaces). By default, this is localhost. DashboardHost string `json:"dashboardHost,omitempty"` - BatchScheduler batchscheduler.BatchSchedulerConfig `json:"batchSchedulerConfig,omitempty"` + BatchScheduler batchscheduler.BatchSchedulerConfig `json:"BatchScheduler,omitempty"` // DeprecatedNodeIPAddress the IP address of the head node. By default, this is pod ip address. DeprecatedNodeIPAddress string `json:"nodeIPAddress,omitempty" pflag:"-,DEPRECATED. Please use DefaultConfig.[HeadNode|WorkerNode].IPAddress"`