Skip to content

Commit

Permalink
Merge pull request #13 from ActiveState/DX-2901
Browse files Browse the repository at this point in the history
Reimplement process exit expect failure
  • Loading branch information
Naatan authored Jun 28, 2024
2 parents 45f7444 + e176b0b commit 3b5b4ea
Show file tree
Hide file tree
Showing 9 changed files with 65 additions and 72 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ name: unit-tests

on:
push:
branches: [ master ]
branches: [ master, v2-wip ]
pull_request:
branches: [ master ]
branches: [ master, v2-wip]

jobs:

Expand Down
10 changes: 5 additions & 5 deletions expect.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (tt *TermTest) ExpectCustom(consumer consumer, opts ...SetExpectOpt) (rerr
return fmt.Errorf("could not create expect options: %w", err)
}

cons, err := tt.outputProducer.addConsumer(tt, consumer, expectOpts.ToConsumerOpts()...)
cons, err := tt.outputProducer.addConsumer(consumer, expectOpts.ToConsumerOpts()...)
if err != nil {
return fmt.Errorf("could not add consumer: %w", err)
}
Expand Down Expand Up @@ -180,11 +180,11 @@ func (tt *TermTest) expectExitCode(exitCode int, match bool, opts ...SetExpectOp
select {
case <-time.After(timeoutV):
return fmt.Errorf("after %s: %w", timeoutV, TimeoutError)
case state := <-tt.Exited(false): // do not wait for unread output since it's not read by this select{}
if state.Err != nil && (state.ProcessState == nil || state.ProcessState.ExitCode() == 0) {
return fmt.Errorf("cmd wait failed: %w", state.Err)
case err := <-waitChan(tt.cmd.Wait):
if err != nil && (tt.cmd.ProcessState == nil || tt.cmd.ProcessState.ExitCode() == 0) {
return fmt.Errorf("cmd wait failed: %w", err)
}
if err := tt.assertExitCode(state.ProcessState.ExitCode(), exitCode, match); err != nil {
if err := tt.assertExitCode(tt.cmd.ProcessState.ExitCode(), exitCode, match); err != nil {
return err
}
}
Expand Down
6 changes: 3 additions & 3 deletions expect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func Test_ExpectCustom(t *testing.T) {
[]SetExpectOpt{OptExpectTimeout(time.Second)},
},
"",
TimeoutError,
PtyEOF,
},
{
"Custom error",
Expand Down Expand Up @@ -167,7 +167,7 @@ func Test_ExpectCustom_Cmd(t *testing.T) {
},
[]SetExpectOpt{OptExpectTimeout(time.Second)},
},
TimeoutError,
PtyEOF,
},
{
"Custom error",
Expand All @@ -194,7 +194,7 @@ func Test_ExpectCustom_Cmd(t *testing.T) {
}

func Test_Expect_Timeout(t *testing.T) {
tt := newTermTest(t, exec.Command("bash", "-c", "echo HELLO"), false)
tt := newTermTest(t, exec.Command("bash", "-c", "echo HELLO && sleep 1"), false)
durations := []time.Duration{
100 * time.Millisecond,
200 * time.Millisecond,
Expand Down
15 changes: 2 additions & 13 deletions helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"bytes"
"errors"
"os"
"os/exec"
"strings"
"time"
)
Expand All @@ -23,20 +22,10 @@ type cmdExit struct {
Err error
}

// waitForCmdExit turns process.wait() into a channel so that it can be used within a select{} statement
func waitForCmdExit(cmd *exec.Cmd) chan *cmdExit {
exit := make(chan *cmdExit, 1)
go func() {
err := cmd.Wait()
exit <- &cmdExit{ProcessState: cmd.ProcessState, Err: err}
}()
return exit
}

func waitChan[T any](wait func() T) chan T {
done := make(chan T)
done := make(chan T, 1)
go func() {
done <- wait()
wait()
close(done)
}()
return done
Expand Down
27 changes: 18 additions & 9 deletions outputconsumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ type outputConsumer struct {
opts *OutputConsumerOpts
isalive bool
mutex *sync.Mutex
tt *TermTest
}

type OutputConsumerOpts struct {
Expand All @@ -37,7 +36,7 @@ func OptsConsTimeout(timeout time.Duration) func(o *OutputConsumerOpts) {
}
}

func newOutputConsumer(tt *TermTest, consume consumer, opts ...SetConsOpt) *outputConsumer {
func newOutputConsumer(consume consumer, opts ...SetConsOpt) *outputConsumer {
oc := &outputConsumer{
consume: consume,
opts: &OutputConsumerOpts{
Expand All @@ -47,7 +46,6 @@ func newOutputConsumer(tt *TermTest, consume consumer, opts ...SetConsOpt) *outp
waiter: make(chan error, 1),
isalive: true,
mutex: &sync.Mutex{},
tt: tt,
}

for _, optSetter := range opts {
Expand Down Expand Up @@ -83,6 +81,23 @@ func (e *outputConsumer) Report(buffer []byte) (int, error) {
return pos, err
}

type errConsumerStopped struct {
reason error
}

func (e errConsumerStopped) Error() string {
return fmt.Sprintf("consumer stopped, reason: %s", e.reason)
}

func (e errConsumerStopped) Unwrap() error {
return e.reason
}

func (e *outputConsumer) Stop(reason error) {
e.opts.Logger.Printf("stopping consumer, reason: %s\n", reason)
e.waiter <- errConsumerStopped{reason}
}

func (e *outputConsumer) wait() error {
e.opts.Logger.Println("started waiting")
defer e.opts.Logger.Println("stopped waiting")
Expand All @@ -103,11 +118,5 @@ func (e *outputConsumer) wait() error {
e.mutex.Lock()
e.opts.Logger.Println("Encountered timeout")
return fmt.Errorf("after %s: %w", e.opts.Timeout, TimeoutError)
case state := <-e.tt.Exited(true): // allow for output to be read first by first case in this select{}
e.mutex.Lock()
if state.Err != nil {
e.opts.Logger.Println("Encountered error waiting for process to exit: %s\n", state.Err.Error())
}
return fmt.Errorf("process exited (status: %d)", state.ProcessState.ExitCode())
}
}
26 changes: 20 additions & 6 deletions outputproducer.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ func (o *outputProducer) listen(r io.Reader, w io.Writer, appendBuffer func([]by
for {
o.opts.Logger.Println("listen: loop")
if err := o.processNextRead(br, w, appendBuffer, size); err != nil {
if errors.Is(err, ptyEOF) {
o.opts.Logger.Println("listen: reached EOF")
if errors.Is(err, PtyEOF) {
return nil
} else {
return fmt.Errorf("could not poll reader: %w", err)
Expand All @@ -64,7 +63,7 @@ func (o *outputProducer) listen(r io.Reader, w io.Writer, appendBuffer func([]by
}
}

var ptyEOF = errors.New("pty closed")
var PtyEOF = errors.New("pty closed")

func (o *outputProducer) processNextRead(r io.Reader, w io.Writer, appendBuffer func([]byte, bool) error, size int) error {
o.opts.Logger.Printf("processNextRead started with size: %d\n", size)
Expand All @@ -78,6 +77,7 @@ func (o *outputProducer) processNextRead(r io.Reader, w io.Writer, appendBuffer
pathError := &fs.PathError{}
if errors.Is(errRead, fs.ErrClosed) || errors.Is(errRead, io.EOF) || (runtime.GOOS == "linux" && errors.As(errRead, &pathError)) {
isEOF = true
o.opts.Logger.Println("reached EOF")
}
}

Expand All @@ -96,7 +96,8 @@ func (o *outputProducer) processNextRead(r io.Reader, w io.Writer, appendBuffer

if errRead != nil {
if isEOF {
return errors.Join(errRead, ptyEOF)
o.closeConsumers(PtyEOF)
return errors.Join(errRead, PtyEOF)
}
return fmt.Errorf("could not read pty output: %w", errRead)
}
Expand Down Expand Up @@ -194,6 +195,19 @@ func (o *outputProducer) processDirtyOutput(output []byte, cursorPos int, cleanU
return append(append(alreadyCleanedOutput, processedOutput...), unprocessedOutput...), processedCursorPos, newCleanUptoPos, nil
}

func (o *outputProducer) closeConsumers(reason error) {
o.opts.Logger.Println("closing consumers")
defer o.opts.Logger.Println("closed consumers")

o.mutex.Lock()
defer o.mutex.Unlock()

for n := 0; n < len(o.consumers); n++ {
o.consumers[n].Stop(reason)
o.consumers = append(o.consumers[:n], o.consumers[n+1:]...)
}
}

func (o *outputProducer) flushConsumers() error {
o.opts.Logger.Println("flushing consumers")
defer o.opts.Logger.Println("flushed consumers")
Expand Down Expand Up @@ -238,12 +252,12 @@ func (o *outputProducer) flushConsumers() error {
return nil
}

func (o *outputProducer) addConsumer(tt *TermTest, consume consumer, opts ...SetConsOpt) (*outputConsumer, error) {
func (o *outputProducer) addConsumer(consume consumer, opts ...SetConsOpt) (*outputConsumer, error) {
o.opts.Logger.Printf("adding consumer")
defer o.opts.Logger.Printf("added consumer")

opts = append(opts, OptConsInherit(o.opts))
listener := newOutputConsumer(tt, consume, opts...)
listener := newOutputConsumer(consume, opts...)
o.consumers = append(o.consumers, listener)

if err := o.flushConsumers(); err != nil {
Expand Down
43 changes: 12 additions & 31 deletions termtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ type TermTest struct {
ptmx pty.Pty
outputProducer *outputProducer
listenError chan error
waitError chan error
opts *Opts
exited *cmdExit
}

type ErrorHandler func(*TermTest, error) error
Expand Down Expand Up @@ -79,6 +79,7 @@ func New(cmd *exec.Cmd, opts ...SetOpt) (*TermTest, error) {
cmd: cmd,
outputProducer: newOutputProducer(optv),
listenError: make(chan error, 1),
waitError: make(chan error, 1),
opts: optv,
}

Expand Down Expand Up @@ -228,6 +229,7 @@ func (tt *TermTest) start() (rerr error) {
tt.term = vt10x.New(vt10x.WithWriter(ptmx), vt10x.WithSize(tt.opts.Cols, tt.opts.Rows))

// Start listening for output
// We use a waitgroup here to ensure the listener is active before consumers are attached.
wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
Expand All @@ -236,12 +238,18 @@ func (tt *TermTest) start() (rerr error) {
err := tt.outputProducer.Listen(tt.ptmx, tt.term)
tt.listenError <- err
}()
wg.Wait()

go func() {
tt.exited = <-waitForCmdExit(tt.cmd)
// We start waiting right away, because on Windows the PTY isn't closed until the process exits, which in turn
// can't happen unless we've told the pty we're ready for it to close.
// This of course isn't ideal, but until the pty library fixes the cross-platform inconsistencies we have to
// work around these limitations.
defer tt.opts.Logger.Printf("waitIndefinitely finished")
tt.waitError <- tt.waitIndefinitely()
}()

wg.Wait()

return nil
}

Expand All @@ -252,13 +260,8 @@ func (tt *TermTest) Wait(timeout time.Duration) (rerr error) {
tt.opts.Logger.Println("wait called")
defer tt.opts.Logger.Println("wait closed")

errc := make(chan error, 1)
go func() {
errc <- tt.WaitIndefinitely()
}()

select {
case err := <-errc:
case err := <-tt.waitError:
// WaitIndefinitely already invokes the expect error handler
return err
case <-time.After(timeout):
Expand Down Expand Up @@ -324,28 +327,6 @@ func (tt *TermTest) SendCtrlC() {
tt.Send(string([]byte{0x03})) // 0x03 is ASCII character for ^C
}

// Exited returns a channel that sends the given termtest's command cmdExit info when available.
// This can be used within a select{} statement.
// If waitExtra is given, waits a little bit before sending cmdExit info. This allows any fellow
// switch cases with output consumers to handle unprocessed stdout. If there are no such cases
// (e.g. ExpectExit(), where we want to catch an exit ASAP), waitExtra should be false.
func (tt *TermTest) Exited(waitExtra bool) chan *cmdExit {
return waitChan(func() *cmdExit {
ticker := time.NewTicker(processExitPollInterval)
for {
select {
case <-ticker.C:
if tt.exited != nil {
if waitExtra { // allow sibling output consumer cases to handle their output
time.Sleep(processExitExtraWait)
}
return tt.exited
}
}
}
})
}

func (tt *TermTest) errorHandler(rerr *error) {
err := *rerr
if err == nil {
Expand Down
2 changes: 1 addition & 1 deletion termtest_other.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ func syscallErrorCode(err error) int {
return -1
}

func (tt *TermTest) WaitIndefinitely() error {
func (tt *TermTest) waitIndefinitely() error {
tt.opts.Logger.Println("WaitIndefinitely called")
defer tt.opts.Logger.Println("WaitIndefinitely closed")

Expand Down
4 changes: 2 additions & 2 deletions termtest_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ func syscallErrorCode(err error) int {
return 0
}

// WaitIndefinitely on Windows has to work around a Windows PTY bug where the PTY will NEVER exit by itself:
// waitIndefinitely on Windows has to work around a Windows PTY bug where the PTY will NEVER exit by itself:
// https://github.com/photostorm/pty/issues/3
// Instead we wait for the process itself to exit, and after a grace period will shut down the pty.
func (tt *TermTest) WaitIndefinitely() error {
func (tt *TermTest) waitIndefinitely() error {
tt.opts.Logger.Println("WaitIndefinitely called")
defer tt.opts.Logger.Println("WaitIndefinitely closed")

Expand Down

0 comments on commit 3b5b4ea

Please sign in to comment.