Skip to content

Commit

Permalink
support successPolicy and failurePolicy
Browse files Browse the repository at this point in the history
Signed-off-by: qiankunli <[email protected]>

run codegen

Signed-off-by: qiankunli <[email protected]>

support watch pg and preemptable label

Signed-off-by: qiankunli <[email protected]>

fix test case

Signed-off-by: qiankunli <[email protected]>

fix test case

Signed-off-by: qiankunli <[email protected]>
  • Loading branch information
qiankunli committed Apr 18, 2022
1 parent 8c43231 commit c50681b
Show file tree
Hide file tree
Showing 10 changed files with 175 additions and 7 deletions.
2 changes: 2 additions & 0 deletions cmd/training-operator.v1/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import (
"github.com/kubeflow/training-operator/pkg/config"
controllerv1 "github.com/kubeflow/training-operator/pkg/controller.v1"
//+kubebuilder:scaffold:imports
volcanov1beta1 "volcano.sh/apis/pkg/apis/scheduling/v1beta1"
)

var (
Expand All @@ -54,6 +55,7 @@ func init() {
utilruntime.Must(mxnetv1.AddToScheme(scheme))
utilruntime.Must(mpiv1.AddToScheme(scheme))
//+kubebuilder:scaffold:scheme
utilruntime.Must(volcanov1beta1.AddToScheme(scheme))
}

func main() {
Expand Down
4 changes: 4 additions & 0 deletions manifests/base/crds/kubeflow.org_pytorchjobs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ spec:
properties:
elasticPolicy:
properties:
failurePolicy:
type: string
maxReplicas:
description: upper limit for the number of pods that can be set
by the autoscaler; cannot be smaller than MinReplicas, defaults
Expand Down Expand Up @@ -522,6 +524,8 @@ spec:
--rdzv_endpoint, --rdzv_id are auto-assigned; any explicitly
set values are ignored.
type: boolean
successPolicy:
type: string
type: object
pytorchReplicaSpecs:
additionalProperties:
Expand Down
8 changes: 8 additions & 0 deletions pkg/apis/pytorch/v1/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,14 @@ func setElasticPolicy(job *PyTorchJob) {
job.Spec.ElasticPolicy.MaxReplicas = workerReplicas
job.Spec.ElasticPolicy.MinReplicas = workerReplicas
}
if job.Spec.ElasticPolicy.SuccessPolicy == nil {
policy := SuccessPolicyDefault
job.Spec.ElasticPolicy.SuccessPolicy = &policy
}
if job.Spec.ElasticPolicy.FailurePolicy == nil {
policy := FailurePolicyDefault
job.Spec.ElasticPolicy.FailurePolicy = &policy
}
}
}

Expand Down
14 changes: 13 additions & 1 deletion pkg/apis/pytorch/v1/openapi_generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 17 additions & 0 deletions pkg/apis/pytorch/v1/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,25 @@ type ElasticPolicy struct {
// If not set, the HPA will not be created.
// +optional
Metrics []autoscalingv2beta2.MetricSpec `json:"metrics,omitempty"`

SuccessPolicy *SuccessPolicy `json:"successPolicy,omitempty"`
FailurePolicy *FailurePolicy `json:"failurePolicy,omitempty"`
}

type SuccessPolicy string

const (
SuccessPolicyDefault SuccessPolicy = ""
SuccessPolicyAllWorkers SuccessPolicy = "AllWorkers"
)

type FailurePolicy string

const (
FailurePolicyDefault FailurePolicy = "" // if one pods fails, the job is set to be fail
FailurePolicyByMinReplicas FailurePolicy = "ByMinReplicas" // only if running pods is less than MinReplicas, the job is set to be fail
)

type RDZVConf struct {
Key string `json:"key,omitempty"`
Value string `json:"value,omitempty"`
Expand Down
10 changes: 10 additions & 0 deletions pkg/apis/pytorch/v1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions pkg/common/util/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,15 @@ func GetReplicaTypes(specs map[commonv1.ReplicaType]*commonv1.ReplicaSpec) []com
}
return keys
}

// GetContainerExitCode gets the container exit code from the given pod.
func GetContainerExitCode(pod *corev1.Pod, name string) int32 {
var exitCode int32 = 0xbeef // magic number
for _, status := range pod.Status.ContainerStatuses {
state := status.State
if status.Name == name && state.Terminated != nil {
exitCode = state.Terminated.ExitCode
}
}
return exitCode
}
27 changes: 27 additions & 0 deletions pkg/controller.v1/pytorch/envvar.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,40 @@ import (
commonv1 "github.com/kubeflow/common/pkg/apis/common/v1"
pytorchv1 "github.com/kubeflow/training-operator/pkg/apis/pytorch/v1"
corev1 "k8s.io/api/core/v1"
volcanov1beta1 "volcano.sh/apis/pkg/apis/scheduling/v1beta1"
)

// EnvVarGenerator is the environment variable generator interface.
type EnvVarGenerator interface {
Generate(job *pytorchv1.PyTorchJob) ([]corev1.EnvVar, error)
}

func setPodLabel(obj interface{}, podTemplateSpec *corev1.PodTemplateSpec, rtype, index string) error {
pytorchjob, ok := obj.(*pytorchv1.PyTorchJob)
if !ok {
return fmt.Errorf("%+v is not a type of PyTorchJob", obj)
}
if pytorchjob.Spec.PyTorchReplicaSpecs[pytorchv1.PyTorchReplicaTypeMaster] != nil {
if rtype == strings.ToLower(string(pytorchv1.PyTorchReplicaTypeMaster)) {
podTemplateSpec.Labels[volcanov1beta1.PodPreemptable] = "false"
} else {
podTemplateSpec.Labels[volcanov1beta1.PodPreemptable] = "true"
}
} else {
// If the master is null, then we need to set the volcano.sh/preemptable = false to make sure that work0 can not be preempted
rank, err := strconv.Atoi(index)
if err != nil {
return err
}
if rank == 0 {
podTemplateSpec.Labels[volcanov1beta1.PodPreemptable] = "false"
} else {
podTemplateSpec.Labels[volcanov1beta1.PodPreemptable] = "true"
}
}
return nil
}

func setPodEnv(obj interface{}, podTemplateSpec *corev1.PodTemplateSpec, rtype, index string) error {
pytorchjob, ok := obj.(*pytorchv1.PyTorchJob)
if !ok {
Expand Down
78 changes: 76 additions & 2 deletions pkg/controller.v1/pytorch/pytorchjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package pytorch
import (
"context"
"fmt"
"strings"
"time"

"github.com/go-logr/logr"
Expand Down Expand Up @@ -47,6 +48,7 @@ import (
"sigs.k8s.io/controller-runtime/pkg/manager"
"sigs.k8s.io/controller-runtime/pkg/predicate"
"sigs.k8s.io/controller-runtime/pkg/source"
volcanov1beta1 "volcano.sh/apis/pkg/apis/scheduling/v1beta1"
volcanoclient "volcano.sh/apis/pkg/client/clientset/versioned"

pytorchv1 "github.com/kubeflow/training-operator/pkg/apis/pytorch/v1"
Expand Down Expand Up @@ -109,6 +111,7 @@ type PyTorchJobReconciler struct {
//+kubebuilder:rbac:groups=kubeflow.org,resources=pytorchjobs/finalizers,verbs=update
//+kubebuilder:rbac:groups="",resources=pods,verbs=get;list;watch;create;update;patch;delete
//+kubebuilder:rbac:groups="",resources=services,verbs=get;list;watch;create;delete
//+kubebuilder:rbac:groups=scheduling.volcano.sh,resources=podgroups,verbs=get;list;watch;create;delete

// Reconcile is part of the main kubernetes reconciliation loop which aims to
// move the current state of the cluster closer to the desired state.
Expand Down Expand Up @@ -207,6 +210,18 @@ func (r *PyTorchJobReconciler) SetupWithManager(mgr ctrl.Manager) error {
return err
}

// inject watching for job related podgroup
if err = c.Watch(&source.Kind{Type: &volcanov1beta1.PodGroup{}}, &handler.EnqueueRequestForOwner{
IsController: true,
OwnerType: &pytorchv1.PyTorchJob{},
}, predicate.Funcs{
CreateFunc: util.OnDependentCreateFunc(r.Expectations),
UpdateFunc: util.OnDependentUpdateFunc(&r.JobController),
DeleteFunc: util.OnDependentDeleteFunc(r.Expectations),
}); err != nil {
return err
}

return nil
}

Expand Down Expand Up @@ -389,8 +404,15 @@ func (r *PyTorchJobReconciler) UpdateJobStatus(job interface{},
}
} else {
if rtype == pytorchv1.PyTorchReplicaTypeWorker {
// TODO(gaocegege): Support SuccessPolicy
if expected == 0 {
worker0Completed, err := r.IsWorker0Completed(pytorchjob, replicas)
if err != nil {
logger.Warnf("check if worker 0 completed error %v", err)
return err
}
// Leave a succeeded condition for the following two cases:
// 1. If default success policy is used and worker 0 has completed.
// 2. If `SuccessPolicyAllWorkers` success policy is used and all workers are succeeded.
if expected == 0 || (worker0Completed && *pytorchjob.Spec.ElasticPolicy.SuccessPolicy != pytorchv1.SuccessPolicyAllWorkers) {
msg := fmt.Sprintf("TFJob %s/%s successfully completed.",
pytorchjob.Namespace, pytorchjob.Name)
r.recorder.Event(pytorchjob, corev1.EventTypeNormal, commonutil.JobSucceededReason, msg)
Expand Down Expand Up @@ -428,6 +450,8 @@ func (r *PyTorchJobReconciler) UpdateJobStatus(job interface{},
return err
}
trainingoperatorcommon.RestartedJobsCounterInc(pytorchjob.Namespace, pytorchv1.FrameworkName)
} else if running >= *pytorchjob.Spec.ElasticPolicy.MinReplicas && *pytorchjob.Spec.ElasticPolicy.FailurePolicy == pytorchv1.FailurePolicyByMinReplicas {
// if running pods is greater or equal than MinReplicas, the job is running
} else {
msg := fmt.Sprintf("PyTorchJob %s is failed because %d %s replica(s) failed.", pytorchjob.Name, failed, rtype)
r.Recorder.Event(pytorchjob, corev1.EventTypeNormal, commonutil.JobFailedReason, msg)
Expand All @@ -448,6 +472,53 @@ func (r *PyTorchJobReconciler) UpdateJobStatus(job interface{},
return nil
}

// IsWorker0Completed returns true if pod of worker0 succeeded and exited with 0
func (p *PyTorchJobReconciler) IsWorker0Completed(job *pytorchv1.PyTorchJob,
replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec) (bool, error) {
worker0Completed := false
_, ok := replicas[pytorchv1.PyTorchReplicaTypeWorker]
if !ok {
return true, nil
}
podSlices, err := p.getPodSlices(job,
replicas[pytorchv1.PyTorchReplicaTypeWorker].Replicas)
if err != nil {
return false, err
}
for index, podSlice := range podSlices {
if len(podSlice) == 1 {
pod := podSlice[0]
exitCode := util.GetContainerExitCode(pod, pytorchv1.DefaultContainerName)
if index == 0 && exitCode == 0 && pod.Status.Phase == corev1.PodSucceeded {
worker0Completed = true
}
}
}
return worker0Completed, nil
}

// getPodSlices returns a slice, which element is the slice of pod.
// It gives enough information to caller to make decision to up/down scale resources.
func (p *PyTorchJobReconciler) getPodSlices(
job *pytorchv1.PyTorchJob, replicasNum *int32) ([][]*corev1.Pod, error) {
logger := commonutil.LoggerForReplica(job, strings.ToLower(string(pytorchv1.PyTorchReplicaTypeWorker)))

pods, err := p.GetPodsForJob(job)
if err != nil {
commonutil.LoggerForJob(job).Warnf("getPodsForTFJob error %v", err)
return nil, err
}

// Get all pods for the type rt.
pods, err = p.JobController.FilterPodsForReplicaType(pods, strings.ToLower(string(pytorchv1.PyTorchReplicaTypeWorker)))
if err != nil {
return nil, err
}

podSlices := p.GetPodSlices(pods, int(*replicasNum), logger)
return podSlices, nil
}

// ContainsMasterSpec returns true if the tfjob contains master spec.
func ContainsMasterSpec(replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec) bool {
if _, ok := replicas[pytorchv1.PyTorchReplicaTypeMaster]; ok {
Expand Down Expand Up @@ -489,6 +560,9 @@ func (r *PyTorchJobReconciler) UpdateJobStatusInApiServer(job interface{}, jobSt

// SetClusterSpec sets the cluster spec and init container for the pod
func (r *PyTorchJobReconciler) SetClusterSpec(job interface{}, podTemplate *corev1.PodTemplateSpec, rtype, index string) error {
if err := setPodLabel(job, podTemplate, rtype, index); err != nil {
return err
}
if err := setPodEnv(job, podTemplate, rtype, index); err != nil {
return err
}
Expand Down
10 changes: 6 additions & 4 deletions pkg/controller.v1/pytorch/pytorchjob_controller_suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ package pytorch

import (
"context"
"path/filepath"
"testing"

v1 "github.com/kubeflow/training-operator/pkg/apis/pytorch/v1"
"github.com/kubeflow/training-operator/pkg/config"
"path/filepath"
"testing"
volcanov1beta1 "volcano.sh/apis/pkg/apis/scheduling/v1beta1"

. "github.com/onsi/ginkgo"
"github.com/onsi/gomega"
Expand Down Expand Up @@ -75,8 +75,10 @@ var _ = BeforeSuite(func() {
config.Config.PyTorchInitContainerImage = config.PyTorchInitContainerImageDefault
config.Config.PyTorchInitContainerTemplateFile = config.PyTorchInitContainerTemplateFileDefault

//+kubebuilder:scaffold:scheme
err = volcanov1beta1.AddToScheme(scheme.Scheme)
Expect(err).NotTo(HaveOccurred())

//+kubebuilder:scaffold:scheme
testK8sClient, err = client.New(cfg, client.Options{Scheme: scheme.Scheme})
Expect(err).NotTo(HaveOccurred())
Expect(testK8sClient).NotTo(BeNil())
Expand Down

0 comments on commit c50681b

Please sign in to comment.