Skip to content

Commit

Permalink
Address comments and panic on error
Browse files Browse the repository at this point in the history
  • Loading branch information
mounchin committed Oct 7, 2024
1 parent 2f3bfc9 commit 8478742
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 25 deletions.
44 changes: 22 additions & 22 deletions ray-operator/controllers/ray/common/pod.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ import (
"strconv"
"strings"

"github.com/go-logr/logr"

"github.com/ray-project/kuberay/ray-operator/controllers/ray/utils"

"k8s.io/apimachinery/pkg/api/resource"
Expand Down Expand Up @@ -760,7 +758,10 @@ func generateRayStartCommand(ctx context.Context, nodeType rayv1.RayNodeType, ra
}

// Add GPU and custom accelerator resources to rayStartParams if not already present.
addWellKnownAcceleratorResources(log, rayStartParams, resource.Limits)
err := addWellKnownAcceleratorResources(rayStartParams, resource.Limits)
if err != nil {
panic(fmt.Errorf("failed to add accelerator resources to rayStartParams: %w", err))
}

rayStartCmd := ""
switch nodeType {
Expand All @@ -775,8 +776,11 @@ func generateRayStartCommand(ctx context.Context, nodeType rayv1.RayNodeType, ra
return rayStartCmd
}

func addWellKnownAcceleratorResources(log logr.Logger, rayStartParams map[string]string, resourceLimits corev1.ResourceList) {
resourcesMap, _ := getResourcesMap(rayStartParams)
func addWellKnownAcceleratorResources(rayStartParams map[string]string, resourceLimits corev1.ResourceList) error {
resourcesMap, err := getResourcesMap(rayStartParams)
if err != nil {
return fmt.Errorf("failed to get resources map from rayStartParams: %w", err)
}

// Flag to track if any custom accelerator resource are present/added in rayStartParams resources.
isCustomAcceleratorResourceAdded := isCustomAcceleratorPresentInResources(resourcesMap)
Expand All @@ -792,15 +796,25 @@ func addWellKnownAcceleratorResources(log logr.Logger, rayStartParams map[string
}

// Add the first encountered custom accelerator resource from the resource limits to the rayStartParams if not already present
if resourcesMap != nil && !isCustomAcceleratorResourceAdded {
if !isCustomAcceleratorResourceAdded {
if rayResourceName, ok := customAcceleratorToRayResourceMap[resourceKeyString]; ok && !resourceValue.IsZero() {
if err := addCustomAcceleratorToResourcesIfNotExists(rayStartParams, resourcesMap, rayResourceName, resourceValue.Value()); err != nil {
log.Error(err, fmt.Sprintf("failed to add %s to resources", rayResourceName))
if _, exists := resourcesMap[rayResourceName]; !exists {
resourcesMap[rayResourceName] = float64(resourceValue.Value())

// Update the resources map in the rayStartParams
updatedResourcesStr, err := json.Marshal(resourcesMap)
if err != nil {
return fmt.Errorf("failed to marshal resources map to string: %w", err)
}

rayStartParams["resources"] = string(updatedResourcesStr)
}
isCustomAcceleratorResourceAdded = true
}
}
}

return nil
}

func isCustomAcceleratorPresentInResources(resourcesMap map[string]float64) bool {
Expand All @@ -816,20 +830,6 @@ func isCustomAcceleratorPresentInResources(resourcesMap map[string]float64) bool
return false
}

func addCustomAcceleratorToResourcesIfNotExists(rayStartParams map[string]string, resourcesMap map[string]float64, resourceName string, resourceCount int64) error {
if _, exists := resourcesMap[resourceName]; !exists {
resourcesMap[resourceName] = float64(resourceCount)
}

updatedResourcesStr, err := json.Marshal(resourcesMap)
if err != nil {
return fmt.Errorf("failed to marshal resources map to string %w", err)
}

rayStartParams["resources"] = string(updatedResourcesStr)
return nil
}

func getResourcesMap(rayStartParams map[string]string) (map[string]float64, error) {
var resources map[string]float64
if resourcesStr, ok := rayStartParams["resources"]; !ok {
Expand Down
25 changes: 22 additions & 3 deletions ray-operator/controllers/ray/common/pod_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1161,6 +1161,7 @@ func TestGenerateRayStartCommand(t *testing.T) {
rayStartParams map[string]string
name string
expected string
err string
nodeType rayv1.RayNodeType
resource corev1.ResourceRequirements
}{
Expand All @@ -1186,6 +1187,18 @@ func TestGenerateRayStartCommand(t *testing.T) {
},
expected: `ray start --head --resources={"neuron_cores":4} `,
},
{
name: "HeadNode with multiple accelerators",
nodeType: rayv1.HeadNode,
rayStartParams: map[string]string{},
resource: corev1.ResourceRequirements{
Limits: corev1.ResourceList{
"aws.amazon.com/neuroncore": resource.MustParse("4"),
"nvidia.com/gpu": resource.MustParse("1"),
},
},
expected: `ray start --head --resources={"neuron_cores":4} --num-gpus=1 `,
},
{
name: "HeadNode with existing resources",
nodeType: rayv1.HeadNode,
Expand Down Expand Up @@ -1223,7 +1236,7 @@ func TestGenerateRayStartCommand(t *testing.T) {
"aws.amazon.com/neuroncore": resource.MustParse("4"),
},
},
expected: "ray start --head --resources={ ",
err: "failed to add accelerator resources to rayStartParams: failed to get resources map from rayStartParams: failed to unmarshal resources unexpected end of JSON input",
},
{
name: "Invalid node type",
Expand All @@ -1236,8 +1249,14 @@ func TestGenerateRayStartCommand(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := generateRayStartCommand(context.TODO(), tt.nodeType, tt.rayStartParams, tt.resource)
assert.Equal(t, tt.expected, result)
if tt.err != "" {
assert.PanicsWithError(t, tt.err, func() {
generateRayStartCommand(context.TODO(), tt.nodeType, tt.rayStartParams, tt.resource)
})
} else {
result := generateRayStartCommand(context.TODO(), tt.nodeType, tt.rayStartParams, tt.resource)
assert.Equal(t, tt.expected, result)
}
})
}
}

0 comments on commit 8478742

Please sign in to comment.