Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reimplement process exit expect failure #13

Merged
merged 9 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ name: unit-tests

on:
push:
branches: [ master ]
branches: [ master, v2-wip ]
pull_request:
branches: [ master ]
branches: [ master, v2-wip]

jobs:

Expand Down
10 changes: 5 additions & 5 deletions expect.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (tt *TermTest) ExpectCustom(consumer consumer, opts ...SetExpectOpt) (rerr
return fmt.Errorf("could not create expect options: %w", err)
}

cons, err := tt.outputProducer.addConsumer(tt, consumer, expectOpts.ToConsumerOpts()...)
cons, err := tt.outputProducer.addConsumer(consumer, expectOpts.ToConsumerOpts()...)
if err != nil {
return fmt.Errorf("could not add consumer: %w", err)
}
Expand Down Expand Up @@ -180,11 +180,11 @@ func (tt *TermTest) expectExitCode(exitCode int, match bool, opts ...SetExpectOp
select {
case <-time.After(timeoutV):
return fmt.Errorf("after %s: %w", timeoutV, TimeoutError)
case state := <-tt.Exited(false): // do not wait for unread output since it's not read by this select{}
if state.Err != nil && (state.ProcessState == nil || state.ProcessState.ExitCode() == 0) {
return fmt.Errorf("cmd wait failed: %w", state.Err)
case err := <-waitChan(tt.cmd.Wait):
if err != nil && (tt.cmd.ProcessState == nil || tt.cmd.ProcessState.ExitCode() == 0) {
return fmt.Errorf("cmd wait failed: %w", err)
}
if err := tt.assertExitCode(state.ProcessState.ExitCode(), exitCode, match); err != nil {
if err := tt.assertExitCode(tt.cmd.ProcessState.ExitCode(), exitCode, match); err != nil {
return err
}
}
Expand Down
6 changes: 3 additions & 3 deletions expect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func Test_ExpectCustom(t *testing.T) {
[]SetExpectOpt{OptExpectTimeout(time.Second)},
},
"",
TimeoutError,
PtyEOF,
},
{
"Custom error",
Expand Down Expand Up @@ -167,7 +167,7 @@ func Test_ExpectCustom_Cmd(t *testing.T) {
},
[]SetExpectOpt{OptExpectTimeout(time.Second)},
},
TimeoutError,
PtyEOF,
},
{
"Custom error",
Expand All @@ -194,7 +194,7 @@ func Test_ExpectCustom_Cmd(t *testing.T) {
}

func Test_Expect_Timeout(t *testing.T) {
tt := newTermTest(t, exec.Command("bash", "-c", "echo HELLO"), false)
tt := newTermTest(t, exec.Command("bash", "-c", "echo HELLO && sleep 1"), false)
durations := []time.Duration{
100 * time.Millisecond,
200 * time.Millisecond,
Expand Down
15 changes: 2 additions & 13 deletions helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"bytes"
"errors"
"os"
"os/exec"
"strings"
"time"
)
Expand All @@ -23,20 +22,10 @@ type cmdExit struct {
Err error
}

// waitForCmdExit turns process.wait() into a channel so that it can be used within a select{} statement
func waitForCmdExit(cmd *exec.Cmd) chan *cmdExit {
exit := make(chan *cmdExit, 1)
go func() {
err := cmd.Wait()
exit <- &cmdExit{ProcessState: cmd.ProcessState, Err: err}
}()
return exit
}

func waitChan[T any](wait func() T) chan T {
done := make(chan T)
done := make(chan T, 1)
go func() {
done <- wait()
wait()
close(done)
}()
return done
Expand Down
27 changes: 18 additions & 9 deletions outputconsumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ type outputConsumer struct {
opts *OutputConsumerOpts
isalive bool
mutex *sync.Mutex
tt *TermTest
}

type OutputConsumerOpts struct {
Expand All @@ -37,7 +36,7 @@ func OptsConsTimeout(timeout time.Duration) func(o *OutputConsumerOpts) {
}
}

func newOutputConsumer(tt *TermTest, consume consumer, opts ...SetConsOpt) *outputConsumer {
func newOutputConsumer(consume consumer, opts ...SetConsOpt) *outputConsumer {
oc := &outputConsumer{
consume: consume,
opts: &OutputConsumerOpts{
Expand All @@ -47,7 +46,6 @@ func newOutputConsumer(tt *TermTest, consume consumer, opts ...SetConsOpt) *outp
waiter: make(chan error, 1),
isalive: true,
mutex: &sync.Mutex{},
tt: tt,
}

for _, optSetter := range opts {
Expand Down Expand Up @@ -83,6 +81,23 @@ func (e *outputConsumer) Report(buffer []byte) (int, error) {
return pos, err
}

type errConsumerStopped struct {
reason error
}

func (e errConsumerStopped) Error() string {
return fmt.Sprintf("consumer stopped, reason: %s", e.reason)
}

func (e errConsumerStopped) Unwrap() error {
return e.reason
}

func (e *outputConsumer) Stop(reason error) {
e.opts.Logger.Printf("stopping consumer, reason: %s\n", reason)
e.waiter <- errConsumerStopped{reason}
}

func (e *outputConsumer) wait() error {
e.opts.Logger.Println("started waiting")
defer e.opts.Logger.Println("stopped waiting")
Expand All @@ -103,11 +118,5 @@ func (e *outputConsumer) wait() error {
e.mutex.Lock()
e.opts.Logger.Println("Encountered timeout")
return fmt.Errorf("after %s: %w", e.opts.Timeout, TimeoutError)
case state := <-e.tt.Exited(true): // allow for output to be read first by first case in this select{}
e.mutex.Lock()
if state.Err != nil {
e.opts.Logger.Println("Encountered error waiting for process to exit: %s\n", state.Err.Error())
}
return fmt.Errorf("process exited (status: %d)", state.ProcessState.ExitCode())
}
}
26 changes: 20 additions & 6 deletions outputproducer.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ func (o *outputProducer) listen(r io.Reader, w io.Writer, appendBuffer func([]by
for {
o.opts.Logger.Println("listen: loop")
if err := o.processNextRead(br, w, appendBuffer, size); err != nil {
if errors.Is(err, ptyEOF) {
o.opts.Logger.Println("listen: reached EOF")
if errors.Is(err, PtyEOF) {
return nil
} else {
return fmt.Errorf("could not poll reader: %w", err)
Expand All @@ -64,7 +63,7 @@ func (o *outputProducer) listen(r io.Reader, w io.Writer, appendBuffer func([]by
}
}

var ptyEOF = errors.New("pty closed")
var PtyEOF = errors.New("pty closed")

func (o *outputProducer) processNextRead(r io.Reader, w io.Writer, appendBuffer func([]byte, bool) error, size int) error {
o.opts.Logger.Printf("processNextRead started with size: %d\n", size)
Expand All @@ -78,6 +77,7 @@ func (o *outputProducer) processNextRead(r io.Reader, w io.Writer, appendBuffer
pathError := &fs.PathError{}
if errors.Is(errRead, fs.ErrClosed) || errors.Is(errRead, io.EOF) || (runtime.GOOS == "linux" && errors.As(errRead, &pathError)) {
isEOF = true
o.opts.Logger.Println("reached EOF")
}
}

Expand All @@ -96,7 +96,8 @@ func (o *outputProducer) processNextRead(r io.Reader, w io.Writer, appendBuffer

if errRead != nil {
if isEOF {
return errors.Join(errRead, ptyEOF)
o.closeConsumers(PtyEOF)
return errors.Join(errRead, PtyEOF)
}
return fmt.Errorf("could not read pty output: %w", errRead)
}
Expand Down Expand Up @@ -194,6 +195,19 @@ func (o *outputProducer) processDirtyOutput(output []byte, cursorPos int, cleanU
return append(append(alreadyCleanedOutput, processedOutput...), unprocessedOutput...), processedCursorPos, newCleanUptoPos, nil
}

func (o *outputProducer) closeConsumers(reason error) {
o.opts.Logger.Println("closing consumers")
defer o.opts.Logger.Println("closed consumers")

o.mutex.Lock()
defer o.mutex.Unlock()

for n := 0; n < len(o.consumers); n++ {
o.consumers[n].Stop(reason)
o.consumers = append(o.consumers[:n], o.consumers[n+1:]...)
}
}

func (o *outputProducer) flushConsumers() error {
o.opts.Logger.Println("flushing consumers")
defer o.opts.Logger.Println("flushed consumers")
Expand Down Expand Up @@ -238,12 +252,12 @@ func (o *outputProducer) flushConsumers() error {
return nil
}

func (o *outputProducer) addConsumer(tt *TermTest, consume consumer, opts ...SetConsOpt) (*outputConsumer, error) {
func (o *outputProducer) addConsumer(consume consumer, opts ...SetConsOpt) (*outputConsumer, error) {
o.opts.Logger.Printf("adding consumer")
defer o.opts.Logger.Printf("added consumer")

opts = append(opts, OptConsInherit(o.opts))
listener := newOutputConsumer(tt, consume, opts...)
listener := newOutputConsumer(consume, opts...)
o.consumers = append(o.consumers, listener)

if err := o.flushConsumers(); err != nil {
Expand Down
43 changes: 12 additions & 31 deletions termtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ type TermTest struct {
ptmx pty.Pty
outputProducer *outputProducer
listenError chan error
waitError chan error
opts *Opts
exited *cmdExit
}

type ErrorHandler func(*TermTest, error) error
Expand Down Expand Up @@ -79,6 +79,7 @@ func New(cmd *exec.Cmd, opts ...SetOpt) (*TermTest, error) {
cmd: cmd,
outputProducer: newOutputProducer(optv),
listenError: make(chan error, 1),
waitError: make(chan error, 1),
opts: optv,
}

Expand Down Expand Up @@ -228,6 +229,7 @@ func (tt *TermTest) start() (rerr error) {
tt.term = vt10x.New(vt10x.WithWriter(ptmx), vt10x.WithSize(tt.opts.Cols, tt.opts.Rows))

// Start listening for output
// We use a waitgroup here to ensure the listener is active before consumers are attached.
wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
Expand All @@ -236,12 +238,18 @@ func (tt *TermTest) start() (rerr error) {
err := tt.outputProducer.Listen(tt.ptmx, tt.term)
tt.listenError <- err
}()
wg.Wait()

go func() {
tt.exited = <-waitForCmdExit(tt.cmd)
// We start waiting right away, because on Windows the PTY isn't closed until the process exits, which in turn
// can't happen unless we've told the pty we're ready for it to close.
// This of course isn't ideal, but until the pty library fixes the cross-platform inconsistencies we have to
// work around these limitations.
defer tt.opts.Logger.Printf("waitIndefinitely finished")
tt.waitError <- tt.waitIndefinitely()
}()

wg.Wait()

return nil
}

Expand All @@ -252,13 +260,8 @@ func (tt *TermTest) Wait(timeout time.Duration) (rerr error) {
tt.opts.Logger.Println("wait called")
defer tt.opts.Logger.Println("wait closed")

errc := make(chan error, 1)
go func() {
errc <- tt.WaitIndefinitely()
}()

select {
case err := <-errc:
case err := <-tt.waitError:
// WaitIndefinitely already invokes the expect error handler
return err
case <-time.After(timeout):
Expand Down Expand Up @@ -324,28 +327,6 @@ func (tt *TermTest) SendCtrlC() {
tt.Send(string([]byte{0x03})) // 0x03 is ASCII character for ^C
}

// Exited returns a channel that sends the given termtest's command cmdExit info when available.
// This can be used within a select{} statement.
// If waitExtra is given, waits a little bit before sending cmdExit info. This allows any fellow
// switch cases with output consumers to handle unprocessed stdout. If there are no such cases
// (e.g. ExpectExit(), where we want to catch an exit ASAP), waitExtra should be false.
func (tt *TermTest) Exited(waitExtra bool) chan *cmdExit {
return waitChan(func() *cmdExit {
ticker := time.NewTicker(processExitPollInterval)
for {
select {
case <-ticker.C:
if tt.exited != nil {
if waitExtra { // allow sibling output consumer cases to handle their output
time.Sleep(processExitExtraWait)
}
return tt.exited
}
}
}
})
}

func (tt *TermTest) errorHandler(rerr *error) {
err := *rerr
if err == nil {
Expand Down
2 changes: 1 addition & 1 deletion termtest_other.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ func syscallErrorCode(err error) int {
return -1
}

func (tt *TermTest) WaitIndefinitely() error {
func (tt *TermTest) waitIndefinitely() error {
tt.opts.Logger.Println("WaitIndefinitely called")
defer tt.opts.Logger.Println("WaitIndefinitely closed")

Expand Down
4 changes: 2 additions & 2 deletions termtest_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ func syscallErrorCode(err error) int {
return 0
}

// WaitIndefinitely on Windows has to work around a Windows PTY bug where the PTY will NEVER exit by itself:
// waitIndefinitely on Windows has to work around a Windows PTY bug where the PTY will NEVER exit by itself:
// https://github.com/photostorm/pty/issues/3
// Instead we wait for the process itself to exit, and after a grace period will shut down the pty.
func (tt *TermTest) WaitIndefinitely() error {
func (tt *TermTest) waitIndefinitely() error {
tt.opts.Logger.Println("WaitIndefinitely called")
defer tt.opts.Logger.Println("WaitIndefinitely closed")

Expand Down
Loading