Skip to content

Commit

Permalink
Build SparkApplicationSpec using ToK8sPodSpec
Browse files Browse the repository at this point in the history
Signed-off-by: Andrew Dye <[email protected]>
  • Loading branch information
andrewwdye committed Oct 10, 2023
1 parent 4a780c3 commit b326791
Show file tree
Hide file tree
Showing 5 changed files with 384 additions and 193 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package flytek8s

import (
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core"
)

// Wraps a regular TaskExecutionMetadata and overrides the IsInterruptible method to always return false
// This is useful as the runner and the scheduler pods should never be interruptible
type NonInterruptibleTaskExecutionMetadata struct {
metadata pluginsCore.TaskExecutionMetadata
}

func (n NonInterruptibleTaskExecutionMetadata) GetOwnerID() types.NamespacedName {
return n.metadata.GetOwnerID()
}

func (n NonInterruptibleTaskExecutionMetadata) GetTaskExecutionID() pluginsCore.TaskExecutionID {
return n.metadata.GetTaskExecutionID()
}

func (n NonInterruptibleTaskExecutionMetadata) GetNamespace() string {
return n.metadata.GetNamespace()
}

func (n NonInterruptibleTaskExecutionMetadata) GetOwnerReference() metav1.OwnerReference {
return n.metadata.GetOwnerReference()
}

func (n NonInterruptibleTaskExecutionMetadata) GetOverrides() pluginsCore.TaskOverrides {
return n.metadata.GetOverrides()
}

func (n NonInterruptibleTaskExecutionMetadata) GetLabels() map[string]string {
return n.metadata.GetLabels()
}

func (n NonInterruptibleTaskExecutionMetadata) GetMaxAttempts() uint32 {
return n.metadata.GetMaxAttempts()
}

func (n NonInterruptibleTaskExecutionMetadata) GetAnnotations() map[string]string {
return n.metadata.GetAnnotations()
}

func (n NonInterruptibleTaskExecutionMetadata) GetK8sServiceAccount() string {
return n.metadata.GetK8sServiceAccount()
}

func (n NonInterruptibleTaskExecutionMetadata) GetSecurityContext() core.SecurityContext {
return n.metadata.GetSecurityContext()
}

func (n NonInterruptibleTaskExecutionMetadata) GetPlatformResources() *v1.ResourceRequirements {
return n.metadata.GetPlatformResources()
}

func (n NonInterruptibleTaskExecutionMetadata) GetInterruptibleFailureThreshold() int32 {
return n.metadata.GetInterruptibleFailureThreshold()
}

func (n NonInterruptibleTaskExecutionMetadata) GetEnvironmentVariables() map[string]string {
return n.metadata.GetEnvironmentVariables()
}

func (n NonInterruptibleTaskExecutionMetadata) IsInterruptible() bool {
return false
}

// A wrapper around a regular TaskExecutionContext allowing to inject a custom TaskExecutionMetadata which is
// non-interruptible
type NonInterruptibleTaskExecutionContext struct {
pluginsCore.TaskExecutionContext
metadata NonInterruptibleTaskExecutionMetadata
}

func (n NonInterruptibleTaskExecutionContext) TaskExecutionMetadata() pluginsCore.TaskExecutionMetadata {
return n.metadata
}

func NewNonInterruptibleTaskExecutionContext(ctx pluginsCore.TaskExecutionContext) NonInterruptibleTaskExecutionContext {
return NonInterruptibleTaskExecutionContext{
TaskExecutionContext: ctx,
metadata: NonInterruptibleTaskExecutionMetadata{
metadata: ctx.TaskExecutionMetadata(),
},
}
}
9 changes: 9 additions & 0 deletions flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,15 @@ func ToK8sPodSpec(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (*
return podSpec, objectMeta, primaryContainerName, nil
}

func GetContainer(podSpec *v1.PodSpec, name string) (*v1.Container, error) {
for _, container := range podSpec.Containers {
if container.Name == name {
return &container, nil
}
}
return nil, pluginserrors.Errorf(pluginserrors.BadTaskSpecification, "invalid TaskSpecification, container [%s] not defined", name)
}

// getBasePodTemplate attempts to retrieve the PodTemplate to use as the base for k8s Pod configuration. This value can
// come from one of the following:
// (1) PodTemplate name in the TaskMetadata: This name is then looked up in the PodTemplateStore.
Expand Down
50 changes: 10 additions & 40 deletions flyteplugins/go/tasks/plugins/k8s/dask/dask.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ import (
"time"

daskAPI "github.com/dask/dask-kubernetes/v2023/dask_kubernetes/operator/go_client/pkg/apis/kubernetes.dask.org/v1"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/intstr"
"k8s.io/client-go/kubernetes/scheme"
"sigs.k8s.io/controller-runtime/pkg/client"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/errors"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/logs"
Expand All @@ -15,54 +21,19 @@ import (
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/intstr"
"k8s.io/client-go/kubernetes/scheme"
"sigs.k8s.io/controller-runtime/pkg/client"
)

const (
daskTaskType = "dask"
KindDaskJob = "DaskJob"
)

// Wraps a regular TaskExecutionMetadata and overrides the IsInterruptible method to always return false
// This is useful as the runner and the scheduler pods should never be interruptible
type nonInterruptibleTaskExecutionMetadata struct {
pluginsCore.TaskExecutionMetadata
}

func (n nonInterruptibleTaskExecutionMetadata) IsInterruptible() bool {
return false
}

// A wrapper around a regular TaskExecutionContext allowing to inject a custom TaskExecutionMetadata which is
// non-interruptible
type nonInterruptibleTaskExecutionContext struct {
pluginsCore.TaskExecutionContext
metadata nonInterruptibleTaskExecutionMetadata
}

func (n nonInterruptibleTaskExecutionContext) TaskExecutionMetadata() pluginsCore.TaskExecutionMetadata {
return n.metadata
}

func mergeMapInto(src map[string]string, dst map[string]string) {
for key, value := range src {
dst[key] = value
}
}

func getPrimaryContainer(spec *v1.PodSpec, primaryContainerName string) (*v1.Container, error) {
for _, container := range spec.Containers {
if container.Name == primaryContainerName {
return &container, nil
}
}
return nil, errors.Errorf(errors.BadTaskSpecification, "primary container [%v] not found in pod spec", primaryContainerName)
}

func replacePrimaryContainer(spec *v1.PodSpec, primaryContainerName string, container v1.Container) error {
for i, c := range spec.Containers {
if c.Name == primaryContainerName {
Expand Down Expand Up @@ -104,8 +75,7 @@ func (p daskResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC
if err != nil {
return nil, err
}
nonInterruptibleTaskMetadata := nonInterruptibleTaskExecutionMetadata{taskCtx.TaskExecutionMetadata()}
nonInterruptibleTaskCtx := nonInterruptibleTaskExecutionContext{taskCtx, nonInterruptibleTaskMetadata}
nonInterruptibleTaskCtx := flytek8s.NewNonInterruptibleTaskExecutionContext(taskCtx)
nonInterruptiblePodSpec, _, _, err := flytek8s.ToK8sPodSpec(ctx, nonInterruptibleTaskCtx)
if err != nil {
return nil, err
Expand Down Expand Up @@ -144,7 +114,7 @@ func (p daskResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC

func createWorkerSpec(cluster plugins.DaskWorkerGroup, podSpec *v1.PodSpec, primaryContainerName string) (*daskAPI.WorkerSpec, error) {
workerPodSpec := podSpec.DeepCopy()
primaryContainer, err := getPrimaryContainer(workerPodSpec, primaryContainerName)
primaryContainer, err := flytek8s.GetContainer(workerPodSpec, primaryContainerName)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -206,7 +176,7 @@ func createWorkerSpec(cluster plugins.DaskWorkerGroup, podSpec *v1.PodSpec, prim

func createSchedulerSpec(scheduler plugins.DaskScheduler, clusterName string, podSpec *v1.PodSpec, primaryContainerName string) (*daskAPI.SchedulerSpec, error) {
schedulerPodSpec := podSpec.DeepCopy()
primaryContainer, err := getPrimaryContainer(schedulerPodSpec, primaryContainerName)
primaryContainer, err := flytek8s.GetContainer(schedulerPodSpec, primaryContainerName)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -283,7 +253,7 @@ func createJobSpec(workerSpec daskAPI.WorkerSpec, schedulerSpec daskAPI.Schedule
jobPodSpec := podSpec.DeepCopy()
jobPodSpec.RestartPolicy = v1.RestartPolicyNever

primaryContainer, err := getPrimaryContainer(jobPodSpec, primaryContainerName)
primaryContainer, err := flytek8s.GetContainer(jobPodSpec, primaryContainerName)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit b326791

Please sign in to comment.