Skip to content

Commit

Permalink
calculate byte lengths with a HEAD request
Browse files Browse the repository at this point in the history
  • Loading branch information
cosmicexplorer committed Aug 19, 2024
1 parent e0dab39 commit 90703cc
Showing 1 changed file with 37 additions and 8 deletions.
45 changes: 37 additions & 8 deletions src/pip/_internal/network/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pathlib import Path
from queue import Queue
from threading import Event, Semaphore, Thread
from typing import Iterable, Optional, Tuple, Union
from typing import Iterable, List, Optional, Tuple, Union

from pip._vendor.requests.models import Response

Expand Down Expand Up @@ -127,6 +127,15 @@ def _http_get_download(session: PipSession, link: Link) -> Response:
return resp


def _http_content_length(session: PipSession, link: Link) -> Optional[int]:
target_url = link.url.split("#", 1)[0]
resp = session.head(target_url)
raise_for_status(resp)
if content_length := resp.headers.get("content-length", None):
return int(content_length)
return None


class Downloader:
def __init__(
self,
Expand Down Expand Up @@ -164,8 +173,9 @@ def _copy_chunks(
semaphore: Semaphore,
session: PipSession,
location: Path,
link: Link,
download_info: Tuple[Link, Optional[int]],
) -> None:
link, total_length = download_info
with semaphore:
try:
try:
Expand All @@ -188,11 +198,19 @@ def _copy_chunks(

with filepath.open("wb") as output_file:
chunk_index = 0
current_bytes = 0
for chunk in chunks:
# Check if another thread exited with an exception between chunks.
if event.is_set():
return
logger.debug("reading chunk %d for link %s", chunk_index, link)
current_bytes += len(chunk)
logger.debug(
"reading chunk %d for file %s [%d/%s bytes]",
chunk_index,
filename,
current_bytes,
str(total_length or 0),
)
chunk_index += 1
# Copy chunk directly to output file, without any
# additional buffering.
Expand All @@ -214,13 +232,24 @@ def __init__(
logger.info("Ignoring progress bar %s for parallel downloads", progress_bar)

def __call__(
self, input_links: Iterable[Link], location: Path
self, links: Iterable[Link], location: Path
) -> Iterable[Tuple[Link, Tuple[Path, Optional[str]]]]:
"""Download the files given by links into location."""
links = list(input_links)
# Calculate the byte length for each file, if available.
links_with_lengths: List[Tuple[Link, Optional[int]]] = [
(link, _http_content_length(self._session, link)) for link in links
]
# Sum up the total length we'll be downloading.
total_length: Optional[int] = 0
for _link, maybe_len in links_with_lengths:
if maybe_len is None:
total_length = None
break
assert total_length is not None
total_length += maybe_len

# Set up state to track thread progress, including inner exceptions.
total_downloads: int = len(links)
total_downloads: int = len(links_with_lengths)
completed_downloads: int = 0
q: "Queue[Union[Tuple[Link, Path, Optional[str]], BaseException]]" = Queue()
event = Event()
Expand All @@ -239,9 +268,9 @@ def __call__(
workers = [
Thread(
target=_copy_chunks,
args=(q, event, semaphore, self._session, location, link),
args=(q, event, semaphore, self._session, location, download_info),
)
for link in links
for download_info in links_with_lengths
]
for w in workers:
w.start()
Expand Down

0 comments on commit 90703cc

Please sign in to comment.