Skip to content

Commit

Permalink
Add in check for v5e TPU pods
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanaoleary committed Mar 8, 2024
1 parent 2d53056 commit 8a74e2c
Showing 1 changed file with 36 additions and 12 deletions.
48 changes: 36 additions & 12 deletions applications/ray/kuberay-tpu-webhook/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ import (

// our representation of a pod slice
type slice struct {
clusterName string
groupName string
replicaIndex int
numOfHosts int32
clusterName string
groupName string
replicaIndex int
numOfHosts int32
}

// our representation of a worker pod
Expand Down Expand Up @@ -77,7 +77,7 @@ func containerRequestingTPUs(containers ...corev1.Container) bool {
return false
}

func getNumTPUHostsFromTopology(topology string) (int32, error) {
func getNumTPUHostsFromTopology(topology string, acceleratorType string) (int32, error) {
if topology == "" {
return 0, errors.New("TPU topology not specified")
}
Expand All @@ -91,13 +91,29 @@ func getNumTPUHostsFromTopology(topology string) (int32, error) {
}
chips *= dim
}
// number VMs = number chips / 4
return int32(max(chips/4, 1)), nil
// calculate the # of VMs using # of chips per host
acceleratorTypeValues := strings.Split(acceleratorType, "-")
var vms int32
if acceleratorTypeValues[0] == "v5litepod" {
// v5e TPU VMs can have 1, 4 or 8 chips
chipsPerHost, err := strconv.Atoi(acceleratorTypeValues[1])
if err != nil {
klog.Errorf("Unexpected acceleratorType: %s", acceleratorType)
}
if chipsPerHost > 8 {
chipsPerHost = 4
}
vms = int32(max(chips/chipsPerHost, 1))
} else {
// otherwise default to 4 chips per VM
vms = int32(max(chips/4, 1))
}
return vms, nil
}

// check if request is for TPU multi-host
func isTPUMultiHost(topology string) (bool, error) {
vms, err := getNumTPUHostsFromTopology(topology)
func isTPUMultiHost(topology string, acceleratorType string) (bool, error) {
vms, err := getNumTPUHostsFromTopology(topology, acceleratorType)
if err != nil {
return false, err
}
Expand Down Expand Up @@ -207,10 +223,14 @@ func checkWorkersMatchTopology(workerGroupSpec ray.WorkerGroupSpec) (bool, error
}
if containerRequestingTPUs(containers...) {
topology := workerGroupSpec.Template.Spec.NodeSelector["cloud.google.com/gke-tpu-topology"]
acceleratorType := workerGroupSpec.Template.Spec.NodeSelector["cloud.google.com/gke-tpu-accelerator"]
if topology == "" {
klog.Error("TPU topology not specified")
}
expectedHosts, err := getNumTPUHostsFromTopology(topology)
if acceleratorType == "" {
klog.Error("TPU accelerator type not specified")
}
expectedHosts, err := getNumTPUHostsFromTopology(topology, acceleratorType)
if err != nil {
return false, err
}
Expand Down Expand Up @@ -387,21 +407,25 @@ func mutatePod(admissionReview *admissionv1.AdmissionReview) (*admissionv1.Admis
clusterName := pod.Labels["ray.io/cluster"]
groupName := pod.Labels["ray.io/group"]
topology := pod.Spec.NodeSelector["cloud.google.com/gke-tpu-topology"]
acceleratorType := pod.Spec.NodeSelector["cloud.google.com/gke-tpu-accelerator"]
if topology == "" {
klog.Error("TPU topology not specified")
}
if acceleratorType == "" {
klog.Error("TPU accelerator type not specified")
}
containers := pod.Spec.Containers
if containers == nil {
return nil, errors.New("Container path not specified")
}
if containerRequestingTPUs(containers...) {
// assign worker to the next unique ID in the pod slice and update map
numOfHosts, _ := getNumTPUHostsFromTopology(topology) // ignore error here because topology may not be set yet
numOfHosts, _ := getNumTPUHostsFromTopology(topology, acceleratorType) // ignore error here because topology may not be set yet
replicaIndex := getReplicaIndex(clusterName)
podSlice := slice{clusterName, groupName, replicaIndex, numOfHosts}
tpuWorkerID := getNextWorkerID(podSlice, replicaIndex)

isMultiHost, _ := isTPUMultiHost(topology) // ignore error here because topology may not be set yet
isMultiHost, _ := isTPUMultiHost(topology, acceleratorType) // ignore error here because topology may not be set yet
if isMultiHost {
// inject hostname into pod spec for DNS records
hostname := fmt.Sprintf(groupName+"-%d-%d", replicaIndex, tpuWorkerID)
Expand Down

0 comments on commit 8a74e2c

Please sign in to comment.