Skip to content

Commit

Permalink
fix: timeout when pause cluster
Browse files Browse the repository at this point in the history
Signed-off-by: vietanhduong <[email protected]>
  • Loading branch information
vietanhduong committed Jun 10, 2023
1 parent a41186b commit c42147e
Showing 1 changed file with 94 additions and 79 deletions.
173 changes: 94 additions & 79 deletions pkg/gcloud/gke/gke.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ func (c *Client) ListClusters(project string) ([]*apis.Cluster, error) {
}
defer conn.Close()

igConn, err := c.newInstanceGroupsConn()
migConn, err := c.newManagedInstanceGroupConn()
if err != nil {
return nil, errors.Wrap(err, "list clusters")
}
defer igConn.Close()
defer migConn.Close()

resp, err := conn.ListClusters(context.TODO(), &containerpb.ListClustersRequest{
Parent: fmt.Sprintf("projects/%s/locations/-", project),
Expand All @@ -64,7 +64,7 @@ func (c *Client) ListClusters(project string) ([]*apis.Cluster, error) {
InstanceGroups: p.GetInstanceGroupUrls(),
Locations: p.GetLocations(),
InitialNodeCount: p.GetInitialNodeCount(),
CurrentSize: int32(getNodePoolSize(igConn, project, e.GetName(), p.GetName(), p.GetLocations())),
CurrentSize: int32(getNodePoolSize(migConn, project, e.GetName(), p.GetName(), p.GetLocations())),
Spot: p.GetConfig().GetSpot(),
Preemptible: p.GetConfig().GetPreemptible(),
}
Expand Down Expand Up @@ -100,11 +100,11 @@ func (c *Client) GetCluster(project, location, name string) (*apis.Cluster, erro
}
defer conn.Close()

igConn, err := c.newInstanceGroupsConn()
migConn, err := c.newManagedInstanceGroupConn()
if err != nil {
return nil, errors.Wrap(err, "get cluster")
}
defer igConn.Close()
defer migConn.Close()

req := &containerpb.GetClusterRequest{
Name: fmt.Sprintf("projects/%s/locations/%s/clusters/%s", project, location, name),
Expand All @@ -129,7 +129,7 @@ func (c *Client) GetCluster(project, location, name string) (*apis.Cluster, erro
InstanceGroups: p.GetInstanceGroupUrls(),
Locations: p.GetLocations(),
InitialNodeCount: p.GetInitialNodeCount(),
CurrentSize: int32(getNodePoolSize(igConn, project, cluster.GetName(), p.GetName(), p.GetLocations())),
CurrentSize: int32(getNodePoolSize(migConn, project, cluster.GetName(), p.GetName(), p.GetLocations())),
Spot: p.GetConfig().GetSpot(),
Preemptible: p.GetConfig().GetPreemptible(),
}
Expand Down Expand Up @@ -158,11 +158,11 @@ func (c *Client) PauseCluster(cluster *apis.Cluster) error {
}
defer conn.Close()

igConn, err := c.newInstanceGroupsConn()
migConn, err := c.newManagedInstanceGroupConn()
if err != nil {
return errors.Wrap(err, "list clusters")
}
defer igConn.Close()
defer migConn.Close()

pause := func(cluster *apis.Cluster, pool *apis.Cluster_NodePool) error {
if pool.GetAutoscaling() != nil || pool.GetAutoscaling().GetEnabled() {
Expand All @@ -174,16 +174,38 @@ func (c *Client) PauseCluster(cluster *apis.Cluster) error {
log.Printf("WARN: disable node pool autoscaling for '%s/%s' failed\n", cluster.GetName(), cluster.GetLocation())
return errors.Wrapf(err, "pause '%s/%s'", cluster.GetName(), cluster.GetLocation())
}
if err = waitOp(conn, op); err != nil {
if err = waitContainerOp(conn, op); err != nil {
return errors.Wrapf(err, "pause '%s/%s'", cluster.GetName(), cluster.GetLocation())
}
log.Printf("INFO: disabled autoscaling for '%s/%s'\n", cluster.Name, pool.Name)
}

if err := resize(conn, igConn, cluster, pool, 0); err != nil {
if err := resize(migConn, cluster, pool, 0); err != nil {
return errors.Wrap(err, "pause cluster")
}
return nil

interval := time.NewTicker(10 * time.Second)
counter := 3
for {
select {
case <-interval.C:
// if after 3 times check and the current node is 0, we can mark this pool is scaled down
if current := getNodePoolSize(migConn, cluster.GetProject(), cluster.GetName(), pool.GetName(), pool.GetLocations()); current == 0 {
if counter > 0 {
counter--
} else {
return nil
}
} else {
// we will reset the counter if the current node return size != 0
log.Printf("INFO: resizing pool '%s/%s', current size is %d...\n", cluster.GetName(), pool.GetName(), current)
counter = 3
_ = resize(migConn, cluster, pool, 0)
}
case <-context.Background().Done():
return nil
}
}
}

defer func() {
Expand Down Expand Up @@ -212,11 +234,11 @@ func (c *Client) UnpauseCluster(cluster *apis.Cluster) error {
}
defer conn.Close()

igConn, err := c.newInstanceGroupsConn()
migConn, err := c.newManagedInstanceGroupConn()
if err != nil {
return errors.Wrap(err, "list clusters")
}
defer igConn.Close()
defer migConn.Close()
unpause := func(cluster *apis.Cluster, p *apis.Cluster_NodePool) error {
if err := waitClusterOperation(conn, cluster); err != nil {
return errors.Wrap(err, "unpaise cluster")
Expand All @@ -236,13 +258,13 @@ func (c *Client) UnpauseCluster(cluster *apis.Cluster) error {
log.Printf("WARN: enable node pool autoscaling for '%s/%s' failed\n", cluster.GetName(), cluster.GetLocation())
return errors.Wrapf(err, "unpause '%s/%s'", cluster.GetName(), cluster.GetLocation())
}
if err = waitOp(conn, op); err != nil {
if err = waitContainerOp(conn, op); err != nil {
return errors.Wrapf(err, "unpause '%s/%s'", cluster.GetName(), cluster.GetLocation())
}
log.Printf("INFO: enabled autoscaling for '%s/%s'\n", cluster.Name, p.Name)
}

if err := resize(conn, igConn, cluster, p, int(p.GetCurrentSize())); err != nil {
if err := resize(migConn, cluster, p, int(p.GetCurrentSize())); err != nil {
return errors.Wrap(err, "unpause cluster")
}
return nil
Expand Down Expand Up @@ -332,51 +354,37 @@ func (c *Client) RefreshCluster(cluster *apis.Cluster) error {
return eg.Wait()
}

// resize the input node pool
func resize(clusterConn *container_v1.ClusterManagerClient, igConn *compute_v1.InstanceGroupsClient, cluster *apis.Cluster, pool *apis.Cluster_NodePool, size int) error {
_resize := func() (bool, error) {
currentSize := getNodePoolSize(igConn, cluster.GetProject(), cluster.GetName(), pool.GetName(), pool.GetLocations())
if size == 0 && size == currentSize {
log.Printf("INFO: node pool '%s/%s' has been resized to 0!", cluster.GetName(), pool.GetName())
return true, nil
}
if currentSize >= size && size > 0 {
log.Printf("INFO: node pool '%s/%s' has been resized to %d (current=%d)!\n", cluster.GetName(), pool.GetName(), size, currentSize)
return true, nil
} else {
log.Printf("INFO: resizing node pool '%s/%s'! current=%d; expect=%d\n", cluster.GetName(), pool.GetName(), currentSize, size)
}
if err := waitClusterOperation(clusterConn, cluster); err != nil {
return false, errors.Wrap(err, "resize")
}
// resize the input node pool. This function will resize the MIGs in the input pool instead of the GKE
func resize(conn *compute_v1.InstanceGroupManagersClient, cluster *apis.Cluster, pool *apis.Cluster_NodePool, size int) error {
// get all mig of the input pool
migs, err := findMIGs(conn, cluster.GetProject(), cluster.GetName(), pool.GetName(), pool.GetLocations())
if err != nil {
log.Printf("WARN: find MIGs of '%s/%s' got erorr: %v\n", cluster.GetName(), pool.GetName(), err)
return err
}

req := &containerpb.SetNodePoolSizeRequest{
Name: fmt.Sprintf("projects/%s/locations/%s/clusters/%s/nodePools/%s", cluster.GetProject(), cluster.GetLocation(), cluster.GetName(), pool.GetName()),
NodeCount: int32(size),
_resize := func(mig *computepb.InstanceGroupManager) error {
req := &computepb.ResizeInstanceGroupManagerRequest{
InstanceGroupManager: mig.GetName(),
Project: cluster.GetProject(),
Size: int32(size),
Zone: basename(mig.GetZone()),
}
op, err := clusterConn.SetNodePoolSize(context.Background(), req)
_, err := conn.Resize(context.Background(), req)
if err != nil {
return false, errors.Wrap(err, "resize")
log.Printf("WARN: resize MIG %q got error: %v\n", mig.GetName(), err)
}
return false, waitOp(clusterConn, op)
}
var stop bool
var err error
if stop, err = _resize(); err != nil || stop {
return err
}

ticker := time.NewTicker(3 * time.Second)
for {
select {
case <-ticker.C:
if stop, err = _resize(); err != nil || stop {
return err
}
case <-context.Background().Done():
return nil
}
var eg errgroup.Group
for _, mig := range migs {
mig := mig
eg.Go(func() error { return _resize(mig) })
}
if err = eg.Wait(); err != nil {
return err
}
return nil
}

func (c *Client) newClusterConn() (*container_v1.ClusterManagerClient, error) {
Expand All @@ -395,14 +403,6 @@ func (c *Client) newManagedInstanceGroupConn() (*compute_v1.InstanceGroupManager
return compute_v1.NewInstanceGroupManagersRESTClient(context.Background(), opts...)
}

func (c *Client) newInstanceGroupsConn() (*compute_v1.InstanceGroupsClient, error) {
var opts []option.ClientOption
if c.options.Credentials != "" {
opts = append(opts, option.WithCredentialsFile(c.options.Credentials))
}
return compute_v1.NewInstanceGroupsRESTClient(context.Background(), opts...)
}

// getMangedInstanceNames return the url of managed instances by the input MIG
func getMangedInstanceNames(conn *compute_v1.InstanceGroupManagersClient, project string, mig *computepb.InstanceGroupManager) ([]string, error) {
req := &computepb.ListManagedInstancesInstanceGroupManagersRequest{
Expand All @@ -425,40 +425,55 @@ func getMangedInstanceNames(conn *compute_v1.InstanceGroupManagersClient, projec
return names, nil
}

func getNodePoolSize(conn *compute_v1.InstanceGroupsClient, project, cluster, pool string, zones []string) int {
ret := make([]int, len(zones))
getSize := func(i int, z string) error {
func getNodePoolSize(conn *compute_v1.InstanceGroupManagersClient, project, cluster, pool string, zones []string) int {
migs, err := findMIGs(conn, project, cluster, pool, zones)
if err != nil {
log.Printf("WARN: get pool size '%s/%s/%s' got error: %v\n", project, cluster, pool, err)
return 0
}

var size int32
for _, mig := range migs {
size += mig.GetTargetSize()
}
return int(size)
}

func findMIGs(conn *compute_v1.InstanceGroupManagersClient, project, cluster, pool string, zones []string) ([]*computepb.InstanceGroupManager, error) {
tmp := make([][]*computepb.InstanceGroupManager, len(zones))
getMigs := func(i int, z string) error {
filterQuery := fmt.Sprintf("name:gke-%s-%s-*", cluster, pool)
req := &computepb.ListInstanceGroupsRequest{
req := &computepb.ListInstanceGroupManagersRequest{
Project: project,
Filter: &filterQuery,
Zone: z,
}
it := conn.List(context.Background(), req)
for {
resp, err := it.Next()
if err == nil || err == iterator.Done {
ret[i] = int(resp.GetSize())
return nil
if err == iterator.Done {
break
}
if err != nil {
log.Printf("WARN: get pool size '%s/%s/%s' got error: %v\n", project, cluster, pool, err)
return nil
return err
}
tmp[i] = append(tmp[i], resp)
}
return nil
}
var eg errgroup.Group
for i, z := range zones {
i, z := i, z
eg.Go(func() error { return getSize(i, z) })
eg.Go(func() error { return getMigs(i, z) })
}
_ = eg.Wait()

var size int
for _, val := range ret {
size += val
if err := eg.Wait(); err != nil {
return nil, err
}
var out []*computepb.InstanceGroupManager
for _, l := range tmp {
out = append(out, l...)
}
return size
return out, nil
}

func waitClusterOperation(conn *container_v1.ClusterManagerClient, cluster *apis.Cluster) error {
Expand All @@ -471,7 +486,7 @@ func waitClusterOperation(conn *container_v1.ClusterManagerClient, cluster *apis
if !strings.Contains(op.GetTargetLink(), fmt.Sprintf("clusters/%s", cluster.GetName())) || op.GetStatus() == containerpb.Operation_DONE {
return nil
}
if err := waitOp(conn, op); err != nil {
if err := waitContainerOp(conn, op); err != nil {
return errors.Wrapf(err, "wait operation %q", op.GetName())
}
log.Printf("INFO: handle operation '%s/%s' has been completed\n", cluster.GetName(), op.GetName())
Expand All @@ -486,7 +501,7 @@ func waitClusterOperation(conn *container_v1.ClusterManagerClient, cluster *apis
return eg.Wait()
}

func waitOp(conn *container_v1.ClusterManagerClient, op *containerpb.Operation) error {
func waitContainerOp(conn *container_v1.ClusterManagerClient, op *containerpb.Operation) error {
ticker := time.NewTicker(time.Second)
for {
select {
Expand Down

0 comments on commit c42147e

Please sign in to comment.