Skip to content

Commit

Permalink
Comments
Browse files Browse the repository at this point in the history
Signed-off-by: Andrew Dye <[email protected]>
  • Loading branch information
andrewwdye committed Oct 10, 2023
1 parent b326791 commit 072cda8
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 67 deletions.
Original file line number Diff line number Diff line change
@@ -1,70 +1,13 @@
package flytek8s

import (
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core"
)

// Wraps a regular TaskExecutionMetadata and overrides the IsInterruptible method to always return false
// This is useful as the runner and the scheduler pods should never be interruptible
type NonInterruptibleTaskExecutionMetadata struct {
metadata pluginsCore.TaskExecutionMetadata
}

func (n NonInterruptibleTaskExecutionMetadata) GetOwnerID() types.NamespacedName {
return n.metadata.GetOwnerID()
}

func (n NonInterruptibleTaskExecutionMetadata) GetTaskExecutionID() pluginsCore.TaskExecutionID {
return n.metadata.GetTaskExecutionID()
}

func (n NonInterruptibleTaskExecutionMetadata) GetNamespace() string {
return n.metadata.GetNamespace()
}

func (n NonInterruptibleTaskExecutionMetadata) GetOwnerReference() metav1.OwnerReference {
return n.metadata.GetOwnerReference()
}

func (n NonInterruptibleTaskExecutionMetadata) GetOverrides() pluginsCore.TaskOverrides {
return n.metadata.GetOverrides()
}

func (n NonInterruptibleTaskExecutionMetadata) GetLabels() map[string]string {
return n.metadata.GetLabels()
}

func (n NonInterruptibleTaskExecutionMetadata) GetMaxAttempts() uint32 {
return n.metadata.GetMaxAttempts()
}

func (n NonInterruptibleTaskExecutionMetadata) GetAnnotations() map[string]string {
return n.metadata.GetAnnotations()
}

func (n NonInterruptibleTaskExecutionMetadata) GetK8sServiceAccount() string {
return n.metadata.GetK8sServiceAccount()
}

func (n NonInterruptibleTaskExecutionMetadata) GetSecurityContext() core.SecurityContext {
return n.metadata.GetSecurityContext()
}

func (n NonInterruptibleTaskExecutionMetadata) GetPlatformResources() *v1.ResourceRequirements {
return n.metadata.GetPlatformResources()
}

func (n NonInterruptibleTaskExecutionMetadata) GetInterruptibleFailureThreshold() int32 {
return n.metadata.GetInterruptibleFailureThreshold()
}

func (n NonInterruptibleTaskExecutionMetadata) GetEnvironmentVariables() map[string]string {
return n.metadata.GetEnvironmentVariables()
pluginsCore.TaskExecutionMetadata
}

func (n NonInterruptibleTaskExecutionMetadata) IsInterruptible() bool {
Expand All @@ -86,7 +29,7 @@ func NewNonInterruptibleTaskExecutionContext(ctx pluginsCore.TaskExecutionContex
return NonInterruptibleTaskExecutionContext{
TaskExecutionContext: ctx,
metadata: NonInterruptibleTaskExecutionMetadata{
metadata: ctx.TaskExecutionMetadata(),
ctx.TaskExecutionMetadata(),
},
}
}
10 changes: 2 additions & 8 deletions flyteplugins/go/tasks/plugins/k8s/spark/spark.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,16 +161,14 @@ func createSparkPodSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionCo
SecurityContenxt: podSpec.SecurityContext.DeepCopy(),
DNSConfig: podSpec.DNSConfig.DeepCopy(),
Tolerations: podSpec.Tolerations,
SchedulerName: &config.GetK8sPluginConfig().SchedulerName,
SchedulerName: &podSpec.SchedulerName,
NodeSelector: podSpec.NodeSelector,
HostNetwork: config.GetK8sPluginConfig().EnableHostNetworkingPod,
HostNetwork: &podSpec.HostNetwork,
}
return &spec, nil
}

type driverSpec struct {
podSpec *v1.PodSpec
container *v1.Container
sparkSpec *sparkOp.DriverSpec
}

Expand All @@ -191,8 +189,6 @@ func createDriverSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionCont
}
serviceAccountName := serviceAccountName(nonInterruptibleTaskCtx.TaskExecutionMetadata())
spec := driverSpec{
podSpec,
primaryContainer,
&sparkOp.DriverSpec{
SparkPodSpec: *sparkPodSpec,
ServiceAccount: &serviceAccountName,
Expand All @@ -206,7 +202,6 @@ func createDriverSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionCont
}

type executorSpec struct {
podSpec *v1.PodSpec
container *v1.Container
sparkSpec *sparkOp.ExecutorSpec
serviceAccountName string
Expand All @@ -227,7 +222,6 @@ func createExecutorSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionCo
}
serviceAccountName := serviceAccountName(taskCtx.TaskExecutionMetadata())
spec := executorSpec{
podSpec,
primaryContainer,
&sparkOp.ExecutorSpec{
SparkPodSpec: *sparkPodSpec,
Expand Down

0 comments on commit 072cda8

Please sign in to comment.