diff --git a/pkg/runner/step_runner.go b/pkg/runner/step_runner.go index 60d91d33..266652c5 100644 --- a/pkg/runner/step_runner.go +++ b/pkg/runner/step_runner.go @@ -15,29 +15,28 @@ import ( "github.com/linuxboot/contest/pkg/xcontext" ) -type OnTargetResult func(res error) +type OnTargetResult func(tgt *target.Target, res error) type OnStepRunnerStopped func(err error) type StepRunner struct { - ctx xcontext.Context - cancel context.CancelFunc - mu sync.Mutex - stopCallback OnStepRunnerStopped + ctx xcontext.Context + cancel context.CancelFunc + mu sync.Mutex + targetCallback OnTargetResult + stopCallback OnStepRunnerStopped stepIn chan *target.Target - addedTargets map[string]OnTargetResult + stopOnce sync.Once + reportedTargets map[string]struct{} + started uint32 runningLoopActive bool - - stopped chan struct{} + finished chan struct{} resultErr error resultResumeState json.RawMessage } -func (sr *StepRunner) AddTarget(tgt *target.Target, callback OnTargetResult) error { - if callback == nil { - return fmt.Errorf("callback should not be nil") - } - err := func(tgt *target.Target, callback OnTargetResult) error { +func (sr *StepRunner) AddTarget(tgt *target.Target) error { + err := func(tgt *target.Target) error { sr.mu.Lock() defer sr.mu.Unlock() @@ -48,14 +47,8 @@ func (sr *StepRunner) AddTarget(tgt *target.Target, callback OnTargetResult) err if sr.stepIn == nil { return fmt.Errorf("step runner is stopped") } - - existingCb := sr.addedTargets[tgt.ID] - if existingCb != nil { - return fmt.Errorf("existing target") - } - sr.addedTargets[tgt.ID] = callback return nil - }(tgt, callback) + }(tgt) if err != nil { return err } @@ -71,18 +64,61 @@ func (sr *StepRunner) AddTarget(tgt *target.Target, callback OnTargetResult) err return nil } -func (sr *StepRunner) IsRunning() bool { +func (sr *StepRunner) Run( + bundle test.TestStepBundle, + ev testevent.Emitter, + resumeState json.RawMessage, +) { + if !atomic.CompareAndSwapUint32(&sr.started, 0, 1) { + return + } + + var activeLoopsCount int32 = 2 + onFinished := func() { + if atomic.AddInt32(&activeLoopsCount, -1) != 0 { + return + } + sr.mu.Lock() + defer sr.mu.Unlock() + + close(sr.finished) + sr.finished = nil + + // if an error occurred, this callback was invoked early + sr.notifyStoppedLocked(nil) + sr.ctx.Debugf("StepRunner finished") + } + + stepOut := make(chan test.TestStepResult) + go func() { + defer onFinished() + sr.runningLoop(stepOut, bundle, ev, resumeState) + sr.ctx.Debugf("Running loop finished") + }() + + go func() { + defer onFinished() + sr.readingLoop(stepOut, bundle.TestStepLabel) + sr.ctx.Debugf("Reading loop finished") + }() +} + +func (sr *StepRunner) Started() bool { + return atomic.LoadUint32(&sr.started) == 1 +} + +func (sr *StepRunner) Running() bool { sr.mu.Lock() defer sr.mu.Unlock() - return sr.stopped != nil + return sr.Started() && sr.finished != nil } func (sr *StepRunner) WaitResults(ctx context.Context) (json.RawMessage, error) { sr.mu.Lock() resultErr := sr.resultErr resultResumeState := sr.resultResumeState - stepStopped := sr.stopped + stepStopped := sr.finished sr.mu.Unlock() if resultErr != nil { @@ -104,24 +140,22 @@ func (sr *StepRunner) WaitResults(ctx context.Context) (json.RawMessage, error) // Stop triggers TestStep to stop running by closing input channel func (sr *StepRunner) Stop() { - sr.mu.Lock() - defer sr.mu.Unlock() - - if sr.stepIn != nil { - sr.ctx.Debugf("Close input channel") + sr.stopOnce.Do(func() { close(sr.stepIn) - sr.stepIn = nil - } + sr.ctx.Debugf("Input channel was closed") + }) } -func (sr *StepRunner) readingLoop(stepOut chan test.TestStepResult, testStepLabel string) { - invokePanicSafe := func(callback OnTargetResult, res error) { +func (sr *StepRunner) targetCallbackPanicSafe(tgt *target.Target, res error) { + defer func() { if r := recover(); r != nil { sr.ctx.Errorf("Callback panic, stack: %s", debug.Stack()) } - callback(res) - } + }() + sr.targetCallback(tgt, res) +} +func (sr *StepRunner) readingLoop(stepOut chan test.TestStepResult, testStepLabel string) { for { select { case res, ok := <-stepOut: @@ -141,21 +175,18 @@ func (sr *StepRunner) readingLoop(stepOut chan test.TestStepResult, testStepLabe sr.setErr(&cerrors.ErrTestStepReturnedNoTarget{StepName: testStepLabel}) return } + sr.ctx.Infof("Obtained '%v' for target '%s'", res, res.Target.ID) sr.mu.Lock() - callback, found := sr.addedTargets[res.Target.ID] - sr.addedTargets[res.Target.ID] = nil + _, found := sr.reportedTargets[res.Target.ID] + sr.reportedTargets[res.Target.ID] = struct{}{} sr.mu.Unlock() - if !found { - sr.setErr(&cerrors.ErrTestStepReturnedUnexpectedResult{StepName: testStepLabel, Target: res.Target.ID}) - return - } - if callback == nil { + if found { sr.setErr(&cerrors.ErrTestStepReturnedDuplicateResult{StepName: testStepLabel, Target: res.Target.ID}) return } - invokePanicSafe(callback, res.Err) + sr.targetCallbackPanicSafe(res.Target, res.Err) case <-sr.ctx.Done(): sr.ctx.Debugf("canceled readingLoop") @@ -165,7 +196,7 @@ func (sr *StepRunner) readingLoop(stepOut chan test.TestStepResult, testStepLabe } func (sr *StepRunner) runningLoop( - stepIn chan *target.Target, stepOut chan test.TestStepResult, + stepOut chan test.TestStepResult, bundle test.TestStepBundle, ev testevent.Emitter, resumeState json.RawMessage, ) { defer func() { @@ -190,9 +221,13 @@ func (sr *StepRunner) runningLoop( } }() - inChannels := test.TestStepChannels{In: stepIn, Out: stepOut} + sr.mu.Lock() + sr.runningLoopActive = true + sr.mu.Unlock() + + inChannels := test.TestStepChannels{In: sr.stepIn, Out: stepOut} resultResumeState, err := bundle.TestStep.Run(sr.ctx, inChannels, bundle.Parameters, ev, resumeState) - sr.ctx.Debugf("Step runner finished '%v', rs %s", err, string(resultResumeState)) + sr.ctx.Debugf("TestStep finished '%v', rs %s", err, string(resultResumeState)) sr.mu.Lock() sr.setErrLocked(err) @@ -246,48 +281,24 @@ func safeCloseOutCh(ch chan test.TestStepResult) (recoverOccurred bool) { // NewStepRunner creates a new StepRunner object func NewStepRunner( ctx xcontext.Context, - bundle test.TestStepBundle, - ev testevent.Emitter, - resumeState json.RawMessage, + targetCallback OnTargetResult, stoppedCallback OnStepRunnerStopped, ) *StepRunner { + if targetCallback == nil { + panic("target callback should not be nil") + } stepIn := make(chan *target.Target) - stepOut := make(chan test.TestStepResult) srCrx, cancel := xcontext.WithCancel(ctx) sr := &StepRunner{ ctx: srCrx, cancel: cancel, stepIn: stepIn, - addedTargets: make(map[string]OnTargetResult), - runningLoopActive: true, - stopped: make(chan struct{}), + targetCallback: targetCallback, stopCallback: stoppedCallback, + reportedTargets: make(map[string]struct{}), + runningLoopActive: true, + finished: make(chan struct{}), } - - var activeLoopsCount int32 = 2 - onFinished := func() { - if atomic.AddInt32(&activeLoopsCount, -1) != 0 { - return - } - sr.mu.Lock() - defer sr.mu.Unlock() - - close(sr.stopped) - sr.stopped = nil - - // if an error occurred, this callback was invoked early - sr.notifyStoppedLocked(nil) - } - - go func() { - defer onFinished() - sr.runningLoop(stepIn, stepOut, bundle, ev, resumeState) - }() - - go func() { - defer onFinished() - sr.readingLoop(stepOut, bundle.TestStepLabel) - }() return sr } diff --git a/pkg/runner/test_runner.go b/pkg/runner/test_runner.go index 144786de..4d447dcb 100644 --- a/pkg/runner/test_runner.go +++ b/pkg/runner/test_runner.go @@ -9,6 +9,7 @@ import ( "context" "encoding/json" "fmt" + "strconv" "sync" "time" @@ -138,19 +139,41 @@ func (tr *TestRunner) Run( // Set up the pipeline for i, sb := range t.TestStepsBundles { stepCtx, stepCancel := xcontext.WithCancel(stepsCtx) + stepCtx = stepCtx.WithField("step_index", strconv.Itoa(i)) + stepCtx = stepCtx.WithField("step_label", sb.TestStepLabel) + var srs json.RawMessage if i < len(rs.StepResumeState) && string(rs.StepResumeState[i]) != "null" { srs = rs.StepResumeState[i] } - tr.steps = append(tr.steps, &stepState{ + ss := &stepState{ ctx: stepCtx, cancel: stepCancel, stepIndex: i, sb: sb, ev: emitterFactory.New(sb.TestStepLabel), resumeState: srs, - }) + } // Step handlers will be started from target handlers as targets reach them. + ss.stepRunner = NewStepRunner( + ss.ctx, + func(tgt *target.Target, res error) { + if err := tr.reportTargetResult(ss, tgt, res); err != nil { + ctx.Errorf("Reporting target result failed: %v", err) + tr.mu.Lock() + ss.setErrLocked(err) + tr.mu.Unlock() + } + tr.cond.Signal() + }, + func(err error) { + tr.mu.Lock() + defer tr.mu.Unlock() + ss.setErrLocked(err) + tr.cond.Signal() + }, + ) + tr.steps = append(tr.steps, ss) } // Set up the targets @@ -231,13 +254,13 @@ func (tr *TestRunner) Run( if stepErr != nil && stepErr != xcontext.ErrPaused { resumeOk = false } - ctx.Debugf(" %d %s %v %t", i, tgs, stepErr, resumeOk) + ctx.Debugf(" %d %s %v %t", i, tgs.String(), stepErr, resumeOk) } ctx.Debugf("- %d in flight, ok to resume? %t", numInFlightTargets, resumeOk) ctx.Debugf("step states:") for i, ss := range tr.steps { ctx.Debugf(" %d %s %t %t %v %s", - i, ss, ss.stepRunner != nil, ss.stepRunner != nil && ss.stepRunner.IsRunning(), ss.runErr, ss.resumeState) + i, ss, ss.stepRunner.Started(), ss.stepRunner.Running(), ss.runErr, ss.resumeState) } // Is there a useful error to report? @@ -287,14 +310,10 @@ func (tr *TestRunner) waitStepRunners(ctx xcontext.Context) error { var neverReturnedErr *cerrors.ErrTestStepsNeverReturned for _, ss := range tr.steps { - tr.mu.Lock() - stepRunner := ss.stepRunner - tr.mu.Unlock() - - if stepRunner == nil { + if !ss.stepRunner.Started() { continue } - resumeState, err := stepRunner.WaitResults(shutdownCtx) + resumeState, err := ss.stepRunner.WaitResults(shutdownCtx) if err == context.DeadlineExceeded { err = &cerrors.ErrTestStepsNeverReturned{StepNames: []string{ss.sb.TestStepLabel}} if neverReturnedErr == nil { @@ -331,15 +350,7 @@ func (tr *TestRunner) injectTarget(ctx xcontext.Context, tgs *targetState, ss *s ctx.Debugf("%s: injecting into %s", tgs, ss) tgt := tgs.tgt - err := ss.stepRunner.AddTarget(tgt, func(res error) { - if err := tr.reportTargetResult(ss, tgt, res); err != nil { - ctx.Errorf("Reporting target result failed: %v", err) - tr.mu.Lock() - defer tr.mu.Unlock() - ss.setErrLocked(err) - } - tr.cond.Signal() - }) + err := ss.stepRunner.AddTarget(tgt) if err == nil { if err = ss.emitEvent(ctx, target.EventTargetIn, tgs.tgt, nil); err != nil { err = fmt.Errorf("failed to report target injection: %w", err) @@ -493,20 +504,7 @@ loop: // runStepIfNeeded starts the step runner goroutine if not already running. func (tr *TestRunner) runStepIfNeeded(ss *stepState) { - tr.mu.Lock() - defer tr.mu.Unlock() - - if ss.stepRunner != nil { - return - } - ctx := xcontext.WithValue(ss.ctx, "step_index", ss.stepIndex) - ctx = xcontext.WithValue(ctx, "step_label", ss.sb.TestStepLabel) - ss.stepRunner = NewStepRunner(ctx, ss.sb, ss.ev, ss.resumeState, func(err error) { - tr.mu.Lock() - defer tr.mu.Unlock() - ss.setErrLocked(err) - tr.cond.Signal() - }) + ss.stepRunner.Run(ss.sb, ss.ev, ss.resumeState) } // setErrLocked sets step runner error unless already set. @@ -624,11 +622,11 @@ func (tr *TestRunner) runMonitor(ctx xcontext.Context, minStep int) error { stepLoop: for step := minStep; step < len(tr.steps); pass++ { ss := tr.steps[step] - ctx.Debugf("monitor pass %d: current step %s", pass, ss) + ctx.Debugf("monitor pass %d: current step %s", pass, ss.String()) // Check if all the targets have either made it past the injection phase or terminated. ok := true for _, tgs := range tr.targets { - ctx.Debugf("monitor pass %d: %s: %s", pass, ss, tgs) + ctx.Debugf("monitor pass %d: %s: %s", pass, ss, tgs.String()) if !tgs.handlerRunning { // Not running anymore continue } @@ -648,9 +646,7 @@ stepLoop: } // All targets ok, close the step's input channel. ctx.Debugf("monitor pass %d: %s: no more targets, closing input channel", pass, ss) - if ss.stepRunner != nil { - ss.stepRunner.Stop() - } + ss.stepRunner.Stop() step++ } // Wait for all the targets to finish. @@ -671,7 +667,7 @@ tgtLoop: // It's been paused, this is fine. continue } - if ss.stepRunner != nil && !ss.stepRunner.IsRunning() { + if ss.stepRunner.Started() && !ss.stepRunner.Running() { // Target has been injected but step runner has exited without a valid reason, this target has been lost. runErr = &cerrors.ErrTestStepLostTargets{ StepName: ss.sb.TestStepLabel, @@ -741,5 +737,5 @@ func (tgs *targetState) String() string { } finished := !tgs.handlerRunning return fmt.Sprintf("[%s %d %s %t %s]", - tgs.tgt, tgs.CurStep, tgs.CurPhase, finished, resText) + tgs.tgt, tgs.CurStep, tgs.CurPhase.String(), finished, resText) }