Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Migrate ArrayJob to use TaskTemplate Config and deprecate ArrayJob proto #229

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
28 changes: 10 additions & 18 deletions go/tasks/plugins/array/array_tests_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@ package array
import (
"testing"

arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core"
"github.com/flyteorg/flyteplugins/tests"

idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"
"github.com/flyteorg/flytestdlib/utils"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
Expand Down Expand Up @@ -49,12 +48,10 @@ func RunArrayTestsEndToEnd(t *testing.T, executor core.Plugin, iter AdvanceItera
}

var err error
template.Custom, err = utils.MarshalPbToStruct(&plugins.ArrayJob{
Parallelism: 10,
Size: 1,
SuccessCriteria: &plugins.ArrayJob_MinSuccesses{
MinSuccesses: 1,
},
template.Custom, err = utils.MarshalObjToStruct(&arrayCore.ArrayJob{
Parallelism: 10,
Size: 1,
MinSuccesses: 1,
})

assert.NoError(t, err)
Expand Down Expand Up @@ -83,16 +80,11 @@ func RunArrayTestsEndToEnd(t *testing.T, executor core.Plugin, iter AdvanceItera
},
}

var err error
template.Custom, err = utils.MarshalPbToStruct(&plugins.ArrayJob{
Parallelism: 10,
Size: 2,
SuccessCriteria: &plugins.ArrayJob_MinSuccesses{
MinSuccesses: 1,
},
})

assert.NoError(t, err)
template.Config = map[string]string{
"Parallelism": "10",
"Size": "2",
"MinSuccesses": "1",
}

expectedOutputs := coreutils.MustMakeLiteral(map[string]interface{}{
"x": []interface{}{5, 5},
Expand Down
4 changes: 2 additions & 2 deletions go/tasks/plugins/array/awsbatch/transformer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import (

"github.com/aws/aws-sdk-go/service/batch"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"
arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -136,7 +136,7 @@ func TestArrayJobToBatchInput(t *testing.T) {
},
}

input := &plugins.ArrayJob{
input := &arrayCore.ArrayJob{
Size: 10,
Parallelism: 5,
}
Expand Down
4 changes: 2 additions & 2 deletions go/tasks/plugins/array/catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func DetermineDiscoverability(ctx context.Context, tCtx core.TaskExecutionContex
}

// Extract the custom plugin pb
arrayJob, err := arrayCore.ToArrayJob(taskTemplate.GetCustom(), taskTemplate.TaskTypeVersion)
arrayJob, err := arrayCore.ToArrayJob(taskTemplate.GetConfig(), taskTemplate.TaskTypeVersion)
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return state, err
}
Expand Down Expand Up @@ -81,7 +81,7 @@ func DetermineDiscoverability(ctx context.Context, tCtx core.TaskExecutionContex
return state, errors.Errorf(errors.BadTaskSpecification, "Unable to determine array size from inputs")
}

minSuccesses := math.Ceil(float64(arrayJob.GetMinSuccessRatio()) * float64(size))
minSuccesses := math.Ceil(arrayJob.GetMinSuccessRatio() * float64(size))

logger.Debugf(ctx, "Computed state: size [%d] and minSuccesses [%d]", int64(size), int64(minSuccesses))
state = state.SetOriginalArraySize(int64(size))
Expand Down
15 changes: 3 additions & 12 deletions go/tasks/plugins/array/catalog_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@ import (
"errors"
"testing"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils"
structpb "github.com/golang/protobuf/ptypes/struct"

stdErrors "github.com/flyteorg/flytestdlib/errors"

pluginErrors "github.com/flyteorg/flyteplugins/go/tasks/errors"
Expand Down Expand Up @@ -244,14 +240,9 @@ func TestDiscoverabilityTaskType1(t *testing.T) {
download.OnGetCachedResults().Return(bitarray.NewBitSet(1)).Once()
toCache := arrayCore.InvertBitSet(bitarray.NewBitSet(uint(3)), uint(3))

arrayJob := &plugins.ArrayJob{
SuccessCriteria: &plugins.ArrayJob_MinSuccessRatio{
MinSuccessRatio: 0.5,
},
arrayJob := map[string]string{
"MinSuccessRatio": "0.5",
}
var arrayJobCustom structpb.Struct
err := utils.MarshalStruct(arrayJob, &arrayJobCustom)
assert.NoError(t, err)
templateType1 := &core.TaskTemplate{
Id: &core.Identifier{
ResourceType: core.ResourceType_TASK,
Expand All @@ -276,7 +267,7 @@ func TestDiscoverabilityTaskType1(t *testing.T) {
},
},
TaskTypeVersion: 1,
Custom: &arrayJobCustom,
Config: arrayJob,
}

runDetermineDiscoverabilityTest(t, templateType1, f, &arrayCore.State{
Expand Down
24 changes: 24 additions & 0 deletions go/tasks/plugins/array/core/array_job.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package core

type ArrayJob struct {
Parallelism int64
Size int64
MinSuccesses int64
MinSuccessRatio float64
}

func (a ArrayJob) GetParallelism() int64 {
return a.Parallelism
}

func (a ArrayJob) GetSize() int64 {
return a.Size
}

func (a ArrayJob) GetMinSuccesses() int64 {
return a.MinSuccesses
}

func (a ArrayJob) GetMinSuccessRatio() float64 {
return a.MinSuccessRatio
}
45 changes: 25 additions & 20 deletions go/tasks/plugins/array/core/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package core
import (
"context"
"fmt"
"strconv"
"time"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event"
Expand All @@ -13,11 +14,8 @@ import (
"github.com/flyteorg/flytestdlib/bitarray"

idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
idlPlugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils"
"github.com/flyteorg/flytestdlib/logger"
structpb "github.com/golang/protobuf/ptypes/struct"
)

//go:generate mockery -all -case=underscore
Expand Down Expand Up @@ -133,29 +131,36 @@ const (
ErrorK8sArrayGeneric errors.ErrorCode = "ARRAY_JOB_GENERIC_FAILURE"
)

func ToArrayJob(structObj *structpb.Struct, taskTypeVersion int32) (*idlPlugins.ArrayJob, error) {
if structObj == nil {
func ToArrayJob(config map[string]string, taskTypeVersion int32) (*ArrayJob, error) {
if config == nil {
if taskTypeVersion == 0 {

return &idlPlugins.ArrayJob{
Parallelism: 1,
Size: 1,
SuccessCriteria: &idlPlugins.ArrayJob_MinSuccesses{
MinSuccesses: 1,
},
return &ArrayJob{
Parallelism: 1,
Size: 1,
MinSuccesses: 1,
}, nil
}
return &idlPlugins.ArrayJob{
Parallelism: 1,
Size: 1,
SuccessCriteria: &idlPlugins.ArrayJob_MinSuccessRatio{
MinSuccessRatio: 1.0,
},
return &ArrayJob{
Parallelism: 1,
Size: 1,
MinSuccessRatio: 1.0,
}, nil
}

arrayJob := &idlPlugins.ArrayJob{}
err := utils.UnmarshalStruct(structObj, arrayJob)
arrayJob := &ArrayJob{}
var err error
if len(config["Parallelism"]) != 0 {
arrayJob.Parallelism, err = strconv.ParseInt(config["Parallelism"], 10, 64)
}
if len(config["Size"]) != 0 {
arrayJob.Size, err = strconv.ParseInt(config["Size"], 10, 64)
}
if len(config["MinSuccesses"]) != 0 {
arrayJob.MinSuccesses, err = strconv.ParseInt(config["MinSuccesses"], 10, 64)
}
if len(config["MinSuccessRatio"]) != 0 {
arrayJob.MinSuccessRatio, err = strconv.ParseFloat(config["MinSuccessRatio"], 64)
}
return arrayJob, err
}

Expand Down
41 changes: 26 additions & 15 deletions go/tasks/plugins/array/core/state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"
"github.com/golang/protobuf/proto"

"github.com/flyteorg/flytestdlib/bitarray"
Expand Down Expand Up @@ -247,24 +246,36 @@ func TestToArrayJob(t *testing.T) {
t.Run("task_type_version == 0", func(t *testing.T) {
arrayJob, err := ToArrayJob(nil, 0)
assert.NoError(t, err)
assert.True(t, proto.Equal(arrayJob, &plugins.ArrayJob{
Parallelism: 1,
Size: 1,
SuccessCriteria: &plugins.ArrayJob_MinSuccesses{
MinSuccesses: 1,
},
}))
assert.True(t, *arrayJob == ArrayJob{
Parallelism: 1,
Size: 1,
MinSuccesses: 1,
})
})

t.Run("task_type_version == 1", func(t *testing.T) {
arrayJob, err := ToArrayJob(nil, 1)
assert.NoError(t, err)
assert.True(t, proto.Equal(arrayJob, &plugins.ArrayJob{
Parallelism: 1,
Size: 1,
SuccessCriteria: &plugins.ArrayJob_MinSuccessRatio{
MinSuccessRatio: 1.0,
},
}))
assert.True(t, *arrayJob == ArrayJob{
Parallelism: 1,
Size: 1,
MinSuccessRatio: 1.0,
})
})

t.Run("ToArrayJob with config", func(t *testing.T) {
config := map[string]string{
"Parallelism": "10",
"Size": "10",
"MinSuccesses": "1",
"MinSuccessRatio": "1.0",
}

arrayJob, err := ToArrayJob(config, 0)
assert.NoError(t, err)
assert.Equal(t, arrayJob.GetParallelism(), int64(10))
assert.Equal(t, arrayJob.GetSize(), int64(10))
assert.Equal(t, arrayJob.GetMinSuccesses(), int64(1))
assert.Equal(t, arrayJob.GetMinSuccessRatio(), 1.0)
})
}
13 changes: 5 additions & 8 deletions go/tasks/plugins/array/k8s/transformer.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ import (
"github.com/flyteorg/flyteplugins/go/tasks/plugins/array"

idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
idlPlugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"
"github.com/flyteorg/flyteplugins/go/tasks/errors"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s"
arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core"
core2 "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
Expand Down Expand Up @@ -97,7 +97,7 @@ func buildPodMapTask(task *idlCore.TaskTemplate, metadata core.TaskExecutionMeta
// FlyteArrayJobToK8sPodTemplate returns a pod template for the given task context. Note that Name is not set on the
// result object. It's up to the caller to set the Name before creating the object in K8s.
func FlyteArrayJobToK8sPodTemplate(ctx context.Context, tCtx core.TaskExecutionContext, namespaceTemplate string) (
podTemplate v1.Pod, job *idlPlugins.ArrayJob, err error) {
podTemplate v1.Pod, job *arrayCore.ArrayJob, err error) {

// Check that the taskTemplate is valid
taskTemplate, err := tCtx.TaskReader().Read(ctx)
Expand All @@ -117,12 +117,9 @@ func FlyteArrayJobToK8sPodTemplate(ctx context.Context, tCtx core.TaskExecutionC
arrayInputReader: array.GetInputReader(tCtx, taskTemplate),
}

var arrayJob *idlPlugins.ArrayJob
if taskTemplate.GetCustom() != nil {
arrayJob, err = core2.ToArrayJob(taskTemplate.GetCustom(), taskTemplate.TaskTypeVersion)
if err != nil {
return v1.Pod{}, nil, err
}
arrayJob, err := core2.ToArrayJob(taskTemplate.GetConfig(), taskTemplate.TaskTypeVersion)
if err != nil {
return v1.Pod{}, nil, err
}

annotations := utils.UnionMaps(config.GetK8sPluginConfig().DefaultAnnotations, tCtx.TaskExecutionMetadata().GetAnnotations())
Expand Down
12 changes: 3 additions & 9 deletions go/tasks/plugins/array/k8s/transformer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@ import (
"github.com/flyteorg/flytestdlib/storage"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
idlPlugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks"
mocks2 "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils"
arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core"
structpb "github.com/golang/protobuf/ptypes/struct"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
Expand Down Expand Up @@ -42,7 +41,7 @@ var podSpec = v1.PodSpec{
},
}

var arrayJob = idlPlugins.ArrayJob{
var arrayJob = arrayCore.ArrayJob{
Size: 100,
}

Expand All @@ -57,15 +56,11 @@ func getK8sPodTask(t *testing.T, annotations map[string]string) *core.TaskTempla
t.Fatal(err)
}

custom := &structpb.Struct{}
if err := utils.MarshalStruct(&arrayJob, custom); err != nil {
t.Fatal(err)
}

return &core.TaskTemplate{
TaskTypeVersion: 2,
Config: map[string]string{
primaryContainerKey: testPrimaryContainerName,
"Size": "100",
},
Target: &core.TaskTemplate_K8SPod{
K8SPod: &core.K8SPod{
Expand All @@ -78,7 +73,6 @@ func getK8sPodTask(t *testing.T, annotations map[string]string) *core.TaskTempla
},
},
},
Custom: custom,
}
}

Expand Down