From b3267916fd5e886805b4ec472c7e39c12b9ef397 Mon Sep 17 00:00:00 2001 From: Andrew Dye Date: Mon, 9 Oct 2023 20:18:38 -0700 Subject: [PATCH 1/6] Build SparkApplicationSpec using ToK8sPodSpec Signed-off-by: Andrew Dye --- .../flytek8s/non_interruptible.go | 92 ++++++ .../pluginmachinery/flytek8s/pod_helper.go | 9 + .../go/tasks/plugins/k8s/dask/dask.go | 50 +--- .../go/tasks/plugins/k8s/spark/spark.go | 281 ++++++++++-------- .../go/tasks/plugins/k8s/spark/spark_test.go | 145 +++++++-- 5 files changed, 384 insertions(+), 193 deletions(-) create mode 100644 flyteplugins/go/tasks/pluginmachinery/flytek8s/non_interruptible.go diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/non_interruptible.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/non_interruptible.go new file mode 100644 index 0000000000..daa00241bb --- /dev/null +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/non_interruptible.go @@ -0,0 +1,92 @@ +package flytek8s + +import ( + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" +) + +// Wraps a regular TaskExecutionMetadata and overrides the IsInterruptible method to always return false +// This is useful as the runner and the scheduler pods should never be interruptible +type NonInterruptibleTaskExecutionMetadata struct { + metadata pluginsCore.TaskExecutionMetadata +} + +func (n NonInterruptibleTaskExecutionMetadata) GetOwnerID() types.NamespacedName { + return n.metadata.GetOwnerID() +} + +func (n NonInterruptibleTaskExecutionMetadata) GetTaskExecutionID() pluginsCore.TaskExecutionID { + return n.metadata.GetTaskExecutionID() +} + +func (n NonInterruptibleTaskExecutionMetadata) GetNamespace() string { + return n.metadata.GetNamespace() +} + +func (n NonInterruptibleTaskExecutionMetadata) GetOwnerReference() metav1.OwnerReference { + return n.metadata.GetOwnerReference() +} + +func (n NonInterruptibleTaskExecutionMetadata) GetOverrides() pluginsCore.TaskOverrides { + return n.metadata.GetOverrides() +} + +func (n NonInterruptibleTaskExecutionMetadata) GetLabels() map[string]string { + return n.metadata.GetLabels() +} + +func (n NonInterruptibleTaskExecutionMetadata) GetMaxAttempts() uint32 { + return n.metadata.GetMaxAttempts() +} + +func (n NonInterruptibleTaskExecutionMetadata) GetAnnotations() map[string]string { + return n.metadata.GetAnnotations() +} + +func (n NonInterruptibleTaskExecutionMetadata) GetK8sServiceAccount() string { + return n.metadata.GetK8sServiceAccount() +} + +func (n NonInterruptibleTaskExecutionMetadata) GetSecurityContext() core.SecurityContext { + return n.metadata.GetSecurityContext() +} + +func (n NonInterruptibleTaskExecutionMetadata) GetPlatformResources() *v1.ResourceRequirements { + return n.metadata.GetPlatformResources() +} + +func (n NonInterruptibleTaskExecutionMetadata) GetInterruptibleFailureThreshold() int32 { + return n.metadata.GetInterruptibleFailureThreshold() +} + +func (n NonInterruptibleTaskExecutionMetadata) GetEnvironmentVariables() map[string]string { + return n.metadata.GetEnvironmentVariables() +} + +func (n NonInterruptibleTaskExecutionMetadata) IsInterruptible() bool { + return false +} + +// A wrapper around a regular TaskExecutionContext allowing to inject a custom TaskExecutionMetadata which is +// non-interruptible +type NonInterruptibleTaskExecutionContext struct { + pluginsCore.TaskExecutionContext + metadata NonInterruptibleTaskExecutionMetadata +} + +func (n NonInterruptibleTaskExecutionContext) TaskExecutionMetadata() pluginsCore.TaskExecutionMetadata { + return n.metadata +} + +func NewNonInterruptibleTaskExecutionContext(ctx pluginsCore.TaskExecutionContext) NonInterruptibleTaskExecutionContext { + return NonInterruptibleTaskExecutionContext{ + TaskExecutionContext: ctx, + metadata: NonInterruptibleTaskExecutionMetadata{ + metadata: ctx.TaskExecutionMetadata(), + }, + } +} diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go index 6ee8e41722..93de6a0f39 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go @@ -271,6 +271,15 @@ func ToK8sPodSpec(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (* return podSpec, objectMeta, primaryContainerName, nil } +func GetContainer(podSpec *v1.PodSpec, name string) (*v1.Container, error) { + for _, container := range podSpec.Containers { + if container.Name == name { + return &container, nil + } + } + return nil, pluginserrors.Errorf(pluginserrors.BadTaskSpecification, "invalid TaskSpecification, container [%s] not defined", name) +} + // getBasePodTemplate attempts to retrieve the PodTemplate to use as the base for k8s Pod configuration. This value can // come from one of the following: // (1) PodTemplate name in the TaskMetadata: This name is then looked up in the PodTemplateStore. diff --git a/flyteplugins/go/tasks/plugins/k8s/dask/dask.go b/flyteplugins/go/tasks/plugins/k8s/dask/dask.go index f8272b919a..65050f5bb2 100644 --- a/flyteplugins/go/tasks/plugins/k8s/dask/dask.go +++ b/flyteplugins/go/tasks/plugins/k8s/dask/dask.go @@ -6,6 +6,12 @@ import ( "time" daskAPI "github.com/dask/dask-kubernetes/v2023/dask_kubernetes/operator/go_client/pkg/apis/kubernetes.dask.org/v1" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/intstr" + "k8s.io/client-go/kubernetes/scheme" + "sigs.k8s.io/controller-runtime/pkg/client" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" "github.com/flyteorg/flyte/flyteplugins/go/tasks/errors" "github.com/flyteorg/flyte/flyteplugins/go/tasks/logs" @@ -15,11 +21,6 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" - v1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/util/intstr" - "k8s.io/client-go/kubernetes/scheme" - "sigs.k8s.io/controller-runtime/pkg/client" ) const ( @@ -27,42 +28,12 @@ const ( KindDaskJob = "DaskJob" ) -// Wraps a regular TaskExecutionMetadata and overrides the IsInterruptible method to always return false -// This is useful as the runner and the scheduler pods should never be interruptible -type nonInterruptibleTaskExecutionMetadata struct { - pluginsCore.TaskExecutionMetadata -} - -func (n nonInterruptibleTaskExecutionMetadata) IsInterruptible() bool { - return false -} - -// A wrapper around a regular TaskExecutionContext allowing to inject a custom TaskExecutionMetadata which is -// non-interruptible -type nonInterruptibleTaskExecutionContext struct { - pluginsCore.TaskExecutionContext - metadata nonInterruptibleTaskExecutionMetadata -} - -func (n nonInterruptibleTaskExecutionContext) TaskExecutionMetadata() pluginsCore.TaskExecutionMetadata { - return n.metadata -} - func mergeMapInto(src map[string]string, dst map[string]string) { for key, value := range src { dst[key] = value } } -func getPrimaryContainer(spec *v1.PodSpec, primaryContainerName string) (*v1.Container, error) { - for _, container := range spec.Containers { - if container.Name == primaryContainerName { - return &container, nil - } - } - return nil, errors.Errorf(errors.BadTaskSpecification, "primary container [%v] not found in pod spec", primaryContainerName) -} - func replacePrimaryContainer(spec *v1.PodSpec, primaryContainerName string, container v1.Container) error { for i, c := range spec.Containers { if c.Name == primaryContainerName { @@ -104,8 +75,7 @@ func (p daskResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC if err != nil { return nil, err } - nonInterruptibleTaskMetadata := nonInterruptibleTaskExecutionMetadata{taskCtx.TaskExecutionMetadata()} - nonInterruptibleTaskCtx := nonInterruptibleTaskExecutionContext{taskCtx, nonInterruptibleTaskMetadata} + nonInterruptibleTaskCtx := flytek8s.NewNonInterruptibleTaskExecutionContext(taskCtx) nonInterruptiblePodSpec, _, _, err := flytek8s.ToK8sPodSpec(ctx, nonInterruptibleTaskCtx) if err != nil { return nil, err @@ -144,7 +114,7 @@ func (p daskResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC func createWorkerSpec(cluster plugins.DaskWorkerGroup, podSpec *v1.PodSpec, primaryContainerName string) (*daskAPI.WorkerSpec, error) { workerPodSpec := podSpec.DeepCopy() - primaryContainer, err := getPrimaryContainer(workerPodSpec, primaryContainerName) + primaryContainer, err := flytek8s.GetContainer(workerPodSpec, primaryContainerName) if err != nil { return nil, err } @@ -206,7 +176,7 @@ func createWorkerSpec(cluster plugins.DaskWorkerGroup, podSpec *v1.PodSpec, prim func createSchedulerSpec(scheduler plugins.DaskScheduler, clusterName string, podSpec *v1.PodSpec, primaryContainerName string) (*daskAPI.SchedulerSpec, error) { schedulerPodSpec := podSpec.DeepCopy() - primaryContainer, err := getPrimaryContainer(schedulerPodSpec, primaryContainerName) + primaryContainer, err := flytek8s.GetContainer(schedulerPodSpec, primaryContainerName) if err != nil { return nil, err } @@ -283,7 +253,7 @@ func createJobSpec(workerSpec daskAPI.WorkerSpec, schedulerSpec daskAPI.Schedule jobPodSpec := podSpec.DeepCopy() jobPodSpec.RestartPolicy = v1.RestartPolicyNever - primaryContainer, err := getPrimaryContainer(jobPodSpec, primaryContainerName) + primaryContainer, err := flytek8s.GetContainer(jobPodSpec, primaryContainerName) if err != nil { return nil, err } diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go index 7c7ba34f9a..8365c023aa 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go @@ -3,35 +3,29 @@ package spark import ( "context" "fmt" - - "sigs.k8s.io/controller-runtime/pkg/client" - + "regexp" "strconv" + "strings" + "time" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" - - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/template" - - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" + sparkOp "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2" + sparkOpConfig "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/config" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/scheme" + "sigs.k8s.io/controller-runtime/pkg/client" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" "github.com/flyteorg/flyte/flyteplugins/go/tasks/errors" "github.com/flyteorg/flyte/flyteplugins/go/tasks/logs" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery" pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" - + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" - "k8s.io/client-go/kubernetes/scheme" - - sparkOp "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - - "regexp" - "strings" - "time" ) const KindSparkApplication = "SparkApplication" @@ -80,70 +74,20 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo return nil, errors.Wrapf(errors.BadTaskSpecification, err, "invalid TaskSpecification [%v].", taskTemplate.GetCustom()) } - annotations := utils.UnionMaps(config.GetK8sPluginConfig().DefaultAnnotations, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations())) - labels := utils.UnionMaps(config.GetK8sPluginConfig().DefaultLabels, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels())) - container := taskTemplate.GetContainer() - - envVars := flytek8s.DecorateEnvVars(ctx, flytek8s.ToK8sEnvVar(container.GetEnv()), - taskCtx.TaskExecutionMetadata().GetEnvironmentVariables(), taskCtx.TaskExecutionMetadata().GetTaskExecutionID()) - - sparkEnvVars := make(map[string]string) - for _, envVar := range envVars { - sparkEnvVars[envVar.Name] = envVar.Value - } - - sparkEnvVars["FLYTE_MAX_ATTEMPTS"] = strconv.Itoa(int(taskCtx.TaskExecutionMetadata().GetMaxAttempts())) - - serviceAccountName := flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) - - if len(serviceAccountName) == 0 { - serviceAccountName = sparkTaskType - } - driverSpec := sparkOp.DriverSpec{ - SparkPodSpec: sparkOp.SparkPodSpec{ - Affinity: config.GetK8sPluginConfig().DefaultAffinity, - Annotations: annotations, - Labels: labels, - EnvVars: sparkEnvVars, - Image: &container.Image, - SecurityContenxt: config.GetK8sPluginConfig().DefaultPodSecurityContext.DeepCopy(), - DNSConfig: config.GetK8sPluginConfig().DefaultPodDNSConfig.DeepCopy(), - Tolerations: config.GetK8sPluginConfig().DefaultTolerations, - SchedulerName: &config.GetK8sPluginConfig().SchedulerName, - NodeSelector: config.GetK8sPluginConfig().DefaultNodeSelector, - HostNetwork: config.GetK8sPluginConfig().EnableHostNetworkingPod, - }, - ServiceAccount: &serviceAccountName, - } - - executorSpec := sparkOp.ExecutorSpec{ - SparkPodSpec: sparkOp.SparkPodSpec{ - Affinity: config.GetK8sPluginConfig().DefaultAffinity.DeepCopy(), - Annotations: annotations, - Labels: labels, - Image: &container.Image, - EnvVars: sparkEnvVars, - SecurityContenxt: config.GetK8sPluginConfig().DefaultPodSecurityContext.DeepCopy(), - DNSConfig: config.GetK8sPluginConfig().DefaultPodDNSConfig.DeepCopy(), - Tolerations: config.GetK8sPluginConfig().DefaultTolerations, - SchedulerName: &config.GetK8sPluginConfig().SchedulerName, - NodeSelector: config.GetK8sPluginConfig().DefaultNodeSelector, - HostNetwork: config.GetK8sPluginConfig().EnableHostNetworkingPod, - }, + sparkConfig := getSparkConfig(taskCtx, &sparkJob) + driverSpec, err := createDriverSpec(ctx, taskCtx, sparkConfig) + if err != nil { + return nil, err } - - modifiedArgs, err := template.Render(ctx, container.GetArgs(), template.Parameters{ - TaskExecMetadata: taskCtx.TaskExecutionMetadata(), - Inputs: taskCtx.InputReader(), - OutputPath: taskCtx.OutputWriter(), - Task: taskCtx.TaskReader(), - }) + executorSpec, err := createExecutorSpec(ctx, taskCtx, sparkConfig) if err != nil { return nil, err } + app := createSparkApplication(&sparkJob, sparkConfig, driverSpec, executorSpec) + return app, nil +} - // Hack: Retry submit failures in-case of resource limits hit. - submissionFailureRetries := int32(14) +func getSparkConfig(taskCtx pluginsCore.TaskExecutionContext, sparkJob *plugins.SparkJob) map[string]string { // Start with default config values. sparkConfig := make(map[string]string) for k, v := range GetSparkConfig().DefaultSparkConfig { @@ -165,57 +109,158 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo } // Set pod limits. - if len(sparkConfig["spark.kubernetes.driver.limit.cores"]) == 0 { + if len(sparkConfig[sparkOpConfig.SparkDriverCoreLimitKey]) == 0 { // spark.kubernetes.driver.request.cores takes precedence over spark.driver.cores - if len(sparkConfig["spark.kubernetes.driver.request.cores"]) != 0 { - sparkConfig["spark.kubernetes.driver.limit.cores"] = sparkConfig["spark.kubernetes.driver.request.cores"] + if len(sparkConfig[sparkOpConfig.SparkDriverCoreRequestKey]) != 0 { + sparkConfig[sparkOpConfig.SparkDriverCoreLimitKey] = sparkConfig[sparkOpConfig.SparkDriverCoreRequestKey] } else if len(sparkConfig["spark.driver.cores"]) != 0 { - sparkConfig["spark.kubernetes.driver.limit.cores"] = sparkConfig["spark.driver.cores"] + sparkConfig[sparkOpConfig.SparkDriverCoreLimitKey] = sparkConfig["spark.driver.cores"] } } - if len(sparkConfig["spark.kubernetes.executor.limit.cores"]) == 0 { + if len(sparkConfig[sparkOpConfig.SparkExecutorCoreLimitKey]) == 0 { // spark.kubernetes.executor.request.cores takes precedence over spark.executor.cores - if len(sparkConfig["spark.kubernetes.executor.request.cores"]) != 0 { - sparkConfig["spark.kubernetes.executor.limit.cores"] = sparkConfig["spark.kubernetes.executor.request.cores"] + if len(sparkConfig[sparkOpConfig.SparkExecutorCoreRequestKey]) != 0 { + sparkConfig[sparkOpConfig.SparkExecutorCoreLimitKey] = sparkConfig[sparkOpConfig.SparkExecutorCoreRequestKey] } else if len(sparkConfig["spark.executor.cores"]) != 0 { - sparkConfig["spark.kubernetes.executor.limit.cores"] = sparkConfig["spark.executor.cores"] + sparkConfig[sparkOpConfig.SparkExecutorCoreLimitKey] = sparkConfig["spark.executor.cores"] } } sparkConfig["spark.kubernetes.executor.podNamePrefix"] = taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() sparkConfig["spark.kubernetes.driverEnv.FLYTE_START_TIME"] = strconv.FormatInt(time.Now().UnixNano()/1000000, 10) - // Add driver/executor defaults to CRD Driver/Executor Spec as well. - cores, err := strconv.ParseInt(sparkConfig["spark.driver.cores"], 10, 32) - if err == nil { - driverSpec.Cores = intPtr(int32(cores)) + return sparkConfig +} + +func serviceAccountName(metadata pluginsCore.TaskExecutionMetadata) string { + name := flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(metadata) + if len(name) == 0 { + name = sparkTaskType } - driverSpec.Memory = strPtr(sparkConfig["spark.driver.memory"]) + return name +} - execCores, err := strconv.ParseInt(sparkConfig["spark.executor.cores"], 10, 32) - if err == nil { - executorSpec.Cores = intPtr(int32(execCores)) +func createSparkPodSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, podSpec *v1.PodSpec, container *v1.Container) ( + *sparkOp.SparkPodSpec, error) { + annotations := utils.UnionMaps(config.GetK8sPluginConfig().DefaultAnnotations, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations())) + labels := utils.UnionMaps(config.GetK8sPluginConfig().DefaultLabels, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels())) + + sparkEnvVars := make(map[string]string) + for _, envVar := range container.Env { + sparkEnvVars[envVar.Name] = envVar.Value } + sparkEnvVars["FLYTE_MAX_ATTEMPTS"] = strconv.Itoa(int(taskCtx.TaskExecutionMetadata().GetMaxAttempts())) + + spec := sparkOp.SparkPodSpec{ + Affinity: podSpec.Affinity, + Annotations: annotations, + Labels: labels, + EnvVars: sparkEnvVars, + Image: &container.Image, + SecurityContenxt: podSpec.SecurityContext.DeepCopy(), + DNSConfig: podSpec.DNSConfig.DeepCopy(), + Tolerations: podSpec.Tolerations, + SchedulerName: &config.GetK8sPluginConfig().SchedulerName, + NodeSelector: podSpec.NodeSelector, + HostNetwork: config.GetK8sPluginConfig().EnableHostNetworkingPod, + } + return &spec, nil +} - execCount, err := strconv.ParseInt(sparkConfig["spark.executor.instances"], 10, 32) - if err == nil { - executorSpec.Instances = intPtr(int32(execCount)) +type driverSpec struct { + podSpec *v1.PodSpec + container *v1.Container + sparkSpec *sparkOp.DriverSpec +} + +func createDriverSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, sparkConfig map[string]string) (*driverSpec, error) { + // Spark driver pods should always run as non-interruptible + nonInterruptibleTaskCtx := flytek8s.NewNonInterruptibleTaskExecutionContext(taskCtx) + podSpec, _, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, nonInterruptibleTaskCtx) + if err != nil { + return nil, err + } + primaryContainer, err := flytek8s.GetContainer(podSpec, primaryContainerName) + if err != nil { + return nil, err + } + sparkPodSpec, err := createSparkPodSpec(ctx, nonInterruptibleTaskCtx, podSpec, primaryContainer) + if err != nil { + return nil, err } - executorSpec.Memory = strPtr(sparkConfig["spark.executor.memory"]) + serviceAccountName := serviceAccountName(nonInterruptibleTaskCtx.TaskExecutionMetadata()) + spec := driverSpec{ + podSpec, + primaryContainer, + &sparkOp.DriverSpec{ + SparkPodSpec: *sparkPodSpec, + ServiceAccount: &serviceAccountName, + }, + } + if cores, err := strconv.ParseInt(sparkConfig["spark.driver.cores"], 10, 32); err == nil { + spec.sparkSpec.Cores = intPtr(int32(cores)) + } + spec.sparkSpec.Memory = strPtr(sparkConfig["spark.driver.memory"]) + return &spec, nil +} - j := &sparkOp.SparkApplication{ +type executorSpec struct { + podSpec *v1.PodSpec + container *v1.Container + sparkSpec *sparkOp.ExecutorSpec + serviceAccountName string +} + +func createExecutorSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, sparkConfig map[string]string) (*executorSpec, error) { + podSpec, _, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) + if err != nil { + return nil, err + } + primaryContainer, err := flytek8s.GetContainer(podSpec, primaryContainerName) + if err != nil { + return nil, err + } + sparkPodSpec, err := createSparkPodSpec(ctx, taskCtx, podSpec, primaryContainer) + if err != nil { + return nil, err + } + serviceAccountName := serviceAccountName(taskCtx.TaskExecutionMetadata()) + spec := executorSpec{ + podSpec, + primaryContainer, + &sparkOp.ExecutorSpec{ + SparkPodSpec: *sparkPodSpec, + }, + serviceAccountName, + } + if execCores, err := strconv.ParseInt(sparkConfig["spark.executor.cores"], 10, 32); err == nil { + spec.sparkSpec.Cores = intPtr(int32(execCores)) + } + if execCount, err := strconv.ParseInt(sparkConfig["spark.executor.instances"], 10, 32); err == nil { + spec.sparkSpec.Instances = intPtr(int32(execCount)) + } + spec.sparkSpec.Memory = strPtr(sparkConfig["spark.executor.memory"]) + return &spec, nil +} + +func createSparkApplication(sparkJob *plugins.SparkJob, sparkConfig map[string]string, driverSpec *driverSpec, + executorSpec *executorSpec) *sparkOp.SparkApplication { + // Hack: Retry submit failures in-case of resource limits hit. + submissionFailureRetries := int32(14) + + app := &sparkOp.SparkApplication{ TypeMeta: metav1.TypeMeta{ Kind: KindSparkApplication, APIVersion: sparkOp.SchemeGroupVersion.String(), }, Spec: sparkOp.SparkApplicationSpec{ - ServiceAccount: &serviceAccountName, + ServiceAccount: &executorSpec.serviceAccountName, Type: getApplicationType(sparkJob.GetApplicationType()), - Image: &container.Image, - Arguments: modifiedArgs, - Driver: driverSpec, - Executor: executorSpec, + Image: &executorSpec.container.Image, + Arguments: executorSpec.container.Args, + Driver: *driverSpec.sparkSpec, + Executor: *executorSpec.sparkSpec, SparkConf: sparkConfig, HadoopConf: sparkJob.GetHadoopConf(), // SubmissionFailures handled here. Task Failures handled at Propeller/Job level. @@ -227,32 +272,16 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo } if val, ok := sparkConfig["spark.batchScheduler"]; ok { - j.Spec.BatchScheduler = &val + app.Spec.BatchScheduler = &val } if sparkJob.MainApplicationFile != "" { - j.Spec.MainApplicationFile = &sparkJob.MainApplicationFile + app.Spec.MainApplicationFile = &sparkJob.MainApplicationFile } if sparkJob.MainClass != "" { - j.Spec.MainClass = &sparkJob.MainClass - } - - // Spark driver pods should always run as non-interruptible. As such, we hardcode - // `interruptible=false` to explicitly add non-interruptible node selector - // requirements to the driver pods - flytek8s.ApplyInterruptibleNodeSelectorRequirement(false, j.Spec.Driver.Affinity) - - // Add Interruptible Tolerations/NodeSelector to only Executor pods. - // The Interruptible NodeSelector takes precedence over the DefaultNodeSelector - if taskCtx.TaskExecutionMetadata().IsInterruptible() { - j.Spec.Executor.Tolerations = append(j.Spec.Executor.Tolerations, config.GetK8sPluginConfig().InterruptibleTolerations...) - j.Spec.Executor.NodeSelector = config.GetK8sPluginConfig().InterruptibleNodeSelector + app.Spec.MainClass = &sparkJob.MainClass } - - // Add interruptible/non-interruptible node selector requirements to executor pod - flytek8s.ApplyInterruptibleNodeSelectorRequirement(taskCtx.TaskExecutionMetadata().IsInterruptible(), j.Spec.Executor.Affinity) - - return j, nil + return app } func addConfig(sparkConfig map[string]string, key string, value string) { diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go index e981c0dce1..13bf1153ac 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go @@ -6,28 +6,24 @@ import ( "strconv" "testing" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" - - "github.com/stretchr/testify/mock" - - "github.com/flyteorg/flyte/flyteplugins/go/tasks/logs" - - pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" - - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" - - pluginIOMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks" - sj "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" "github.com/golang/protobuf/jsonpb" structpb "github.com/golang/protobuf/ptypes/struct" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" corev1 "k8s.io/api/core/v1" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/logs" + pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" + pluginIOMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" ) const sparkMainClass = "MainClass" @@ -87,7 +83,7 @@ func TestGetEventInfo(t *testing.T) { }, }, })) - taskCtx := dummySparkTaskContext(dummySparkTaskTemplate("blah-1", dummySparkConf), false) + taskCtx := dummySparkTaskContext(dummySparkTaskTemplateContainer("blah-1", dummySparkConf), false) info, err := getEventInfoForSpark(taskCtx, dummySparkApplication(sj.RunningState)) assert.NoError(t, err) assert.Len(t, info.Logs, 6) @@ -157,7 +153,7 @@ func TestGetTaskPhase(t *testing.T) { sparkResourceHandler := sparkResourceHandler{} ctx := context.TODO() - taskCtx := dummySparkTaskContext(dummySparkTaskTemplate("", dummySparkConf), false) + taskCtx := dummySparkTaskContext(dummySparkTaskTemplateContainer("", dummySparkConf), false) taskPhase, err := sparkResourceHandler.GetTaskPhase(ctx, taskCtx, dummySparkApplication(sj.NewState)) assert.NoError(t, err) assert.Equal(t, taskPhase.Phase(), pluginsCore.PhaseQueued) @@ -250,8 +246,20 @@ func dummySparkCustomObj(sparkConf map[string]string) *plugins.SparkJob { return &sparkJob } -func dummySparkTaskTemplate(id string, sparkConf map[string]string) *core.TaskTemplate { +func dummyPodSpec() *corev1.PodSpec { + return &corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "primary", + Image: testImage, + Args: testArgs, + Env: flytek8s.ToK8sEnvVar(dummyEnvVars), + }, + }, + } +} +func dummySparkTaskTemplateContainer(id string, sparkConf map[string]string) *core.TaskTemplate { sparkJob := dummySparkCustomObj(sparkConf) sparkJobJSON, err := utils.MarshalToString(sparkJob) if err != nil { @@ -279,6 +287,40 @@ func dummySparkTaskTemplate(id string, sparkConf map[string]string) *core.TaskTe } } +func dummySparkTaskTemplatePod(id string, sparkConf map[string]string, podSpec *corev1.PodSpec) *core.TaskTemplate { + sparkJob := dummySparkCustomObj(sparkConf) + sparkJobJSON, err := utils.MarshalToString(sparkJob) + if err != nil { + panic(err) + } + + structObj := structpb.Struct{} + + err = jsonpb.UnmarshalString(sparkJobJSON, &structObj) + if err != nil { + panic(err) + } + + podSpecPb, err := utils.MarshalObjToStruct(podSpec) + if err != nil { + panic(err) + } + + return &core.TaskTemplate{ + Id: &core.Identifier{Name: id}, + Type: "k8s_pod", + Target: &core.TaskTemplate_K8SPod{ + K8SPod: &core.K8SPod{ + PodSpec: podSpecPb, + }, + }, + Config: map[string]string{ + flytek8s.PrimaryContainerKey: "primary", + }, + Custom: &structObj, + } +} + func dummySparkTaskContext(taskTemplate *core.TaskTemplate, interruptible bool) pluginsCore.TaskExecutionContext { taskCtx := &mocks.TaskExecutionContext{} inputReader := &pluginIOMocks.InputReader{} @@ -312,6 +354,9 @@ func dummySparkTaskContext(taskTemplate *core.TaskTemplate, interruptible bool) }) tID.On("GetGeneratedName").Return("some-acceptable-name") + overrides := &mocks.TaskOverrides{} + overrides.On("GetResources").Return(&corev1.ResourceRequirements{}) + taskExecutionMetadata := &mocks.TaskExecutionMetadata{} taskExecutionMetadata.On("GetTaskExecutionID").Return(tID) taskExecutionMetadata.On("GetNamespace").Return("test-namespace") @@ -327,6 +372,9 @@ func dummySparkTaskContext(taskTemplate *core.TaskTemplate, interruptible bool) taskExecutionMetadata.On("IsInterruptible").Return(interruptible) taskExecutionMetadata.On("GetMaxAttempts").Return(uint32(1)) taskExecutionMetadata.On("GetEnvironmentVariables").Return(nil) + taskExecutionMetadata.On("GetPlatformResources").Return(nil) + taskExecutionMetadata.On("GetOverrides").Return(overrides) + taskExecutionMetadata.On("GetK8sServiceAccount").Return("new-val") taskCtx.On("TaskExecutionMetadata").Return(taskExecutionMetadata) return taskCtx } @@ -335,7 +383,7 @@ func TestBuildResourceSpark(t *testing.T) { sparkResourceHandler := sparkResourceHandler{} // Case1: Valid Spark Task-Template - taskTemplate := dummySparkTaskTemplate("blah-1", dummySparkConf) + taskTemplate := dummySparkTaskTemplateContainer("blah-1", dummySparkConf) // Set spark custom feature config. assert.NoError(t, setSparkConfig(&Config{ @@ -522,6 +570,7 @@ func TestBuildResourceSpark(t *testing.T) { // Validate // * Interruptible Toleration and NodeSelector set for Executor but not Driver. + // TODO: confirm expected behavior // * Validate Default NodeSelector set for Driver but overwritten with Interruptible NodeSelector for Executor. // * Default Tolerations set for both Driver and Executor. // * Interruptible/Non-Interruptible NodeSelectorRequirements set for Executor Affinity but not Driver Affinity. @@ -538,17 +587,18 @@ func TestBuildResourceSpark(t *testing.T) { assert.Equal(t, 1, len(sparkApp.Spec.Executor.NodeSelector)) assert.Equal(t, interruptibleNodeSelector, sparkApp.Spec.Executor.NodeSelector) - tolExecDefault := sparkApp.Spec.Executor.Tolerations[0] + tolExecInterrupt := sparkApp.Spec.Executor.Tolerations[0] + assert.Equal(t, tolExecInterrupt.Key, "x/flyte") + assert.Equal(t, tolExecInterrupt.Value, "interruptible") + assert.Equal(t, tolExecInterrupt.Operator, corev1.TolerationOperator("Equal")) + assert.Equal(t, tolExecInterrupt.Effect, corev1.TaintEffect("NoSchedule")) + + tolExecDefault := sparkApp.Spec.Executor.Tolerations[1] assert.Equal(t, tolExecDefault.Key, "x/flyte") assert.Equal(t, tolExecDefault.Value, "default") assert.Equal(t, tolExecDefault.Operator, corev1.TolerationOperator("Equal")) assert.Equal(t, tolExecDefault.Effect, corev1.TaintEffect("NoSchedule")) - tolExecInterrupt := sparkApp.Spec.Executor.Tolerations[1] - assert.Equal(t, tolExecInterrupt.Key, "x/flyte") - assert.Equal(t, tolExecInterrupt.Value, "interruptible") - assert.Equal(t, tolExecInterrupt.Operator, corev1.TolerationOperator("Equal")) - assert.Equal(t, tolExecInterrupt.Effect, corev1.TaintEffect("NoSchedule")) assert.Equal(t, "true", sparkApp.Spec.Executor.NodeSelector["x/interruptible"]) for confKey, confVal := range dummySparkConf { @@ -619,7 +669,7 @@ func TestBuildResourceSpark(t *testing.T) { dummyConfWithRequest["spark.kubernetes.driver.request.cores"] = "3" dummyConfWithRequest["spark.kubernetes.executor.request.cores"] = "4" - taskTemplate = dummySparkTaskTemplate("blah-1", dummyConfWithRequest) + taskTemplate = dummySparkTaskTemplateContainer("blah-1", dummyConfWithRequest) resource, err = sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, false)) assert.Nil(t, err) assert.NotNil(t, resource) @@ -678,6 +728,47 @@ func TestBuildResourceSpark(t *testing.T) { assert.Nil(t, resource) } +func TestBuildResourcePodTemplate(t *testing.T) { + defaultToleration := corev1.Toleration{ + + Key: "x/flyte", + Value: "default", + Operator: "Equal", + } + err := config.SetK8sPluginConfig(&config.K8sPluginConfig{ + DefaultTolerations: []corev1.Toleration{ + defaultToleration, + }, + }) + assert.NoError(t, err) + extraToleration := corev1.Toleration{ + Key: "x/flyte", + Value: "extra", + Operator: "Equal", + } + podSpec := dummyPodSpec() + podSpec.Tolerations = append(podSpec.Tolerations, extraToleration) + taskTemplate := dummySparkTaskTemplatePod("blah-1", dummySparkConf, podSpec) + taskTemplate.GetK8SPod() + sparkResourceHandler := sparkResourceHandler{} + resource, err := sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, false)) + assert.Nil(t, err) + + assert.NotNil(t, resource) + sparkApp, ok := resource.(*sj.SparkApplication) + assert.True(t, ok) + assert.Equal(t, 2, len(sparkApp.Spec.Driver.Tolerations)) + assert.Equal(t, sparkApp.Spec.Driver.Tolerations, []corev1.Toleration{ + defaultToleration, + extraToleration, + }) + assert.Equal(t, 2, len(sparkApp.Spec.Executor.Tolerations)) + assert.Equal(t, sparkApp.Spec.Executor.Tolerations, []corev1.Toleration{ + defaultToleration, + extraToleration, + }) +} + func TestGetPropertiesSpark(t *testing.T) { sparkResourceHandler := sparkResourceHandler{} expected := k8s.PluginProperties{} From 072cda812cf02bcc5b5b5076df5f834296225dee Mon Sep 17 00:00:00 2001 From: Andrew Dye Date: Tue, 10 Oct 2023 09:54:44 -0700 Subject: [PATCH 2/6] Comments Signed-off-by: Andrew Dye --- .../flytek8s/non_interruptible.go | 61 +------------------ .../go/tasks/plugins/k8s/spark/spark.go | 10 +-- 2 files changed, 4 insertions(+), 67 deletions(-) diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/non_interruptible.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/non_interruptible.go index daa00241bb..d2f5042cf8 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/non_interruptible.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/non_interruptible.go @@ -1,70 +1,13 @@ package flytek8s import ( - v1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/types" - - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" ) // Wraps a regular TaskExecutionMetadata and overrides the IsInterruptible method to always return false // This is useful as the runner and the scheduler pods should never be interruptible type NonInterruptibleTaskExecutionMetadata struct { - metadata pluginsCore.TaskExecutionMetadata -} - -func (n NonInterruptibleTaskExecutionMetadata) GetOwnerID() types.NamespacedName { - return n.metadata.GetOwnerID() -} - -func (n NonInterruptibleTaskExecutionMetadata) GetTaskExecutionID() pluginsCore.TaskExecutionID { - return n.metadata.GetTaskExecutionID() -} - -func (n NonInterruptibleTaskExecutionMetadata) GetNamespace() string { - return n.metadata.GetNamespace() -} - -func (n NonInterruptibleTaskExecutionMetadata) GetOwnerReference() metav1.OwnerReference { - return n.metadata.GetOwnerReference() -} - -func (n NonInterruptibleTaskExecutionMetadata) GetOverrides() pluginsCore.TaskOverrides { - return n.metadata.GetOverrides() -} - -func (n NonInterruptibleTaskExecutionMetadata) GetLabels() map[string]string { - return n.metadata.GetLabels() -} - -func (n NonInterruptibleTaskExecutionMetadata) GetMaxAttempts() uint32 { - return n.metadata.GetMaxAttempts() -} - -func (n NonInterruptibleTaskExecutionMetadata) GetAnnotations() map[string]string { - return n.metadata.GetAnnotations() -} - -func (n NonInterruptibleTaskExecutionMetadata) GetK8sServiceAccount() string { - return n.metadata.GetK8sServiceAccount() -} - -func (n NonInterruptibleTaskExecutionMetadata) GetSecurityContext() core.SecurityContext { - return n.metadata.GetSecurityContext() -} - -func (n NonInterruptibleTaskExecutionMetadata) GetPlatformResources() *v1.ResourceRequirements { - return n.metadata.GetPlatformResources() -} - -func (n NonInterruptibleTaskExecutionMetadata) GetInterruptibleFailureThreshold() int32 { - return n.metadata.GetInterruptibleFailureThreshold() -} - -func (n NonInterruptibleTaskExecutionMetadata) GetEnvironmentVariables() map[string]string { - return n.metadata.GetEnvironmentVariables() + pluginsCore.TaskExecutionMetadata } func (n NonInterruptibleTaskExecutionMetadata) IsInterruptible() bool { @@ -86,7 +29,7 @@ func NewNonInterruptibleTaskExecutionContext(ctx pluginsCore.TaskExecutionContex return NonInterruptibleTaskExecutionContext{ TaskExecutionContext: ctx, metadata: NonInterruptibleTaskExecutionMetadata{ - metadata: ctx.TaskExecutionMetadata(), + ctx.TaskExecutionMetadata(), }, } } diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go index 8365c023aa..7f791eda29 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go @@ -161,16 +161,14 @@ func createSparkPodSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionCo SecurityContenxt: podSpec.SecurityContext.DeepCopy(), DNSConfig: podSpec.DNSConfig.DeepCopy(), Tolerations: podSpec.Tolerations, - SchedulerName: &config.GetK8sPluginConfig().SchedulerName, + SchedulerName: &podSpec.SchedulerName, NodeSelector: podSpec.NodeSelector, - HostNetwork: config.GetK8sPluginConfig().EnableHostNetworkingPod, + HostNetwork: &podSpec.HostNetwork, } return &spec, nil } type driverSpec struct { - podSpec *v1.PodSpec - container *v1.Container sparkSpec *sparkOp.DriverSpec } @@ -191,8 +189,6 @@ func createDriverSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionCont } serviceAccountName := serviceAccountName(nonInterruptibleTaskCtx.TaskExecutionMetadata()) spec := driverSpec{ - podSpec, - primaryContainer, &sparkOp.DriverSpec{ SparkPodSpec: *sparkPodSpec, ServiceAccount: &serviceAccountName, @@ -206,7 +202,6 @@ func createDriverSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionCont } type executorSpec struct { - podSpec *v1.PodSpec container *v1.Container sparkSpec *sparkOp.ExecutorSpec serviceAccountName string @@ -227,7 +222,6 @@ func createExecutorSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionCo } serviceAccountName := serviceAccountName(taskCtx.TaskExecutionMetadata()) spec := executorSpec{ - podSpec, primaryContainer, &sparkOp.ExecutorSpec{ SparkPodSpec: *sparkPodSpec, From 6fc00ea6a8af7dab7f9a29364db224b54266a2b1 Mon Sep 17 00:00:00 2001 From: Andrew Dye Date: Tue, 10 Oct 2023 10:37:33 -0700 Subject: [PATCH 3/6] Expect merged interruptible node selectors Signed-off-by: Andrew Dye --- flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go index 13bf1153ac..25f1b0d575 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go @@ -570,7 +570,6 @@ func TestBuildResourceSpark(t *testing.T) { // Validate // * Interruptible Toleration and NodeSelector set for Executor but not Driver. - // TODO: confirm expected behavior // * Validate Default NodeSelector set for Driver but overwritten with Interruptible NodeSelector for Executor. // * Default Tolerations set for both Driver and Executor. // * Interruptible/Non-Interruptible NodeSelectorRequirements set for Executor Affinity but not Driver Affinity. @@ -584,8 +583,11 @@ func TestBuildResourceSpark(t *testing.T) { assert.Equal(t, tolDriverDefault.Effect, corev1.TaintEffect("NoSchedule")) assert.Equal(t, 2, len(sparkApp.Spec.Executor.Tolerations)) - assert.Equal(t, 1, len(sparkApp.Spec.Executor.NodeSelector)) - assert.Equal(t, interruptibleNodeSelector, sparkApp.Spec.Executor.NodeSelector) + assert.Equal(t, 2, len(sparkApp.Spec.Executor.NodeSelector)) + assert.Equal(t, map[string]string{ + "x/default": "true", + "x/interruptible": "true", + }, sparkApp.Spec.Executor.NodeSelector) tolExecInterrupt := sparkApp.Spec.Executor.Tolerations[0] assert.Equal(t, tolExecInterrupt.Key, "x/flyte") @@ -599,8 +601,6 @@ func TestBuildResourceSpark(t *testing.T) { assert.Equal(t, tolExecDefault.Operator, corev1.TolerationOperator("Equal")) assert.Equal(t, tolExecDefault.Effect, corev1.TaintEffect("NoSchedule")) - assert.Equal(t, "true", sparkApp.Spec.Executor.NodeSelector["x/interruptible"]) - for confKey, confVal := range dummySparkConf { exists := false From 94cbc222efa192f1a1853bb7d01e287cab32386d Mon Sep 17 00:00:00 2001 From: Andrew Dye Date: Tue, 10 Oct 2023 11:25:10 -0700 Subject: [PATCH 4/6] Fix lints Signed-off-by: Andrew Dye --- flyteplugins/go/tasks/plugins/k8s/spark/spark.go | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go index 7f791eda29..d0506ccfb5 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go @@ -141,8 +141,7 @@ func serviceAccountName(metadata pluginsCore.TaskExecutionMetadata) string { return name } -func createSparkPodSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, podSpec *v1.PodSpec, container *v1.Container) ( - *sparkOp.SparkPodSpec, error) { +func createSparkPodSpec(taskCtx pluginsCore.TaskExecutionContext, podSpec *v1.PodSpec, container *v1.Container) *sparkOp.SparkPodSpec { annotations := utils.UnionMaps(config.GetK8sPluginConfig().DefaultAnnotations, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations())) labels := utils.UnionMaps(config.GetK8sPluginConfig().DefaultLabels, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels())) @@ -165,7 +164,7 @@ func createSparkPodSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionCo NodeSelector: podSpec.NodeSelector, HostNetwork: &podSpec.HostNetwork, } - return &spec, nil + return &spec } type driverSpec struct { @@ -183,10 +182,7 @@ func createDriverSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionCont if err != nil { return nil, err } - sparkPodSpec, err := createSparkPodSpec(ctx, nonInterruptibleTaskCtx, podSpec, primaryContainer) - if err != nil { - return nil, err - } + sparkPodSpec := createSparkPodSpec(nonInterruptibleTaskCtx, podSpec, primaryContainer) serviceAccountName := serviceAccountName(nonInterruptibleTaskCtx.TaskExecutionMetadata()) spec := driverSpec{ &sparkOp.DriverSpec{ @@ -216,10 +212,7 @@ func createExecutorSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionCo if err != nil { return nil, err } - sparkPodSpec, err := createSparkPodSpec(ctx, taskCtx, podSpec, primaryContainer) - if err != nil { - return nil, err - } + sparkPodSpec := createSparkPodSpec(taskCtx, podSpec, primaryContainer) serviceAccountName := serviceAccountName(taskCtx.TaskExecutionMetadata()) spec := executorSpec{ primaryContainer, From b9f01c0bdb1f23341021f65c3839cf8bc32bf4ff Mon Sep 17 00:00:00 2001 From: Andrew Dye Date: Tue, 10 Oct 2023 14:39:14 -0700 Subject: [PATCH 5/6] Rename and cleanup TestBuildResourceContainer Signed-off-by: Andrew Dye --- .../go/tasks/plugins/k8s/spark/spark_test.go | 211 ++++++++++-------- 1 file changed, 113 insertions(+), 98 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go index 25f1b0d575..95bae4dceb 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go @@ -248,6 +248,13 @@ func dummySparkCustomObj(sparkConf map[string]string) *plugins.SparkJob { func dummyPodSpec() *corev1.PodSpec { return &corev1.PodSpec{ + InitContainers: []corev1.Container{ + { + Name: "init", + Image: testImage, + Args: testArgs, + }, + }, Containers: []corev1.Container{ { Name: "primary", @@ -255,6 +262,12 @@ func dummyPodSpec() *corev1.PodSpec { Args: testArgs, Env: flytek8s.ToK8sEnvVar(dummyEnvVars), }, + { + Name: "secondary", + Image: testImage, + Args: testArgs, + Env: flytek8s.ToK8sEnvVar(dummyEnvVars), + }, }, } } @@ -379,26 +392,7 @@ func dummySparkTaskContext(taskTemplate *core.TaskTemplate, interruptible bool) return taskCtx } -func TestBuildResourceSpark(t *testing.T) { - sparkResourceHandler := sparkResourceHandler{} - - // Case1: Valid Spark Task-Template - taskTemplate := dummySparkTaskTemplateContainer("blah-1", dummySparkConf) - - // Set spark custom feature config. - assert.NoError(t, setSparkConfig(&Config{ - Features: []Feature{ - { - Name: "feature1", - SparkConfig: map[string]string{"spark.hadoop.feature1": "true"}, - }, - { - Name: "feature2", - SparkConfig: map[string]string{"spark.hadoop.feature2": "true"}, - }, - }, - })) - +func defaultPluginConfig() *config.K8sPluginConfig { // Set Interruptible Config runAsUser := int64(1000) dnsOptVal1 := "1" @@ -448,7 +442,7 @@ func TestBuildResourceSpark(t *testing.T) { }, } - // interruptible/non-interruptible nodeselector requirement + // Interruptible/non-interruptible nodeselector requirement interruptibleNodeSelectorRequirement := &corev1.NodeSelectorRequirement{ Key: "x/interruptible", Operator: corev1.NodeSelectorOpIn, @@ -461,9 +455,7 @@ func TestBuildResourceSpark(t *testing.T) { Values: []string{"true"}, } - // NonInterruptibleNodeSelectorRequirement - - assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{ + config := &config.K8sPluginConfig{ DefaultAffinity: defaultAffinity, DefaultPodSecurityContext: &corev1.PodSecurityContext{ RunAsUser: &runAsUser, @@ -513,8 +505,32 @@ func TestBuildResourceSpark(t *testing.T) { EnableHostNetworkingPod: &defaultPodHostNetwork, DefaultEnvVars: defaultEnvVars, DefaultEnvVarsFromEnv: defaultEnvVarsFromEnv, - }), - ) + } + return config +} + +func TestBuildResourceContainer(t *testing.T) { + sparkResourceHandler := sparkResourceHandler{} + + // Case1: Valid Spark Task-Template + taskTemplate := dummySparkTaskTemplateContainer("blah-1", dummySparkConf) + + // Set spark custom feature config. + assert.NoError(t, setSparkConfig(&Config{ + Features: []Feature{ + { + Name: "feature1", + SparkConfig: map[string]string{"spark.hadoop.feature1": "true"}, + }, + { + Name: "feature2", + SparkConfig: map[string]string{"spark.hadoop.feature2": "true"}, + }, + }, + })) + + defaultConfig := defaultPluginConfig() + assert.NoError(t, config.SetK8sPluginConfig(defaultConfig)) resource, err := sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, true)) assert.Nil(t, err) @@ -527,28 +543,16 @@ func TestBuildResourceSpark(t *testing.T) { assert.Equal(t, testArgs, sparkApp.Spec.Arguments) assert.Equal(t, testImage, *sparkApp.Spec.Image) assert.NotNil(t, sparkApp.Spec.Driver.SparkPodSpec.SecurityContenxt) - assert.Equal(t, *sparkApp.Spec.Driver.SparkPodSpec.SecurityContenxt.RunAsUser, runAsUser) + assert.Equal(t, *sparkApp.Spec.Driver.SparkPodSpec.SecurityContenxt.RunAsUser, *defaultConfig.DefaultPodSecurityContext.RunAsUser) assert.NotNil(t, sparkApp.Spec.Driver.DNSConfig) assert.Equal(t, []string{"8.8.8.8", "8.8.4.4"}, sparkApp.Spec.Driver.DNSConfig.Nameservers) - assert.Equal(t, "ndots", sparkApp.Spec.Driver.DNSConfig.Options[0].Name) - assert.Equal(t, dnsOptVal1, *sparkApp.Spec.Driver.DNSConfig.Options[0].Value) - assert.Equal(t, "single-request-reopen", sparkApp.Spec.Driver.DNSConfig.Options[1].Name) - assert.Equal(t, "timeout", sparkApp.Spec.Driver.DNSConfig.Options[2].Name) - assert.Equal(t, dnsOptVal2, *sparkApp.Spec.Driver.DNSConfig.Options[2].Value) - assert.Equal(t, "attempts", sparkApp.Spec.Driver.DNSConfig.Options[3].Name) - assert.Equal(t, dnsOptVal3, *sparkApp.Spec.Driver.DNSConfig.Options[3].Value) + assert.ElementsMatch(t, defaultConfig.DefaultPodDNSConfig.Options, sparkApp.Spec.Driver.DNSConfig.Options) assert.Equal(t, []string{"ns1.svc.cluster-domain.example", "my.dns.search.suffix"}, sparkApp.Spec.Driver.DNSConfig.Searches) assert.NotNil(t, sparkApp.Spec.Executor.SparkPodSpec.SecurityContenxt) - assert.Equal(t, *sparkApp.Spec.Executor.SparkPodSpec.SecurityContenxt.RunAsUser, runAsUser) + assert.Equal(t, *sparkApp.Spec.Executor.SparkPodSpec.SecurityContenxt.RunAsUser, *defaultConfig.DefaultPodSecurityContext.RunAsUser) assert.NotNil(t, sparkApp.Spec.Executor.DNSConfig) assert.NotNil(t, sparkApp.Spec.Executor.DNSConfig) - assert.Equal(t, "ndots", sparkApp.Spec.Executor.DNSConfig.Options[0].Name) - assert.Equal(t, dnsOptVal1, *sparkApp.Spec.Executor.DNSConfig.Options[0].Value) - assert.Equal(t, "single-request-reopen", sparkApp.Spec.Executor.DNSConfig.Options[1].Name) - assert.Equal(t, "timeout", sparkApp.Spec.Executor.DNSConfig.Options[2].Name) - assert.Equal(t, dnsOptVal2, *sparkApp.Spec.Executor.DNSConfig.Options[2].Value) - assert.Equal(t, "attempts", sparkApp.Spec.Executor.DNSConfig.Options[3].Name) - assert.Equal(t, dnsOptVal3, *sparkApp.Spec.Executor.DNSConfig.Options[3].Value) + assert.ElementsMatch(t, defaultConfig.DefaultPodDNSConfig.Options, sparkApp.Spec.Executor.DNSConfig.Options) assert.Equal(t, []string{"ns1.svc.cluster-domain.example", "my.dns.search.suffix"}, sparkApp.Spec.Executor.DNSConfig.Searches) //Validate Driver/Executor Spec. @@ -563,19 +567,19 @@ func TestBuildResourceSpark(t *testing.T) { assert.Equal(t, dummySparkConf["spark.driver.memory"], *sparkApp.Spec.Driver.Memory) assert.Equal(t, dummySparkConf["spark.executor.memory"], *sparkApp.Spec.Executor.Memory) assert.Equal(t, dummySparkConf["spark.batchScheduler"], *sparkApp.Spec.BatchScheduler) - assert.Equal(t, schedulerName, *sparkApp.Spec.Executor.SchedulerName) - assert.Equal(t, schedulerName, *sparkApp.Spec.Driver.SchedulerName) - assert.Equal(t, defaultPodHostNetwork, *sparkApp.Spec.Executor.HostNetwork) - assert.Equal(t, defaultPodHostNetwork, *sparkApp.Spec.Driver.HostNetwork) + assert.Equal(t, defaultConfig.SchedulerName, *sparkApp.Spec.Executor.SchedulerName) + assert.Equal(t, defaultConfig.SchedulerName, *sparkApp.Spec.Driver.SchedulerName) + assert.Equal(t, *defaultConfig.EnableHostNetworkingPod, *sparkApp.Spec.Executor.HostNetwork) + assert.Equal(t, *defaultConfig.EnableHostNetworkingPod, *sparkApp.Spec.Driver.HostNetwork) // Validate - // * Interruptible Toleration and NodeSelector set for Executor but not Driver. - // * Validate Default NodeSelector set for Driver but overwritten with Interruptible NodeSelector for Executor. - // * Default Tolerations set for both Driver and Executor. - // * Interruptible/Non-Interruptible NodeSelectorRequirements set for Executor Affinity but not Driver Affinity. + // * Default tolerations set for both Driver and Executor. + // * Interruptible tolerations and node selector set for Executor but not Driver. + // * Default node selector set for both Driver and Executor. + // * Interruptible node selector requirements set for Executor Affinity, non-interruptiblefir Driver Affinity. assert.Equal(t, 1, len(sparkApp.Spec.Driver.Tolerations)) assert.Equal(t, 1, len(sparkApp.Spec.Driver.NodeSelector)) - assert.Equal(t, defaultNodeSelector, sparkApp.Spec.Driver.NodeSelector) + assert.Equal(t, defaultConfig.DefaultNodeSelector, sparkApp.Spec.Driver.NodeSelector) tolDriverDefault := sparkApp.Spec.Driver.Tolerations[0] assert.Equal(t, tolDriverDefault.Key, "x/flyte") assert.Equal(t, tolDriverDefault.Value, "default") @@ -633,31 +637,36 @@ func TestBuildResourceSpark(t *testing.T) { assert.Equal(t, dummySparkConf["spark.flyteorg.feature3.enabled"], sparkApp.Spec.SparkConf["spark.flyteorg.feature3.enabled"]) assert.Equal(t, len(sparkApp.Spec.Driver.EnvVars["FLYTE_MAX_ATTEMPTS"]), 1) - assert.Equal(t, sparkApp.Spec.Driver.EnvVars["foo"], defaultEnvVars["foo"]) - assert.Equal(t, sparkApp.Spec.Executor.EnvVars["foo"], defaultEnvVars["foo"]) - assert.Equal(t, sparkApp.Spec.Driver.EnvVars["fooEnv"], targetValueFromEnv) - assert.Equal(t, sparkApp.Spec.Executor.EnvVars["fooEnv"], targetValueFromEnv) - - assert.Equal( - t, - sparkApp.Spec.Driver.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], - defaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], - ) - assert.Equal( - t, - sparkApp.Spec.Driver.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[1], - *nonInterruptibleNodeSelectorRequirement, - ) - assert.Equal( - t, - sparkApp.Spec.Executor.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], - defaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], - ) - assert.Equal( - t, - sparkApp.Spec.Executor.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[1], - *interruptibleNodeSelectorRequirement, - ) + assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], sparkApp.Spec.Driver.EnvVars["foo"]) + assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], sparkApp.Spec.Executor.EnvVars["foo"]) + assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], sparkApp.Spec.Driver.EnvVars["fooEnv"]) + assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], sparkApp.Spec.Executor.EnvVars["fooEnv"]) + + assert.Equal(t, &corev1.NodeAffinity{ + RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{ + NodeSelectorTerms: []corev1.NodeSelectorTerm{ + { + MatchExpressions: []corev1.NodeSelectorRequirement{ + defaultConfig.DefaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], + *defaultConfig.NonInterruptibleNodeSelectorRequirement, + }, + }, + }, + }, + }, sparkApp.Spec.Driver.Affinity.NodeAffinity) + + assert.Equal(t, &corev1.NodeAffinity{ + RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{ + NodeSelectorTerms: []corev1.NodeSelectorTerm{ + { + MatchExpressions: []corev1.NodeSelectorRequirement{ + defaultConfig.DefaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], + *defaultConfig.InterruptibleNodeSelectorRequirement, + }, + }, + }, + }, + }, sparkApp.Spec.Executor.Affinity.NodeAffinity) // Case 2: Driver/Executor request cores set. dummyConfWithRequest := make(map[string]string) @@ -690,36 +699,41 @@ func TestBuildResourceSpark(t *testing.T) { // Validate that the default Toleration and NodeSelector are set for both Driver and Executors. assert.Equal(t, 1, len(sparkApp.Spec.Driver.Tolerations)) assert.Equal(t, 1, len(sparkApp.Spec.Driver.NodeSelector)) - assert.Equal(t, defaultNodeSelector, sparkApp.Spec.Driver.NodeSelector) + assert.Equal(t, defaultConfig.DefaultNodeSelector, sparkApp.Spec.Driver.NodeSelector) assert.Equal(t, 1, len(sparkApp.Spec.Executor.Tolerations)) assert.Equal(t, 1, len(sparkApp.Spec.Executor.NodeSelector)) - assert.Equal(t, defaultNodeSelector, sparkApp.Spec.Executor.NodeSelector) + assert.Equal(t, defaultConfig.DefaultNodeSelector, sparkApp.Spec.Executor.NodeSelector) assert.Equal(t, sparkApp.Spec.Executor.Tolerations[0].Key, "x/flyte") assert.Equal(t, sparkApp.Spec.Executor.Tolerations[0].Value, "default") assert.Equal(t, sparkApp.Spec.Driver.Tolerations[0].Key, "x/flyte") assert.Equal(t, sparkApp.Spec.Driver.Tolerations[0].Value, "default") // Validate correct affinity and nodeselector requirements are set for both Driver and Executors. - assert.Equal( - t, - sparkApp.Spec.Driver.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], - defaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], - ) - assert.Equal( - t, - sparkApp.Spec.Driver.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[1], - *nonInterruptibleNodeSelectorRequirement, - ) - assert.Equal( - t, - sparkApp.Spec.Executor.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], - defaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], - ) - assert.Equal( - t, - sparkApp.Spec.Executor.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[1], - *nonInterruptibleNodeSelectorRequirement, - ) + assert.Equal(t, &corev1.NodeAffinity{ + RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{ + NodeSelectorTerms: []corev1.NodeSelectorTerm{ + { + MatchExpressions: []corev1.NodeSelectorRequirement{ + defaultConfig.DefaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], + *defaultConfig.NonInterruptibleNodeSelectorRequirement, + }, + }, + }, + }, + }, sparkApp.Spec.Driver.Affinity.NodeAffinity) + + assert.Equal(t, &corev1.NodeAffinity{ + RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{ + NodeSelectorTerms: []corev1.NodeSelectorTerm{ + { + MatchExpressions: []corev1.NodeSelectorRequirement{ + defaultConfig.DefaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], + *defaultConfig.NonInterruptibleNodeSelectorRequirement, + }, + }, + }, + }, + }, sparkApp.Spec.Executor.Affinity.NodeAffinity) // Case 4: Invalid Spark Task-Template taskTemplate.Custom = nil @@ -748,6 +762,7 @@ func TestBuildResourcePodTemplate(t *testing.T) { } podSpec := dummyPodSpec() podSpec.Tolerations = append(podSpec.Tolerations, extraToleration) + podSpec.NodeSelector["x/custom"] = "foo" taskTemplate := dummySparkTaskTemplatePod("blah-1", dummySparkConf, podSpec) taskTemplate.GetK8SPod() sparkResourceHandler := sparkResourceHandler{} From de2e42af8e284beab1adc920e23ca1374e3b9b53 Mon Sep 17 00:00:00 2001 From: Andrew Dye Date: Tue, 10 Oct 2023 16:28:30 -0700 Subject: [PATCH 6/6] Add more test coverage to TestBuildResourcePodTemplate Signed-off-by: Andrew Dye --- .../go/tasks/plugins/k8s/spark/spark_test.go | 121 ++++++++++++++---- 1 file changed, 98 insertions(+), 23 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go index 95bae4dceb..18565fc0a4 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go @@ -7,6 +7,7 @@ import ( "testing" sj "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2" + sparkOp "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2" "github.com/golang/protobuf/jsonpb" structpb "github.com/golang/protobuf/ptypes/struct" "github.com/stretchr/testify/assert" @@ -743,18 +744,8 @@ func TestBuildResourceContainer(t *testing.T) { } func TestBuildResourcePodTemplate(t *testing.T) { - defaultToleration := corev1.Toleration{ - - Key: "x/flyte", - Value: "default", - Operator: "Equal", - } - err := config.SetK8sPluginConfig(&config.K8sPluginConfig{ - DefaultTolerations: []corev1.Toleration{ - defaultToleration, - }, - }) - assert.NoError(t, err) + defaultConfig := defaultPluginConfig() + assert.NoError(t, config.SetK8sPluginConfig(defaultConfig)) extraToleration := corev1.Toleration{ Key: "x/flyte", Value: "extra", @@ -762,26 +753,110 @@ func TestBuildResourcePodTemplate(t *testing.T) { } podSpec := dummyPodSpec() podSpec.Tolerations = append(podSpec.Tolerations, extraToleration) - podSpec.NodeSelector["x/custom"] = "foo" + podSpec.NodeSelector = map[string]string{"x/custom": "foo"} taskTemplate := dummySparkTaskTemplatePod("blah-1", dummySparkConf, podSpec) taskTemplate.GetK8SPod() sparkResourceHandler := sparkResourceHandler{} - resource, err := sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, false)) - assert.Nil(t, err) + taskCtx := dummySparkTaskContext(taskTemplate, true) + resource, err := sparkResourceHandler.BuildResource(context.TODO(), taskCtx) + + assert.Nil(t, err) assert.NotNil(t, resource) sparkApp, ok := resource.(*sj.SparkApplication) assert.True(t, ok) - assert.Equal(t, 2, len(sparkApp.Spec.Driver.Tolerations)) - assert.Equal(t, sparkApp.Spec.Driver.Tolerations, []corev1.Toleration{ - defaultToleration, + + // Application + assert.Equal(t, v1.TypeMeta{ + Kind: KindSparkApplication, + APIVersion: sparkOp.SchemeGroupVersion.String(), + }, sparkApp.TypeMeta) + + // Application spec + assert.Equal(t, flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()), *sparkApp.Spec.ServiceAccount) + assert.Equal(t, sparkOp.PythonApplicationType, sparkApp.Spec.Type) + assert.Equal(t, testImage, *sparkApp.Spec.Image) + assert.Equal(t, testArgs, sparkApp.Spec.Arguments) + assert.Equal(t, sparkOp.RestartPolicy{ + Type: sparkOp.OnFailure, + OnSubmissionFailureRetries: intPtr(int32(14)), + }, sparkApp.Spec.RestartPolicy) + assert.Equal(t, sparkMainClass, *sparkApp.Spec.MainClass) + assert.Equal(t, sparkApplicationFile, *sparkApp.Spec.MainApplicationFile) + + // Driver + assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultAnnotations, map[string]string{"annotation-1": "val1"}), sparkApp.Spec.Driver.Annotations) + assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{"label-1": "val1"}), sparkApp.Spec.Driver.Labels) + assert.Equal(t, len(sparkApp.Spec.Driver.EnvVars["FLYTE_MAX_ATTEMPTS"]), 1) + assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], sparkApp.Spec.Driver.EnvVars["foo"]) + assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], sparkApp.Spec.Driver.EnvVars["fooEnv"]) + assert.Equal(t, testImage, *sparkApp.Spec.Driver.Image) + assert.Equal(t, flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()), *sparkApp.Spec.Driver.ServiceAccount) + assert.Equal(t, defaultConfig.DefaultPodSecurityContext, sparkApp.Spec.Driver.SecurityContenxt) + assert.Equal(t, defaultConfig.DefaultPodDNSConfig, sparkApp.Spec.Driver.DNSConfig) + assert.Equal(t, defaultConfig.EnableHostNetworkingPod, sparkApp.Spec.Driver.HostNetwork) + assert.Equal(t, defaultConfig.SchedulerName, *sparkApp.Spec.Driver.SchedulerName) + assert.Equal(t, []corev1.Toleration{ + defaultConfig.DefaultTolerations[0], extraToleration, - }) - assert.Equal(t, 2, len(sparkApp.Spec.Executor.Tolerations)) - assert.Equal(t, sparkApp.Spec.Executor.Tolerations, []corev1.Toleration{ - defaultToleration, + }, sparkApp.Spec.Driver.Tolerations) + assert.Equal(t, map[string]string{ + "x/default": "true", + "x/custom": "foo", + }, sparkApp.Spec.Driver.NodeSelector) + assert.Equal(t, &corev1.NodeAffinity{ + RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{ + NodeSelectorTerms: []corev1.NodeSelectorTerm{ + { + MatchExpressions: []corev1.NodeSelectorRequirement{ + defaultConfig.DefaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], + *defaultConfig.NonInterruptibleNodeSelectorRequirement, + }, + }, + }, + }, + }, sparkApp.Spec.Driver.Affinity.NodeAffinity) + cores, _ := strconv.ParseInt(dummySparkConf["spark.driver.cores"], 10, 32) + assert.Equal(t, intPtr(int32(cores)), sparkApp.Spec.Driver.Cores) + assert.Equal(t, dummySparkConf["spark.driver.memory"], *sparkApp.Spec.Driver.Memory) + + // Executor + assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultAnnotations, map[string]string{"annotation-1": "val1"}), sparkApp.Spec.Executor.Annotations) + assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{"label-1": "val1"}), sparkApp.Spec.Executor.Labels) + assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], sparkApp.Spec.Executor.EnvVars["foo"]) + assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], sparkApp.Spec.Executor.EnvVars["fooEnv"]) + assert.Equal(t, testImage, *sparkApp.Spec.Executor.Image) + assert.Equal(t, defaultConfig.DefaultPodSecurityContext, sparkApp.Spec.Executor.SecurityContenxt) + assert.Equal(t, defaultConfig.DefaultPodDNSConfig, sparkApp.Spec.Executor.DNSConfig) + assert.Equal(t, defaultConfig.EnableHostNetworkingPod, sparkApp.Spec.Executor.HostNetwork) + assert.Equal(t, defaultConfig.SchedulerName, *sparkApp.Spec.Executor.SchedulerName) + assert.ElementsMatch(t, []corev1.Toleration{ + defaultConfig.DefaultTolerations[0], extraToleration, - }) + defaultConfig.InterruptibleTolerations[0], + }, sparkApp.Spec.Executor.Tolerations) + assert.Equal(t, map[string]string{ + "x/default": "true", + "x/custom": "foo", + "x/interruptible": "true", + }, sparkApp.Spec.Executor.NodeSelector) + assert.Equal(t, &corev1.NodeAffinity{ + RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{ + NodeSelectorTerms: []corev1.NodeSelectorTerm{ + { + MatchExpressions: []corev1.NodeSelectorRequirement{ + defaultConfig.DefaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], + *defaultConfig.InterruptibleNodeSelectorRequirement, + }, + }, + }, + }, + }, sparkApp.Spec.Executor.Affinity.NodeAffinity) + cores, _ = strconv.ParseInt(dummySparkConf["spark.executor.cores"], 10, 32) + instances, _ := strconv.ParseInt(dummySparkConf["spark.executor.instances"], 10, 32) + assert.Equal(t, intPtr(int32(instances)), sparkApp.Spec.Executor.Instances) + assert.Equal(t, intPtr(int32(cores)), sparkApp.Spec.Executor.Cores) + assert.Equal(t, dummySparkConf["spark.executor.memory"], *sparkApp.Spec.Executor.Memory) } func TestGetPropertiesSpark(t *testing.T) {