From 77c3be27ebb2841950646fe669c5cac54b6150dc Mon Sep 17 00:00:00 2001 From: ryanaoleary <113500783+ryanaoleary@users.noreply.github.com> Date: Fri, 5 Apr 2024 13:38:19 -0700 Subject: [PATCH] Support for Multiple Separate TPU Worker Groups per RayCluster (#467) * Support for multiple seperate TPU workergroups per RayCluster * Add namespace to slice struct, logs, and comments * Added unit tests for getReplicaIndex and getNextWorkerID * added two more test cases for edge cases * Fixed comments --- applications/ray/kuberay-tpu-webhook/go.mod | 2 + applications/ray/kuberay-tpu-webhook/go.sum | 2 + applications/ray/kuberay-tpu-webhook/main.go | 121 +++-- .../kuberay-tpu-webhook/webhook_main_test.go | 444 ++++++++++++++++++ 4 files changed, 520 insertions(+), 49 deletions(-) create mode 100644 applications/ray/kuberay-tpu-webhook/webhook_main_test.go diff --git a/applications/ray/kuberay-tpu-webhook/go.mod b/applications/ray/kuberay-tpu-webhook/go.mod index a5e8cfab8..d7e720fcc 100644 --- a/applications/ray/kuberay-tpu-webhook/go.mod +++ b/applications/ray/kuberay-tpu-webhook/go.mod @@ -4,6 +4,7 @@ go 1.21 require ( github.com/ray-project/kuberay/ray-operator v1.1.0-rc.0 + github.com/stretchr/testify v1.8.4 k8s.io/api v0.29.1 k8s.io/apimachinery v0.29.1 k8s.io/klog/v2 v2.120.1 @@ -37,6 +38,7 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_golang v1.18.0 // indirect github.com/prometheus/client_model v0.5.0 // indirect github.com/prometheus/common v0.45.0 // indirect diff --git a/applications/ray/kuberay-tpu-webhook/go.sum b/applications/ray/kuberay-tpu-webhook/go.sum index e4cf08f34..fa134e573 100644 --- a/applications/ray/kuberay-tpu-webhook/go.sum +++ b/applications/ray/kuberay-tpu-webhook/go.sum @@ -53,6 +53,8 @@ github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+ github.com/ianlancetaylor/demangle v0.0.0-20210905161508-09a460cdf81d/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w= github.com/imdario/mergo v0.3.12 h1:b6R2BslTbIEToALKP7LxUvijTsNI9TAe80pLWN2g/HU= github.com/imdario/mergo v0.3.12/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA= +github.com/jarcoal/httpmock v1.2.0 h1:gSvTxxFR/MEMfsGrvRbdfpRUMBStovlSRLw0Ep1bwwc= +github.com/jarcoal/httpmock v1.2.0/go.mod h1:oCoTsnAz4+UoOUIf5lJOWV2QQIW5UoeUI6aM2YnWAZk= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= diff --git a/applications/ray/kuberay-tpu-webhook/main.go b/applications/ray/kuberay-tpu-webhook/main.go index 7e2aaeb64..61da8534f 100755 --- a/applications/ray/kuberay-tpu-webhook/main.go +++ b/applications/ray/kuberay-tpu-webhook/main.go @@ -24,6 +24,7 @@ import ( type slice struct { clusterName string groupName string + namespace string replicaIndex int numOfHosts int32 } @@ -77,7 +78,7 @@ func containerRequestingTPUs(containers ...corev1.Container) bool { return false } -func getNumTPUHostsFromTopology(clusterName string, namespace string, topology string, acceleratorType string) (int32, error) { +func getNumTPUHostsFromTopology(clusterName string, groupName string, namespace string, topology string, acceleratorType string) (int32, error) { if topology == "" { return 0, errors.New("TPU topology not specified") } @@ -86,7 +87,7 @@ func getNumTPUHostsFromTopology(clusterName string, namespace string, topology s for i := 0; i < len(topologyVals); i++ { dim, err := strconv.Atoi(topologyVals[i]) if err != nil { - klog.ErrorS(err, "RayCluster", namespace+"/"+clusterName, "gke-tpu-topology", topology) + klog.ErrorS(err, "RayCluster", namespace+"/"+clusterName, "Worker Group", groupName, "gke-tpu-topology", topology) return 0, err } chips *= dim @@ -98,19 +99,19 @@ func getNumTPUHostsFromTopology(clusterName string, namespace string, topology s // v5e TPU VMs can have 1, 4 or 8 chips chipsPerHost, err := strconv.Atoi(acceleratorTypeValues[1]) if err != nil { - klog.ErrorS(err, "RayCluster", namespace+"/"+clusterName, "gke-tpu-accelerator", acceleratorType) + klog.ErrorS(err, "RayCluster", namespace+"/"+clusterName, "Worker Group", groupName, "gke-tpu-accelerator", acceleratorType) return 0, err } chipsPerHost = min(chipsPerHost, 8) // max of 8 chips per host } hosts := int32(max(chips/chipsPerHost, 1)) - klog.V(1).InfoS("getNumTPUHostsFromTopology", "RayCluster", namespace+"/"+clusterName, "hosts", hosts) + klog.V(1).InfoS("getNumTPUHostsFromTopology", "RayCluster", namespace+"/"+clusterName, "Worker Group", groupName, "hosts", hosts) return hosts, nil } // check if request is for TPU multi-host -func isTPUMultiHost(clusterName string, namespace string, topology string, acceleratorType string) (bool, error) { - vms, err := getNumTPUHostsFromTopology(clusterName, namespace, topology, acceleratorType) +func isTPUMultiHost(clusterName string, groupName string, namespace string, topology string, acceleratorType string) (bool, error) { + vms, err := getNumTPUHostsFromTopology(clusterName, groupName, namespace, topology, acceleratorType) if err != nil { return false, err } @@ -133,7 +134,7 @@ func extractRayCluster(admissionReview *admissionv1.AdmissionReview) (*ray.RayCl return &rayCluster, nil } -func genDNSHostnames(workerGroupSpec ray.WorkerGroupSpec, replicaIndex int) (string, error) { +func genDNSHostnames(workerGroupSpec ray.WorkerGroupSpec, clusterName string, namespace string, replicaIndex int) (string, error) { numHosts := workerGroupSpec.NumOfHosts if numHosts == 0 { return "", errors.New("workerGroupSpec NumOfHosts not set") @@ -144,6 +145,7 @@ func genDNSHostnames(workerGroupSpec ray.WorkerGroupSpec, replicaIndex int) (str for j := 0; j < int(numHosts); j++ { hostNames[j] = fmt.Sprintf("%s-%d-%d.%s", workerGroupName, replicaIndex, j, headlessServiceName) } + klog.V(1).InfoS("genDNSHostnames", "RayCluster", namespace+"/"+clusterName, "NumOfHosts", numHosts, "Replica Index", replicaIndex) return strings.Join(hostNames, ","), nil } @@ -218,6 +220,7 @@ func checkWorkersMatchTopology(clusterName string, namespace string, workerGroup if numHosts == 0 { return false, errors.New("workerGroupSpec NumOfHosts not set") } + groupName := workerGroupSpec.GroupName containers := workerGroupSpec.Template.Spec.Containers if containers == nil { return false, errors.New("Container path not specified") @@ -227,12 +230,12 @@ func checkWorkersMatchTopology(clusterName string, namespace string, workerGroup acceleratorType := workerGroupSpec.Template.Spec.NodeSelector["cloud.google.com/gke-tpu-accelerator"] klog.V(1).InfoS("checkWorkersMatchTopology", "RayCluster", namespace+"/"+clusterName, "topology", topology, "AcceleratorType", acceleratorType, "NumOfHosts", numHosts) if topology == "" { - klog.ErrorS(errors.New("TPU topology not specified"), "RayCluster", namespace+"/"+clusterName, "gke-tpu-topology", topology) + klog.ErrorS(errors.New("TPU topology not specified"), "checkWorkersMatchTopology", "RayCluster", namespace+"/"+clusterName, "gke-tpu-topology", topology) } if acceleratorType == "" { - klog.ErrorS(errors.New("TPU accelerator not specified"), "RayCluster", namespace+"/"+clusterName, "gke-tpu-accelerator", acceleratorType) + klog.ErrorS(errors.New("TPU accelerator not specified"), "checkWorkersMatchTopology", "RayCluster", namespace+"/"+clusterName, "gke-tpu-accelerator", acceleratorType) } - expectedHosts, err := getNumTPUHostsFromTopology(clusterName, namespace, topology, acceleratorType) + expectedHosts, err := getNumTPUHostsFromTopology(clusterName, groupName, namespace, topology, acceleratorType) if err != nil { return false, err } @@ -263,23 +266,29 @@ func validateRayCluster(admissionReview *admissionv1.AdmissionReview) (*admissio } for i := 0; i < len(workerGroupSpecs); i++ { workerGroupSpec := workerGroupSpecs[i] - // create mapping for pod slices -> TPU_WORKER_HOSTNAMES in cluster - replicas := int(*workerGroupSpec.Replicas) - numOfHosts := workerGroupSpec.NumOfHosts - for replicaIndex := 0; replicaIndex < replicas; replicaIndex++ { - // reset past sliceToWorkers and sliceToHostnames entries for slice in ray cluster - groupName := workerGroupSpec.GroupName - podSlice := slice{clusterName, groupName, replicaIndex, numOfHosts} - sliceToWorkers[podSlice] = nil - sliceToHostnames[podSlice] = "" - // generate TPU_WORKER_HOSTNAMES - if numOfHosts > 1 { - joinedHostNames, err := genDNSHostnames(workerGroupSpec, replicaIndex) - if err != nil { - klog.Error("Failed to generate DNS Hostnames") + if containerRequestingTPUs(workerGroupSpec.Template.Spec.Containers...) { + klog.V(0).InfoS("validateRayCluster", "RayCluster", namespace+"/"+clusterName, "Worker Group", workerGroupSpec.GroupName, "Requests TPUs", true) + // create mapping for pod slices -> TPU_WORKER_HOSTNAMES in cluster + replicas := int(*workerGroupSpec.Replicas) + numOfHosts := workerGroupSpec.NumOfHosts + for replicaIndex := 0; replicaIndex < replicas; replicaIndex++ { + // reset past sliceToWorkers and sliceToHostnames entries for slice in ray cluster + groupName := workerGroupSpec.GroupName + podSlice := slice{clusterName, groupName, namespace, replicaIndex, numOfHosts} + sliceToWorkers[podSlice] = nil + sliceToHostnames[podSlice] = "" + // generate TPU_WORKER_HOSTNAMES + if numOfHosts > 1 { + joinedHostNames, err := genDNSHostnames(workerGroupSpec, clusterName, namespace, replicaIndex) + if err != nil { + klog.Error("Failed to generate DNS Hostnames") + } + sliceToHostnames[podSlice] = joinedHostNames } - sliceToHostnames[podSlice] = joinedHostNames } + } else { + // RayCluster worker group does not request TPUs + klog.V(0).InfoS("validateRayCluster", "RayCluster", namespace+"/"+clusterName, "Worker Group", workerGroupSpec.GroupName, "Requests TPUs", false) } // validate NumOfHosts for worker group matches topology nodeSelector workersMatchTopology, err := checkWorkersMatchTopology(clusterName, namespace, workerGroupSpec) @@ -291,8 +300,8 @@ func validateRayCluster(admissionReview *admissionv1.AdmissionReview) (*admissio admit = false status = "Failure" message = "Number of workers in worker group not equal to specified topology" + break } - break } // Create AdmissionResponse @@ -318,15 +327,28 @@ func getEnvironmentVariable(varName string, container corev1.Container) string { return "" } -// get next lowest-index pod slice to assign a pod to in the RayCluster -// this will be the first pod slice with # created pods < NumOfHosts -func getReplicaIndex(clusterName string, namespace string) int { +// gets the next lowest-index pod slice (worker group replica) to assign a pod to in the RayCluster +// there are three possible cases here: +// 1. sliceToWorkers is empty, this is the first pod the webhook intercepts +// - assign this pod to replica 0 +// 2. The pod slice exists in sliceToWorkers, but has # created workers < NumOfHosts +// - assign this pod to the lowest index replica with # created workers < NumOfHosts +// - since we update isCreated when a worker is deleted, this allows us to assign re-created +// pods to the same replica +// 3. sliceToWorkers isn't empty, but all slices have # workers == NumOfHosts +// - this occurs when the pod we intercept is the first pod of a different slice in the cluster +// - we keep track of how many replicas of the same worker group have been added to sliceToWorkers +// so far, and assign this pod to the next integer replicaIndex +func getReplicaIndex(clusterName string, groupName string, namespace string) int { + // first pod created in cluster if sliceToWorkers == nil { return 0 } nextLowestId := math.MaxInt32 + numReplicas := 0 // tracks # of replicas in worker group created so far for slice, workerList := range sliceToWorkers { - if slice.clusterName == clusterName { + if slice.clusterName == clusterName && slice.groupName == groupName && slice.namespace == namespace { + numReplicas++ createdPods := 0 for _, worker := range workerList { if worker.isCreated { @@ -340,10 +362,11 @@ func getReplicaIndex(clusterName string, namespace string) int { } } } + // first pod of new slice in cluster if nextLowestId == math.MaxInt32 { - klog.ErrorS(errors.New("Replica Index never set"), "RayCluster", namespace+"/"+clusterName, "Replica Index", nextLowestId) + nextLowestId = numReplicas } - klog.V(0).InfoS("getReplicaIndex", "RayCluster", namespace+"/"+clusterName, "Replica Index", nextLowestId) + klog.V(0).InfoS("getReplicaIndex", "RayCluster", namespace+"/"+clusterName, "Worker Group", groupName, "Replica Index", nextLowestId) return nextLowestId } @@ -379,7 +402,7 @@ func getNextWorkerID(podSlice slice, namespace string, replicaIndex int) int { } tpuWorkerID = nextLowestID } - klog.V(0).InfoS("getNextWorkerID", "RayCluster", namespace+"/"+podSlice.clusterName, "TPU_WORKER_ID", tpuWorkerID) + klog.V(0).InfoS("getNextWorkerID", "RayCluster", namespace+"/"+podSlice.clusterName, "Worker Group", podSlice.groupName, "TPU_WORKER_ID", tpuWorkerID) return tpuWorkerID } @@ -417,31 +440,31 @@ func mutatePod(admissionReview *admissionv1.AdmissionReview) (*admissionv1.Admis if clusterName == "" { return nil, errors.New("Kuberay Pod missing RayCluster label") } - namespace := pod.Namespace - groupName := pod.Labels["ray.io/group"] - topology := pod.Spec.NodeSelector["cloud.google.com/gke-tpu-topology"] - acceleratorType := pod.Spec.NodeSelector["cloud.google.com/gke-tpu-accelerator"] - if topology == "" { - klog.ErrorS(errors.New("TPU topology not specified"), "RayCluster", namespace+"/"+clusterName, "gke-tpu-topology", topology) - } - if acceleratorType == "" { - klog.ErrorS(errors.New("TPU accelerator not specified"), "RayCluster", namespace+"/"+clusterName, "gke-tpu-accelerator", acceleratorType) - } containers := pod.Spec.Containers if containers == nil { return nil, errors.New("Container path not specified") } if containerRequestingTPUs(containers...) { + namespace := pod.Namespace + groupName := pod.Labels["ray.io/group"] + topology := pod.Spec.NodeSelector["cloud.google.com/gke-tpu-topology"] + acceleratorType := pod.Spec.NodeSelector["cloud.google.com/gke-tpu-accelerator"] + if topology == "" { + klog.ErrorS(errors.New("TPU topology not specified"), "mutatePod", "RayCluster", namespace+"/"+clusterName, "gke-tpu-topology", topology) + } + if acceleratorType == "" { + klog.ErrorS(errors.New("TPU accelerator not specified"), "mutatePod", "RayCluster", namespace+"/"+clusterName, "gke-tpu-accelerator", acceleratorType) + } // assign worker to the next unique ID in the pod slice and update map - numOfHosts, _ := getNumTPUHostsFromTopology(clusterName, namespace, topology, acceleratorType) // ignore error here because topology may not be set yet - replicaIndex := getReplicaIndex(clusterName, namespace) - podSlice := slice{clusterName, groupName, replicaIndex, numOfHosts} + numOfHosts, _ := getNumTPUHostsFromTopology(clusterName, groupName, namespace, topology, acceleratorType) // ignore error here because topology may not be set yet + replicaIndex := getReplicaIndex(clusterName, groupName, namespace) + podSlice := slice{clusterName, groupName, namespace, replicaIndex, numOfHosts} tpuWorkerID := getNextWorkerID(podSlice, namespace, replicaIndex) // defaults to 0 for single-host // inject replica index label injectReplicaLabel(clusterName, namespace, replicaIndex, groupName, &patches) - isMultiHost, _ := isTPUMultiHost(clusterName, namespace, topology, acceleratorType) // ignore error here because topology may not be set yet + isMultiHost, _ := isTPUMultiHost(clusterName, groupName, namespace, topology, acceleratorType) // ignore error here because topology may not be set yet if isMultiHost { // inject hostname into pod spec for DNS records hostname := fmt.Sprintf(groupName+"-%d-%d", replicaIndex, tpuWorkerID) @@ -545,7 +568,7 @@ func deletePod(admissionReview *admissionv1.AdmissionReview) (*admissionv1.Admis if replicaIndexLabel != "" { replicaIndexLabelValues := strings.Split(replicaIndexLabel, "-") replicaIndex, _ := strconv.Atoi(replicaIndexLabelValues[1]) // ignore error here since must be set - + containers := pod.Spec.Containers if containers == nil { return nil, errors.New("Pod spec missing containers") @@ -565,7 +588,7 @@ func deletePod(admissionReview *admissionv1.AdmissionReview) (*admissionv1.Admis } // update sliceToWorkers map for slice, _ := range sliceToWorkers { - if slice.clusterName == clusterName && slice.groupName == groupName && slice.replicaIndex == replicaIndex { + if slice.clusterName == clusterName && slice.groupName == groupName && slice.namespace == namespace && slice.replicaIndex == replicaIndex { // set the pod state to indicate it is not running for index, worker := range sliceToWorkers[slice] { if worker.workerIndex == tpuWorkerID { diff --git a/applications/ray/kuberay-tpu-webhook/webhook_main_test.go b/applications/ray/kuberay-tpu-webhook/webhook_main_test.go new file mode 100644 index 000000000..35bd2b5f5 --- /dev/null +++ b/applications/ray/kuberay-tpu-webhook/webhook_main_test.go @@ -0,0 +1,444 @@ +package main + +import ( + "testing" + + rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" + "github.com/ray-project/kuberay/ray-operator/controllers/ray/utils" + admissionv1 "k8s.io/api/admission/v1" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/labels" + "k8s.io/apimachinery/pkg/runtime" + + "github.com/stretchr/testify/assert" +) + +var ( + namespaceStr string + instanceName string + groupNameStr string + headGroupNameStr string + testPodAdmissionReviews *admissionv1.AdmissionReview + testCPUWorker *corev1.Pod + testTPUWorker *corev1.Pod + testRayClusterAdmissionReview *admissionv1.AdmissionReview + testRayClusterNoTPUs *rayv1.RayCluster + testRayClusterSingleHostTPU *rayv1.RayCluster + testRayClusterMultiHostTPU *rayv1.RayCluster + testServices []runtime.Object + workerSelector labels.Selector + headNodeIP string +) + +func setupTest(t *testing.T) { + namespaceStr = "test" + instanceName = "raycluster-test-sample" + headNodeIP = "1.2.3.4" + groupNameStr = "test-group-name" + headlessServiceSuffix = "headless-worker-svc" + + // CPU pod - doesn't request TPUs + testCPUWorker = &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "cpu-pod", + Namespace: namespaceStr, + Labels: map[string]string{ + utils.RayNodeLabelKey: "yes", + utils.RayClusterLabelKey: instanceName, + utils.RayNodeGroupLabelKey: groupNameStr, + }, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "ray-worker", + }, + }, + }, + Status: corev1.PodStatus{ + Phase: corev1.PodRunning, + ContainerStatuses: []corev1.ContainerStatus{ + { + Name: "ray-worker", + State: corev1.ContainerState{}, + }, + }, + }, + } + + // TPU Ray worker pod + testTPUWorker = &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "tpu-pod", + Namespace: namespaceStr, + Labels: map[string]string{ + utils.RayNodeLabelKey: "yes", + utils.RayClusterLabelKey: instanceName, + utils.RayNodeGroupLabelKey: groupNameStr, + }, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "ray-worker", + Resources: corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + "cpu": resource.MustParse("1"), + "google.com/tpu": resource.MustParse("4"), + "memory": resource.MustParse("40G"), + "ephemeral-storage": resource.MustParse("20Gi"), + }, + Requests: corev1.ResourceList{ + "cpu": resource.MustParse("1"), + "google.com/tpu": resource.MustParse("4"), + "memory": resource.MustParse("40G"), + "ephemeral-storage": resource.MustParse("20Gi"), + }, + }, + }, + }, + }, + Status: corev1.PodStatus{ + Phase: corev1.PodRunning, + ContainerStatuses: []corev1.ContainerStatus{ + { + Name: "ray-worker", + State: corev1.ContainerState{}, + }, + }, + }, + } +} + +// helper function used by tests which mutate sliceToWorkers +func deepCopySliceToWorkers() map[slice][]worker { + deepCopy := make(map[slice][]worker) + for slice, workerList := range sliceToWorkers { + deepCopy[slice] = []worker{} + for _, worker := range workerList { + deepCopy[slice] = append(deepCopy[slice], worker) + } + } + + return deepCopy +} + +func Test_GetReplicaIndex(t *testing.T) { + setupTest(t) + + tests := map[string]struct { + sliceToWorkers map[slice][]worker + numOfHosts int32 + numReplicas int + additionalGroupStr string + additionalNumOfHosts int32 + additionalNumReplicas int + workersToDelete []worker + }{ + "single-host, single-slice worker group": { + // single-slice, replicaIndex should always be 0 + numOfHosts: 1, + numReplicas: 1, + }, + "single-host, multi-slice worker group": { + // multi-slice, replicaIndex should always be 0-numReplicas + numOfHosts: 1, + numReplicas: 4, + }, + "multi-host, single-slice worker group": { + // single-slice, replicaIndex should always be 0 + numOfHosts: 4, + numReplicas: 1, + }, + "multi-host, multi-slice worker group": { + // multi-slice, replicaIndex should always be 0-numReplicas for 0-numOfHosts pods + numOfHosts: 4, + numReplicas: 4, + }, + "multiple worker groups": { + // should assign replicaIndex 0-numReplicas and TPU_WORKER_ID 0-numOfHosts + // for each respective worker group + numOfHosts: 4, + numReplicas: 4, + additionalGroupStr: "another-worker-group", + additionalNumOfHosts: 2, + additionalNumReplicas: 3, + }, + "deleted pods from replica": { + // should re-assign pods to lowest index replicas with # isCreated pods < NumOfHosts + numOfHosts: 4, + numReplicas: 4, + workersToDelete: []worker{worker{0, 0, true}, worker{2, 1, true}, worker{3, 2, true}}, + }, + "delete pods from different multi-host groups": { + // pods should be reassigned the lowest replica ID with # isCreated pods < NumOfHosts + // in each respective worker group + numOfHosts: 4, + numReplicas: 4, + additionalGroupStr: "another-worker-group", + additionalNumOfHosts: 4, + additionalNumReplicas: 3, + workersToDelete: []worker{worker{1, 0, true}, worker{2, 1, true}, worker{3, 2, true}}, + }, + } + + // validate getReplicaIndex() returns the expected Replica ID for TPU pods in varying pod slices + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + sliceToWorkersCopy := deepCopySliceToWorkers() + for i := 0; i < tc.numReplicas; i++ { + testPodSlice := slice{instanceName, groupNameStr, namespaceStr, i, tc.numOfHosts} + for j := 0; j < int(tc.numOfHosts); j++ { + replicaIndex := getReplicaIndex(instanceName, groupNameStr, namespaceStr) + assert.Equal(t, i, replicaIndex) + + // add the worker to sliceToWorkers - this would happen in getNextWorkerID + testWorker := worker{j, replicaIndex, true} + if sliceToWorkers[testPodSlice] == nil { + sliceToWorkers[testPodSlice] = []worker{testWorker} + } else { + sliceToWorkers[testPodSlice] = append(sliceToWorkers[testPodSlice], testWorker) + } + } + } + + if len(tc.workersToDelete) > 0 { + // test deleting and then re-assigning one pod at a time + for _, workerToDelete := range tc.workersToDelete { + // "delete" the pod + replicaToDeleteFrom := workerToDelete.replicaIndex + testPodSlice := slice{instanceName, groupNameStr, namespaceStr, replicaToDeleteFrom, tc.numOfHosts} + // set the pod isCreated value to false to simulate pod deletion + for index, worker := range sliceToWorkers[testPodSlice] { + if worker.workerIndex == workerToDelete.workerIndex { + sliceToWorkers[testPodSlice][index].isCreated = false + } + } + + // should re-assign the pod to the same replica + replicaIndex := getReplicaIndex(instanceName, groupNameStr, namespaceStr) + // set the pod isCreated value back to true to simulate pod re-creation + for index, worker := range sliceToWorkers[testPodSlice] { + if worker.workerIndex == workerToDelete.workerIndex { + sliceToWorkers[testPodSlice][index].isCreated = true + } + } + assert.Equal(t, replicaToDeleteFrom, replicaIndex) + } + + // test deleting pods simultaneously and then re-assigning + for _, workerToDelete := range tc.workersToDelete { + testPodSlice := slice{instanceName, groupNameStr, namespaceStr, workerToDelete.replicaIndex, tc.numOfHosts} + + // set the pod isCreated value to false to simulate pod deletion + for index, worker := range sliceToWorkers[testPodSlice] { + if worker.workerIndex == workerToDelete.workerIndex { + sliceToWorkers[testPodSlice][index].isCreated = false + } + } + } + } + + // test assigning pods to replicas for a different worker group + if tc.additionalGroupStr != "" { + for i := 0; i < tc.additionalNumReplicas; i++ { + testAdditionalPodSlice := slice{instanceName, tc.additionalGroupStr, namespaceStr, i, tc.additionalNumOfHosts} + for j := 0; j < int(tc.additionalNumOfHosts); j++ { + replicaIndex := getReplicaIndex(instanceName, tc.additionalGroupStr, namespaceStr) + assert.Equal(t, i, replicaIndex) + + // add the worker to sliceToWorkers - this would happen in getNextWorkerID + testWorker := worker{j, replicaIndex, true} + if sliceToWorkers[testAdditionalPodSlice] == nil { + sliceToWorkers[testAdditionalPodSlice] = []worker{testWorker} + } else { + sliceToWorkers[testAdditionalPodSlice] = append(sliceToWorkers[testAdditionalPodSlice], testWorker) + } + } + } + + // test deleting pods from a different worker group + if len(tc.workersToDelete) > 0 { + for _, workerToDelete := range tc.workersToDelete { + replicaToDeleteFrom := workerToDelete.replicaIndex + testAdditionalPodSlice := slice{instanceName, tc.additionalGroupStr, namespaceStr, replicaToDeleteFrom, tc.additionalNumOfHosts} + for index, worker := range sliceToWorkers[testAdditionalPodSlice] { + if worker.workerIndex == workerToDelete.workerIndex { + sliceToWorkers[testAdditionalPodSlice][index].isCreated = false + } + } + } + } + } + + // should re-assign the pod to the same replica for each respective worker group + if len(tc.workersToDelete) > 0 { + for _, workerToDelete := range tc.workersToDelete { + testPodSlice := slice{instanceName, groupNameStr, namespaceStr, workerToDelete.replicaIndex, tc.numOfHosts} + replicaIndex := getReplicaIndex(instanceName, groupNameStr, namespaceStr) + // "re-create" the worker pod + for index, worker := range sliceToWorkers[testPodSlice] { + if worker.workerIndex == workerToDelete.workerIndex { + sliceToWorkers[testPodSlice][index].isCreated = true + } + } + assert.Equal(t, workerToDelete.replicaIndex, replicaIndex) + + if tc.additionalGroupStr != "" { + testAdditionalPodSlice := slice{instanceName, tc.additionalGroupStr, namespaceStr, workerToDelete.replicaIndex, tc.additionalNumOfHosts} + additionalReplicaIndex := getReplicaIndex(instanceName, tc.additionalGroupStr, namespaceStr) + // "re-create" the worker pod + for index, worker := range sliceToWorkers[testAdditionalPodSlice] { + if worker.workerIndex == workerToDelete.workerIndex { + sliceToWorkers[testAdditionalPodSlice][index].isCreated = true + } + } + assert.Equal(t, workerToDelete.replicaIndex, additionalReplicaIndex) + } + } + } + + assert.Equal(t, tc.numReplicas+tc.additionalNumReplicas, len(sliceToWorkers)) + sliceToWorkers = sliceToWorkersCopy // reset sliceToWorkers to previous state + }) + } +} + +func Test_GetNextWorkerID(t *testing.T) { + setupTest(t) + + tests := map[string]struct { + numOfHosts int32 + numReplicas int + workersToDelete []worker + additionalGroupStr string + additionalNumOfHosts int32 + additionalNumReplicas int + }{ + "single-host, single-slice worker group": { + // single-host, TPU_WORKER_ID should always be 0 + numOfHosts: 1, + numReplicas: 1, + }, + "single-host, multi-slice worker group": { + // multi-slice, TPU_WORKER_ID should be 0 for all replicas + numOfHosts: 1, + numReplicas: 4, + }, + "multi-host, single-slice worker group": { + // multi-host, TPU_WORKER_ID should range from 0 to NumOfHosts-1 + numOfHosts: 4, + numReplicas: 1, + }, + "multi-host, multi-slice worker group": { + // multi-slice, unique TPU_WORKER_IDs should range from 0 to NumOfHosts-1 for each replica + numOfHosts: 4, + numReplicas: 4, + }, + "delete pods from multi-host group": { + // pods should be reassigned the lowest integer ID with isCreated == false belonging to the replica + numOfHosts: 4, + numReplicas: 4, + workersToDelete: []worker{worker{0, 0, true}, worker{2, 1, true}, worker{3, 2, true}}, + }, + "delete pods from different multi-host groups": { + // pods should be reassigned the lowest TPU_WORKER_ID ID with isCreated == false belonging to the replica + // in each respective worker group + numOfHosts: 4, + numReplicas: 4, + workersToDelete: []worker{worker{0, 0, true}, worker{2, 1, true}, worker{3, 2, true}}, + additionalGroupStr: "another-worker-group", + additionalNumOfHosts: 4, + additionalNumReplicas: 3, + }, + } + + // validate getNextWorkerID() returns the expected TPU_WORKER ID for different worker group specifications + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + sliceToWorkersCopy := deepCopySliceToWorkers() + for i := 0; i < tc.numReplicas; i++ { + testPodSlice := slice{instanceName, groupNameStr, namespaceStr, i, tc.numOfHosts} + for j := 0; j < int(tc.numOfHosts); j++ { + workerID := getNextWorkerID(testPodSlice, namespaceStr, i) + assert.Equal(t, j, workerID) + } + } + + if len(tc.workersToDelete) > 0 { + // test deleting and then re-assigning one pod at a time + for _, workerToDelete := range tc.workersToDelete { + replicaToDeleteFrom := workerToDelete.replicaIndex + testPodSlice := slice{instanceName, groupNameStr, namespaceStr, replicaToDeleteFrom, tc.numOfHosts} + // set the pod isCreated value to false to simulate pod deletion + for index, worker := range sliceToWorkers[testPodSlice] { + if worker.workerIndex == workerToDelete.workerIndex { + sliceToWorkers[testPodSlice][index].isCreated = false + } + } + workerID := getNextWorkerID(testPodSlice, namespaceStr, replicaToDeleteFrom) + assert.Equal(t, workerToDelete.workerIndex, workerID) + } + + // test deleting pods simultaneously and then re-assigning + for _, workerToDelete := range tc.workersToDelete { + testPodSlice := slice{instanceName, groupNameStr, namespaceStr, workerToDelete.replicaIndex, tc.numOfHosts} + // set the pod isCreated value to false to simulate pod deletion + for index, worker := range sliceToWorkers[testPodSlice] { + if worker.workerIndex == workerToDelete.workerIndex { + sliceToWorkers[testPodSlice][index].isCreated = false + } + } + } + } + + // test assigning TPU_WORKER_IDs to pods for a different worker group + if tc.additionalGroupStr != "" { + for i := 0; i < tc.additionalNumReplicas; i++ { + testAdditionalPodSlice := slice{instanceName, tc.additionalGroupStr, namespaceStr, i, tc.additionalNumOfHosts} + for j := 0; j < int(tc.additionalNumOfHosts); j++ { + workerID := getNextWorkerID(testAdditionalPodSlice, namespaceStr, i) + assert.Equal(t, j, workerID) + } + } + + // test deleting pods from a different worker group + if len(tc.workersToDelete) > 0 { + for _, workerToDelete := range tc.workersToDelete { + replicaToDeleteFrom := workerToDelete.replicaIndex + testAdditionalPodSlice := slice{instanceName, tc.additionalGroupStr, namespaceStr, replicaToDeleteFrom, tc.additionalNumOfHosts} + for index, worker := range sliceToWorkers[testAdditionalPodSlice] { + if worker.workerIndex == workerToDelete.workerIndex { + sliceToWorkers[testAdditionalPodSlice][index].isCreated = false + } + } + } + } + } + + // should re-assign the pod to the same replica for each respective worker group + if len(tc.workersToDelete) > 0 { + for _, workerToDelete := range tc.workersToDelete { + testPodSlice := slice{instanceName, groupNameStr, namespaceStr, workerToDelete.replicaIndex, tc.numOfHosts} + workerID := getNextWorkerID(testPodSlice, namespaceStr, workerToDelete.replicaIndex) + assert.Equal(t, workerToDelete.workerIndex, workerID) + + if tc.additionalGroupStr != "" { + testAdditionalPodSlice := slice{instanceName, tc.additionalGroupStr, namespaceStr, workerToDelete.replicaIndex, tc.additionalNumOfHosts} + additionalWorkerID := getNextWorkerID(testAdditionalPodSlice, namespaceStr, workerToDelete.replicaIndex) + // "re-create" the worker pod + for index, worker := range sliceToWorkers[testAdditionalPodSlice] { + if worker.workerIndex == workerToDelete.workerIndex { + sliceToWorkers[testAdditionalPodSlice][index].isCreated = true + } + } + assert.Equal(t, workerToDelete.workerIndex, additionalWorkerID) + } + } + } + sliceToWorkers = sliceToWorkersCopy // reset sliceToWorkers to previous state + }) + } +}