From 679807eed8561aec4cbfe58d9d4ef36e175194b7 Mon Sep 17 00:00:00 2001 From: qiankunli Date: Mon, 11 Apr 2022 20:53:04 +0800 Subject: [PATCH] support successPolicy and failurePolicy Signed-off-by: qiankunli run codegen Signed-off-by: qiankunli support watch pg and preemptable label Signed-off-by: qiankunli fix test case Signed-off-by: qiankunli fix test case Signed-off-by: qiankunli fix test case Signed-off-by: qiankunli --- cmd/training-operator.v1/main.go | 2 + .../base/crds/kubeflow.org_pytorchjobs.yaml | 4 + pkg/apis/pytorch/v1/defaults.go | 8 ++ pkg/apis/pytorch/v1/openapi_generated.go | 14 +++- pkg/apis/pytorch/v1/types.go | 17 ++++ pkg/apis/pytorch/v1/zz_generated.deepcopy.go | 10 +++ pkg/common/util/util.go | 12 +++ pkg/controller.v1/pytorch/envvar.go | 27 +++++++ .../pytorch/pytorchjob_controller.go | 78 ++++++++++++++++++- .../pytorchjob_controller_suite_test.go | 15 ++-- 10 files changed, 178 insertions(+), 9 deletions(-) diff --git a/cmd/training-operator.v1/main.go b/cmd/training-operator.v1/main.go index 0f849492b9..fa1854d35c 100644 --- a/cmd/training-operator.v1/main.go +++ b/cmd/training-operator.v1/main.go @@ -39,6 +39,7 @@ import ( "github.com/kubeflow/training-operator/pkg/config" controllerv1 "github.com/kubeflow/training-operator/pkg/controller.v1" //+kubebuilder:scaffold:imports + volcanov1beta1 "volcano.sh/apis/pkg/apis/scheduling/v1beta1" ) var ( @@ -54,6 +55,7 @@ func init() { utilruntime.Must(mxnetv1.AddToScheme(scheme)) utilruntime.Must(mpiv1.AddToScheme(scheme)) //+kubebuilder:scaffold:scheme + utilruntime.Must(volcanov1beta1.AddToScheme(scheme)) } func main() { diff --git a/manifests/base/crds/kubeflow.org_pytorchjobs.yaml b/manifests/base/crds/kubeflow.org_pytorchjobs.yaml index daedf9b93b..dcabc5ee19 100644 --- a/manifests/base/crds/kubeflow.org_pytorchjobs.yaml +++ b/manifests/base/crds/kubeflow.org_pytorchjobs.yaml @@ -45,6 +45,8 @@ spec: properties: elasticPolicy: properties: + failurePolicy: + type: string maxReplicas: description: upper limit for the number of pods that can be set by the autoscaler; cannot be smaller than MinReplicas, defaults @@ -522,6 +524,8 @@ spec: --rdzv_endpoint, --rdzv_id are auto-assigned; any explicitly set values are ignored. type: boolean + successPolicy: + type: string type: object pytorchReplicaSpecs: additionalProperties: diff --git a/pkg/apis/pytorch/v1/defaults.go b/pkg/apis/pytorch/v1/defaults.go index 12d2274558..4bbd01c953 100644 --- a/pkg/apis/pytorch/v1/defaults.go +++ b/pkg/apis/pytorch/v1/defaults.go @@ -73,6 +73,14 @@ func setElasticPolicy(job *PyTorchJob) { job.Spec.ElasticPolicy.MaxReplicas = workerReplicas job.Spec.ElasticPolicy.MinReplicas = workerReplicas } + if job.Spec.ElasticPolicy.SuccessPolicy == nil { + policy := SuccessPolicyDefault + job.Spec.ElasticPolicy.SuccessPolicy = &policy + } + if job.Spec.ElasticPolicy.FailurePolicy == nil { + policy := FailurePolicyDefault + job.Spec.ElasticPolicy.FailurePolicy = &policy + } } } diff --git a/pkg/apis/pytorch/v1/openapi_generated.go b/pkg/apis/pytorch/v1/openapi_generated.go index 943f4d170c..1b5537908c 100644 --- a/pkg/apis/pytorch/v1/openapi_generated.go +++ b/pkg/apis/pytorch/v1/openapi_generated.go @@ -410,7 +410,7 @@ func schema_pkg_apis_pytorch_v1_ElasticPolicy(ref common.ReferenceCallback) comm }, "metrics": { SchemaProps: spec.SchemaProps{ - Description: "metrics contains the specifications for which to use to calculate the desired replica count (the maximum replica count across all metrics will be used). The desired replica count is calculated multiplying the ratio between the target value and the current value by the current number of pods. Ergo, metrics used must decrease as the pod count is increased, and vice-versa. See the individual metric source types for more information about how each type of metric must respond. If not set, the default metric will be set to 80% average CPU utilization.", + Description: "Metrics contains the specifications which are used to calculate the desired replica count (the maximum replica count across all metrics will be used). The desired replica count is calculated with multiplying the ratio between the target value and the current value by the current number of pods. Ergo, metrics used must decrease as the pod count is increased, and vice-versa. See the individual metric source types for more information about how each type of metric must respond. If not set, the HPA will not be created.", Type: []string{"array"}, Items: &spec.SchemaOrArray{ Schema: &spec.Schema{ @@ -421,6 +421,18 @@ func schema_pkg_apis_pytorch_v1_ElasticPolicy(ref common.ReferenceCallback) comm }, }, }, + "successPolicy": { + SchemaProps: spec.SchemaProps{ + Type: []string{"string"}, + Format: "", + }, + }, + "failurePolicy": { + SchemaProps: spec.SchemaProps{ + Type: []string{"string"}, + Format: "", + }, + }, }, }, }, diff --git a/pkg/apis/pytorch/v1/types.go b/pkg/apis/pytorch/v1/types.go index 2f9e973922..09b699f9fb 100644 --- a/pkg/apis/pytorch/v1/types.go +++ b/pkg/apis/pytorch/v1/types.go @@ -98,8 +98,25 @@ type ElasticPolicy struct { // If not set, the HPA will not be created. // +optional Metrics []autoscalingv2beta2.MetricSpec `json:"metrics,omitempty"` + + SuccessPolicy *SuccessPolicy `json:"successPolicy,omitempty"` + FailurePolicy *FailurePolicy `json:"failurePolicy,omitempty"` } +type SuccessPolicy string + +const ( + SuccessPolicyDefault SuccessPolicy = "" + SuccessPolicyAllWorkers SuccessPolicy = "AllWorkers" +) + +type FailurePolicy string + +const ( + FailurePolicyDefault FailurePolicy = "" // if one pods fails, the job is set to be fail + FailurePolicyByMinReplicas FailurePolicy = "ByMinReplicas" // only if running pods is less than MinReplicas, the job is set to be fail +) + type RDZVConf struct { Key string `json:"key,omitempty"` Value string `json:"value,omitempty"` diff --git a/pkg/apis/pytorch/v1/zz_generated.deepcopy.go b/pkg/apis/pytorch/v1/zz_generated.deepcopy.go index 1fc845ff7b..1e518ee967 100644 --- a/pkg/apis/pytorch/v1/zz_generated.deepcopy.go +++ b/pkg/apis/pytorch/v1/zz_generated.deepcopy.go @@ -85,6 +85,16 @@ func (in *ElasticPolicy) DeepCopyInto(out *ElasticPolicy) { (*in)[i].DeepCopyInto(&(*out)[i]) } } + if in.SuccessPolicy != nil { + in, out := &in.SuccessPolicy, &out.SuccessPolicy + *out = new(SuccessPolicy) + **out = **in + } + if in.FailurePolicy != nil { + in, out := &in.FailurePolicy, &out.FailurePolicy + *out = new(FailurePolicy) + **out = **in + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ElasticPolicy. diff --git a/pkg/common/util/util.go b/pkg/common/util/util.go index 0ddef5ed74..f6b346193f 100644 --- a/pkg/common/util/util.go +++ b/pkg/common/util/util.go @@ -72,3 +72,15 @@ func GetReplicaTypes(specs map[commonv1.ReplicaType]*commonv1.ReplicaSpec) []com } return keys } + +// GetContainerExitCode gets the container exit code from the given pod. +func GetContainerExitCode(pod *corev1.Pod, name string) int32 { + var exitCode int32 = 0xbeef // magic number + for _, status := range pod.Status.ContainerStatuses { + state := status.State + if status.Name == name && state.Terminated != nil { + exitCode = state.Terminated.ExitCode + } + } + return exitCode +} diff --git a/pkg/controller.v1/pytorch/envvar.go b/pkg/controller.v1/pytorch/envvar.go index 74b30bc84a..e03368e4a9 100644 --- a/pkg/controller.v1/pytorch/envvar.go +++ b/pkg/controller.v1/pytorch/envvar.go @@ -22,6 +22,7 @@ import ( commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" pytorchv1 "github.com/kubeflow/training-operator/pkg/apis/pytorch/v1" corev1 "k8s.io/api/core/v1" + volcanov1beta1 "volcano.sh/apis/pkg/apis/scheduling/v1beta1" ) // EnvVarGenerator is the environment variable generator interface. @@ -29,6 +30,32 @@ type EnvVarGenerator interface { Generate(job *pytorchv1.PyTorchJob) ([]corev1.EnvVar, error) } +func setPodLabel(obj interface{}, podTemplateSpec *corev1.PodTemplateSpec, rtype, index string) error { + pytorchjob, ok := obj.(*pytorchv1.PyTorchJob) + if !ok { + return fmt.Errorf("%+v is not a type of PyTorchJob", obj) + } + if pytorchjob.Spec.PyTorchReplicaSpecs[pytorchv1.PyTorchReplicaTypeMaster] != nil { + if rtype == strings.ToLower(string(pytorchv1.PyTorchReplicaTypeMaster)) { + podTemplateSpec.Labels[volcanov1beta1.PodPreemptable] = "false" + } else { + podTemplateSpec.Labels[volcanov1beta1.PodPreemptable] = "true" + } + } else { + // If the master is null, then we need to set the volcano.sh/preemptable = false to make sure that work0 can not be preempted + rank, err := strconv.Atoi(index) + if err != nil { + return err + } + if rank == 0 { + podTemplateSpec.Labels[volcanov1beta1.PodPreemptable] = "false" + } else { + podTemplateSpec.Labels[volcanov1beta1.PodPreemptable] = "true" + } + } + return nil +} + func setPodEnv(obj interface{}, podTemplateSpec *corev1.PodTemplateSpec, rtype, index string) error { pytorchjob, ok := obj.(*pytorchv1.PyTorchJob) if !ok { diff --git a/pkg/controller.v1/pytorch/pytorchjob_controller.go b/pkg/controller.v1/pytorch/pytorchjob_controller.go index 6a523a70f2..09df99c997 100644 --- a/pkg/controller.v1/pytorch/pytorchjob_controller.go +++ b/pkg/controller.v1/pytorch/pytorchjob_controller.go @@ -17,6 +17,7 @@ package pytorch import ( "context" "fmt" + "strings" "time" "github.com/go-logr/logr" @@ -47,6 +48,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/manager" "sigs.k8s.io/controller-runtime/pkg/predicate" "sigs.k8s.io/controller-runtime/pkg/source" + volcanov1beta1 "volcano.sh/apis/pkg/apis/scheduling/v1beta1" volcanoclient "volcano.sh/apis/pkg/client/clientset/versioned" pytorchv1 "github.com/kubeflow/training-operator/pkg/apis/pytorch/v1" @@ -109,6 +111,7 @@ type PyTorchJobReconciler struct { //+kubebuilder:rbac:groups=kubeflow.org,resources=pytorchjobs/finalizers,verbs=update //+kubebuilder:rbac:groups="",resources=pods,verbs=get;list;watch;create;update;patch;delete //+kubebuilder:rbac:groups="",resources=services,verbs=get;list;watch;create;delete +//+kubebuilder:rbac:groups=scheduling.volcano.sh,resources=podgroups,verbs=get;list;watch;create;delete // Reconcile is part of the main kubernetes reconciliation loop which aims to // move the current state of the cluster closer to the desired state. @@ -207,6 +210,18 @@ func (r *PyTorchJobReconciler) SetupWithManager(mgr ctrl.Manager) error { return err } + // inject watching for job related podgroup + if err = c.Watch(&source.Kind{Type: &volcanov1beta1.PodGroup{}}, &handler.EnqueueRequestForOwner{ + IsController: true, + OwnerType: &pytorchv1.PyTorchJob{}, + }, predicate.Funcs{ + CreateFunc: util.OnDependentCreateFunc(r.Expectations), + UpdateFunc: util.OnDependentUpdateFunc(&r.JobController), + DeleteFunc: util.OnDependentDeleteFunc(r.Expectations), + }); err != nil { + return err + } + return nil } @@ -389,8 +404,15 @@ func (r *PyTorchJobReconciler) UpdateJobStatus(job interface{}, } } else { if rtype == pytorchv1.PyTorchReplicaTypeWorker { - // TODO(gaocegege): Support SuccessPolicy - if expected == 0 { + worker0Completed, err := r.IsWorker0Completed(pytorchjob, replicas) + if err != nil { + logger.Warnf("check if worker 0 completed error %v", err) + return err + } + // Leave a succeeded condition for the following two cases: + // 1. If default success policy is used and worker 0 has completed. + // 2. If `SuccessPolicyAllWorkers` success policy is used and all workers are succeeded. + if expected == 0 || (worker0Completed && *pytorchjob.Spec.ElasticPolicy.SuccessPolicy != pytorchv1.SuccessPolicyAllWorkers) { msg := fmt.Sprintf("TFJob %s/%s successfully completed.", pytorchjob.Namespace, pytorchjob.Name) r.recorder.Event(pytorchjob, corev1.EventTypeNormal, commonutil.JobSucceededReason, msg) @@ -428,6 +450,8 @@ func (r *PyTorchJobReconciler) UpdateJobStatus(job interface{}, return err } trainingoperatorcommon.RestartedJobsCounterInc(pytorchjob.Namespace, pytorchv1.FrameworkName) + } else if running >= *pytorchjob.Spec.ElasticPolicy.MinReplicas && *pytorchjob.Spec.ElasticPolicy.FailurePolicy == pytorchv1.FailurePolicyByMinReplicas { + // if running pods is greater or equal than MinReplicas, the job is running } else { msg := fmt.Sprintf("PyTorchJob %s is failed because %d %s replica(s) failed.", pytorchjob.Name, failed, rtype) r.Recorder.Event(pytorchjob, corev1.EventTypeNormal, commonutil.JobFailedReason, msg) @@ -448,6 +472,53 @@ func (r *PyTorchJobReconciler) UpdateJobStatus(job interface{}, return nil } +// IsWorker0Completed returns true if pod of worker0 succeeded and exited with 0 +func (p *PyTorchJobReconciler) IsWorker0Completed(job *pytorchv1.PyTorchJob, + replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec) (bool, error) { + worker0Completed := false + _, ok := replicas[pytorchv1.PyTorchReplicaTypeWorker] + if !ok { + return true, nil + } + podSlices, err := p.getPodSlices(job, + replicas[pytorchv1.PyTorchReplicaTypeWorker].Replicas) + if err != nil { + return false, err + } + for index, podSlice := range podSlices { + if len(podSlice) == 1 { + pod := podSlice[0] + exitCode := util.GetContainerExitCode(pod, pytorchv1.DefaultContainerName) + if index == 0 && exitCode == 0 && pod.Status.Phase == corev1.PodSucceeded { + worker0Completed = true + } + } + } + return worker0Completed, nil +} + +// getPodSlices returns a slice, which element is the slice of pod. +// It gives enough information to caller to make decision to up/down scale resources. +func (p *PyTorchJobReconciler) getPodSlices( + job *pytorchv1.PyTorchJob, replicasNum *int32) ([][]*corev1.Pod, error) { + logger := commonutil.LoggerForReplica(job, strings.ToLower(string(pytorchv1.PyTorchReplicaTypeWorker))) + + pods, err := p.GetPodsForJob(job) + if err != nil { + commonutil.LoggerForJob(job).Warnf("getPodsForTFJob error %v", err) + return nil, err + } + + // Get all pods for the type rt. + pods, err = p.JobController.FilterPodsForReplicaType(pods, strings.ToLower(string(pytorchv1.PyTorchReplicaTypeWorker))) + if err != nil { + return nil, err + } + + podSlices := p.GetPodSlices(pods, int(*replicasNum), logger) + return podSlices, nil +} + // ContainsMasterSpec returns true if the tfjob contains master spec. func ContainsMasterSpec(replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec) bool { if _, ok := replicas[pytorchv1.PyTorchReplicaTypeMaster]; ok { @@ -489,6 +560,9 @@ func (r *PyTorchJobReconciler) UpdateJobStatusInApiServer(job interface{}, jobSt // SetClusterSpec sets the cluster spec and init container for the pod func (r *PyTorchJobReconciler) SetClusterSpec(job interface{}, podTemplate *corev1.PodTemplateSpec, rtype, index string) error { + if err := setPodLabel(job, podTemplate, rtype, index); err != nil { + return err + } if err := setPodEnv(job, podTemplate, rtype, index); err != nil { return err } diff --git a/pkg/controller.v1/pytorch/pytorchjob_controller_suite_test.go b/pkg/controller.v1/pytorch/pytorchjob_controller_suite_test.go index 208edbd657..3e0b657f0d 100644 --- a/pkg/controller.v1/pytorch/pytorchjob_controller_suite_test.go +++ b/pkg/controller.v1/pytorch/pytorchjob_controller_suite_test.go @@ -16,22 +16,22 @@ package pytorch import ( "context" - "path/filepath" - "testing" - v1 "github.com/kubeflow/training-operator/pkg/apis/pytorch/v1" "github.com/kubeflow/training-operator/pkg/config" - . "github.com/onsi/ginkgo" "github.com/onsi/gomega" . "github.com/onsi/gomega" + "k8s.io/apimachinery/pkg/runtime" "k8s.io/client-go/kubernetes/scheme" + "path/filepath" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/envtest" "sigs.k8s.io/controller-runtime/pkg/envtest/printer" logf "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/log/zap" + "testing" + volcanov1beta1 "volcano.sh/apis/pkg/apis/scheduling/v1beta1" //+kubebuilder:scaffold:imports ) @@ -68,20 +68,23 @@ var _ = BeforeSuite(func() { Expect(err).NotTo(HaveOccurred()) Expect(cfg).NotTo(BeNil()) + s := runtime.NewScheme() err = v1.AddToScheme(scheme.Scheme) Expect(err).NotTo(HaveOccurred()) + err = volcanov1beta1.AddToScheme(s) + Expect(err).NotTo(HaveOccurred()) // Set default config. config.Config.PyTorchInitContainerImage = config.PyTorchInitContainerImageDefault config.Config.PyTorchInitContainerTemplateFile = config.PyTorchInitContainerTemplateFileDefault //+kubebuilder:scaffold:scheme - - testK8sClient, err = client.New(cfg, client.Options{Scheme: scheme.Scheme}) + testK8sClient, err = client.New(cfg, client.Options{Scheme: s}) Expect(err).NotTo(HaveOccurred()) Expect(testK8sClient).NotTo(BeNil()) mgr, err := ctrl.NewManager(cfg, ctrl.Options{ + Scheme: s, MetricsBindAddress: "0", }) Expect(err).NotTo(gomega.HaveOccurred())