Skip to content

Commit

Permalink
Ray TPU Webhook Autoscaling Support (GoogleCloudPlatform#740)
Browse files Browse the repository at this point in the history
* Generate hostnames at Pod creation

Signed-off-by: Ryan O'Leary <[email protected]>

* Should not fatal log in deletePod

Signed-off-by: Ryan O'Leary <[email protected]>

* deletePod admission always succeeds

Signed-off-by: Ryan O'Leary <[email protected]>

* Remove unused tests make command

Signed-off-by: Ryan O'Leary <[email protected]>

* Update tests and add error checking

Signed-off-by: Ryan O'Leary <[email protected]>

* Just return an error instead of logging

Signed-off-by: Ryan O'Leary <[email protected]>

---------

Signed-off-by: Ryan O'Leary <[email protected]>
  • Loading branch information
ryanaoleary committed Jul 19, 2024
1 parent 1374b77 commit dc3a615
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 87 deletions.
5 changes: 0 additions & 5 deletions ray-on-gke/tpu/kuberay-tpu-webhook/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,3 @@ deploy-cert:
uninstall-cert:
kubectl delete -f certs/

tests:
kubectl apply -f tests/

delete-tests:
kubectl delete -f tests/
74 changes: 32 additions & 42 deletions ray-on-gke/tpu/kuberay-tpu-webhook/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,6 @@ var (
// map of pod slices to workers in the slice
sliceToWorkers map[slice][]worker

// map of pod slices to TPU_WORKER_HOSTNAMES in that pod slice
sliceToHostnames map[slice]string

// Flag arguments.
BindAddr string
CACert string
Expand Down Expand Up @@ -134,18 +131,17 @@ func extractRayCluster(admissionReview *admissionv1.AdmissionReview) (*ray.RayCl
return &rayCluster, nil
}

func genDNSHostnames(workerGroupSpec ray.WorkerGroupSpec, clusterName string, namespace string, replicaIndex int) (string, error) {
numHosts := workerGroupSpec.NumOfHosts
if numHosts == 0 {
return "", errors.New("workerGroupSpec NumOfHosts not set")
func genDNSHostnames(numOfHosts int32, groupName string, clusterName string, namespace string, replicaIndex int) (string, error) {
if numOfHosts == 0 {
err := errors.New("workerGroupSpec NumOfHosts not set")
return "", err
}
workerGroupName := workerGroupSpec.GroupName
hostNames := make([]string, numHosts)
hostNames := make([]string, numOfHosts)
// Host names will be of the form {WORKER_GROUP_NAME}-{REPLICA_INDEX}-{HOST_INDEX}.headless-worker-svc
for j := 0; j < int(numHosts); j++ {
hostNames[j] = fmt.Sprintf("%s-%d-%d.%s-%s", workerGroupName, replicaIndex, j, clusterName, headlessServiceSuffix)
for j := 0; j < int(numOfHosts); j++ {
hostNames[j] = fmt.Sprintf("%s-%d-%d.%s-%s", groupName, replicaIndex, j, clusterName, headlessServiceSuffix)
}
klog.V(1).InfoS("genDNSHostnames", "RayCluster", namespace+"/"+clusterName, "NumOfHosts", numHosts, "Replica Index", replicaIndex)
klog.V(1).InfoS("genDNSHostnames", "RayCluster", namespace+"/"+clusterName, "NumOfHosts", numOfHosts, "Replica Index", replicaIndex)
return strings.Join(hostNames, ","), nil
}

Expand Down Expand Up @@ -268,23 +264,13 @@ func validateRayCluster(admissionReview *admissionv1.AdmissionReview) (*admissio
workerGroupSpec := workerGroupSpecs[i]
if containerRequestingTPUs(workerGroupSpec.Template.Spec.Containers...) {
klog.V(1).InfoS("validateRayCluster", "RayCluster", namespace+"/"+clusterName, "Worker Group", workerGroupSpec.GroupName, "Requests TPUs", true)
// create mapping for pod slices -> TPU_WORKER_HOSTNAMES in cluster
replicas := int(*workerGroupSpec.Replicas)
numOfHosts := workerGroupSpec.NumOfHosts
for replicaIndex := 0; replicaIndex < replicas; replicaIndex++ {
// reset past sliceToWorkers and sliceToHostnames entries for slice in ray cluster
// reset past sliceToWorkers entries for slice in ray cluster
groupName := workerGroupSpec.GroupName
podSlice := slice{clusterName, groupName, namespace, replicaIndex, numOfHosts}
sliceToWorkers[podSlice] = nil
sliceToHostnames[podSlice] = ""
// generate TPU_WORKER_HOSTNAMES
if numOfHosts > 1 {
joinedHostNames, err := genDNSHostnames(workerGroupSpec, clusterName, namespace, replicaIndex)
if err != nil {
klog.Error("Failed to generate DNS Hostnames")
}
sliceToHostnames[podSlice] = joinedHostNames
}
}
} else {
// RayCluster worker group does not request TPUs
Expand Down Expand Up @@ -480,10 +466,14 @@ func mutatePod(admissionReview *admissionv1.AdmissionReview) (*admissionv1.Admis
if containerRequestingTPUs(container) {
path := fmt.Sprintf("/spec/containers/%d/env", i)
if numOfHosts > 1 {
// inject TPU_WORKER_HOSTNAMES set during RayCluster interception
klog.V(1).InfoS("mutatePod", "RayCluster", namespace+"/"+clusterName, "TPU_WORKER_HOSTNAMES", sliceToHostnames[podSlice])
// inject TPU_WORKER_HOSTNAMES
hostnames, err := genDNSHostnames(numOfHosts, groupName, clusterName, namespace, replicaIndex)
if err != nil {
return nil, err
}
klog.V(1).InfoS("mutatePod", "RayCluster", namespace+"/"+clusterName, "TPU_WORKER_HOSTNAMES", hostnames)
klog.V(1).InfoS("mutatePod", "RayCluster", namespace+"/"+clusterName, "subdomain", clusterName+"-"+headlessServiceSuffix)
injectHostnames(clusterName, sliceToHostnames[podSlice], path, container, &patches)
injectHostnames(clusterName, hostnames, path, container, &patches)
}
// inject TPU_WORKER_ID
if getEnvironmentVariable("TPU_WORKER_ID", container) == "" {
Expand Down Expand Up @@ -545,18 +535,28 @@ func mutatePod(admissionReview *admissionv1.AdmissionReview) (*admissionv1.Admis

// update sliceToWorkers map on pod deletion
func deletePod(admissionReview *admissionv1.AdmissionReview) (*admissionv1.AdmissionResponse, error) {
// Create AdmissionResponse - we never deny the deletion request
admissionResponse := &admissionv1.AdmissionResponse{
UID: admissionReview.Request.UID,
Allowed: true,
Result: &metav1.Status{
Status: "Success",
Message: "",
},
}

pod, err := extractPod(admissionReview)
if err != nil {
klog.Fatalf("Pod extraction failed: %s", err)
klog.Errorf("Pod extraction failed: %s", err)
}

clusterName := pod.Labels["ray.io/cluster"]
if clusterName == "" {
return nil, errors.New("Kuberay Pod missing RayCluster label")
return admissionResponse, errors.New("Kuberay Pod missing RayCluster label")
}
groupName := pod.Labels["ray.io/group"]
if groupName == "" {
return nil, errors.New("Kuberay Pod missing Ray group label")
return admissionResponse, errors.New("Kuberay Pod missing Ray group label")
}
namespace := pod.Namespace
replicaIndexLabel := pod.Labels["replicaIndex"]
Expand All @@ -567,20 +567,20 @@ func deletePod(admissionReview *admissionv1.AdmissionReview) (*admissionv1.Admis

containers := pod.Spec.Containers
if containers == nil {
return nil, errors.New("Pod spec missing containers")
return admissionResponse, errors.New("Pod spec missing containers")
}
tpuWorkerID := -1
for _, container := range pod.Spec.Containers {
if containerRequestingTPUs(container) {
tpuWorkerID, err = strconv.Atoi(getEnvironmentVariable("TPU_WORKER_ID", container))
if err != nil {
return nil, errors.New("Unable to extract TPU_WORKER_ID")
return admissionResponse, errors.New("Unable to extract TPU_WORKER_ID")
}
break
}
}
if tpuWorkerID == -1 {
return nil, errors.New("Kuberay Pod missing TPU_WORKER_ID")
return admissionResponse, errors.New("Kuberay Pod missing TPU_WORKER_ID")
}
// update sliceToWorkers map
for slice, _ := range sliceToWorkers {
Expand All @@ -598,15 +598,6 @@ func deletePod(admissionReview *admissionv1.AdmissionReview) (*admissionv1.Admis
}
}

// Create AdmissionResponse - we never deny the deletion request
admissionResponse := &admissionv1.AdmissionResponse{
UID: admissionReview.Request.UID,
Allowed: true,
Result: &metav1.Status{
Status: "Success",
Message: "",
},
}
return admissionResponse, nil
}

Expand All @@ -621,7 +612,6 @@ func writeCertfile(filename string, encodedData string) error {

func init() {
sliceToWorkers = make(map[slice][]worker)
sliceToHostnames = make(map[slice]string)

flag.StringVar(&BindAddr, "bind-address", ":443", "Address to bind HTTPS service to")
flag.StringVar(&CACert, "ca-cert", "", "base64-encoded root certificate for TLS")
Expand Down
43 changes: 3 additions & 40 deletions ray-on-gke/tpu/kuberay-tpu-webhook/webhook_main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,16 +303,6 @@ func setupTest(t *testing.T) {
}
}

// helper function used by tests which mutate sliceToHostnames
func deepCopySliceToHostnames() map[slice]string {
deepCopy := make(map[slice]string)
for slice, hostnames := range sliceToHostnames {
deepCopy[slice] = hostnames
}

return deepCopy
}

// helper function used by tests which mutate sliceToWorkers
func deepCopySliceToWorkers() map[slice][]worker {
deepCopy := make(map[slice][]worker)
Expand Down Expand Up @@ -895,30 +885,26 @@ func Test_GenDNSHostnames(t *testing.T) {
setupTest(t)

tests := map[string]struct {
workerGroupSpec *rayv1.WorkerGroupSpec
replicaIndex int
numOfHosts int32
expectedHostnames string
expectedError error
}{
"genDNSHostnames with NumOfHosts == 0": {
// you can't have a workergroup with NumOfHosts set to 0 so this should error out
workerGroupSpec: testWorkerGroupSpec,
// a workergroup can't have NumOfHosts set to 0 so this should error out
replicaIndex: 0,
numOfHosts: int32(0),
expectedError: errors.New("workerGroupSpec NumOfHosts not set"),
},
"genDNSHostnames with NumOfHosts == 1": {
// Single-host worker group, should return a single DNS hostname. This function will
// never be called for single-host groups, but we don't necessarily want it to error if it does.
workerGroupSpec: testWorkerGroupSpec,
replicaIndex: 0,
numOfHosts: int32(1),
expectedHostnames: fmt.Sprintf("%s-%d-%d.%s-%s", groupNameStr, 0, 0, instanceName, headlessServiceSuffix),
},
"genDNSHostnames with NumOfHosts > 1": {
// multi-host worker group, should return a string list of DNS hostnames for the given replica
workerGroupSpec: testWorkerGroupSpec,
replicaIndex: 1,
numOfHosts: int32(4),
expectedHostnames: strings.Join([]string{fmt.Sprintf("%s-%d-%d.%s-%s", groupNameStr, 1, 0, instanceName, headlessServiceSuffix),
Expand All @@ -932,8 +918,7 @@ func Test_GenDNSHostnames(t *testing.T) {
// validate that genDNSHostnames correctly returns a string list of DNS addressable hostnames
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
tc.workerGroupSpec.NumOfHosts = tc.numOfHosts
hostnames, err := genDNSHostnames(*tc.workerGroupSpec, instanceName, namespaceStr, tc.replicaIndex)
hostnames, err := genDNSHostnames(tc.numOfHosts, groupNameStr, instanceName, namespaceStr, tc.replicaIndex)
if err != nil {
assert.Equal(t, tc.expectedError, err)
} else {
Expand Down Expand Up @@ -1241,7 +1226,6 @@ func Test_ValidateRayCluster(t *testing.T) {
// check validateRayCluster
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
sliceToHostnamesCopy := deepCopySliceToHostnames()
sliceToWorkersCopy := deepCopySliceToWorkers()

// set up admissionReview object
Expand Down Expand Up @@ -1272,22 +1256,7 @@ func Test_ValidateRayCluster(t *testing.T) {
assert.Equal(t, tc.expectedResult.Message, admissionResponse.Result.Message)
}

// check that sliceToHostnames entry is generated
if tc.topology != "" && tc.numOfHosts > 1 {
for replicaIndex := 0; replicaIndex < int(*tc.replicas); replicaIndex++ {
// generate TPU_WORKER_HOSTNAME values
var expectedHostnames []string
for hostIndex := 0; hostIndex < int(tc.numOfHosts); hostIndex++ {
expectedHostnames = append(expectedHostnames, fmt.Sprintf("%s-%d-%d.%s-%s", groupNameStr, replicaIndex, hostIndex, instanceName, headlessServiceSuffix))
}
// check that expectedHostnames have been set for each slice
testSlice := slice{instanceName, groupNameStr, namespaceStr, replicaIndex, tc.numOfHosts}
assert.Equal(t, strings.Join(expectedHostnames, ","), sliceToHostnames[testSlice])
}
}

// set maps back to their previous values
sliceToHostnames = sliceToHostnamesCopy
// reset map previous values
sliceToWorkers = sliceToWorkersCopy
})
}
Expand Down Expand Up @@ -1418,11 +1387,6 @@ func Test_MutatePod(t *testing.T) {
t.Run(name, func(t *testing.T) {
// save copy of sliceToWorkers
sliceToWorkersCopy := deepCopySliceToWorkers()
sliceToHostnamesCopy := deepCopySliceToHostnames()

// set sliceToHostnames value to be injected during mutatePod
testSlice := slice{instanceName, groupNameStr, namespaceStr, tc.expectedReplicaID, tc.numOfHosts}
sliceToHostnames[testSlice] = tc.expectedHostnames

// set up Pod object
if tc.missingClusterLabel {
Expand Down Expand Up @@ -1478,7 +1442,6 @@ func Test_MutatePod(t *testing.T) {
}
// reset map values after test
sliceToWorkers = sliceToWorkersCopy
sliceToHostnames = sliceToHostnamesCopy
}
})
}
Expand Down

0 comments on commit dc3a615

Please sign in to comment.