From 3bf997487afc27c03f7c2f74b04e3d22b3023a19 Mon Sep 17 00:00:00 2001 From: Omer Zidkoni Date: Tue, 9 Apr 2024 10:22:39 +0300 Subject: [PATCH] CR --- http/filestream/filestream.go | 39 +++++++++++++++++++----------- http/filestream/filestream_test.go | 7 +++--- 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/http/filestream/filestream.go b/http/filestream/filestream.go index 40f5402..343a604 100644 --- a/http/filestream/filestream.go +++ b/http/filestream/filestream.go @@ -62,45 +62,56 @@ type FileInfo struct { Path string } -func WriteFilesToStream(multipartWriter *multipart.Writer, filesList []*FileInfo) error { +func WriteFilesToStream(multipartWriter *multipart.Writer, filesList []*FileInfo) (err error) { + var isContentWritten bool + defer func() { + // The multipartWriter.Close() function automatically writes the closing boundary to the underlying writer, + // regardless of whether any content was written to it. Therefore, if no content was written + // (i.e., no parts were created using the multipartWriter), there is no need to explicitly close the + // multipartWriter. The closing boundary will be correctly handled by calling multipartWriter.Close() + // when it goes out of scope or when explicitly called, ensuring the proper termination of the multipart request. + if isContentWritten { + err = errors.Join(err, multipartWriter.Close()) + } + }() for _, file := range filesList { - if err := writeFile(multipartWriter, file); err != nil { - return writeErrPart(multipartWriter, file, err) + if err = writeFile(multipartWriter, file); err != nil { + isContentWritten, err = writeErrPart(multipartWriter, file, err) + return err } + isContentWritten = true } - // Close finishes the multipart message and writes the trailing - // boundary end line to the output. - // We don't use defer for this because the multipart.Writer's Close() method writes regardless of whether there was an error or if writing hadn't started at all - return multipartWriter.Close() + return nil } func writeFile(multipartWriter *multipart.Writer, file *FileInfo) (err error) { fileReader, err := os.Open(file.Path) if err != nil { - return fmt.Errorf("failed opening %q: %w", file.Name, err) + return fmt.Errorf("failed opening file %q: %w", file.Name, err) } defer ioutils.Close(fileReader, &err) fileWriter, err := multipartWriter.CreateFormFile(FileType, file.Name) if err != nil { - return fmt.Errorf("failed to CreateFormFile: %w", err) + return fmt.Errorf("failed to create form file for %q: %w", file.Name, err) } _, err = io.Copy(fileWriter, fileReader) return err } -func writeErrPart(multipartWriter *multipart.Writer, file *FileInfo, writeFileErr error) error { +func writeErrPart(multipartWriter *multipart.Writer, file *FileInfo, writeFileErr error) (bool, error) { + var isPartWritten bool fileWriter, err := multipartWriter.CreateFormField(ErrorType) if err != nil { - return fmt.Errorf("failed to CreateFormField: %w", err) + return isPartWritten, fmt.Errorf("failed to create form field: %w", err) } - + isPartWritten = true multipartErr := NewMultipartError(file.Name, writeFileErr.Error()) multipartErrJSON, err := json.Marshal(multipartErr) if err != nil { - return fmt.Errorf("failed to marshal multipart error: %w", err) + return isPartWritten, fmt.Errorf("failed to marshal multipart error for file %q: %w", file.Name, err) } _, err = io.Copy(fileWriter, bytes.NewReader(multipartErrJSON)) - return err + return isPartWritten, err } diff --git a/http/filestream/filestream_test.go b/http/filestream/filestream_test.go index b0e8627..1947a53 100644 --- a/http/filestream/filestream_test.go +++ b/http/filestream/filestream_test.go @@ -10,6 +10,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var targetDir string @@ -61,11 +62,11 @@ func TestWriteFilesToStreamWithError(t *testing.T) { // Call WriteFilesToStream and expect an error err := WriteFilesToStream(multipartWriter, []*FileInfo{file}) - assert.NoError(t, err) + require.NoError(t, err) multipartReader := multipart.NewReader(body, multipartWriter.Boundary()) - form, err := multipartReader.ReadForm(1024) - assert.NoError(t, err) + form, err := multipartReader.ReadForm(10 * 1024) + require.NoError(t, err) assert.Len(t, form.Value[ErrorType], 1) var multipartErr MultipartError