Skip to content

Commit

Permalink
Use errgroup for concurrency writes
Browse files Browse the repository at this point in the history
  • Loading branch information
Or-Geva committed Mar 18, 2024
1 parent f057ef3 commit 6a41275
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 33 deletions.
56 changes: 23 additions & 33 deletions io/multiwriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,54 +3,44 @@ package io
import (
"errors"
"io"
"sync"

"golang.org/x/sync/errgroup"
)

var ErrShortWrite = errors.New("short write")

type asyncMultiWriter struct {
writers []io.Writer
limit int
}

// AsyncMultiWriter creates a writer that duplicates its writes to all the
// provided writers asynchronous
func AsyncMultiWriter(writers ...io.Writer) io.Writer {
func AsyncMultiWriter(limit int,writers ...io.Writer) io.Writer {
w := make([]io.Writer, len(writers))
copy(w, writers)
return &asyncMultiWriter{w}
return &asyncMultiWriter{writers: w, limit: limit}
}

// Writes data asynchronously to each writer and waits for all of them to complete.
// In case of an error, the writing will not complete.
func (t *asyncMultiWriter) Write(p []byte) (int, error) {
var wg sync.WaitGroup
wg.Add(len(t.writers))
errChannel := make(chan error)
finished := make(chan bool, 1)
for _, w := range t.writers {
go writeData(p, w, &wg, errChannel)
}
go func() {
wg.Wait()
close(finished)
}()
// This select will block until one of the two channels returns a value.
select {
case <-finished:
case err := <-errChannel:
if err != nil {
return 0, err
}
}
return len(p), nil
}
func writeData(p []byte, w io.Writer, wg *sync.WaitGroup, errChan chan error) {
n, err := w.Write(p)
if err != nil {
errChan <- err
}
if n != len(p) {
errChan <- ErrShortWrite
func (amw *asyncMultiWriter) Write(p []byte) (int, error) {
eg := errgroup.Group{}
eg.SetLimit(amw.limit)
for _, w := range amw.writers {
func(w io.Writer) {
eg.Go(func() error {
n, err := w.Write(p)
if err != nil {
return err
}
if n != len(p) {
return ErrShortWrite
}
return nil
})
}(w)
}
wg.Done()

return len(p), eg.Wait()
}
49 changes: 49 additions & 0 deletions io/multiwriter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package io

import (
"bytes"
"errors"
"testing"

"github.com/stretchr/testify/assert"
)

func TestAsyncMultiWriter(t *testing.T) {
for _, limit := range []int{1, 2} {
var buf1, buf2 bytes.Buffer
multiWriter := AsyncMultiWriter(limit, &buf1, &buf2)

data := []byte("test data")
n, err := multiWriter.Write(data)
assert.NoError(t, err)
assert.Equal(t, len(data), n)

// Check if data is correctly written to both writers
if buf1.String() != string(data) || buf2.String() != string(data) {
t.Errorf("Data not written correctly to all writers")
}
}
}

// TestAsyncMultiWriter_Error tests the error handling behavior of AsyncMultiWriter.
func TestAsyncMultiWriter_Error(t *testing.T) {
expectedErr := errors.New("write error")

// Mock writer that always returns an error
mockWriter := &mockWriter{writeErr: expectedErr}
multiWriter := AsyncMultiWriter(2, mockWriter)

_, err := multiWriter.Write([]byte("test data"))
if err != expectedErr {
t.Errorf("Expected error: %v, got: %v", expectedErr, err)
}
}

// Mock writer to simulate Write errors
type mockWriter struct {
writeErr error
}

func (m *mockWriter) Write(p []byte) (int, error) {
return 0, m.writeErr
}

0 comments on commit 6a41275

Please sign in to comment.