Skip to content

Commit

Permalink
Add more test coverage to TestBuildResourcePodTemplate
Browse files Browse the repository at this point in the history
Signed-off-by: Andrew Dye <[email protected]>
  • Loading branch information
andrewwdye committed Oct 10, 2023
1 parent b9f01c0 commit de2e42a
Showing 1 changed file with 98 additions and 23 deletions.
121 changes: 98 additions & 23 deletions flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"testing"

sj "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2"
sparkOp "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2"
"github.com/golang/protobuf/jsonpb"
structpb "github.com/golang/protobuf/ptypes/struct"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -743,45 +744,119 @@ func TestBuildResourceContainer(t *testing.T) {
}

func TestBuildResourcePodTemplate(t *testing.T) {
defaultToleration := corev1.Toleration{

Key: "x/flyte",
Value: "default",
Operator: "Equal",
}
err := config.SetK8sPluginConfig(&config.K8sPluginConfig{
DefaultTolerations: []corev1.Toleration{
defaultToleration,
},
})
assert.NoError(t, err)
defaultConfig := defaultPluginConfig()
assert.NoError(t, config.SetK8sPluginConfig(defaultConfig))
extraToleration := corev1.Toleration{
Key: "x/flyte",
Value: "extra",
Operator: "Equal",
}
podSpec := dummyPodSpec()
podSpec.Tolerations = append(podSpec.Tolerations, extraToleration)
podSpec.NodeSelector["x/custom"] = "foo"
podSpec.NodeSelector = map[string]string{"x/custom": "foo"}
taskTemplate := dummySparkTaskTemplatePod("blah-1", dummySparkConf, podSpec)
taskTemplate.GetK8SPod()
sparkResourceHandler := sparkResourceHandler{}
resource, err := sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, false))
assert.Nil(t, err)

taskCtx := dummySparkTaskContext(taskTemplate, true)
resource, err := sparkResourceHandler.BuildResource(context.TODO(), taskCtx)

assert.Nil(t, err)
assert.NotNil(t, resource)
sparkApp, ok := resource.(*sj.SparkApplication)
assert.True(t, ok)
assert.Equal(t, 2, len(sparkApp.Spec.Driver.Tolerations))
assert.Equal(t, sparkApp.Spec.Driver.Tolerations, []corev1.Toleration{
defaultToleration,

// Application
assert.Equal(t, v1.TypeMeta{
Kind: KindSparkApplication,
APIVersion: sparkOp.SchemeGroupVersion.String(),
}, sparkApp.TypeMeta)

// Application spec
assert.Equal(t, flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()), *sparkApp.Spec.ServiceAccount)
assert.Equal(t, sparkOp.PythonApplicationType, sparkApp.Spec.Type)
assert.Equal(t, testImage, *sparkApp.Spec.Image)
assert.Equal(t, testArgs, sparkApp.Spec.Arguments)
assert.Equal(t, sparkOp.RestartPolicy{
Type: sparkOp.OnFailure,
OnSubmissionFailureRetries: intPtr(int32(14)),
}, sparkApp.Spec.RestartPolicy)
assert.Equal(t, sparkMainClass, *sparkApp.Spec.MainClass)
assert.Equal(t, sparkApplicationFile, *sparkApp.Spec.MainApplicationFile)

// Driver
assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultAnnotations, map[string]string{"annotation-1": "val1"}), sparkApp.Spec.Driver.Annotations)
assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{"label-1": "val1"}), sparkApp.Spec.Driver.Labels)
assert.Equal(t, len(sparkApp.Spec.Driver.EnvVars["FLYTE_MAX_ATTEMPTS"]), 1)
assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], sparkApp.Spec.Driver.EnvVars["foo"])
assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], sparkApp.Spec.Driver.EnvVars["fooEnv"])
assert.Equal(t, testImage, *sparkApp.Spec.Driver.Image)
assert.Equal(t, flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()), *sparkApp.Spec.Driver.ServiceAccount)
assert.Equal(t, defaultConfig.DefaultPodSecurityContext, sparkApp.Spec.Driver.SecurityContenxt)
assert.Equal(t, defaultConfig.DefaultPodDNSConfig, sparkApp.Spec.Driver.DNSConfig)
assert.Equal(t, defaultConfig.EnableHostNetworkingPod, sparkApp.Spec.Driver.HostNetwork)
assert.Equal(t, defaultConfig.SchedulerName, *sparkApp.Spec.Driver.SchedulerName)
assert.Equal(t, []corev1.Toleration{
defaultConfig.DefaultTolerations[0],
extraToleration,
})
assert.Equal(t, 2, len(sparkApp.Spec.Executor.Tolerations))
assert.Equal(t, sparkApp.Spec.Executor.Tolerations, []corev1.Toleration{
defaultToleration,
}, sparkApp.Spec.Driver.Tolerations)
assert.Equal(t, map[string]string{
"x/default": "true",
"x/custom": "foo",
}, sparkApp.Spec.Driver.NodeSelector)
assert.Equal(t, &corev1.NodeAffinity{
RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{
NodeSelectorTerms: []corev1.NodeSelectorTerm{
{
MatchExpressions: []corev1.NodeSelectorRequirement{
defaultConfig.DefaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0],
*defaultConfig.NonInterruptibleNodeSelectorRequirement,
},
},
},
},
}, sparkApp.Spec.Driver.Affinity.NodeAffinity)
cores, _ := strconv.ParseInt(dummySparkConf["spark.driver.cores"], 10, 32)
assert.Equal(t, intPtr(int32(cores)), sparkApp.Spec.Driver.Cores)
assert.Equal(t, dummySparkConf["spark.driver.memory"], *sparkApp.Spec.Driver.Memory)

// Executor
assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultAnnotations, map[string]string{"annotation-1": "val1"}), sparkApp.Spec.Executor.Annotations)
assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{"label-1": "val1"}), sparkApp.Spec.Executor.Labels)
assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], sparkApp.Spec.Executor.EnvVars["foo"])
assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], sparkApp.Spec.Executor.EnvVars["fooEnv"])
assert.Equal(t, testImage, *sparkApp.Spec.Executor.Image)
assert.Equal(t, defaultConfig.DefaultPodSecurityContext, sparkApp.Spec.Executor.SecurityContenxt)
assert.Equal(t, defaultConfig.DefaultPodDNSConfig, sparkApp.Spec.Executor.DNSConfig)
assert.Equal(t, defaultConfig.EnableHostNetworkingPod, sparkApp.Spec.Executor.HostNetwork)
assert.Equal(t, defaultConfig.SchedulerName, *sparkApp.Spec.Executor.SchedulerName)
assert.ElementsMatch(t, []corev1.Toleration{
defaultConfig.DefaultTolerations[0],
extraToleration,
})
defaultConfig.InterruptibleTolerations[0],
}, sparkApp.Spec.Executor.Tolerations)
assert.Equal(t, map[string]string{
"x/default": "true",
"x/custom": "foo",
"x/interruptible": "true",
}, sparkApp.Spec.Executor.NodeSelector)
assert.Equal(t, &corev1.NodeAffinity{
RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{
NodeSelectorTerms: []corev1.NodeSelectorTerm{
{
MatchExpressions: []corev1.NodeSelectorRequirement{
defaultConfig.DefaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0],
*defaultConfig.InterruptibleNodeSelectorRequirement,
},
},
},
},
}, sparkApp.Spec.Executor.Affinity.NodeAffinity)
cores, _ = strconv.ParseInt(dummySparkConf["spark.executor.cores"], 10, 32)
instances, _ := strconv.ParseInt(dummySparkConf["spark.executor.instances"], 10, 32)
assert.Equal(t, intPtr(int32(instances)), sparkApp.Spec.Executor.Instances)
assert.Equal(t, intPtr(int32(cores)), sparkApp.Spec.Executor.Cores)
assert.Equal(t, dummySparkConf["spark.executor.memory"], *sparkApp.Spec.Executor.Memory)
}

func TestGetPropertiesSpark(t *testing.T) {
Expand Down

0 comments on commit de2e42a

Please sign in to comment.