diff --git a/cmd/arcaflow/main.go b/cmd/arcaflow/main.go index 967ebf90..b87ca800 100644 --- a/cmd/arcaflow/main.go +++ b/cmd/arcaflow/main.go @@ -6,6 +6,7 @@ import ( "flag" "fmt" "os" + "os/signal" "path/filepath" "strings" @@ -164,13 +165,21 @@ Options: func runWorkflow(flow engine.WorkflowEngine, dirContext map[string][]byte, workflowFile string, logger log.Logger, inputData []byte) int { ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctrlC := make(chan os.Signal, 4) // We expect up to three ctrl-C inputs. Plus one extra to buffer in case. + signal.Notify(ctrlC, os.Interrupt) + + go handleOSInterrupt(ctrlC, cancel, logger) + defer func() { + close(ctrlC) // Ensure that the goroutine exits + cancel() + }() workflow, err := flow.Parse(dirContext, workflowFile) if err != nil { logger.Errorf("Invalid workflow (%v)", err) return ExitCodeInvalidData } + outputID, outputData, outputError, err := workflow.Run(ctx, inputData) if err != nil { logger.Errorf("Workflow execution failed (%v)", err) @@ -193,6 +202,28 @@ func runWorkflow(flow engine.WorkflowEngine, dirContext map[string][]byte, workf return ExitCodeOK } +func handleOSInterrupt(ctrlC chan os.Signal, cancel context.CancelFunc, logger log.Logger) { + _, ok := <-ctrlC + if !ok { + return + } + logger.Infof("Requesting graceful shutdown.") + cancel() + + _, ok = <-ctrlC + if !ok { + return + } + logger.Warningf("Hit CTRL-C again to forcefully exit workflow without cleanup. You may need to manually delete pods or containers.") + + _, ok = <-ctrlC + if !ok { + return + } + logger.Warningf("Force exiting. You may need to manually delete pods or containers.") + os.Exit(1) +} + func loadContext(dir string) (map[string][]byte, error) { absDir, err := filepath.Abs(dir) if err != nil { diff --git a/cmd/run-plugin/run.go b/cmd/run-plugin/run.go index f3a19138..7675b704 100644 --- a/cmd/run-plugin/run.go +++ b/cmd/run-plugin/run.go @@ -91,7 +91,7 @@ func main() { if err != nil { panic(err) } - ctrlC := make(chan os.Signal, 1) + ctrlC := make(chan os.Signal, 1) // Buffer of one to properly buffer if the signal is sent early. signal.Notify(ctrlC, os.Interrupt) // Set up the signal channel to send cancel signal on ctrl-c diff --git a/engine.go b/engine.go index d82687a3..0411a9dc 100644 --- a/engine.go +++ b/engine.go @@ -62,7 +62,12 @@ type workflowEngine struct { config *config.Config } -func (w workflowEngine) RunWorkflow(ctx context.Context, input []byte, workflowContext map[string][]byte, workflowFileName string) (outputID string, outputData any, outputError bool, err error) { +func (w workflowEngine) RunWorkflow( + ctx context.Context, + input []byte, + workflowContext map[string][]byte, + workflowFileName string, +) (outputID string, outputData any, outputError bool, err error) { wf, err := w.Parse(workflowContext, workflowFileName) if err != nil { return "", nil, true, err @@ -126,7 +131,10 @@ type engineWorkflow struct { workflow workflow.ExecutableWorkflow } -func (e engineWorkflow) Run(ctx context.Context, input []byte) (outputID string, outputData any, outputIsError bool, err error) { +func (e engineWorkflow) Run( + ctx context.Context, + input []byte, +) (outputID string, outputData any, outputIsError bool, err error) { decodedInput, err := yaml.New().Parse(input) if err != nil { return "", nil, true, fmt.Errorf("failed to YAML decode input (%w)", err) diff --git a/workflow/workflow_test.go b/workflow/workflow_test.go index 330b7ed0..350cca7c 100644 --- a/workflow/workflow_test.go +++ b/workflow/workflow_test.go @@ -493,3 +493,65 @@ func TestMissingInputsWrongOutput(t *testing.T) { assert.Error(t, err) assert.Equals(t, outputID, "") } + +var fiveSecWaitWorkflowDefinition = ` +version: v0.2.0 +input: + root: RootObject + objects: + RootObject: + id: RootObject + properties: {} +steps: + long_wait: + plugin: + src: "n/a" + deployment_type: "builtin" + step: wait + input: + wait_time_ms: 5000 +outputs: + success: + first_step_output: !expr $.steps.long_wait.outputs +` + +func TestEarlyContextCancellation(t *testing.T) { + // For this test, a workflow runs two steps, where each step runs a wait step for 5s + // The second wait step waits for the first to succeed after which it runs + // Due to the wait for condition, the steps will execute serially + // The total execution time for this test function should be greater than 10seconds + // as each step runs for 5s and are run serially + // The test double deployer will be used for this test, as we + // need a deployer to test the plugin step provider. + logConfig := log.Config{ + Level: log.LevelInfo, + Destination: log.DestinationStdout, + } + logger := log.New( + logConfig, + ) + cfg := &config.Config{ + Log: logConfig, + } + stepRegistry := NewTestImplStepRegistry(logger, t) + + executor := lang.Must2(workflow.NewExecutor( + logger, + cfg, + stepRegistry, + )) + wf := lang.Must2(workflow.NewYAMLConverter(stepRegistry).FromYAML([]byte(fiveSecWaitWorkflowDefinition))) + preparedWorkflow := lang.Must2(executor.Prepare(wf, map[string][]byte{})) + // Cancel the context after 3 ms to simulate cancellation with ctrl-c. + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*3) + startTime := time.Now() // Right before execute to not include pre-processing time. + //nolint:dogsled + _, _, _ = preparedWorkflow.Execute(ctx, map[string]any{}) + cancel() + + duration := time.Since(startTime) + t.Logf("Test execution time: %s", duration) + if duration >= 1000*time.Millisecond { + t.Fatalf("Test execution time is greater than 1000 milliseconds; Is the workflow properly cancelling?") + } +}