diff --git a/ray-on-gke/tpu/kuberay-tpu-webhook/Makefile b/ray-on-gke/tpu/kuberay-tpu-webhook/Makefile index 91183f60b..b2ee4da3c 100644 --- a/ray-on-gke/tpu/kuberay-tpu-webhook/Makefile +++ b/ray-on-gke/tpu/kuberay-tpu-webhook/Makefile @@ -39,8 +39,3 @@ deploy-cert: uninstall-cert: kubectl delete -f certs/ -tests: - kubectl apply -f tests/ - -delete-tests: - kubectl delete -f tests/ diff --git a/ray-on-gke/tpu/kuberay-tpu-webhook/main.go b/ray-on-gke/tpu/kuberay-tpu-webhook/main.go index 2a7ce691d..dbb819581 100755 --- a/ray-on-gke/tpu/kuberay-tpu-webhook/main.go +++ b/ray-on-gke/tpu/kuberay-tpu-webhook/main.go @@ -50,9 +50,6 @@ var ( // map of pod slices to workers in the slice sliceToWorkers map[slice][]worker - // map of pod slices to TPU_WORKER_HOSTNAMES in that pod slice - sliceToHostnames map[slice]string - // Flag arguments. BindAddr string CACert string @@ -134,18 +131,17 @@ func extractRayCluster(admissionReview *admissionv1.AdmissionReview) (*ray.RayCl return &rayCluster, nil } -func genDNSHostnames(workerGroupSpec ray.WorkerGroupSpec, clusterName string, namespace string, replicaIndex int) (string, error) { - numHosts := workerGroupSpec.NumOfHosts - if numHosts == 0 { - return "", errors.New("workerGroupSpec NumOfHosts not set") +func genDNSHostnames(numOfHosts int32, groupName string, clusterName string, namespace string, replicaIndex int) (string, error) { + if numOfHosts == 0 { + err := errors.New("workerGroupSpec NumOfHosts not set") + return "", err } - workerGroupName := workerGroupSpec.GroupName - hostNames := make([]string, numHosts) + hostNames := make([]string, numOfHosts) // Host names will be of the form {WORKER_GROUP_NAME}-{REPLICA_INDEX}-{HOST_INDEX}.headless-worker-svc - for j := 0; j < int(numHosts); j++ { - hostNames[j] = fmt.Sprintf("%s-%d-%d.%s-%s", workerGroupName, replicaIndex, j, clusterName, headlessServiceSuffix) + for j := 0; j < int(numOfHosts); j++ { + hostNames[j] = fmt.Sprintf("%s-%d-%d.%s-%s", groupName, replicaIndex, j, clusterName, headlessServiceSuffix) } - klog.V(1).InfoS("genDNSHostnames", "RayCluster", namespace+"/"+clusterName, "NumOfHosts", numHosts, "Replica Index", replicaIndex) + klog.V(1).InfoS("genDNSHostnames", "RayCluster", namespace+"/"+clusterName, "NumOfHosts", numOfHosts, "Replica Index", replicaIndex) return strings.Join(hostNames, ","), nil } @@ -268,23 +264,13 @@ func validateRayCluster(admissionReview *admissionv1.AdmissionReview) (*admissio workerGroupSpec := workerGroupSpecs[i] if containerRequestingTPUs(workerGroupSpec.Template.Spec.Containers...) { klog.V(1).InfoS("validateRayCluster", "RayCluster", namespace+"/"+clusterName, "Worker Group", workerGroupSpec.GroupName, "Requests TPUs", true) - // create mapping for pod slices -> TPU_WORKER_HOSTNAMES in cluster replicas := int(*workerGroupSpec.Replicas) numOfHosts := workerGroupSpec.NumOfHosts for replicaIndex := 0; replicaIndex < replicas; replicaIndex++ { - // reset past sliceToWorkers and sliceToHostnames entries for slice in ray cluster + // reset past sliceToWorkers entries for slice in ray cluster groupName := workerGroupSpec.GroupName podSlice := slice{clusterName, groupName, namespace, replicaIndex, numOfHosts} sliceToWorkers[podSlice] = nil - sliceToHostnames[podSlice] = "" - // generate TPU_WORKER_HOSTNAMES - if numOfHosts > 1 { - joinedHostNames, err := genDNSHostnames(workerGroupSpec, clusterName, namespace, replicaIndex) - if err != nil { - klog.Error("Failed to generate DNS Hostnames") - } - sliceToHostnames[podSlice] = joinedHostNames - } } } else { // RayCluster worker group does not request TPUs @@ -480,10 +466,14 @@ func mutatePod(admissionReview *admissionv1.AdmissionReview) (*admissionv1.Admis if containerRequestingTPUs(container) { path := fmt.Sprintf("/spec/containers/%d/env", i) if numOfHosts > 1 { - // inject TPU_WORKER_HOSTNAMES set during RayCluster interception - klog.V(1).InfoS("mutatePod", "RayCluster", namespace+"/"+clusterName, "TPU_WORKER_HOSTNAMES", sliceToHostnames[podSlice]) + // inject TPU_WORKER_HOSTNAMES + hostnames, err := genDNSHostnames(numOfHosts, groupName, clusterName, namespace, replicaIndex) + if err != nil { + return nil, err + } + klog.V(1).InfoS("mutatePod", "RayCluster", namespace+"/"+clusterName, "TPU_WORKER_HOSTNAMES", hostnames) klog.V(1).InfoS("mutatePod", "RayCluster", namespace+"/"+clusterName, "subdomain", clusterName+"-"+headlessServiceSuffix) - injectHostnames(clusterName, sliceToHostnames[podSlice], path, container, &patches) + injectHostnames(clusterName, hostnames, path, container, &patches) } // inject TPU_WORKER_ID if getEnvironmentVariable("TPU_WORKER_ID", container) == "" { @@ -545,18 +535,28 @@ func mutatePod(admissionReview *admissionv1.AdmissionReview) (*admissionv1.Admis // update sliceToWorkers map on pod deletion func deletePod(admissionReview *admissionv1.AdmissionReview) (*admissionv1.AdmissionResponse, error) { + // Create AdmissionResponse - we never deny the deletion request + admissionResponse := &admissionv1.AdmissionResponse{ + UID: admissionReview.Request.UID, + Allowed: true, + Result: &metav1.Status{ + Status: "Success", + Message: "", + }, + } + pod, err := extractPod(admissionReview) if err != nil { - klog.Fatalf("Pod extraction failed: %s", err) + klog.Errorf("Pod extraction failed: %s", err) } clusterName := pod.Labels["ray.io/cluster"] if clusterName == "" { - return nil, errors.New("Kuberay Pod missing RayCluster label") + return admissionResponse, errors.New("Kuberay Pod missing RayCluster label") } groupName := pod.Labels["ray.io/group"] if groupName == "" { - return nil, errors.New("Kuberay Pod missing Ray group label") + return admissionResponse, errors.New("Kuberay Pod missing Ray group label") } namespace := pod.Namespace replicaIndexLabel := pod.Labels["replicaIndex"] @@ -567,20 +567,20 @@ func deletePod(admissionReview *admissionv1.AdmissionReview) (*admissionv1.Admis containers := pod.Spec.Containers if containers == nil { - return nil, errors.New("Pod spec missing containers") + return admissionResponse, errors.New("Pod spec missing containers") } tpuWorkerID := -1 for _, container := range pod.Spec.Containers { if containerRequestingTPUs(container) { tpuWorkerID, err = strconv.Atoi(getEnvironmentVariable("TPU_WORKER_ID", container)) if err != nil { - return nil, errors.New("Unable to extract TPU_WORKER_ID") + return admissionResponse, errors.New("Unable to extract TPU_WORKER_ID") } break } } if tpuWorkerID == -1 { - return nil, errors.New("Kuberay Pod missing TPU_WORKER_ID") + return admissionResponse, errors.New("Kuberay Pod missing TPU_WORKER_ID") } // update sliceToWorkers map for slice, _ := range sliceToWorkers { @@ -598,15 +598,6 @@ func deletePod(admissionReview *admissionv1.AdmissionReview) (*admissionv1.Admis } } - // Create AdmissionResponse - we never deny the deletion request - admissionResponse := &admissionv1.AdmissionResponse{ - UID: admissionReview.Request.UID, - Allowed: true, - Result: &metav1.Status{ - Status: "Success", - Message: "", - }, - } return admissionResponse, nil } @@ -621,7 +612,6 @@ func writeCertfile(filename string, encodedData string) error { func init() { sliceToWorkers = make(map[slice][]worker) - sliceToHostnames = make(map[slice]string) flag.StringVar(&BindAddr, "bind-address", ":443", "Address to bind HTTPS service to") flag.StringVar(&CACert, "ca-cert", "", "base64-encoded root certificate for TLS") diff --git a/ray-on-gke/tpu/kuberay-tpu-webhook/webhook_main_test.go b/ray-on-gke/tpu/kuberay-tpu-webhook/webhook_main_test.go index 67e457678..d90333da1 100644 --- a/ray-on-gke/tpu/kuberay-tpu-webhook/webhook_main_test.go +++ b/ray-on-gke/tpu/kuberay-tpu-webhook/webhook_main_test.go @@ -303,16 +303,6 @@ func setupTest(t *testing.T) { } } -// helper function used by tests which mutate sliceToHostnames -func deepCopySliceToHostnames() map[slice]string { - deepCopy := make(map[slice]string) - for slice, hostnames := range sliceToHostnames { - deepCopy[slice] = hostnames - } - - return deepCopy -} - // helper function used by tests which mutate sliceToWorkers func deepCopySliceToWorkers() map[slice][]worker { deepCopy := make(map[slice][]worker) @@ -895,15 +885,13 @@ func Test_GenDNSHostnames(t *testing.T) { setupTest(t) tests := map[string]struct { - workerGroupSpec *rayv1.WorkerGroupSpec replicaIndex int numOfHosts int32 expectedHostnames string expectedError error }{ "genDNSHostnames with NumOfHosts == 0": { - // you can't have a workergroup with NumOfHosts set to 0 so this should error out - workerGroupSpec: testWorkerGroupSpec, + // a workergroup can't have NumOfHosts set to 0 so this should error out replicaIndex: 0, numOfHosts: int32(0), expectedError: errors.New("workerGroupSpec NumOfHosts not set"), @@ -911,14 +899,12 @@ func Test_GenDNSHostnames(t *testing.T) { "genDNSHostnames with NumOfHosts == 1": { // Single-host worker group, should return a single DNS hostname. This function will // never be called for single-host groups, but we don't necessarily want it to error if it does. - workerGroupSpec: testWorkerGroupSpec, replicaIndex: 0, numOfHosts: int32(1), expectedHostnames: fmt.Sprintf("%s-%d-%d.%s-%s", groupNameStr, 0, 0, instanceName, headlessServiceSuffix), }, "genDNSHostnames with NumOfHosts > 1": { // multi-host worker group, should return a string list of DNS hostnames for the given replica - workerGroupSpec: testWorkerGroupSpec, replicaIndex: 1, numOfHosts: int32(4), expectedHostnames: strings.Join([]string{fmt.Sprintf("%s-%d-%d.%s-%s", groupNameStr, 1, 0, instanceName, headlessServiceSuffix), @@ -932,8 +918,7 @@ func Test_GenDNSHostnames(t *testing.T) { // validate that genDNSHostnames correctly returns a string list of DNS addressable hostnames for name, tc := range tests { t.Run(name, func(t *testing.T) { - tc.workerGroupSpec.NumOfHosts = tc.numOfHosts - hostnames, err := genDNSHostnames(*tc.workerGroupSpec, instanceName, namespaceStr, tc.replicaIndex) + hostnames, err := genDNSHostnames(tc.numOfHosts, groupNameStr, instanceName, namespaceStr, tc.replicaIndex) if err != nil { assert.Equal(t, tc.expectedError, err) } else { @@ -1241,7 +1226,6 @@ func Test_ValidateRayCluster(t *testing.T) { // check validateRayCluster for name, tc := range tests { t.Run(name, func(t *testing.T) { - sliceToHostnamesCopy := deepCopySliceToHostnames() sliceToWorkersCopy := deepCopySliceToWorkers() // set up admissionReview object @@ -1272,22 +1256,7 @@ func Test_ValidateRayCluster(t *testing.T) { assert.Equal(t, tc.expectedResult.Message, admissionResponse.Result.Message) } - // check that sliceToHostnames entry is generated - if tc.topology != "" && tc.numOfHosts > 1 { - for replicaIndex := 0; replicaIndex < int(*tc.replicas); replicaIndex++ { - // generate TPU_WORKER_HOSTNAME values - var expectedHostnames []string - for hostIndex := 0; hostIndex < int(tc.numOfHosts); hostIndex++ { - expectedHostnames = append(expectedHostnames, fmt.Sprintf("%s-%d-%d.%s-%s", groupNameStr, replicaIndex, hostIndex, instanceName, headlessServiceSuffix)) - } - // check that expectedHostnames have been set for each slice - testSlice := slice{instanceName, groupNameStr, namespaceStr, replicaIndex, tc.numOfHosts} - assert.Equal(t, strings.Join(expectedHostnames, ","), sliceToHostnames[testSlice]) - } - } - - // set maps back to their previous values - sliceToHostnames = sliceToHostnamesCopy + // reset map previous values sliceToWorkers = sliceToWorkersCopy }) } @@ -1418,11 +1387,6 @@ func Test_MutatePod(t *testing.T) { t.Run(name, func(t *testing.T) { // save copy of sliceToWorkers sliceToWorkersCopy := deepCopySliceToWorkers() - sliceToHostnamesCopy := deepCopySliceToHostnames() - - // set sliceToHostnames value to be injected during mutatePod - testSlice := slice{instanceName, groupNameStr, namespaceStr, tc.expectedReplicaID, tc.numOfHosts} - sliceToHostnames[testSlice] = tc.expectedHostnames // set up Pod object if tc.missingClusterLabel { @@ -1478,7 +1442,6 @@ func Test_MutatePod(t *testing.T) { } // reset map values after test sliceToWorkers = sliceToWorkersCopy - sliceToHostnames = sliceToHostnamesCopy } }) }