Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Migrate ArrayJob to use TaskTemplate Config and deprecate ArrayJob proto #229

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
31 changes: 10 additions & 21 deletions go/tasks/plugins/array/array_tests_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@ import (
"github.com/flyteorg/flyteplugins/tests"

idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"
"github.com/flyteorg/flytestdlib/utils"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"

"context"
Expand Down Expand Up @@ -49,13 +45,11 @@ func RunArrayTestsEndToEnd(t *testing.T, executor core.Plugin, iter AdvanceItera
}

var err error
template.Custom, err = utils.MarshalPbToStruct(&plugins.ArrayJob{
Parallelism: 10,
Size: 1,
SuccessCriteria: &plugins.ArrayJob_MinSuccesses{
MinSuccesses: 1,
},
})
template.Config = map[string]string{
"Parallelism": "10",
"Size": "1",
"MinSuccesses": "1",
}

assert.NoError(t, err)

Expand Down Expand Up @@ -83,16 +77,11 @@ func RunArrayTestsEndToEnd(t *testing.T, executor core.Plugin, iter AdvanceItera
},
}

var err error
template.Custom, err = utils.MarshalPbToStruct(&plugins.ArrayJob{
Parallelism: 10,
Size: 2,
SuccessCriteria: &plugins.ArrayJob_MinSuccesses{
MinSuccesses: 1,
},
})

assert.NoError(t, err)
template.Config = map[string]string{
"Parallelism": "10",
"Size": "2",
"MinSuccesses": "1",
}

expectedOutputs := coreutils.MustMakeLiteral(map[string]interface{}{
"x": []interface{}{5, 5},
Expand Down
4 changes: 2 additions & 2 deletions go/tasks/plugins/array/awsbatch/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,10 @@ func init() {
pluginmachinery.PluginRegistry().RegisterCorePlugin(
core.PluginEntry{
ID: executorName,
RegisteredTaskTypes: []core.TaskType{arrayTaskType, array.AwsBatchTaskType},
RegisteredTaskTypes: []core.TaskType{arrayTaskType, arrayCore.AwsBatchTaskType},
LoadPlugin: createNewExecutorPlugin,
IsDefault: false,
DefaultForTaskTypes: []core.TaskType{arrayTaskType, array.AwsBatchTaskType},
DefaultForTaskTypes: []core.TaskType{arrayTaskType, arrayCore.AwsBatchTaskType},
})
}

Expand Down
7 changes: 3 additions & 4 deletions go/tasks/plugins/array/awsbatch/transformer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,14 @@ import (

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils"
"github.com/flyteorg/flyteplugins/go/tasks/plugins/array"
"github.com/flyteorg/flyteplugins/go/tasks/plugins/array/awsbatch/config"

v12 "k8s.io/api/core/v1"
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"

"github.com/aws/aws-sdk-go/service/batch"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"
arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -137,7 +136,7 @@ func TestArrayJobToBatchInput(t *testing.T) {
},
}

input := &plugins.ArrayJob{
input := &arrayCore.ArrayJob{
Size: 10,
Parallelism: 5,
}
Expand Down Expand Up @@ -207,7 +206,7 @@ func TestArrayJobToBatchInput(t *testing.T) {
assert.NotNil(t, batchInput)
assert.Equal(t, *expectedBatchInput, *batchInput)

taskTemplate.Type = array.AwsBatchTaskType
taskTemplate.Type = arrayCore.AwsBatchTaskType
tr.OnReadMatch(mock.Anything).Return(taskTemplate, nil)
taskCtx.OnTaskReader().Return(tr)

Expand Down
19 changes: 2 additions & 17 deletions go/tasks/plugins/array/catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ import (
"math"
"strconv"

idlPlugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"

arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core"

"github.com/flyteorg/flytestdlib/bitarray"
Expand All @@ -23,8 +21,6 @@ import (
idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
)

const AwsBatchTaskType = "aws-batch"

// DetermineDiscoverability checks if there are any previously cached tasks. If there are we will only submit an
// ArrayJob for the non-cached tasks. The ArrayJob is now a different size, and each task will get a new index location
// which is different than their original location. To find the original index we construct an indexLookup array.
Expand All @@ -42,18 +38,7 @@ func DetermineDiscoverability(ctx context.Context, tCtx core.TaskExecutionContex
}

// Extract the custom plugin pb
var arrayJob *idlPlugins.ArrayJob
if taskTemplate.Type == AwsBatchTaskType {
arrayJob = &idlPlugins.ArrayJob{
Parallelism: 1,
Size: 1,
SuccessCriteria: &idlPlugins.ArrayJob_MinSuccesses{
MinSuccesses: 1,
},
}
} else {
arrayJob, err = arrayCore.ToArrayJob(taskTemplate.GetCustom(), taskTemplate.TaskTypeVersion)
}
arrayJob, err := arrayCore.ToArrayJob(taskTemplate, taskTemplate.TaskTypeVersion)
if err != nil {
return state, err
}
Expand Down Expand Up @@ -96,7 +81,7 @@ func DetermineDiscoverability(ctx context.Context, tCtx core.TaskExecutionContex
return state, errors.Errorf(errors.BadTaskSpecification, "Unable to determine array size from inputs")
}

minSuccesses := math.Ceil(float64(arrayJob.GetMinSuccessRatio()) * float64(size))
minSuccesses := math.Ceil(arrayJob.GetMinSuccessRatio() * float64(size))

logger.Debugf(ctx, "Computed state: size [%d] and minSuccesses [%d]", int64(size), int64(minSuccesses))
state = state.SetOriginalArraySize(int64(size))
Expand Down
35 changes: 26 additions & 9 deletions go/tasks/plugins/array/catalog_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func TestDetermineDiscoverability(t *testing.T) {

t.Run("Run AWS Batch single job", func(t *testing.T) {
toCache := arrayCore.InvertBitSet(bitarray.NewBitSet(1), 1)
template.Type = AwsBatchTaskType
template.Type = arrayCore.AwsBatchTaskType
runDetermineDiscoverabilityTest(t, template, f, &arrayCore.State{
CurrentPhase: arrayCore.PhasePreLaunch,
PhaseVersion: core2.DefaultPhaseVersion,
Expand Down Expand Up @@ -258,14 +258,9 @@ func TestDiscoverabilityTaskType1(t *testing.T) {
download.OnGetCachedResults().Return(bitarray.NewBitSet(1)).Once()
toCache := arrayCore.InvertBitSet(bitarray.NewBitSet(uint(3)), uint(3))

arrayJob := &plugins.ArrayJob{
SuccessCriteria: &plugins.ArrayJob_MinSuccessRatio{
MinSuccessRatio: 0.5,
},
arrayJob := map[string]string{
"MinSuccessRatio": "0.5",
}
var arrayJobCustom structpb.Struct
err := utils.MarshalStruct(arrayJob, &arrayJobCustom)
assert.NoError(t, err)
templateType1 := &core.TaskTemplate{
Id: &core.Identifier{
ResourceType: core.ResourceType_TASK,
Expand All @@ -290,8 +285,30 @@ func TestDiscoverabilityTaskType1(t *testing.T) {
},
},
TaskTypeVersion: 1,
Custom: &arrayJobCustom,
Config: arrayJob,
}

runDetermineDiscoverabilityTest(t, templateType1, f, &arrayCore.State{
CurrentPhase: arrayCore.PhasePreLaunch,
PhaseVersion: core2.DefaultPhaseVersion,
ExecutionArraySize: 3,
OriginalArraySize: 3,
OriginalMinSuccesses: 2,
IndexesToCache: toCache,
Reason: "Task is not discoverable.",
}, nil)

// Get ArrayJob information from taskTemplate.config
arrayJobProto := &plugins.ArrayJob{
SuccessCriteria: &plugins.ArrayJob_MinSuccessRatio{
MinSuccessRatio: 0.5,
},
}
var arrayJobCustom structpb.Struct
err := utils.MarshalStruct(arrayJobProto, &arrayJobCustom)
assert.NoError(t, err)
templateType1.Config = nil
templateType1.Custom = &arrayJobCustom

runDetermineDiscoverabilityTest(t, templateType1, f, &arrayCore.State{
CurrentPhase: arrayCore.PhasePreLaunch,
Expand Down
24 changes: 24 additions & 0 deletions go/tasks/plugins/array/core/array_job.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package core

type ArrayJob struct {
Parallelism int64
Size int64
MinSuccesses int64
MinSuccessRatio float64
}

func (a ArrayJob) GetParallelism() int64 {
return a.Parallelism
}

func (a ArrayJob) GetSize() int64 {
return a.Size
}

func (a ArrayJob) GetMinSuccesses() int64 {
return a.MinSuccesses
}

func (a ArrayJob) GetMinSuccessRatio() float64 {
return a.MinSuccessRatio
}
70 changes: 48 additions & 22 deletions go/tasks/plugins/array/core/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ package core
import (
"context"
"fmt"
"strconv"
"time"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils"

"github.com/flyteorg/flytestdlib/errors"

"github.com/flyteorg/flyteplugins/go/tasks/plugins/array/arraystatus"
Expand All @@ -13,9 +16,7 @@ import (
idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
idlPlugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils"
"github.com/flyteorg/flytestdlib/logger"
structpb "github.com/golang/protobuf/ptypes/struct"
)

//go:generate mockery -all -case=underscore
Expand All @@ -38,6 +39,8 @@ const (
PhasePermanentFailure
)

const AwsBatchTaskType = "aws-batch"

type State struct {
CurrentPhase Phase `json:"phase"`
PhaseVersion uint32 `json:"phaseVersion"`
Expand Down Expand Up @@ -139,30 +142,53 @@ const (
ErrorK8sArrayGeneric errors.ErrorCode = "ARRAY_JOB_GENERIC_FAILURE"
)

func ToArrayJob(structObj *structpb.Struct, taskTypeVersion int32) (*idlPlugins.ArrayJob, error) {
if structObj == nil {
if taskTypeVersion == 0 {

return &idlPlugins.ArrayJob{
Parallelism: 1,
Size: 1,
SuccessCriteria: &idlPlugins.ArrayJob_MinSuccesses{
MinSuccesses: 1,
},
}, nil
func ToArrayJob(taskTemplate *idlCore.TaskTemplate, taskTypeVersion int32) (*ArrayJob, error) {
if taskTemplate != nil && taskTemplate.GetConfig() != nil {
config := taskTemplate.GetConfig()
arrayJob := &ArrayJob{}
var err error
if len(config["Parallelism"]) != 0 {
arrayJob.Parallelism, err = strconv.ParseInt(config["Parallelism"], 10, 64)
}
if len(config["Size"]) != 0 {
arrayJob.Size, err = strconv.ParseInt(config["Size"], 10, 64)
}
if len(config["MinSuccesses"]) != 0 {
arrayJob.MinSuccesses, err = strconv.ParseInt(config["MinSuccesses"], 10, 64)
}
if len(config["MinSuccessRatio"]) != 0 {
arrayJob.MinSuccessRatio, err = strconv.ParseFloat(config["MinSuccessRatio"], 64)
}
return arrayJob, err
}

// Keep backward compatibility for those who use arrayJob proto
if taskTemplate != nil && taskTemplate.GetCustom() != nil {
arrayJob := &idlPlugins.ArrayJob{}
err := utils.UnmarshalStruct(taskTemplate.GetCustom(), arrayJob)
if err != nil {
return nil, err
}
return &idlPlugins.ArrayJob{
Parallelism: 1,
Size: 1,
SuccessCriteria: &idlPlugins.ArrayJob_MinSuccessRatio{
MinSuccessRatio: 1.0,
},
return &ArrayJob{
Parallelism: arrayJob.GetParallelism(),
Size: arrayJob.GetSize(),
MinSuccessRatio: float64(arrayJob.GetMinSuccessRatio()),
MinSuccesses: arrayJob.GetMinSuccesses(),
}, nil
}

arrayJob := &idlPlugins.ArrayJob{}
err := utils.UnmarshalStruct(structObj, arrayJob)
return arrayJob, err
if taskTypeVersion == 0 || (taskTemplate != nil && taskTemplate.Type == AwsBatchTaskType) {
return &ArrayJob{
Parallelism: 1,
Size: 1,
MinSuccesses: 1,
}, nil
}
return &ArrayJob{
Parallelism: 1,
Size: 1,
MinSuccessRatio: 1.0,
}, nil
}

func GetPhaseVersionOffset(currentPhase Phase, length int64) uint32 {
Expand Down
Loading