diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go index 0a6f51d0e2..9f9830fc17 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go @@ -2,6 +2,7 @@ package spark import ( "context" + "encoding/json" "os" "reflect" "strconv" @@ -9,7 +10,9 @@ import ( sj "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2" sparkOp "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2" - structpb "github.com/golang/protobuf/ptypes/struct" + // NOTE: this import also use things inside google.golang structpb one + // structpb "github.com/golang/protobuf/ptypes/struct" + "google.golang.org/protobuf/types/known/structpb" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" corev1 "k8s.io/api/core/v1" @@ -283,6 +286,19 @@ func dummySparkCustomObj(sparkConf map[string]string) *plugins.SparkJob { return &sparkJob } +func dummySparkCustomObjDriverExecutor(sparkConf map[string]string, driverPod *core.K8SPod, executorPod *core.K8SPod) *plugins.SparkJob { + sparkJob := plugins.SparkJob{} + + sparkJob.MainClass = sparkMainClass + sparkJob.MainApplicationFile = sparkApplicationFile + sparkJob.SparkConf = sparkConf + sparkJob.ApplicationType = plugins.SparkApplication_PYTHON + + sparkJob.DriverPod = driverPod + sparkJob.ExecutorPod = executorPod + return &sparkJob +} + func dummyPodSpec() *corev1.PodSpec { return &corev1.PodSpec{ InitContainers: []corev1.Container{ @@ -337,7 +353,33 @@ func dummySparkTaskTemplateContainer(id string, sparkConf map[string]string) *co } } +func dummySparkTaskTemplateDriverExecutor(id string, sparkConf map[string]string, driverPod *core.K8SPod, executorPod *core.K8SPod) *core.TaskTemplate { + sparkJob := dummySparkCustomObjDriverExecutor(sparkConf, driverPod, executorPod) + + structObj, err := utils.MarshalObjToStruct(sparkJob) + if err != nil { + panic(err) + } + + return &core.TaskTemplate{ + Id: &core.Identifier{Name: id}, + Type: "container", + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Image: testImage, + Args: testArgs, + Env: dummyEnvVars, + }, + }, + Config: map[string]string{ + flytek8s.PrimaryContainerKey: "primary", + }, + Custom: structObj, + } +} + func dummySparkTaskTemplatePod(id string, sparkConf map[string]string, podSpec *corev1.PodSpec) *core.TaskTemplate { + // add driver/executor pod below sparkJob := dummySparkCustomObj(sparkConf) sparkJobJSON, err := utils.MarshalToString(sparkJob) if err != nil { @@ -930,3 +972,162 @@ func TestGetPropertiesSpark(t *testing.T) { expected := k8s.PluginProperties{} assert.Equal(t, expected, sparkResourceHandler.GetProperties()) } + +func TestBuildResourceCustomK8SPod(t *testing.T) { + // TODO: edit below tests for custom driver and executor + // the TestBuildResourcePodTemplate test whether the custom Toleration is displayed + + // create dummy driver and executor pod + // dummy sparkJob that takes in dummy driver and executor pod + // see whether the driver and worker podSpec is what we set + // what properties to test + + defaultConfig := defaultPluginConfig() + assert.NoError(t, config.SetK8sPluginConfig(defaultConfig)) + + // add extraDriverToleration and extraExecutorToleration + driverExtraToleration := corev1.Toleration{ + Key: "x/flyte-driver", + Value: "extra-driver", + Operator: "Equal", + } + executorExtraToleration := corev1.Toleration{ + Key: "x/flyte-executor", + Value: "extra-executor", + Operator: "Equal", + } + + // pod for driver and executor + driverPodSpec := dummyPodSpec() + executorPodSpec := dummyPodSpec() + driverPodSpec.Tolerations = append(driverPodSpec.Tolerations, driverExtraToleration) + driverPodSpec.NodeSelector = map[string]string{"x/custom": "foo-driver"} + executorPodSpec.Tolerations = append(executorPodSpec.Tolerations, executorExtraToleration) + executorPodSpec.NodeSelector = map[string]string{"x/custom": "foo-executor"} + + driverK8SPod := &core.K8SPod{ + PodSpec: transformStructToStructPB(t, driverPodSpec), + } + executorK8SPod := &core.K8SPod{ + PodSpec: transformStructToStructPB(t, executorPodSpec), + } + // put the driver/executor podspec (add custom tolerations) to below function + taskTemplate := dummySparkTaskTemplateDriverExecutor("blah-1", dummySparkConf, driverK8SPod, executorK8SPod) + sparkResourceHandler := sparkResourceHandler{} + + taskCtx := dummySparkTaskContext(taskTemplate, true, k8s.PluginState{}) + resource, err := sparkResourceHandler.BuildResource(context.TODO(), taskCtx) + + assert.Nil(t, err) + assert.NotNil(t, resource) + sparkApp, ok := resource.(*sj.SparkApplication) + assert.True(t, ok) + + // Application + assert.Equal(t, v1.TypeMeta{ + Kind: KindSparkApplication, + APIVersion: sparkOp.SchemeGroupVersion.String(), + }, sparkApp.TypeMeta) + + // Application spec + assert.Equal(t, flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()), *sparkApp.Spec.ServiceAccount) + assert.Equal(t, sparkOp.PythonApplicationType, sparkApp.Spec.Type) + assert.Equal(t, testImage, *sparkApp.Spec.Image) + assert.Equal(t, testArgs, sparkApp.Spec.Arguments) + assert.Equal(t, sparkOp.RestartPolicy{ + Type: sparkOp.OnFailure, + OnSubmissionFailureRetries: intPtr(int32(14)), + }, sparkApp.Spec.RestartPolicy) + assert.Equal(t, sparkMainClass, *sparkApp.Spec.MainClass) + assert.Equal(t, sparkApplicationFile, *sparkApp.Spec.MainApplicationFile) + + // Driver + assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultAnnotations, map[string]string{"annotation-1": "val1"}), sparkApp.Spec.Driver.Annotations) + assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{"label-1": "val1"}), sparkApp.Spec.Driver.Labels) + assert.Equal(t, len(findEnvVarByName(sparkApp.Spec.Driver.Env, "FLYTE_MAX_ATTEMPTS").Value), 1) + // assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Driver.Env, "foo").Value) + // assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Driver.Env, "fooEnv").Value) + assert.Equal(t, findEnvVarByName(dummyEnvVarsWithSecretRef, "SECRET"), findEnvVarByName(sparkApp.Spec.Driver.Env, "SECRET")) + // assert.Equal(t, 9, len(sparkApp.Spec.Driver.Env)) + assert.Equal(t, testImage, *sparkApp.Spec.Driver.Image) + assert.Equal(t, flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()), *sparkApp.Spec.Driver.ServiceAccount) + // assert.Equal(t, defaultConfig.DefaultPodSecurityContext, sparkApp.Spec.Driver.SecurityContenxt) + // assert.Equal(t, defaultConfig.DefaultPodDNSConfig, sparkApp.Spec.Driver.DNSConfig) + // assert.Equal(t, defaultConfig.EnableHostNetworkingPod, sparkApp.Spec.Driver.HostNetwork) + // assert.Equal(t, defaultConfig.SchedulerName, *sparkApp.Spec.Driver.SchedulerName) + assert.Equal(t, []corev1.Toleration{ + defaultConfig.DefaultTolerations[0], + driverExtraToleration, + }, sparkApp.Spec.Driver.Tolerations) + assert.Equal(t, map[string]string{ + "x/default": "true", + "x/custom": "foo-driver", + }, sparkApp.Spec.Driver.NodeSelector) + assert.Equal(t, &corev1.NodeAffinity{ + RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{ + NodeSelectorTerms: []corev1.NodeSelectorTerm{ + { + MatchExpressions: []corev1.NodeSelectorRequirement{ + defaultConfig.DefaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], + *defaultConfig.NonInterruptibleNodeSelectorRequirement, + }, + }, + }, + }, + }, sparkApp.Spec.Driver.Affinity.NodeAffinity) + cores, _ := strconv.ParseInt(dummySparkConf["spark.driver.cores"], 10, 32) + assert.Equal(t, intPtr(int32(cores)), sparkApp.Spec.Driver.Cores) + assert.Equal(t, dummySparkConf["spark.driver.memory"], *sparkApp.Spec.Driver.Memory) + + // // Executor + assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultAnnotations, map[string]string{"annotation-1": "val1"}), sparkApp.Spec.Executor.Annotations) + assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{"label-1": "val1"}), sparkApp.Spec.Executor.Labels) + assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Executor.Env, "foo").Value) + assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Executor.Env, "fooEnv").Value) + assert.Equal(t, findEnvVarByName(dummyEnvVarsWithSecretRef, "SECRET"), findEnvVarByName(sparkApp.Spec.Executor.Env, "SECRET")) + assert.Equal(t, 9, len(sparkApp.Spec.Executor.Env)) + assert.Equal(t, testImage, *sparkApp.Spec.Executor.Image) + assert.Equal(t, defaultConfig.DefaultPodSecurityContext, sparkApp.Spec.Executor.SecurityContenxt) + assert.Equal(t, defaultConfig.DefaultPodDNSConfig, sparkApp.Spec.Executor.DNSConfig) + assert.Equal(t, defaultConfig.EnableHostNetworkingPod, sparkApp.Spec.Executor.HostNetwork) + assert.Equal(t, defaultConfig.SchedulerName, *sparkApp.Spec.Executor.SchedulerName) + assert.ElementsMatch(t, []corev1.Toleration{ + defaultConfig.DefaultTolerations[0], + executorExtraToleration, + defaultConfig.InterruptibleTolerations[0], + }, sparkApp.Spec.Executor.Tolerations) + assert.Equal(t, map[string]string{ + "x/default": "true", + "x/custom": "foo-executor", + "x/interruptible": "true", + }, sparkApp.Spec.Executor.NodeSelector) + assert.Equal(t, &corev1.NodeAffinity{ + RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{ + NodeSelectorTerms: []corev1.NodeSelectorTerm{ + { + MatchExpressions: []corev1.NodeSelectorRequirement{ + defaultConfig.DefaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], + *defaultConfig.InterruptibleNodeSelectorRequirement, + }, + }, + }, + }, + }, sparkApp.Spec.Executor.Affinity.NodeAffinity) + cores, _ = strconv.ParseInt(dummySparkConf["spark.executor.cores"], 10, 32) + instances, _ := strconv.ParseInt(dummySparkConf["spark.executor.instances"], 10, 32) + assert.Equal(t, intPtr(int32(instances)), sparkApp.Spec.Executor.Instances) + assert.Equal(t, intPtr(int32(cores)), sparkApp.Spec.Executor.Cores) + assert.Equal(t, dummySparkConf["spark.executor.memory"], *sparkApp.Spec.Executor.Memory) +} + + +func transformStructToStructPB(t *testing.T, obj interface{}) *structpb.Struct { + data, err := json.Marshal(obj) + assert.Nil(t, err) + podSpecMap := make(map[string]interface{}) + err = json.Unmarshal(data, &podSpecMap) + assert.Nil(t, err) + s, err := structpb.NewStruct(podSpecMap) + assert.Nil(t, err) + return s +}