diff --git a/manifests/base/crds/kubeflow.org_pytorchjobs.yaml b/manifests/base/crds/kubeflow.org_pytorchjobs.yaml index daedf9b93b..dcabc5ee19 100644 --- a/manifests/base/crds/kubeflow.org_pytorchjobs.yaml +++ b/manifests/base/crds/kubeflow.org_pytorchjobs.yaml @@ -45,6 +45,8 @@ spec: properties: elasticPolicy: properties: + failurePolicy: + type: string maxReplicas: description: upper limit for the number of pods that can be set by the autoscaler; cannot be smaller than MinReplicas, defaults @@ -522,6 +524,8 @@ spec: --rdzv_endpoint, --rdzv_id are auto-assigned; any explicitly set values are ignored. type: boolean + successPolicy: + type: string type: object pytorchReplicaSpecs: additionalProperties: diff --git a/pkg/apis/pytorch/v1/defaults.go b/pkg/apis/pytorch/v1/defaults.go index 12d2274558..4bbd01c953 100644 --- a/pkg/apis/pytorch/v1/defaults.go +++ b/pkg/apis/pytorch/v1/defaults.go @@ -73,6 +73,14 @@ func setElasticPolicy(job *PyTorchJob) { job.Spec.ElasticPolicy.MaxReplicas = workerReplicas job.Spec.ElasticPolicy.MinReplicas = workerReplicas } + if job.Spec.ElasticPolicy.SuccessPolicy == nil { + policy := SuccessPolicyDefault + job.Spec.ElasticPolicy.SuccessPolicy = &policy + } + if job.Spec.ElasticPolicy.FailurePolicy == nil { + policy := FailurePolicyDefault + job.Spec.ElasticPolicy.FailurePolicy = &policy + } } } diff --git a/pkg/apis/pytorch/v1/openapi_generated.go b/pkg/apis/pytorch/v1/openapi_generated.go index 943f4d170c..1b5537908c 100644 --- a/pkg/apis/pytorch/v1/openapi_generated.go +++ b/pkg/apis/pytorch/v1/openapi_generated.go @@ -410,7 +410,7 @@ func schema_pkg_apis_pytorch_v1_ElasticPolicy(ref common.ReferenceCallback) comm }, "metrics": { SchemaProps: spec.SchemaProps{ - Description: "metrics contains the specifications for which to use to calculate the desired replica count (the maximum replica count across all metrics will be used). The desired replica count is calculated multiplying the ratio between the target value and the current value by the current number of pods. Ergo, metrics used must decrease as the pod count is increased, and vice-versa. See the individual metric source types for more information about how each type of metric must respond. If not set, the default metric will be set to 80% average CPU utilization.", + Description: "Metrics contains the specifications which are used to calculate the desired replica count (the maximum replica count across all metrics will be used). The desired replica count is calculated with multiplying the ratio between the target value and the current value by the current number of pods. Ergo, metrics used must decrease as the pod count is increased, and vice-versa. See the individual metric source types for more information about how each type of metric must respond. If not set, the HPA will not be created.", Type: []string{"array"}, Items: &spec.SchemaOrArray{ Schema: &spec.Schema{ @@ -421,6 +421,18 @@ func schema_pkg_apis_pytorch_v1_ElasticPolicy(ref common.ReferenceCallback) comm }, }, }, + "successPolicy": { + SchemaProps: spec.SchemaProps{ + Type: []string{"string"}, + Format: "", + }, + }, + "failurePolicy": { + SchemaProps: spec.SchemaProps{ + Type: []string{"string"}, + Format: "", + }, + }, }, }, }, diff --git a/pkg/apis/pytorch/v1/types.go b/pkg/apis/pytorch/v1/types.go index 2f9e973922..335caa6132 100644 --- a/pkg/apis/pytorch/v1/types.go +++ b/pkg/apis/pytorch/v1/types.go @@ -98,8 +98,25 @@ type ElasticPolicy struct { // If not set, the HPA will not be created. // +optional Metrics []autoscalingv2beta2.MetricSpec `json:"metrics,omitempty"` + + SuccessPolicy *SuccessPolicy `json:"successPolicy,omitempty"` + FailurePolicy *FailurePolicy `json:"failurePolicy,omitempty"` } +type SuccessPolicy string + +const ( + SuccessPolicyDefault SuccessPolicy = "" // if worker0 is success, the job is set to be success + SuccessPolicyAllWorkers SuccessPolicy = "AllWorkers" // only if all pods is success, the job is set to be success +) + +type FailurePolicy string + +const ( + FailurePolicyDefault FailurePolicy = "" // if one pods fails, the job is set to be fail + FailurePolicyByMinReplicas FailurePolicy = "ByMinReplicas" // only if running pods is less than MinReplicas, the job is set to be fail +) + type RDZVConf struct { Key string `json:"key,omitempty"` Value string `json:"value,omitempty"` diff --git a/pkg/apis/pytorch/v1/zz_generated.deepcopy.go b/pkg/apis/pytorch/v1/zz_generated.deepcopy.go index 1fc845ff7b..1e518ee967 100644 --- a/pkg/apis/pytorch/v1/zz_generated.deepcopy.go +++ b/pkg/apis/pytorch/v1/zz_generated.deepcopy.go @@ -85,6 +85,16 @@ func (in *ElasticPolicy) DeepCopyInto(out *ElasticPolicy) { (*in)[i].DeepCopyInto(&(*out)[i]) } } + if in.SuccessPolicy != nil { + in, out := &in.SuccessPolicy, &out.SuccessPolicy + *out = new(SuccessPolicy) + **out = **in + } + if in.FailurePolicy != nil { + in, out := &in.FailurePolicy, &out.FailurePolicy + *out = new(FailurePolicy) + **out = **in + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ElasticPolicy. diff --git a/pkg/common/util/util.go b/pkg/common/util/util.go index 0ddef5ed74..f6b346193f 100644 --- a/pkg/common/util/util.go +++ b/pkg/common/util/util.go @@ -72,3 +72,15 @@ func GetReplicaTypes(specs map[commonv1.ReplicaType]*commonv1.ReplicaSpec) []com } return keys } + +// GetContainerExitCode gets the container exit code from the given pod. +func GetContainerExitCode(pod *corev1.Pod, name string) int32 { + var exitCode int32 = 0xbeef // magic number + for _, status := range pod.Status.ContainerStatuses { + state := status.State + if status.Name == name && state.Terminated != nil { + exitCode = state.Terminated.ExitCode + } + } + return exitCode +} diff --git a/pkg/controller.v1/pytorch/label.go b/pkg/controller.v1/pytorch/label.go new file mode 100644 index 0000000000..976b63805f --- /dev/null +++ b/pkg/controller.v1/pytorch/label.go @@ -0,0 +1,40 @@ +package pytorch + +import ( + "fmt" + "strconv" + "strings" + + pytorchv1 "github.com/kubeflow/training-operator/pkg/apis/pytorch/v1" + corev1 "k8s.io/api/core/v1" + volcanov1beta1 "volcano.sh/apis/pkg/apis/scheduling/v1beta1" +) + +func setPodLabel(obj interface{}, podTemplateSpec *corev1.PodTemplateSpec, rtype, index string) error { + pytorchjob, ok := obj.(*pytorchv1.PyTorchJob) + if !ok { + return fmt.Errorf("%+v is not a type of PyTorchJob", obj) + } + if len(podTemplateSpec.Labels) == 0 { + podTemplateSpec.Labels = make(map[string]string) + } + if pytorchjob.Spec.PyTorchReplicaSpecs[pytorchv1.PyTorchReplicaTypeMaster] != nil { + if rtype == strings.ToLower(string(pytorchv1.PyTorchReplicaTypeMaster)) { + podTemplateSpec.Labels[volcanov1beta1.PodPreemptable] = "false" + } else { + podTemplateSpec.Labels[volcanov1beta1.PodPreemptable] = "true" + } + } else { + // If the master is null, then we need to set the volcano.sh/preemptable = false to make sure that work0 can not be preempted + rank, err := strconv.Atoi(index) + if err != nil { + return err + } + if rank == 0 { + podTemplateSpec.Labels[volcanov1beta1.PodPreemptable] = "false" + } else { + podTemplateSpec.Labels[volcanov1beta1.PodPreemptable] = "true" + } + } + return nil +} diff --git a/pkg/controller.v1/pytorch/pytorchjob_controller.go b/pkg/controller.v1/pytorch/pytorchjob_controller.go index 6a523a70f2..35ac474d28 100644 --- a/pkg/controller.v1/pytorch/pytorchjob_controller.go +++ b/pkg/controller.v1/pytorch/pytorchjob_controller.go @@ -17,6 +17,7 @@ package pytorch import ( "context" "fmt" + "strings" "time" "github.com/go-logr/logr" @@ -389,8 +390,15 @@ func (r *PyTorchJobReconciler) UpdateJobStatus(job interface{}, } } else { if rtype == pytorchv1.PyTorchReplicaTypeWorker { - // TODO(gaocegege): Support SuccessPolicy - if expected == 0 { + worker0Completed, err := r.IsWorker0Completed(pytorchjob, replicas) + if err != nil { + logger.Warnf("check if worker 0 completed error %v", err) + return err + } + // Leave a succeeded condition for the following two cases: + // 1. If default success policy is used and worker 0 has completed. + // 2. If `SuccessPolicyAllWorkers` success policy is used and all workers are succeeded. + if expected == 0 || (worker0Completed && *pytorchjob.Spec.ElasticPolicy.SuccessPolicy != pytorchv1.SuccessPolicyAllWorkers) { msg := fmt.Sprintf("TFJob %s/%s successfully completed.", pytorchjob.Namespace, pytorchjob.Name) r.recorder.Event(pytorchjob, corev1.EventTypeNormal, commonutil.JobSucceededReason, msg) @@ -428,7 +436,7 @@ func (r *PyTorchJobReconciler) UpdateJobStatus(job interface{}, return err } trainingoperatorcommon.RestartedJobsCounterInc(pytorchjob.Namespace, pytorchv1.FrameworkName) - } else { + } else if running < *pytorchjob.Spec.ElasticPolicy.MinReplicas || *pytorchjob.Spec.ElasticPolicy.FailurePolicy == pytorchv1.FailurePolicyDefault { msg := fmt.Sprintf("PyTorchJob %s is failed because %d %s replica(s) failed.", pytorchjob.Name, failed, rtype) r.Recorder.Event(pytorchjob, corev1.EventTypeNormal, commonutil.JobFailedReason, msg) if pytorchjob.Status.CompletionTime == nil { @@ -442,12 +450,58 @@ func (r *PyTorchJobReconciler) UpdateJobStatus(job interface{}, } trainingoperatorcommon.FailedJobsCounterInc(pytorchjob.Namespace, pytorchv1.FrameworkName) } + // if running pods is greater or equal than MinReplicas, the job is running } } return nil } +// IsWorker0Completed returns true if pod of worker0 succeeded and qexited with 0 +func (p *PyTorchJobReconciler) IsWorker0Completed(job *pytorchv1.PyTorchJob, + replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec) (bool, error) { + worker0Completed := false + _, ok := replicas[pytorchv1.PyTorchReplicaTypeWorker] + if !ok { + return true, nil + } + podSlices, err := p.getPodSlices(job, replicas[pytorchv1.PyTorchReplicaTypeWorker].Replicas, string(pytorchv1.PyTorchReplicaTypeWorker)) + if err != nil { + return false, err + } + for index, podSlice := range podSlices { + if len(podSlice) == 1 { + pod := podSlice[0] + exitCode := util.GetContainerExitCode(pod, pytorchv1.DefaultContainerName) + if index == 0 && exitCode == 0 && pod.Status.Phase == corev1.PodSucceeded { + worker0Completed = true + } + } + } + return worker0Completed, nil +} + +// getPodSlices returns a slice, which element is the slice of pod. +// It gives enough information to caller to make decision to up/down scale resources. +func (p *PyTorchJobReconciler) getPodSlices(job *pytorchv1.PyTorchJob, replicasNum *int32, rtype string) ([][]*corev1.Pod, error) { + logger := commonutil.LoggerForReplica(job, strings.ToLower(rtype)) + + pods, err := p.GetPodsForJob(job) + if err != nil { + commonutil.LoggerForJob(job).Warnf("getPodsForTFJob error %v", err) + return nil, err + } + + // Get all pods for the type rt. + pods, err = p.JobController.FilterPodsForReplicaType(pods, strings.ToLower(rtype)) + if err != nil { + return nil, err + } + + podSlices := p.GetPodSlices(pods, int(*replicasNum), logger) + return podSlices, nil +} + // ContainsMasterSpec returns true if the tfjob contains master spec. func ContainsMasterSpec(replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec) bool { if _, ok := replicas[pytorchv1.PyTorchReplicaTypeMaster]; ok { @@ -489,6 +543,9 @@ func (r *PyTorchJobReconciler) UpdateJobStatusInApiServer(job interface{}, jobSt // SetClusterSpec sets the cluster spec and init container for the pod func (r *PyTorchJobReconciler) SetClusterSpec(job interface{}, podTemplate *corev1.PodTemplateSpec, rtype, index string) error { + if err := setPodLabel(job, podTemplate, rtype, index); err != nil { + return err + } if err := setPodEnv(job, podTemplate, rtype, index); err != nil { return err }