diff --git a/api/sda/sda.go b/api/sda/sda.go index c86731f..943100d 100644 --- a/api/sda/sda.go +++ b/api/sda/sda.go @@ -12,7 +12,6 @@ import ( "time" "github.com/gin-gonic/gin" - "github.com/neicnordic/crypt4gh/model/headers" "github.com/neicnordic/crypt4gh/streaming" "github.com/neicnordic/sda-download/api/middleware" "github.com/neicnordic/sda-download/internal/config" @@ -176,15 +175,6 @@ func Download(c *gin.Context) { return } - // Get coordinates - coordinates, err := parseCoordinates(c.Request) - if err != nil { - log.Errorf("parsing of query param coordinates to crypt4gh format failed, reason: %v", err) - c.String(http.StatusBadRequest, err.Error()) - - return - } - c.Header("Content-Length", fmt.Sprint(fileDetails.DecryptedSize)) c.Header("Content-Type", "application/octet-stream") if c.GetBool("S3") { @@ -206,90 +196,101 @@ func Download(c *gin.Context) { return } - // Stitch file and prepare it for streaming - fileStream, err := stitchFile(fileDetails.Header, file, coordinates) + hr := bytes.NewReader(fileDetails.Header) + mr := io.MultiReader(hr, file) + c4ghr, err := streaming.NewCrypt4GHReader(mr, *config.Config.App.Crypt4GHKey, nil) if err != nil { log.Errorf("could not prepare file for streaming, %s", err) c.String(http.StatusInternalServerError, "file stream error") return } + defer c4ghr.Close() + + sendStream(c, c4ghr) - sendStream(c.Writer, fileStream) } -// stitchFile stitches the header and file body together for Crypt4GHReader -// and returns a streamable Reader -var stitchFile = func(header []byte, file io.ReadCloser, coordinates *headers.DataEditListHeaderPacket) (*streaming.Crypt4GHReader, error) { - log.Debugf("stitching header to file %s for streaming", file) - // Stitch header and file body together - hr := bytes.NewReader(header) - mr := io.MultiReader(hr, file) - c4ghr, err := streaming.NewCrypt4GHReader(mr, *config.Config.App.Crypt4GHKey, coordinates) +// sendStream streams file contents from a reader +var sendStream = func(c *gin.Context, c4ghr *streaming.Crypt4GHReader) { + log.Debug("begin data stream") + + // Get query params + qStart := c.DefaultQuery("startCoordinate", "0") + qEnd := c.DefaultQuery("endCoordinate", "0") + + // Parse and verify coordinates are valid + start, err := strconv.ParseInt(qStart, 10, 0) if err != nil { - log.Errorf("failed to create Crypt4GH stream reader, %v", err) + log.Errorf("failed to convert start coordinate %d to integer, %s", start, err) + c.String(http.StatusInternalServerError, "startCoordinate must be an integer") - return nil, err + return } - log.Debugf("file stream for %s constructed", file) - - return c4ghr, nil -} + end, err := strconv.ParseInt(qEnd, 10, 0) + if err != nil { + log.Errorf("failed to convert end coordinate %d to integer, %s", end, err) -// parseCoordinates takes query param coordinates and converts them to -// Crypt4GH reader format -var parseCoordinates = func(r *http.Request) (*headers.DataEditListHeaderPacket, error) { + c.String(http.StatusInternalServerError, "endCoordinate must be an integer") - coordinates := &headers.DataEditListHeaderPacket{} + return + } + if end < start { + log.Errorf("endCoordinate=%d must be greater than startCoordinate=%d", end, start) - // Get query params - qStart := r.URL.Query().Get("startCoordinate") - qEnd := r.URL.Query().Get("endCoordinate") + c.String(http.StatusInternalServerError, "endCoordinate must be greater than startCoordinate") - // Parse and verify coordinates are valid - if len(qStart) > 0 && len(qEnd) > 0 { - start, err := strconv.ParseUint(qStart, 10, 64) - if err != nil { - log.Errorf("failed to convert start coordinate %d to integer, %s", start, err) + return + } - return nil, errors.New("startCoordinate must be an integer") - } - end, err := strconv.ParseUint(qEnd, 10, 64) - if err != nil { - log.Errorf("failed to convert end coordinate %d to integer, %s", end, err) + if start != 0 { + // We don't want to read from start, skip ahead to where we should be + if _, err := c4ghr.Seek(start, 0); err != nil { - return nil, errors.New("endCoordinate must be an integer") - } - if end < start { - log.Errorf("endCoordinate=%d must be greater than startCoordinate=%d", end, start) + c.String(http.StatusInternalServerError, "endCoordinate must be greater than startCoordinate") - return nil, errors.New("endCoordinate must be greater than startCoordinate") + return } - // API query params take a coordinate range to read "start...end" - // But Crypt4GHReader takes a start byte and number of bytes to read "start...(end-start)" - bytesToRead := end - start - coordinates.NumberLengths = 2 - coordinates.Lengths = []uint64{start, bytesToRead} - } else { - coordinates = nil } - return coordinates, nil -} + // Calculate how much we should read (if given) + togo := end - start -// sendStream streams file contents from a reader -var sendStream = func(w http.ResponseWriter, file io.Reader) { - log.Debug("begin data stream") + buf := make([]byte, 4096) - n, err := io.Copy(w, file) - log.Debug("end data stream") + // Loop until we've read what we should (if no/faulty end given, that's EOF) + for end == 0 || togo > 0 { + rbuf := buf - if err != nil { - log.Errorf("file streaming failed, reason: %v", err) - http.Error(w, "file streaming failed", 500) + if end != 0 && togo < 4096 { + // If we don't want to read as much as 4096 bytes + rbuf = buf[:togo] + } + r, err := c4ghr.Read(rbuf) + togo -= int64(r) - return - } + // Nothing more to read? + if err == io.EOF && r == 0 { + // Fall out without error if we had EOF (if we got any data, do one + // more lap in the loop) + return + } - log.Debugf("Sent %d bytes", n) + if err != nil && err != io.EOF { + // An error we want to signal? + return + } + + wbuf := rbuf[:r] + for len(wbuf) > 0 { + // Loop until we've written all that we could read, + // fall out on error + w, err := c.Writer.Write(wbuf) + + if err != nil { + return + } + wbuf = wbuf[w:] + } + } } diff --git a/api/sda/sda_test.go b/api/sda/sda_test.go index 148a777..d5cb80a 100644 --- a/api/sda/sda_test.go +++ b/api/sda/sda_test.go @@ -4,13 +4,11 @@ import ( "bytes" "errors" "io" - "net/http" "net/http/httptest" - "os" "testing" "github.com/gin-gonic/gin" - "github.com/neicnordic/crypt4gh/model/headers" + "github.com/neicnordic/crypt4gh/streaming" "github.com/neicnordic/sda-download/api/middleware" "github.com/neicnordic/sda-download/internal/config" @@ -323,103 +321,6 @@ func TestFiles_Success(t *testing.T) { } -func TestParseCoordinates_Fail_Start(t *testing.T) { - - // Test case - // startCoordinate must be an integer - r := httptest.NewRequest("GET", "https://testing.fi?startCoordinate=x&endCoordinate=100", nil) - - // Run test target - coordinates, err := parseCoordinates(r) - - // Expected results - expectedError := "startCoordinate must be an integer" - - if err.Error() != expectedError { - t.Errorf("TestParseCoordinates_Fail_Start failed, got %s expected %s", err.Error(), expectedError) - } - if coordinates != nil { - t.Errorf("TestParseCoordinates_Fail_Start failed, got %v expected nil", coordinates) - } - -} - -func TestParseCoordinates_Fail_End(t *testing.T) { - - // Test case - // endCoordinate must be an integer - r := httptest.NewRequest("GET", "https://testing.fi?startCoordinate=0&endCoordinate=y", nil) - - // Run test target - coordinates, err := parseCoordinates(r) - - // Expected results - expectedError := "endCoordinate must be an integer" - - if err.Error() != expectedError { - t.Errorf("TestParseCoordinates_Fail_End failed, got %s expected %s", err.Error(), expectedError) - } - if coordinates != nil { - t.Errorf("TestParseCoordinates_Fail_End failed, got %v expected nil", coordinates) - } - -} - -func TestParseCoordinates_Fail_SizeComparison(t *testing.T) { - - // Test case - // endCoordinate must be greater than startCoordinate - r := httptest.NewRequest("GET", "https://testing.fi?startCoordinate=50&endCoordinate=100", nil) - - // Run test target - coordinates, err := parseCoordinates(r) - - // Expected results - expectedLength := uint32(2) - expectedStart := uint64(50) - expectedBytesToRead := uint64(50) - - if err != nil { - t.Errorf("TestParseCoordinates_Fail_SizeComparison failed, got %v expected nil", err) - } - // nolint:staticcheck - if coordinates == nil { - t.Error("TestParseCoordinates_Fail_SizeComparison failed, got nil expected not nil") - } - // nolint:staticcheck - if coordinates.NumberLengths != expectedLength { - t.Errorf("TestParseCoordinates_Fail_SizeComparison failed, got %d expected %d", coordinates.Lengths, expectedLength) - } - if coordinates.Lengths[0] != expectedStart { - t.Errorf("TestParseCoordinates_Fail_SizeComparison failed, got %d expected %d", coordinates.Lengths, expectedLength) - } - if coordinates.Lengths[1] != expectedBytesToRead { - t.Errorf("TestParseCoordinates_Fail_SizeComparison failed, got %d expected %d", coordinates.Lengths, expectedLength) - } - -} - -func TestParseCoordinates_Success(t *testing.T) { - - // Test case - // endCoordinate must be greater than startCoordinate - r := httptest.NewRequest("GET", "https://testing.fi?startCoordinate=100&endCoordinate=50", nil) - - // Run test target - coordinates, err := parseCoordinates(r) - - // Expected results - expectedError := "endCoordinate must be greater than startCoordinate" - - if err.Error() != expectedError { - t.Errorf("TestParseCoordinates_Fail_SizeComparison failed, got %s expected %s", err.Error(), expectedError) - } - if coordinates != nil { - t.Errorf("TestParseCoordinates_Fail_SizeComparison failed, got %v expected nil", coordinates) - } - -} - func TestDownload_Fail_FileNotFound(t *testing.T) { // Save original to-be-mocked functions @@ -611,7 +512,6 @@ func TestDownload_Fail_ParseCoordinates(t *testing.T) { originalCheckFilePermission := database.CheckFilePermission originalGetCacheFromContext := middleware.GetCacheFromContext originalGetFile := database.GetFile - originalParseCoordinates := parseCoordinates config.Config.Archive.Posix.Location = "." Backend, _ = storage.NewBackend(config.Config.Archive) @@ -633,9 +533,6 @@ func TestDownload_Fail_ParseCoordinates(t *testing.T) { return fileDetails, nil } - parseCoordinates = func(r *http.Request) (*headers.DataEditListHeaderPacket, error) { - return nil, errors.New("bad params") - } // Mock request and response holders w := httptest.NewRecorder() @@ -663,7 +560,6 @@ func TestDownload_Fail_ParseCoordinates(t *testing.T) { database.CheckFilePermission = originalCheckFilePermission middleware.GetCacheFromContext = originalGetCacheFromContext database.GetFile = originalGetFile - parseCoordinates = originalParseCoordinates } @@ -673,8 +569,6 @@ func TestDownload_Fail_StreamFile(t *testing.T) { originalCheckFilePermission := database.CheckFilePermission originalGetCacheFromContext := middleware.GetCacheFromContext originalGetFile := database.GetFile - originalParseCoordinates := parseCoordinates - originalStitchFile := stitchFile config.Config.Archive.Posix.Location = "." Backend, _ = storage.NewBackend(config.Config.Archive) @@ -697,12 +591,6 @@ func TestDownload_Fail_StreamFile(t *testing.T) { return fileDetails, nil } - parseCoordinates = func(r *http.Request) (*headers.DataEditListHeaderPacket, error) { - return nil, nil - } - stitchFile = func(header []byte, file io.ReadCloser, coordinates *headers.DataEditListHeaderPacket) (*streaming.Crypt4GHReader, error) { - return nil, errors.New("file stream error") - } // Mock request and response holders w := httptest.NewRecorder() @@ -731,8 +619,6 @@ func TestDownload_Fail_StreamFile(t *testing.T) { database.CheckFilePermission = originalCheckFilePermission middleware.GetCacheFromContext = originalGetCacheFromContext database.GetFile = originalGetFile - parseCoordinates = originalParseCoordinates - stitchFile = originalStitchFile } @@ -742,8 +628,6 @@ func TestDownload_Success(t *testing.T) { originalCheckFilePermission := database.CheckFilePermission originalGetCacheFromContext := middleware.GetCacheFromContext originalGetFile := database.GetFile - originalParseCoordinates := parseCoordinates - originalStitchFile := stitchFile originalSendStream := sendStream config.Config.Archive.Posix.Location = "." Backend, _ = storage.NewBackend(config.Config.Archive) @@ -767,15 +651,9 @@ func TestDownload_Success(t *testing.T) { return fileDetails, nil } - parseCoordinates = func(r *http.Request) (*headers.DataEditListHeaderPacket, error) { - return nil, nil - } - stitchFile = func(header []byte, file io.ReadCloser, coordinates *headers.DataEditListHeaderPacket) (*streaming.Crypt4GHReader, error) { - return nil, nil - } - sendStream = func(w http.ResponseWriter, file io.Reader) { + sendStream = func(ctx *gin.Context, file *streaming.Crypt4GHReader) { fileReader := bytes.NewReader([]byte("hello\n")) - _, _ = io.Copy(w, fileReader) + _, _ = io.Copy(ctx.Writer, fileReader) } // Mock request and response holders @@ -805,114 +683,6 @@ func TestDownload_Success(t *testing.T) { database.CheckFilePermission = originalCheckFilePermission middleware.GetCacheFromContext = originalGetCacheFromContext database.GetFile = originalGetFile - parseCoordinates = originalParseCoordinates - stitchFile = originalStitchFile sendStream = originalSendStream } - -func TestSendStream(t *testing.T) { - // Mock file - file := []byte("hello\n") - fileReader := bytes.NewReader(file) - - // Mock stream response - w := httptest.NewRecorder() - w.Header().Add("Content-Length", "5") - - // Send file to streamer - sendStream(w, fileReader) - response := w.Result() - defer response.Body.Close() - body, _ := io.ReadAll(response.Body) - expectedContentLen := "5" - expectedBody := []byte("hello\n") - - // Verify that stream received contents - if contentLen := response.Header.Get("Content-Length"); contentLen != expectedContentLen { - t.Errorf("TestSendStream failed, got %s, expected %s", contentLen, expectedContentLen) - } - if !bytes.Equal(body, []byte(expectedBody)) { - t.Errorf("TestSendStream failed, got %s, expected %s", string(body), string(expectedBody)) - } -} - -func TestStitchFile_Fail(t *testing.T) { - - // Set test decryption key - config.Config.App.Crypt4GHKey = &[32]byte{} - - // Test header - header := []byte("header") - - // Test file body - testFile, err := os.CreateTemp("/tmp", "_sda_download_test_file") - if err != nil { - t.Errorf("TestStitchFile_Fail failed to create temp file, %v", err) - } - defer os.Remove(testFile.Name()) - defer testFile.Close() - const data = "hello, here is some test data\n" - _, _ = io.WriteString(testFile, data) - - // Test - fileStream, err := stitchFile(header, testFile, nil) - - // Expected results - expectedError := "not a Crypt4GH file" - - if err.Error() != expectedError { - t.Errorf("TestStitchFile_Fail failed, got %s expected %s", err.Error(), expectedError) - } - if fileStream != nil { - t.Errorf("TestStitchFile_Fail failed, got %v expected nil", fileStream) - } - -} - -func TestStitchFile_Success(t *testing.T) { - - // Set test decryption key - config.Config.App.Crypt4GHKey = &[32]byte{104, 35, 143, 159, 198, 120, 0, 145, 227, 124, 101, 127, 223, - 22, 252, 57, 224, 114, 205, 70, 150, 10, 28, 79, 192, 242, 151, 202, 44, 51, 36, 97} - - // Test header - header := []byte{99, 114, 121, 112, 116, 52, 103, 104, 1, 0, 0, 0, 1, 0, 0, 0, 108, 0, 0, 0, 0, 0, 0, 0, - 44, 219, 36, 17, 144, 78, 250, 192, 85, 103, 229, 122, 90, 11, 223, 131, 246, 165, 142, 191, 83, 97, - 206, 225, 206, 114, 10, 235, 239, 160, 206, 82, 55, 101, 76, 39, 217, 91, 249, 206, 122, 241, 69, 142, - 155, 97, 24, 47, 112, 45, 165, 197, 159, 60, 92, 214, 160, 112, 21, 129, 73, 31, 159, 54, 210, 4, 44, - 147, 108, 119, 178, 95, 194, 195, 11, 249, 60, 53, 133, 77, 93, 62, 31, 218, 29, 65, 143, 123, 208, 234, - 249, 34, 58, 163, 32, 149, 156, 110, 68, 49} - - // Test file body - testFile, err := os.CreateTemp("/tmp", "_sda_download_test_file") - if err != nil { - t.Errorf("TestStitchFile_Fail failed to create temp file, %v", err) - } - defer os.Remove(testFile.Name()) - defer testFile.Close() - testData := []byte{237, 0, 67, 9, 203, 239, 12, 187, 86, 6, 195, 174, 56, 234, 44, 78, 140, 2, 195, 5, 252, - 199, 244, 189, 150, 209, 144, 197, 61, 72, 73, 155, 205, 210, 206, 160, 226, 116, 242, 134, 63, 224, 178, - 153, 13, 181, 78, 210, 151, 219, 156, 18, 210, 70, 194, 76, 152, 178} - _, _ = testFile.Write(testData) - - // Test - // The decryption passes, but for some reason the temp test file doesn't return any data, so we can just check for error here - _, err = stitchFile(header, testFile, nil) - // fileStream, err := stitchFile(header, testFile, nil) - // data, err := io.ReadAll(fileStream) - - // Expected results - // expectedData := "hello, here is some test data" - - if err != nil { - t.Errorf("TestStitchFile_Success failed, got %v expected nil", err) - } - // if !bytes.Equal(data, []byte(expectedData)) { - // // visual byte comparison in terminal (easier to find string differences) - // t.Error(data) - // t.Error([]byte(expectedData)) - // t.Errorf("TestStitchFile_Success failed, got %s expected %s", string(data), string(expectedData)) - // } - -}