Skip to content

Commit

Permalink
Fix CR
Browse files Browse the repository at this point in the history
  • Loading branch information
Or-Geva committed Mar 18, 2024
1 parent 6a41275 commit 4a4c0cc
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 23 deletions.
27 changes: 13 additions & 14 deletions io/multiwriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"golang.org/x/sync/errgroup"
)

var ErrShortWrite = errors.New("short write")
var ErrShortWrite = errors.New("The number of bytes written is less than the length of the input")

type asyncMultiWriter struct {
writers []io.Writer
Expand All @@ -16,7 +16,7 @@ type asyncMultiWriter struct {

// AsyncMultiWriter creates a writer that duplicates its writes to all the
// provided writers asynchronous
func AsyncMultiWriter(limit int,writers ...io.Writer) io.Writer {
func AsyncMultiWriter(limit int, writers ...io.Writer) io.Writer {
w := make([]io.Writer, len(writers))
copy(w, writers)
return &asyncMultiWriter{writers: w, limit: limit}
Expand All @@ -28,18 +28,17 @@ func (amw *asyncMultiWriter) Write(p []byte) (int, error) {
eg := errgroup.Group{}
eg.SetLimit(amw.limit)
for _, w := range amw.writers {
func(w io.Writer) {
eg.Go(func() error {
n, err := w.Write(p)
if err != nil {
return err
}
if n != len(p) {
return ErrShortWrite
}
return nil
})
}(w)
currentWriter := w
eg.Go(func() error {
n, err := currentWriter.Write(p)
if err != nil {
return err
}
if n != len(p) {
return ErrShortWrite
}
return nil
})
}

return len(p), eg.Wait()
Expand Down
10 changes: 3 additions & 7 deletions io/multiwriter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,11 @@ func TestAsyncMultiWriter(t *testing.T) {
assert.Equal(t, len(data), n)

// Check if data is correctly written to both writers
if buf1.String() != string(data) || buf2.String() != string(data) {
t.Errorf("Data not written correctly to all writers")
}
assert.Equal(t, string(data), buf1.String())
assert.Equal(t, string(data), buf2.String())
}
}

// TestAsyncMultiWriter_Error tests the error handling behavior of AsyncMultiWriter.
func TestAsyncMultiWriter_Error(t *testing.T) {
expectedErr := errors.New("write error")

Expand All @@ -34,9 +32,7 @@ func TestAsyncMultiWriter_Error(t *testing.T) {
multiWriter := AsyncMultiWriter(2, mockWriter)

_, err := multiWriter.Write([]byte("test data"))
if err != expectedErr {
t.Errorf("Expected error: %v, got: %v", expectedErr, err)
}
assert.Equal(t, expectedErr, err)
}

// Mock writer to simulate Write errors
Expand Down
4 changes: 2 additions & 2 deletions unarchive/archive.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ func (u *Unarchiver) byExtension(filename string) (interface{}, error) {

// Make sure the archive is free from Zip Slip and Zip symlinks attacks
func inspectArchive(archive interface{}, localArchivePath, destinationDir string) error {
// If the destination directory ends with a slash, delete it.
// This is necessary to handle a situation where the entry path might be at the root of the destination directory,
// If the destination directory ends with a slash, delete it.
// This is necessary to handle a situation where the entry path might be at the root of the destination directory,
// but in such a case "<destination-dir>/" is not a prefix of "<destination-dir>".
destinationDir = strings.TrimSuffix(destinationDir, string(os.PathSeparator))
walker, ok := archive.(archiver.Walker)
Expand Down

0 comments on commit 4a4c0cc

Please sign in to comment.