Skip to content

Commit

Permalink
test: update test for SparkJob K8sPod (unfinished)
Browse files Browse the repository at this point in the history
Signed-off-by: machichima <[email protected]>
  • Loading branch information
machichima committed Dec 11, 2024
1 parent f64e171 commit 1faaa00
Showing 1 changed file with 202 additions and 1 deletion.
203 changes: 202 additions & 1 deletion flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@ package spark

import (
"context"
"encoding/json"
"os"
"reflect"
"strconv"
"testing"

sj "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2"
sparkOp "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2"
structpb "github.com/golang/protobuf/ptypes/struct"
// NOTE: this import also use things inside google.golang structpb one
// structpb "github.com/golang/protobuf/ptypes/struct"
"google.golang.org/protobuf/types/known/structpb"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
corev1 "k8s.io/api/core/v1"
Expand Down Expand Up @@ -283,6 +286,19 @@ func dummySparkCustomObj(sparkConf map[string]string) *plugins.SparkJob {
return &sparkJob
}

func dummySparkCustomObjDriverExecutor(sparkConf map[string]string, driverPod *core.K8SPod, executorPod *core.K8SPod) *plugins.SparkJob {
sparkJob := plugins.SparkJob{}

sparkJob.MainClass = sparkMainClass
sparkJob.MainApplicationFile = sparkApplicationFile
sparkJob.SparkConf = sparkConf
sparkJob.ApplicationType = plugins.SparkApplication_PYTHON

sparkJob.DriverPod = driverPod
sparkJob.ExecutorPod = executorPod
return &sparkJob
}

func dummyPodSpec() *corev1.PodSpec {
return &corev1.PodSpec{
InitContainers: []corev1.Container{
Expand Down Expand Up @@ -337,7 +353,33 @@ func dummySparkTaskTemplateContainer(id string, sparkConf map[string]string) *co
}
}

func dummySparkTaskTemplateDriverExecutor(id string, sparkConf map[string]string, driverPod *core.K8SPod, executorPod *core.K8SPod) *core.TaskTemplate {
sparkJob := dummySparkCustomObjDriverExecutor(sparkConf, driverPod, executorPod)

structObj, err := utils.MarshalObjToStruct(sparkJob)
if err != nil {
panic(err)
}

return &core.TaskTemplate{
Id: &core.Identifier{Name: id},
Type: "container",
Target: &core.TaskTemplate_Container{
Container: &core.Container{
Image: testImage,
Args: testArgs,
Env: dummyEnvVars,
},
},
Config: map[string]string{
flytek8s.PrimaryContainerKey: "primary",
},
Custom: structObj,
}
}

func dummySparkTaskTemplatePod(id string, sparkConf map[string]string, podSpec *corev1.PodSpec) *core.TaskTemplate {
// add driver/executor pod below
sparkJob := dummySparkCustomObj(sparkConf)
sparkJobJSON, err := utils.MarshalToString(sparkJob)
if err != nil {
Expand Down Expand Up @@ -930,3 +972,162 @@ func TestGetPropertiesSpark(t *testing.T) {
expected := k8s.PluginProperties{}
assert.Equal(t, expected, sparkResourceHandler.GetProperties())
}

func TestBuildResourceCustomK8SPod(t *testing.T) {
// TODO: edit below tests for custom driver and executor
// the TestBuildResourcePodTemplate test whether the custom Toleration is displayed

// create dummy driver and executor pod
// dummy sparkJob that takes in dummy driver and executor pod
// see whether the driver and worker podSpec is what we set
// what properties to test

defaultConfig := defaultPluginConfig()
assert.NoError(t, config.SetK8sPluginConfig(defaultConfig))

// add extraDriverToleration and extraExecutorToleration
driverExtraToleration := corev1.Toleration{
Key: "x/flyte-driver",
Value: "extra-driver",
Operator: "Equal",
}
executorExtraToleration := corev1.Toleration{
Key: "x/flyte-executor",
Value: "extra-executor",
Operator: "Equal",
}

// pod for driver and executor
driverPodSpec := dummyPodSpec()
executorPodSpec := dummyPodSpec()
driverPodSpec.Tolerations = append(driverPodSpec.Tolerations, driverExtraToleration)
driverPodSpec.NodeSelector = map[string]string{"x/custom": "foo-driver"}
executorPodSpec.Tolerations = append(executorPodSpec.Tolerations, executorExtraToleration)
executorPodSpec.NodeSelector = map[string]string{"x/custom": "foo-executor"}

driverK8SPod := &core.K8SPod{
PodSpec: transformStructToStructPB(t, driverPodSpec),
}
executorK8SPod := &core.K8SPod{
PodSpec: transformStructToStructPB(t, executorPodSpec),
}
// put the driver/executor podspec (add custom tolerations) to below function
taskTemplate := dummySparkTaskTemplateDriverExecutor("blah-1", dummySparkConf, driverK8SPod, executorK8SPod)
sparkResourceHandler := sparkResourceHandler{}

taskCtx := dummySparkTaskContext(taskTemplate, true, k8s.PluginState{})
resource, err := sparkResourceHandler.BuildResource(context.TODO(), taskCtx)

assert.Nil(t, err)
assert.NotNil(t, resource)
sparkApp, ok := resource.(*sj.SparkApplication)
assert.True(t, ok)

// Application
assert.Equal(t, v1.TypeMeta{
Kind: KindSparkApplication,
APIVersion: sparkOp.SchemeGroupVersion.String(),
}, sparkApp.TypeMeta)

// Application spec
assert.Equal(t, flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()), *sparkApp.Spec.ServiceAccount)
assert.Equal(t, sparkOp.PythonApplicationType, sparkApp.Spec.Type)
assert.Equal(t, testImage, *sparkApp.Spec.Image)
assert.Equal(t, testArgs, sparkApp.Spec.Arguments)
assert.Equal(t, sparkOp.RestartPolicy{
Type: sparkOp.OnFailure,
OnSubmissionFailureRetries: intPtr(int32(14)),
}, sparkApp.Spec.RestartPolicy)
assert.Equal(t, sparkMainClass, *sparkApp.Spec.MainClass)
assert.Equal(t, sparkApplicationFile, *sparkApp.Spec.MainApplicationFile)

// Driver
assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultAnnotations, map[string]string{"annotation-1": "val1"}), sparkApp.Spec.Driver.Annotations)
assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{"label-1": "val1"}), sparkApp.Spec.Driver.Labels)
assert.Equal(t, len(findEnvVarByName(sparkApp.Spec.Driver.Env, "FLYTE_MAX_ATTEMPTS").Value), 1)
// assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Driver.Env, "foo").Value)
// assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Driver.Env, "fooEnv").Value)
assert.Equal(t, findEnvVarByName(dummyEnvVarsWithSecretRef, "SECRET"), findEnvVarByName(sparkApp.Spec.Driver.Env, "SECRET"))
// assert.Equal(t, 9, len(sparkApp.Spec.Driver.Env))
assert.Equal(t, testImage, *sparkApp.Spec.Driver.Image)
assert.Equal(t, flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()), *sparkApp.Spec.Driver.ServiceAccount)
// assert.Equal(t, defaultConfig.DefaultPodSecurityContext, sparkApp.Spec.Driver.SecurityContenxt)
// assert.Equal(t, defaultConfig.DefaultPodDNSConfig, sparkApp.Spec.Driver.DNSConfig)
// assert.Equal(t, defaultConfig.EnableHostNetworkingPod, sparkApp.Spec.Driver.HostNetwork)
// assert.Equal(t, defaultConfig.SchedulerName, *sparkApp.Spec.Driver.SchedulerName)
assert.Equal(t, []corev1.Toleration{
defaultConfig.DefaultTolerations[0],
driverExtraToleration,
}, sparkApp.Spec.Driver.Tolerations)
assert.Equal(t, map[string]string{
"x/default": "true",
"x/custom": "foo-driver",
}, sparkApp.Spec.Driver.NodeSelector)
assert.Equal(t, &corev1.NodeAffinity{
RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{
NodeSelectorTerms: []corev1.NodeSelectorTerm{
{
MatchExpressions: []corev1.NodeSelectorRequirement{
defaultConfig.DefaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0],
*defaultConfig.NonInterruptibleNodeSelectorRequirement,
},
},
},
},
}, sparkApp.Spec.Driver.Affinity.NodeAffinity)
cores, _ := strconv.ParseInt(dummySparkConf["spark.driver.cores"], 10, 32)
assert.Equal(t, intPtr(int32(cores)), sparkApp.Spec.Driver.Cores)
assert.Equal(t, dummySparkConf["spark.driver.memory"], *sparkApp.Spec.Driver.Memory)

// // Executor
assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultAnnotations, map[string]string{"annotation-1": "val1"}), sparkApp.Spec.Executor.Annotations)
assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{"label-1": "val1"}), sparkApp.Spec.Executor.Labels)
assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Executor.Env, "foo").Value)
assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Executor.Env, "fooEnv").Value)
assert.Equal(t, findEnvVarByName(dummyEnvVarsWithSecretRef, "SECRET"), findEnvVarByName(sparkApp.Spec.Executor.Env, "SECRET"))
assert.Equal(t, 9, len(sparkApp.Spec.Executor.Env))
assert.Equal(t, testImage, *sparkApp.Spec.Executor.Image)
assert.Equal(t, defaultConfig.DefaultPodSecurityContext, sparkApp.Spec.Executor.SecurityContenxt)
assert.Equal(t, defaultConfig.DefaultPodDNSConfig, sparkApp.Spec.Executor.DNSConfig)
assert.Equal(t, defaultConfig.EnableHostNetworkingPod, sparkApp.Spec.Executor.HostNetwork)
assert.Equal(t, defaultConfig.SchedulerName, *sparkApp.Spec.Executor.SchedulerName)
assert.ElementsMatch(t, []corev1.Toleration{
defaultConfig.DefaultTolerations[0],
executorExtraToleration,
defaultConfig.InterruptibleTolerations[0],
}, sparkApp.Spec.Executor.Tolerations)
assert.Equal(t, map[string]string{
"x/default": "true",
"x/custom": "foo-executor",
"x/interruptible": "true",
}, sparkApp.Spec.Executor.NodeSelector)
assert.Equal(t, &corev1.NodeAffinity{
RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{
NodeSelectorTerms: []corev1.NodeSelectorTerm{
{
MatchExpressions: []corev1.NodeSelectorRequirement{
defaultConfig.DefaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0],
*defaultConfig.InterruptibleNodeSelectorRequirement,
},
},
},
},
}, sparkApp.Spec.Executor.Affinity.NodeAffinity)
cores, _ = strconv.ParseInt(dummySparkConf["spark.executor.cores"], 10, 32)
instances, _ := strconv.ParseInt(dummySparkConf["spark.executor.instances"], 10, 32)
assert.Equal(t, intPtr(int32(instances)), sparkApp.Spec.Executor.Instances)
assert.Equal(t, intPtr(int32(cores)), sparkApp.Spec.Executor.Cores)
assert.Equal(t, dummySparkConf["spark.executor.memory"], *sparkApp.Spec.Executor.Memory)
}


func transformStructToStructPB(t *testing.T, obj interface{}) *structpb.Struct {
data, err := json.Marshal(obj)
assert.Nil(t, err)
podSpecMap := make(map[string]interface{})
err = json.Unmarshal(data, &podSpecMap)
assert.Nil(t, err)
s, err := structpb.NewStruct(podSpecMap)
assert.Nil(t, err)
return s
}

0 comments on commit 1faaa00

Please sign in to comment.