From c81721681514b6ef0fe93b900a01d5f724181f18 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Wed, 20 Mar 2024 10:47:15 -0500 Subject: [PATCH] prepopulate output literals with TaskNode interface output variables Signed-off-by: Daniel Rammer --- .../pkg/controller/nodes/array/handler.go | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/flytepropeller/pkg/controller/nodes/array/handler.go b/flytepropeller/pkg/controller/nodes/array/handler.go index 72b2c511eb..5be0bed120 100644 --- a/flytepropeller/pkg/controller/nodes/array/handler.go +++ b/flytepropeller/pkg/controller/nodes/array/handler.go @@ -478,7 +478,38 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu gatherOutputsRequests = append(gatherOutputsRequests, gatherOutputsRequest) } + // attempt best effort at initializing outputLiterals with output variable names. currently + // only TaskNode and WorkflowNode contain node interfaces. outputLiterals := make(map[string]*idlcore.Literal) + + switch arrayNode.GetSubNodeSpec().GetKind() { + case v1alpha1.NodeKindTask: + taskID := *arrayNode.GetSubNodeSpec().TaskRef + taskNode, err := nCtx.ExecutionContext().GetTask(taskID) + if err != nil { + return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(idlcore.ExecutionError_SYSTEM, + errors.BadSpecificationError, fmt.Sprintf("failed to find ArrayNode subNode task with id: '%s'", taskID), nil)), nil + } + + if outputs := taskNode.CoreTask().GetInterface().GetOutputs(); outputs != nil { + for name := range outputs.Variables { + outputLiteral := &idlcore.Literal{ + Value: &idlcore.Literal_Collection{ + Collection: &idlcore.LiteralCollection{ + Literals: make([]*idlcore.Literal, 0, len(arrayNodeState.SubNodePhases.GetItems())), + }, + }, + } + + outputLiterals[name] = outputLiteral + } + } + case v1alpha1.NodeKindWorkflow: + fallthrough + default: + logger.Warnf(ctx, "ArrayNode does not support pre-populating outputLiteral collections for node kind '%s'", arrayNode.GetSubNodeSpec().GetKind()) + } + workerErrorCollector := errorcollector.NewErrorMessageCollector() for i, gatherOutputsRequest := range gatherOutputsRequests { outputResponse := <-gatherOutputsRequest.responseChannel