From 552e42bb18c2bd241c2746523a4358e5596b2663 Mon Sep 17 00:00:00 2001 From: qiankunli Date: Mon, 11 Apr 2022 20:53:04 +0800 Subject: [PATCH] support successPolicy and failurePolicy Signed-off-by: qiankunli run codegen Signed-off-by: qiankunli support watch pg and preemptable label Signed-off-by: qiankunli fix test case Signed-off-by: qiankunli fix test case Signed-off-by: qiankunli fix test case Signed-off-by: qiankunli refactor Signed-off-by: qiankunli fix make Signed-off-by: qiankunli fix test Signed-off-by: qiankunli add corev1 schema Signed-off-by: qiankunli add podgroups crd Signed-off-by: qiankunli --- cmd/training-operator.v1/main.go | 2 + .../base/crds/kubeflow.org_pytorchjobs.yaml | 4 + .../crds/scheduling.volcano.sh_podgroups.yaml | 301 ++++++++++++++++++ pkg/apis/pytorch/v1/defaults.go | 8 + pkg/apis/pytorch/v1/openapi_generated.go | 14 +- pkg/apis/pytorch/v1/types.go | 17 + pkg/apis/pytorch/v1/zz_generated.deepcopy.go | 10 + pkg/common/util/util.go | 12 + pkg/controller.v1/pytorch/label.go | 40 +++ .../pytorch/pytorchjob_controller.go | 77 ++++- .../pytorchjob_controller_suite_test.go | 16 +- 11 files changed, 492 insertions(+), 9 deletions(-) create mode 100644 manifests/base/crds/scheduling.volcano.sh_podgroups.yaml create mode 100644 pkg/controller.v1/pytorch/label.go diff --git a/cmd/training-operator.v1/main.go b/cmd/training-operator.v1/main.go index 0f849492b9..fa1854d35c 100644 --- a/cmd/training-operator.v1/main.go +++ b/cmd/training-operator.v1/main.go @@ -39,6 +39,7 @@ import ( "github.com/kubeflow/training-operator/pkg/config" controllerv1 "github.com/kubeflow/training-operator/pkg/controller.v1" //+kubebuilder:scaffold:imports + volcanov1beta1 "volcano.sh/apis/pkg/apis/scheduling/v1beta1" ) var ( @@ -54,6 +55,7 @@ func init() { utilruntime.Must(mxnetv1.AddToScheme(scheme)) utilruntime.Must(mpiv1.AddToScheme(scheme)) //+kubebuilder:scaffold:scheme + utilruntime.Must(volcanov1beta1.AddToScheme(scheme)) } func main() { diff --git a/manifests/base/crds/kubeflow.org_pytorchjobs.yaml b/manifests/base/crds/kubeflow.org_pytorchjobs.yaml index daedf9b93b..dcabc5ee19 100644 --- a/manifests/base/crds/kubeflow.org_pytorchjobs.yaml +++ b/manifests/base/crds/kubeflow.org_pytorchjobs.yaml @@ -45,6 +45,8 @@ spec: properties: elasticPolicy: properties: + failurePolicy: + type: string maxReplicas: description: upper limit for the number of pods that can be set by the autoscaler; cannot be smaller than MinReplicas, defaults @@ -522,6 +524,8 @@ spec: --rdzv_endpoint, --rdzv_id are auto-assigned; any explicitly set values are ignored. type: boolean + successPolicy: + type: string type: object pytorchReplicaSpecs: additionalProperties: diff --git a/manifests/base/crds/scheduling.volcano.sh_podgroups.yaml b/manifests/base/crds/scheduling.volcano.sh_podgroups.yaml new file mode 100644 index 0000000000..dabca2f3a5 --- /dev/null +++ b/manifests/base/crds/scheduling.volcano.sh_podgroups.yaml @@ -0,0 +1,301 @@ +apiVersion: apiextensions.k8s.io/v1 +kind: CustomResourceDefinition +metadata: + annotations: + controller-gen.kubebuilder.io/version: v0.6.0 + creationTimestamp: null + name: podgroups.scheduling.volcano.sh +spec: + group: scheduling.volcano.sh + names: + kind: PodGroup + listKind: PodGroupList + plural: podgroups + shortNames: + - pg + - podgroup-v1beta1 + singular: podgroup + scope: Namespaced + versions: + - name: v1beta1 + additionalPrinterColumns: + - name: STATUS + type: string + jsonPath: .status.phase + - name: MINMEMBER + type: integer + jsonPath: .spec.minMember + - name: RUNNINGS + type: integer + jsonPath: .status.running + - name: AGE + type: date + jsonPath: .metadata.creationTimestamp + schema: + openAPIV3Schema: + description: PodGroup is a collection of Pod; used for batch workload. + properties: + apiVersion: + description: 'APIVersion defines the versioned schema of this representation + of an object. Servers should convert recognized schemas to the latest + internal value, and may reject unrecognized values. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources' + type: string + kind: + description: 'Kind is a string value representing the REST resource this + object represents. Servers may infer this from the endpoint the client + submits requests to. Cannot be updated. In CamelCase. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds' + type: string + metadata: + type: object + spec: + description: 'Specification of the desired behavior of the pod group. + More info: https://git.k8s.io/community/contributors/devel/api-conventions.md#spec-and-status' + properties: + minMember: + description: MinMember defines the minimal number of members/tasks + to run the pod group; if there's not enough resources to start all + tasks, the scheduler will not start anyone. + format: int32 + type: integer + minResources: + additionalProperties: + anyOf: + - type: integer + - type: string + pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$ + x-kubernetes-int-or-string: true + description: MinResources defines the minimal resource of members/tasks + to run the pod group; if there's not enough resources to start all + tasks, the scheduler will not start anyone. + type: object + minTaskMember: + additionalProperties: + format: int32 + type: integer + description: MinTaskMember defines the minimal number of pods to run + each task in the pod group; if there's not enough resources to start + each task, the scheduler will not start anyone. + type: object + priorityClassName: + description: If specified, indicates the PodGroup's priority. "system-node-critical" + and "system-cluster-critical" are two special keywords which indicate + the highest priorities with the former being the highest priority. + Any other name must be defined by creating a PriorityClass object + with that name. If not specified, the PodGroup priority will be + default or zero if there is no default. + type: string + queue: + description: Queue defines the queue to allocate resource for PodGroup; + if queue does not exist, the PodGroup will not be scheduled. Defaults + to `default` Queue with the lowest weight. + type: string + type: object + status: + description: Status represents the current information about a pod group. + This data may not be up to date. + properties: + conditions: + description: The conditions of PodGroup. + items: + description: PodGroupCondition contains details for the current + state of this pod group. + properties: + lastTransitionTime: + description: Last time the phase transitioned from another to + current phase. + format: date-time + type: string + message: + description: Human-readable message indicating details about + last transition. + type: string + reason: + description: Unique, one-word, CamelCase reason for the phase's + last transition. + type: string + status: + description: Status is the status of the condition. + type: string + transitionID: + description: The ID of condition transition. + type: string + type: + description: Type is the type of the condition + type: string + type: object + type: array + failed: + description: The number of pods which reached phase Failed. + format: int32 + type: integer + phase: + description: Current phase of PodGroup. + type: string + running: + description: The number of actively running pods. + format: int32 + type: integer + succeeded: + description: The number of pods which reached phase Succeeded. + format: int32 + type: integer + type: object + type: object + served: true + storage: true +status: + acceptedNames: + kind: "" + plural: "" + conditions: [] + storedVersions: [] +--- +# Source: volcano/templates/scheduling_v1beta1_queue.yaml +apiVersion: apiextensions.k8s.io/v1 +kind: CustomResourceDefinition +metadata: + annotations: + controller-gen.kubebuilder.io/version: v0.6.0 + creationTimestamp: null + name: queues.scheduling.volcano.sh +spec: + group: scheduling.volcano.sh + names: + kind: Queue + listKind: QueueList + plural: queues + shortNames: + - q + - queue-v1beta1 + singular: queue + scope: Cluster + versions: + - name: v1beta1 + schema: + openAPIV3Schema: + description: Queue is a queue of PodGroup. + properties: + apiVersion: + description: 'APIVersion defines the versioned schema of this representation + of an object. Servers should convert recognized schemas to the latest + internal value, and may reject unrecognized values. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources' + type: string + kind: + description: 'Kind is a string value representing the REST resource this + object represents. Servers may infer this from the endpoint the client + submits requests to. Cannot be updated. In CamelCase. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds' + type: string + metadata: + type: object + spec: + description: 'Specification of the desired behavior of the queue. More + info: https://git.k8s.io/community/contributors/devel/api-conventions.md#spec-and-status' + properties: + capability: + additionalProperties: + anyOf: + - type: integer + - type: string + pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$ + x-kubernetes-int-or-string: true + description: ResourceList is a set of (resource name, quantity) pairs. + type: object + extendClusters: + description: extendCluster indicate the jobs in this Queue will be + dispatched to these clusters. + items: + description: CluterSpec represents the template of Cluster + properties: + capacity: + additionalProperties: + anyOf: + - type: integer + - type: string + pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$ + x-kubernetes-int-or-string: true + description: ResourceList is a set of (resource name, quantity) + pairs. + type: object + name: + type: string + weight: + format: int32 + type: integer + type: object + type: array + guarantee: + description: Guarantee indicate configuration about resource reservation + properties: + resource: + additionalProperties: + anyOf: + - type: integer + - type: string + pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$ + x-kubernetes-int-or-string: true + description: The amount of cluster resource reserved for queue. + Just set either `percentage` or `resource` + type: object + type: object + reclaimable: + description: Reclaimable indicate whether the queue can be reclaimed + by other queue + type: boolean + weight: + format: int32 + type: integer + type: object + status: + description: The status of queue. + properties: + inqueue: + description: The number of `Inqueue` PodGroup in this queue. + format: int32 + type: integer + pending: + description: The number of 'Pending' PodGroup in this queue. + format: int32 + type: integer + reservation: + description: Reservation is the profile of resource reservation for + queue + properties: + nodes: + description: Nodes are Locked nodes for queue + items: + type: string + type: array + resource: + additionalProperties: + anyOf: + - type: integer + - type: string + pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$ + x-kubernetes-int-or-string: true + description: Resource is a list of total idle resource in locked + nodes. + type: object + type: object + running: + description: The number of 'Running' PodGroup in this queue. + format: int32 + type: integer + state: + description: State is state of queue + type: string + unknown: + description: The number of 'Unknown' PodGroup in this queue. + format: int32 + type: integer + type: object + type: object + served: true + storage: true + subresources: + status: {} +status: + acceptedNames: + kind: "" + plural: "" + conditions: [] + storedVersions: [] \ No newline at end of file diff --git a/pkg/apis/pytorch/v1/defaults.go b/pkg/apis/pytorch/v1/defaults.go index 12d2274558..4bbd01c953 100644 --- a/pkg/apis/pytorch/v1/defaults.go +++ b/pkg/apis/pytorch/v1/defaults.go @@ -73,6 +73,14 @@ func setElasticPolicy(job *PyTorchJob) { job.Spec.ElasticPolicy.MaxReplicas = workerReplicas job.Spec.ElasticPolicy.MinReplicas = workerReplicas } + if job.Spec.ElasticPolicy.SuccessPolicy == nil { + policy := SuccessPolicyDefault + job.Spec.ElasticPolicy.SuccessPolicy = &policy + } + if job.Spec.ElasticPolicy.FailurePolicy == nil { + policy := FailurePolicyDefault + job.Spec.ElasticPolicy.FailurePolicy = &policy + } } } diff --git a/pkg/apis/pytorch/v1/openapi_generated.go b/pkg/apis/pytorch/v1/openapi_generated.go index 943f4d170c..1b5537908c 100644 --- a/pkg/apis/pytorch/v1/openapi_generated.go +++ b/pkg/apis/pytorch/v1/openapi_generated.go @@ -410,7 +410,7 @@ func schema_pkg_apis_pytorch_v1_ElasticPolicy(ref common.ReferenceCallback) comm }, "metrics": { SchemaProps: spec.SchemaProps{ - Description: "metrics contains the specifications for which to use to calculate the desired replica count (the maximum replica count across all metrics will be used). The desired replica count is calculated multiplying the ratio between the target value and the current value by the current number of pods. Ergo, metrics used must decrease as the pod count is increased, and vice-versa. See the individual metric source types for more information about how each type of metric must respond. If not set, the default metric will be set to 80% average CPU utilization.", + Description: "Metrics contains the specifications which are used to calculate the desired replica count (the maximum replica count across all metrics will be used). The desired replica count is calculated with multiplying the ratio between the target value and the current value by the current number of pods. Ergo, metrics used must decrease as the pod count is increased, and vice-versa. See the individual metric source types for more information about how each type of metric must respond. If not set, the HPA will not be created.", Type: []string{"array"}, Items: &spec.SchemaOrArray{ Schema: &spec.Schema{ @@ -421,6 +421,18 @@ func schema_pkg_apis_pytorch_v1_ElasticPolicy(ref common.ReferenceCallback) comm }, }, }, + "successPolicy": { + SchemaProps: spec.SchemaProps{ + Type: []string{"string"}, + Format: "", + }, + }, + "failurePolicy": { + SchemaProps: spec.SchemaProps{ + Type: []string{"string"}, + Format: "", + }, + }, }, }, }, diff --git a/pkg/apis/pytorch/v1/types.go b/pkg/apis/pytorch/v1/types.go index 2f9e973922..335caa6132 100644 --- a/pkg/apis/pytorch/v1/types.go +++ b/pkg/apis/pytorch/v1/types.go @@ -98,8 +98,25 @@ type ElasticPolicy struct { // If not set, the HPA will not be created. // +optional Metrics []autoscalingv2beta2.MetricSpec `json:"metrics,omitempty"` + + SuccessPolicy *SuccessPolicy `json:"successPolicy,omitempty"` + FailurePolicy *FailurePolicy `json:"failurePolicy,omitempty"` } +type SuccessPolicy string + +const ( + SuccessPolicyDefault SuccessPolicy = "" // if worker0 is success, the job is set to be success + SuccessPolicyAllWorkers SuccessPolicy = "AllWorkers" // only if all pods is success, the job is set to be success +) + +type FailurePolicy string + +const ( + FailurePolicyDefault FailurePolicy = "" // if one pods fails, the job is set to be fail + FailurePolicyByMinReplicas FailurePolicy = "ByMinReplicas" // only if running pods is less than MinReplicas, the job is set to be fail +) + type RDZVConf struct { Key string `json:"key,omitempty"` Value string `json:"value,omitempty"` diff --git a/pkg/apis/pytorch/v1/zz_generated.deepcopy.go b/pkg/apis/pytorch/v1/zz_generated.deepcopy.go index 1fc845ff7b..1e518ee967 100644 --- a/pkg/apis/pytorch/v1/zz_generated.deepcopy.go +++ b/pkg/apis/pytorch/v1/zz_generated.deepcopy.go @@ -85,6 +85,16 @@ func (in *ElasticPolicy) DeepCopyInto(out *ElasticPolicy) { (*in)[i].DeepCopyInto(&(*out)[i]) } } + if in.SuccessPolicy != nil { + in, out := &in.SuccessPolicy, &out.SuccessPolicy + *out = new(SuccessPolicy) + **out = **in + } + if in.FailurePolicy != nil { + in, out := &in.FailurePolicy, &out.FailurePolicy + *out = new(FailurePolicy) + **out = **in + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ElasticPolicy. diff --git a/pkg/common/util/util.go b/pkg/common/util/util.go index 0ddef5ed74..f6b346193f 100644 --- a/pkg/common/util/util.go +++ b/pkg/common/util/util.go @@ -72,3 +72,15 @@ func GetReplicaTypes(specs map[commonv1.ReplicaType]*commonv1.ReplicaSpec) []com } return keys } + +// GetContainerExitCode gets the container exit code from the given pod. +func GetContainerExitCode(pod *corev1.Pod, name string) int32 { + var exitCode int32 = 0xbeef // magic number + for _, status := range pod.Status.ContainerStatuses { + state := status.State + if status.Name == name && state.Terminated != nil { + exitCode = state.Terminated.ExitCode + } + } + return exitCode +} diff --git a/pkg/controller.v1/pytorch/label.go b/pkg/controller.v1/pytorch/label.go new file mode 100644 index 0000000000..976b63805f --- /dev/null +++ b/pkg/controller.v1/pytorch/label.go @@ -0,0 +1,40 @@ +package pytorch + +import ( + "fmt" + "strconv" + "strings" + + pytorchv1 "github.com/kubeflow/training-operator/pkg/apis/pytorch/v1" + corev1 "k8s.io/api/core/v1" + volcanov1beta1 "volcano.sh/apis/pkg/apis/scheduling/v1beta1" +) + +func setPodLabel(obj interface{}, podTemplateSpec *corev1.PodTemplateSpec, rtype, index string) error { + pytorchjob, ok := obj.(*pytorchv1.PyTorchJob) + if !ok { + return fmt.Errorf("%+v is not a type of PyTorchJob", obj) + } + if len(podTemplateSpec.Labels) == 0 { + podTemplateSpec.Labels = make(map[string]string) + } + if pytorchjob.Spec.PyTorchReplicaSpecs[pytorchv1.PyTorchReplicaTypeMaster] != nil { + if rtype == strings.ToLower(string(pytorchv1.PyTorchReplicaTypeMaster)) { + podTemplateSpec.Labels[volcanov1beta1.PodPreemptable] = "false" + } else { + podTemplateSpec.Labels[volcanov1beta1.PodPreemptable] = "true" + } + } else { + // If the master is null, then we need to set the volcano.sh/preemptable = false to make sure that work0 can not be preempted + rank, err := strconv.Atoi(index) + if err != nil { + return err + } + if rank == 0 { + podTemplateSpec.Labels[volcanov1beta1.PodPreemptable] = "false" + } else { + podTemplateSpec.Labels[volcanov1beta1.PodPreemptable] = "true" + } + } + return nil +} diff --git a/pkg/controller.v1/pytorch/pytorchjob_controller.go b/pkg/controller.v1/pytorch/pytorchjob_controller.go index 6a523a70f2..d47e4e8d00 100644 --- a/pkg/controller.v1/pytorch/pytorchjob_controller.go +++ b/pkg/controller.v1/pytorch/pytorchjob_controller.go @@ -17,6 +17,7 @@ package pytorch import ( "context" "fmt" + "strings" "time" "github.com/go-logr/logr" @@ -47,6 +48,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/manager" "sigs.k8s.io/controller-runtime/pkg/predicate" "sigs.k8s.io/controller-runtime/pkg/source" + volcanov1beta1 "volcano.sh/apis/pkg/apis/scheduling/v1beta1" volcanoclient "volcano.sh/apis/pkg/client/clientset/versioned" pytorchv1 "github.com/kubeflow/training-operator/pkg/apis/pytorch/v1" @@ -109,6 +111,7 @@ type PyTorchJobReconciler struct { //+kubebuilder:rbac:groups=kubeflow.org,resources=pytorchjobs/finalizers,verbs=update //+kubebuilder:rbac:groups="",resources=pods,verbs=get;list;watch;create;update;patch;delete //+kubebuilder:rbac:groups="",resources=services,verbs=get;list;watch;create;delete +//+kubebuilder:rbac:groups=scheduling.volcano.sh,resources=podgroups,verbs=get;list;watch;create;delete // Reconcile is part of the main kubernetes reconciliation loop which aims to // move the current state of the cluster closer to the desired state. @@ -207,6 +210,18 @@ func (r *PyTorchJobReconciler) SetupWithManager(mgr ctrl.Manager) error { return err } + // inject watching for job related podgroup + if err = c.Watch(&source.Kind{Type: &volcanov1beta1.PodGroup{}}, &handler.EnqueueRequestForOwner{ + IsController: true, + OwnerType: &pytorchv1.PyTorchJob{}, + }, predicate.Funcs{ + CreateFunc: util.OnDependentCreateFunc(r.Expectations), + UpdateFunc: util.OnDependentUpdateFunc(&r.JobController), + DeleteFunc: util.OnDependentDeleteFunc(r.Expectations), + }); err != nil { + return err + } + return nil } @@ -389,8 +404,15 @@ func (r *PyTorchJobReconciler) UpdateJobStatus(job interface{}, } } else { if rtype == pytorchv1.PyTorchReplicaTypeWorker { - // TODO(gaocegege): Support SuccessPolicy - if expected == 0 { + worker0Completed, err := r.IsWorker0Completed(pytorchjob, replicas) + if err != nil { + logger.Warnf("check if worker 0 completed error %v", err) + return err + } + // Leave a succeeded condition for the following two cases: + // 1. If default success policy is used and worker 0 has completed. + // 2. If `SuccessPolicyAllWorkers` success policy is used and all workers are succeeded. + if expected == 0 || (worker0Completed && *pytorchjob.Spec.ElasticPolicy.SuccessPolicy != pytorchv1.SuccessPolicyAllWorkers) { msg := fmt.Sprintf("TFJob %s/%s successfully completed.", pytorchjob.Namespace, pytorchjob.Name) r.recorder.Event(pytorchjob, corev1.EventTypeNormal, commonutil.JobSucceededReason, msg) @@ -428,7 +450,7 @@ func (r *PyTorchJobReconciler) UpdateJobStatus(job interface{}, return err } trainingoperatorcommon.RestartedJobsCounterInc(pytorchjob.Namespace, pytorchv1.FrameworkName) - } else { + } else if running < *pytorchjob.Spec.ElasticPolicy.MinReplicas || *pytorchjob.Spec.ElasticPolicy.FailurePolicy == pytorchv1.FailurePolicyDefault { msg := fmt.Sprintf("PyTorchJob %s is failed because %d %s replica(s) failed.", pytorchjob.Name, failed, rtype) r.Recorder.Event(pytorchjob, corev1.EventTypeNormal, commonutil.JobFailedReason, msg) if pytorchjob.Status.CompletionTime == nil { @@ -442,12 +464,58 @@ func (r *PyTorchJobReconciler) UpdateJobStatus(job interface{}, } trainingoperatorcommon.FailedJobsCounterInc(pytorchjob.Namespace, pytorchv1.FrameworkName) } + // if running pods is greater or equal than MinReplicas, the job is running } } return nil } +// IsWorker0Completed returns true if pod of worker0 succeeded and qexited with 0 +func (p *PyTorchJobReconciler) IsWorker0Completed(job *pytorchv1.PyTorchJob, + replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec) (bool, error) { + worker0Completed := false + _, ok := replicas[pytorchv1.PyTorchReplicaTypeWorker] + if !ok { + return true, nil + } + podSlices, err := p.getPodSlices(job, replicas[pytorchv1.PyTorchReplicaTypeWorker].Replicas, string(pytorchv1.PyTorchReplicaTypeWorker)) + if err != nil { + return false, err + } + for index, podSlice := range podSlices { + if len(podSlice) == 1 { + pod := podSlice[0] + exitCode := util.GetContainerExitCode(pod, pytorchv1.DefaultContainerName) + if index == 0 && exitCode == 0 && pod.Status.Phase == corev1.PodSucceeded { + worker0Completed = true + } + } + } + return worker0Completed, nil +} + +// getPodSlices returns a slice, which element is the slice of pod. +// It gives enough information to caller to make decision to up/down scale resources. +func (p *PyTorchJobReconciler) getPodSlices(job *pytorchv1.PyTorchJob, replicasNum *int32, rtype string) ([][]*corev1.Pod, error) { + logger := commonutil.LoggerForReplica(job, strings.ToLower(rtype)) + + pods, err := p.GetPodsForJob(job) + if err != nil { + commonutil.LoggerForJob(job).Warnf("getPodsForTFJob error %v", err) + return nil, err + } + + // Get all pods for the type rt. + pods, err = p.JobController.FilterPodsForReplicaType(pods, strings.ToLower(rtype)) + if err != nil { + return nil, err + } + + podSlices := p.GetPodSlices(pods, int(*replicasNum), logger) + return podSlices, nil +} + // ContainsMasterSpec returns true if the tfjob contains master spec. func ContainsMasterSpec(replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec) bool { if _, ok := replicas[pytorchv1.PyTorchReplicaTypeMaster]; ok { @@ -489,6 +557,9 @@ func (r *PyTorchJobReconciler) UpdateJobStatusInApiServer(job interface{}, jobSt // SetClusterSpec sets the cluster spec and init container for the pod func (r *PyTorchJobReconciler) SetClusterSpec(job interface{}, podTemplate *corev1.PodTemplateSpec, rtype, index string) error { + if err := setPodLabel(job, podTemplate, rtype, index); err != nil { + return err + } if err := setPodEnv(job, podTemplate, rtype, index); err != nil { return err } diff --git a/pkg/controller.v1/pytorch/pytorchjob_controller_suite_test.go b/pkg/controller.v1/pytorch/pytorchjob_controller_suite_test.go index 208edbd657..861005b318 100644 --- a/pkg/controller.v1/pytorch/pytorchjob_controller_suite_test.go +++ b/pkg/controller.v1/pytorch/pytorchjob_controller_suite_test.go @@ -21,17 +21,18 @@ import ( v1 "github.com/kubeflow/training-operator/pkg/apis/pytorch/v1" "github.com/kubeflow/training-operator/pkg/config" - . "github.com/onsi/ginkgo" "github.com/onsi/gomega" . "github.com/onsi/gomega" - "k8s.io/client-go/kubernetes/scheme" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/runtime" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/envtest" "sigs.k8s.io/controller-runtime/pkg/envtest/printer" logf "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/log/zap" + volcanov1beta1 "volcano.sh/apis/pkg/apis/scheduling/v1beta1" //+kubebuilder:scaffold:imports ) @@ -68,7 +69,12 @@ var _ = BeforeSuite(func() { Expect(err).NotTo(HaveOccurred()) Expect(cfg).NotTo(BeNil()) - err = v1.AddToScheme(scheme.Scheme) + s := runtime.NewScheme() + err = corev1.AddToScheme(s) + Expect(err).NotTo(HaveOccurred()) + err = v1.AddToScheme(s) + Expect(err).NotTo(HaveOccurred()) + err = volcanov1beta1.AddToScheme(s) Expect(err).NotTo(HaveOccurred()) // Set default config. @@ -76,12 +82,12 @@ var _ = BeforeSuite(func() { config.Config.PyTorchInitContainerTemplateFile = config.PyTorchInitContainerTemplateFileDefault //+kubebuilder:scaffold:scheme - - testK8sClient, err = client.New(cfg, client.Options{Scheme: scheme.Scheme}) + testK8sClient, err = client.New(cfg, client.Options{Scheme: s}) Expect(err).NotTo(HaveOccurred()) Expect(testK8sClient).NotTo(BeNil()) mgr, err := ctrl.NewManager(cfg, ctrl.Options{ + Scheme: s, MetricsBindAddress: "0", }) Expect(err).NotTo(gomega.HaveOccurred())