diff --git a/.secrets.baseline b/.secrets.baseline index ed5d527ed4..47dc1ac67d 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -272,5 +272,5 @@ } ] }, - "generated_at": "2025-06-05T07:39:32Z" + "generated_at": "2025-06-04T19:30:24Z" } diff --git a/changes/20250530171020.feature b/changes/20250530171020.feature new file mode 100644 index 0000000000..96daee4141 --- /dev/null +++ b/changes/20250530171020.feature @@ -0,0 +1 @@ +:sparkles: `subprocess` Add support for terminating processes instead of killing them diff --git a/changes/20250604202816.feature b/changes/20250604202816.feature new file mode 100644 index 0000000000..b1d90080a2 --- /dev/null +++ b/changes/20250604202816.feature @@ -0,0 +1 @@ +:sparkles: Introducing [diodes] module which is a copy of [cloud foundary's](https://github.com/cloudfoundry/go-diodes) library. diff --git a/utils/diodes/README b/utils/diodes/README new file mode 100644 index 0000000000..ee897e9dcc --- /dev/null +++ b/utils/diodes/README @@ -0,0 +1,2 @@ +Vendoring [cloud foundary](https://github.com/cloudfoundry/go-diodes)) diode libraries to avoid importing test dependencies. + diff --git a/utils/diodes/many_to_one.go b/utils/diodes/many_to_one.go new file mode 100644 index 0000000000..3acc9f30f2 --- /dev/null +++ b/utils/diodes/many_to_one.go @@ -0,0 +1,130 @@ +package diodes + +import ( + "log" + "sync/atomic" + "unsafe" +) + +// ManyToOne diode is optimal for many writers (go-routines B-n) and a single +// reader (go-routine A). It is not thread safe for multiple readers. +type ManyToOne struct { + writeIndex uint64 + buffer []unsafe.Pointer + readIndex uint64 + alerter Alerter +} + +// NewManyToOne creates a new diode (ring buffer). The ManyToOne diode +// is optimzed for many writers (on go-routines B-n) and a single reader +// (on go-routine A). The alerter is invoked on the read's go-routine. It is +// called when it notices that the writer go-routine has passed it and wrote +// over data. A nil can be used to ignore alerts. +func NewManyToOne(size int, alerter Alerter) *ManyToOne { + if alerter == nil { + alerter = AlertFunc(func(int) {}) + } + + d := &ManyToOne{ + buffer: make([]unsafe.Pointer, size), + alerter: alerter, + } + + // Start write index at the value before 0 + // to allow the first write to use AddUint64 + // and still have a beginning index of 0 + d.writeIndex = ^d.writeIndex + return d +} + +// Set sets the data in the next slot of the ring buffer. +func (d *ManyToOne) Set(data GenericDataType) { + for { + writeIndex := atomic.AddUint64(&d.writeIndex, 1) + idx := writeIndex % uint64(len(d.buffer)) + old := atomic.LoadPointer(&d.buffer[idx]) + + if old != nil && + (*bucket)(old) != nil && + (*bucket)(old).seq > writeIndex-uint64(len(d.buffer)) { + log.Println("Diode set collision: consider using a larger diode") + continue + } + + newBucket := &bucket{ + data: data, + seq: writeIndex, + } + + if !atomic.CompareAndSwapPointer(&d.buffer[idx], old, unsafe.Pointer(newBucket)) { + log.Println("Diode set collision: consider using a larger diode") + continue + } + + return + } +} + +// TryNext will attempt to read from the next slot of the ring buffer. +// If there is not data available, it will return (nil, false). +func (d *ManyToOne) TryNext() (data GenericDataType, ok bool) { + // Read a value from the ring buffer based on the readIndex. + idx := d.readIndex % uint64(len(d.buffer)) + result := (*bucket)(atomic.SwapPointer(&d.buffer[idx], nil)) + + // When the result is nil that means the writer has not had the + // opportunity to write a value into the diode. This value must be ignored + // and the read head must not increment. + if result == nil { + return nil, false + } + + // When the seq value is less than the current read index that means a + // value was read from idx that was previously written but has since has + // been dropped. This value must be ignored and the read head must not + // increment. + // + // The simulation for this scenario assumes the fast forward occurred as + // detailed below. + // + // 5. The reader reads again getting seq 5. It then reads again expecting + // seq 6 but gets seq 2. This is a read of a stale value that was + // effectively "dropped" so the read fails and the read head stays put. + // `| 4 | 5 | 2 | 3 |` r: 7, w: 6 + // + if result.seq < d.readIndex { + return nil, false + } + + // When the seq value is greater than the current read index that means a + // value was read from idx that overwrote the value that was expected to + // be at this idx. This happens when the writer has lapped the reader. The + // reader needs to catch up to the writer so it moves its write head to + // the new seq, effectively dropping the messages that were not read in + // between the two values. + // + // Here is a simulation of this scenario: + // + // 1. Both the read and write heads start at 0. + // `| nil | nil | nil | nil |` r: 0, w: 0 + // 2. The writer fills the buffer. + // `| 0 | 1 | 2 | 3 |` r: 0, w: 4 + // 3. The writer laps the read head. + // `| 4 | 5 | 2 | 3 |` r: 0, w: 6 + // 4. The reader reads the first value, expecting a seq of 0 but reads 4, + // this forces the reader to fast forward to 5. + // `| 4 | 5 | 2 | 3 |` r: 5, w: 6 + // + if result.seq > d.readIndex { + dropped := result.seq - d.readIndex + d.readIndex = result.seq + d.alerter.Alert(int(dropped)) // nolint:gosec + } + + // Only increment read index if a regular read occurred (where seq was + // equal to readIndex) or a value was read that caused a fast forward + // (where seq was greater than readIndex). + // + d.readIndex++ + return result.data, true +} diff --git a/utils/diodes/one_to_one.go b/utils/diodes/one_to_one.go new file mode 100644 index 0000000000..9b628b5a2e --- /dev/null +++ b/utils/diodes/one_to_one.go @@ -0,0 +1,129 @@ +package diodes + +import ( + "sync/atomic" + "unsafe" +) + +// GenericDataType is the data type the diodes operate on. +type GenericDataType unsafe.Pointer + +// Alerter is used to report how many values were overwritten since the +// last write. +type Alerter interface { + Alert(missed int) +} + +// AlertFunc type is an adapter to allow the use of ordinary functions as +// Alert handlers. +type AlertFunc func(missed int) + +// Alert calls f(missed) +func (f AlertFunc) Alert(missed int) { + f(missed) +} + +type bucket struct { + data GenericDataType + seq uint64 // seq is the recorded write index at the time of writing +} + +// OneToOne diode is meant to be used by a single reader and a single writer. +// It is not thread safe if used otherwise. +type OneToOne struct { + buffer []unsafe.Pointer + writeIndex uint64 + readIndex uint64 + alerter Alerter +} + +// NewOneToOne creates a new diode is meant to be used by a single reader and +// a single writer. The alerter is invoked on the read's go-routine. It is +// called when it notices that the writer go-routine has passed it and wrote +// over data. A nil can be used to ignore alerts. +func NewOneToOne(size int, alerter Alerter) *OneToOne { + if alerter == nil { + alerter = AlertFunc(func(int) {}) + } + + return &OneToOne{ + buffer: make([]unsafe.Pointer, size), + alerter: alerter, + } +} + +// Set sets the data in the next slot of the ring buffer. +func (d *OneToOne) Set(data GenericDataType) { + idx := d.writeIndex % uint64(len(d.buffer)) + + newBucket := &bucket{ + data: data, + seq: d.writeIndex, + } + d.writeIndex++ + + atomic.StorePointer(&d.buffer[idx], unsafe.Pointer(newBucket)) +} + +// TryNext will attempt to read from the next slot of the ring buffer. +// If there is no data available, it will return (nil, false). +func (d *OneToOne) TryNext() (data GenericDataType, ok bool) { + // Read a value from the ring buffer based on the readIndex. + idx := d.readIndex % uint64(len(d.buffer)) + result := (*bucket)(atomic.SwapPointer(&d.buffer[idx], nil)) + + // When the result is nil that means the writer has not had the + // opportunity to write a value into the diode. This value must be ignored + // and the read head must not increment. + if result == nil { + return nil, false + } + + // When the seq value is less than the current read index that means a + // value was read from idx that was previously written but has since has + // been dropped. This value must be ignored and the read head must not + // increment. + // + // The simulation for this scenario assumes the fast forward occurred as + // detailed below. + // + // 5. The reader reads again getting seq 5. It then reads again expecting + // seq 6 but gets seq 2. This is a read of a stale value that was + // effectively "dropped" so the read fails and the read head stays put. + // `| 4 | 5 | 2 | 3 |` r: 7, w: 6 + // + if result.seq < d.readIndex { + return nil, false + } + + // When the seq value is greater than the current read index that means a + // value was read from idx that overwrote the value that was expected to + // be at this idx. This happens when the writer has lapped the reader. The + // reader needs to catch up to the writer so it moves its write head to + // the new seq, effectively dropping the messages that were not read in + // between the two values. + // + // Here is a simulation of this scenario: + // + // 1. Both the read and write heads start at 0. + // `| nil | nil | nil | nil |` r: 0, w: 0 + // 2. The writer fills the buffer. + // `| 0 | 1 | 2 | 3 |` r: 0, w: 4 + // 3. The writer laps the read head. + // `| 4 | 5 | 2 | 3 |` r: 0, w: 6 + // 4. The reader reads the first value, expecting a seq of 0 but reads 4, + // this forces the reader to fast forward to 5. + // `| 4 | 5 | 2 | 3 |` r: 5, w: 6 + // + if result.seq > d.readIndex { + dropped := result.seq - d.readIndex + d.readIndex = result.seq + d.alerter.Alert(int(dropped)) // nolint:gosec + } + + // Only increment read index if a regular read occurred (where seq was + // equal to readIndex) or a value was read that caused a fast forward + // (where seq was greater than readIndex). + d.readIndex++ + return result.data, true +} diff --git a/utils/diodes/poller.go b/utils/diodes/poller.go new file mode 100644 index 0000000000..575837355c --- /dev/null +++ b/utils/diodes/poller.go @@ -0,0 +1,80 @@ +package diodes + +import ( + "context" + "time" +) + +// Diode is any implementation of a diode. +type Diode interface { + Set(GenericDataType) + TryNext() (GenericDataType, bool) +} + +// Poller will poll a diode until a value is available. +type Poller struct { + Diode + interval time.Duration + ctx context.Context +} + +// PollerConfigOption can be used to setup the poller. +type PollerConfigOption func(*Poller) + +// WithPollingInterval sets the interval at which the diode is queried +// for new data. The default is 10ms. +func WithPollingInterval(interval time.Duration) PollerConfigOption { + return PollerConfigOption(func(c *Poller) { + c.interval = interval + }) +} + +// WithPollingContext sets the context to cancel any retrieval (Next()). It +// will not change any results for adding data (Set()). Default is +// context.Background(). +func WithPollingContext(ctx context.Context) PollerConfigOption { + return PollerConfigOption(func(c *Poller) { + c.ctx = ctx + }) +} + +// NewPoller returns a new Poller that wraps the given diode. +func NewPoller(d Diode, opts ...PollerConfigOption) *Poller { + p := &Poller{ + Diode: d, + interval: 10 * time.Millisecond, + ctx: context.Background(), + } + + for _, o := range opts { + o(p) + } + + return p +} + +// Next polls the diode until data is available or until the context is done. +// If the context is done, then nil will be returned. +func (p *Poller) Next() GenericDataType { + for { + data, ok := p.Diode.TryNext() // nolint:staticcheck + if !ok { + if p.IsDone() { + return nil + } + + time.Sleep(p.interval) + continue + } + return data + } +} + +func (p *Poller) IsDone() bool { + select { + case <-p.ctx.Done(): + return true + default: + return false + } +} diff --git a/utils/diodes/waiter.go b/utils/diodes/waiter.go new file mode 100644 index 0000000000..fc203108da --- /dev/null +++ b/utils/diodes/waiter.go @@ -0,0 +1,71 @@ +package diodes + +import ( + "context" +) + +// Waiter will use a channel signal to alert the reader to when data is +// available. +type Waiter struct { + Diode + c chan struct{} + ctx context.Context +} + +// WaiterConfigOption can be used to setup the waiter. +type WaiterConfigOption func(*Waiter) + +// WithWaiterContext sets the context to cancel any retrieval (Next()). It +// will not change any results for adding data (Set()). Default is +// context.Background(). +func WithWaiterContext(ctx context.Context) WaiterConfigOption { + return WaiterConfigOption(func(c *Waiter) { + c.ctx = ctx + }) +} + +// NewWaiter returns a new Waiter that wraps the given diode. +func NewWaiter(d Diode, opts ...WaiterConfigOption) *Waiter { + w := new(Waiter) + w.Diode = d + w.c = make(chan struct{}, 1) + w.ctx = context.Background() + + for _, opt := range opts { + opt(w) + } + + return w +} + +// Set invokes the wrapped diode's Set with the given data and uses broadcast +// to wake up any readers. +func (w *Waiter) Set(data GenericDataType) { + w.Diode.Set(data) + w.broadcast() +} + +// broadcast sends to the channel if it can. +func (w *Waiter) broadcast() { + select { + case w.c <- struct{}{}: + default: + } +} + +// Next returns the next data point on the wrapped diode. If there is no new +// data, it will wait for Set to be called or the context to be done. If the +// context is done, then nil will be returned. +func (w *Waiter) Next() GenericDataType { + for { + data, ok := w.Diode.TryNext() // nolint:staticcheck + if ok { + return data + } + select { + case <-w.ctx.Done(): + return nil + case <-w.c: + } + } +} diff --git a/utils/logs/fifo_logger.go b/utils/logs/fifo_logger.go index de5297fc19..a98e34ea45 100644 --- a/utils/logs/fifo_logger.go +++ b/utils/logs/fifo_logger.go @@ -4,148 +4,196 @@ import ( "bytes" "context" "fmt" - "io" "iter" - "log" - "sync" + "strings" "time" + "github.com/ARM-software/golang-utils/utils/commonerrors" + "github.com/ARM-software/golang-utils/utils/diodes" "github.com/ARM-software/golang-utils/utils/parallelisation" ) -var _ Loggers = &FIFOLoggers{} +const ( + newLine = '\n' + bufferSize = 10000 +) -type FIFOWriter struct { - io.WriteCloser - mu sync.RWMutex - Logs bytes.Buffer +type loggerAlerter struct { + log Loggers } -func (w *FIFOWriter) Write(p []byte) (n int, err error) { - w.mu.RLock() - defer w.mu.RUnlock() - w.Logs.Write(p) - return +func (l *loggerAlerter) Alert(missed int) { + if l.log != nil { + l.log.LogError(fmt.Sprintf("Logger dropped %d messages", missed)) + } } -func (w *FIFOWriter) Close() (err error) { - w.mu.Lock() - defer w.mu.Unlock() - w.Logs.Reset() - return +func newLoggerAlerter(logs Loggers) diodes.Alerter { + return &loggerAlerter{log: logs} } -func (w *FIFOWriter) Read() string { - w.mu.Lock() - defer w.mu.Unlock() - n := w.Logs.Len() - if n == 0 { - return "" +func newFIFODiode(ctx context.Context, ringBufferSize int, pollingPeriod time.Duration, droppedMessagesLogger Loggers) *fifoDiode { + dCtx, cancel := context.WithCancel(ctx) + cancelStore := parallelisation.NewCancelFunctionsStore() + cancelStore.RegisterCancelFunction(cancel) + return &fifoDiode{ + d: diodes.NewPoller(diodes.NewManyToOne(ringBufferSize, newLoggerAlerter(droppedMessagesLogger)), diodes.WithPollingInterval(pollingPeriod), diodes.WithPollingContext(dCtx)), + cancelStore: cancelStore, } - bytes := w.Logs.Next(n) - return string(bytes) } -func (w *FIFOWriter) ReadLines(ctx context.Context) iter.Seq[string] { +type fifoDiode struct { + d *diodes.Poller + cancelStore *parallelisation.CancelFunctionStore +} + +func (d *fifoDiode) Set(data []byte) { + d.d.Set(diodes.GenericDataType(&data)) +} + +func (d *fifoDiode) Close() error { + d.cancelStore.Cancel() + return nil +} + +// LineIterator returns an iterator over lines. It should only be called within the context of the same goroutine. +func (d *fifoDiode) LineIterator(ctx context.Context) iter.Seq[string] { return func(yield func(string) bool) { - var partial []byte - for { - if err := parallelisation.DetermineContextError(ctx); err != nil { + err := IterateOverLines(ctx, func(fCtx context.Context) (b []byte, err error) { + err = parallelisation.DetermineContextError(fCtx) + if err != nil { return } - - buf := func() []byte { - w.mu.Lock() - defer w.mu.Unlock() - defer w.Logs.Reset() - tmp := w.Logs.Bytes() - buf := make([]byte, len(tmp)) - copy(buf, tmp) - return buf - }() - - if len(buf) == 0 { - if err := parallelisation.DetermineContextError(ctx); err != nil { - if len(partial) > 0 { - yield(string(partial)) - } - return - } - - parallelisation.SleepWithContext(ctx, 50*time.Millisecond) - continue + data, has := d.d.TryNext() + if has { + b = *(*[]byte)(data) + return } - - if len(partial) > 0 { - buf = append(partial, buf...) - partial = nil + if d.d.IsDone() { + err = commonerrors.ErrEOF + return } + return + }, yield) + if err != nil { + return + } + } +} - for { - idx := bytes.IndexByte(buf, '\n') - if idx < 0 { - break - } - line := buf[:idx] - - if len(line) > 0 && line[len(line)-1] == '\r' { - line = line[:len(line)-1] - } - buf = buf[idx+1:] - if len(line) == 0 { - continue - } - - if !yield(string(line)) { - return - } - } +func cleanseLine(line string) string { + return strings.TrimSuffix(strings.ReplaceAll(line, "\r", ""), string(newLine)) +} - if len(buf) > 0 { - partial = buf +func iterateOverLines(ctx context.Context, b *bytes.Buffer, yield func(string) bool) (err error) { + for { + subErr := parallelisation.DetermineContextError(ctx) + if subErr != nil { + err = subErr + return + } + line, foundErr := b.ReadString(newLine) + if foundErr == nil { + if !yield(line) { + err = commonerrors.ErrEOF + return + } + } else { + b.Reset() + _, subErr = b.Write([]byte(line)) + if subErr != nil { + err = subErr + return } + return + } + } +} + +func IterateOverLines(ctx context.Context, fetchNext func(fCtx context.Context) ([]byte, error), yield func(string) bool) (err error) { + extendedYield := func(s string) bool { + return yield(cleanseLine(s)) + } + b := bytes.NewBuffer(make([]byte, 0, 512)) + for { + subErr := parallelisation.DetermineContextError(ctx) + if subErr != nil { + err = subErr + return + } + nextBuf, subErr := fetchNext(ctx) + if subErr != nil { + err = subErr + return + } + if len(nextBuf) == 0 { + parallelisation.SleepWithContext(ctx, 10*time.Millisecond) + continue + } + _, subErr = b.Write(nextBuf) + if subErr != nil { + err = subErr + return + } + subErr = iterateOverLines(ctx, b, extendedYield) + if subErr != nil { + err = subErr + return } } } type FIFOLoggers struct { - GenericLoggers - LogWriter FIFOWriter + d *fifoDiode + newline bool } -func (l *FIFOLoggers) Check() error { - return l.GenericLoggers.Check() +func (l *FIFOLoggers) SetLogSource(_ string) error { + return nil +} + +func (l *FIFOLoggers) SetLoggerSource(_ string) error { + return nil +} + +func (l *FIFOLoggers) Log(output ...any) { + l.log(output...) } -func (l *FIFOLoggers) Read() string { - return l.LogWriter.Read() +func (l *FIFOLoggers) LogError(err ...any) { + l.log(err...) } -func (l *FIFOLoggers) ReadLines(ctx context.Context) iter.Seq[string] { - return l.LogWriter.ReadLines(ctx) +func (l *FIFOLoggers) log(args ...any) { + b := bytes.NewBufferString(fmt.Sprint(args...)) + if l.newline { + _, _ = b.Write([]byte{newLine}) + } + l.d.Set(b.Bytes()) +} + +func (l *FIFOLoggers) Check() error { + if l.d == nil { + return commonerrors.UndefinedVariable("FIFO diode") + } + return nil +} + +// LineIterator returns an iterator over lines. It should only be called within the context of the same goroutine. +func (l *FIFOLoggers) LineIterator(ctx context.Context) iter.Seq[string] { + return l.d.LineIterator(ctx) } // Close closes the logger func (l *FIFOLoggers) Close() (err error) { - err = l.LogWriter.Close() - if err != nil { - return - } - err = l.GenericLoggers.Close() - return + return l.d.Close() } // NewFIFOLogger creates a logger to a bytes buffer. // All messages (whether they are output or error) are merged together. // Once messages have been accessed they are gone -func NewFIFOLogger(loggerSource string) (loggers *FIFOLoggers, err error) { - loggers = &FIFOLoggers{ - LogWriter: FIFOWriter{}, - } - loggers.GenericLoggers = GenericLoggers{ - Output: log.New(&loggers.LogWriter, fmt.Sprintf("[%v] Output: ", loggerSource), log.LstdFlags), - Error: log.New(&loggers.LogWriter, fmt.Sprintf("[%v] Error: ", loggerSource), log.LstdFlags), - } +func NewFIFOLogger() (loggers *FIFOLoggers, err error) { + loggers, err = newDefaultFIFOLogger(true) return } @@ -153,12 +201,22 @@ func NewFIFOLogger(loggerSource string) (loggers *FIFOLoggers, err error) { // All messages (whether they are output or error) are merged together. // Once messages have been accessed they are gone func NewPlainFIFOLogger() (loggers *FIFOLoggers, err error) { - loggers = &FIFOLoggers{ - LogWriter: FIFOWriter{}, + loggers, err = newDefaultFIFOLogger(false) + return +} + +func newDefaultFIFOLogger(addNewLine bool) (loggers *FIFOLoggers, err error) { + l, err := NewNoopLogger("FIFO") + if err != nil { + return } - loggers.GenericLoggers = GenericLoggers{ - Output: log.New(&loggers.LogWriter, "", 0), - Error: log.New(&loggers.LogWriter, "", 0), + return NewFIFOLoggerWithBuffer(addNewLine, bufferSize, 50*time.Millisecond, l) +} + +func NewFIFOLoggerWithBuffer(addNewLine bool, ringBufferSize int, pollingPeriod time.Duration, droppedMessageLogger Loggers) (loggers *FIFOLoggers, err error) { + loggers = &FIFOLoggers{ + d: newFIFODiode(context.Background(), ringBufferSize, pollingPeriod, droppedMessageLogger), + newline: addNewLine, } return } diff --git a/utils/logs/fifo_logger_test.go b/utils/logs/fifo_logger_test.go index 8aa757f98c..ece2dba43b 100644 --- a/utils/logs/fifo_logger_test.go +++ b/utils/logs/fifo_logger_test.go @@ -1,104 +1,146 @@ package logs import ( + "bytes" "context" - "regexp" + "fmt" + "io" "strings" "testing" "time" + "github.com/go-faker/faker/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/ARM-software/golang-utils/utils/commonerrors" + "github.com/ARM-software/golang-utils/utils/commonerrors/errortest" + "github.com/ARM-software/golang-utils/utils/parallelisation" ) -func TestFIFOLoggerRead(t *testing.T) { - loggers, err := NewFIFOLogger("Test") - require.NoError(t, err) - testLog(t, loggers) - loggers.LogError("Test err") - loggers.Log("Test1") - contents := loggers.Read() - require.NotEmpty(t, contents) - require.True(t, strings.Contains(contents, "Test err")) - require.True(t, strings.Contains(contents, "Test1")) - loggers.Log("Test2") - contents = loggers.Read() - require.NotEmpty(t, contents) - require.False(t, strings.Contains(contents, "Test err")) - require.False(t, strings.Contains(contents, "Test1")) - require.True(t, strings.Contains(contents, "Test2")) - contents = loggers.Read() - require.Empty(t, contents) -} +func TestFIFOLoggerLineIterator(t *testing.T) { + t.Run("logger tests", func(t *testing.T) { + loggers, err := NewFIFOLogger() + require.NoError(t, err) + defer func() { _ = loggers.Close() }() + testLog(t, loggers) + }) + t.Run("read lines", func(t *testing.T) { + loggers, err := NewFIFOLogger() + require.NoError(t, err) + defer func() { _ = loggers.Close() }() + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + loggers.LogError("Test\r err\n") + loggers.Log("\rTest1\r") + count := 0 -func TestPlainFIFOLoggerRead(t *testing.T) { - loggers, err := NewPlainFIFOLogger() - require.NoError(t, err) - testLog(t, loggers) - loggers.LogError("Test err") - loggers.Log("Test1") - contents := loggers.Read() - require.NotEmpty(t, contents) - require.True(t, strings.Contains(contents, "Test err")) - require.True(t, strings.Contains(contents, "Test1")) - loggers.Log("Test2") - contents = loggers.Read() - require.NotEmpty(t, contents) - require.False(t, strings.Contains(contents, "Test err")) - require.False(t, strings.Contains(contents, "Test1")) - require.True(t, strings.Contains(contents, "Test2")) - contents = loggers.Read() - require.Empty(t, contents) -} + var b strings.Builder + for line := range loggers.LineIterator(ctx) { + _, err := b.WriteString(line + "\n") + require.NoError(t, err) + count++ + } -func TestFIFOLoggerReadlines(t *testing.T) { - loggers, err := NewFIFOLogger("Test") - require.NoError(t, err) - testLog(t, loggers) - loggers.LogError("Test err\n") - loggers.Log("Test1") - count := 0 - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() + assert.Equal(t, "Test err\n\nTest1\n", b.String()) + assert.Equal(t, 3, count) // log error added a line + }) +} - var b strings.Builder - for line := range loggers.ReadLines(ctx) { - _, err := b.WriteString(line + "\n") +func TestPlainFIFOLoggerLineIterator(t *testing.T) { + t.Run("logger tests", func(t *testing.T) { + loggers, err := NewPlainFIFOLogger() require.NoError(t, err) - count++ - } - - assert.Regexp(t, regexp.MustCompile(`\[Test\] Error: .* .* Test err\n\[Test\] Output: .* .* Test1\n`), b.String()) - assert.Equal(t, 2, count) -} + defer func() { _ = loggers.Close() }() + testLog(t, loggers) + }) + t.Run("read lines", func(t *testing.T) { + loggers, err := NewPlainFIFOLogger() + require.NoError(t, err) + defer func() { _ = loggers.Close() }() + go func() { + time.Sleep(500 * time.Millisecond) + loggers.LogError("Test err") + loggers.Log("") + time.Sleep(100 * time.Millisecond) + loggers.Log("Test1") + loggers.Log("\n\n\n") + time.Sleep(200 * time.Millisecond) + loggers.Log("Test2\n") + }() -func TestPlainFIFOLoggerReadlines(t *testing.T) { - loggers, err := NewPlainFIFOLogger() - require.NoError(t, err) - testLog(t, loggers) + count := 0 + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() - go func() { - time.Sleep(500 * time.Millisecond) - loggers.LogError("Test err") - loggers.Log("") - time.Sleep(100 * time.Millisecond) - loggers.Log("Test1") - loggers.Log("\n\n\n") - time.Sleep(200 * time.Millisecond) - loggers.Log("Test2") - }() + var b strings.Builder + for line := range loggers.LineIterator(ctx) { + _, err := b.WriteString(line + "\n") + require.NoError(t, err) + count++ + } - count := 0 - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() + assert.Equal(t, "Test errTest1\n\n\nTest2\n", b.String()) + assert.Equal(t, 4, count) + }) +} - var b strings.Builder - for line := range loggers.ReadLines(ctx) { - _, err := b.WriteString(line + "\n") - require.NoError(t, err) - count++ +func Test_iterateOverLines(t *testing.T) { + endIncompleteLine := faker.Word() + testLines := fmt.Sprintf("%v\n%v", strings.ReplaceAll(faker.Paragraph(), " ", "/r/n"), endIncompleteLine) + buf := bytes.NewBufferString(testLines) + numberOfLines := strings.Count(testLines, "\n") + lineCounter := 0 + yield := func(string) bool { + lineCounter++ + return true } + t.Run("cancelled", func(t *testing.T) { + cancelCtx, cancel := context.WithCancel(context.Background()) + cancel() + errortest.AssertError(t, iterateOverLines(cancelCtx, buf, yield), commonerrors.ErrCancelled) + assert.Zero(t, lineCounter) + }) + t.Run("success", func(t *testing.T) { + err := iterateOverLines(context.Background(), buf, yield) + require.NoError(t, err) + assert.Equal(t, numberOfLines, lineCounter) + assert.Equal(t, len(endIncompleteLine), buf.Len()) + line, err := buf.ReadString(newLine) + require.Error(t, err) + assert.Equal(t, io.EOF, err) + assert.Equal(t, endIncompleteLine, line) + }) +} - assert.Equal(t, "Test err\nTest1\nTest2\n", b.String()) - assert.Equal(t, 3, count) +func Test_IterateOverLines(t *testing.T) { + lastIncompleteLine := faker.Sentence() + overallLines := []string{fmt.Sprintf("%v\n%v", faker.Word(), faker.Word()), fmt.Sprintf("%v\n%v", strings.ReplaceAll(faker.Sentence(), " ", "\r"), faker.Name()), fmt.Sprintf("%v\n%v\n%v", faker.DomainName(), faker.IPv4(), lastIncompleteLine)} + expectedLines := strings.Split(strings.ReplaceAll(strings.TrimSuffix(strings.Join(overallLines, ""), "\n"+lastIncompleteLine), "\r", ""), "\n") + index := 0 + nextLine := func(fCtx context.Context) ([]byte, error) { + err := parallelisation.DetermineContextError(fCtx) + if err != nil { + return nil, err + } + if index >= len(overallLines) { + return nil, nil + } + b := []byte(overallLines[index]) + index++ + return b, nil + } + lineCounter := 0 + var readLines []string + yield := func(l string) bool { + lineCounter++ + readLines = append(readLines, l) + return true + } + cctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + err := IterateOverLines(cctx, nextLine, yield) + errortest.AssertError(t, err, commonerrors.ErrTimeout, commonerrors.ErrCancelled) + assert.Equal(t, 4, lineCounter) + assert.EqualValues(t, expectedLines, readLines) } diff --git a/utils/subprocess/command_wrapper.go b/utils/subprocess/command_wrapper.go index 9961e5c590..a0949edb35 100644 --- a/utils/subprocess/command_wrapper.go +++ b/utils/subprocess/command_wrapper.go @@ -59,15 +59,23 @@ func (c *cmdWrapper) Run() error { return ConvertCommandError(c.cmd.Run()) } -func (c *cmdWrapper) Stop() error { +type interruptType int + +const ( + kill interruptType = 9 + term interruptType = 15 +) + +func (c *cmdWrapper) interrupt(interrupt interruptType) error { c.mu.RLock() defer c.mu.RUnlock() if c.cmd == nil { - return fmt.Errorf("%w:undefined command", commonerrors.ErrUndefined) + return commonerrors.New(commonerrors.ErrUndefined, "undefined command") } subprocess := c.cmd.Process ctx, cancel := context.WithCancel(context.Background()) defer cancel() + var stopErr error if subprocess != nil { pid := subprocess.Pid parallelisation.ScheduleAfter(ctx, 10*time.Millisecond, func(time.Time) { @@ -75,11 +83,26 @@ func (c *cmdWrapper) Stop() error { if process == nil || err != nil { return } - _ = process.KillWithChildren(ctx) + switch interrupt { + case kill: + _ = process.KillWithChildren(ctx) + case term: + _ = process.Terminate(ctx) + default: + stopErr = commonerrors.New(commonerrors.ErrInvalid, "unknown interrupt type for process") + } }) } _ = c.cmd.Wait() - return nil + return stopErr +} + +func (c *cmdWrapper) Stop() error { + return c.interrupt(kill) +} + +func (c *cmdWrapper) Interrupt() error { + return c.interrupt(term) } func (c *cmdWrapper) Pid() (pid int, err error) { diff --git a/utils/subprocess/command_wrapper_test.go b/utils/subprocess/command_wrapper_test.go index 19ec07a7c1..fb2f65451d 100644 --- a/utils/subprocess/command_wrapper_test.go +++ b/utils/subprocess/command_wrapper_test.go @@ -162,6 +162,10 @@ func TestCmdStartStop(t *testing.T) { require.Error(t, err) err = wrapper.Stop() require.NoError(t, err) + err = wrapper.Start() + require.Error(t, err) + err = wrapper.Interrupt() + require.NoError(t, err) }) } } diff --git a/utils/subprocess/executor.go b/utils/subprocess/executor.go index b089644671..99d96a1872 100644 --- a/utils/subprocess/executor.go +++ b/utils/subprocess/executor.go @@ -192,6 +192,14 @@ func (s *Subprocess) IsOn() bool { return s.isRunning.Load() && s.processMonitoring.IsOn() } +// Wait waits for the command to exit and waits for any copying to +// stdin or copying from stdout or stderr to complete. +// +// The command must have been started by Start. +func (s *Subprocess) Wait() error { + return s.command.cmdWrapper.cmd.Wait() +} + // Start starts the process if not already started. // This method is idempotent. func (s *Subprocess) Start() (err error) { @@ -268,6 +276,13 @@ func (s *Subprocess) Stop() (err error) { return s.stop(true) } +// Interrupt terminates the process +// This method should be used in combination with `Start`. +// This method is idempotent +func (s *Subprocess) Interrupt() (err error) { + return s.interrupt() +} + // Restart restarts a process. It will stop the process if currently running. func (s *Subprocess) Restart() (err error) { err = s.stop(false) @@ -314,3 +329,25 @@ func (s *Subprocess) stop(cancel bool) (err error) { s.messaging.LogEnd(nil) return } + +func (s *Subprocess) interrupt() (err error) { + if !s.IsOn() { + return + } + err = s.Check() + if err != nil { + return + } + s.mu.Lock() + defer s.mu.Unlock() + defer s.Cancel() + if !s.IsOn() { + return + } + s.messaging.LogStopping() + err = s.getCmd().Interrupt() + s.command.Reset() + s.isRunning.Store(false) + s.messaging.LogEnd(nil) + return +} diff --git a/utils/subprocess/executor_test.go b/utils/subprocess/executor_test.go index 3d994c488f..2eb6628604 100644 --- a/utils/subprocess/executor_test.go +++ b/utils/subprocess/executor_test.go @@ -178,6 +178,74 @@ func TestStartStop(t *testing.T) { } } +func TestStartInterrupt(t *testing.T) { + currentDir, err := os.Getwd() + require.NoError(t, err) + tests := []struct { + name string + cmdWindows string + argWindows []string + cmdOther string + argOther []string + }{ + { + name: "ShortProcess", + cmdWindows: "cmd", + argWindows: []string{"dir", currentDir}, + cmdOther: "ls", + argOther: []string{"-l", currentDir}, + }, + { + name: "LongProcess", + cmdWindows: "cmd", + argWindows: []string{"SLEEP 1"}, + cmdOther: "sleep", + argOther: []string{"1"}, + }, + } + + for i := range tests { + test := tests[i] + t.Run(test.name, func(t *testing.T) { + defer goleak.VerifyNone(t) + loggers, err := logs.NewLogrLogger(logstest.NewTestLogger(t), "test") + require.NoError(t, err) + + var p *Subprocess + if platform.IsWindows() { + p, err = New(context.Background(), loggers, "", "", "", test.cmdWindows, test.argWindows...) + } else { + p, err = New(context.Background(), loggers, "", "", "", test.cmdOther, test.argOther...) + } + require.NoError(t, err) + require.NotNil(t, p) + assert.False(t, p.IsOn()) + err = p.Start() + require.NoError(t, err) + assert.True(t, p.IsOn()) + + // Checking idempotence + err = p.Start() + require.NoError(t, err) + err = p.Check() + require.NoError(t, err) + + time.Sleep(200 * time.Millisecond) + err = p.Restart() + require.NoError(t, err) + assert.True(t, p.IsOn()) + err = p.Interrupt() + require.NoError(t, err) + assert.False(t, p.IsOn()) + // Checking idempotence + err = p.Interrupt() + require.NoError(t, err) + time.Sleep(100 * time.Millisecond) + err = p.Execute() + require.NoError(t, err) + }) + } +} func TestExecute(t *testing.T) { currentDir, err := os.Getwd() require.NoError(t, err) diff --git a/utils/subprocess/supervisor/supervisor.go b/utils/subprocess/supervisor/supervisor.go index 5570c71363..da67a7c1a4 100644 --- a/utils/subprocess/supervisor/supervisor.go +++ b/utils/subprocess/supervisor/supervisor.go @@ -24,6 +24,7 @@ type Supervisor struct { haltingErrors []error restartDelay time.Duration count uint + cmd *subprocess.Subprocess } type SupervisorOption func(*Supervisor) @@ -117,17 +118,18 @@ func (s *Supervisor) Run(ctx context.Context) (err error) { } g, _ := errgroup.WithContext(ctx) - cmd, err := s.newCommand(ctx) + s.cmd, err = s.newCommand(ctx) if err != nil { if commonerrors.Any(err, commonerrors.ErrCancelled, commonerrors.ErrTimeout) { return err } return fmt.Errorf("%w: error occurred when creating new command: %v", commonerrors.ErrUnexpected, err.Error()) } - if cmd == nil { + if s.cmd == nil { return fmt.Errorf("%w: command was undefined", commonerrors.ErrUndefined) } - g.Go(cmd.Execute) + + g.Go(s.cmd.Execute) if s.postStart != nil { err = s.postStart(ctx)