diff --git a/http/filestream/filestream.go b/http/filestream/filestream.go index bd69dd9..2c455e2 100644 --- a/http/filestream/filestream.go +++ b/http/filestream/filestream.go @@ -15,9 +15,9 @@ const ( ) // The expected type of function that should be provided to the ReadFilesFromStream func, that returns the writer that should handle each file -type FileWriterFunc func(fileName string) (writer io.WriteCloser, err error) +type FileWriterFunc func(fileName string) (writer []io.WriteCloser, err error) -func ReadFilesFromStream(multipartReader *multipart.Reader, fileWriterFunc FileWriterFunc) error { +func ReadFilesFromStream(multipartReader *multipart.Reader, fileWritersFunc FileWriterFunc) error { for { // Read the next file streamed from client fileReader, err := multipartReader.NextPart() @@ -27,7 +27,7 @@ func ReadFilesFromStream(multipartReader *multipart.Reader, fileWriterFunc FileW } return fmt.Errorf("failed to read file: %w", err) } - err = readFile(fileReader, fileWriterFunc) + err = readFile(fileReader, fileWritersFunc) if err != nil { return err } @@ -42,11 +42,17 @@ func readFile(fileReader *multipart.Part, fileWriterFunc FileWriterFunc) (err er if err != nil { return err } - defer ioutils.Close(fileWriter, &err) - if _, err = io.Copy(fileWriter, fileReader); err != nil { + var writers []io.Writer + for _, writer := range fileWriter { + defer ioutils.Close(writer, &err) + // Create a multi writer that will write the file to all the provided writers + // We read multipart once and write to multiple writers, so we can't use the same multipart writer multiple times + writers = append(writers, writer) + } + if _, err = io.Copy(ioutils.AsyncMultiWriter(writers...), fileReader); err != nil { return fmt.Errorf("failed writing '%s' file: %w", fileName, err) } - return err + return nil } type FileInfo struct { diff --git a/http/filestream/filestream_test.go b/http/filestream/filestream_test.go index 465c139..a5203b4 100644 --- a/http/filestream/filestream_test.go +++ b/http/filestream/filestream_test.go @@ -2,12 +2,13 @@ package filestream import ( "bytes" - "github.com/stretchr/testify/assert" "io" "mime/multipart" "os" "path/filepath" "testing" + + "github.com/stretchr/testify/assert" ) var targetDir string @@ -45,6 +46,10 @@ func TestWriteFilesToStreamAndReadFilesFromStream(t *testing.T) { assert.Equal(t, file2Content, content) } -func simpleFileWriter(fileName string) (fileWriter io.WriteCloser, err error) { - return os.Create(filepath.Join(targetDir, fileName)) +func simpleFileWriter(fileName string) (fileWriter []io.WriteCloser, err error) { + writer, err := os.Create(filepath.Join(targetDir, fileName)) + if err != nil { + return nil, err + } + return []io.WriteCloser{writer}, nil } diff --git a/io/multiwriter.go b/io/multiwriter.go new file mode 100644 index 0000000..c165d4b --- /dev/null +++ b/io/multiwriter.go @@ -0,0 +1,56 @@ +package io + +import ( + "errors" + "io" + "sync" +) + +var ErrShortWrite = errors.New("short write") + +type asyncMultiWriter struct { + writers []io.Writer +} + +// AsyncMultiWriter creates a writer that duplicates its writes to all the +// provided writers asynchronous +func AsyncMultiWriter(writers ...io.Writer) io.Writer { + w := make([]io.Writer, len(writers)) + copy(w, writers) + return &asyncMultiWriter{w} +} + +// Writes data asynchronously to each writer and waits for all of them to complete. +// In case of an error, the writing will not complete. +func (t *asyncMultiWriter) Write(p []byte) (int, error) { + var wg sync.WaitGroup + wg.Add(len(t.writers)) + errChannel := make(chan error) + finished := make(chan bool, 1) + for _, w := range t.writers { + go writeData(p, w, &wg, errChannel) + } + go func() { + wg.Wait() + close(finished) + }() + // This select will block until one of the two channels returns a value. + select { + case <-finished: + case err := <-errChannel: + if err != nil { + return 0, err + } + } + return len(p), nil +} +func writeData(p []byte, w io.Writer, wg *sync.WaitGroup, errChan chan error) { + n, err := w.Write(p) + if err != nil { + errChan <- err + } + if n != len(p) { + errChan <- ErrShortWrite + } + wg.Done() +}