Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TPU Provisioner: Fix skipped scale ups & garbage collect zombie node pools #277

32 changes: 28 additions & 4 deletions tpu-provisioner/cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"net/http"
"os"
"strings"
"sync"
"time"

"k8s.io/apimachinery/pkg/runtime/schema"
Expand Down Expand Up @@ -76,12 +77,15 @@ func main() {

GCPNodeTags []string `envconfig:"GCP_NODE_TAGS"`
GCPNodeSecondaryDisk string `envconfig:"GCP_NODE_SECONDARY_DISK" default:""`
GCPNodeSecureBoot bool `envconfig:"GCP_NODE_SECURE_BOOT" default:"true"`

// NodeMinLifespan is the amount of time that should pass between a Node object
// creation and a cleanup of that Node. This needs to be long enough to allow
// the node to become Ready and for a pending Pod to be scheduled on it.
NodeMinLifespan time.Duration `envconfig:"NODE_MIN_LIFESPAN" default:"3m"`

NodepoolDeletionDelay time.Duration `envconfig:"NODEPOOL_DELETION_DELAY" default:"30s"`

PodResourceType string `envconfig:"POD_RESOURCE_TYPE" default:"google.com/tpu"`

Concurrency int `envconfig:"CONCURRENCY" default:"3"`
Expand Down Expand Up @@ -197,7 +201,9 @@ func main() {
NodeServiceAccount: cfg.GCPNodeServiceAccount,
NodeSecondaryDisk: cfg.GCPNodeSecondaryDisk,
NodeTags: cfg.GCPNodeTags,
NodeSecureBoot: cfg.GCPNodeSecureBoot,
},
Recorder: mgr.GetEventRecorderFor("tpu-provisioner"),
}
case "mock":
provider = &cloud.Mock{}
Expand All @@ -209,7 +215,7 @@ func main() {
if err := (&controller.CreationReconciler{
Client: mgr.GetClient(),
Scheme: mgr.GetScheme(),
Recorder: mgr.GetEventRecorderFor("tpu-provisioner-creator"),
Recorder: mgr.GetEventRecorderFor("tpu-provisioner"),
Provider: provider,
PodCriteria: controller.PodCriteria{
ResourceType: cfg.PodResourceType,
Expand All @@ -222,10 +228,11 @@ func main() {
if err := (&controller.DeletionReconciler{
Client: mgr.GetClient(),
Scheme: mgr.GetScheme(),
Recorder: mgr.GetEventRecorderFor("tpu-provisioner-deleter"),
Recorder: mgr.GetEventRecorderFor("tpu-provisioner"),
Provider: provider,
NodeCriteria: controller.NodeCriteria{
MinLifetime: cfg.NodeMinLifespan,
MinLifetime: cfg.NodeMinLifespan,
PoolDeletionDelay: cfg.NodepoolDeletionDelay,
},
}).SetupWithManager(mgr); err != nil {
setupLog.Error(err, "unable to create controller", "controller", "DeletionReconciler")
Expand All @@ -241,10 +248,27 @@ func main() {
setupLog.Error(err, "unable to set up ready check")
os.Exit(1)
}
ctx := ctrl.SetupSignalHandler()

gc := &controller.NodePoolGarbageCollector{
Interval: time.Minute,
Client: mgr.GetClient(),
Provider: provider,
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
gc.Run(ctx)
wg.Done()
}()

setupLog.Info("starting manager")
if err := mgr.Start(ctrl.SetupSignalHandler()); err != nil {
if err := mgr.Start(ctx); err != nil {
setupLog.Error(err, "problem running manager")
os.Exit(1)
}

setupLog.Info("waiting for all goroutines to finish")
wg.Wait()
setupLog.Info("exiting")
}
55 changes: 55 additions & 0 deletions tpu-provisioner/internal/cloud/common.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package cloud

import (
"errors"
"time"

corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/controller-runtime/pkg/client"
)

const (
keyPrefix = "google.com/"

LabelNodepoolManager = keyPrefix + "nodepool-manager"
LabelNodepoolManagerTPUPodinator = "tpu-provisioner"

LabelParentKind = keyPrefix + "tpu-provisioner-parent-kind"
LabelParentName = keyPrefix + "tpu-provisioner-parent-name"
LabelParentNamespace = keyPrefix + "tpu-provisioner-parent-namespace"

LabelPodName = keyPrefix + "tpu-provisioner-pod-name"
LabelPodNamespace = keyPrefix + "tpu-provisioner-pod-namespace"

EventNodePoolCreationStarted = "NodePoolCreationStarted"
EventNodePoolCreationSucceeded = "NodePoolCreationSucceeded"
EventNodePoolCreationFailed = "NodePoolCreationFailed"

EventNodePoolDeletionStarted = "NodePoolDeletionStarted"
EventNodePoolDeletionSucceeded = "NodePoolDeletionSucceeded"
EventNodePoolDeletionFailed = "NodePoolDeletionFailed"

EventNodePoolNotFound = "NodePoolNotFound"
)

type Provider interface {
NodePoolLabelKey() string
EnsureNodePoolForPod(*corev1.Pod, string) error
DeleteNodePoolForNode(*corev1.Node, string) error
DeleteNodePool(string, client.Object, string) error
ListNodePools() ([]NodePoolRef, error)
}

var ErrDuplicateRequest = errors.New("duplicate request")

type NodePoolRef struct {
Name string

CreationTime time.Time

CreatedForPod types.NamespacedName

Error bool
Message string
}
110 changes: 81 additions & 29 deletions tpu-provisioner/internal/cloud/gke.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ import (
"google.golang.org/api/googleapi"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
"k8s.io/client-go/tools/record"
"sigs.k8s.io/controller-runtime/pkg/client"
logf "sigs.k8s.io/controller-runtime/pkg/log"
)

Expand All @@ -37,21 +40,22 @@ const (
maxPodsPerNode = 15
)

var _ Provider = &GKE{}

type GKE struct {
Service *containerv1beta1.Service
ClusterContext GKEContext

Recorder record.EventRecorder

inProgressDeletes sync.Map
inProgressCreates sync.Map
}

func (g *GKE) NodePoolLabelKey() string { return GKENodePoolNameLabel }

func (g *GKE) EnsureNodePoolForPod(p *corev1.Pod) error {
name, err := podToNodePoolName(p, GKENodePoolNamePrefix, "")
if err != nil {
return fmt.Errorf("determining node pool name: %w", err)
}
func (g *GKE) EnsureNodePoolForPod(p *corev1.Pod, why string) error {
name := podToNodePoolName(p, GKENodePoolNamePrefix, "")

exists, err := g.nodePoolExists(name)
if err != nil {
Expand All @@ -66,8 +70,6 @@ func (g *GKE) EnsureNodePoolForPod(p *corev1.Pod) error {
return fmt.Errorf("determining node pool for pod: %w", err)
}

log.Info("creating node pool", "name", name, "nodeCount", np.InitialNodeCount)

req := &containerv1beta1.CreateNodePoolRequest{
NodePool: np,
Parent: g.ClusterContext.ClusterName(),
Expand All @@ -83,20 +85,58 @@ func (g *GKE) EnsureNodePoolForPod(p *corev1.Pod) error {
g.inProgressCreates.Store(name, struct{}{})
defer g.inProgressCreates.Delete(name)

g.Recorder.Eventf(p, corev1.EventTypeNormal, EventNodePoolCreationStarted, "Starting creation of Node Pool %s (size = %v) because %s", name, np.InitialNodeCount, why)
call := g.Service.Projects.Locations.Clusters.NodePools.Create(g.ClusterContext.ClusterName(), req)
op, err := call.Do()
if err != nil {
g.Recorder.Eventf(p, corev1.EventTypeWarning, EventNodePoolCreationFailed, "Request to create Node Pool %s failed: %v.", name, err)
return fmt.Errorf("do: %w", err)
}

return waitForGkeOp(g.Service, g.ClusterContext, op)
if err := waitForGkeOp(g.Service, g.ClusterContext, op); err != nil {
g.Recorder.Eventf(p, corev1.EventTypeWarning, EventNodePoolCreationFailed, "Operation to create Node Pool %s failed: %v.", name, err)
return fmt.Errorf("waiting for operation: %w", err)
}

g.Recorder.Eventf(p, corev1.EventTypeNormal, EventNodePoolCreationSucceeded, "Successfully created Node Pool %s.", name)

return nil
}

func (g *GKE) DeleteNodePoolForNode(node *corev1.Node) error {
func (g *GKE) ListNodePools() ([]NodePoolRef, error) {
var refs []NodePoolRef

resp, err := g.Service.Projects.Locations.Clusters.NodePools.List(g.ClusterContext.ClusterName()).Do()
if err != nil {
return nil, fmt.Errorf("listing node pools: %w", err)

}

for _, np := range resp.NodePools {
refs = append(refs, NodePoolRef{
Name: np.Name,
Error: np.Status == "ERROR",
Message: np.StatusMessage,
CreatedForPod: types.NamespacedName{
Name: np.Config.Labels[LabelPodName],
Namespace: np.Config.Labels[LabelPodNamespace],
},
})
}

return refs, nil
}

func (g *GKE) DeleteNodePoolForNode(node *corev1.Node, why string) error {
name, ok := node.GetLabels()[g.NodePoolLabelKey()]
if !ok {
return fmt.Errorf("node %q does not have node pool label", node.Name)
}

return g.DeleteNodePool(name, node, why)
}

func (g *GKE) DeleteNodePool(name string, eventObj client.Object, why string) error {
// Due to concurrent reconciles, multiple deletes for the same
// Node Pool will occur at the same time. The result is an error:
// To avoid a bunch of failed requests, we dedeuplicate here.
Expand All @@ -106,23 +146,42 @@ func (g *GKE) DeleteNodePoolForNode(node *corev1.Node) error {
g.inProgressDeletes.Store(name, struct{}{})
defer g.inProgressDeletes.Delete(name)

g.Recorder.Eventf(eventObj, corev1.EventTypeNormal, EventNodePoolDeletionStarted, "Starting deletion of Node Pool %s because %s", name, why)
op, err := g.Service.Projects.Locations.Clusters.Delete(g.ClusterContext.NodePoolName(name)).Do()
if err != nil {
if gerr, ok := err.(*googleapi.Error); ok && gerr.Code == http.StatusNotFound {
g.Recorder.Eventf(eventObj, corev1.EventTypeNormal, EventNodePoolNotFound, "Node pool not found - ignoring deletion attempt.", name)
return nil
}
g.Recorder.Eventf(eventObj, corev1.EventTypeWarning, EventNodePoolDeletionFailed, "Request to delete Node Pool %s failed: %v.", name, err)
return fmt.Errorf("deleting node pool %q: %w", name, err)
}

return waitForGkeOp(g.Service, g.ClusterContext, op)
if err := waitForGkeOp(g.Service, g.ClusterContext, op); err != nil {
g.Recorder.Eventf(eventObj, corev1.EventTypeWarning, EventNodePoolDeletionFailed, "Operation to delete Node Pool %s failed: %v.", name, err)
return err
}

g.Recorder.Eventf(eventObj, corev1.EventTypeNormal, EventNodePoolDeletionSucceeded, "Successfully deleted Node Pool %s.", name)

return nil
}

var ErrNodePoolStopping = errors.New("node pool stopping")

func (g *GKE) nodePoolExists(name string) (bool, error) {
call := g.Service.Projects.Locations.Clusters.NodePools.Get(g.ClusterContext.NodePoolName(name))
_, err := call.Do()
np, err := call.Do()
if err == nil {
return true, nil
}
if gerr, ok := err.(*googleapi.Error); ok && gerr.Code == http.StatusNotFound {
return false, nil
}
if np.Status == "STOPPING" {
return false, ErrNodePoolStopping
}

return false, err
}

Expand All @@ -142,6 +201,9 @@ func (g *GKE) nodePoolForPod(name string, p *corev1.Pod) (*containerv1beta1.Node
LabelParentName: strings.ToLower(ref.Name),
// Assuming a Namespaced parent here...
LabelParentNamespace: strings.ToLower(p.Namespace),

LabelPodName: p.Name,
LabelPodNamespace: p.Namespace,
}

for k, v := range p.Spec.NodeSelector {
Expand Down Expand Up @@ -202,7 +264,7 @@ func (g *GKE) nodePoolForPod(name string, p *corev1.Pod) (*containerv1beta1.Node
ServiceAccount: g.ClusterContext.NodeServiceAccount,
ShieldedInstanceConfig: &containerv1beta1.ShieldedInstanceConfig{
EnableIntegrityMonitoring: true,
EnableSecureBoot: true,
EnableSecureBoot: g.ClusterContext.NodeSecureBoot,
},
Tags: g.ClusterContext.NodeTags,
// NOTE: vendor/ was manually updated to include the field because
Expand Down Expand Up @@ -248,25 +310,15 @@ func sumTPURequests(p *corev1.Pod) (int, error) {
return n, nil
}

func podToNodePoolName(p *corev1.Pod, prefix, suffix string) (string, error) {
// If JobSet job key annotation (SHA1 hash of namespaced job key) exists,
// use it as the owner ID.
// This annotation is stable through Job recreations, so the node pool name
// generated here will be the same if the JobSet is restarted.
var ownerID string
if jobKey, exists := p.Annotations["jobset.sigs.k8s.io/job-key"]; exists {
ownerID = jobKey
func podToNodePoolName(p *corev1.Pod, prefix, suffix string) string {
var uid string
ref := metav1.GetControllerOf(p)
if ref != nil {
uid = string(ref.UID)
} else {
// Otherwise, fall back to the Job UID. The Job UID is not stable through
// recreations, so if a Job is recreated, the node pool name generated here
// will be different.
ref := metav1.GetControllerOf(p)
if ref == nil {
return "", errors.New("no owner reference")
}
ownerID = string(ref.UID)
uid = string(p.UID)
}
return prefix + ownerID[0:12] + suffix, nil
return prefix + uid[0:12] + suffix
}

func tpuTopologyToNodeCount(accelerator, topo string) (int, error) {
Expand Down
1 change: 1 addition & 0 deletions tpu-provisioner/internal/cloud/gke_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ type GKEContext struct {
NodeServiceAccount string
NodeSecondaryDisk string
NodeTags []string
NodeSecureBoot bool
}

func (c GKEContext) ClusterName() string {
Expand Down
15 changes: 0 additions & 15 deletions tpu-provisioner/internal/cloud/interface.go

This file was deleted.

12 changes: 0 additions & 12 deletions tpu-provisioner/internal/cloud/labels.go

This file was deleted.

Loading