diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go index 95bae4dceb..18565fc0a4 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go @@ -7,6 +7,7 @@ import ( "testing" sj "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2" + sparkOp "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2" "github.com/golang/protobuf/jsonpb" structpb "github.com/golang/protobuf/ptypes/struct" "github.com/stretchr/testify/assert" @@ -743,18 +744,8 @@ func TestBuildResourceContainer(t *testing.T) { } func TestBuildResourcePodTemplate(t *testing.T) { - defaultToleration := corev1.Toleration{ - - Key: "x/flyte", - Value: "default", - Operator: "Equal", - } - err := config.SetK8sPluginConfig(&config.K8sPluginConfig{ - DefaultTolerations: []corev1.Toleration{ - defaultToleration, - }, - }) - assert.NoError(t, err) + defaultConfig := defaultPluginConfig() + assert.NoError(t, config.SetK8sPluginConfig(defaultConfig)) extraToleration := corev1.Toleration{ Key: "x/flyte", Value: "extra", @@ -762,26 +753,110 @@ func TestBuildResourcePodTemplate(t *testing.T) { } podSpec := dummyPodSpec() podSpec.Tolerations = append(podSpec.Tolerations, extraToleration) - podSpec.NodeSelector["x/custom"] = "foo" + podSpec.NodeSelector = map[string]string{"x/custom": "foo"} taskTemplate := dummySparkTaskTemplatePod("blah-1", dummySparkConf, podSpec) taskTemplate.GetK8SPod() sparkResourceHandler := sparkResourceHandler{} - resource, err := sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, false)) - assert.Nil(t, err) + taskCtx := dummySparkTaskContext(taskTemplate, true) + resource, err := sparkResourceHandler.BuildResource(context.TODO(), taskCtx) + + assert.Nil(t, err) assert.NotNil(t, resource) sparkApp, ok := resource.(*sj.SparkApplication) assert.True(t, ok) - assert.Equal(t, 2, len(sparkApp.Spec.Driver.Tolerations)) - assert.Equal(t, sparkApp.Spec.Driver.Tolerations, []corev1.Toleration{ - defaultToleration, + + // Application + assert.Equal(t, v1.TypeMeta{ + Kind: KindSparkApplication, + APIVersion: sparkOp.SchemeGroupVersion.String(), + }, sparkApp.TypeMeta) + + // Application spec + assert.Equal(t, flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()), *sparkApp.Spec.ServiceAccount) + assert.Equal(t, sparkOp.PythonApplicationType, sparkApp.Spec.Type) + assert.Equal(t, testImage, *sparkApp.Spec.Image) + assert.Equal(t, testArgs, sparkApp.Spec.Arguments) + assert.Equal(t, sparkOp.RestartPolicy{ + Type: sparkOp.OnFailure, + OnSubmissionFailureRetries: intPtr(int32(14)), + }, sparkApp.Spec.RestartPolicy) + assert.Equal(t, sparkMainClass, *sparkApp.Spec.MainClass) + assert.Equal(t, sparkApplicationFile, *sparkApp.Spec.MainApplicationFile) + + // Driver + assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultAnnotations, map[string]string{"annotation-1": "val1"}), sparkApp.Spec.Driver.Annotations) + assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{"label-1": "val1"}), sparkApp.Spec.Driver.Labels) + assert.Equal(t, len(sparkApp.Spec.Driver.EnvVars["FLYTE_MAX_ATTEMPTS"]), 1) + assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], sparkApp.Spec.Driver.EnvVars["foo"]) + assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], sparkApp.Spec.Driver.EnvVars["fooEnv"]) + assert.Equal(t, testImage, *sparkApp.Spec.Driver.Image) + assert.Equal(t, flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()), *sparkApp.Spec.Driver.ServiceAccount) + assert.Equal(t, defaultConfig.DefaultPodSecurityContext, sparkApp.Spec.Driver.SecurityContenxt) + assert.Equal(t, defaultConfig.DefaultPodDNSConfig, sparkApp.Spec.Driver.DNSConfig) + assert.Equal(t, defaultConfig.EnableHostNetworkingPod, sparkApp.Spec.Driver.HostNetwork) + assert.Equal(t, defaultConfig.SchedulerName, *sparkApp.Spec.Driver.SchedulerName) + assert.Equal(t, []corev1.Toleration{ + defaultConfig.DefaultTolerations[0], extraToleration, - }) - assert.Equal(t, 2, len(sparkApp.Spec.Executor.Tolerations)) - assert.Equal(t, sparkApp.Spec.Executor.Tolerations, []corev1.Toleration{ - defaultToleration, + }, sparkApp.Spec.Driver.Tolerations) + assert.Equal(t, map[string]string{ + "x/default": "true", + "x/custom": "foo", + }, sparkApp.Spec.Driver.NodeSelector) + assert.Equal(t, &corev1.NodeAffinity{ + RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{ + NodeSelectorTerms: []corev1.NodeSelectorTerm{ + { + MatchExpressions: []corev1.NodeSelectorRequirement{ + defaultConfig.DefaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], + *defaultConfig.NonInterruptibleNodeSelectorRequirement, + }, + }, + }, + }, + }, sparkApp.Spec.Driver.Affinity.NodeAffinity) + cores, _ := strconv.ParseInt(dummySparkConf["spark.driver.cores"], 10, 32) + assert.Equal(t, intPtr(int32(cores)), sparkApp.Spec.Driver.Cores) + assert.Equal(t, dummySparkConf["spark.driver.memory"], *sparkApp.Spec.Driver.Memory) + + // Executor + assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultAnnotations, map[string]string{"annotation-1": "val1"}), sparkApp.Spec.Executor.Annotations) + assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{"label-1": "val1"}), sparkApp.Spec.Executor.Labels) + assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], sparkApp.Spec.Executor.EnvVars["foo"]) + assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], sparkApp.Spec.Executor.EnvVars["fooEnv"]) + assert.Equal(t, testImage, *sparkApp.Spec.Executor.Image) + assert.Equal(t, defaultConfig.DefaultPodSecurityContext, sparkApp.Spec.Executor.SecurityContenxt) + assert.Equal(t, defaultConfig.DefaultPodDNSConfig, sparkApp.Spec.Executor.DNSConfig) + assert.Equal(t, defaultConfig.EnableHostNetworkingPod, sparkApp.Spec.Executor.HostNetwork) + assert.Equal(t, defaultConfig.SchedulerName, *sparkApp.Spec.Executor.SchedulerName) + assert.ElementsMatch(t, []corev1.Toleration{ + defaultConfig.DefaultTolerations[0], extraToleration, - }) + defaultConfig.InterruptibleTolerations[0], + }, sparkApp.Spec.Executor.Tolerations) + assert.Equal(t, map[string]string{ + "x/default": "true", + "x/custom": "foo", + "x/interruptible": "true", + }, sparkApp.Spec.Executor.NodeSelector) + assert.Equal(t, &corev1.NodeAffinity{ + RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{ + NodeSelectorTerms: []corev1.NodeSelectorTerm{ + { + MatchExpressions: []corev1.NodeSelectorRequirement{ + defaultConfig.DefaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], + *defaultConfig.InterruptibleNodeSelectorRequirement, + }, + }, + }, + }, + }, sparkApp.Spec.Executor.Affinity.NodeAffinity) + cores, _ = strconv.ParseInt(dummySparkConf["spark.executor.cores"], 10, 32) + instances, _ := strconv.ParseInt(dummySparkConf["spark.executor.instances"], 10, 32) + assert.Equal(t, intPtr(int32(instances)), sparkApp.Spec.Executor.Instances) + assert.Equal(t, intPtr(int32(cores)), sparkApp.Spec.Executor.Cores) + assert.Equal(t, dummySparkConf["spark.executor.memory"], *sparkApp.Spec.Executor.Memory) } func TestGetPropertiesSpark(t *testing.T) {