From e00d42d40aaff1365b0543e60260fbe28a63abea Mon Sep 17 00:00:00 2001 From: Nick Stogner Date: Mon, 9 Sep 2024 12:23:07 -0400 Subject: [PATCH] Add annotation to specify which labels to copy to nodes --- tpu-provisioner/internal/cloud/common.go | 3 ++ tpu-provisioner/internal/cloud/gke.go | 18 +++++++++ tpu-provisioner/internal/cloud/gke_test.go | 43 +++++++++++++++++++--- 3 files changed, 59 insertions(+), 5 deletions(-) diff --git a/tpu-provisioner/internal/cloud/common.go b/tpu-provisioner/internal/cloud/common.go index ee4305af5..468b3c5d9 100644 --- a/tpu-provisioner/internal/cloud/common.go +++ b/tpu-provisioner/internal/cloud/common.go @@ -24,6 +24,9 @@ const ( LabelProvisionerNodepoolID = "provisioner-nodepool-id" + // AnnotationCopyLabels is a comma-separated list of labels to copy from the Pod to the node pool config (Nodes). + AnnotationCopyLabels = "tpu-provisioner.cloud.google.com/copy-labels" + EventNodePoolCreationStarted = "NodePoolCreationStarted" EventNodePoolCreationSucceeded = "NodePoolCreationSucceeded" EventNodePoolCreationFailed = "NodePoolCreationFailed" diff --git a/tpu-provisioner/internal/cloud/gke.go b/tpu-provisioner/internal/cloud/gke.go index 4c8a4a92a..cba11e6b7 100644 --- a/tpu-provisioner/internal/cloud/gke.go +++ b/tpu-provisioner/internal/cloud/gke.go @@ -269,6 +269,17 @@ func (g *GKE) nodePoolForPod(name string, p *corev1.Pod) (*containerv1beta1.Node } } + // Copy labels specified by annotation to the Node. + for _, key := range strings.Split(getAnnotation(p, AnnotationCopyLabels), ",") { + key = strings.TrimSpace(key) + if key == "" { + continue + } + if val, ok := p.Labels[key]; ok { + labels[key] = val + } + } + for labelKey, labelValue := range p.Spec.NodeSelector { switch labelKey { case ICIResiliencyLabel: @@ -492,3 +503,10 @@ func min(a, b int) int { } return b } + +func getAnnotation(p *corev1.Pod, key string) string { + if p.Annotations == nil { + return "" + } + return p.Annotations[key] +} diff --git a/tpu-provisioner/internal/cloud/gke_test.go b/tpu-provisioner/internal/cloud/gke_test.go index 38df47c61..55bdd795f 100644 --- a/tpu-provisioner/internal/cloud/gke_test.go +++ b/tpu-provisioner/internal/cloud/gke_test.go @@ -220,11 +220,12 @@ func TestPodToNodePoolName(t *testing.T) { func TestNodePoolForPod(t *testing.T) { trueVar := true tests := []struct { - desc string - gkeContext GKEContext - additionalLabels map[string]string - selector map[string]string - want *containerv1beta1.NodePool + desc string + gkeContext GKEContext + additionalLabels map[string]string + additionalAnnotations map[string]string + selector map[string]string + want *containerv1beta1.NodePool }{ { desc: "simple case", @@ -482,6 +483,38 @@ func TestNodePoolForPod(t *testing.T) { UpgradeSettings: &container.UpgradeSettings{MaxSurge: 1}, }, }, + { + desc: "labels to copy from pod to node by annotation", + additionalLabels: map[string]string{ + "copy-me": "val-x", + "dont-copy-me": "val-y", + }, + additionalAnnotations: map[string]string{ + "tpu-provisioner.cloud.google.com/copy-labels": "copy-me", + }, + want: &containerv1beta1.NodePool{ + Config: &container.NodeConfig{ + Labels: map[string]string{ + "google.com/nodepool-manager": "tpu-provisioner", + "google.com/tpu-provisioner-jobset-name": "jobset-test", + "google.com/tpu-provisioner-jobset-namespace": "default", + "google.com/tpu-provisioner-parent-kind": "job", + "google.com/tpu-provisioner-parent-name": "jobset-test-job-1-0", + "google.com/tpu-provisioner-parent-namespace": "default", + "copy-me": "val-x", + }, + MachineType: "ct5p-hightpu-4t", + ShieldedInstanceConfig: &container.ShieldedInstanceConfig{EnableIntegrityMonitoring: true}, + }, + InitialNodeCount: 512, + Locations: []string{""}, + Management: &container.NodeManagement{AutoRepair: true, AutoUpgrade: false}, + MaxPodsConstraint: &container.MaxPodsConstraint{MaxPodsPerNode: 15}, + Name: "test-pool", + PlacementPolicy: &container.PlacementPolicy{TpuTopology: "8x16x16", Type: "COMPACT"}, + UpgradeSettings: &container.UpgradeSettings{MaxSurge: 1}, + }, + }, } for _, tc := range tests { t.Run(tc.desc, func(t *testing.T) {