Skip to content

Commit

Permalink
Add inter-pod anti-affinity rule in master node spec to schedule node…
Browse files Browse the repository at this point in the history
…s on different cluster GPU/CPU nodes
  • Loading branch information
abhijeet-dhumal committed Dec 20, 2024
1 parent c5686c8 commit 4b92b1d
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 77 deletions.
98 changes: 23 additions & 75 deletions tests/kfto/kfto_mnist_training_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,13 @@ package kfto
import (
"bytes"
"fmt"
"os"
"testing"

kftov1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
. "github.com/onsi/gomega"
. "github.com/project-codeflare/codeflare-common/support"

corev1 "k8s.io/api/core/v1"
storagev1 "k8s.io/api/storage/v1"
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)
Expand All @@ -47,32 +45,12 @@ func TestPyTorchJobMnistWithROCm(t *testing.T) {
func runKFTOPyTorchMnistJob(t *testing.T, numGpus int, workerReplicas int, gpuLabel string, image string, requirementsFile string) {
test := With(t)

storageClasses, err := test.Client().Core().StorageV1().StorageClasses().List(test.Ctx(), metav1.ListOptions{})
test.Expect(err).NotTo(HaveOccurred(), "Failed to list StorageClasses")

// Verify at least one StorageClass supports RWX
foundRWX := false
var storageClassWithRWX storagev1.StorageClass
for _, sc := range storageClasses.Items {
// Check the allowed access modes in the StorageClass annotations
if checkStorageClassSupportsRWX(sc) {
foundRWX = true
storageClassWithRWX = sc
break
}
}
test.Expect(foundRWX).To(BeTrue(), "No StorageClass found with RWX access mode")

// Create a namespace
namespace := test.NewTestNamespace()

workingDirectory, err := os.Getwd()
test.Expect(err).ToNot(HaveOccurred())
mnist := ReadFile(test, "resources/mnist.py")
requirementsFileName := ReadFile(test, requirementsFile)

mnist, err := os.ReadFile(workingDirectory + "/resources/mnist.py")
test.Expect(err).ToNot(HaveOccurred())

requirementsFileName, err := os.ReadFile(workingDirectory + "/" + requirementsFile)
if numGpus > 0 {
mnist = bytes.Replace(mnist, []byte("accelerator=\"has to be specified\""), []byte("accelerator=\"gpu\""), 1)
} else {
Expand All @@ -84,7 +62,7 @@ func runKFTOPyTorchMnistJob(t *testing.T, numGpus int, workerReplicas int, gpuLa
"requirements.txt": requirementsFileName,
})

outputPvc := CreatePersistentVolumeClaimWithStorageClass(test, namespace.Name, storageClassWithRWX, "50Gi", corev1.ReadWriteMany)
outputPvc := CreatePersistentVolumeClaim(test, namespace.Name, "50Gi", corev1.ReadWriteOnce)
defer test.Client().Core().CoreV1().PersistentVolumeClaims(namespace.Name).Delete(test.Ctx(), outputPvc.Name, metav1.DeleteOptions{})

// Create training PyTorch job
Expand Down Expand Up @@ -126,7 +104,26 @@ func createKFTOPyTorchMnistJob(test Test, namespace string, config corev1.Config
Replicas: Ptr(int32(1)),
RestartPolicy: kftov1.RestartPolicyOnFailure,
Template: corev1.PodTemplateSpec{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
"app": "kfto-mnist",
},
},
Spec: corev1.PodSpec{
Affinity: &corev1.Affinity{
PodAntiAffinity: &corev1.PodAntiAffinity{
RequiredDuringSchedulingIgnoredDuringExecution: []corev1.PodAffinityTerm{
{
LabelSelector: &metav1.LabelSelector{
MatchLabels: map[string]string{
"app": "kfto-mnist",
},
},
TopologyKey: "kubernetes.io/hostname",
},
},
},
},
Containers: []corev1.Container{
{
Name: "pytorch",
Expand Down Expand Up @@ -217,7 +214,7 @@ func createKFTOPyTorchMnistJob(test Test, namespace string, config corev1.Config
"/bin/bash", "-c",
fmt.Sprintf(`mkdir -p /tmp/lib && export PYTHONPATH=$PYTHONPATH:/tmp/lib && \
pip install --no-cache-dir -r /mnt/files/requirements.txt --target=/tmp/lib && \
python /mnt/files/mnist.py --epochs 1 --save-model --output-path /mnt/output --backend %s`, backend),
python /mnt/files/mnist.py --epochs 1 --save-model --backend %s`, backend),
},
VolumeMounts: []corev1.VolumeMount{
{
Expand All @@ -228,10 +225,6 @@ func createKFTOPyTorchMnistJob(test Test, namespace string, config corev1.Config
Name: "tmp-volume",
MountPath: "/tmp",
},
{
Name: "output-volume",
MountPath: "/mnt/output",
},
},
},
},
Expand All @@ -252,14 +245,6 @@ func createKFTOPyTorchMnistJob(test Test, namespace string, config corev1.Config
EmptyDir: &corev1.EmptyDirVolumeSource{},
},
},
{
Name: "output-volume",
VolumeSource: corev1.VolumeSource{
PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{
ClaimName: outputPvcName,
},
},
},
},
RestartPolicy: corev1.RestartPolicyOnFailure,
},
Expand Down Expand Up @@ -307,40 +292,3 @@ func createKFTOPyTorchMnistJob(test Test, namespace string, config corev1.Config

return tuningJob
}

func checkStorageClassSupportsRWX(sc storagev1.StorageClass) bool {
// Provisioners like nfs.csi.k8s.io or kubernetes.io/nfs usually support RWX by default.
if sc.Provisioner == "nfs.csi.k8s.io" || sc.Provisioner == "kubernetes.io/nfs" {
return true
}
return false
}

func CreatePersistentVolumeClaimWithStorageClass(t Test, namespace string, storageClass storagev1.StorageClass, storageSize string, accessMode ...corev1.PersistentVolumeAccessMode) *corev1.PersistentVolumeClaim {
t.T().Helper()

pvc := &corev1.PersistentVolumeClaim{
TypeMeta: metav1.TypeMeta{
APIVersion: corev1.SchemeGroupVersion.String(),
Kind: "PersistentVolumeClaim",
},
ObjectMeta: metav1.ObjectMeta{
GenerateName: "pvc-",
Namespace: namespace,
},
Spec: corev1.PersistentVolumeClaimSpec{
AccessModes: accessMode,
Resources: corev1.VolumeResourceRequirements{
Requests: corev1.ResourceList{
corev1.ResourceStorage: resource.MustParse(storageSize),
},
},
StorageClassName: &storageClass.Name,
},
}
pvc, err := t.Client().Core().CoreV1().PersistentVolumeClaims(namespace).Create(t.Ctx(), pvc, metav1.CreateOptions{})
t.Expect(err).NotTo(HaveOccurred())
t.T().Logf("Created PersistentVolumeClaim %s/%s successfully", pvc.Namespace, pvc.Name)

return pvc
}
2 changes: 0 additions & 2 deletions tests/kfto/resources/mnist.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from __future__ import print_function

import argparse
import os

Expand Down

0 comments on commit 4b92b1d

Please sign in to comment.