diff --git a/ray-operator/controllers/ray/common/pod.go b/ray-operator/controllers/ray/common/pod.go index 4d440dd62c..c8c86736b8 100644 --- a/ray-operator/controllers/ray/common/pod.go +++ b/ray-operator/controllers/ray/common/pod.go @@ -9,8 +9,6 @@ import ( "strconv" "strings" - "github.com/go-logr/logr" - "github.com/ray-project/kuberay/ray-operator/controllers/ray/utils" "k8s.io/apimachinery/pkg/api/resource" @@ -760,7 +758,10 @@ func generateRayStartCommand(ctx context.Context, nodeType rayv1.RayNodeType, ra } // Add GPU and custom accelerator resources to rayStartParams if not already present. - addWellKnownAcceleratorResources(log, rayStartParams, resource.Limits) + err := addWellKnownAcceleratorResources(rayStartParams, resource.Limits) + if err != nil { + panic(fmt.Errorf("failed to add accelerator resources to rayStartParams: %w", err)) + } rayStartCmd := "" switch nodeType { @@ -775,8 +776,11 @@ func generateRayStartCommand(ctx context.Context, nodeType rayv1.RayNodeType, ra return rayStartCmd } -func addWellKnownAcceleratorResources(log logr.Logger, rayStartParams map[string]string, resourceLimits corev1.ResourceList) { - resourcesMap, _ := getResourcesMap(rayStartParams) +func addWellKnownAcceleratorResources(rayStartParams map[string]string, resourceLimits corev1.ResourceList) error { + resourcesMap, err := getResourcesMap(rayStartParams) + if err != nil { + return fmt.Errorf("failed to get resources map from rayStartParams: %w", err) + } // Flag to track if any custom accelerator resource are present/added in rayStartParams resources. isCustomAcceleratorResourceAdded := isCustomAcceleratorPresentInResources(resourcesMap) @@ -792,15 +796,25 @@ func addWellKnownAcceleratorResources(log logr.Logger, rayStartParams map[string } // Add the first encountered custom accelerator resource from the resource limits to the rayStartParams if not already present - if resourcesMap != nil && !isCustomAcceleratorResourceAdded { + if !isCustomAcceleratorResourceAdded { if rayResourceName, ok := customAcceleratorToRayResourceMap[resourceKeyString]; ok && !resourceValue.IsZero() { - if err := addCustomAcceleratorToResourcesIfNotExists(rayStartParams, resourcesMap, rayResourceName, resourceValue.Value()); err != nil { - log.Error(err, fmt.Sprintf("failed to add %s to resources", rayResourceName)) + if _, exists := resourcesMap[rayResourceName]; !exists { + resourcesMap[rayResourceName] = float64(resourceValue.Value()) + + // Update the resources map in the rayStartParams + updatedResourcesStr, err := json.Marshal(resourcesMap) + if err != nil { + return fmt.Errorf("failed to marshal resources map to string: %w", err) + } + + rayStartParams["resources"] = string(updatedResourcesStr) } isCustomAcceleratorResourceAdded = true } } } + + return nil } func isCustomAcceleratorPresentInResources(resourcesMap map[string]float64) bool { @@ -816,20 +830,6 @@ func isCustomAcceleratorPresentInResources(resourcesMap map[string]float64) bool return false } -func addCustomAcceleratorToResourcesIfNotExists(rayStartParams map[string]string, resourcesMap map[string]float64, resourceName string, resourceCount int64) error { - if _, exists := resourcesMap[resourceName]; !exists { - resourcesMap[resourceName] = float64(resourceCount) - } - - updatedResourcesStr, err := json.Marshal(resourcesMap) - if err != nil { - return fmt.Errorf("failed to marshal resources map to string %w", err) - } - - rayStartParams["resources"] = string(updatedResourcesStr) - return nil -} - func getResourcesMap(rayStartParams map[string]string) (map[string]float64, error) { var resources map[string]float64 if resourcesStr, ok := rayStartParams["resources"]; !ok { diff --git a/ray-operator/controllers/ray/common/pod_test.go b/ray-operator/controllers/ray/common/pod_test.go index a77c7ec872..ca00c33dbb 100644 --- a/ray-operator/controllers/ray/common/pod_test.go +++ b/ray-operator/controllers/ray/common/pod_test.go @@ -1161,6 +1161,7 @@ func TestGenerateRayStartCommand(t *testing.T) { rayStartParams map[string]string name string expected string + err string nodeType rayv1.RayNodeType resource corev1.ResourceRequirements }{ @@ -1186,6 +1187,18 @@ func TestGenerateRayStartCommand(t *testing.T) { }, expected: `ray start --head --resources={"neuron_cores":4} `, }, + { + name: "HeadNode with multiple accelerators", + nodeType: rayv1.HeadNode, + rayStartParams: map[string]string{}, + resource: corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + "aws.amazon.com/neuroncore": resource.MustParse("4"), + "nvidia.com/gpu": resource.MustParse("1"), + }, + }, + expected: `ray start --head --resources={"neuron_cores":4} --num-gpus=1 `, + }, { name: "HeadNode with existing resources", nodeType: rayv1.HeadNode, @@ -1223,7 +1236,7 @@ func TestGenerateRayStartCommand(t *testing.T) { "aws.amazon.com/neuroncore": resource.MustParse("4"), }, }, - expected: "ray start --head --resources={ ", + err: "failed to add accelerator resources to rayStartParams: failed to get resources map from rayStartParams: failed to unmarshal resources unexpected end of JSON input", }, { name: "Invalid node type", @@ -1236,8 +1249,14 @@ func TestGenerateRayStartCommand(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := generateRayStartCommand(context.TODO(), tt.nodeType, tt.rayStartParams, tt.resource) - assert.Equal(t, tt.expected, result) + if tt.err != "" { + assert.PanicsWithError(t, tt.err, func() { + generateRayStartCommand(context.TODO(), tt.nodeType, tt.rayStartParams, tt.resource) + }) + } else { + result := generateRayStartCommand(context.TODO(), tt.nodeType, tt.rayStartParams, tt.resource) + assert.Equal(t, tt.expected, result) + } }) } }