diff --git a/flyteplugins/go/tasks/plugins/array/catalog.go b/flyteplugins/go/tasks/plugins/array/catalog.go index 8544555609..836ff52fda 100644 --- a/flyteplugins/go/tasks/plugins/array/catalog.go +++ b/flyteplugins/go/tasks/plugins/array/catalog.go @@ -74,11 +74,10 @@ func DetermineDiscoverability(ctx context.Context, tCtx core.TaskExecutionContex return state, errors.Errorf(errors.MetadataAccessFailed, "Could not read inputs and therefore failed to determine array job size") } + // identify and validate the size of the array job size := -1 var literalCollection *idlCore.LiteralCollection - literals := make([][]*idlCore.Literal, 0) - discoveredInputNames := make([]string, 0) - for inputName, literal := range inputs.Literals { + for _, literal := range inputs.Literals { if literalCollection = literal.GetCollection(); literalCollection != nil { // validate length of input list if size != -1 && size != len(literalCollection.Literals) { @@ -86,9 +85,6 @@ func DetermineDiscoverability(ctx context.Context, tCtx core.TaskExecutionContex return state, nil } - literals = append(literals, literalCollection.Literals) - discoveredInputNames = append(discoveredInputNames, inputName) - size = len(literalCollection.Literals) } } @@ -110,7 +106,7 @@ func DetermineDiscoverability(ctx context.Context, tCtx core.TaskExecutionContex arrayJobSize = int64(size) // build input readers - inputReaders = ConstructStaticInputReaders(tCtx.InputReader(), literals, discoveredInputNames) + inputReaders = ConstructStaticInputReaders(tCtx.InputReader(), inputs.Literals, size) } if arrayJobSize > maxArrayJobSize { @@ -246,18 +242,7 @@ func WriteToDiscovery(ctx context.Context, tCtx core.TaskExecutionContext, state return state, externalResources, errors.Errorf(errors.MetadataAccessFailed, "Could not read inputs and therefore failed to determine array job size") } - var literalCollection *idlCore.LiteralCollection - literals := make([][]*idlCore.Literal, 0) - discoveredInputNames := make([]string, 0) - for inputName, literal := range inputs.Literals { - if literalCollection = literal.GetCollection(); literalCollection != nil { - literals = append(literals, literalCollection.Literals) - discoveredInputNames = append(discoveredInputNames, inputName) - } - } - - // build input readers - inputReaders = ConstructStaticInputReaders(tCtx.InputReader(), literals, discoveredInputNames) + inputReaders = ConstructStaticInputReaders(tCtx.InputReader(), inputs.Literals, arrayJobSize) } // output reader @@ -476,16 +461,19 @@ func ConstructCatalogReaderWorkItems(ctx context.Context, taskReader core.TaskRe // ConstructStaticInputReaders constructs input readers that comply with the io.InputReader interface but have their // inputs already populated. -func ConstructStaticInputReaders(inputPaths io.InputFilePaths, inputs [][]*idlCore.Literal, inputNames []string) []io.InputReader { - inputReaders := make([]io.InputReader, 0, len(inputs)) - if len(inputs) == 0 { - return inputReaders - } +func ConstructStaticInputReaders(inputPaths io.InputFilePaths, inputLiterals map[string]*idlCore.Literal, arrayJobSize int) []io.InputReader { + var literalCollection *idlCore.LiteralCollection - for i := 0; i < len(inputs[0]); i++ { + inputReaders := make([]io.InputReader, 0, arrayJobSize) + for i := 0; i < arrayJobSize; i++ { literals := make(map[string]*idlCore.Literal) - for j := 0; j < len(inputNames); j++ { - literals[inputNames[j]] = inputs[j][i] + for inputName, inputLiteral := range inputLiterals { + if literalCollection = inputLiteral.GetCollection(); literalCollection != nil { + // if literal is a collection then we need to retrieve the specific literal for this subtask index + literals[inputName] = literalCollection.Literals[i] + } else { + literals[inputName] = inputLiteral + } } inputReaders = append(inputReaders, NewStaticInputReader(inputPaths, &idlCore.LiteralMap{Literals: literals})) diff --git a/flytepropeller/pkg/controller/nodes/array/handler.go b/flytepropeller/pkg/controller/nodes/array/handler.go index 35076d0a97..dddcd0e7c5 100644 --- a/flytepropeller/pkg/controller/nodes/array/handler.go +++ b/flytepropeller/pkg/controller/nodes/array/handler.go @@ -494,7 +494,12 @@ func (a *arrayNodeHandler) buildArrayNodeContext(ctx context.Context, nCtx inter taskPhase := int(arrayNodeState.SubNodeTaskPhases.GetItem(subNodeIndex)) // need to initialize the inputReader every time to ensure TaskHandler can access for cache lookups / population - inputLiteralMap, err := constructLiteralMap(ctx, nCtx.InputReader(), subNodeIndex) + inputs, err := nCtx.InputReader().Get(ctx) + if err != nil { + return nil, nil, nil, nil, nil, nil, err + } + + inputLiteralMap, err := constructLiteralMap(inputs, subNodeIndex) if err != nil { return nil, nil, nil, nil, nil, nil, err } diff --git a/flytepropeller/pkg/controller/nodes/array/node_execution_context.go b/flytepropeller/pkg/controller/nodes/array/node_execution_context.go index c639c543e7..17d46d2944 100644 --- a/flytepropeller/pkg/controller/nodes/array/node_execution_context.go +++ b/flytepropeller/pkg/controller/nodes/array/node_execution_context.go @@ -2,6 +2,7 @@ package array import ( "context" + "fmt" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io" @@ -26,16 +27,16 @@ func newStaticInputReader(inputPaths io.InputFilePaths, input *core.LiteralMap) } } -func constructLiteralMap(ctx context.Context, inputReader io.InputReader, index int) (*core.LiteralMap, error) { - inputs, err := inputReader.Get(ctx) - if err != nil { - return nil, err - } - +func constructLiteralMap(inputs *core.LiteralMap, index int) (*core.LiteralMap, error) { literals := make(map[string]*core.Literal) for name, literal := range inputs.Literals { if literalCollection := literal.GetCollection(); literalCollection != nil { + if index >= len(literalCollection.Literals) { + return nil, fmt.Errorf("index %v out of bounds for literal collection %v", index, name) + } literals[name] = literalCollection.Literals[index] + } else { + literals[name] = literal } } diff --git a/flytepropeller/pkg/controller/nodes/array/node_execution_context_test.go b/flytepropeller/pkg/controller/nodes/array/node_execution_context_test.go new file mode 100644 index 0000000000..0a5546bd25 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/array/node_execution_context_test.go @@ -0,0 +1,160 @@ +package array + +import ( + "testing" + + "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/assert" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" +) + +var ( + literalOne = &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: &core.Primitive{ + Value: &core.Primitive_Integer{ + Integer: 1, + }, + }, + }, + }, + }, + } + literalTwo = &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: &core.Primitive{ + Value: &core.Primitive_Integer{ + Integer: 2, + }, + }, + }, + }, + }, + } +) + +func TestConstructLiteralMap(t *testing.T) { + tests := []struct { + name string + inputLiteralMaps *core.LiteralMap + expectedLiteralMaps []*core.LiteralMap + }{ + { + "SingleList", + &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "foo": &core.Literal{ + Value: &core.Literal_Collection{ + Collection: &core.LiteralCollection{ + Literals: []*core.Literal{ + literalOne, + literalTwo, + }, + }, + }, + }, + }, + }, + []*core.LiteralMap{ + &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "foo": literalOne, + }, + }, + &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "foo": literalTwo, + }, + }, + }, + }, + { + "MultiList", + &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "foo": &core.Literal{ + Value: &core.Literal_Collection{ + Collection: &core.LiteralCollection{ + Literals: []*core.Literal{ + literalOne, + literalTwo, + }, + }, + }, + }, + "bar": &core.Literal{ + Value: &core.Literal_Collection{ + Collection: &core.LiteralCollection{ + Literals: []*core.Literal{ + literalTwo, + literalOne, + }, + }, + }, + }, + }, + }, + []*core.LiteralMap{ + &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "foo": literalOne, + "bar": literalTwo, + }, + }, + &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "foo": literalTwo, + "bar": literalOne, + }, + }, + }, + }, + { + "Partial", + &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "foo": &core.Literal{ + Value: &core.Literal_Collection{ + Collection: &core.LiteralCollection{ + Literals: []*core.Literal{ + literalOne, + literalTwo, + }, + }, + }, + }, + "bar": literalTwo, + }, + }, + []*core.LiteralMap{ + &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "foo": literalOne, + "bar": literalTwo, + }, + }, + &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "foo": literalTwo, + "bar": literalTwo, + }, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for i := 0; i < len(test.expectedLiteralMaps); i++ { + outputLiteralMap, err := constructLiteralMap(test.inputLiteralMaps, i) + assert.NoError(t, err) + assert.True(t, proto.Equal(test.expectedLiteralMaps[i], outputLiteralMap)) + } + }) + } +}