From 631cd7c49ced8306cde07079eb54a473a0b0939e Mon Sep 17 00:00:00 2001 From: Antonin Stefanutti Date: Fri, 20 Oct 2023 19:09:50 +0200 Subject: [PATCH] [RayJob]: Always use target RayCluster image as default RayJob submitter image (#1548) --- ray-operator/controllers/ray/common/job.go | 20 +++------- .../controllers/ray/common/job_test.go | 21 ++++++++++ .../controllers/ray/rayjob_controller.go | 6 +-- .../ray/rayjob_controller_unit_test.go | 40 ++++++++++++++++--- 4 files changed, 63 insertions(+), 24 deletions(-) diff --git a/ray-operator/controllers/ray/common/job.go b/ray-operator/controllers/ray/common/job.go index c4dad4e288..14a25a7852 100644 --- a/ray-operator/controllers/ray/common/job.go +++ b/ray-operator/controllers/ray/common/job.go @@ -136,25 +136,15 @@ func GetK8sJobCommand(rayJobInstance *rayv1.RayJob) ([]string, error) { return k8sJobCommand, nil } -// getDefaultSubmitterTemplate creates a default submitter template for the Ray job. -func GetDefaultSubmitterTemplate(rayJobInstance *rayv1.RayJob) v1.PodTemplateSpec { - // Use the image of the Ray head to be defensive against version mismatch issues - var image string - if rayJobInstance.Spec.RayClusterSpec != nil && - len(rayJobInstance.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.Containers) > 0 { - image = rayJobInstance.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.Containers[0].Image - } - - if len(image) == 0 { - // If we can't find the image of the Ray head, fall back to the latest stable release. - image = "rayproject/ray:latest" - } +// GetDefaultSubmitterTemplate creates a default submitter template for the Ray job. +func GetDefaultSubmitterTemplate(rayClusterInstance *rayv1.RayCluster) v1.PodTemplateSpec { return v1.PodTemplateSpec{ Spec: v1.PodSpec{ Containers: []v1.Container{ { - Name: "ray-job-submitter", - Image: image, + Name: "ray-job-submitter", + // Use the image of the Ray head to be defensive against version mismatch issues + Image: rayClusterInstance.Spec.HeadGroupSpec.Template.Spec.Containers[RayContainerIndex].Image, Resources: v1.ResourceRequirements{ Limits: v1.ResourceList{ v1.ResourceCPU: resource.MustParse("1"), diff --git a/ray-operator/controllers/ray/common/job_test.go b/ray-operator/controllers/ray/common/job_test.go index 2b468c2159..ed33b90947 100644 --- a/ray-operator/controllers/ray/common/job_test.go +++ b/ray-operator/controllers/ray/common/job_test.go @@ -6,6 +6,7 @@ import ( rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" "github.com/stretchr/testify/assert" + corev1 "k8s.io/api/core/v1" ) var testRayJob = &rayv1.RayJob{ @@ -178,3 +179,23 @@ func TestMetadataRaisesErrorBeforeRay26(t *testing.T) { _, err := GetMetadataJson(rayJob.Spec.Metadata, rayJob.Spec.RayClusterSpec.RayVersion) assert.Error(t, err) } + +func TestGetDefaultSubmitterTemplate(t *testing.T) { + rayCluster := &rayv1.RayCluster{ + Spec: rayv1.RayClusterSpec{ + HeadGroupSpec: rayv1.HeadGroupSpec{ + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Image: "rayproject/ray:test-submitter-template", + }, + }, + }, + }, + }, + }, + } + template := GetDefaultSubmitterTemplate(rayCluster) + assert.Equal(t, template.Spec.Containers[0].Image, rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers[RayContainerIndex].Image) +} diff --git a/ray-operator/controllers/ray/rayjob_controller.go b/ray-operator/controllers/ray/rayjob_controller.go index f6c948e726..ee88b2ba92 100644 --- a/ray-operator/controllers/ray/rayjob_controller.go +++ b/ray-operator/controllers/ray/rayjob_controller.go @@ -336,7 +336,7 @@ func (r *RayJobReconciler) getOrCreateK8sJob(ctx context.Context, rayJobInstance job := &batchv1.Job{} if err := r.Client.Get(ctx, client.ObjectKey{Namespace: jobNamespace, Name: jobName}, job); err != nil { if errors.IsNotFound(err) { - submitterTemplate, err := r.getSubmitterTemplate(rayJobInstance) + submitterTemplate, err := r.getSubmitterTemplate(rayJobInstance, rayClusterInstance) if err != nil { r.Log.Error(err, "failed to get submitter template") return "", false, err @@ -354,12 +354,12 @@ func (r *RayJobReconciler) getOrCreateK8sJob(ctx context.Context, rayJobInstance } // getSubmitterTemplate builds the submitter pod template for the Ray job. -func (r *RayJobReconciler) getSubmitterTemplate(rayJobInstance *rayv1.RayJob) (v1.PodTemplateSpec, error) { +func (r *RayJobReconciler) getSubmitterTemplate(rayJobInstance *rayv1.RayJob, rayClusterInstance *rayv1.RayCluster) (v1.PodTemplateSpec, error) { var submitterTemplate v1.PodTemplateSpec // Set the default value for the optional field SubmitterPodTemplate if not provided. if rayJobInstance.Spec.SubmitterPodTemplate == nil { - submitterTemplate = common.GetDefaultSubmitterTemplate(rayJobInstance) + submitterTemplate = common.GetDefaultSubmitterTemplate(rayClusterInstance) r.Log.Info("default submitter template is used") } else { submitterTemplate = *rayJobInstance.Spec.SubmitterPodTemplate.DeepCopy() diff --git a/ray-operator/controllers/ray/rayjob_controller_unit_test.go b/ray-operator/controllers/ray/rayjob_controller_unit_test.go index 57b0f00833..b5f907b491 100644 --- a/ray-operator/controllers/ray/rayjob_controller_unit_test.go +++ b/ray-operator/controllers/ray/rayjob_controller_unit_test.go @@ -27,6 +27,19 @@ func TestGetOrCreateK8sJob(t *testing.T) { Name: "test-raycluster", Namespace: "default", }, + Spec: rayv1.RayClusterSpec{ + HeadGroupSpec: rayv1.HeadGroupSpec{ + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Image: "rayproject/ray", + }, + }, + }, + }, + }, + }, } rayJob := &rayv1.RayJob{ @@ -114,30 +127,45 @@ func TestGetSubmitterTemplate(t *testing.T) { DashboardURL: "test-url", }, } + rayClusterInstance := &rayv1.RayCluster{ + Spec: rayv1.RayClusterSpec{ + HeadGroupSpec: rayv1.HeadGroupSpec{ + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Image: "rayproject/ray:custom-version", + }, + }, + }, + }, + }, + }, + } r := &RayJobReconciler{ Log: ctrl.Log.WithName("controllers").WithName("RayJob"), } // Test 1: User provided template with command - submitterTemplate, err := r.getSubmitterTemplate(rayJobInstanceWithTemplate) + submitterTemplate, err := r.getSubmitterTemplate(rayJobInstanceWithTemplate, nil) assert.NoError(t, err) assert.Equal(t, "user-command", submitterTemplate.Spec.Containers[common.RayContainerIndex].Command[0]) // Test 2: User provided template without command rayJobInstanceWithTemplate.Spec.SubmitterPodTemplate.Spec.Containers[common.RayContainerIndex].Command = []string{} - submitterTemplate, err = r.getSubmitterTemplate(rayJobInstanceWithTemplate) + submitterTemplate, err = r.getSubmitterTemplate(rayJobInstanceWithTemplate, nil) assert.NoError(t, err) - assert.Equal(t, ([]string{"ray", "job", "submit", "--address", "http://test-url", "--", "echo", "hello", "world"}), submitterTemplate.Spec.Containers[common.RayContainerIndex].Command) + assert.Equal(t, []string{"ray", "job", "submit", "--address", "http://test-url", "--", "echo", "hello", "world"}, submitterTemplate.Spec.Containers[common.RayContainerIndex].Command) // Test 3: User did not provide template, should use the image of the Ray Head - submitterTemplate, err = r.getSubmitterTemplate(rayJobInstanceWithoutTemplate) + submitterTemplate, err = r.getSubmitterTemplate(rayJobInstanceWithoutTemplate, rayClusterInstance) assert.NoError(t, err) - assert.Equal(t, ([]string{"ray", "job", "submit", "--address", "http://test-url", "--", "echo", "hello", "world"}), submitterTemplate.Spec.Containers[common.RayContainerIndex].Command) + assert.Equal(t, []string{"ray", "job", "submit", "--address", "http://test-url", "--", "echo", "hello", "world"}, submitterTemplate.Spec.Containers[common.RayContainerIndex].Command) assert.Equal(t, "rayproject/ray:custom-version", submitterTemplate.Spec.Containers[common.RayContainerIndex].Image) // Test 4: Check default PYTHONUNBUFFERED setting - submitterTemplate, err = r.getSubmitterTemplate(rayJobInstanceWithoutTemplate) + submitterTemplate, err = r.getSubmitterTemplate(rayJobInstanceWithoutTemplate, rayClusterInstance) assert.NoError(t, err) found := false for _, envVar := range submitterTemplate.Spec.Containers[common.RayContainerIndex].Env {