Skip to content

Commit d05d922

Browse files
committed
Fix: validation in right place
1. Remove validation logic in GetWorkerGroupDesiredReplicas (utils.go), and append to ValidateRayClusterSpec. 2. Add other validation logic for worker group specs. 3. Remove unnecessary test cases in GetWorkerGroupDesiredReplicas.
1 parent 1d04c34 commit d05d922

File tree

4 files changed

+32
-36
lines changed

4 files changed

+32
-36
lines changed

ray-operator/controllers/ray/raycluster_controller.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -642,7 +642,7 @@ func (r *RayClusterReconciler) reconcilePods(ctx context.Context, instance *rayv
642642
continue
643643
}
644644
// workerReplicas will store the target number of pods for this worker group.
645-
numExpectedWorkerPods := int(utils.GetWorkerGroupDesiredReplicas(ctx, worker))
645+
numExpectedWorkerPods := int(utils.GetWorkerGroupDesiredReplicas(worker))
646646
logger.Info("reconcilePods", "desired workerReplicas (always adhering to minReplicas/maxReplica)", numExpectedWorkerPods, "worker group", worker.GroupName, "maxReplicas", worker.MaxReplicas, "minReplicas", worker.MinReplicas, "replicas", worker.Replicas)
647647

648648
workerPods := corev1.PodList{}

ray-operator/controllers/ray/utils/util.go

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -335,23 +335,16 @@ func GenerateIdentifier(clusterName string, nodeType rayv1.RayNodeType) string {
335335
return fmt.Sprintf("%s-%s", clusterName, nodeType)
336336
}
337337

338-
func GetWorkerGroupDesiredReplicas(ctx context.Context, workerGroupSpec rayv1.WorkerGroupSpec) int32 {
339-
log := ctrl.LoggerFrom(ctx)
338+
func GetWorkerGroupDesiredReplicas(workerGroupSpec rayv1.WorkerGroupSpec) int32 {
340339
// Always adhere to min/max replicas constraints.
341340
var workerReplicas int32
342341
if workerGroupSpec.Suspend != nil && *workerGroupSpec.Suspend {
343342
return 0
344343
}
345-
if *workerGroupSpec.MinReplicas > *workerGroupSpec.MaxReplicas {
346-
log.Info("minReplicas is greater than maxReplicas, using maxReplicas as desired replicas. "+
347-
"Please fix this to avoid any unexpected behaviors.", "minReplicas", *workerGroupSpec.MinReplicas, "maxReplicas", *workerGroupSpec.MaxReplicas)
348-
workerReplicas = *workerGroupSpec.MaxReplicas
349-
} else if workerGroupSpec.Replicas == nil || *workerGroupSpec.Replicas < *workerGroupSpec.MinReplicas {
350-
// Replicas is impossible to be nil as it has a default value assigned in the CRD.
351-
// Add this check to make testing easier.
344+
// Validation for replicas/min/max should be enforced in validation.go before reconcile proceeds.
345+
// Here we only compute the desired replicas within the already-validated bounds.
346+
if workerGroupSpec.Replicas == nil {
352347
workerReplicas = *workerGroupSpec.MinReplicas
353-
} else if *workerGroupSpec.Replicas > *workerGroupSpec.MaxReplicas {
354-
workerReplicas = *workerGroupSpec.MaxReplicas
355348
} else {
356349
workerReplicas = *workerGroupSpec.Replicas
357350
}
@@ -362,7 +355,7 @@ func GetWorkerGroupDesiredReplicas(ctx context.Context, workerGroupSpec rayv1.Wo
362355
func CalculateDesiredReplicas(ctx context.Context, cluster *rayv1.RayCluster) int32 {
363356
count := int32(0)
364357
for _, nodeGroup := range cluster.Spec.WorkerGroupSpecs {
365-
count += GetWorkerGroupDesiredReplicas(ctx, nodeGroup)
358+
count += GetWorkerGroupDesiredReplicas(nodeGroup)
366359
}
367360

368361
return count

ray-operator/controllers/ray/utils/util_test.go

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,6 @@ func TestGenerateHeadServiceName(t *testing.T) {
550550
}
551551

552552
func TestGetWorkerGroupDesiredReplicas(t *testing.T) {
553-
ctx := context.Background()
554553
// Test 1: `WorkerGroupSpec.Replicas` is nil.
555554
// `Replicas` is impossible to be nil in a real RayCluster CR as it has a default value assigned in the CRD.
556555
numOfHosts := int32(1)
@@ -562,37 +561,21 @@ func TestGetWorkerGroupDesiredReplicas(t *testing.T) {
562561
MinReplicas: &minReplicas,
563562
MaxReplicas: &maxReplicas,
564563
}
565-
assert.Equal(t, GetWorkerGroupDesiredReplicas(ctx, workerGroupSpec), minReplicas)
564+
assert.Equal(t, GetWorkerGroupDesiredReplicas(workerGroupSpec), minReplicas)
566565

567566
// Test 2: `WorkerGroupSpec.Replicas` is not nil and is within the range.
568567
replicas := int32(3)
569568
workerGroupSpec.Replicas = &replicas
570-
assert.Equal(t, GetWorkerGroupDesiredReplicas(ctx, workerGroupSpec), replicas)
569+
assert.Equal(t, GetWorkerGroupDesiredReplicas(workerGroupSpec), replicas)
571570

572-
// Test 3: `WorkerGroupSpec.Replicas` is not nil but is more than maxReplicas.
573-
replicas = int32(6)
574-
workerGroupSpec.Replicas = &replicas
575-
assert.Equal(t, GetWorkerGroupDesiredReplicas(ctx, workerGroupSpec), maxReplicas)
576-
577-
// Test 4: `WorkerGroupSpec.Replicas` is not nil but is less than minReplicas.
578-
replicas = int32(0)
579-
workerGroupSpec.Replicas = &replicas
580-
assert.Equal(t, GetWorkerGroupDesiredReplicas(ctx, workerGroupSpec), minReplicas)
581-
582-
// Test 5: `WorkerGroupSpec.Replicas` is nil and minReplicas is less than maxReplicas.
583-
workerGroupSpec.Replicas = nil
584-
workerGroupSpec.MinReplicas = &maxReplicas
585-
workerGroupSpec.MaxReplicas = &minReplicas
586-
assert.Equal(t, GetWorkerGroupDesiredReplicas(ctx, workerGroupSpec), *workerGroupSpec.MaxReplicas)
587-
588-
// Test 6: `WorkerGroupSpec.Suspend` is true.
571+
// Test 3: `WorkerGroupSpec.Suspend` is true.
589572
suspend := true
590573
workerGroupSpec.MinReplicas = &maxReplicas
591574
workerGroupSpec.MaxReplicas = &minReplicas
592575
workerGroupSpec.Suspend = &suspend
593-
assert.Zero(t, GetWorkerGroupDesiredReplicas(ctx, workerGroupSpec))
576+
assert.Zero(t, GetWorkerGroupDesiredReplicas(workerGroupSpec))
594577

595-
// Test 7: `WorkerGroupSpec.NumOfHosts` is 4.
578+
// Test 4: `WorkerGroupSpec.NumOfHosts` is 4.
596579
numOfHosts = int32(4)
597580
replicas = int32(5)
598581
suspend = false
@@ -601,7 +584,7 @@ func TestGetWorkerGroupDesiredReplicas(t *testing.T) {
601584
workerGroupSpec.Suspend = &suspend
602585
workerGroupSpec.MinReplicas = &minReplicas
603586
workerGroupSpec.MaxReplicas = &maxReplicas
604-
assert.Equal(t, GetWorkerGroupDesiredReplicas(ctx, workerGroupSpec), replicas*numOfHosts)
587+
assert.Equal(t, GetWorkerGroupDesiredReplicas(workerGroupSpec), replicas*numOfHosts)
605588
}
606589

607590
func TestCalculateMinAndMaxReplicas(t *testing.T) {

ray-operator/controllers/ray/utils/validation.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,26 @@ func ValidateRayClusterSpec(spec *rayv1.RayClusterSpec, annotations map[string]s
4343
if len(workerGroup.Template.Spec.Containers) == 0 {
4444
return fmt.Errorf("workerGroupSpec should have at least one container")
4545
}
46+
if workerGroup.MinReplicas == nil || workerGroup.MaxReplicas == nil {
47+
return fmt.Errorf("worker group %s must set both minReplicas and maxReplicas", workerGroup.GroupName)
48+
}
49+
if *workerGroup.MinReplicas < 0 {
50+
return fmt.Errorf("worker group %s has negative minReplicas %d", workerGroup.GroupName, *workerGroup.MinReplicas)
51+
}
52+
if *workerGroup.MaxReplicas < 0 {
53+
return fmt.Errorf("worker group %s has negative maxReplicas %d", workerGroup.GroupName, *workerGroup.MaxReplicas)
54+
}
55+
if *workerGroup.MinReplicas > *workerGroup.MaxReplicas {
56+
return fmt.Errorf("worker group %s has minReplicas %d greater than maxReplicas %d", workerGroup.GroupName, *workerGroup.MinReplicas, *workerGroup.MaxReplicas)
57+
}
58+
if workerGroup.Replicas != nil {
59+
if *workerGroup.Replicas < *workerGroup.MinReplicas {
60+
return fmt.Errorf("worker group %s has replicas %d smaller than minReplicas %d", workerGroup.GroupName, *workerGroup.Replicas, *workerGroup.MinReplicas)
61+
}
62+
if *workerGroup.Replicas > *workerGroup.MaxReplicas {
63+
return fmt.Errorf("worker group %s has replicas %d greater than maxReplicas %d", workerGroup.GroupName, *workerGroup.Replicas, *workerGroup.MaxReplicas)
64+
}
65+
}
4666
}
4767

4868
if annotations[RayFTEnabledAnnotationKey] != "" && spec.GcsFaultToleranceOptions != nil {

0 commit comments

Comments
 (0)