diff --git a/internal/step/foreach/provider.go b/internal/step/foreach/provider.go index 906db40f..f6fd375a 100644 --- a/internal/step/foreach/provider.go +++ b/internal/step/foreach/provider.go @@ -6,6 +6,7 @@ import ( "go.arcalot.io/dgraph" "reflect" "sync" + "sync/atomic" "go.arcalot.io/log/v2" "go.flow.arcalot.io/engine/internal/step" @@ -36,6 +37,14 @@ const ( StageIDOutputs StageID = "outputs" // StageIDFailed is providing the error reason from the subworkflow. StageIDFailed StageID = "failed" + // StageIDEnabling is a stage that indicates that the step is waiting to be enabled. + // This is required to be separate to ensure that it exits immediately if disabled. + StageIDEnabling StageID = "enabling" + // StageIDDisabled is indicating that the step was disabled. + StageIDDisabled StageID = "disabled" + // StageIDClosed is a stage that indicates that the workflow has exited or did not start + // due to workflow termination or step cancellation. + StageIDClosed StageID = "closed" ) var executeLifecycleStage = step.LifecycleStage{ @@ -71,6 +80,33 @@ var errorLifecycleStage = step.LifecycleStage{ NextStages: map[string]dgraph.DependencyType{}, Fatal: true, } +var enablingLifecycleStage = step.LifecycleStage{ + ID: string(StageIDEnabling), + WaitingName: "waiting to be enabled", + RunningName: "enabling", + FinishedName: "enablement determined", + InputFields: map[string]struct{}{ + "enabled": {}, + }, + NextStages: map[string]dgraph.DependencyType{ + string(StageIDExecute): dgraph.AndDependency, + string(StageIDDisabled): dgraph.AndDependency, + string(StageIDClosed): dgraph.CompletionAndDependency, + }, +} +var disabledLifecycleStage = step.LifecycleStage{ + ID: string(StageIDDisabled), + WaitingName: "waiting for the step to be disabled", + RunningName: "disabling", + FinishedName: "disabled", + InputFields: map[string]struct{}{}, +} +var closedLifecycleStage = step.LifecycleStage{ + ID: string(StageIDClosed), + WaitingName: "closed", + RunningName: "closed", + FinishedName: "closed", +} type forEachProvider struct { logger log.Logger @@ -89,6 +125,9 @@ func (l *forEachProvider) Lifecycle() step.Lifecycle[step.LifecycleStage] { executeLifecycleStage, outputLifecycleStage, errorLifecycleStage, + enablingLifecycleStage, + disabledLifecycleStage, + closedLifecycleStage, }, } } @@ -324,6 +363,68 @@ func (r *runnableStep) Lifecycle(_ map[string]any) (step.Lifecycle[step.Lifecycl ), }, }, + { + LifecycleStage: enablingLifecycleStage, + InputSchema: map[string]*schema.PropertySchema{ + "enabled": schema.NewPropertySchema( + schema.NewBoolSchema(), + schema.NewDisplayValue( + schema.PointerTo("Enabled"), + schema.PointerTo("Used to set whether the step is enabled."), + nil, + ), + false, + nil, + nil, + nil, + nil, + nil, + ), + }, + Outputs: map[string]*schema.StepOutputSchema{ + "resolved": step.EnabledOutputSchema(), + }, + }, + { + LifecycleStage: disabledLifecycleStage, + InputSchema: nil, + Outputs: map[string]*schema.StepOutputSchema{ + "output": step.DisabledOutputSchema(), + }, + }, + { + LifecycleStage: closedLifecycleStage, + InputSchema: nil, + Outputs: map[string]*schema.StepOutputSchema{ + "result": { + SchemaValue: schema.NewScopeSchema( + schema.NewObjectSchema( + "ClosedInfo", + map[string]*schema.PropertySchema{ + // Unlike a normal step, it cannot be cancelled at this time. + // That feature can be added later if there is demand. + "close_requested": schema.NewPropertySchema( + schema.NewBoolSchema(), + schema.NewDisplayValue( + schema.PointerTo("close requested"), + schema.PointerTo("Whether the step was closed with Close()"), + nil, + ), + true, + nil, + nil, + nil, + nil, + nil, + ), + }, + ), + ), + DisplayValue: nil, + ErrorValue: true, + }, + }, + }, }, }, nil } @@ -339,9 +440,10 @@ func (r *runnableStep) Start(_ map[string]any, runID string, stageChangeHandler ctx: ctx, cancel: cancel, lock: &sync.Mutex{}, - currentStage: StageIDExecute, + currentStage: StageIDEnabling, currentState: step.RunningStepStateStarting, inputData: make(chan []any, 1), + enabledInput: make(chan bool, 1), workflow: r.workflow, stageChangeHandler: stageChangeHandler, parallelism: r.parallelism, @@ -352,26 +454,28 @@ func (r *runnableStep) Start(_ map[string]any, runID string, stageChangeHandler } type runningStep struct { - runID string - workflow workflow.ExecutableWorkflow - currentStage StageID - lock *sync.Mutex - currentState step.RunningStepState - inputAvailable bool - inputData chan []any - ctx context.Context - closed bool - wg sync.WaitGroup - cancel context.CancelFunc - stageChangeHandler step.StageChangeHandler - parallelism int64 - logger log.Logger + runID string + workflow workflow.ExecutableWorkflow + currentStage StageID + lock *sync.Mutex + currentState step.RunningStepState + executionInputAvailable bool + inputData chan []any + enabledInput chan bool + enabledInputAvailable bool + ctx context.Context + closed atomic.Bool + wg sync.WaitGroup + cancel context.CancelFunc + stageChangeHandler step.StageChangeHandler + parallelism int64 + logger log.Logger } func (r *runningStep) ProvideStageInput(stage string, input map[string]any) error { r.lock.Lock() defer r.lock.Unlock() - if r.closed { + if r.closed.Load() { r.logger.Debugf("exiting foreach ProvideStageInput due to step being closed") return nil } @@ -388,24 +492,41 @@ func (r *runningStep) ProvideStageInput(stage string, input map[string]any) erro } input[i] = item } - if r.inputAvailable { + if r.executionInputAvailable { return fmt.Errorf("input for execute workflow provided twice for run/step %s", r.runID) } if r.currentState == step.RunningStepStateWaitingForInput && r.currentStage == StageIDExecute { r.currentState = step.RunningStepStateRunning } - r.inputAvailable = true + r.executionInputAvailable = true r.inputData <- input // Send before unlock to ensure that it never gets closed before sending. return nil case string(StageIDOutputs): return nil case string(StageIDFailed): return nil + case string(StageIDEnabling): + return r.provideEnablingInput(input) + case string(StageIDDisabled): + return nil default: return fmt.Errorf("invalid stage: %s", stage) } } +func (r *runningStep) provideEnablingInput(input map[string]any) error { + // Note: The calling function must have the step mutex locked + if r.enabledInputAvailable { + return fmt.Errorf("enabled input provided more than once") + } + // Check to make sure it's enabled. + // This is an optional field, so no input means enabled. + enabled := input["enabled"] == nil || input["enabled"] == true + r.enabledInputAvailable = true + r.enabledInput <- enabled + return nil +} + func (r *runningStep) CurrentStage() string { r.lock.Lock() defer r.lock.Unlock() @@ -420,7 +541,7 @@ func (r *runningStep) State() step.RunningStepState { func (r *runningStep) Close() error { r.lock.Lock() - r.closed = true + r.closed.Load() r.lock.Unlock() r.cancel() r.wg.Wait() @@ -441,14 +562,33 @@ func (r *runningStep) run() { r.wg.Done() }() waitingForInput := false + enabled, contextDoneEarly := r.enableStage() + if contextDoneEarly { + r.closedEarly(StageIDExecute, true) + return + } + if !enabled { + r.transitionToDisabled() + return + } + + var newState step.RunningStepState r.lock.Lock() - if !r.inputAvailable { - r.currentState = step.RunningStepStateWaitingForInput + if !r.executionInputAvailable { + newState = step.RunningStepStateWaitingForInput waitingForInput = true } else { - r.currentState = step.RunningStepStateRunning + newState = step.RunningStepStateRunning } r.lock.Unlock() + enabledOutput := any(map[any]any{"enabled": true}) + // End Enabling with resolved output, and start starting + r.transitionStageWithOutput( + StageIDExecute, + newState, + schema.PointerTo("resolved"), + &enabledOutput, + ) r.stageChangeHandler.OnStageChange( r, nil, @@ -461,6 +601,160 @@ func (r *runningStep) run() { r.runOnInput() } +// enableStage returns the result of whether the stage was enabled or not. +// Return values: +// - bool: Whether the step was enabled. +// - bool: True if the step was disabled due to context done. +func (r *runningStep) enableStage() (bool, bool) { + // Enabling is the first stage, so do not transition out of it. + var enabled bool + select { + case enabled = <-r.enabledInput: + case <-r.ctx.Done(): + return false, true + } + + if enabled { + // It's enabled, so the disabled stage will not occur. + r.stageChangeHandler.OnStepStageFailure(r, string(StageIDDisabled), &r.wg, fmt.Errorf("step enabled; cannot be disabled anymore")) + } + return enabled, false +} + +func (r *runningStep) closedEarly(stageToMarkUnresolvable StageID, priorStageFailed bool) { + r.logger.Infof("Step foreach %s closed", r.runID) + // Follow the convention of transitioning to running then finished. + if priorStageFailed { + r.transitionFromFailedStage(StageIDClosed, step.RunningStepStateRunning, fmt.Errorf("step closed early")) + } else { + r.transitionRunningStage(StageIDClosed) + } + closedOutput := any(map[any]any{"close_requested": r.closed.Load()}) + + r.completeStep( + StageIDClosed, + step.RunningStepStateFinished, + schema.PointerTo("result"), + &closedOutput, + ) + + err := fmt.Errorf("step foreach %s closed due to workflow termination", r.runID) + r.markStageFailures(stageToMarkUnresolvable, err) +} + +func (r *runningStep) transitionToDisabled() { + r.logger.Infof("Step foreach %s disabled", r.runID) + enabledOutput := any(map[any]any{"enabled": false}) + // End prior stage "enabling" with "resolved" output, and start "disabled" stage. + r.transitionStageWithOutput( + StageIDDisabled, + step.RunningStepStateRunning, + schema.PointerTo("resolved"), + &enabledOutput, + ) + disabledOutput := any(map[any]any{"message": fmt.Sprintf("Step foreach %s disabled", r.runID)}) + r.completeStep( + StageIDDisabled, + step.RunningStepStateFinished, // Must set the stage to finished for the engine realize the step is done. + schema.PointerTo("output"), + &disabledOutput, + ) + + err := fmt.Errorf("step foreach %s disabled", r.runID) + r.markStageFailures(StageIDExecute, err) + r.markNotClosable(err) +} + +// Closable is the graceful case, so this is necessary if it crashes. +func (r *runningStep) markNotClosable(err error) { + r.stageChangeHandler.OnStepStageFailure(r, string(StageIDClosed), &r.wg, err) +} + +// TransitionStage transitions the running step to the specified stage, and the state running. +// For other situations, use transitionFromFailedStage, completeStep, or transitionStageWithOutput. +func (r *runningStep) transitionRunningStage(newStage StageID) { + r.transitionStageWithOutput(newStage, step.RunningStepStateRunning, nil, nil) +} + +func (r *runningStep) transitionFromFailedStage(newStage StageID, state step.RunningStepState, err error) { + r.lock.Lock() + previousStage := string(r.currentStage) + r.currentStage = newStage + // Don't forget to update this, or else it will behave very oddly. + // First running, then finished. You can't skip states. + r.currentState = state + r.lock.Unlock() + r.stageChangeHandler.OnStepStageFailure( + r, + previousStage, + &r.wg, + err, + ) +} + +// TransitionStage transitions the stage to the specified stage, and the state to the specified state. +func (r *runningStep) transitionStageWithOutput( + newStage StageID, + state step.RunningStepState, + outputID *string, + previousStageOutput *any, +) { + // A current lack of observability into the atp client prevents + // non-fragile testing of this function. + r.lock.Lock() + previousStage := string(r.currentStage) + r.currentStage = newStage + // Don't forget to update this, or else it will behave very oddly. + // First running, then finished. You can't skip states. + r.currentState = state + r.lock.Unlock() + r.stageChangeHandler.OnStageChange( + r, + &previousStage, + outputID, + previousStageOutput, + string(newStage), + false, + &r.wg, + ) +} + +//nolint:unparam // Currently only gets state finished, but that's okay. +//nolint:nolintlint // Differing versions of the linter do or do not care. +func (r *runningStep) completeStep(currentStage StageID, state step.RunningStepState, outputID *string, previousStageOutput *any) { + r.lock.Lock() + previousStage := string(r.currentStage) + r.currentStage = currentStage + r.currentState = state + r.lock.Unlock() + + r.stageChangeHandler.OnStepComplete( + r, + previousStage, + outputID, + previousStageOutput, + &r.wg, + ) +} + +func (r *runningStep) markStageFailures(firstStage StageID, err error) { + switch firstStage { + case StageIDEnabling: + r.stageChangeHandler.OnStepStageFailure(r, string(StageIDEnabling), &r.wg, err) + fallthrough + case StageIDDisabled: + r.stageChangeHandler.OnStepStageFailure(r, string(StageIDDisabled), &r.wg, err) + fallthrough + case StageIDExecute: + r.stageChangeHandler.OnStepStageFailure(r, string(StageIDExecute), &r.wg, err) + fallthrough + case StageIDOutputs: + r.stageChangeHandler.OnStepStageFailure(r, string(StageIDOutputs), &r.wg, err) + default: + panic("unknown StageID " + firstStage) + } +} + func (r *runningStep) runOnInput() { select { case loopData, ok := <-r.inputData: @@ -476,6 +770,7 @@ func (r *runningStep) runOnInput() { } func (r *runningStep) processInput(inputData []any) { + // TODO: Transition to reusable functions r.logger.Debugf("Executing subworkflow for step %s...", r.runID) outputs, errors := r.executeSubWorkflows(inputData) diff --git a/internal/step/plugin/provider.go b/internal/step/plugin/provider.go index 23ed21a4..8a8a7437 100644 --- a/internal/step/plugin/provider.go +++ b/internal/step/plugin/provider.go @@ -393,60 +393,6 @@ func (r *runnableStep) StartedSchema() *schema.StepOutputSchema { ) } -func (r *runnableStep) EnabledOutputSchema() *schema.StepOutputSchema { - return schema.NewStepOutputSchema( - schema.NewScopeSchema( - schema.NewObjectSchema( - "EnabledOutput", - map[string]*schema.PropertySchema{ - "enabled": schema.NewPropertySchema( - schema.NewBoolSchema(), - schema.NewDisplayValue( - schema.PointerTo("enabled"), - schema.PointerTo("Whether the step was enabled"), - nil), - true, - nil, - nil, - nil, - nil, - nil, - ), - }, - ), - ), - nil, - false, - ) -} - -func (r *runnableStep) DisabledOutputSchema() *schema.StepOutputSchema { - return schema.NewStepOutputSchema( - schema.NewScopeSchema( - schema.NewObjectSchema( - "DisabledMessageOutput", - map[string]*schema.PropertySchema{ - "message": schema.NewPropertySchema( - schema.NewStringSchema(nil, nil, nil), - schema.NewDisplayValue( - schema.PointerTo("message"), - schema.PointerTo("A human readable message stating that the step was disabled."), - nil), - true, - nil, - nil, - nil, - nil, - nil, - ), - }, - ), - ), - nil, - false, - ) -} - const defaultClosureTimeout = 5000 func closureTimeoutSchema() *schema.PropertySchema { @@ -584,7 +530,7 @@ func (r *runnableStep) Lifecycle(input map[string]any) (result step.Lifecycle[st ), }, Outputs: map[string]*schema.StepOutputSchema{ - "resolved": r.EnabledOutputSchema(), + "resolved": step.EnabledOutputSchema(), }, }, { @@ -636,7 +582,7 @@ func (r *runnableStep) Lifecycle(input map[string]any) (result step.Lifecycle[st LifecycleStage: disabledLifecycleStage, InputSchema: nil, Outputs: map[string]*schema.StepOutputSchema{ - "output": r.DisabledOutputSchema(), + "output": step.DisabledOutputSchema(), }, }, { @@ -866,6 +812,8 @@ func (r *runningStep) ProvideStageInput(stage string, input map[string]any) erro return nil case string(StageIDOutput): return nil + case string(StageIDDisabled): + return nil default: return fmt.Errorf("bug: invalid stage: %s", stage) } diff --git a/internal/step/shared_schema.go b/internal/step/shared_schema.go new file mode 100644 index 00000000..4a6ab88e --- /dev/null +++ b/internal/step/shared_schema.go @@ -0,0 +1,57 @@ +package step + +import "go.flow.arcalot.io/pluginsdk/schema" + +func EnabledOutputSchema() *schema.StepOutputSchema { + return schema.NewStepOutputSchema( + schema.NewScopeSchema( + schema.NewObjectSchema( + "EnabledOutput", + map[string]*schema.PropertySchema{ + "enabled": schema.NewPropertySchema( + schema.NewBoolSchema(), + schema.NewDisplayValue( + schema.PointerTo("enabled"), + schema.PointerTo("Whether the step was enabled"), + nil), + true, + nil, + nil, + nil, + nil, + nil, + ), + }, + ), + ), + nil, + false, + ) +} + +func DisabledOutputSchema() *schema.StepOutputSchema { + return schema.NewStepOutputSchema( + schema.NewScopeSchema( + schema.NewObjectSchema( + "DisabledMessageOutput", + map[string]*schema.PropertySchema{ + "message": schema.NewPropertySchema( + schema.NewStringSchema(nil, nil, nil), + schema.NewDisplayValue( + schema.PointerTo("message"), + schema.PointerTo("A human readable message stating that the step was disabled."), + nil), + true, + nil, + nil, + nil, + nil, + nil, + ), + }, + ), + ), + nil, + false, + ) +} diff --git a/workflow/workflow_test.go b/workflow/workflow_test.go index ae1c7c8a..0abf285c 100644 --- a/workflow/workflow_test.go +++ b/workflow/workflow_test.go @@ -1327,6 +1327,125 @@ func TestGracefullyDisabledStepWorkflow(t *testing.T) { assert.Equals(t, outputDataMap["result"], "disabled_wait_output") } +var gracefullyDisabledForeachStepWorkflow = ` +version: v0.2.0 +input: + root: WorkflowInput + objects: + WorkflowInput: + id: WorkflowInput + properties: + step_enabled: + type: + type_id: bool +steps: + subwf_step: + kind: foreach + items: + - {} + workflow: subworkflow.yaml + enabled: !expr $.input.step_enabled +outputs: + parent_success: + subwf_output: !ordisabled $.steps.subwf_step.outputs +` + +var simpleSubWf = ` +version: v0.2.0 +input: + root: WorkflowInput + objects: + WorkflowInput: + id: WorkflowInput + properties: {} +steps: + simple_wait: + plugin: + src: "n/a" + deployment_type: "builtin" + step: wait + input: + wait_time_ms: 0 +outputs: + success: !expr $.steps.simple_wait.outputs.success +` + +func TestGracefullyDisabledForeachStepWorkflow(t *testing.T) { + // Run a workflow where both the disabled output and the success output + // result in a single valid workflow output, but use + logConfig := log.Config{ + Level: log.LevelDebug, + Destination: log.DestinationStdout, + } + logger := log.New( + logConfig, + ) + cfg := &config.Config{ + Log: logConfig, + } + factories := workflowFactory{ + config: cfg, + } + deployerRegistry := deployerregistry.New( + deployer.Any(testimpl.NewFactory()), + ) + + pluginProvider := assert.NoErrorR[step.Provider](t)( + plugin.New(logger, deployerRegistry, map[string]interface{}{ + "builtin": map[string]any{ + "deployer_name": "test-impl", + "deploy_time": "0", + }, + }), + ) + stepRegistry, err := stepregistry.New( + pluginProvider, + lang.Must2(foreach.New(logger, factories.createYAMLParser, factories.createWorkflow)), + ) + assert.NoError(t, err) + + factories.stepRegistry = stepRegistry + executor := lang.Must2(workflow.NewExecutor( + logger, + cfg, + stepRegistry, + builtinfunctions.GetFunctions(), + )) + wf := lang.Must2(workflow.NewYAMLConverter(stepRegistry).FromYAML([]byte(gracefullyDisabledForeachStepWorkflow))) + preparedWorkflow := lang.Must2(executor.Prepare(wf, map[string][]byte{ + "subworkflow.yaml": []byte(simpleSubWf), + })) + outputID, outputData, err := preparedWorkflow.Execute(context.Background(), map[string]any{ + "step_enabled": true, + }) + assert.NoError(t, err) + assert.Equals(t, outputID, "parent_success") + assert.Equals(t, outputData.(map[any]any), map[any]any{ + "subwf_output": map[any]any{ + "result": "enabled", + "success": map[string]any{ + "data": []any{ + map[any]any{ + "message": "Plugin slept for 0 ms.", + }, + }, + }, + }, + }) + // Test step disabled case + outputID, outputData, err = preparedWorkflow.Execute(context.Background(), map[string]any{ + "step_enabled": false, + }) + assert.NoError(t, err) + assert.Equals(t, outputID, "parent_success") + assert.Equals(t, outputData.(map[any]any), map[any]any{ + "subwf_output": map[any]any{ + "result": "disabled", + "message": "Step foreach subwf_step disabled", + }, + }) +} + var shorthandGracefullyDisabledStepWorkflow = ` version: v0.2.0 input: