Skip to content

Commit

Permalink
refactor: add ctx to readUntilF to fix timing out on read untils in s…
Browse files Browse the repository at this point in the history
…end input/interactive
  • Loading branch information
carlmontanari committed Dec 14, 2024
1 parent 8bdb939 commit 83f5204
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 25 deletions.
16 changes: 14 additions & 2 deletions channel/read.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,20 @@ func (c *Channel) ReadAll() ([]byte, error) {
}

// ReadUntilFuzzy reads until a fuzzy match of the input is found.
func (c *Channel) ReadUntilFuzzy(b []byte) ([]byte, error) {
func (c *Channel) ReadUntilFuzzy(ctx context.Context, b []byte) ([]byte, error) {
if len(b) == 0 {
return nil, nil
}

var rb []byte

for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}

nb, err := c.Read()
if err != nil {
return nil, err
Expand All @@ -193,10 +199,16 @@ func (c *Channel) ReadUntilFuzzy(b []byte) ([]byte, error) {

// ReadUntilExplicit reads bytes out of the channel Q object until the bytes b are seen in the
// output. Once the bytes are seen all read bytes are returned.
func (c *Channel) ReadUntilExplicit(b []byte) ([]byte, error) {
func (c *Channel) ReadUntilExplicit(ctx context.Context, b []byte) ([]byte, error) {
var rb []byte

for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}

nb, err := c.Read()
if err != nil {
return nil, err
Expand Down
33 changes: 12 additions & 21 deletions channel/sendinput.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (c *Channel) SendInputB(input []byte, opts ...util.Option) ([]byte, error)
return
}

_, err = readUntilF(input)
_, err = readUntilF(ctx, input)
if err != nil {
cr <- &result{b: b, err: err}

Expand Down Expand Up @@ -85,30 +85,21 @@ func (c *Channel) SendInputB(input []byte, opts ...util.Option) ([]byte, error)
}
}()

select {
case r := <-cr:
if r.err != nil {
if errors.Is(r.err, context.DeadlineExceeded) {
c.l.Critical("channel timeout sending input to device")
r := <-cr
if r.err != nil {
if errors.Is(r.err, context.DeadlineExceeded) {
c.l.Critical("channel timeout sending input to device")

return nil, fmt.Errorf(
"%w: channel timeout sending input to device",
util.ErrTimeoutError,
)
}

return nil, r.err
return nil, fmt.Errorf(
"%w: channel timeout sending input to device",
util.ErrTimeoutError,
)
}

return r.b, nil
case <-ctx.Done():
c.l.Critical("channel timeout sending input to device")

return nil, fmt.Errorf(
"%w: channel timeout sending input to device",
util.ErrTimeoutError,
)
return nil, r.err
}

return r.b, nil
}

// SendInput sends the input string to the target device. Any bytes output is returned.
Expand Down
4 changes: 2 additions & 2 deletions channel/sendinteractive.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func (c *Channel) sendInteractive(
cr chan *result,
events []*SendInteractiveEvent,
op *OperationOptions,
readUntilF func(b []byte) ([]byte, error),
readUntilF func(ctx context.Context, b []byte) ([]byte, error),
) {
defer close(cr)

Expand All @@ -48,7 +48,7 @@ func (c *Channel) sendInteractive(
if e.ChannelResponse != "" && !e.HideInput {
var nb []byte

nb, err = readUntilF([]byte(e.ChannelInput))
nb, err = readUntilF(ctx, []byte(e.ChannelInput))
if err != nil {
cr <- &result{b: nil, err: err}

Expand Down

0 comments on commit 83f5204

Please sign in to comment.