Skip to content

Commit

Permalink
Merge branch 'master' of github.com:flyteorg/flyte into failure_node
Browse files Browse the repository at this point in the history
  • Loading branch information
pingsutw committed Nov 30, 2023
2 parents d7dd52a + 7481b3d commit 4c09a50
Show file tree
Hide file tree
Showing 12 changed files with 545 additions and 376 deletions.

This file was deleted.

123 changes: 123 additions & 0 deletions flyteplugins/go/tasks/pluginmachinery/flytek8s/plugin_exec_context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package flytek8s

import (
v1 "k8s.io/api/core/v1"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core"
)

type pluginTaskOverrides struct {
pluginsCore.TaskOverrides
resources *v1.ResourceRequirements
extendedResources *core.ExtendedResources
}

func (to *pluginTaskOverrides) GetResources() *v1.ResourceRequirements {
if to.resources != nil {
return to.resources
}
return to.TaskOverrides.GetResources()
}

func (to *pluginTaskOverrides) GetExtendedResources() *core.ExtendedResources {
if to.extendedResources != nil {
return to.extendedResources
}
return to.TaskOverrides.GetExtendedResources()
}

type pluginTaskExecutionMetadata struct {
pluginsCore.TaskExecutionMetadata
interruptible *bool
overrides *pluginTaskOverrides
}

func (tm *pluginTaskExecutionMetadata) IsInterruptible() bool {
if tm.interruptible != nil {
return *tm.interruptible
}
return tm.TaskExecutionMetadata.IsInterruptible()
}

func (tm *pluginTaskExecutionMetadata) GetOverrides() pluginsCore.TaskOverrides {
if tm.overrides != nil {
return tm.overrides
}
return tm.TaskExecutionMetadata.GetOverrides()
}

type pluginTaskExecutionContext struct {
pluginsCore.TaskExecutionContext
metadata *pluginTaskExecutionMetadata
}

func (tc *pluginTaskExecutionContext) TaskExecutionMetadata() pluginsCore.TaskExecutionMetadata {
if tc.metadata != nil {
return tc.metadata
}
return tc.TaskExecutionContext.TaskExecutionMetadata()
}

type PluginTaskExecutionContextOption func(*pluginTaskExecutionContext)

func WithInterruptible(v bool) PluginTaskExecutionContextOption {
return func(tc *pluginTaskExecutionContext) {
if tc.metadata == nil {
tc.metadata = &pluginTaskExecutionMetadata{
TaskExecutionMetadata: tc.TaskExecutionContext.TaskExecutionMetadata(),
}
}
tc.metadata.interruptible = &v
}
}

func WithResources(r *v1.ResourceRequirements) PluginTaskExecutionContextOption {
return func(tc *pluginTaskExecutionContext) {
if tc.metadata == nil {
tc.metadata = &pluginTaskExecutionMetadata{
TaskExecutionMetadata: tc.TaskExecutionContext.TaskExecutionMetadata(),
}
}
if tc.metadata.overrides == nil {
tc.metadata.overrides = &pluginTaskOverrides{
TaskOverrides: tc.metadata.TaskExecutionMetadata.GetOverrides(),
}
}
tc.metadata.overrides.resources = r
}
}

func WithExtendedResources(er *core.ExtendedResources) PluginTaskExecutionContextOption {
return func(tc *pluginTaskExecutionContext) {
if tc.metadata == nil {
tc.metadata = &pluginTaskExecutionMetadata{
TaskExecutionMetadata: tc.TaskExecutionContext.TaskExecutionMetadata(),
}
}
if tc.metadata.overrides == nil {
tc.metadata.overrides = &pluginTaskOverrides{
TaskOverrides: tc.metadata.TaskExecutionMetadata.GetOverrides(),
}
}
tc.metadata.overrides.extendedResources = er
}
}

func NewPluginTaskExecutionContext(tc pluginsCore.TaskExecutionContext, options ...PluginTaskExecutionContextOption) pluginsCore.TaskExecutionContext {
tm := tc.TaskExecutionMetadata()
to := tm.GetOverrides()
ctx := &pluginTaskExecutionContext{
TaskExecutionContext: tc,
metadata: &pluginTaskExecutionMetadata{
TaskExecutionMetadata: tm,
overrides: &pluginTaskOverrides{
TaskOverrides: to,
},
},
}
for _, o := range options {
o(ctx)
}
return ctx
}
2 changes: 1 addition & 1 deletion flyteplugins/go/tasks/plugins/k8s/dask/dask.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func (p daskResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC
if err != nil {
return nil, err
}
nonInterruptibleTaskCtx := flytek8s.NewNonInterruptibleTaskExecutionContext(taskCtx)
nonInterruptibleTaskCtx := flytek8s.NewPluginTaskExecutionContext(taskCtx, flytek8s.WithInterruptible(false))
nonInterruptiblePodSpec, _, _, err := flytek8s.ToK8sPodSpec(ctx, nonInterruptibleTaskCtx)
if err != nil {
return nil, err
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package common

import (
"context"
"fmt"
"sort"
"time"
Expand All @@ -15,8 +16,10 @@ import (
"github.com/flyteorg/flyte/flyteplugins/go/tasks/logs"
pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils"
)

const (
Expand All @@ -25,12 +28,6 @@ const (
PytorchTaskType = "pytorch"
)

type ReplicaEntry struct {
PodSpec *v1.PodSpec
ReplicaNum int32
RestartPolicy commonOp.RestartPolicy
}

// ExtractCurrentCondition will return the first job condition for tensorflow/pytorch
func ExtractCurrentCondition(jobConditions []commonOp.JobCondition) (commonOp.JobCondition, error) {
if jobConditions != nil {
Expand Down Expand Up @@ -254,27 +251,97 @@ func ParseRestartPolicy(flyteRestartPolicy kfplugins.RestartPolicy) commonOp.Res
}

// OverrideContainerSpec overrides the specified container's properties in the given podSpec. The function
// updates the image, resources and command arguments of the container that matches the given containerName.
func OverrideContainerSpec(podSpec *v1.PodSpec, containerName string, image string, resources *core.Resources, args []string) error {
// updates the image and command arguments of the container that matches the given containerName.
func OverrideContainerSpec(podSpec *v1.PodSpec, containerName string, image string, args []string) error {
for idx, c := range podSpec.Containers {
if c.Name == containerName {
if image != "" {
podSpec.Containers[idx].Image = image
}
if resources != nil {
// if resources requests and limits both not set, we will not override the resources
if len(resources.Requests) >= 1 || len(resources.Limits) >= 1 {
resources, err := flytek8s.ToK8sResourceRequirements(resources)
if err != nil {
return flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification on Resources [%v], Err: [%v]", resources, err.Error())
}
podSpec.Containers[idx].Resources = *resources
}
}
if len(args) != 0 {
podSpec.Containers[idx].Args = args
}
}
}
return nil
}

func ToReplicaSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, primaryContainerName string) (*commonOp.ReplicaSpec, error) {
podSpec, objectMeta, oldPrimaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx)
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error())
}

OverridePrimaryContainerName(podSpec, oldPrimaryContainerName, primaryContainerName)

cfg := config.GetK8sPluginConfig()
objectMeta.Annotations = utils.UnionMaps(cfg.DefaultAnnotations, objectMeta.Annotations, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()))
objectMeta.Labels = utils.UnionMaps(cfg.DefaultLabels, objectMeta.Labels, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()))

replicas := int32(0)
return &commonOp.ReplicaSpec{
Replicas: &replicas,
Template: v1.PodTemplateSpec{
ObjectMeta: *objectMeta,
Spec: *podSpec,
},
RestartPolicy: commonOp.RestartPolicyNever,
}, nil
}

type kfDistributedReplicaSpec interface {
GetReplicas() int32
GetImage() string
GetResources() *core.Resources
GetRestartPolicy() kfplugins.RestartPolicy
}

type allowsCommandOverride interface {
GetCommand() []string
}

func ToReplicaSpecWithOverrides(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, rs kfDistributedReplicaSpec, primaryContainerName string, isMaster bool) (*commonOp.ReplicaSpec, error) {
taskCtxOptions := []flytek8s.PluginTaskExecutionContextOption{}
if rs != nil && rs.GetResources() != nil {
resources, err := flytek8s.ToK8sResourceRequirements(rs.GetResources())
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification on Resources [%v], Err: [%v]", resources, err.Error())
}
taskCtxOptions = append(taskCtxOptions, flytek8s.WithResources(resources))
}
newTaskCtx := flytek8s.NewPluginTaskExecutionContext(taskCtx, taskCtxOptions...)
replicaSpec, err := ToReplicaSpec(ctx, newTaskCtx, primaryContainerName)
if err != nil {
return nil, err
}

// Master should have a single replica
if isMaster {
replicas := int32(1)
replicaSpec.Replicas = &replicas
}

if rs != nil {
var command []string
if v, ok := rs.(allowsCommandOverride); ok {
command = v.GetCommand()
}
if err := OverrideContainerSpec(
&replicaSpec.Template.Spec,
primaryContainerName,
rs.GetImage(),
command,
); err != nil {
return nil, err
}

replicaSpec.RestartPolicy = ParseRestartPolicy(rs.GetRestartPolicy())

if !isMaster {
replicas := rs.GetReplicas()
replicaSpec.Replicas = &replicas
}
}

return replicaSpec, nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,9 @@ func dummyPodSpec() v1.PodSpec {
return v1.PodSpec{
Containers: []v1.Container{
{
Name: "primary container",
Args: []string{"pyflyte-execute", "--task-module", "tests.flytekit.unit.sdk.tasks.test_sidecar_tasks", "--task-name", "simple_sidecar_task", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}"},
Name: "primary container",
Args: []string{"pyflyte-execute", "--task-module", "tests.flytekit.unit.sdk.tasks.test_sidecar_tasks", "--task-name", "simple_sidecar_task", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}"},
Image: "dummy-image",
Resources: v1.ResourceRequirements{
Limits: v1.ResourceList{
"cpu": resource.MustParse("2"),
Expand Down Expand Up @@ -270,50 +271,21 @@ func TestOverrideContainerSpec(t *testing.T) {
podSpec := dummyPodSpec()
err := OverrideContainerSpec(
&podSpec, "primary container", "testing-image",
&core.Resources{
Requests: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "250m"},
},
Limits: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "500m"},
},
},
[]string{"python", "-m", "run.py"},
)
assert.NoError(t, err)
assert.Equal(t, 2, len(podSpec.Containers))
assert.Equal(t, "testing-image", podSpec.Containers[0].Image)
assert.NotNil(t, podSpec.Containers[0].Resources.Limits)
assert.NotNil(t, podSpec.Containers[0].Resources.Requests)
// verify resources not overridden if empty resources
assert.True(t, podSpec.Containers[0].Resources.Requests.Cpu().Equal(resource.MustParse("250m")))
assert.True(t, podSpec.Containers[0].Resources.Limits.Cpu().Equal(resource.MustParse("500m")))
assert.Equal(t, []string{"python", "-m", "run.py"}, podSpec.Containers[0].Args)
}

func TestOverrideContainerSpecEmptyFields(t *testing.T) {
podSpec := dummyPodSpec()
err := OverrideContainerSpec(&podSpec, "primary container", "", &core.Resources{}, []string{})
assert.NoError(t, err)
assert.Equal(t, 2, len(podSpec.Containers))
assert.NotNil(t, podSpec.Containers[0].Resources.Limits)
assert.NotNil(t, podSpec.Containers[0].Resources.Requests)
// verify resources not overridden if empty resources
assert.True(t, podSpec.Containers[0].Resources.Requests.Cpu().Equal(resource.MustParse("1")))
assert.True(t, podSpec.Containers[0].Resources.Requests.Memory().Equal(resource.MustParse("100Mi")))
assert.True(t, podSpec.Containers[0].Resources.Limits.Cpu().Equal(resource.MustParse("2")))
assert.True(t, podSpec.Containers[0].Resources.Limits.Memory().Equal(resource.MustParse("200Mi")))
}

func TestOverrideContainerNilResources(t *testing.T) {
podSpec := dummyPodSpec()
podSpecCopy := podSpec.DeepCopy()

err := OverrideContainerSpec(&podSpec, "primary container", "", nil, []string{})
err := OverrideContainerSpec(&podSpec, "primary container", "", []string{})
assert.NoError(t, err)
assert.Equal(t, 2, len(podSpec.Containers))
assert.Equal(t, podSpec.Containers[0].Resources.Limits, podSpecCopy.Containers[0].Resources.Limits)
assert.Equal(t, podSpec.Containers[0].Resources.Requests, podSpecCopy.Containers[0].Resources.Requests)
assert.Equal(t, "dummy-image", podSpec.Containers[0].Image)
assert.Equal(t, []string{"pyflyte-execute", "--task-module", "tests.flytekit.unit.sdk.tasks.test_sidecar_tasks", "--task-name", "simple_sidecar_task", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}"}, podSpec.Containers[0].Args)
}

func dummyTaskContext() pluginsCore.TaskExecutionContext {
Expand Down
Loading

0 comments on commit 4c09a50

Please sign in to comment.