diff --git a/flyteidl/gen/pb-go/flyteidl/plugins/spark.pb.go b/flyteidl/gen/pb-go/flyteidl/plugins/spark.pb.go index 47063ecb9f..5180512190 100644 --- a/flyteidl/gen/pb-go/flyteidl/plugins/spark.pb.go +++ b/flyteidl/gen/pb-go/flyteidl/plugins/spark.pb.go @@ -134,8 +134,8 @@ type SparkJob struct { // This instance name can be set in either flytepropeller or flytekit. DatabricksInstance string `protobuf:"bytes,9,opt,name=databricksInstance,proto3" json:"databricksInstance,omitempty"` - DriverPod *core.K8SPod `protobuf:"bytes,2,opt,name=driverPod,json=driverPod,proto3" json:"driverPod,omitempty"` - ExecutorPod *core.K8SPod `protobuf:"bytes,2,opt,name=executorPod,json=executorPod,proto3" json:"executorPod,omitempty"` + DriverPod *core.K8SPod `protobuf:"bytes,10,opt,name=driverPod,json=driverPod,proto3" json:"driverPod,omitempty"` + ExecutorPod *core.K8SPod `protobuf:"bytes,11,opt,name=executorPod,json=executorPod,proto3" json:"executorPod,omitempty"` } func (x *SparkJob) Reset() { diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go index 58f12382a3..80172d9bf2 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go @@ -141,10 +141,10 @@ func serviceAccountName(metadata pluginsCore.TaskExecutionMetadata) string { return name } -func createSparkPodSpec(taskCtx pluginsCore.TaskExecutionContext, podSpec *v1.PodSpec, container *v1.Container, podAnnotations map[string]string, podLabels map[string]string) *sparkOp.SparkPodSpec { +func createSparkPodSpec(taskCtx pluginsCore.TaskExecutionContext, podSpec *v1.PodSpec, container *v1.Container, k8sPod core.K8SPod) *sparkOp.SparkPodSpec { // TODO: check whether merge annotations/labels together or other ways? - annotations := utils.UnionMaps(config.GetK8sPluginConfig().DefaultAnnotations, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()), podAnnotations) - labels := utils.UnionMaps(config.GetK8sPluginConfig().DefaultLabels, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()), podLabels) + annotations := utils.UnionMaps(config.GetK8sPluginConfig().DefaultAnnotations, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()), k8sPod.Metadata.Annotations) + labels := utils.UnionMaps(config.GetK8sPluginConfig().DefaultLabels, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()), k8sPod.Metadata.Labels) sparkEnv := make([]v1.EnvVar, 0) for _, envVar := range container.Env { @@ -183,22 +183,18 @@ func createDriverSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionCont // TODO: Validate whether the following function is correct // If DriverPod exist in sparkJob and is primary, use it instead - var podAnnotations map[string]string - var podLabels map[string]string if sparkJob.DriverPod != nil { podSpec, err = unmarshalK8sPod(podSpec, sparkJob.DriverPod, primaryContainerName) if err != nil { return nil, err } - podAnnotations = sparkJob.DriverPod.Metadata.Annotations - podLabels = sparkJob.DriverPod.Metadata.Labels } primaryContainer, err := flytek8s.GetContainer(podSpec, primaryContainerName) if err != nil { return nil, err } - sparkPodSpec := createSparkPodSpec(nonInterruptibleTaskCtx, podSpec, primaryContainer, podAnnotations, podLabels) + sparkPodSpec := createSparkPodSpec(nonInterruptibleTaskCtx, podSpec, primaryContainer, *sparkJob.DriverPod) serviceAccountName := serviceAccountName(nonInterruptibleTaskCtx.TaskExecutionMetadata()) spec := driverSpec{ &sparkOp.DriverSpec{ @@ -260,22 +256,18 @@ func createExecutorSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionCo // TODO: Validate whether the following function is correct // If DriverPod exist in sparkJob and is primary, use it instead - var podAnnotations map[string]string - var podLabels map[string]string if sparkJob.ExecutorPod != nil { podSpec, err = unmarshalK8sPod(podSpec, sparkJob.ExecutorPod, primaryContainerName) if err != nil { return nil, err } - podAnnotations = sparkJob.ExecutorPod.Metadata.Annotations - podLabels = sparkJob.ExecutorPod.Metadata.Labels } primaryContainer, err := flytek8s.GetContainer(podSpec, primaryContainerName) if err != nil { return nil, err } - sparkPodSpec := createSparkPodSpec(taskCtx, podSpec, primaryContainer, podAnnotations, podLabels) + sparkPodSpec := createSparkPodSpec(taskCtx, podSpec, primaryContainer, *sparkJob.ExecutorPod) serviceAccountName := serviceAccountName(taskCtx.TaskExecutionMetadata()) spec := executorSpec{ primaryContainer,