diff --git a/tests/odh/mnist_ray_test.go b/tests/odh/mnist_ray_test.go index 0d88e6e6..6e86151b 100644 --- a/tests/odh/mnist_ray_test.go +++ b/tests/odh/mnist_ray_test.go @@ -33,29 +33,45 @@ import ( ) func TestMnistRayCpu(t *testing.T) { - mnistRay(t, 0) + mnistRay(t, 0, false) } func TestMnistRayGpu(t *testing.T) { - mnistRay(t, 1) + mnistRay(t, 1, false) +} + +func TestMnistRayAMDGpu(t *testing.T) { + mnistRay(t, 1, true) } func TestMnistCustomRayImageCpu(t *testing.T) { - mnistRay(t, 0) + mnistRay(t, 0, false) } func TestMnistCustomRayImageGpu(t *testing.T) { - mnistRay(t, 1) + mnistRay(t, 1, false) } -func mnistRay(t *testing.T, numGpus int) { +func mnistRay(t *testing.T, numGpus int, amd bool) { test := With(t) // Create a namespace namespace := test.NewTestNamespace() + var gpuResource string + if amd { + gpuResource = "amd.com/gpu" + } else { + gpuResource = "nvidia.com/gpu" + } + // Get ray image - rayImage := GetRayImage() + var rayImage string + if amd { + rayImage = GetRayAMDGpuImage() + } else { + rayImage = GetRayImage() + } // Create Kueue resources resourceFlavor := CreateKueueResourceFlavor(test, v1beta1.ResourceFlavorSpec{}) @@ -64,7 +80,7 @@ func mnistRay(t *testing.T, numGpus int) { NamespaceSelector: &metav1.LabelSelector{}, ResourceGroups: []v1beta1.ResourceGroup{ { - CoveredResources: []corev1.ResourceName{corev1.ResourceName("cpu"), corev1.ResourceName("memory"), corev1.ResourceName("nvidia.com/gpu")}, + CoveredResources: []corev1.ResourceName{corev1.ResourceName("cpu"), corev1.ResourceName("memory"), corev1.ResourceName(gpuResource)}, Flavors: []v1beta1.FlavorQuotas{ { Name: v1beta1.ResourceFlavorReference(resourceFlavor.Name), @@ -78,7 +94,7 @@ func mnistRay(t *testing.T, numGpus int) { NominalQuota: resource.MustParse("12Gi"), }, { - Name: corev1.ResourceName("nvidia.com/gpu"), + Name: corev1.ResourceName(gpuResource), NominalQuota: resource.MustParse(fmt.Sprint(numGpus)), }, }, @@ -99,9 +115,13 @@ func mnistRay(t *testing.T, numGpus int) { } else { mnist = bytes.Replace(mnist, []byte("accelerator=\"has to be specified\""), []byte("accelerator=\"cpu\""), 1) } + jupyterNotebook := ReadFile(test, "resources/mnist_ray_mini.ipynb") + if amd { + jupyterNotebook = bytes.ReplaceAll(jupyterNotebook, []byte("nvidia.com/gpu"), []byte("amd.com/gpu")) + } config := CreateConfigMap(test, namespace.Name, map[string][]byte{ // MNIST Ray Notebook - jupyterNotebookConfigMapFileName: ReadFile(test, "resources/mnist_ray_mini.ipynb"), + jupyterNotebookConfigMapFileName: jupyterNotebook, "mnist.py": mnist, "requirements.txt": ReadFile(test, "resources/requirements.txt"), })