diff --git a/go.mod b/go.mod index a60ea6b..06f55ba 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/jfrog/archiver/v3 v3.6.0 github.com/pkg/errors v0.9.1 github.com/stretchr/testify v1.8.4 + github.com/zeebo/xxh3 v1.0.2 ) require ( @@ -14,6 +15,7 @@ require ( github.com/dsnet/compress v0.0.1 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/klauspost/compress v1.17.4 // indirect + github.com/klauspost/cpuid/v2 v2.0.9 // indirect github.com/klauspost/pgzip v1.2.6 // indirect github.com/nwaples/rardecode v1.1.3 // indirect github.com/pierrec/lz4/v4 v4.1.21 // indirect diff --git a/go.sum b/go.sum index b9baeb6..8a65411 100644 --- a/go.sum +++ b/go.sum @@ -13,6 +13,8 @@ github.com/klauspost/compress v1.4.1/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0 github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= github.com/klauspost/cpuid v1.2.0/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= +github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/pgzip v1.2.6 h1:8RXeL5crjEUFnR2/Sn6GJNWtSQ3Dk8pq4CL3jvdDyjU= github.com/klauspost/pgzip v1.2.6/go.mod h1:Ch1tH69qFZu15pkjo5kYi6mth2Zzwzt50oCQKQE9RUs= github.com/nwaples/rardecode v1.1.3 h1:cWCaZwfM5H7nAD6PyEdcVnczzV8i/JtotnyW/dD9lEc= @@ -30,6 +32,9 @@ github.com/ulikunitz/xz v0.5.11 h1:kpFauv27b6ynzBNT/Xy+1k+fK4WswhN/6PN5WhFAGw8= github.com/ulikunitz/xz v0.5.11/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 h1:nIPpBwaJSVYIxUFsDv3M8ofmx9yWTog9BfvIu0q41lo= github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8/go.mod h1:HUYIGzjTL3rfEspMxjDjgmT5uz5wzYJKVo23qUhYTos= +github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ= +github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= +github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/http/filestream/filestream.go b/http/filestream/filestream.go new file mode 100644 index 0000000..9f81879 --- /dev/null +++ b/http/filestream/filestream.go @@ -0,0 +1,80 @@ +package filestream + +import ( + "errors" + "fmt" + ioutils "github.com/jfrog/gofrog/io" + "io" + "mime/multipart" + "net/http" + "os" +) + +const ( + contentType = "Content-Type" + FileType = "file" +) + +// The expected type of function that should be provided to the ReadFilesFromStream func, that returns the writer that should handle each file +type FileHandlerFunc func(fileName string) (writer io.WriteCloser, err error) + +func ReadFilesFromStream(multipartReader *multipart.Reader, fileHandlerFunc FileHandlerFunc) error { + for { + // Read the next file streamed from client + fileReader, err := multipartReader.NextPart() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return fmt.Errorf("failed to read file: %w", err) + } + err = readFile(fileReader, fileHandlerFunc) + if err != nil { + return err + } + + } + return nil +} + +func readFile(fileReader *multipart.Part, fileHandlerFunc FileHandlerFunc) (err error) { + fileName := fileReader.FileName() + fileWriter, err := fileHandlerFunc(fileName) + if err != nil { + return err + } + defer ioutils.Close(fileWriter, &err) + if _, err = io.Copy(fileWriter, fileReader); err != nil { + return fmt.Errorf("failed writing '%s' file: %w", fileName, err) + } + return err +} + +func WriteFilesToStream(responseWriter http.ResponseWriter, filePaths []string) (err error) { + multipartWriter := multipart.NewWriter(responseWriter) + responseWriter.Header().Set(contentType, multipartWriter.FormDataContentType()) + + for _, filePath := range filePaths { + if err = writeFile(multipartWriter, filePath); err != nil { + return + } + } + + // Close finishes the multipart message and writes the trailing + // boundary end line to the output. + return multipartWriter.Close() +} + +func writeFile(multipartWriter *multipart.Writer, filePath string) (err error) { + fileReader, err := os.Open(filePath) + if err != nil { + return fmt.Errorf("failed to open file: %w", err) + } + defer ioutils.Close(fileReader, &err) + fileWriter, err := multipartWriter.CreateFormFile(FileType, filePath) + if err != nil { + return fmt.Errorf("failed to CreateFormFile: %w", err) + } + _, err = io.Copy(fileWriter, fileReader) + return err +} diff --git a/http/filestream/filestream_test.go b/http/filestream/filestream_test.go new file mode 100644 index 0000000..f021314 --- /dev/null +++ b/http/filestream/filestream_test.go @@ -0,0 +1,58 @@ +package filestream + +import ( + "github.com/stretchr/testify/assert" + "io" + "mime/multipart" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" +) + +var targetDir string + +func TestWriteFilesToStreamAndReadFilesFromStream(t *testing.T) { + sourceDir := t.TempDir() + // Create 2 file to be transferred via our multipart stream + file1 := filepath.Join(sourceDir, "test1.txt") + file2 := filepath.Join(sourceDir, "test2.txt") + file1Content := []byte("test content1") + file2Content := []byte("test content2") + assert.NoError(t, os.WriteFile(file1, file1Content, 0600)) + assert.NoError(t, os.WriteFile(file2, file2Content, 0600)) + + // Create the multipart writer that will stream our files + responseWriter := httptest.NewRecorder() + assert.NoError(t, WriteFilesToStream(responseWriter, []string{file1, file2})) + + // Create local temp dir that will store our files + targetDir = t.TempDir() + + // Get boundary hash from writer + boundary := strings.Split(responseWriter.Header().Get(contentType), "boundary=")[1] + // Create the multipart reader that will read the files from the stream + multipartReader := multipart.NewReader(responseWriter.Body, boundary) + assert.NoError(t, ReadFilesFromStream(multipartReader, simpleFileHandler)) + + // Validate file 1 transferred successfully + file1 = filepath.Join(targetDir, "test1.txt") + assert.FileExists(t, file1) + content, err := os.ReadFile(file1) + assert.NoError(t, err) + assert.Equal(t, file1Content, content) + assert.NoError(t, os.Remove(file1)) + + // Validate file 2 transferred successfully + file2 = filepath.Join(targetDir, "test2.txt") + assert.FileExists(t, file2) + content, err = os.ReadFile(file2) + assert.NoError(t, err) + assert.Equal(t, file2Content, content) + assert.NoError(t, os.Remove(file2)) +} + +func simpleFileHandler(fileName string) (fileWriter io.WriteCloser, err error) { + return os.Create(filepath.Join(targetDir, fileName)) +}