Skip to content

Commit

Permalink
File Stream - Support multi-writers
Browse files Browse the repository at this point in the history
  • Loading branch information
Or-Geva committed Mar 17, 2024
1 parent b0ac460 commit 8f0e0e7
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 9 deletions.
18 changes: 12 additions & 6 deletions http/filestream/filestream.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ const (
)

// The expected type of function that should be provided to the ReadFilesFromStream func, that returns the writer that should handle each file
type FileWriterFunc func(fileName string) (writer io.WriteCloser, err error)
type FileWriterFunc func(fileName string) (writer []io.WriteCloser, err error)

func ReadFilesFromStream(multipartReader *multipart.Reader, fileWriterFunc FileWriterFunc) error {
func ReadFilesFromStream(multipartReader *multipart.Reader, fileWritersFunc FileWriterFunc) error {
for {
// Read the next file streamed from client
fileReader, err := multipartReader.NextPart()
Expand All @@ -27,7 +27,7 @@ func ReadFilesFromStream(multipartReader *multipart.Reader, fileWriterFunc FileW
}
return fmt.Errorf("failed to read file: %w", err)
}
err = readFile(fileReader, fileWriterFunc)
err = readFile(fileReader, fileWritersFunc)
if err != nil {
return err
}
Expand All @@ -42,11 +42,17 @@ func readFile(fileReader *multipart.Part, fileWriterFunc FileWriterFunc) (err er
if err != nil {
return err
}
defer ioutils.Close(fileWriter, &err)
if _, err = io.Copy(fileWriter, fileReader); err != nil {
var writers []io.Writer
for _, writer := range fileWriter {
defer ioutils.Close(writer, &err)
// Create a multi writer that will write the file to all the provided writers
// We read multipart once and write to multiple writers, so we can't use the same multipart writer multiple times
writers = append(writers, writer)
}
if _, err = io.Copy(ioutils.AsyncMultiWriter(writers...), fileReader); err != nil {
return fmt.Errorf("failed writing '%s' file: %w", fileName, err)
}
return err
return nil
}

type FileInfo struct {
Expand Down
11 changes: 8 additions & 3 deletions http/filestream/filestream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ package filestream

import (
"bytes"
"github.com/stretchr/testify/assert"
"io"
"mime/multipart"
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
)

var targetDir string
Expand Down Expand Up @@ -45,6 +46,10 @@ func TestWriteFilesToStreamAndReadFilesFromStream(t *testing.T) {
assert.Equal(t, file2Content, content)
}

func simpleFileWriter(fileName string) (fileWriter io.WriteCloser, err error) {
return os.Create(filepath.Join(targetDir, fileName))
func simpleFileWriter(fileName string) (fileWriter []io.WriteCloser, err error) {
writer, err := os.Create(filepath.Join(targetDir, fileName))
if err != nil {
return nil, err
}
return []io.WriteCloser{writer}, nil
}
56 changes: 56 additions & 0 deletions io/multiwriter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package io

import (
"errors"
"io"
"sync"
)

var ErrShortWrite = errors.New("short write")

type asyncMultiWriter struct {
writers []io.Writer
}

// AsyncMultiWriter creates a writer that duplicates its writes to all the
// provided writers asynchronous
func AsyncMultiWriter(writers ...io.Writer) io.Writer {
w := make([]io.Writer, len(writers))
copy(w, writers)
return &asyncMultiWriter{w}
}

// Writes data asynchronously to each writer and waits for all of them to complete.
// In case of an error, the writing will not complete.
func (t *asyncMultiWriter) Write(p []byte) (int, error) {
var wg sync.WaitGroup
wg.Add(len(t.writers))
errChannel := make(chan error)
finished := make(chan bool, 1)
for _, w := range t.writers {
go writeData(p, w, &wg, errChannel)
}
go func() {
wg.Wait()
close(finished)
}()
// This select will block until one of the two channels returns a value.
select {
case <-finished:
case err := <-errChannel:
if err != nil {
return 0, err
}
}
return len(p), nil
}
func writeData(p []byte, w io.Writer, wg *sync.WaitGroup, errChan chan error) {
n, err := w.Write(p)
if err != nil {
errChan <- err
}
if n != len(p) {
errChan <- ErrShortWrite
}
wg.Done()
}

0 comments on commit 8f0e0e7

Please sign in to comment.