diff --git a/channel/read.go b/channel/read.go index 5ce07cc..283060b 100644 --- a/channel/read.go +++ b/channel/read.go @@ -161,7 +161,7 @@ func (c *Channel) ReadAll() ([]byte, error) { } // ReadUntilFuzzy reads until a fuzzy match of the input is found. -func (c *Channel) ReadUntilFuzzy(b []byte) ([]byte, error) { +func (c *Channel) ReadUntilFuzzy(ctx context.Context, b []byte) ([]byte, error) { if len(b) == 0 { return nil, nil } @@ -169,6 +169,12 @@ func (c *Channel) ReadUntilFuzzy(b []byte) ([]byte, error) { var rb []byte for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + nb, err := c.Read() if err != nil { return nil, err @@ -193,10 +199,16 @@ func (c *Channel) ReadUntilFuzzy(b []byte) ([]byte, error) { // ReadUntilExplicit reads bytes out of the channel Q object until the bytes b are seen in the // output. Once the bytes are seen all read bytes are returned. -func (c *Channel) ReadUntilExplicit(b []byte) ([]byte, error) { +func (c *Channel) ReadUntilExplicit(ctx context.Context, b []byte) ([]byte, error) { var rb []byte for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + nb, err := c.Read() if err != nil { return nil, err diff --git a/channel/sendinput.go b/channel/sendinput.go index 9019cdc..d84a1e7 100644 --- a/channel/sendinput.go +++ b/channel/sendinput.go @@ -42,7 +42,7 @@ func (c *Channel) SendInputB(input []byte, opts ...util.Option) ([]byte, error) return } - _, err = readUntilF(input) + _, err = readUntilF(ctx, input) if err != nil { cr <- &result{b: b, err: err} @@ -85,30 +85,21 @@ func (c *Channel) SendInputB(input []byte, opts ...util.Option) ([]byte, error) } }() - select { - case r := <-cr: - if r.err != nil { - if errors.Is(r.err, context.DeadlineExceeded) { - c.l.Critical("channel timeout sending input to device") + r := <-cr + if r.err != nil { + if errors.Is(r.err, context.DeadlineExceeded) { + c.l.Critical("channel timeout sending input to device") - return nil, fmt.Errorf( - "%w: channel timeout sending input to device", - util.ErrTimeoutError, - ) - } - - return nil, r.err + return nil, fmt.Errorf( + "%w: channel timeout sending input to device", + util.ErrTimeoutError, + ) } - return r.b, nil - case <-ctx.Done(): - c.l.Critical("channel timeout sending input to device") - - return nil, fmt.Errorf( - "%w: channel timeout sending input to device", - util.ErrTimeoutError, - ) + return nil, r.err } + + return r.b, nil } // SendInput sends the input string to the target device. Any bytes output is returned. diff --git a/channel/sendinteractive.go b/channel/sendinteractive.go index b20dd40..771aef3 100644 --- a/channel/sendinteractive.go +++ b/channel/sendinteractive.go @@ -24,7 +24,7 @@ func (c *Channel) sendInteractive( cr chan *result, events []*SendInteractiveEvent, op *OperationOptions, - readUntilF func(b []byte) ([]byte, error), + readUntilF func(ctx context.Context, b []byte) ([]byte, error), ) { defer close(cr) @@ -48,7 +48,7 @@ func (c *Channel) sendInteractive( if e.ChannelResponse != "" && !e.HideInput { var nb []byte - nb, err = readUntilF([]byte(e.ChannelInput)) + nb, err = readUntilF(ctx, []byte(e.ChannelInput)) if err != nil { cr <- &result{b: nil, err: err}