Skip to content

Commit daac294

Browse files
committed
[POC] Prototype multi-host indexing
Signed-off-by: Aaron Liang <[email protected]>
1 parent cb86f9f commit daac294

File tree

7 files changed

+235
-31
lines changed

7 files changed

+235
-31
lines changed

ray-operator/controllers/ray/common/pod.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717

1818
rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
1919
"github.com/ray-project/kuberay/ray-operator/controllers/ray/utils"
20+
"github.com/ray-project/kuberay/ray-operator/pkg/features"
2021
)
2122

2223
const (
@@ -244,7 +245,7 @@ func getEnableProbesInjection() bool {
244245
}
245246

246247
// DefaultWorkerPodTemplate sets the config values
247-
func DefaultWorkerPodTemplate(ctx context.Context, instance rayv1.RayCluster, workerSpec rayv1.WorkerGroupSpec, podName string, fqdnRayIP string, headPort string) corev1.PodTemplateSpec {
248+
func DefaultWorkerPodTemplate(ctx context.Context, instance rayv1.RayCluster, workerSpec rayv1.WorkerGroupSpec, podName string, fqdnRayIP string, headPort string, replicaGrpName string, numHostIndex int) corev1.PodTemplateSpec {
248249
podTemplate := workerSpec.Template
249250
podTemplate.GenerateName = podName
250251
// Pods created by RayCluster should be restricted to the namespace of the RayCluster.
@@ -315,6 +316,11 @@ func DefaultWorkerPodTemplate(ctx context.Context, instance rayv1.RayCluster, wo
315316
podTemplate.Labels = make(map[string]string)
316317
}
317318
podTemplate.Labels = labelPod(rayv1.WorkerNode, instance.Name, workerSpec.GroupName, workerSpec.Template.ObjectMeta.Labels)
319+
// Add additional labels for RayMultihostIndexing
320+
multihostIndexingEnabled := features.Enabled(features.RayMulithostIndexing) && workerSpec.NumOfHosts > 1
321+
if multihostIndexingEnabled {
322+
podTemplate.Labels = addMultihostIndexingPodLabels(podTemplate.Labels, replicaGrpName, numHostIndex)
323+
}
318324
workerSpec.RayStartParams = setMissingRayStartParams(ctx, workerSpec.RayStartParams, rayv1.WorkerNode, headPort, fqdnRayIP)
319325

320326
initTemplateAnnotations(instance, &podTemplate)
@@ -628,6 +634,15 @@ func labelPod(rayNodeType rayv1.RayNodeType, rayClusterName string, groupName st
628634
return labels
629635
}
630636

637+
// addMultihostIndexingPodLabels returns labels that contain RayMultihostIndexing feature labels
638+
func addMultihostIndexingPodLabels(currentLabels map[string]string, replicaGrpName string, numHostIndex int) map[string]string {
639+
labels := currentLabels
640+
labels[utils.RayWorkerReplicaIndexKey] = replicaGrpName
641+
labels[utils.RayHostIndexKey] = strconv.Itoa(numHostIndex)
642+
643+
return labels
644+
}
645+
631646
func setInitContainerEnvVars(container *corev1.Container, fqdnRayIP string) {
632647
if len(container.Env) == 0 {
633648
container.Env = []corev1.EnvVar{}

ray-operator/controllers/ray/common/pod_test.go

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020

2121
rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
2222
"github.com/ray-project/kuberay/ray-operator/controllers/ray/utils"
23+
"github.com/ray-project/kuberay/ray-operator/pkg/features"
2324
)
2425

2526
var testMemoryLimit = resource.MustParse("1Gi")
@@ -681,7 +682,7 @@ func TestBuildPod(t *testing.T) {
681682
worker := cluster.Spec.WorkerGroupSpecs[0]
682683
podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
683684
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
684-
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379")
685+
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
685686
pod = BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.GetCRDType(""), fqdnRayIP)
686687

687688
// Check resources
@@ -752,7 +753,7 @@ func TestBuildPod_WithNoCPULimits(t *testing.T) {
752753
worker := cluster.Spec.WorkerGroupSpecs[0]
753754
podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
754755
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
755-
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379")
756+
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
756757
pod = BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.GetCRDType(""), fqdnRayIP)
757758
expectedCommandArg = splitAndSort("ulimit -n 65536; ray start --block --dashboard-agent-listen-port=52365 --memory=1073741824 --num-cpus=2 --num-gpus=3 --address=raycluster-sample-head-svc.default.svc.cluster.local:6379 --port=6379 --metrics-export-port=8080")
758759
actualCommandArg = splitAndSort(pod.Spec.Containers[0].Args[0])
@@ -783,7 +784,7 @@ func TestBuildPod_WithOverwriteCommand(t *testing.T) {
783784
worker := cluster.Spec.WorkerGroupSpecs[0]
784785
podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
785786
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
786-
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379")
787+
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
787788
workerPod := BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.GetCRDType(""), fqdnRayIP)
788789
workerContainer := workerPod.Spec.Containers[utils.RayContainerIndex]
789790
assert.Equal(t, []string{"I am worker"}, workerContainer.Command)
@@ -838,7 +839,7 @@ func TestBuildPod_WithCreatedByRayService(t *testing.T) {
838839
worker := cluster.Spec.WorkerGroupSpecs[0]
839840
podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
840841
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
841-
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379")
842+
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
842843
pod = BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.RayServiceCRD, fqdnRayIP)
843844

844845
val, ok = pod.Labels[utils.RayClusterServingServiceLabelKey]
@@ -894,7 +895,7 @@ func TestBuildPod_WithLoginBash(t *testing.T) {
894895
worker := cluster.Spec.WorkerGroupSpecs[0]
895896
podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
896897
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
897-
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379")
898+
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
898899
workerPod := BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.RayServiceCRD, fqdnRayIP)
899900

900901
// Verify worker container command
@@ -1157,11 +1158,33 @@ func TestDefaultWorkerPodTemplateWithName(t *testing.T) {
11571158
expectedWorker := *worker.DeepCopy()
11581159

11591160
// Pass a deep copy of worker (*worker.DeepCopy()) to prevent "worker" from updating.
1160-
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379")
1161+
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", "", 0)
11611162
assert.Empty(t, podTemplateSpec.ObjectMeta.Name)
11621163
assert.Equal(t, expectedWorker, worker)
11631164
}
11641165

1166+
func TestDeafultWorkerPodTemplateWithReplicaGrpAndIndex(t *testing.T) {
1167+
ctx := context.Background()
1168+
1169+
cluster := instance.DeepCopy()
1170+
1171+
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
1172+
worker := cluster.Spec.WorkerGroupSpecs[0]
1173+
1174+
features.SetFeatureGateDuringTest(t, features.RayMulithostIndexing, true)
1175+
1176+
worker.Template.ObjectMeta.Name = "ray-worker-test"
1177+
worker.NumOfHosts = 4
1178+
podName := cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
1179+
groupReplicaName := utils.GenerateRayWorkerReplicaGroupName(worker.GroupName)
1180+
1181+
// Pass a deep copy of worker (*worker.DeepCopy()) to prevent "worker" from updating.
1182+
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", groupReplicaName, 2)
1183+
assert.Empty(t, podTemplateSpec.ObjectMeta.Name)
1184+
assert.Equal(t, podTemplateSpec.Labels[utils.RayWorkerReplicaIndexKey], groupReplicaName)
1185+
assert.Equal(t, "2", podTemplateSpec.Labels[utils.RayHostIndexKey])
1186+
}
1187+
11651188
func containerPortExists(ports []corev1.ContainerPort, containerPort int32) error {
11661189
name := utils.MetricsPortName
11671190
for _, port := range ports {
@@ -1204,7 +1227,7 @@ func TestDefaultWorkerPodTemplateWithConfigurablePorts(t *testing.T) {
12041227
worker := cluster.Spec.WorkerGroupSpecs[0]
12051228
podName := cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
12061229
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
1207-
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379")
1230+
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
12081231
// DefaultWorkerPodTemplate will add the default metrics port if user doesn't specify it.
12091232
// Verify the default metrics port exists.
12101233
require.NoError(t, containerPortExists(podTemplateSpec.Spec.Containers[0].Ports, int32(utils.DefaultMetricsPort)))
@@ -1214,7 +1237,7 @@ func TestDefaultWorkerPodTemplateWithConfigurablePorts(t *testing.T) {
12141237
ContainerPort: customMetricsPort,
12151238
}
12161239
cluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Ports = []corev1.ContainerPort{metricsPort}
1217-
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379")
1240+
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
12181241
// Verify the custom metrics port exists.
12191242
require.NoError(t, containerPortExists(podTemplateSpec.Spec.Containers[0].Ports, customMetricsPort))
12201243
}
@@ -1253,7 +1276,7 @@ func TestDefaultWorkerPodTemplate_Autoscaling(t *testing.T) {
12531276

12541277
for name, tc := range tests {
12551278
t.Run(name, func(t *testing.T) {
1256-
podTemplateSpec := DefaultWorkerPodTemplate(ctx, tc.cluster, tc.cluster.Spec.WorkerGroupSpecs[0], podName, fqdnRayIP, "6379")
1279+
podTemplateSpec := DefaultWorkerPodTemplate(ctx, tc.cluster, tc.cluster.Spec.WorkerGroupSpecs[0], podName, fqdnRayIP, "6379", "", 0)
12571280
assert.Equal(t, tc.expectedRestartPolicy, podTemplateSpec.Spec.RestartPolicy)
12581281
})
12591282
}
@@ -1269,7 +1292,7 @@ func TestDefaultInitContainer(t *testing.T) {
12691292
expectedResult := len(cluster.Spec.WorkerGroupSpecs[0].Template.Spec.InitContainers) + 1
12701293

12711294
// Pass a deep copy of worker (*worker.DeepCopy()) to prevent "worker" from updating.
1272-
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379")
1295+
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", "", 0)
12731296
numInitContainers := len(podTemplateSpec.Spec.InitContainers)
12741297
assert.Equal(t, expectedResult, numInitContainers, "A default init container is expected to be added.")
12751298

@@ -1328,7 +1351,7 @@ func TestDefaultInitContainerImagePullPolicy(t *testing.T) {
13281351
// set ray container imagePullPolicy
13291352
worker.Template.Spec.Containers[utils.RayContainerIndex].ImagePullPolicy = tc.imagePullPolicy
13301353

1331-
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379")
1354+
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", "", 0)
13321355

13331356
healthCheckContainer := podTemplateSpec.Spec.InitContainers[len(podTemplateSpec.Spec.InitContainers)-1]
13341357
assert.Equal(t, tc.expectedPullPolicy, healthCheckContainer.ImagePullPolicy, "The ImagePullPolicy of the init container should be the same as the Ray container.")

0 commit comments

Comments
 (0)