diff --git a/tpu-provisioner/cmd/main.go b/tpu-provisioner/cmd/main.go index aaabea79a..2a841beb7 100644 --- a/tpu-provisioner/cmd/main.go +++ b/tpu-provisioner/cmd/main.go @@ -22,6 +22,7 @@ import ( "net/http" "os" "strings" + "sync" "time" "k8s.io/apimachinery/pkg/runtime/schema" @@ -76,12 +77,15 @@ func main() { GCPNodeTags []string `envconfig:"GCP_NODE_TAGS"` GCPNodeSecondaryDisk string `envconfig:"GCP_NODE_SECONDARY_DISK" default:""` + GCPNodeSecureBoot bool `envconfig:"GCP_NODE_SECURE_BOOT" default:"true"` // NodeMinLifespan is the amount of time that should pass between a Node object // creation and a cleanup of that Node. This needs to be long enough to allow // the node to become Ready and for a pending Pod to be scheduled on it. NodeMinLifespan time.Duration `envconfig:"NODE_MIN_LIFESPAN" default:"3m"` + NodepoolDeletionDelay time.Duration `envconfig:"NODEPOOL_DELETION_DELAY" default:"30s"` + PodResourceType string `envconfig:"POD_RESOURCE_TYPE" default:"google.com/tpu"` Concurrency int `envconfig:"CONCURRENCY" default:"3"` @@ -197,7 +201,9 @@ func main() { NodeServiceAccount: cfg.GCPNodeServiceAccount, NodeSecondaryDisk: cfg.GCPNodeSecondaryDisk, NodeTags: cfg.GCPNodeTags, + NodeSecureBoot: cfg.GCPNodeSecureBoot, }, + Recorder: mgr.GetEventRecorderFor("tpu-provisioner"), } case "mock": provider = &cloud.Mock{} @@ -209,7 +215,7 @@ func main() { if err := (&controller.CreationReconciler{ Client: mgr.GetClient(), Scheme: mgr.GetScheme(), - Recorder: mgr.GetEventRecorderFor("tpu-provisioner-creator"), + Recorder: mgr.GetEventRecorderFor("tpu-provisioner"), Provider: provider, PodCriteria: controller.PodCriteria{ ResourceType: cfg.PodResourceType, @@ -222,10 +228,11 @@ func main() { if err := (&controller.DeletionReconciler{ Client: mgr.GetClient(), Scheme: mgr.GetScheme(), - Recorder: mgr.GetEventRecorderFor("tpu-provisioner-deleter"), + Recorder: mgr.GetEventRecorderFor("tpu-provisioner"), Provider: provider, NodeCriteria: controller.NodeCriteria{ - MinLifetime: cfg.NodeMinLifespan, + MinLifetime: cfg.NodeMinLifespan, + PoolDeletionDelay: cfg.NodepoolDeletionDelay, }, }).SetupWithManager(mgr); err != nil { setupLog.Error(err, "unable to create controller", "controller", "DeletionReconciler") @@ -241,10 +248,27 @@ func main() { setupLog.Error(err, "unable to set up ready check") os.Exit(1) } + ctx := ctrl.SetupSignalHandler() + + gc := &controller.NodePoolGarbageCollector{ + Interval: time.Minute, + Client: mgr.GetClient(), + Provider: provider, + } + var wg sync.WaitGroup + wg.Add(1) + go func() { + gc.Run(ctx) + wg.Done() + }() setupLog.Info("starting manager") - if err := mgr.Start(ctrl.SetupSignalHandler()); err != nil { + if err := mgr.Start(ctx); err != nil { setupLog.Error(err, "problem running manager") os.Exit(1) } + + setupLog.Info("waiting for all goroutines to finish") + wg.Wait() + setupLog.Info("exiting") } diff --git a/tpu-provisioner/internal/cloud/common.go b/tpu-provisioner/internal/cloud/common.go new file mode 100644 index 000000000..f4ac2b6c1 --- /dev/null +++ b/tpu-provisioner/internal/cloud/common.go @@ -0,0 +1,55 @@ +package cloud + +import ( + "errors" + "time" + + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +const ( + keyPrefix = "google.com/" + + LabelNodepoolManager = keyPrefix + "nodepool-manager" + LabelNodepoolManagerTPUPodinator = "tpu-provisioner" + + LabelParentKind = keyPrefix + "tpu-provisioner-parent-kind" + LabelParentName = keyPrefix + "tpu-provisioner-parent-name" + LabelParentNamespace = keyPrefix + "tpu-provisioner-parent-namespace" + + LabelPodName = keyPrefix + "tpu-provisioner-pod-name" + LabelPodNamespace = keyPrefix + "tpu-provisioner-pod-namespace" + + EventNodePoolCreationStarted = "NodePoolCreationStarted" + EventNodePoolCreationSucceeded = "NodePoolCreationSucceeded" + EventNodePoolCreationFailed = "NodePoolCreationFailed" + + EventNodePoolDeletionStarted = "NodePoolDeletionStarted" + EventNodePoolDeletionSucceeded = "NodePoolDeletionSucceeded" + EventNodePoolDeletionFailed = "NodePoolDeletionFailed" + + EventNodePoolNotFound = "NodePoolNotFound" +) + +type Provider interface { + NodePoolLabelKey() string + EnsureNodePoolForPod(*corev1.Pod, string) error + DeleteNodePoolForNode(*corev1.Node, string) error + DeleteNodePool(string, client.Object, string) error + ListNodePools() ([]NodePoolRef, error) +} + +var ErrDuplicateRequest = errors.New("duplicate request") + +type NodePoolRef struct { + Name string + + CreationTime time.Time + + CreatedForPod types.NamespacedName + + Error bool + Message string +} diff --git a/tpu-provisioner/internal/cloud/gke.go b/tpu-provisioner/internal/cloud/gke.go index fe872440e..f1f7fab79 100644 --- a/tpu-provisioner/internal/cloud/gke.go +++ b/tpu-provisioner/internal/cloud/gke.go @@ -13,6 +13,9 @@ import ( "google.golang.org/api/googleapi" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/tools/record" + "sigs.k8s.io/controller-runtime/pkg/client" logf "sigs.k8s.io/controller-runtime/pkg/log" ) @@ -37,21 +40,22 @@ const ( maxPodsPerNode = 15 ) +var _ Provider = &GKE{} + type GKE struct { Service *containerv1beta1.Service ClusterContext GKEContext + Recorder record.EventRecorder + inProgressDeletes sync.Map inProgressCreates sync.Map } func (g *GKE) NodePoolLabelKey() string { return GKENodePoolNameLabel } -func (g *GKE) EnsureNodePoolForPod(p *corev1.Pod) error { - name, err := podToNodePoolName(p, GKENodePoolNamePrefix, "") - if err != nil { - return fmt.Errorf("determining node pool name: %w", err) - } +func (g *GKE) EnsureNodePoolForPod(p *corev1.Pod, why string) error { + name := podToNodePoolName(p, GKENodePoolNamePrefix, "") exists, err := g.nodePoolExists(name) if err != nil { @@ -66,8 +70,6 @@ func (g *GKE) EnsureNodePoolForPod(p *corev1.Pod) error { return fmt.Errorf("determining node pool for pod: %w", err) } - log.Info("creating node pool", "name", name, "nodeCount", np.InitialNodeCount) - req := &containerv1beta1.CreateNodePoolRequest{ NodePool: np, Parent: g.ClusterContext.ClusterName(), @@ -83,20 +85,58 @@ func (g *GKE) EnsureNodePoolForPod(p *corev1.Pod) error { g.inProgressCreates.Store(name, struct{}{}) defer g.inProgressCreates.Delete(name) + g.Recorder.Eventf(p, corev1.EventTypeNormal, EventNodePoolCreationStarted, "Starting creation of Node Pool %s (size = %v) because %s", name, np.InitialNodeCount, why) call := g.Service.Projects.Locations.Clusters.NodePools.Create(g.ClusterContext.ClusterName(), req) op, err := call.Do() if err != nil { + g.Recorder.Eventf(p, corev1.EventTypeWarning, EventNodePoolCreationFailed, "Request to create Node Pool %s failed: %v.", name, err) return fmt.Errorf("do: %w", err) } - return waitForGkeOp(g.Service, g.ClusterContext, op) + if err := waitForGkeOp(g.Service, g.ClusterContext, op); err != nil { + g.Recorder.Eventf(p, corev1.EventTypeWarning, EventNodePoolCreationFailed, "Operation to create Node Pool %s failed: %v.", name, err) + return fmt.Errorf("waiting for operation: %w", err) + } + + g.Recorder.Eventf(p, corev1.EventTypeNormal, EventNodePoolCreationSucceeded, "Successfully created Node Pool %s.", name) + + return nil } -func (g *GKE) DeleteNodePoolForNode(node *corev1.Node) error { +func (g *GKE) ListNodePools() ([]NodePoolRef, error) { + var refs []NodePoolRef + + resp, err := g.Service.Projects.Locations.Clusters.NodePools.List(g.ClusterContext.ClusterName()).Do() + if err != nil { + return nil, fmt.Errorf("listing node pools: %w", err) + + } + + for _, np := range resp.NodePools { + refs = append(refs, NodePoolRef{ + Name: np.Name, + Error: np.Status == "ERROR", + Message: np.StatusMessage, + CreatedForPod: types.NamespacedName{ + Name: np.Config.Labels[LabelPodName], + Namespace: np.Config.Labels[LabelPodNamespace], + }, + }) + } + + return refs, nil +} + +func (g *GKE) DeleteNodePoolForNode(node *corev1.Node, why string) error { name, ok := node.GetLabels()[g.NodePoolLabelKey()] if !ok { return fmt.Errorf("node %q does not have node pool label", node.Name) } + + return g.DeleteNodePool(name, node, why) +} + +func (g *GKE) DeleteNodePool(name string, eventObj client.Object, why string) error { // Due to concurrent reconciles, multiple deletes for the same // Node Pool will occur at the same time. The result is an error: // To avoid a bunch of failed requests, we dedeuplicate here. @@ -106,23 +146,42 @@ func (g *GKE) DeleteNodePoolForNode(node *corev1.Node) error { g.inProgressDeletes.Store(name, struct{}{}) defer g.inProgressDeletes.Delete(name) + g.Recorder.Eventf(eventObj, corev1.EventTypeNormal, EventNodePoolDeletionStarted, "Starting deletion of Node Pool %s because %s", name, why) op, err := g.Service.Projects.Locations.Clusters.Delete(g.ClusterContext.NodePoolName(name)).Do() if err != nil { + if gerr, ok := err.(*googleapi.Error); ok && gerr.Code == http.StatusNotFound { + g.Recorder.Eventf(eventObj, corev1.EventTypeNormal, EventNodePoolNotFound, "Node pool not found - ignoring deletion attempt.", name) + return nil + } + g.Recorder.Eventf(eventObj, corev1.EventTypeWarning, EventNodePoolDeletionFailed, "Request to delete Node Pool %s failed: %v.", name, err) return fmt.Errorf("deleting node pool %q: %w", name, err) } - return waitForGkeOp(g.Service, g.ClusterContext, op) + if err := waitForGkeOp(g.Service, g.ClusterContext, op); err != nil { + g.Recorder.Eventf(eventObj, corev1.EventTypeWarning, EventNodePoolDeletionFailed, "Operation to delete Node Pool %s failed: %v.", name, err) + return err + } + + g.Recorder.Eventf(eventObj, corev1.EventTypeNormal, EventNodePoolDeletionSucceeded, "Successfully deleted Node Pool %s.", name) + + return nil } +var ErrNodePoolStopping = errors.New("node pool stopping") + func (g *GKE) nodePoolExists(name string) (bool, error) { call := g.Service.Projects.Locations.Clusters.NodePools.Get(g.ClusterContext.NodePoolName(name)) - _, err := call.Do() + np, err := call.Do() if err == nil { return true, nil } if gerr, ok := err.(*googleapi.Error); ok && gerr.Code == http.StatusNotFound { return false, nil } + if np.Status == "STOPPING" { + return false, ErrNodePoolStopping + } + return false, err } @@ -142,6 +201,9 @@ func (g *GKE) nodePoolForPod(name string, p *corev1.Pod) (*containerv1beta1.Node LabelParentName: strings.ToLower(ref.Name), // Assuming a Namespaced parent here... LabelParentNamespace: strings.ToLower(p.Namespace), + + LabelPodName: p.Name, + LabelPodNamespace: p.Namespace, } for k, v := range p.Spec.NodeSelector { @@ -202,7 +264,7 @@ func (g *GKE) nodePoolForPod(name string, p *corev1.Pod) (*containerv1beta1.Node ServiceAccount: g.ClusterContext.NodeServiceAccount, ShieldedInstanceConfig: &containerv1beta1.ShieldedInstanceConfig{ EnableIntegrityMonitoring: true, - EnableSecureBoot: true, + EnableSecureBoot: g.ClusterContext.NodeSecureBoot, }, Tags: g.ClusterContext.NodeTags, // NOTE: vendor/ was manually updated to include the field because @@ -248,25 +310,15 @@ func sumTPURequests(p *corev1.Pod) (int, error) { return n, nil } -func podToNodePoolName(p *corev1.Pod, prefix, suffix string) (string, error) { - // If JobSet job key annotation (SHA1 hash of namespaced job key) exists, - // use it as the owner ID. - // This annotation is stable through Job recreations, so the node pool name - // generated here will be the same if the JobSet is restarted. - var ownerID string - if jobKey, exists := p.Annotations["jobset.sigs.k8s.io/job-key"]; exists { - ownerID = jobKey +func podToNodePoolName(p *corev1.Pod, prefix, suffix string) string { + var uid string + ref := metav1.GetControllerOf(p) + if ref != nil { + uid = string(ref.UID) } else { - // Otherwise, fall back to the Job UID. The Job UID is not stable through - // recreations, so if a Job is recreated, the node pool name generated here - // will be different. - ref := metav1.GetControllerOf(p) - if ref == nil { - return "", errors.New("no owner reference") - } - ownerID = string(ref.UID) + uid = string(p.UID) } - return prefix + ownerID[0:12] + suffix, nil + return prefix + uid[0:12] + suffix } func tpuTopologyToNodeCount(accelerator, topo string) (int, error) { diff --git a/tpu-provisioner/internal/cloud/gke_context.go b/tpu-provisioner/internal/cloud/gke_context.go index 78510708f..70c16178e 100644 --- a/tpu-provisioner/internal/cloud/gke_context.go +++ b/tpu-provisioner/internal/cloud/gke_context.go @@ -10,6 +10,7 @@ type GKEContext struct { NodeServiceAccount string NodeSecondaryDisk string NodeTags []string + NodeSecureBoot bool } func (c GKEContext) ClusterName() string { diff --git a/tpu-provisioner/internal/cloud/interface.go b/tpu-provisioner/internal/cloud/interface.go deleted file mode 100644 index 21f1622c0..000000000 --- a/tpu-provisioner/internal/cloud/interface.go +++ /dev/null @@ -1,15 +0,0 @@ -package cloud - -import ( - "errors" - - corev1 "k8s.io/api/core/v1" -) - -type Provider interface { - NodePoolLabelKey() string - EnsureNodePoolForPod(*corev1.Pod) error - DeleteNodePoolForNode(*corev1.Node) error -} - -var ErrDuplicateRequest = errors.New("duplicate request") diff --git a/tpu-provisioner/internal/cloud/labels.go b/tpu-provisioner/internal/cloud/labels.go deleted file mode 100644 index a06493634..000000000 --- a/tpu-provisioner/internal/cloud/labels.go +++ /dev/null @@ -1,12 +0,0 @@ -package cloud - -const ( - keyPrefix = "google.com/" - - LabelNodepoolManager = keyPrefix + "nodepool-manager" - LabelNodepoolManagerTPUPodinator = "tpu-provisioner" - - LabelParentKind = keyPrefix + "tpu-provisioner-parent-kind" - LabelParentName = keyPrefix + "tpu-provisioner-parent-name" - LabelParentNamespace = keyPrefix + "tpu-provisioner-parent-namespace" -) diff --git a/tpu-provisioner/internal/cloud/mock.go b/tpu-provisioner/internal/cloud/mock.go index b761b2ea6..43c0a05b8 100644 --- a/tpu-provisioner/internal/cloud/mock.go +++ b/tpu-provisioner/internal/cloud/mock.go @@ -1,12 +1,19 @@ package cloud -import corev1 "k8s.io/api/core/v1" +import ( + corev1 "k8s.io/api/core/v1" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +var _ Provider = &Mock{} // Mock is useful for local development or debugging purposes to understand what // the controller would do without it doing anything. type Mock struct{} // TODO: Find a better mock node pool label key. -func (m *Mock) NodePoolLabelKey() string { return "kubernetes.io/os" } -func (m *Mock) EnsureNodePoolForPod(*corev1.Pod) error { return nil } -func (m *Mock) DeleteNodePoolForNode(*corev1.Node) error { return nil } +func (m *Mock) NodePoolLabelKey() string { return "kubernetes.io/os" } +func (m *Mock) EnsureNodePoolForPod(*corev1.Pod, string) error { return nil } +func (m *Mock) DeleteNodePoolForNode(*corev1.Node, string) error { return nil } +func (m *Mock) DeleteNodePool(string, client.Object, string) error { return nil } +func (m *Mock) ListNodePools() ([]NodePoolRef, error) { return nil, nil } diff --git a/tpu-provisioner/internal/controller/creation_controller.go b/tpu-provisioner/internal/controller/creation_controller.go index 41f7a0046..a4eca0eb3 100644 --- a/tpu-provisioner/internal/controller/creation_controller.go +++ b/tpu-provisioner/internal/controller/creation_controller.go @@ -20,6 +20,7 @@ import ( "context" "errors" "fmt" + "time" "github.com/GoogleCloudPlatform/ai-on-gke/tpu-provisioner/internal/cloud" @@ -29,7 +30,7 @@ import ( "k8s.io/client-go/tools/record" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" - "sigs.k8s.io/controller-runtime/pkg/log" + ctrllog "sigs.k8s.io/controller-runtime/pkg/log" ) // CreationReconciler watches Pods and creates Node Pools. @@ -52,7 +53,7 @@ type PodCriteria struct { //+kubebuilder:rbac:groups="",resources=pods/finalizers,verbs=update func (r *CreationReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { - lg := log.FromContext(ctx) + lg := ctrllog.FromContext(ctx) lg.V(3).Info("Reconciling Pod") @@ -66,22 +67,26 @@ func (r *CreationReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c } // Return early if Pod should not trigger a scale up. - if !isPending(&pod) || !isUnschedulable(&pod) || !doesRequestResource(&pod, r.PodCriteria.ResourceType) || !hasNodeSelectors(&pod, cloud.GKETPUNodeSelector) { + if !isPending(&pod) || !isUnschedulable(&pod) || + !doesRequestResource(&pod, r.PodCriteria.ResourceType) || + !hasNodeSelectors(&pod, cloud.GKETPUNodeSelector) || + pod.DeletionTimestamp != nil { lg.V(3).Info("Ignoring pod") return ctrl.Result{}, nil } lg.Info("Ensuring node pool for unschedulable pod") - r.Recorder.Eventf(&pod, corev1.EventTypeNormal, EventEnsuringNodePool, "Ensuring Node Pool, triggered by Pod %s/%s.", pod.Namespace, pod.Name) - if err := r.Provider.EnsureNodePoolForPod(&pod); err != nil { + if err := r.Provider.EnsureNodePoolForPod(&pod, "pod is currently unschedulable"); err != nil { if errors.Is(err, cloud.ErrDuplicateRequest) { lg.Info("Ignoring duplicate request to create node pool") + } else if errors.Is(err, cloud.ErrNodePoolStopping) { + wait := 5 * time.Second + lg.Info("Attempted to create a node pool that is currently undergoing deletion, retrying soon", + "wait", wait) + return ctrl.Result{RequeueAfter: wait}, nil } else { - r.Recorder.Event(&pod, corev1.EventTypeWarning, EventFailedEnsuringNodePool, "Failed to ensure existance of Node Pool: "+err.Error()) return ctrl.Result{}, err } - } else { - r.Recorder.Event(&pod, corev1.EventTypeNormal, EventNodePoolEnsured, "Node Pool Ensured.") } return ctrl.Result{}, nil diff --git a/tpu-provisioner/internal/controller/deletion_controller.go b/tpu-provisioner/internal/controller/deletion_controller.go index 457ee6ac0..8b2ed9e4b 100644 --- a/tpu-provisioner/internal/controller/deletion_controller.go +++ b/tpu-provisioner/internal/controller/deletion_controller.go @@ -18,18 +18,11 @@ import ( ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/handler" - "sigs.k8s.io/controller-runtime/pkg/log" + ctrllog "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/reconcile" "sigs.k8s.io/controller-runtime/pkg/source" ) -// nodePoolDeletionCheckInterval is the interval between the first and -// second node pool deletion checks. Once the node pool deletion check -// has passed twice, the node pool can be safely deleted. This second -// check is ensure the node pool is not prematurely deleted, in the case -// where a JobSet is restarted, but no pods have been created yet. -var nodePoolDeletionCheckInterval = 30 * time.Second - // DeletionReconciler watches Pods and Nodes and deletes Node Pools. type DeletionReconciler struct { client.Client @@ -43,6 +36,13 @@ type DeletionReconciler struct { type NodeCriteria struct { MinLifetime time.Duration + + // PoolDeletionDelay is the interval between the first and + // second node pool deletion checks. Once the node pool deletion check + // has passed twice, the node pool can be safely deleted. This second + // check is ensure the node pool is not prematurely deleted, in the case + // where a JobSet is restarted, but no pods have been created yet. + PoolDeletionDelay time.Duration } //+kubebuilder:rbac:groups="",resources=nodes,verbs=get;list;watch;create;update;patch;delete @@ -50,7 +50,7 @@ type NodeCriteria struct { //+kubebuilder:rbac:groups="",resources=nodes/finalizers,verbs=update func (r *DeletionReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { - lg := log.FromContext(ctx) + lg := ctrllog.FromContext(ctx) lg.V(3).Info("Reconciling Node") @@ -128,31 +128,28 @@ func (r *DeletionReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c if !exists { lg.Info(fmt.Sprintf("Node pool %q passed deletion check once", nodePoolName)) r.NodePoolsMarkedForDeletion.Store(nodePoolName, time.Now()) - return ctrl.Result{RequeueAfter: nodePoolDeletionCheckInterval}, nil + return ctrl.Result{RequeueAfter: r.NodeCriteria.PoolDeletionDelay}, nil } // If we haven't reached the node pool deletion check interval, this reconcile was // caused by something else, we can return early, and wait for the manually requeued // reconcile we did after the first deletion check passed. firstDeletionCheckTime := value.(time.Time) - if time.Now().Sub(firstDeletionCheckTime) < nodePoolDeletionCheckInterval { + if time.Now().Sub(firstDeletionCheckTime) < r.NodeCriteria.PoolDeletionDelay { return ctrl.Result{}, nil } // If this point is reached, the node pool has passed the deletion check twice // and can be deleted. lg.Info(fmt.Sprintf("Node pool %q passed deletion check twice. Ensuring Node Pool is deleted", nodePoolName)) - r.Recorder.Event(&node, corev1.EventTypeNormal, EventDeletingNodePool, DeletingNodePoolEventMessage) - if err := r.Provider.DeleteNodePoolForNode(&node); err != nil { + if err := r.Provider.DeleteNodePoolForNode(&node, "no user Pods are running on any of the Nodes in this node pool"); err != nil { if errors.Is(err, cloud.ErrDuplicateRequest) { lg.Info("Ignoring duplicate request to delete node pool") return ctrl.Result{}, nil } else { - r.Recorder.Event(&node, corev1.EventTypeWarning, EventFailedDeletingNodePool, "Failed to delete Node Pool: "+err.Error()) - return ctrl.Result{}, client.IgnoreNotFound(err) + return ctrl.Result{}, err } } - r.Recorder.Event(&node, corev1.EventTypeNormal, EventNodePoolDeleted, DeletedNodePoolEventMessage) // Remove node pool from the map tracking node pools marked for deletion, in case the JobSet // is reran in the future, as this will result in node pools with the same name being recreated, diff --git a/tpu-provisioner/internal/controller/deletion_controller_test.go b/tpu-provisioner/internal/controller/deletion_controller_test.go index 13bfc7bb7..232ea7fac 100644 --- a/tpu-provisioner/internal/controller/deletion_controller_test.go +++ b/tpu-provisioner/internal/controller/deletion_controller_test.go @@ -67,7 +67,7 @@ var _ = Describe("Deletion controller", func() { By("Checking the first deletion attempt only occurred after the node had existed for >= nodeDeletionInterval") actualDuration := deletionTimestamp.Sub(createdNode.CreationTimestamp.Time) - requiredDuration := nodePoolDeletionCheckInterval + minNodeLifetime + requiredDuration := nodepoolDeletionDelay + minNodeLifetime Expect(actualDuration).Should(BeNumerically(">=", requiredDuration)) By("Checking that other Nodes were ignored") diff --git a/tpu-provisioner/internal/controller/events.go b/tpu-provisioner/internal/controller/events.go index 235be19ac..9ccf4b3bc 100644 --- a/tpu-provisioner/internal/controller/events.go +++ b/tpu-provisioner/internal/controller/events.go @@ -2,13 +2,4 @@ package controller // +kubebuilder:rbac:groups="",resources=events,verbs=create;patch -const ( - EventFailedEnsuringNodePool = "FailedEnsuringNodePool" - EventFailedDeletingNodePool = "FailedDeletingNodePool" - EventEnsuringNodePool = "EnsuringNodePool" - EventNodePoolEnsured = "NodePoolEnsured" - EventDeletingNodePool = "DeletingNodePool" - EventNodePoolDeleted = "NodePoolDeleted" - DeletingNodePoolEventMessage = "Deleted Node Pool." - DeletedNodePoolEventMessage = "Deleted Node Pool." -) +// See pkg cloud for events. diff --git a/tpu-provisioner/internal/controller/nodepool_garbage_collector.go b/tpu-provisioner/internal/controller/nodepool_garbage_collector.go new file mode 100644 index 000000000..0fb4337c1 --- /dev/null +++ b/tpu-provisioner/internal/controller/nodepool_garbage_collector.go @@ -0,0 +1,98 @@ +package controller + +import ( + "context" + "fmt" + "time" + + "github.com/GoogleCloudPlatform/ai-on-gke/tpu-provisioner/internal/cloud" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "sigs.k8s.io/controller-runtime/pkg/client" + ctrllog "sigs.k8s.io/controller-runtime/pkg/log" +) + +// NodePoolGarbageCollector deletes node pools that have no Nodes, +// are in an errored state, and where the Pod that created the node pool +// no longer exists (the deletion reconciler would not see these b/c there +// are no Node objects). +type NodePoolGarbageCollector struct { + Interval time.Duration + client.Client + Provider cloud.Provider +} + +func (g *NodePoolGarbageCollector) Run(ctx context.Context) { + log := ctrllog.Log.WithName("nodepool-garbage-collector") + + t := time.NewTicker(g.Interval) + + for { + select { + case <-ctx.Done(): + t.Stop() + return + case <-t.C: + } + + log.Info("starting node pool garbage collection loop") + + nodepools, err := g.Provider.ListNodePools() + if err != nil { + log.Error(err, "failed to list errored node pools") + continue + } + + for _, np := range nodepools { + log := log.WithValues( + "nodepool", np.Name, + "createdForPodName", np.CreatedForPod.Name, + "createdForPodNamespace", np.CreatedForPod.Namespace, + ) + + if !np.Error { + continue + } + + if np.CreatedForPod.Name == "" || np.CreatedForPod.Namespace == "" { + log.Info("skipping garbage collection of node pool, no pod reference") + continue + } + + // Check if the Pod that triggered the Node Pool creation still exists. + err := g.Get(ctx, np.CreatedForPod, &v1.Pod{}) + if err == nil { + log.Info("skipping garbage collection of node pool, pod still exists", + "podName", np.CreatedForPod.Name, + "podNamespace", np.CreatedForPod.Namespace, + ) + continue + } + if client.IgnoreNotFound(err) != nil { + log.Error(err, "failed to get pod node pool was created for") + continue + } + // Pod not found if this point is reached. + + // Ignore node pools that have Nodes registered for them (these will be handled by the deletion controller). + var nodes v1.NodeList + if err := g.List(ctx, &nodes, client.MatchingLabels{g.Provider.NodePoolLabelKey(): np.Name}); err != nil { + log.Error(err, "failed to list nodes for node pool") + continue + } + if len(nodes.Items) > 0 { + log.Info("skipping garbage collection of node pool, nodes exist") + continue + } + + log.Info("garbage collecting node pool in error state") + // TODO: Lookup namespace from env with downward API. + whyDelete := fmt.Sprintf("the node pool has no corresponding Nodes, the Pod (%s/%s) that triggered its creation no longer exists, and node pool is in an error state: %s", + np.CreatedForPod.Namespace, np.CreatedForPod.Name, np.Message) + if err := g.Provider.DeleteNodePool(np.Name, &v1.Namespace{ObjectMeta: metav1.ObjectMeta{Name: "tpu-provisioner-system"}}, whyDelete); err != nil { + log.Error(err, "failed to garbage collect node pool") + continue + } + } + } +} diff --git a/tpu-provisioner/internal/controller/provider_test.go b/tpu-provisioner/internal/controller/provider_test.go index 63fd2f66d..5e7e25de7 100644 --- a/tpu-provisioner/internal/controller/provider_test.go +++ b/tpu-provisioner/internal/controller/provider_test.go @@ -4,19 +4,24 @@ import ( "sync" "time" + "github.com/GoogleCloudPlatform/ai-on-gke/tpu-provisioner/internal/cloud" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/types" ) +var _ cloud.Provider = &testProvider{} + type testProvider struct { sync.Mutex created map[types.NamespacedName]bool deleted map[string]time.Time + + cloud.Provider } func (p *testProvider) NodePoolLabelKey() string { return "cloud.test.com/test-nodepool" } -func (p *testProvider) EnsureNodePoolForPod(pod *corev1.Pod) error { +func (p *testProvider) EnsureNodePoolForPod(pod *corev1.Pod, _ string) error { p.Lock() defer p.Unlock() p.created[types.NamespacedName{Namespace: pod.Namespace, Name: pod.Name}] = true @@ -29,7 +34,7 @@ func (p *testProvider) getCreated(nn types.NamespacedName) bool { return p.created[nn] } -func (p *testProvider) DeleteNodePoolForNode(node *corev1.Node) error { +func (p *testProvider) DeleteNodePoolForNode(node *corev1.Node, _ string) error { p.Lock() defer p.Unlock() if _, exists := p.deleted[node.Name]; !exists { diff --git a/tpu-provisioner/internal/controller/suite_test.go b/tpu-provisioner/internal/controller/suite_test.go index c81307c5a..39b46c9a3 100644 --- a/tpu-provisioner/internal/controller/suite_test.go +++ b/tpu-provisioner/internal/controller/suite_test.go @@ -51,8 +51,9 @@ var ( ) const ( - resourceName = "test.com/tpu" - minNodeLifetime = time.Second + resourceName = "test.com/tpu" + minNodeLifetime = time.Second + nodepoolDeletionDelay = 5 * time.Second ) func TestAPIs(t *testing.T) { @@ -103,7 +104,8 @@ var _ = BeforeSuite(func() { Recorder: mgr.GetEventRecorderFor("tpu-provisioner-deleter"), Provider: provider, NodeCriteria: NodeCriteria{ - MinLifetime: minNodeLifetime, + MinLifetime: minNodeLifetime, + PoolDeletionDelay: nodepoolDeletionDelay, }, }).SetupWithManager(mgr) Expect(err).ToNot(HaveOccurred())