diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go index 111c84f801..0d71abe970 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go @@ -81,9 +81,11 @@ func transformPodSpecToTaskTemplateTarget(podSpec *corev1.PodSpec) *core.TaskTem func dummyRayCustomObj() *plugins.RayJob { return &plugins.RayJob{ RayCluster: &plugins.RayCluster{ - HeadGroupSpec: &plugins.HeadGroupSpec{RayStartParams: map[string]string{"num-cpus": "1"}}, - WorkerGroupSpec: []*plugins.WorkerGroupSpec{{GroupName: workerGroupName, Replicas: 3, MinReplicas: 3, MaxReplicas: 3}}, - EnableAutoscaling: false, + HeadGroupSpec: &plugins.HeadGroupSpec{RayStartParams: map[string]string{"num-cpus": "1"}}, + WorkerGroupSpec: []*plugins.WorkerGroupSpec{{GroupName: workerGroupName, Replicas: 3, MinReplicas: 3, MaxReplicas: 3}}, + EnableAutoscaling: true, + ShutdownAfterJobFinishes: true, + TtlSecondsAfterFinished: 120, }, } } @@ -179,7 +181,9 @@ func TestBuildResourceRay(t *testing.T) { ray, ok := RayResource.(*rayv1alpha1.RayJob) assert.True(t, ok) - assert.Equal(t, *ray.Spec.RayClusterSpec.EnableInTreeAutoscaling, false) + assert.Equal(t, *ray.Spec.RayClusterSpec.EnableInTreeAutoscaling, true) + assert.Equal(t, *ray.Spec.RayClusterSpec.ShutdownAfterJobFinishes, true) + assert.Equal(t, *ray.Spec.RayClusterSpec.TTLSecondsAfterFinished, int32(120)) headReplica := int32(1) assert.Equal(t, *ray.Spec.RayClusterSpec.HeadGroupSpec.Replicas, headReplica) @@ -351,9 +355,11 @@ func TestDefaultStartParameters(t *testing.T) { rayJobResourceHandler := rayJobResourceHandler{} rayJob := &plugins.RayJob{ RayCluster: &plugins.RayCluster{ - HeadGroupSpec: &plugins.HeadGroupSpec{}, - WorkerGroupSpec: []*plugins.WorkerGroupSpec{{GroupName: workerGroupName, Replicas: 3, MinReplicas: 3, MaxReplicas: 3}}, - EnableAutoscaling: false, + HeadGroupSpec: &plugins.HeadGroupSpec{}, + WorkerGroupSpec: []*plugins.WorkerGroupSpec{{GroupName: workerGroupName, Replicas: 3, MinReplicas: 3, MaxReplicas: 3}}, + EnableAutoscaling: true, + ShutdownAfterJobFinishes: true, + TtlSecondsAfterFinished: 120, }, } @@ -374,7 +380,9 @@ func TestDefaultStartParameters(t *testing.T) { ray, ok := RayResource.(*rayv1alpha1.RayJob) assert.True(t, ok) - assert.Equal(t, *ray.Spec.RayClusterSpec.EnableInTreeAutoscaling, false) + assert.Equal(t, *ray.Spec.RayClusterSpec.EnableInTreeAutoscaling, true) + assert.Equal(t, *ray.Spec.RayClusterSpec.ShutdownAfterJobFinishes, true) + assert.Equal(t, *ray.Spec.RayClusterSpec.TTLSecondsAfterFinished, int32(120)) headReplica := int32(1) assert.Equal(t, *ray.Spec.RayClusterSpec.HeadGroupSpec.Replicas, headReplica)