From 549781b4b70251246d7275bdfb4c78d81a10e6a5 Mon Sep 17 00:00:00 2001 From: Danny McClanahan <1305167+cosmicexplorer@users.noreply.github.com> Date: Mon, 19 Aug 2024 11:10:39 -0400 Subject: [PATCH 1/3] improve pooled progress output for BatchDownloader - use more specific types for BatchDownloader#__call__ - calculate byte lengths with a HEAD request - quiet all progress output from -q - don't write colored output with --no-color - write a lot more documentation for the new progress bar logic - use ProgressBarType enum for --progress-bar CLI flag --- news/12925.feature.rst | 1 + src/pip/_internal/cli/cmdoptions.py | 11 +- src/pip/_internal/cli/progress_bars.py | 530 ++++++++++++++++++++++-- src/pip/_internal/cli/req_command.py | 5 +- src/pip/_internal/network/download.py | 180 ++++++-- src/pip/_internal/network/utils.py | 1 + src/pip/_internal/operations/prepare.py | 35 +- tests/unit/test_network_download.py | 3 +- tests/unit/test_operations_prepare.py | 5 +- tests/unit/test_req.py | 5 +- 10 files changed, 676 insertions(+), 100 deletions(-) create mode 100644 news/12925.feature.rst diff --git a/news/12925.feature.rst b/news/12925.feature.rst new file mode 100644 index 00000000000..8255f66d01b --- /dev/null +++ b/news/12925.feature.rst @@ -0,0 +1 @@ +Use very rich progress output for batch downloading. Use ``ProgressBarType`` enum class for ``--progress-bar`` choices. diff --git a/src/pip/_internal/cli/cmdoptions.py b/src/pip/_internal/cli/cmdoptions.py index 0b7cff77bdd..3bb7cbf4246 100644 --- a/src/pip/_internal/cli/cmdoptions.py +++ b/src/pip/_internal/cli/cmdoptions.py @@ -22,6 +22,7 @@ from pip._vendor.packaging.utils import canonicalize_name from pip._internal.cli.parser import ConfigOptionParser +from pip._internal.cli.progress_bars import ProgressBarType from pip._internal.exceptions import CommandError from pip._internal.locations import USER_CACHE_DIR, get_src_prefix from pip._internal.models.format_control import FormatControl @@ -226,11 +227,15 @@ class PipOption(Option): "--progress-bar", dest="progress_bar", type="choice", - choices=["on", "off", "raw"], - default="on", - help="Specify whether the progress bar should be used [on, off, raw] (default: on)", + choices=ProgressBarType.choices(), + default=ProgressBarType.ON.value, + help=( + "Specify whether the progress bar should be used" + f" {ProgressBarType.help_choices()} (default: %default)" + ), ) + log: Callable[..., Option] = partial( PipOption, "--log", diff --git a/src/pip/_internal/cli/progress_bars.py b/src/pip/_internal/cli/progress_bars.py index 883359c9ce7..1e9aa3c7ac7 100644 --- a/src/pip/_internal/cli/progress_bars.py +++ b/src/pip/_internal/cli/progress_bars.py @@ -1,19 +1,37 @@ +import abc import functools import sys -from typing import Callable, Generator, Iterable, Iterator, Optional, Tuple +from enum import Enum +from typing import ( + Any, + Callable, + Iterable, + Iterator, + List, + Optional, + Tuple, + Type, + TypeVar, +) +from pip._vendor.rich.console import Console +from pip._vendor.rich.live import Live +from pip._vendor.rich.panel import Panel from pip._vendor.rich.progress import ( BarColumn, DownloadColumn, FileSizeColumn, + MofNCompleteColumn, Progress, ProgressColumn, SpinnerColumn, + TaskID, TextColumn, TimeElapsedColumn, TimeRemainingColumn, TransferSpeedColumn, ) +from pip._vendor.rich.table import Table from pip._internal.cli.spinners import RateLimiter from pip._internal.utils.logging import get_indentation @@ -21,36 +39,79 @@ DownloadProgressRenderer = Callable[[Iterable[bytes]], Iterator[bytes]] +def _unknown_size_columns() -> Tuple[ProgressColumn, ...]: + """Rich progress with a spinner for completion of a download of unknown size. + + This is employed for downloads where the server does not return a 'Content-Length' + header, which currently cannot be inferred from e.g. wheel metadata.""" + return ( + TextColumn("[progress.description]{task.description}"), + SpinnerColumn("line", speed=1.5), + FileSizeColumn(), + TransferSpeedColumn(), + TimeElapsedColumn(), + ) + + +def _known_size_columns() -> Tuple[ProgressColumn, ...]: + """Rich progress for %completion of a download task in terms of bytes, with ETA.""" + return ( + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + DownloadColumn(), + TransferSpeedColumn(), + TextColumn("eta"), + TimeRemainingColumn(), + ) + + +def _task_columns() -> Tuple[ProgressColumn, ...]: + """Rich progress for %complete out of a fixed positive number of known tasks.""" + return ( + TextColumn("[progress.description]{task.description}"), + SpinnerColumn("line", speed=1.5), + BarColumn(), + TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + MofNCompleteColumn(), + ) + + +def _progress_task_prefix() -> str: + """For output that doesn't take up the whole terminal, make it align with current + logger indentation.""" + return " " * (get_indentation() + 2) + + def _rich_progress_bar( iterable: Iterable[bytes], *, - bar_type: str, - size: int, -) -> Generator[bytes, None, None]: - assert bar_type == "on", "This should only be used in the default mode." + size: Optional[int], + quiet: bool, + color: bool, +) -> Iterator[bytes]: + """Deploy a single rich progress bar to wrap a single download task. - if not size: + This provides a single line of updating output, prefixed with the appropriate + indentation. ETA and %completion are provided if ``size`` is known; otherwise, + a spinner with size, transfer speed, and time elapsed are provided.""" + if size is None: total = float("inf") - columns: Tuple[ProgressColumn, ...] = ( - TextColumn("[progress.description]{task.description}"), - SpinnerColumn("line", speed=1.5), - FileSizeColumn(), - TransferSpeedColumn(), - TimeElapsedColumn(), - ) + columns = _unknown_size_columns() else: total = size - columns = ( - TextColumn("[progress.description]{task.description}"), - BarColumn(), - DownloadColumn(), - TransferSpeedColumn(), - TextColumn("eta"), - TimeRemainingColumn(), - ) + columns = _known_size_columns() - progress = Progress(*columns, refresh_per_second=5) - task_id = progress.add_task(" " * (get_indentation() + 2), total=total) + progress = Progress( + *columns, + # TODO: consider writing to stderr over stdout? + console=Console(stderr=False, quiet=quiet, no_color=not color), + refresh_per_second=5, + ) + # This adds a task with no name, just enough indentation to align with log + # output. We rely upon the name of the download being printed beforehand on the + # previous line for context. + task_id = progress.add_task(_progress_task_prefix(), total=total) with progress: for chunk in iterable: yield chunk @@ -61,34 +122,431 @@ def _raw_progress_bar( iterable: Iterable[bytes], *, size: Optional[int], -) -> Generator[bytes, None, None]: - def write_progress(current: int, total: int) -> None: - sys.stdout.write("Progress %d of %d\n" % (current, total)) - sys.stdout.flush() + quiet: bool, +) -> Iterator[bytes]: + """Hand-write progress to stdout. + + Use subsequent lines for each chunk, with manual rate limiting. + """ + prefix = _progress_task_prefix() + total_fmt = "?" if size is None else str(size) + stream = sys.stdout + + def write_progress(current: int) -> None: + if quiet: + return + stream.write(f"{prefix}Progress {current} of {total_fmt} bytes\n") + stream.flush() current = 0 - total = size or 0 rate_limiter = RateLimiter(0.25) - write_progress(current, total) + write_progress(current) for chunk in iterable: current += len(chunk) - if rate_limiter.ready() or current == total: - write_progress(current, total) + if rate_limiter.ready() or current == size: + write_progress(current) rate_limiter.reset() yield chunk +class ProgressBarType(Enum): + """Types of progress output to show, for single or batched downloads. + + The values of this enum are used as the choices for the --progress-var CLI flag.""" + + ON = "on" + OFF = "off" + RAW = "raw" + + @classmethod + def choices(cls) -> List[str]: + return [x.value for x in cls] + + @classmethod + def help_choices(cls) -> str: + inner = ", ".join(cls.choices()) + return f"[{inner}]" + + def get_download_progress_renderer( - *, bar_type: str, size: Optional[int] = None + *, + bar_type: ProgressBarType, + size: Optional[int] = None, + quiet: bool = False, + color: bool = True, ) -> DownloadProgressRenderer: """Get an object that can be used to render the download progress. Returns a callable, that takes an iterable to "wrap". """ - if bar_type == "on": - return functools.partial(_rich_progress_bar, bar_type=bar_type, size=size) - elif bar_type == "raw": - return functools.partial(_raw_progress_bar, size=size) + if size is not None: + assert size >= 0 + + # TODO: use 3.10+ match statement! + if bar_type == ProgressBarType.ON: + return functools.partial( + _rich_progress_bar, size=size, quiet=quiet, color=color + ) + elif bar_type == ProgressBarType.RAW: + return functools.partial(_raw_progress_bar, size=size, quiet=quiet) else: + assert bar_type == ProgressBarType.OFF return iter # no-op, when passed an iterator + + +_ProgressClass = TypeVar("_ProgressClass", bound="BatchedProgress") + + +class BatchedProgress(abc.ABC): + """Interface for reporting progress output on batched download tasks. + + For batched downloads, we want to be able to express progress on several parallel + tasks at once. This means that instead of transforming an ``Iterator[bytes]`` like + ``DownloadProgressRenderer``, we instead want to receive asynchronous notifications + about progress over several separate tasks. These tasks may not start all at once, + and will end at different times. We assume progress over all of these tasks can be + uniformly summed up to get a measure of total progress. + """ + + @abc.abstractmethod + def add_subtask(self, description: str, total: Optional[int]) -> TaskID: + """Given a specific subtask description and known total length, add it to the + set of tracked tasks. + + This method is generally expected to be called before __enter__, but this is not + required.""" + ... + + @abc.abstractmethod + def start_subtask(self, task_id: TaskID) -> None: + """Given a subtask id returned by .add_subtask(), signal that the task + has begun. + + This information is used in progress reporting to calculate ETA. This method is + generally expected to be called after __enter__, but this is not required.""" + ... + + @abc.abstractmethod + def advance_subtask(self, task_id: TaskID, steps: int) -> None: + """Given a subtask id returned by .add_subtask(), progress the given number of + steps. + + Since tasks correspond to downloaded files, ``steps`` refers to the number of + bytes received. This is expected not to overflow the ``total`` number provided + to .add_subtask(), since the total is expected to be exact, but no error will + occur if it does.""" + ... + + @abc.abstractmethod + def finish_subtask(self, task_id: TaskID) -> None: + """Given a subtask id returned by .add_subtask(), indicate the task is complete. + + This is generally used to remove the task progress from the set of tracked + tasks, or to log that the task has completed. It does not need to be called in + the case of an exception.""" + ... + + @abc.abstractmethod + def __enter__(self) -> "BatchedProgress": + """Begin writing output to the terminal to track task progress. + + This may do nothing for no-op progress recorders, or it may write log messages, + or it may produce a rich output taking up the entire terminal.""" + ... + + @abc.abstractmethod + def __exit__(self, ty: Any, val: Any, tb: Any) -> None: + """Clean up any output written to the terminal. + + This is generally a no-op except for the rich progress recorder, which will give + back the terminal to the rest of pip.""" + ... + + @classmethod + @abc.abstractmethod + def create( + cls: Type[_ProgressClass], + num_tasks: int, + known_total_length: Optional[int], + quiet: bool, + color: bool, + ) -> _ProgressClass: + """Generate a progress recorder for a static number of known tasks. + + These tasks are intended to correspond to file downloads, so their "length" + corresponds to byte length. These tasks may not have their individual byte + lengths known, depending upon whether the server provides a 'Content-Length' + header. + + Progress recorders are expected to produce no output when ``quiet=True``, and + should not write colored output to the terminal when ``color=False``.""" + ... + + @classmethod + def select_progress_bar(cls, bar_type: ProgressBarType) -> "Type[BatchedProgress]": + """Factory method to produce a progress recorder according to CLI flag.""" + # TODO: use 3.10+ match statement! + if bar_type == ProgressBarType.ON: + return BatchedRichProgressBar + if bar_type == ProgressBarType.RAW: + return BatchedRawProgressBar + assert bar_type == ProgressBarType.OFF + return BatchedNoOpProgressBar + + +class BatchedNoOpProgressBar(BatchedProgress): + """Do absolutely nothing with the info.""" + + def add_subtask(self, description: str, total: Optional[int]) -> TaskID: + return TaskID(0) + + def start_subtask(self, task_id: TaskID) -> None: + pass + + def advance_subtask(self, task_id: TaskID, steps: int) -> None: + pass + + def finish_subtask(self, task_id: TaskID) -> None: + pass + + def __enter__(self) -> "BatchedNoOpProgressBar": + return self + + def __exit__(self, ty: Any, val: Any, tb: Any) -> None: + pass + + @classmethod + def create( + cls, + num_tasks: int, + known_total_length: Optional[int], + quiet: bool, + color: bool, + ) -> "BatchedNoOpProgressBar": + return cls() + + +class BatchedRawProgressBar(BatchedProgress): + """Manually write progress output to stdout. + + This will notify when subtasks have started, when they've completed, and how much + progress was made in the overall byte download (the sum of all bytes downloaded as + a fraction of the known total bytes, if provided).""" + + def __init__( + self, + total_bytes: Optional[int], + prefix: str, + quiet: bool, + ) -> None: + self._total_bytes = total_bytes + self._prefix = prefix + self._total_progress = 0 + self._subtasks: List[Tuple[str, Optional[int]]] = [] + self._rate_limiter = RateLimiter(0.25) + self._stream = sys.stdout + self._quiet = quiet + + def add_subtask(self, description: str, total: Optional[int]) -> TaskID: + task_id = len(self._subtasks) + self._subtasks.append((description, total)) + return TaskID(task_id) + + def _write_immediate(self, line: str) -> None: + if self._quiet: + return + self._stream.write(f"{self._prefix}{line}\n") + self._stream.flush() + + @staticmethod + def _format_total(total: Optional[int]) -> str: + if total is None: + return "?" + return str(total) + + def _total_tasks(self) -> int: + return len(self._subtasks) + + def start_subtask(self, task_id: TaskID) -> None: + description, total = self._subtasks[task_id] + total_fmt = self._format_total(total) + task_index = task_id + 1 + n = self._total_tasks() + self._write_immediate( + f"Starting download [{task_index}/{n}] {description} ({total_fmt} bytes)" + ) + + def _write_progress(self) -> None: + total_fmt = self._format_total(self._total_bytes) + if self._total_bytes is not None: + raw_pcnt = float(self._total_progress) / float(self._total_bytes) * 100 + pcnt = str(round(raw_pcnt, 1)) + else: + pcnt = "?" + self._write_immediate( + f"Progress {pcnt}% {self._total_progress} of {total_fmt} bytes" + ) + + def advance_subtask(self, task_id: TaskID, steps: int) -> None: + self._total_progress += steps + if self._rate_limiter.ready() or self._total_progress == self._total_bytes: + self._write_progress() + self._rate_limiter.reset() + + def finish_subtask(self, task_id: TaskID) -> None: + description, _total = self._subtasks[task_id] + task_index = task_id + 1 + n = self._total_tasks() + self._write_immediate(f"Completed download [{task_index}/{n}] {description}") + + def __enter__(self) -> "BatchedRawProgressBar": + self._write_progress() + return self + + def __exit__(self, ty: Any, val: Any, tb: Any) -> None: + pass + + @classmethod + def create( + cls, + num_tasks: int, + known_total_length: Optional[int], + quiet: bool, + color: bool, + ) -> "BatchedRawProgressBar": + prefix = _progress_task_prefix() + return cls(known_total_length, prefix, quiet=quiet) + + +class BatchedRichProgressBar(BatchedProgress): + """Extremely rich progress output for download tasks. + + Provides overall byte progress as well as a separate progress for # of tasks + completed, with individual lines for each subtask. Subtasks are removed from the + table upon completion. ETA and %completion is generated for all subtasks as well as + the overall byte download task.""" + + def __init__( + self, + task_progress: Progress, + total_task_id: TaskID, + progress: Progress, + total_bytes_task_id: TaskID, + quiet: bool, + color: bool, + ) -> None: + self._task_progress = task_progress + self._total_task_id = total_task_id + self._progress = progress + self._total_bytes_task_id = total_bytes_task_id + self._quiet = quiet + self._color = color + self._live: Optional[Live] = None + + _TRIM_LEN = 20 + + def add_subtask(self, description: str, total: Optional[int]) -> TaskID: + if len(description) > self._TRIM_LEN: + description_trimmed = description[: self._TRIM_LEN] + "..." + else: + description_trimmed = description + return self._progress.add_task( + description=f"[green]{description_trimmed}", + start=False, + total=total, + ) + + def start_subtask(self, task_id: TaskID) -> None: + self._progress.start_task(task_id) + + def advance_subtask(self, task_id: TaskID, steps: int) -> None: + self._progress.advance(self._total_bytes_task_id, steps) + self._progress.advance(task_id, steps) + + def finish_subtask(self, task_id: TaskID) -> None: + self._task_progress.advance(self._total_task_id) + self._progress.remove_task(task_id) + + def __enter__(self) -> "BatchedRichProgressBar": + """Generate a table with two rows so different columns can be used. + + Overall progress in terms of # tasks completed is shown at top, while a box of + all individual tasks is provided below. Tasks are removed from the table (making + it shorter) when completed, and are shown with indeterminate ETA before they are + started.""" + table = Table.grid() + table.add_row( + Panel( + self._task_progress, + title="Download Progress", + border_style="cyan", + padding=(0, 1), + ) + ) + table.add_row( + Panel( + self._progress, + title="[b]Individual Request Progress", + border_style="green", + padding=(0, 0), + ) + ) + self._live = Live( + table, + # TODO: consider writing to stderr over stdout? + console=Console(stderr=False, quiet=self._quiet, no_color=not self._color), + refresh_per_second=5, + ) + self._task_progress.start_task(self._total_task_id) + self._progress.start_task(self._total_bytes_task_id) + self._live.__enter__() + return self + + def __exit__(self, ty: Any, val: Any, tb: Any) -> None: + assert self._live is not None + self._live.__exit__(ty, val, tb) + + @classmethod + def create( + cls, + num_tasks: int, + known_total_length: Optional[int], + quiet: bool, + color: bool, + ) -> "BatchedRichProgressBar": + # This progress indicator is for completion of download subtasks, separate from + # counting overall progress by summing chunk byte lengths. + task_columns = _task_columns() + task_progress = Progress(*task_columns) + # Create the single task in this progress indicator, tracking # of + # completed tasks. + total_task_id = task_progress.add_task( + description="[yellow]total downloads", + start=False, + total=num_tasks, + ) + + # This progress indicator is for individual byte downloads. + if known_total_length is None: + total = float("inf") + columns = _unknown_size_columns() + else: + total = known_total_length + columns = _known_size_columns() + progress = Progress(*columns) + # Create a task for total progress in byte downloads. + total_bytes_task_id = progress.add_task( + description="[cyan]total bytes", + start=False, + total=total, + ) + + return cls( + task_progress, + total_task_id, + progress, + total_bytes_task_id, + quiet=quiet, + color=color, + ) diff --git a/src/pip/_internal/cli/req_command.py b/src/pip/_internal/cli/req_command.py index 92900f94ff4..521cd7c917c 100644 --- a/src/pip/_internal/cli/req_command.py +++ b/src/pip/_internal/cli/req_command.py @@ -14,6 +14,7 @@ from pip._internal.cli import cmdoptions from pip._internal.cli.index_command import IndexGroupCommand from pip._internal.cli.index_command import SessionCommandMixin as SessionCommandMixin +from pip._internal.cli.progress_bars import ProgressBarType from pip._internal.exceptions import CommandError, PreviousBuildDirError from pip._internal.index.collector import LinkCollector from pip._internal.index.package_finder import PackageFinder @@ -135,12 +136,14 @@ def make_requirement_preparer( check_build_deps=options.check_build_deps, build_tracker=build_tracker, session=session, - progress_bar=options.progress_bar, + progress_bar=ProgressBarType(options.progress_bar), finder=finder, require_hashes=options.require_hashes, use_user_site=use_user_site, lazy_wheel=lazy_wheel, verbosity=verbosity, + quietness=options.quiet, + color=not options.no_color, legacy_resolver=legacy_resolver, ) diff --git a/src/pip/_internal/network/download.py b/src/pip/_internal/network/download.py index 5c3bce3d2fd..ebae94fc7f1 100644 --- a/src/pip/_internal/network/download.py +++ b/src/pip/_internal/network/download.py @@ -5,11 +5,17 @@ import logging import mimetypes import os -from typing import Iterable, Optional, Tuple +from pathlib import Path +from typing import Iterable, List, Mapping, Optional, Tuple from pip._vendor.requests.models import Response +from pip._vendor.rich.progress import TaskID -from pip._internal.cli.progress_bars import get_download_progress_renderer +from pip._internal.cli.progress_bars import ( + BatchedProgress, + ProgressBarType, + get_download_progress_renderer, +) from pip._internal.exceptions import NetworkConnectionError from pip._internal.models.index import PyPI from pip._internal.models.link import Link @@ -28,28 +34,42 @@ def _get_http_response_size(resp: Response) -> Optional[int]: return None -def _prepare_download( - resp: Response, - link: Link, - progress_bar: str, -) -> Iterable[bytes]: - total_length = _get_http_response_size(resp) - +def _format_download_log_url(link: Link) -> str: if link.netloc == PyPI.file_storage_domain: url = link.show_url else: url = link.url_without_fragment - logged_url = redact_auth_from_url(url) + return redact_auth_from_url(url) + + +def _log_download_link( + link: Link, + total_length: Optional[int], + link_is_from_cache: bool = False, +) -> None: + logged_url = _format_download_log_url(link) if total_length: logged_url = f"{logged_url} ({format_size(total_length)})" - if is_from_cache(resp): + if link_is_from_cache: logger.info("Using cached %s", logged_url) else: logger.info("Downloading %s", logged_url) + +def _prepare_download( + resp: Response, + link: Link, + progress_bar: ProgressBarType, + quiet: bool = False, + color: bool = True, +) -> Iterable[bytes]: + total_length = _get_http_response_size(resp) + + _log_download_link(link, total_length, is_from_cache(resp)) + if logger.getEffectiveLevel() > logging.INFO: show_progress = False elif is_from_cache(resp): @@ -66,7 +86,12 @@ def _prepare_download( if not show_progress: return chunks - renderer = get_download_progress_renderer(bar_type=progress_bar, size=total_length) + renderer = get_download_progress_renderer( + bar_type=progress_bar, + size=total_length, + quiet=quiet, + color=color, + ) return renderer(chunks) @@ -92,22 +117,24 @@ def parse_content_disposition(content_disposition: str, default_filename: str) - return filename or default_filename -def _get_http_response_filename(resp: Response, link: Link) -> str: +def _get_http_response_filename( + headers: Mapping[str, str], resp_url: str, link: Link +) -> str: """Get an ideal filename from the given HTTP response, falling back to the link filename if not provided. """ filename = link.filename # fallback # Have a look at the Content-Disposition header for a better guess - content_disposition = resp.headers.get("content-disposition") + content_disposition = headers.get("content-disposition", None) if content_disposition: filename = parse_content_disposition(content_disposition, filename) ext: Optional[str] = splitext(filename)[1] if not ext: - ext = mimetypes.guess_extension(resp.headers.get("content-type", "")) + ext = mimetypes.guess_extension(headers.get("content-type", "")) if ext: filename += ext - if not ext and link.url != resp.url: - ext = os.path.splitext(resp.url)[1] + if not ext and link.url != resp_url: + ext = os.path.splitext(resp_url)[1] if ext: filename += ext return filename @@ -120,14 +147,35 @@ def _http_get_download(session: PipSession, link: Link) -> Response: return resp +def _http_head_content_info( + session: PipSession, + link: Link, +) -> Tuple[Optional[int], str]: + target_url = link.url.split("#", 1)[0] + resp = session.head(target_url) + raise_for_status(resp) + + if length := resp.headers.get("content-length", None): + content_length = int(length) + else: + content_length = None + + filename = _get_http_response_filename(resp.headers, resp.url, link) + return content_length, filename + + class Downloader: def __init__( self, session: PipSession, - progress_bar: str, + progress_bar: ProgressBarType, + quiet: bool = False, + color: bool = True, ) -> None: self._session = session self._progress_bar = progress_bar + self._quiet = quiet + self._color = color def __call__(self, link: Link, location: str) -> Tuple[str, str]: """Download the file given by link into location.""" @@ -140,10 +188,12 @@ def __call__(self, link: Link, location: str) -> Tuple[str, str]: ) raise - filename = _get_http_response_filename(resp, link) + filename = _get_http_response_filename(resp.headers, resp.url, link) filepath = os.path.join(location, filename) - chunks = _prepare_download(resp, link, self._progress_bar) + chunks = _prepare_download( + resp, link, self._progress_bar, quiet=self._quiet, color=self._color + ) with open(filepath, "wb") as content_file: for chunk in chunks: content_file.write(chunk) @@ -155,33 +205,75 @@ class BatchDownloader: def __init__( self, session: PipSession, - progress_bar: str, + progress_bar: ProgressBarType, + quiet: bool = False, + color: bool = True, ) -> None: self._session = session self._progress_bar = progress_bar + self._quiet = quiet + self._color = color def __call__( - self, links: Iterable[Link], location: str - ) -> Iterable[Tuple[Link, Tuple[str, str]]]: + self, links: Iterable[Link], location: Path + ) -> Iterable[Tuple[Link, Tuple[Path, Optional[str]]]]: """Download the files given by links into location.""" - for link in links: - try: - resp = _http_get_download(self._session, link) - except NetworkConnectionError as e: - assert e.response is not None - logger.critical( - "HTTP error %s while getting %s", - e.response.status_code, - link, - ) - raise - - filename = _get_http_response_filename(resp, link) - filepath = os.path.join(location, filename) - - chunks = _prepare_download(resp, link, self._progress_bar) - with open(filepath, "wb") as content_file: - for chunk in chunks: - content_file.write(chunk) - content_type = resp.headers.get("Content-Type", "") - yield link, (filepath, content_type) + # Calculate the byte length for each file, if available. + links_with_lengths: List[Tuple[Link, Tuple[Optional[int], str]]] = [ + (link, _http_head_content_info(self._session, link)) for link in links + ] + # Sum up the total length we'll be downloading. + # TODO: filter out responses from cache from total download size? + total_length: Optional[int] = 0 + for _link, (maybe_len, _filename) in links_with_lengths: + if maybe_len is None: + total_length = None + break + assert total_length is not None + total_length += maybe_len + + batched_progress = BatchedProgress.select_progress_bar( + self._progress_bar + ).create( + num_tasks=len(links_with_lengths), + known_total_length=total_length, + quiet=self._quiet, + color=self._color, + ) + + link_tasks: List[Tuple[Link, TaskID, str]] = [] + for link, (maybe_len, filename) in links_with_lengths: + _log_download_link(link, maybe_len) + task_id = batched_progress.add_subtask(filename, maybe_len) + link_tasks.append((link, task_id, filename)) + + with batched_progress: + for link, task_id, filename in link_tasks: + try: + resp = _http_get_download(self._session, link) + except NetworkConnectionError as e: + assert e.response is not None + logger.critical( + "HTTP error %s while getting %s", + e.response.status_code, + link, + ) + raise + + filepath = location / filename + content_type = resp.headers.get("Content-Type") + # TODO: different chunk size for batched downloads? + chunks = response_chunks(resp) + with open(filepath, "wb") as content_file: + # Notify that the current task has begun. + batched_progress.start_subtask(task_id) + for chunk in chunks: + # Copy chunk directly to output file, without any + # additional buffering. + content_file.write(chunk) + # Update progress. + batched_progress.advance_subtask(task_id, len(chunk)) + # Notify of completion. + batched_progress.finish_subtask(task_id) + # Yield completed link and download path. + yield link, (filepath, content_type) diff --git a/src/pip/_internal/network/utils.py b/src/pip/_internal/network/utils.py index bba4c265e89..6127715f5a5 100644 --- a/src/pip/_internal/network/utils.py +++ b/src/pip/_internal/network/utils.py @@ -56,6 +56,7 @@ def raise_for_status(resp: Response) -> None: raise NetworkConnectionError(http_error_msg, response=resp) +# TODO: consider reading into a bytearray? def response_chunks( response: Response, chunk_size: int = DOWNLOAD_CHUNK_SIZE ) -> Generator[bytes, None, None]: diff --git a/src/pip/_internal/operations/prepare.py b/src/pip/_internal/operations/prepare.py index e6aa3447200..cd716ea8903 100644 --- a/src/pip/_internal/operations/prepare.py +++ b/src/pip/_internal/operations/prepare.py @@ -13,6 +13,7 @@ from pip._vendor.packaging.utils import canonicalize_name +from pip._internal.cli.progress_bars import ProgressBarType from pip._internal.distributions import make_distribution_for_install_requirement from pip._internal.distributions.installed import InstalledDistribution from pip._internal.exceptions import ( @@ -215,7 +216,7 @@ def _check_download_dir( class RequirementPreparer: """Prepares a Requirement""" - def __init__( + def __init__( # noqa: PLR0913 self, build_dir: str, download_dir: Optional[str], @@ -224,12 +225,14 @@ def __init__( check_build_deps: bool, build_tracker: BuildTracker, session: PipSession, - progress_bar: str, + progress_bar: ProgressBarType, finder: PackageFinder, require_hashes: bool, use_user_site: bool, lazy_wheel: bool, verbosity: int, + quietness: int, + color: bool, legacy_resolver: bool, ) -> None: super().__init__() @@ -238,8 +241,15 @@ def __init__( self.build_dir = build_dir self.build_tracker = build_tracker self._session = session - self._download = Downloader(session, progress_bar) - self._batch_download = BatchDownloader(session, progress_bar) + self._download = Downloader( + session, progress_bar, quiet=quietness > 0, color=color + ) + self._batch_download = BatchDownloader( + session, + progress_bar, + quiet=quietness > 0, + color=color, + ) self.finder = finder # Where still-packed archives should be written to. If None, they are @@ -464,28 +474,29 @@ def _complete_partial_requirements( batch_download = self._batch_download( links_to_fully_download.keys(), - temp_dir, + Path(temp_dir), ) + # Process completely-downloaded files in parallel with the worker threads + # spawned by the BatchDownloader. for link, (filepath, _) in batch_download: - logger.debug("Downloading link %s to %s", link, filepath) + logger.debug("Completed download for link %s into %s", link, filepath) req = links_to_fully_download[link] # Record the downloaded file path so wheel reqs can extract a Distribution # in .get_dist(). - req.local_file_path = filepath + req.local_file_path = str(filepath) # Record that the file is downloaded so we don't do it again in # _prepare_linked_requirement(). - self._downloaded[req.link.url] = filepath + self._downloaded[req.link.url] = str(filepath) # If this is an sdist, we need to unpack it after downloading, but the # .source_dir won't be set up until we are in _prepare_linked_requirement(). # Add the downloaded archive to the install requirement to unpack after # preparing the source dir. if not req.is_wheel: - req.needs_unpacked_archive(Path(filepath)) + req.needs_unpacked_archive(filepath) - # This step is necessary to ensure all lazy wheels are processed - # successfully by the 'download', 'wheel', and 'install' commands. - for req in partially_downloaded_reqs: + # This step is necessary to ensure all lazy wheels are processed + # successfully by the 'download', 'wheel', and 'install' commands. self._prepare_linked_requirement(req, parallel_builds) def prepare_linked_requirement( diff --git a/tests/unit/test_network_download.py b/tests/unit/test_network_download.py index 53200f2e511..06307f0aa73 100644 --- a/tests/unit/test_network_download.py +++ b/tests/unit/test_network_download.py @@ -4,6 +4,7 @@ import pytest +from pip._internal.cli.progress_bars import ProgressBarType from pip._internal.models.link import Link from pip._internal.network.download import ( _prepare_download, @@ -63,7 +64,7 @@ def test_prepare_download__log( if from_cache: resp.from_cache = from_cache link = Link(url) - _prepare_download(resp, link, progress_bar="on") + _prepare_download(resp, link, progress_bar=ProgressBarType.ON) assert len(caplog.records) == 1 record = caplog.records[0] diff --git a/tests/unit/test_operations_prepare.py b/tests/unit/test_operations_prepare.py index d06733e8503..608c32f6147 100644 --- a/tests/unit/test_operations_prepare.py +++ b/tests/unit/test_operations_prepare.py @@ -8,6 +8,7 @@ import pytest +from pip._internal.cli.progress_bars import ProgressBarType from pip._internal.exceptions import HashMismatch from pip._internal.models.link import Link from pip._internal.network.download import Downloader @@ -31,7 +32,7 @@ def _fake_session_get(*args: Any, **kwargs: Any) -> Dict[str, str]: session = Mock() session.get = _fake_session_get - download = Downloader(session, progress_bar="on") + download = Downloader(session, progress_bar=ProgressBarType.ON) uri = data.packages.joinpath("simple-1.0.tar.gz").as_uri() link = Link(uri) @@ -77,7 +78,7 @@ def test_download_http_url__no_directory_traversal( "content-disposition": 'attachment;filename="../out_dir_file"', } session.get.return_value = resp - download = Downloader(session, progress_bar="on") + download = Downloader(session, progress_bar=ProgressBarType.ON) download_dir = os.fspath(tmpdir.joinpath("download")) os.mkdir(download_dir) diff --git a/tests/unit/test_req.py b/tests/unit/test_req.py index 8a95c058706..460a4ae7549 100644 --- a/tests/unit/test_req.py +++ b/tests/unit/test_req.py @@ -14,6 +14,7 @@ from pip._vendor.packaging.requirements import Requirement from pip._internal.cache import WheelCache +from pip._internal.cli.progress_bars import ProgressBarType from pip._internal.commands import create_command from pip._internal.commands.install import InstallCommand from pip._internal.exceptions import ( @@ -100,12 +101,14 @@ def _basic_resolver( check_build_deps=False, build_tracker=tracker, session=session, - progress_bar="on", + progress_bar=ProgressBarType.ON, finder=finder, require_hashes=require_hashes, use_user_site=False, lazy_wheel=False, verbosity=0, + quietness=0, + color=True, legacy_resolver=True, ) yield Resolver( From 11a7a0f58e91f579d577a8ad47a0ac5cdd1f1c05 Mon Sep 17 00:00:00 2001 From: Danny McClanahan <1305167+cosmicexplorer@users.noreply.github.com> Date: Tue, 20 Aug 2024 13:05:36 -0400 Subject: [PATCH 2/3] download larger files first --- src/pip/_internal/network/download.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/pip/_internal/network/download.py b/src/pip/_internal/network/download.py index ebae94fc7f1..61509457312 100644 --- a/src/pip/_internal/network/download.py +++ b/src/pip/_internal/network/download.py @@ -6,7 +6,7 @@ import mimetypes import os from pathlib import Path -from typing import Iterable, List, Mapping, Optional, Tuple +from typing import Iterable, List, Mapping, Optional, Tuple, cast from pip._vendor.requests.models import Response from pip._vendor.rich.progress import TaskID @@ -231,6 +231,10 @@ def __call__( break assert total_length is not None total_length += maybe_len + # Sort downloads to perform larger downloads first. + if total_length is not None: + # Extract the length from each tuple entry. + links_with_lengths.sort(key=lambda t: cast(int, t[1][0]), reverse=True) batched_progress = BatchedProgress.select_progress_bar( self._progress_bar From 6eeb26fef64a86d343424c20d10b0a887a8bbd10 Mon Sep 17 00:00:00 2001 From: Danny McClanahan <1305167+cosmicexplorer@users.noreply.github.com> Date: Tue, 20 Aug 2024 10:51:27 -0400 Subject: [PATCH 3/3] execute batch downloads in parallel worker threads - limit downloads to 10 at a time instead of starting all at once - add cli arg to limit download parallelism - factor out receiving thread exceptions into a contextmanager - default batch parallelism to 10 - make batch download parallelism 1 in html index test - explicitly yield threads to help ensure correct download ordering --- news/12923.feature.rst | 1 + src/pip/_internal/cli/cmdoptions.py | 17 ++ src/pip/_internal/cli/req_command.py | 2 + src/pip/_internal/commands/download.py | 2 + src/pip/_internal/commands/install.py | 2 + src/pip/_internal/commands/wheel.py | 2 + src/pip/_internal/network/download.py | 230 +++++++++++++++++++----- src/pip/_internal/operations/prepare.py | 2 + tests/functional/test_download.py | 2 + tests/unit/test_req.py | 1 + 10 files changed, 215 insertions(+), 46 deletions(-) create mode 100644 news/12923.feature.rst diff --git a/news/12923.feature.rst b/news/12923.feature.rst new file mode 100644 index 00000000000..e152a8d5f1a --- /dev/null +++ b/news/12923.feature.rst @@ -0,0 +1 @@ +Download concrete dists for metadata-only resolves in parallel using worker threads. Add ``--batch-download-parallelism`` CLI flag to limit parallelism. diff --git a/src/pip/_internal/cli/cmdoptions.py b/src/pip/_internal/cli/cmdoptions.py index 3bb7cbf4246..4b7a3e9b51b 100644 --- a/src/pip/_internal/cli/cmdoptions.py +++ b/src/pip/_internal/cli/cmdoptions.py @@ -236,6 +236,23 @@ class PipOption(Option): ) +batch_download_parallelism: Callable[..., Option] = partial( + Option, + "--batch-download-parallelism", + dest="batch_download_parallelism", + type="int", + default=10, + help=( + "Maximum parallelism employed for batch downloading of metadata-only dists" + " (default %default parallel requests)." + " Note that more than 10 downloads may overflow the requests connection pool," + " which may affect performance." + " Note also that commands such as 'install --dry-run' should avoid downloads" + " entirely, and so will not be affected by this option." + ), +) + + log: Callable[..., Option] = partial( PipOption, "--log", diff --git a/src/pip/_internal/cli/req_command.py b/src/pip/_internal/cli/req_command.py index 521cd7c917c..e215d64fe04 100644 --- a/src/pip/_internal/cli/req_command.py +++ b/src/pip/_internal/cli/req_command.py @@ -101,6 +101,7 @@ def make_requirement_preparer( use_user_site: bool, download_dir: Optional[str] = None, verbosity: int = 0, + batch_download_parallelism: Optional[int] = None, ) -> RequirementPreparer: """ Create a RequirementPreparer instance for the given parameters. @@ -144,6 +145,7 @@ def make_requirement_preparer( verbosity=verbosity, quietness=options.quiet, color=not options.no_color, + batch_download_parallelism=batch_download_parallelism, legacy_resolver=legacy_resolver, ) diff --git a/src/pip/_internal/commands/download.py b/src/pip/_internal/commands/download.py index 917bbb91d83..c2a090de1a3 100644 --- a/src/pip/_internal/commands/download.py +++ b/src/pip/_internal/commands/download.py @@ -47,6 +47,7 @@ def add_options(self) -> None: self.cmd_opts.add_option(cmdoptions.pre()) self.cmd_opts.add_option(cmdoptions.require_hashes()) self.cmd_opts.add_option(cmdoptions.progress_bar()) + self.cmd_opts.add_option(cmdoptions.batch_download_parallelism()) self.cmd_opts.add_option(cmdoptions.no_build_isolation()) self.cmd_opts.add_option(cmdoptions.use_pep517()) self.cmd_opts.add_option(cmdoptions.no_use_pep517()) @@ -116,6 +117,7 @@ def run(self, options: Values, args: List[str]) -> int: download_dir=options.download_dir, use_user_site=False, verbosity=self.verbosity, + batch_download_parallelism=options.batch_download_parallelism, ) resolver = self.make_resolver( diff --git a/src/pip/_internal/commands/install.py b/src/pip/_internal/commands/install.py index ad45a2f2a57..fda0e76ba23 100644 --- a/src/pip/_internal/commands/install.py +++ b/src/pip/_internal/commands/install.py @@ -237,6 +237,7 @@ def add_options(self) -> None: self.cmd_opts.add_option(cmdoptions.prefer_binary()) self.cmd_opts.add_option(cmdoptions.require_hashes()) self.cmd_opts.add_option(cmdoptions.progress_bar()) + self.cmd_opts.add_option(cmdoptions.batch_download_parallelism()) self.cmd_opts.add_option(cmdoptions.root_user_action()) index_opts = cmdoptions.make_option_group( @@ -359,6 +360,7 @@ def run(self, options: Values, args: List[str]) -> int: finder=finder, use_user_site=options.use_user_site, verbosity=self.verbosity, + batch_download_parallelism=options.batch_download_parallelism, ) resolver = self.make_resolver( preparer=preparer, diff --git a/src/pip/_internal/commands/wheel.py b/src/pip/_internal/commands/wheel.py index 278719f4e0c..21f11bfe49c 100644 --- a/src/pip/_internal/commands/wheel.py +++ b/src/pip/_internal/commands/wheel.py @@ -67,6 +67,7 @@ def add_options(self) -> None: self.cmd_opts.add_option(cmdoptions.ignore_requires_python()) self.cmd_opts.add_option(cmdoptions.no_deps()) self.cmd_opts.add_option(cmdoptions.progress_bar()) + self.cmd_opts.add_option(cmdoptions.batch_download_parallelism()) self.cmd_opts.add_option( "--no-verify", @@ -131,6 +132,7 @@ def run(self, options: Values, args: List[str]) -> int: download_dir=options.wheel_dir, use_user_site=False, verbosity=self.verbosity, + batch_download_parallelism=options.batch_download_parallelism, ) resolver = self.make_resolver( diff --git a/src/pip/_internal/network/download.py b/src/pip/_internal/network/download.py index 61509457312..98f6779d484 100644 --- a/src/pip/_internal/network/download.py +++ b/src/pip/_internal/network/download.py @@ -5,8 +5,12 @@ import logging import mimetypes import os +import time +from contextlib import contextmanager from pathlib import Path -from typing import Iterable, List, Mapping, Optional, Tuple, cast +from queue import Queue +from threading import Event, Semaphore, Thread +from typing import Iterable, Iterator, List, Mapping, Optional, Tuple, Union, cast from pip._vendor.requests.models import Response from pip._vendor.rich.progress import TaskID @@ -16,7 +20,7 @@ ProgressBarType, get_download_progress_renderer, ) -from pip._internal.exceptions import NetworkConnectionError +from pip._internal.exceptions import CommandError, NetworkConnectionError from pip._internal.models.index import PyPI from pip._internal.models.link import Link from pip._internal.network.cache import is_from_cache @@ -140,11 +144,20 @@ def _get_http_response_filename( return filename +def _maybe_log_http_error(response: Response, link: Link) -> Response: + try: + raise_for_status(response) + return response + except NetworkConnectionError as e: + assert e.response is not None + logger.critical("HTTP error %s while getting %s", e.response.status_code, link) + raise + + def _http_get_download(session: PipSession, link: Link) -> Response: target_url = link.url.split("#", 1)[0] resp = session.get(target_url, headers=HEADERS, stream=True) - raise_for_status(resp) - return resp + return _maybe_log_http_error(resp, link) def _http_head_content_info( @@ -152,10 +165,9 @@ def _http_head_content_info( link: Link, ) -> Tuple[Optional[int], str]: target_url = link.url.split("#", 1)[0] - resp = session.head(target_url) - raise_for_status(resp) + resp = _maybe_log_http_error(session.head(target_url), link) - if length := resp.headers.get("content-length", None): + if length := resp.headers.get("content-length"): content_length = int(length) else: content_length = None @@ -179,14 +191,7 @@ def __init__( def __call__(self, link: Link, location: str) -> Tuple[str, str]: """Download the file given by link into location.""" - try: - resp = _http_get_download(self._session, link) - except NetworkConnectionError as e: - assert e.response is not None - logger.critical( - "HTTP error %s while getting %s", e.response.status_code, link - ) - raise + resp = _http_get_download(self._session, link) filename = _get_http_response_filename(resp.headers, resp.url, link) filepath = os.path.join(location, filename) @@ -201,6 +206,108 @@ def __call__(self, link: Link, location: str) -> Tuple[str, str]: return filepath, content_type +class _ErrorReceiver: + def __init__(self, error_flag: Event) -> None: + self._error_flag = error_flag + self._thread_exception: Optional[BaseException] = None + + def receive_error(self, exc: BaseException) -> None: + self._error_flag.set() + self._thread_exception = exc + + def stored_error(self) -> Optional[BaseException]: + return self._thread_exception + + +@contextmanager +def _spawn_workers( + workers: List[Thread], error_flag: Event +) -> Iterator[_ErrorReceiver]: + err_recv = _ErrorReceiver(error_flag) + try: + for w in workers: + w.start() + # We've sorted the list of worker threads so they correspond to the largest + # downloads first. Each thread immediately waits upon a semaphore to limit + # maximum parallel downloads, and we would like the semaphore's internal + # wait queue to retain the same order we established earlier (otherwise, we + # would end up nondeterministically downloading files out of our desired + # order). Yielding to the scheduler here is intended to give the thread we + # just started time to jump into the semaphore, either to execute further + # (until the semaphore is full) or to jump into the queue at the desired + # position. We seem to get the ordering reliably even without this explicit + # yield, and ideally we would like to somehow ensure this deterministically, + # but this is relatively idiomatic and lets us lean on much fewer + # synchronization constructs. We can revisit this if users find the ordering + # is unreliable. It's easy to see if we've messed up, as the rich progress + # table prominently shows each download size and which ones are executing. + time.sleep(0) + yield err_recv + except BaseException as e: + err_recv.receive_error(e) + finally: + thread_exception = err_recv.stored_error() + if thread_exception is not None: + logger.critical("Received exception, shutting down downloader threads...") + + # Ensure each thread is complete by the time the queue has exited, either by + # writing the full request contents, or by checking the Event from an exception. + for w in workers: + # If the user (reasonably) wants to hit ^C again to try to make it close + # faster, we want to avoid spewing out a ton of error text, but at least + # let's let them know we hear them and we're trying to shut down! + while w.is_alive(): + try: + w.join() + except BaseException: + logger.critical("Shutting down worker threads, please wait...") + + if thread_exception is not None: + raise thread_exception + + +def _copy_chunks( + output_queue: "Queue[Union[Tuple[Link, Path, Optional[str]], BaseException]]", + error_flag: Event, + semaphore: Semaphore, + session: PipSession, + location: Path, + batched_progress: BatchedProgress, + download_info: Tuple[Link, TaskID, str], +) -> None: + link, task_id, filename = download_info + + with semaphore: + # Check if another thread exited with an exception before we started. + if error_flag.is_set(): + return + try: + resp = _http_get_download(session, link) + + filepath = location / filename + content_type = resp.headers.get("Content-Type") + # TODO: different chunk size for batched downloads? + chunks = response_chunks(resp) + with filepath.open("wb") as output_file: + # Notify that the current task has begun. + batched_progress.start_subtask(task_id) + for chunk in chunks: + # Check if another thread exited with an exception between chunks. + if error_flag.is_set(): + return + # Copy chunk directly to output file, without any + # additional buffering. + output_file.write(chunk) + # Update progress. + batched_progress.advance_subtask(task_id, len(chunk)) + + output_queue.put((link, filepath, content_type)) + except BaseException as e: + output_queue.put(e) + finally: + batched_progress.finish_subtask(task_id) + + class BatchDownloader: def __init__( self, @@ -208,12 +315,21 @@ def __init__( progress_bar: ProgressBarType, quiet: bool = False, color: bool = True, + max_parallelism: Optional[int] = None, ) -> None: self._session = session self._progress_bar = progress_bar self._quiet = quiet self._color = color + if max_parallelism is None: + max_parallelism = 1 + if max_parallelism < 1: + raise CommandError( + f"invalid batch download parallelism {max_parallelism}: must be >=1" + ) + self._max_parallelism: int = max_parallelism + def __call__( self, links: Iterable[Link], location: Path ) -> Iterable[Tuple[Link, Tuple[Path, Optional[str]]]]: @@ -231,53 +347,75 @@ def __call__( break assert total_length is not None total_length += maybe_len - # Sort downloads to perform larger downloads first. + # If lengths are available, sort downloads to perform larger downloads first. if total_length is not None: # Extract the length from each tuple entry. links_with_lengths.sort(key=lambda t: cast(int, t[1][0]), reverse=True) + # Set up state to track thread progress, including inner exceptions. + total_downloads: int = len(links_with_lengths) + completed_downloads: int = 0 + q: "Queue[Union[Tuple[Link, Path, Optional[str]], BaseException]]" = Queue() + error_flag = Event() + # Limit downloads to 10 at a time so we can reuse our connection pool. + semaphore = Semaphore(value=self._max_parallelism) batched_progress = BatchedProgress.select_progress_bar( self._progress_bar ).create( - num_tasks=len(links_with_lengths), + num_tasks=total_downloads, known_total_length=total_length, quiet=self._quiet, color=self._color, ) + # Log the link we're about to download, and add it to the progress table. link_tasks: List[Tuple[Link, TaskID, str]] = [] for link, (maybe_len, filename) in links_with_lengths: _log_download_link(link, maybe_len) task_id = batched_progress.add_subtask(filename, maybe_len) link_tasks.append((link, task_id, filename)) + # Distribute request i/o across equivalent threads. + # NB: event-based/async is likely a better model than thread-per-request, but + # (1) pip doesn't use async anywhere else yet, + # (2) this is at most one thread per dependency in the graph (less if any + # are cached) + # (3) pip is fundamentally run in a synchronous context with a clear start + # and end, instead of e.g. as a server which needs to process + # arbitrary further requests at the same time. + # For these reasons, thread-per-request should be sufficient for our needs. + workers = [ + Thread( + target=_copy_chunks, + args=( + q, + error_flag, + semaphore, + self._session, + location, + batched_progress, + download_info, + ), + ) + for download_info in link_tasks + ] + with batched_progress: - for link, task_id, filename in link_tasks: - try: - resp = _http_get_download(self._session, link) - except NetworkConnectionError as e: - assert e.response is not None - logger.critical( - "HTTP error %s while getting %s", - e.response.status_code, - link, - ) - raise - - filepath = location / filename - content_type = resp.headers.get("Content-Type") - # TODO: different chunk size for batched downloads? - chunks = response_chunks(resp) - with open(filepath, "wb") as content_file: - # Notify that the current task has begun. - batched_progress.start_subtask(task_id) - for chunk in chunks: - # Copy chunk directly to output file, without any - # additional buffering. - content_file.write(chunk) - # Update progress. - batched_progress.advance_subtask(task_id, len(chunk)) - # Notify of completion. - batched_progress.finish_subtask(task_id) - # Yield completed link and download path. - yield link, (filepath, content_type) + with _spawn_workers(workers, error_flag) as err_recv: + # Read completed downloads from queue, or extract the exception. + while completed_downloads < total_downloads: + # Get item from queue, but also check for ^C from user! + try: + item = q.get() + except BaseException as e: + err_recv.receive_error(e) + break + # Now see if the worker thread failed with an exception (unlikely). + if isinstance(item, BaseException): + err_recv.receive_error(item) + break + # Otherwise, the thread succeeded, and we can pass it to + # the preparer! + link, filepath, content_type = item + completed_downloads += 1 + yield link, (filepath, content_type) diff --git a/src/pip/_internal/operations/prepare.py b/src/pip/_internal/operations/prepare.py index cd716ea8903..a445c5e17f2 100644 --- a/src/pip/_internal/operations/prepare.py +++ b/src/pip/_internal/operations/prepare.py @@ -233,6 +233,7 @@ def __init__( # noqa: PLR0913 verbosity: int, quietness: int, color: bool, + batch_download_parallelism: Optional[int], legacy_resolver: bool, ) -> None: super().__init__() @@ -249,6 +250,7 @@ def __init__( # noqa: PLR0913 progress_bar, quiet=quietness > 0, color=color, + max_parallelism=batch_download_parallelism, ) self.finder = finder diff --git a/tests/functional/test_download.py b/tests/functional/test_download.py index d469e71c360..c620773837a 100644 --- a/tests/functional/test_download.py +++ b/tests/functional/test_download.py @@ -1288,6 +1288,8 @@ def run_for_generated_index( str(download_dir), "-i", "http://localhost:8000", + "--batch-download-parallelism", + "1", *args, ] result = script.pip(*pip_args, allow_error=allow_error) diff --git a/tests/unit/test_req.py b/tests/unit/test_req.py index 460a4ae7549..5fa4a2e8159 100644 --- a/tests/unit/test_req.py +++ b/tests/unit/test_req.py @@ -109,6 +109,7 @@ def _basic_resolver( verbosity=0, quietness=0, color=True, + batch_download_parallelism=None, legacy_resolver=True, ) yield Resolver(