From 7c5f67f1fe79d06b3625c46bf3bb84077f711416 Mon Sep 17 00:00:00 2001 From: Pin-Lun Hsu Date: Wed, 20 Mar 2024 10:25:19 -0700 Subject: [PATCH] fix tests Signed-off-by: Pin-Lun Hsu --- .../go/tasks/plugins/k8s/ray/ray_test.go | 68 +++++++++++++++++-- 1 file changed, 63 insertions(+), 5 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go index 6ebcdb05e2..de213eff68 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go @@ -105,7 +105,7 @@ func dummyRayTaskTemplate(id string, rayJob *plugins.RayJob) *core.TaskTemplate } } -func dummyRayTaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources) pluginsCore.TaskExecutionContext { +func dummyRayTaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources, containerImage string) pluginsCore.TaskExecutionContext { taskCtx := &mocks.TaskExecutionContext{} inputReader := &pluginIOMocks.InputReader{} inputReader.OnGetInputPrefixPath().Return("/input/prefix") @@ -140,6 +140,7 @@ func dummyRayTaskContext(taskTemplate *core.TaskTemplate, resources *corev1.Reso overrides := &mocks.TaskOverrides{} overrides.OnGetResources().Return(resources) overrides.OnGetExtendedResources().Return(extendedResources) + overrides.OnGetContainerImage().Return(containerImage) taskExecutionMetadata := &mocks.TaskExecutionMetadata{} taskExecutionMetadata.OnGetTaskExecutionID().Return(tID) @@ -174,7 +175,7 @@ func TestBuildResourceRay(t *testing.T) { err := config.SetK8sPluginConfig(&config.K8sPluginConfig{DefaultTolerations: toleration}) assert.Nil(t, err) - RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), dummyRayTaskContext(taskTemplate, resourceRequirements, nil)) + RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), dummyRayTaskContext(taskTemplate, resourceRequirements, nil, "")) assert.Nil(t, err) assert.NotNil(t, RayResource) @@ -207,6 +208,63 @@ func TestBuildResourceRay(t *testing.T) { assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Tolerations, toleration) } +func TestBuildResourceRayContainerImage(t *testing.T) { + assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{})) + + fixtures := []struct { + name string + resources *corev1.ResourceRequirements + containerImageOverride string + }{ + { + "without overrides", + &corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + }, + "", + }, + { + "with overrides", + &corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + }, + "container-image-override", + }, + } + + for _, f := range fixtures { + t.Run(f.name, func(t *testing.T) { + taskTemplate := dummyRayTaskTemplate("id", dummyRayCustomObj()) + taskContext := dummyRayTaskContext(taskTemplate, f.resources, nil, f.containerImageOverride) + rayJobResourceHandler := rayJobResourceHandler{} + r, err := rayJobResourceHandler.BuildResource(context.TODO(), taskContext) + assert.Nil(t, err) + assert.NotNil(t, r) + rayJob, ok := r.(*rayv1.RayJob) + assert.True(t, ok) + + var expectedContainerImage string + if len(f.containerImageOverride) > 0 { + expectedContainerImage = f.containerImageOverride + } else { + expectedContainerImage = testImage + } + + // Head node + headNodeSpec := rayJob.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec + assert.Equal(t, expectedContainerImage, headNodeSpec.Containers[0].Image) + + // Worker node + workerNodeSpec := rayJob.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec + assert.Equal(t, expectedContainerImage, workerNodeSpec.Containers[0].Image) + }) + } +} + func TestBuildResourceRayExtendedResources(t *testing.T) { assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{ GpuDeviceNodeLabel: "gpu-node-label", @@ -312,7 +370,7 @@ func TestBuildResourceRayExtendedResources(t *testing.T) { t.Run(p.name, func(t *testing.T) { taskTemplate := dummyRayTaskTemplate("ray-id", dummyRayCustomObj()) taskTemplate.ExtendedResources = p.extendedResourcesBase - taskContext := dummyRayTaskContext(taskTemplate, p.resources, p.extendedResourcesOverride) + taskContext := dummyRayTaskContext(taskTemplate, p.resources, p.extendedResourcesOverride, "") rayJobResourceHandler := rayJobResourceHandler{} r, err := rayJobResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) @@ -371,7 +429,7 @@ func TestDefaultStartParameters(t *testing.T) { err := config.SetK8sPluginConfig(&config.K8sPluginConfig{DefaultTolerations: toleration}) assert.Nil(t, err) - RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), dummyRayTaskContext(taskTemplate, resourceRequirements, nil)) + RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), dummyRayTaskContext(taskTemplate, resourceRequirements, nil, "")) assert.Nil(t, err) assert.NotNil(t, RayResource) @@ -577,7 +635,7 @@ func TestInjectLogsSidecar(t *testing.T) { assert.NoError(t, SetConfig(&Config{ LogsSidecar: p.logsSidecarCfg, })) - taskContext := dummyRayTaskContext(&p.taskTemplate, resourceRequirements, nil) + taskContext := dummyRayTaskContext(&p.taskTemplate, resourceRequirements, nil, "") rayJobResourceHandler := rayJobResourceHandler{} r, err := rayJobResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err)