Skip to content

Commit

Permalink
Support for Multiple Separate TPU Worker Groups per RayCluster (#467)
Browse files Browse the repository at this point in the history
* Support for multiple seperate TPU workergroups per RayCluster

* Add namespace to slice struct, logs, and comments

* Added unit tests for getReplicaIndex and getNextWorkerID

* added two more test cases for edge cases

* Fixed comments
  • Loading branch information
ryanaoleary authored and kfswain committed Apr 15, 2024
1 parent c989934 commit 77c3be2
Show file tree
Hide file tree
Showing 4 changed files with 520 additions and 49 deletions.
2 changes: 2 additions & 0 deletions applications/ray/kuberay-tpu-webhook/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ go 1.21

require (
github.com/ray-project/kuberay/ray-operator v1.1.0-rc.0
github.com/stretchr/testify v1.8.4
k8s.io/api v0.29.1
k8s.io/apimachinery v0.29.1
k8s.io/klog/v2 v2.120.1
Expand Down Expand Up @@ -37,6 +38,7 @@ require (
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_golang v1.18.0 // indirect
github.com/prometheus/client_model v0.5.0 // indirect
github.com/prometheus/common v0.45.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions applications/ray/kuberay-tpu-webhook/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+
github.com/ianlancetaylor/demangle v0.0.0-20210905161508-09a460cdf81d/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w=
github.com/imdario/mergo v0.3.12 h1:b6R2BslTbIEToALKP7LxUvijTsNI9TAe80pLWN2g/HU=
github.com/imdario/mergo v0.3.12/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA=
github.com/jarcoal/httpmock v1.2.0 h1:gSvTxxFR/MEMfsGrvRbdfpRUMBStovlSRLw0Ep1bwwc=
github.com/jarcoal/httpmock v1.2.0/go.mod h1:oCoTsnAz4+UoOUIf5lJOWV2QQIW5UoeUI6aM2YnWAZk=
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
Expand Down
121 changes: 72 additions & 49 deletions applications/ray/kuberay-tpu-webhook/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
type slice struct {
clusterName string
groupName string
namespace string
replicaIndex int
numOfHosts int32
}
Expand Down Expand Up @@ -77,7 +78,7 @@ func containerRequestingTPUs(containers ...corev1.Container) bool {
return false
}

func getNumTPUHostsFromTopology(clusterName string, namespace string, topology string, acceleratorType string) (int32, error) {
func getNumTPUHostsFromTopology(clusterName string, groupName string, namespace string, topology string, acceleratorType string) (int32, error) {
if topology == "" {
return 0, errors.New("TPU topology not specified")
}
Expand All @@ -86,7 +87,7 @@ func getNumTPUHostsFromTopology(clusterName string, namespace string, topology s
for i := 0; i < len(topologyVals); i++ {
dim, err := strconv.Atoi(topologyVals[i])
if err != nil {
klog.ErrorS(err, "RayCluster", namespace+"/"+clusterName, "gke-tpu-topology", topology)
klog.ErrorS(err, "RayCluster", namespace+"/"+clusterName, "Worker Group", groupName, "gke-tpu-topology", topology)
return 0, err
}
chips *= dim
Expand All @@ -98,19 +99,19 @@ func getNumTPUHostsFromTopology(clusterName string, namespace string, topology s
// v5e TPU VMs can have 1, 4 or 8 chips
chipsPerHost, err := strconv.Atoi(acceleratorTypeValues[1])
if err != nil {
klog.ErrorS(err, "RayCluster", namespace+"/"+clusterName, "gke-tpu-accelerator", acceleratorType)
klog.ErrorS(err, "RayCluster", namespace+"/"+clusterName, "Worker Group", groupName, "gke-tpu-accelerator", acceleratorType)
return 0, err
}
chipsPerHost = min(chipsPerHost, 8) // max of 8 chips per host
}
hosts := int32(max(chips/chipsPerHost, 1))
klog.V(1).InfoS("getNumTPUHostsFromTopology", "RayCluster", namespace+"/"+clusterName, "hosts", hosts)
klog.V(1).InfoS("getNumTPUHostsFromTopology", "RayCluster", namespace+"/"+clusterName, "Worker Group", groupName, "hosts", hosts)
return hosts, nil
}

// check if request is for TPU multi-host
func isTPUMultiHost(clusterName string, namespace string, topology string, acceleratorType string) (bool, error) {
vms, err := getNumTPUHostsFromTopology(clusterName, namespace, topology, acceleratorType)
func isTPUMultiHost(clusterName string, groupName string, namespace string, topology string, acceleratorType string) (bool, error) {
vms, err := getNumTPUHostsFromTopology(clusterName, groupName, namespace, topology, acceleratorType)
if err != nil {
return false, err
}
Expand All @@ -133,7 +134,7 @@ func extractRayCluster(admissionReview *admissionv1.AdmissionReview) (*ray.RayCl
return &rayCluster, nil
}

func genDNSHostnames(workerGroupSpec ray.WorkerGroupSpec, replicaIndex int) (string, error) {
func genDNSHostnames(workerGroupSpec ray.WorkerGroupSpec, clusterName string, namespace string, replicaIndex int) (string, error) {
numHosts := workerGroupSpec.NumOfHosts
if numHosts == 0 {
return "", errors.New("workerGroupSpec NumOfHosts not set")
Expand All @@ -144,6 +145,7 @@ func genDNSHostnames(workerGroupSpec ray.WorkerGroupSpec, replicaIndex int) (str
for j := 0; j < int(numHosts); j++ {
hostNames[j] = fmt.Sprintf("%s-%d-%d.%s", workerGroupName, replicaIndex, j, headlessServiceName)
}
klog.V(1).InfoS("genDNSHostnames", "RayCluster", namespace+"/"+clusterName, "NumOfHosts", numHosts, "Replica Index", replicaIndex)
return strings.Join(hostNames, ","), nil
}

Expand Down Expand Up @@ -218,6 +220,7 @@ func checkWorkersMatchTopology(clusterName string, namespace string, workerGroup
if numHosts == 0 {
return false, errors.New("workerGroupSpec NumOfHosts not set")
}
groupName := workerGroupSpec.GroupName
containers := workerGroupSpec.Template.Spec.Containers
if containers == nil {
return false, errors.New("Container path not specified")
Expand All @@ -227,12 +230,12 @@ func checkWorkersMatchTopology(clusterName string, namespace string, workerGroup
acceleratorType := workerGroupSpec.Template.Spec.NodeSelector["cloud.google.com/gke-tpu-accelerator"]
klog.V(1).InfoS("checkWorkersMatchTopology", "RayCluster", namespace+"/"+clusterName, "topology", topology, "AcceleratorType", acceleratorType, "NumOfHosts", numHosts)
if topology == "" {
klog.ErrorS(errors.New("TPU topology not specified"), "RayCluster", namespace+"/"+clusterName, "gke-tpu-topology", topology)
klog.ErrorS(errors.New("TPU topology not specified"), "checkWorkersMatchTopology", "RayCluster", namespace+"/"+clusterName, "gke-tpu-topology", topology)
}
if acceleratorType == "" {
klog.ErrorS(errors.New("TPU accelerator not specified"), "RayCluster", namespace+"/"+clusterName, "gke-tpu-accelerator", acceleratorType)
klog.ErrorS(errors.New("TPU accelerator not specified"), "checkWorkersMatchTopology", "RayCluster", namespace+"/"+clusterName, "gke-tpu-accelerator", acceleratorType)
}
expectedHosts, err := getNumTPUHostsFromTopology(clusterName, namespace, topology, acceleratorType)
expectedHosts, err := getNumTPUHostsFromTopology(clusterName, groupName, namespace, topology, acceleratorType)
if err != nil {
return false, err
}
Expand Down Expand Up @@ -263,23 +266,29 @@ func validateRayCluster(admissionReview *admissionv1.AdmissionReview) (*admissio
}
for i := 0; i < len(workerGroupSpecs); i++ {
workerGroupSpec := workerGroupSpecs[i]
// create mapping for pod slices -> TPU_WORKER_HOSTNAMES in cluster
replicas := int(*workerGroupSpec.Replicas)
numOfHosts := workerGroupSpec.NumOfHosts
for replicaIndex := 0; replicaIndex < replicas; replicaIndex++ {
// reset past sliceToWorkers and sliceToHostnames entries for slice in ray cluster
groupName := workerGroupSpec.GroupName
podSlice := slice{clusterName, groupName, replicaIndex, numOfHosts}
sliceToWorkers[podSlice] = nil
sliceToHostnames[podSlice] = ""
// generate TPU_WORKER_HOSTNAMES
if numOfHosts > 1 {
joinedHostNames, err := genDNSHostnames(workerGroupSpec, replicaIndex)
if err != nil {
klog.Error("Failed to generate DNS Hostnames")
if containerRequestingTPUs(workerGroupSpec.Template.Spec.Containers...) {
klog.V(0).InfoS("validateRayCluster", "RayCluster", namespace+"/"+clusterName, "Worker Group", workerGroupSpec.GroupName, "Requests TPUs", true)
// create mapping for pod slices -> TPU_WORKER_HOSTNAMES in cluster
replicas := int(*workerGroupSpec.Replicas)
numOfHosts := workerGroupSpec.NumOfHosts
for replicaIndex := 0; replicaIndex < replicas; replicaIndex++ {
// reset past sliceToWorkers and sliceToHostnames entries for slice in ray cluster
groupName := workerGroupSpec.GroupName
podSlice := slice{clusterName, groupName, namespace, replicaIndex, numOfHosts}
sliceToWorkers[podSlice] = nil
sliceToHostnames[podSlice] = ""
// generate TPU_WORKER_HOSTNAMES
if numOfHosts > 1 {
joinedHostNames, err := genDNSHostnames(workerGroupSpec, clusterName, namespace, replicaIndex)
if err != nil {
klog.Error("Failed to generate DNS Hostnames")
}
sliceToHostnames[podSlice] = joinedHostNames
}
sliceToHostnames[podSlice] = joinedHostNames
}
} else {
// RayCluster worker group does not request TPUs
klog.V(0).InfoS("validateRayCluster", "RayCluster", namespace+"/"+clusterName, "Worker Group", workerGroupSpec.GroupName, "Requests TPUs", false)
}
// validate NumOfHosts for worker group matches topology nodeSelector
workersMatchTopology, err := checkWorkersMatchTopology(clusterName, namespace, workerGroupSpec)
Expand All @@ -291,8 +300,8 @@ func validateRayCluster(admissionReview *admissionv1.AdmissionReview) (*admissio
admit = false
status = "Failure"
message = "Number of workers in worker group not equal to specified topology"
break
}
break
}

// Create AdmissionResponse
Expand All @@ -318,15 +327,28 @@ func getEnvironmentVariable(varName string, container corev1.Container) string {
return ""
}

// get next lowest-index pod slice to assign a pod to in the RayCluster
// this will be the first pod slice with # created pods < NumOfHosts
func getReplicaIndex(clusterName string, namespace string) int {
// gets the next lowest-index pod slice (worker group replica) to assign a pod to in the RayCluster
// there are three possible cases here:
// 1. sliceToWorkers is empty, this is the first pod the webhook intercepts
// - assign this pod to replica 0
// 2. The pod slice exists in sliceToWorkers, but has # created workers < NumOfHosts
// - assign this pod to the lowest index replica with # created workers < NumOfHosts
// - since we update isCreated when a worker is deleted, this allows us to assign re-created
// pods to the same replica
// 3. sliceToWorkers isn't empty, but all slices have # workers == NumOfHosts
// - this occurs when the pod we intercept is the first pod of a different slice in the cluster
// - we keep track of how many replicas of the same worker group have been added to sliceToWorkers
// so far, and assign this pod to the next integer replicaIndex
func getReplicaIndex(clusterName string, groupName string, namespace string) int {
// first pod created in cluster
if sliceToWorkers == nil {
return 0
}
nextLowestId := math.MaxInt32
numReplicas := 0 // tracks # of replicas in worker group created so far
for slice, workerList := range sliceToWorkers {
if slice.clusterName == clusterName {
if slice.clusterName == clusterName && slice.groupName == groupName && slice.namespace == namespace {
numReplicas++
createdPods := 0
for _, worker := range workerList {
if worker.isCreated {
Expand All @@ -340,10 +362,11 @@ func getReplicaIndex(clusterName string, namespace string) int {
}
}
}
// first pod of new slice in cluster
if nextLowestId == math.MaxInt32 {
klog.ErrorS(errors.New("Replica Index never set"), "RayCluster", namespace+"/"+clusterName, "Replica Index", nextLowestId)
nextLowestId = numReplicas
}
klog.V(0).InfoS("getReplicaIndex", "RayCluster", namespace+"/"+clusterName, "Replica Index", nextLowestId)
klog.V(0).InfoS("getReplicaIndex", "RayCluster", namespace+"/"+clusterName, "Worker Group", groupName, "Replica Index", nextLowestId)
return nextLowestId
}

Expand Down Expand Up @@ -379,7 +402,7 @@ func getNextWorkerID(podSlice slice, namespace string, replicaIndex int) int {
}
tpuWorkerID = nextLowestID
}
klog.V(0).InfoS("getNextWorkerID", "RayCluster", namespace+"/"+podSlice.clusterName, "TPU_WORKER_ID", tpuWorkerID)
klog.V(0).InfoS("getNextWorkerID", "RayCluster", namespace+"/"+podSlice.clusterName, "Worker Group", podSlice.groupName, "TPU_WORKER_ID", tpuWorkerID)
return tpuWorkerID
}

Expand Down Expand Up @@ -417,31 +440,31 @@ func mutatePod(admissionReview *admissionv1.AdmissionReview) (*admissionv1.Admis
if clusterName == "" {
return nil, errors.New("Kuberay Pod missing RayCluster label")
}
namespace := pod.Namespace
groupName := pod.Labels["ray.io/group"]
topology := pod.Spec.NodeSelector["cloud.google.com/gke-tpu-topology"]
acceleratorType := pod.Spec.NodeSelector["cloud.google.com/gke-tpu-accelerator"]
if topology == "" {
klog.ErrorS(errors.New("TPU topology not specified"), "RayCluster", namespace+"/"+clusterName, "gke-tpu-topology", topology)
}
if acceleratorType == "" {
klog.ErrorS(errors.New("TPU accelerator not specified"), "RayCluster", namespace+"/"+clusterName, "gke-tpu-accelerator", acceleratorType)
}
containers := pod.Spec.Containers
if containers == nil {
return nil, errors.New("Container path not specified")
}
if containerRequestingTPUs(containers...) {
namespace := pod.Namespace
groupName := pod.Labels["ray.io/group"]
topology := pod.Spec.NodeSelector["cloud.google.com/gke-tpu-topology"]
acceleratorType := pod.Spec.NodeSelector["cloud.google.com/gke-tpu-accelerator"]
if topology == "" {
klog.ErrorS(errors.New("TPU topology not specified"), "mutatePod", "RayCluster", namespace+"/"+clusterName, "gke-tpu-topology", topology)
}
if acceleratorType == "" {
klog.ErrorS(errors.New("TPU accelerator not specified"), "mutatePod", "RayCluster", namespace+"/"+clusterName, "gke-tpu-accelerator", acceleratorType)
}
// assign worker to the next unique ID in the pod slice and update map
numOfHosts, _ := getNumTPUHostsFromTopology(clusterName, namespace, topology, acceleratorType) // ignore error here because topology may not be set yet
replicaIndex := getReplicaIndex(clusterName, namespace)
podSlice := slice{clusterName, groupName, replicaIndex, numOfHosts}
numOfHosts, _ := getNumTPUHostsFromTopology(clusterName, groupName, namespace, topology, acceleratorType) // ignore error here because topology may not be set yet
replicaIndex := getReplicaIndex(clusterName, groupName, namespace)
podSlice := slice{clusterName, groupName, namespace, replicaIndex, numOfHosts}
tpuWorkerID := getNextWorkerID(podSlice, namespace, replicaIndex) // defaults to 0 for single-host

// inject replica index label
injectReplicaLabel(clusterName, namespace, replicaIndex, groupName, &patches)

isMultiHost, _ := isTPUMultiHost(clusterName, namespace, topology, acceleratorType) // ignore error here because topology may not be set yet
isMultiHost, _ := isTPUMultiHost(clusterName, groupName, namespace, topology, acceleratorType) // ignore error here because topology may not be set yet
if isMultiHost {
// inject hostname into pod spec for DNS records
hostname := fmt.Sprintf(groupName+"-%d-%d", replicaIndex, tpuWorkerID)
Expand Down Expand Up @@ -545,7 +568,7 @@ func deletePod(admissionReview *admissionv1.AdmissionReview) (*admissionv1.Admis
if replicaIndexLabel != "" {
replicaIndexLabelValues := strings.Split(replicaIndexLabel, "-")
replicaIndex, _ := strconv.Atoi(replicaIndexLabelValues[1]) // ignore error here since must be set

containers := pod.Spec.Containers
if containers == nil {
return nil, errors.New("Pod spec missing containers")
Expand All @@ -565,7 +588,7 @@ func deletePod(admissionReview *admissionv1.AdmissionReview) (*admissionv1.Admis
}
// update sliceToWorkers map
for slice, _ := range sliceToWorkers {
if slice.clusterName == clusterName && slice.groupName == groupName && slice.replicaIndex == replicaIndex {
if slice.clusterName == clusterName && slice.groupName == groupName && slice.namespace == namespace && slice.replicaIndex == replicaIndex {
// set the pod state to indicate it is not running
for index, worker := range sliceToWorkers[slice] {
if worker.workerIndex == tpuWorkerID {
Expand Down
Loading

0 comments on commit 77c3be2

Please sign in to comment.