From 911b9c2bf9bad5883e14ddca8babedf212a35083 Mon Sep 17 00:00:00 2001 From: Nick Stogner Date: Wed, 6 Mar 2024 14:21:30 -0500 Subject: [PATCH] Update test cases to include reuse assertions --- tpu-provisioner/test/e2e/test/jobset_test.go | 178 +++++++++++++------ 1 file changed, 120 insertions(+), 58 deletions(-) diff --git a/tpu-provisioner/test/e2e/test/jobset_test.go b/tpu-provisioner/test/e2e/test/jobset_test.go index f60052861..4e55099e0 100644 --- a/tpu-provisioner/test/e2e/test/jobset_test.go +++ b/tpu-provisioner/test/e2e/test/jobset_test.go @@ -38,72 +38,82 @@ const ( testCaseLabel = "test-case" ) +var ( + /* + tpu_v4_2x2x2 = tpuConfig{ + accelerator: "tpu-v4-podslice", + topoX: 2, + topoY: 2, + topoZ: 2, + chipsPerNode: 4, + sliceCount: 1, + } + tpu_v4_2x2x4 = tpuConfig{ + accelerator: "tpu-v4-podslice", + topoX: 2, + topoY: 2, + topoZ: 4, + chipsPerNode: 4, + sliceCount: 1, + } + + tpu_v5e_2x4 = tpuConfig{ + accelerator: "tpu-v5-lite-podslice", + topoX: 2, + topoY: 4, + chipsPerNode: 4, + sliceCount: 2, + } + */ + + tpu_v5p_2x2x2 = tpuConfig{ + accelerator: "tpu-v5p-slice", + topoX: 2, + topoY: 2, + topoZ: 2, + chipsPerNode: 4, + sliceCount: 1, + } +) + func TestTPUJobsets(t *testing.T) { - var ( - spot = os.Getenv("TEST_SPOT") == "true" - reservation = os.Getenv("TEST_RESERVATION") - ) + var ( + spot = os.Getenv("TEST_SPOT") == "true" + reservation = os.Getenv("TEST_RESERVATION") + ) cases := []struct { - name string - config tpuConfig + name string + tpu tpuConfig + + uniqueNodeSelector bool + shouldReuse bool }{ - // v4 - // { - // name: "v4-2x2x2-tpu", - // config: tpuConfig{ - // accelerator: "tpu-v4-podslice", - // topoX: 2, - // topoY: 2, - // topoZ: 2, - // chipsPerNode: 4, - // sliceCount: 1, - // }, - // }, - // { - // name: "v4-2x2x4-tpu", - // config: tpuConfig{ - // accelerator: "tpu-v4-podslice", - // topoX: 2, - // topoY: 2, - // topoZ: 4, - // chipsPerNode: 4, - // sliceCount: 1, - // }, - // }, - // v5e - /* - { - name: "v5e-2x4-tpu", - config: tpuConfig{ - accelerator: "tpu-v5-lite-podslice", - topoX: 2, - topoY: 4, - chipsPerNode: 4, - sliceCount: 2, - }, - }, - */ - // v5p { - name: "v5p-2x2x2-tpu", - config: tpuConfig{ - accelerator: "tpu-v5p-slice", - topoX: 2, - topoY: 2, - topoZ: 2, - chipsPerNode: 4, - sliceCount: 1, - }, + name: "first-unique", + tpu: tpu_v5p_2x2x2, + uniqueNodeSelector: true, + }, + { + name: "second-unique", + tpu: tpu_v5p_2x2x2, + uniqueNodeSelector: true, + }, + { + name: "third-reuse", + tpu: tpu_v5p_2x2x2, + shouldReuse: true, }, } + historicalNodePools := map[string]struct{}{} + for _, c := range cases { t.Run(c.name, func(t *testing.T) { - c.config.spot = spot - c.config.reservation = reservation + c.tpu.spot = spot + c.tpu.reservation = reservation - js := newJobset(c.name, c.config) + js := newJobset(c.name, c.tpu, c.uniqueNodeSelector) err := client.Create(ctx, js) require.NoError(t, err) util.EnsureCleanup(t, func() { @@ -111,6 +121,34 @@ func TestTPUJobsets(t *testing.T) { require.NoError(t, err) }) + var nodePoolName string + require.EventuallyWithT(t, func(t *assert.CollectT) { + var pods v1.PodList + err := client.List(ctx, &pods, runtimeclient.MatchingLabels{testCaseLabel: c.name}) + assert.NoError(t, err, "Failed to list pods") + for _, pod := range pods.Items { + var err error + nodePoolName, err = podToNodePoolName(&pod) + if err != nil { + t.Errorf("pod to node pool name: %v", err) + return + } + if nodePoolName != "" { + return + } + } + t.Errorf("no pods scheduled on node pool") + }, jobsetCompletionTimeout, time.Second, "Pods not scheduled") + + require.NotEmpty(t, nodePoolName, "No node pool name found") + if c.shouldReuse { + require.Contains(t, historicalNodePools, nodePoolName, "Should reuse a previously created node pool") + } + if c.uniqueNodeSelector { + require.NotContains(t, historicalNodePools, nodePoolName, "Expected new node pool to be created") + } + historicalNodePools[nodePoolName] = struct{}{} + // Example completed JobSet status: // // status: @@ -134,16 +172,33 @@ func TestTPUJobsets(t *testing.T) { "JobSet is not completed") }, jobsetCompletionTimeout, time.Second, "JobSet did not complete") + }) + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { require.EventuallyWithT(t, func(t *assert.CollectT) { var nodeList v1.NodeList err := client.List(ctx, &nodeList, runtimeclient.MatchingLabels{testCaseLabel: c.name}) assert.NoError(t, err, "Failed to list Nodes") assert.Len(t, nodeList.Items, 0, "Nodes still exist with test case label") - }, nodeDeletionTimeout, time.Second, "Nodes were not deleted after JobSet completed") + }, nodeDeletionTimeout, time.Second, "Nodes were not deleted") }) } } +func podToNodePoolName(pod *v1.Pod) (string, error) { + if pod.Spec.NodeName == "" { + return "", fmt.Errorf("pod %s/%s has no node name", pod.Namespace, pod.Name) + } + var node v1.Node + if err := client.Get(ctx, runtimeclient.ObjectKey{Name: pod.Spec.NodeName}, &node); err != nil { + return "", fmt.Errorf("getting node for pod: %w", err) + } + + return node.Labels["cloud.google.com/gke-nodepool"], nil +} + /* https://cloud.google.com/tpu/docs/tpus-in-gke#v5e @@ -191,13 +246,15 @@ func (t *tpuConfig) nodesPerSlice() int32 { return t.topoX * t.topoY * z / t.chipsPerNode } -func newJobset(name string, c tpuConfig) *jobset.JobSet { +func newJobset(name string, c tpuConfig, uniqueNodeSelector bool) *jobset.JobSet { nodeSelectors := map[string]string{ "cloud.google.com/gke-tpu-accelerator": c.accelerator, "cloud.google.com/gke-tpu-topology": c.topology(), // Ensure each test case triggers its down node pool scale-up. - testCaseLabel: name, + } + if uniqueNodeSelector { + nodeSelectors[testCaseLabel] = name } if c.reservation != "" { nodeSelectors["cloud.google.com/reservation-name"] = c.reservation @@ -232,6 +289,11 @@ func newJobset(name string, c tpuConfig) *jobset.JobSet { Completions: ptr.To(c.nodesPerSlice()), BackoffLimit: ptr.To(c.nodesPerSlice()), Template: v1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Labels: map[string]string{ + testCaseLabel: name, + }, + }, Spec: v1.PodSpec{ NodeSelector: nodeSelectors, Containers: []v1.Container{