Skip to content

Commit

Permalink
unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: yuteng <[email protected]>
  • Loading branch information
0yukali0 committed Jul 26, 2024
1 parent dbc07cf commit a5c8fd5
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package batchscheduler

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestCreateSchedulerPlugin(t *testing.T) {
Expand All @@ -15,9 +17,8 @@ func TestCreateSchedulerPlugin(t *testing.T) {
}
for _, tt := range tests {
t.Run("New scheduler plugin", func(t *testing.T) {
if got := NewSchedulerPlugin(tt.input); got.GetSchedulerName() != tt.expect {
t.Errorf("got %s, expect %s", got, tt.expect)
}
p := NewSchedulerPlugin(tt.input)
assert.Equal(t, tt.expect, p.GetSchedulerName())
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ func NewYunikornPlugin() *YunikornGangSchedulingConfig {
func (s *YunikornGangSchedulingConfig) GetSchedulerName() string { return Yunikorn }

func (s *YunikornGangSchedulingConfig) ParseJob(config *BatchSchedulerConfig, metadata *metav1.ObjectMeta, workerGroupsSpec []*plugins.WorkerGroupSpec, pod *v1.PodSpec, primaryContainerIdx int) error {
s.Annotations = nil
s.Parameters = config.GetParameters()
return s.BuildGangInfo(metadata, workerGroupsSpec, pod, primaryContainerIdx)
}
Expand Down Expand Up @@ -120,7 +121,9 @@ func (s *YunikornGangSchedulingConfig) BuildGangInfo(
headAnnotations := make(map[string]string, 0)
headAnnotations[TaskGroupNameKey] = headName
headAnnotations[TaskGroupsKey] = string(info[:])
headAnnotations[TaskGroupPrarameters] = s.Parameters
if len(s.Parameters) > 0 {
headAnnotations[TaskGroupPrarameters] = s.Parameters
}
s.Annotations[headName] = headAnnotations
return nil
}
Expand Down
129 changes: 103 additions & 26 deletions flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package batchscheduler

import (
"testing"
"encoding/json"

"github.com/stretchr/testify/assert"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins"
v1 "k8s.io/api/core/v1"
Expand All @@ -25,35 +28,113 @@ var (
Affinity: nil,
TopologySpreadConstraints: nil,
}
rayWorkersSpec = []*plugins.WorkerGroupSpec{
{
GroupName: "group1",
Replicas: int32(1),
MinReplicas: int32(1),
MaxReplicas: int32(2),
RayStartParams: nil,
},
{
GroupName: "group2",
Replicas: int32(1),
MinReplicas: int32(1),
MaxReplicas: int32(2),
RayStartParams: nil,
},
}
)

func TestSetSchedulerName(t *testing.T) {
t.Run("Set Scheduler Name", func(t *testing.T){
t.Run("Set Scheduler Name", func(t *testing.T) {
p := NewYunikornPlugin()
p.SetSchedulerName(podSpec)
if got := podSpec.SchedulerName; got != p.GetSchedulerName() {
t.Errorf("got %s, expect %s", got, p.GetSchedulerName())
}
assert.Equal(t, p.GetSchedulerName(), podSpec.SchedulerName)
podSpec.SchedulerName = ""
})
}

func TestBuildGangInfo(t *testing.T) {
names := []string{GenerateTaskGroupName(true, 0)}
res := v1.ResourceList{
"cpu": resource.MustParse("500m"),
"memory": resource.MustParse("1Gi"),
}
for index := 0; index < 2; index++ {
names = append(names, GenerateTaskGroupName(false, index))
}
var tests = []struct {
workerGroupNum int
taskGroups []TaskGroup
}{
{
workerGroupNum: 2,
taskGroups: []TaskGroup{
{
Name: names[0],
MinMember: int32(1),
Labels: nil,
Annotations: map[string]string{"others": "extra"},
MinResource: res,
NodeSelector: nil,
Tolerations: nil,
Affinity: nil,
TopologySpreadConstraints: nil,
},
{
Name: names[1],
MinMember: int32(1),
Labels: nil,
Annotations: map[string]string{"others": "extra"},
MinResource: res,
NodeSelector: nil,
Tolerations: nil,
Affinity: nil,
TopologySpreadConstraints: nil,
},
{
Name: names[2],
MinMember: int32(2),
Labels: nil,
Annotations: map[string]string{"others": "extra"},
MinResource: res,
NodeSelector: nil,
Tolerations: nil,
Affinity: nil,
TopologySpreadConstraints: nil,
},
},
},
}
for _, tt := range tests {
t.Run("Create Yunikorn gang scheduling annotations", func(t *testing.T) {
workersSpec := make([]*plugins.WorkerGroupSpec, 0)
for index := 0; index < tt.workerGroupNum; index++ {
count := 1 * (1 + index)
max := 2 * (1 + index)
workersSpec = append(workersSpec, &plugins.WorkerGroupSpec{
Replicas: int32(count),
MinReplicas: int32(count),
MaxReplicas: int32(max),
})
}
metadata := &metav1.ObjectMeta{
Annotations: map[string]string{"others": "extra"},
}
p := NewYunikornPlugin()
err := p.BuildGangInfo(metadata, workersSpec, podSpec, 0)
assert.Nil(t, err)
// test worker name
for index := 0; index < tt.workerGroupNum; index++ {
workerIndex := index + 1
name := names[workerIndex]
if annotations, ok := p.Annotations[name]; ok {
assert.Equal(t, 1, len(annotations))
assert.Equal(t, name, annotations[TaskGroupNameKey])
} else {
t.Errorf("Worker group %d annotatiosn miss", index)
}
}
// Test head name and groups
headName := names[0]
if annotations, ok := p.Annotations[headName]; ok {
info, err := json.Marshal(tt.taskGroups)
assert.Nil(t, err)
assert.Equal(t, 2, len(annotations))
assert.Equal(t, headName, annotations[TaskGroupNameKey])
assert.Equal(t, string(info[:]), annotations[TaskGroupsKey])
} else {
t.Error("Head annotations miss")
}
})
}
}

func TestGenerateTaskGroupName(t *testing.T) {
var tests = []struct {
master bool
Expand All @@ -66,9 +147,7 @@ func TestGenerateTaskGroupName(t *testing.T) {
}
for _, tt := range tests {
t.Run("Generating Task group name", func(t *testing.T) {
if got := GenerateTaskGroupName(tt.master, tt.index); got != tt.expect {
t.Errorf("got %s, expect %s", got, tt.expect)
}
assert.Equal(t, tt.expect, GenerateTaskGroupName(tt.master, tt.index))
})
}
}
Expand Down Expand Up @@ -124,9 +203,7 @@ func TestRemoveGangSchedulingAnnotations(t *testing.T) {
for _, tt := range tests {
t.Run("Remove Gang scheduling labels", func(t *testing.T) {
RemoveGangSchedulingAnnotations(tt.input)
if got := len(tt.input.Annotations); got != tt.expect {
t.Errorf("got %d, expect %d", got, tt.expect)
}
assert.Equal(t, tt.expect, len(tt.input.Annotations))
})
}
}
2 changes: 1 addition & 1 deletion flyteplugins/go/tasks/plugins/k8s/ray/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ type Config struct {
// or 0.0.0.0 (available from all interfaces). By default, this is localhost.
DashboardHost string `json:"dashboardHost,omitempty"`

BatchScheduler batchscheduler.BatchSchedulerConfig `json:"batchSchedulerConfig,omitempty"`
BatchScheduler batchscheduler.BatchSchedulerConfig `json:"BatchScheduler,omitempty"`

// DeprecatedNodeIPAddress the IP address of the head node. By default, this is pod ip address.
DeprecatedNodeIPAddress string `json:"nodeIPAddress,omitempty" pflag:"-,DEPRECATED. Please use DefaultConfig.[HeadNode|WorkerNode].IPAddress"`
Expand Down

0 comments on commit a5c8fd5

Please sign in to comment.