diff --git a/client.go b/client.go index 5ab9605d..3c749df1 100644 --- a/client.go +++ b/client.go @@ -8,6 +8,7 @@ import ( "io" "log" "net/http" + "strconv" ) // Client offers methods to download video metadata and video streams. @@ -235,13 +236,45 @@ func (c *Client) GetStreamContext(ctx context.Context, video *Video, format *For } r, w := io.Pipe() + contentLength := format.ContentLength - go c.download(req, w, format) + if contentLength == 0 { + // some videos don't have length information + contentLength = c.downloadOnce(req, w, format) + } else { + // we have length information, let's download by chunks! + go c.downloadChunked(req, w, format) + } + + return r, contentLength, nil +} + +func (c *Client) downloadOnce(req *http.Request, w *io.PipeWriter, format *Format) int64 { + resp, err := c.httpDo(req) + if err != nil { + //nolint:errcheck + w.CloseWithError(err) + return 0 + } + + go func() { + defer resp.Body.Close() + _, err := io.Copy(w, resp.Body) + if err == nil { + w.Close() + } else { + //nolint:errcheck + w.CloseWithError(err) + } + }() - return r, format.ContentLength, nil + contentLength := resp.Header.Get("Content-Length") + len, _ := strconv.ParseInt(contentLength, 10, 64) + + return len } -func (c *Client) download(req *http.Request, w *io.PipeWriter, format *Format) { +func (c *Client) downloadChunked(req *http.Request, w *io.PipeWriter, format *Format) { const chunkSize int64 = 10_000_000 // Loads a chunk a returns the written bytes. // Downloading in multiple chunks is much faster: @@ -263,19 +296,6 @@ func (c *Client) download(req *http.Request, w *io.PipeWriter, format *Format) { } defer w.Close() - //nolint:revive,errcheck - if format.ContentLength == 0 { - resp, err := c.httpDo(req) - if err != nil { - w.CloseWithError(err) - return - } - - defer resp.Body.Close() - - io.Copy(w, resp.Body) - return - } //nolint:revive,errcheck // load all the chunks diff --git a/client_test.go b/client_test.go index 9ec18b2a..ee363d92 100644 --- a/client_test.go +++ b/client_test.go @@ -107,11 +107,19 @@ func TestGetVideoWithManifestURL(t *testing.T) { require.NoError(err) require.NotNil(video) + assert.NotEmpty(video.Formats) assert.NotEmpty(video.Thumbnails) assert.Greater(len(video.Thumbnails), 0) assert.NotEmpty(video.Thumbnails[0].URL) assert.NotEmpty(video.HLSManifestURL) assert.NotEmpty(video.DASHManifestURL) + + format := video.Formats[0] + assert.Zero(format.ContentLength) + + _, size, err := testClient.GetStream(video, &format) + require.NoError(err) + require.NotZero(size) } func TestGetStream(t *testing.T) {