From 23adf3a0f638ba4b0220ab09097f032660f6d5de Mon Sep 17 00:00:00 2001 From: Michael Sverdlov Date: Mon, 11 Mar 2024 16:17:56 +0200 Subject: [PATCH] Improve WriteFilesToStream generic function (#55) --- http/filestream/filestream.go | 35 ++++++++++++++---------------- http/filestream/filestream_test.go | 34 +++++++++++------------------ 2 files changed, 29 insertions(+), 40 deletions(-) diff --git a/http/filestream/filestream.go b/http/filestream/filestream.go index 9f81879..dcaf579 100644 --- a/http/filestream/filestream.go +++ b/http/filestream/filestream.go @@ -6,19 +6,17 @@ import ( ioutils "github.com/jfrog/gofrog/io" "io" "mime/multipart" - "net/http" "os" ) const ( - contentType = "Content-Type" - FileType = "file" + FileType = "file" ) // The expected type of function that should be provided to the ReadFilesFromStream func, that returns the writer that should handle each file -type FileHandlerFunc func(fileName string) (writer io.WriteCloser, err error) +type FileWriterFunc func(fileName string) (writer io.WriteCloser, err error) -func ReadFilesFromStream(multipartReader *multipart.Reader, fileHandlerFunc FileHandlerFunc) error { +func ReadFilesFromStream(multipartReader *multipart.Reader, fileWriterFunc FileWriterFunc) error { for { // Read the next file streamed from client fileReader, err := multipartReader.NextPart() @@ -28,7 +26,7 @@ func ReadFilesFromStream(multipartReader *multipart.Reader, fileHandlerFunc File } return fmt.Errorf("failed to read file: %w", err) } - err = readFile(fileReader, fileHandlerFunc) + err = readFile(fileReader, fileWriterFunc) if err != nil { return err } @@ -37,9 +35,9 @@ func ReadFilesFromStream(multipartReader *multipart.Reader, fileHandlerFunc File return nil } -func readFile(fileReader *multipart.Part, fileHandlerFunc FileHandlerFunc) (err error) { +func readFile(fileReader *multipart.Part, fileWriterFunc FileWriterFunc) (err error) { fileName := fileReader.FileName() - fileWriter, err := fileHandlerFunc(fileName) + fileWriter, err := fileWriterFunc(fileName) if err != nil { return err } @@ -50,12 +48,14 @@ func readFile(fileReader *multipart.Part, fileHandlerFunc FileHandlerFunc) (err return err } -func WriteFilesToStream(responseWriter http.ResponseWriter, filePaths []string) (err error) { - multipartWriter := multipart.NewWriter(responseWriter) - responseWriter.Header().Set(contentType, multipartWriter.FormDataContentType()) +type FileInfo struct { + Name string + Path string +} - for _, filePath := range filePaths { - if err = writeFile(multipartWriter, filePath); err != nil { +func WriteFilesToStream(multipartWriter *multipart.Writer, filesList []FileInfo) (err error) { + for _, file := range filesList { + if err = writeFile(multipartWriter, file); err != nil { return } } @@ -65,13 +65,10 @@ func WriteFilesToStream(responseWriter http.ResponseWriter, filePaths []string) return multipartWriter.Close() } -func writeFile(multipartWriter *multipart.Writer, filePath string) (err error) { - fileReader, err := os.Open(filePath) - if err != nil { - return fmt.Errorf("failed to open file: %w", err) - } +func writeFile(multipartWriter *multipart.Writer, file FileInfo) (err error) { + fileReader, err := os.Open(file.Path) defer ioutils.Close(fileReader, &err) - fileWriter, err := multipartWriter.CreateFormFile(FileType, filePath) + fileWriter, err := multipartWriter.CreateFormFile(FileType, file.Name) if err != nil { return fmt.Errorf("failed to CreateFormFile: %w", err) } diff --git a/http/filestream/filestream_test.go b/http/filestream/filestream_test.go index f021314..3e10a7a 100644 --- a/http/filestream/filestream_test.go +++ b/http/filestream/filestream_test.go @@ -1,13 +1,12 @@ package filestream import ( + "bytes" "github.com/stretchr/testify/assert" "io" "mime/multipart" - "net/http/httptest" "os" "path/filepath" - "strings" "testing" ) @@ -16,43 +15,36 @@ var targetDir string func TestWriteFilesToStreamAndReadFilesFromStream(t *testing.T) { sourceDir := t.TempDir() // Create 2 file to be transferred via our multipart stream - file1 := filepath.Join(sourceDir, "test1.txt") - file2 := filepath.Join(sourceDir, "test2.txt") + file1 := FileInfo{Name: "test1.txt", Path: filepath.Join(sourceDir, "test1.txt")} + file2 := FileInfo{Name: "test2.txt", Path: filepath.Join(sourceDir, "test2.txt")} file1Content := []byte("test content1") file2Content := []byte("test content2") - assert.NoError(t, os.WriteFile(file1, file1Content, 0600)) - assert.NoError(t, os.WriteFile(file2, file2Content, 0600)) + assert.NoError(t, os.WriteFile(file1.Path, file1Content, 0600)) + assert.NoError(t, os.WriteFile(file2.Path, file2Content, 0600)) // Create the multipart writer that will stream our files - responseWriter := httptest.NewRecorder() - assert.NoError(t, WriteFilesToStream(responseWriter, []string{file1, file2})) + body := &bytes.Buffer{} + multipartWriter := multipart.NewWriter(body) + assert.NoError(t, WriteFilesToStream(multipartWriter, []FileInfo{file1, file2})) // Create local temp dir that will store our files targetDir = t.TempDir() - // Get boundary hash from writer - boundary := strings.Split(responseWriter.Header().Get(contentType), "boundary=")[1] // Create the multipart reader that will read the files from the stream - multipartReader := multipart.NewReader(responseWriter.Body, boundary) - assert.NoError(t, ReadFilesFromStream(multipartReader, simpleFileHandler)) + multipartReader := multipart.NewReader(body, multipartWriter.Boundary()) + assert.NoError(t, ReadFilesFromStream(multipartReader, simpleFileWriter)) // Validate file 1 transferred successfully - file1 = filepath.Join(targetDir, "test1.txt") - assert.FileExists(t, file1) - content, err := os.ReadFile(file1) + content, err := os.ReadFile(filepath.Join(targetDir, file1.Name)) assert.NoError(t, err) assert.Equal(t, file1Content, content) - assert.NoError(t, os.Remove(file1)) // Validate file 2 transferred successfully - file2 = filepath.Join(targetDir, "test2.txt") - assert.FileExists(t, file2) - content, err = os.ReadFile(file2) + content, err = os.ReadFile(filepath.Join(targetDir, file2.Name)) assert.NoError(t, err) assert.Equal(t, file2Content, content) - assert.NoError(t, os.Remove(file2)) } -func simpleFileHandler(fileName string) (fileWriter io.WriteCloser, err error) { +func simpleFileWriter(fileName string) (fileWriter io.WriteCloser, err error) { return os.Create(filepath.Join(targetDir, fileName)) }