diff --git a/src/config/config.go b/src/config/config.go index 0a4a366..0e2adcb 100644 --- a/src/config/config.go +++ b/src/config/config.go @@ -36,6 +36,9 @@ var ( // VendorPrefix is the prefix for environment variables that an application vendoring Maru wants to use VendorPrefix string + // MaxStack is the maximum stack size for task references + MaxStack = 2048 + extraEnv = map[string]string{"MARU": "true", "MARU_ARCH": GetArch()} ) diff --git a/src/pkg/runner/actions_test.go b/src/pkg/runner/actions_test.go index a82e467..944dc26 100644 --- a/src/pkg/runner/actions_test.go +++ b/src/pkg/runner/actions_test.go @@ -222,10 +222,10 @@ func Test_validateActionableTaskCall(t *testing.T) { func TestRunner_performAction(t *testing.T) { type fields struct { - TasksFile types.TasksFile - TaskNameMap map[string]bool - envFilePath string - variableConfig *variables.VariableConfig[variables.ExtraVariableInfo] + TasksFile types.TasksFile + ExistingTaskIncludeNameLocation map[string]string + envFilePath string + variableConfig *variables.VariableConfig[variables.ExtraVariableInfo] } type args struct { action types.Action @@ -243,10 +243,10 @@ func TestRunner_performAction(t *testing.T) { { name: "failed action processing due to invalid command", fields: fields{ - TasksFile: types.TasksFile{}, - TaskNameMap: make(map[string]bool), - envFilePath: "", - variableConfig: GetMaruVariableConfig(), + TasksFile: types.TasksFile{}, + ExistingTaskIncludeNameLocation: make(map[string]string), + envFilePath: "", + variableConfig: GetMaruVariableConfig(), }, args: args{ action: types.Action{ @@ -264,10 +264,10 @@ func TestRunner_performAction(t *testing.T) { { name: "Unable to open path", fields: fields{ - TasksFile: types.TasksFile{}, - TaskNameMap: make(map[string]bool), - envFilePath: "test/path", - variableConfig: GetMaruVariableConfig(), + TasksFile: types.TasksFile{}, + ExistingTaskIncludeNameLocation: make(map[string]string), + envFilePath: "test/path", + variableConfig: GetMaruVariableConfig(), }, args: args{ action: types.Action{ @@ -293,10 +293,10 @@ func TestRunner_performAction(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { r := &Runner{ - TasksFile: tt.fields.TasksFile, - TaskNameMap: tt.fields.TaskNameMap, - envFilePath: tt.fields.envFilePath, - variableConfig: tt.fields.variableConfig, + TasksFile: tt.fields.TasksFile, + ExistingTaskIncludeNameLocation: tt.fields.ExistingTaskIncludeNameLocation, + envFilePath: tt.fields.envFilePath, + variableConfig: tt.fields.variableConfig, } err := r.performAction(tt.args.action, tt.args.withs, tt.args.inputs) if (err != nil) != tt.wantErr { @@ -308,10 +308,10 @@ func TestRunner_performAction(t *testing.T) { func TestRunner_processAction(t *testing.T) { type fields struct { - TasksFile types.TasksFile - TaskNameMap map[string]bool - envFilePath string - variableConfig *variables.VariableConfig[variables.ExtraVariableInfo] + TasksFile types.TasksFile + ExistingTaskIncludeNameLocation map[string]string + envFilePath string + variableConfig *variables.VariableConfig[variables.ExtraVariableInfo] } type args struct { task types.Task @@ -326,10 +326,10 @@ func TestRunner_processAction(t *testing.T) { { name: "successful action processing", fields: fields{ - TasksFile: types.TasksFile{}, - TaskNameMap: map[string]bool{}, - envFilePath: "", - variableConfig: GetMaruVariableConfig(), + TasksFile: types.TasksFile{}, + ExistingTaskIncludeNameLocation: map[string]string{}, + envFilePath: "", + variableConfig: GetMaruVariableConfig(), }, args: args{ task: types.Task{ @@ -344,10 +344,10 @@ func TestRunner_processAction(t *testing.T) { { name: "action processing with same task and action reference", fields: fields{ - TasksFile: types.TasksFile{}, - TaskNameMap: map[string]bool{}, - envFilePath: "", - variableConfig: GetMaruVariableConfig(), + TasksFile: types.TasksFile{}, + ExistingTaskIncludeNameLocation: map[string]string{}, + envFilePath: "", + variableConfig: GetMaruVariableConfig(), }, args: args{ task: types.Task{ @@ -362,10 +362,10 @@ func TestRunner_processAction(t *testing.T) { { name: "action processing with empty task reference", fields: fields{ - TasksFile: types.TasksFile{}, - TaskNameMap: map[string]bool{}, - envFilePath: "", - variableConfig: GetMaruVariableConfig(), + TasksFile: types.TasksFile{}, + ExistingTaskIncludeNameLocation: map[string]string{}, + envFilePath: "", + variableConfig: GetMaruVariableConfig(), }, args: args{ task: types.Task{ @@ -380,10 +380,10 @@ func TestRunner_processAction(t *testing.T) { { name: "action processing with non-empty task reference and different task and action reference names", fields: fields{ - TasksFile: types.TasksFile{}, - TaskNameMap: map[string]bool{}, - envFilePath: "", - variableConfig: GetMaruVariableConfig(), + TasksFile: types.TasksFile{}, + ExistingTaskIncludeNameLocation: map[string]string{}, + envFilePath: "", + variableConfig: GetMaruVariableConfig(), }, args: args{ task: types.Task{ @@ -405,9 +405,9 @@ func TestRunner_processAction(t *testing.T) { }, }, }, - TaskNameMap: map[string]bool{}, - envFilePath: "", - variableConfig: GetMaruVariableConfig(), + ExistingTaskIncludeNameLocation: map[string]string{}, + envFilePath: "", + variableConfig: GetMaruVariableConfig(), }, args: args{ task: types.Task{ @@ -423,10 +423,10 @@ func TestRunner_processAction(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { r := &Runner{ - TasksFile: tt.fields.TasksFile, - TaskNameMap: tt.fields.TaskNameMap, - envFilePath: tt.fields.envFilePath, - variableConfig: tt.fields.variableConfig, + TasksFile: tt.fields.TasksFile, + ExistingTaskIncludeNameLocation: tt.fields.ExistingTaskIncludeNameLocation, + envFilePath: tt.fields.envFilePath, + variableConfig: tt.fields.variableConfig, } if got := r.processAction(tt.args.task, tt.args.action); got != tt.want { t.Errorf("processAction() got = %v, want %v", got, tt.want) diff --git a/src/pkg/runner/runner.go b/src/pkg/runner/runner.go index 0a1d686..1f73b3f 100644 --- a/src/pkg/runner/runner.go +++ b/src/pkg/runner/runner.go @@ -23,11 +23,12 @@ import ( // Runner holds the necessary data to run tasks from a tasks file type Runner struct { - TasksFile types.TasksFile - TaskNameMap map[string]bool - envFilePath string - variableConfig *variables.VariableConfig[variables.ExtraVariableInfo] - dryRun bool + TasksFile types.TasksFile + ExistingTaskIncludeNameLocation map[string]string + envFilePath string + variableConfig *variables.VariableConfig[variables.ExtraVariableInfo] + dryRun bool + currStackSize int } // Run runs a task from tasks file @@ -62,10 +63,10 @@ func Run(tasksFile types.TasksFile, taskName string, setVariables map[string]str // Create the runner client to execute the task file runner := Runner{ - TasksFile: tasksFile, - TaskNameMap: map[string]bool{}, - variableConfig: combinedVariableConfig, - dryRun: dryRun, + TasksFile: tasksFile, + ExistingTaskIncludeNameLocation: map[string]string{}, + variableConfig: combinedVariableConfig, + dryRun: dryRun, } task, err := runner.getTask(taskName) @@ -78,7 +79,7 @@ func Run(tasksFile types.TasksFile, taskName string, setVariables map[string]str return err } - if err = runner.checkForTaskLoops(task, runner.TasksFile, setVariables); err != nil { + if err = runner.processTaskReferences(task, runner.TasksFile, setVariables); err != nil { return err } @@ -113,63 +114,69 @@ func (r *Runner) processIncludes(tasksFile types.TasksFile, setVariables map[str func (r *Runner) importTasks(includes []map[string]string, currentFileLocation string, setVariables map[string]string) error { // iterate through includes, open the file, and unmarshal it into a Task - var includeFileLocationKey string - var includeFileLocation string + var includeKey string + var includeLocation string for _, include := range includes { if len(include) > 1 { return fmt.Errorf("included item %s must have only one key", include) } // grab first and only value from include map for k, v := range include { - includeFileLocationKey = k - includeFileLocation = v + includeKey = k + includeLocation = v break } - includeFileLocation = utils.TemplateString(r.variableConfig.GetSetVariables(), includeFileLocation) + includeLocation = utils.TemplateString(r.variableConfig.GetSetVariables(), includeLocation) - absIncludeFileLocation, tasksFile, err := loadIncludeTask(currentFileLocation, includeFileLocation) + absIncludeFileLocation, tasksFile, err := loadIncludeTask(currentFileLocation, includeLocation) if err != nil { return fmt.Errorf("unable to read included file: %w", err) } + // If we arrive here we assume this was a new include due to the later check + r.ExistingTaskIncludeNameLocation[includeKey] = absIncludeFileLocation // prefix task names and actions with the includes key for i, t := range tasksFile.Tasks { - tasksFile.Tasks[i].Name = includeFileLocationKey + ":" + t.Name + tasksFile.Tasks[i].Name = includeKey + ":" + t.Name if len(tasksFile.Tasks[i].Actions) > 0 { for j, a := range tasksFile.Tasks[i].Actions { if a.TaskReference != "" && !strings.Contains(a.TaskReference, ":") { - tasksFile.Tasks[i].Actions[j].TaskReference = includeFileLocationKey + ":" + a.TaskReference + tasksFile.Tasks[i].Actions[j].TaskReference = includeKey + ":" + a.TaskReference } } } } - err = r.checkProcessedTasksForLoops(tasksFile) - if err != nil { - return err - } - r.TasksFile.Tasks = append(r.TasksFile.Tasks, tasksFile.Tasks...) r.mergeVariablesFromIncludedTask(tasksFile) // recursively import tasks from included files if tasksFile.Includes != nil { - if err := r.importTasks(tasksFile.Includes, absIncludeFileLocation, setVariables); err != nil { - return err + newIncludes := []map[string]string{} + var newIncludeKey string + var newIncludeLocation string + for _, newInclude := range tasksFile.Includes { + for k, v := range newInclude { + newIncludeKey = k + newIncludeLocation = v + break + } + if existingLocation, exists := r.ExistingTaskIncludeNameLocation[newIncludeKey]; !exists { + newIncludes = append(newIncludes, map[string]string{newIncludeKey: newIncludeLocation}) + } else { + newAbsIncludeFileLocation, err := includeTaskAbsLocation(absIncludeFileLocation, newIncludeLocation) + if err != nil { + return err + } + if existingLocation != newAbsIncludeFileLocation { + return fmt.Errorf("task include %q attempted to be redefined from %q to %q", includeKey, existingLocation, newAbsIncludeFileLocation) + } + } } - } - } - return nil -} - -func (r *Runner) checkProcessedTasksForLoops(tasksFile types.TasksFile) error { - // The following for loop protects against task loops. Makes sure the task being added hasn't already been processed - for _, taskToAdd := range tasksFile.Tasks { - for _, currentTasks := range r.TasksFile.Tasks { - if taskToAdd.Name == currentTasks.Name { - return fmt.Errorf("task loop detected, ensure no cyclic loops in tasks or includes files") + if err := r.importTasks(newIncludes, absIncludeFileLocation, setVariables); err != nil { + return err } } } @@ -207,17 +214,14 @@ func loadIncludedTaskFile(taskFile types.TasksFile, taskName string, setVariable return taskFile, taskName, nil } -func loadIncludeTask(currentFileLocation, includeFileLocation string) (string, types.TasksFile, error) { - var localPath string - var includedTasksFile types.TasksFile +func includeTaskAbsLocation(currentFileLocation, includeFileLocation string) (string, error) { var absIncludeFileLocation string - var err error if !helpers.IsURL(includeFileLocation) { if helpers.IsURL(currentFileLocation) { currentURL, err := url.Parse(currentFileLocation) if err != nil { - return absIncludeFileLocation, includedTasksFile, err + return absIncludeFileLocation, err } currentURL.Path = path.Join(path.Dir(currentURL.Path), includeFileLocation) absIncludeFileLocation = currentURL.String() @@ -228,6 +232,17 @@ func loadIncludeTask(currentFileLocation, includeFileLocation string) (string, t } else { absIncludeFileLocation = includeFileLocation } + return absIncludeFileLocation, nil +} + +func loadIncludeTask(currentFileLocation, includeFileLocation string) (string, types.TasksFile, error) { + var localPath string + var includedTasksFile types.TasksFile + + absIncludeFileLocation, err := includeTaskAbsLocation(currentFileLocation, includeFileLocation) + if err != nil { + return absIncludeFileLocation, includedTasksFile, err + } // If the file is in fact a URL we need to download and load the YAML if helpers.IsURL(absIncludeFileLocation) { @@ -262,6 +277,15 @@ func (r *Runner) getTask(taskName string) (types.Task, error) { } func (r *Runner) executeTask(task types.Task, withs map[string]string) error { + if r.currStackSize > config.MaxStack { + return fmt.Errorf("task looping exceeded max configured task stack of %d", config.MaxStack) + } + + r.currStackSize++ + defer func() { + r.currStackSize-- + }() + defaultEnv := []string{} for name, inputParam := range task.Inputs { d := inputParam.Default @@ -278,15 +302,24 @@ func (r *Runner) executeTask(task types.Task, withs map[string]string) error { for _, action := range task.Actions { action.Env = utils.MergeEnv(action.Env, defaultEnv) - if err := r.performAction(action, withs, task.Inputs); err != nil { return err } } + return nil } -func (r *Runner) checkForTaskLoops(task types.Task, tasksFile types.TasksFile, setVariables map[string]string) error { +func (r *Runner) processTaskReferences(task types.Task, tasksFile types.TasksFile, setVariables map[string]string) error { + if r.currStackSize > config.MaxStack { + return fmt.Errorf("task looping exceeded max configured task stack of %d", config.MaxStack) + } + + r.currStackSize++ + defer func() { + r.currStackSize-- + }() + // Filtering unique task actions allows for rerunning tasks in the same execution uniqueTaskActions := getUniqueTaskActions(task.Actions) for _, action := range uniqueTaskActions { @@ -296,21 +329,14 @@ func (r *Runner) checkForTaskLoops(task types.Task, tasksFile types.TasksFile, s return err } - exists := r.TaskNameMap[action.TaskReference] - if exists { - return fmt.Errorf("task loop detected, ensure no cyclic loops in tasks or includes files") - } - r.TaskNameMap[action.TaskReference] = true newTask, err := r.getTask(action.TaskReference) if err != nil { return err } - if err = r.checkForTaskLoops(newTask, tasksFile, setVariables); err != nil { + if err = r.processTaskReferences(newTask, tasksFile, setVariables); err != nil { return err } } - // Clear map once we get to a task that doesn't call another task - clear(r.TaskNameMap) } return nil } diff --git a/src/test/e2e/runner_test.go b/src/test/e2e/runner_test.go index 6248bbf..21440a3 100644 --- a/src/test/e2e/runner_test.go +++ b/src/test/e2e/runner_test.go @@ -60,15 +60,30 @@ func TestTaskRunner(t *testing.T) { stdOut, stdErr, err := e2e.Maru("run", "recursive", "--file", "src/test/tasks/tasks.yaml") require.Error(t, err, stdOut, stdErr) - require.Contains(t, stdErr, "task loop detected") + require.Contains(t, stdErr, "task looping exceeded max configured task stack") }) - t.Run("includes task loop", func(t *testing.T) { + t.Run("run direct loop", func(t *testing.T) { t.Parallel() - stdOut, stdErr, err := e2e.Maru("run", "include-loop", "--file", "src/test/tasks/tasks.yaml") + stdOut, stdErr, err := e2e.Maru("run", "direct-loop", "--file", "src/test/tasks/loop-task.yaml") require.Error(t, err, stdOut, stdErr) - require.Contains(t, stdErr, "task loop detected") + require.Contains(t, stdErr, "task looping exceeded max configured task stack") + }) + + t.Run("includes intentional task loop", func(t *testing.T) { + t.Parallel() + + // get current git revision + gitRev, err := e2e.GetGitRevision() + if err != nil { + return + } + setVar := fmt.Sprintf("GIT_REVISION=%s", gitRev) + stdOut, stdErr, err := e2e.Maru("run", "include-loop", "--set", setVar, "--file", "src/test/tasks/tasks.yaml") + require.NoError(t, err, stdOut, stdErr) + require.Contains(t, stdErr, "9") + require.Contains(t, stdErr, "0") }) t.Run("run cmd-set-variable with --set", func(t *testing.T) { @@ -148,7 +163,7 @@ func TestTaskRunner(t *testing.T) { t.Parallel() stdOut, stdErr, err := e2e.Maru("run", "rerun-tasks-recursive", "--file", "src/test/tasks/tasks.yaml") require.Error(t, err, stdOut, stdErr) - require.Contains(t, stdErr, "task loop detected") + require.Contains(t, stdErr, "task looping exceeded max configured task stack") }) t.Run("run interactive (with --no-progress)", func(t *testing.T) { @@ -470,4 +485,12 @@ func TestTaskRunner(t *testing.T) { require.Contains(t, stdErr, "Dry-running \"echo $MARU_ARCH\"") require.Contains(t, stdOut, "echo env var from calling task - $SECRET_KEY") }) + + t.Run("redefined include", func(t *testing.T) { + t.Parallel() + + stdOut, stdErr, err := e2e.Maru("run", "--file", "src/test/tasks/redefined-include.yaml") + require.Error(t, err, stdOut, stdErr) + require.Contains(t, stdErr, "task include \"foo\" attempted to be redefined") + }) } diff --git a/src/test/tasks/loop-task.yaml b/src/test/tasks/loop-task.yaml index 65a15e2..beea0c4 100644 --- a/src/test/tasks/loop-task.yaml +++ b/src/test/tasks/loop-task.yaml @@ -4,7 +4,19 @@ includes: - original: "./tasks.yaml" +variables: + - name: LOOP_COUNT + default: "10" + tasks: - name: loop actions: + - cmd: echo $((LOOP_COUNT - 1)) + setVariables: + - name: LOOP_COUNT - task: original:include-loop + if: ${{ ne .variables.LOOP_COUNT "0" }} + + - name: direct-loop + actions: + - task: direct-loop diff --git a/src/test/tasks/redefined-include.yaml b/src/test/tasks/redefined-include.yaml new file mode 100644 index 0000000..44194b1 --- /dev/null +++ b/src/test/tasks/redefined-include.yaml @@ -0,0 +1,7 @@ +includes: + - foo: "./tasks.yaml" + +tasks: + - name: default + actions: + - task: foo:foobar diff --git a/src/test/tasks/remote-import-tasks.yaml b/src/test/tasks/remote-import-tasks.yaml index 1c26d39..10ed20f 100644 --- a/src/test/tasks/remote-import-tasks.yaml +++ b/src/test/tasks/remote-import-tasks.yaml @@ -1,11 +1,11 @@ includes: - - remote: https://raw.githubusercontent.com/defenseunicorns/maru-runner/${GIT_REVISION}/src/test/tasks/even-more-tasks-to-import.yaml + - remote-more: https://raw.githubusercontent.com/defenseunicorns/maru-runner/${GIT_REVISION}/src/test/tasks/even-more-tasks-to-import.yaml - baz: ./more-tasks/baz.yaml tasks: - name: echo-var actions: - - task: remote:set-var + - task: remote-more:set-var - cmd: | echo "${PRETTY_OK_COMPANY} is a pretty ok company" diff --git a/src/test/tasks/tasks.yaml b/src/test/tasks/tasks.yaml index 6f04a4b..d0c21e9 100644 --- a/src/test/tasks/tasks.yaml +++ b/src/test/tasks/tasks.yaml @@ -1,6 +1,6 @@ includes: - foo: ./more-tasks/foo.yaml - - infinite: ./loop-task.yaml + - intentional: ./loop-task.yaml - remote: https://raw.githubusercontent.com/defenseunicorns/maru-runner/${GIT_REVISION}/src/test/tasks/remote-import-tasks.yaml variables: @@ -139,7 +139,7 @@ tasks: namespace: tasks - name: include-loop actions: - - task: infinite:loop + - task: intentional:loop - name: env-from-file envPath: "./my-env" actions: