diff --git a/ray-operator/controllers/ray/common/pod.go b/ray-operator/controllers/ray/common/pod.go index 7c7a3a031d..dba8056e7e 100644 --- a/ray-operator/controllers/ray/common/pod.go +++ b/ray-operator/controllers/ray/common/pod.go @@ -748,6 +748,12 @@ func generateRayStartCommand(ctx context.Context, nodeType rayv1.RayNodeType, ra cpu := resource.Limits[corev1.ResourceCPU] if !cpu.IsZero() { rayStartParams["num-cpus"] = strconv.FormatInt(cpu.Value(), 10) + } else { + // Fall back to CPU request if limit is not specified + cpu := resource.Requests[corev1.ResourceCPU] + if !cpu.IsZero() { + rayStartParams["num-cpus"] = strconv.FormatInt(cpu.Value(), 10) + } } } diff --git a/ray-operator/controllers/ray/common/pod_test.go b/ray-operator/controllers/ray/common/pod_test.go index 831bfd4e5a..5411dc7873 100644 --- a/ray-operator/controllers/ray/common/pod_test.go +++ b/ray-operator/controllers/ray/common/pod_test.go @@ -70,8 +70,7 @@ var instance = rayv1.RayCluster{ MaxReplicas: ptr.To[int32](10000), GroupName: "small-group", RayStartParams: map[string]string{ - "port": "6379", - "num-cpus": "1", + "port": "6379", }, Template: corev1.PodTemplateSpec{ ObjectMeta: metav1.ObjectMeta{ @@ -385,6 +384,54 @@ func TestBuildPod(t *testing.T) { checkContainerEnv(t, rayContainer, "TEST_ENV_NAME", "TEST_ENV_VALUE") } +func TestBuildPod_WithNoCPULimits(t *testing.T) { + cluster := instance.DeepCopy() + ctx := context.Background() + + cluster.Spec.HeadGroupSpec.Template.Spec.Containers[utils.RayContainerIndex].Resources = corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("2"), + corev1.ResourceMemory: testMemoryLimit, + }, + Limits: corev1.ResourceList{ + corev1.ResourceMemory: testMemoryLimit, + }, + } + cluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[utils.RayContainerIndex].Resources = corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("2"), + corev1.ResourceMemory: testMemoryLimit, + }, + + Limits: corev1.ResourceList{ + corev1.ResourceMemory: testMemoryLimit, + "nvidia.com/gpu": resource.MustParse("3"), + }, + } + + // Test head pod + podName := strings.ToLower(cluster.Name + utils.DashSymbol + string(rayv1.HeadNode) + utils.DashSymbol + utils.FormatInt32(0)) + podTemplateSpec := DefaultHeadPodTemplate(ctx, *cluster, cluster.Spec.HeadGroupSpec, podName, "6379") + pod := BuildPod(ctx, podTemplateSpec, rayv1.HeadNode, cluster.Spec.HeadGroupSpec.RayStartParams, "6379", nil, utils.GetCRDType(""), "") + expectedCommandArg := splitAndSort("ulimit -n 65536; ray start --head --block --dashboard-agent-listen-port=52365 --memory=1073741824 --num-cpus=2 --metrics-export-port=8080 --dashboard-host=0.0.0.0") + actualCommandArg := splitAndSort(pod.Spec.Containers[0].Args[0]) + if !reflect.DeepEqual(expectedCommandArg, actualCommandArg) { + t.Fatalf("Expected `%v` but got `%v`", expectedCommandArg, actualCommandArg) + } + + // testing worker pod + worker := cluster.Spec.WorkerGroupSpecs[0] + podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0) + fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace) + podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379") + pod = BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", nil, utils.GetCRDType(""), fqdnRayIP) + expectedCommandArg = splitAndSort("ulimit -n 65536; ray start --block --dashboard-agent-listen-port=52365 --memory=1073741824 --num-cpus=2 --num-gpus=3 --address=raycluster-sample-head-svc.default.svc.cluster.local:6379 --port=6379 --metrics-export-port=8080") + actualCommandArg = splitAndSort(pod.Spec.Containers[0].Args[0]) + if !reflect.DeepEqual(expectedCommandArg, actualCommandArg) { + t.Fatalf("Expected `%v` but got `%v`", expectedCommandArg, actualCommandArg) + } +} + func TestBuildPod_WithOverwriteCommand(t *testing.T) { ctx := context.Background()