diff --git a/ci/envs/310-conda-forge.yaml b/ci/envs/310-conda-forge.yaml index 8929971b..ab40f2cf 100644 --- a/ci/envs/310-conda-forge.yaml +++ b/ci/envs/310-conda-forge.yaml @@ -12,6 +12,7 @@ dependencies: - requests - joblib - xyzservices + - tqdm # testing - pip - pytest diff --git a/ci/envs/311-conda-forge.yaml b/ci/envs/311-conda-forge.yaml index 8b9c6983..7cb429b2 100644 --- a/ci/envs/311-conda-forge.yaml +++ b/ci/envs/311-conda-forge.yaml @@ -12,6 +12,7 @@ dependencies: - requests - joblib - xyzservices + - tqdm # testing - pip - pytest diff --git a/ci/envs/312-latest-conda-forge.yaml b/ci/envs/312-latest-conda-forge.yaml index 4fbffb4e..021a88d8 100644 --- a/ci/envs/312-latest-conda-forge.yaml +++ b/ci/envs/312-latest-conda-forge.yaml @@ -12,6 +12,7 @@ dependencies: - requests - joblib - xyzservices + - tqdm # testing - pip - pytest diff --git a/ci/envs/313-latest-conda-forge.yaml b/ci/envs/313-latest-conda-forge.yaml index 0989bde7..b3e62007 100644 --- a/ci/envs/313-latest-conda-forge.yaml +++ b/ci/envs/313-latest-conda-forge.yaml @@ -12,6 +12,7 @@ dependencies: - requests - joblib - xyzservices + - tqdm # testing - pip - pytest diff --git a/contextily/__init__.py b/contextily/__init__.py index a64a2130..a15bb1c8 100644 --- a/contextily/__init__.py +++ b/contextily/__init__.py @@ -6,6 +6,7 @@ from .place import Place, plot_map from .tile import * from .plotting import add_basemap, add_attribution +from .progress import set_progress_bar from importlib.metadata import PackageNotFoundError, version diff --git a/contextily/progress.py b/contextily/progress.py new file mode 100644 index 00000000..2fd5c993 --- /dev/null +++ b/contextily/progress.py @@ -0,0 +1,56 @@ +from tqdm import tqdm as _default_progress_bar +from contextlib import nullcontext + +# Default progress bar class (can be changed by set_progress_bar) +_progress_bar = _default_progress_bar + +def set_progress_bar(progress_bar=None): + """ + Set the progress bar class to be used for downloading tiles. + + Parameters + ---------- + progress_bar : class, optional + A tqdm-compatible progress bar class. If None, progress bar is disabled. + The progress bar class should have the same interface as tqdm. + Common alternatives include: + - tqdm.notebook.tqdm for Jupyter notebooks + - custom implementations with the same interface + """ + global _progress_bar + _progress_bar = progress_bar + + +def get_progress_bar(): + """ + Returns the progress bar class to be used for downloading tiles. + If progress bars are disabled (set to None), returns a no-op context + manager that doesn't display progress. + + Returns + ---------- + progress_bar : callable + A callable that returns either a tqdm-compatible progress bar or a + no-op context manager with update/close methods if progress bars + are disabled. + """ + if _progress_bar is None: + class NoOpProgress: + def __init__(self, *args, **kwargs): + self.context = nullcontext() + + def __enter__(self): + self.context.__enter__() + return self + + def __exit__(self, *args): + self.context.__exit__(*args) + + def update(self, n=1): + pass + + def close(self): + pass + + return lambda *args, **kwargs: NoOpProgress(*args, **kwargs) + return _progress_bar diff --git a/contextily/tile.py b/contextily/tile.py index 2e64ac31..fb7f4ae0 100644 --- a/contextily/tile.py +++ b/contextily/tile.py @@ -17,8 +17,9 @@ import rasterio as rio from PIL import Image, UnidentifiedImageError from joblib import Memory as _Memory -from joblib import Parallel, delayed +from concurrent.futures import ThreadPoolExecutor, as_completed from rasterio.transform import from_origin +from .progress import get_progress_bar from rasterio.io import MemoryFile from rasterio.vrt import WarpedVRT from rasterio.enums import Resampling @@ -273,16 +274,28 @@ def bounds2img( # download tiles if n_connections < 1 or not isinstance(n_connections, int): raise ValueError(f"n_connections must be a positive integer value.") - # Use threads for a single connection to avoid the overhead of spawning a process. Use processes for multiple - # connections if caching is enabled, as threads lead to memory issues when used in combination with the joblib - # memory caching (used for the _fetch_tile() function). - preferred_backend = ( - "threads" if (n_connections == 1 or not use_cache) else "processes" - ) + fetch_tile_fn = memory.cache(_fetch_tile) if use_cache else _fetch_tile - arrays = Parallel(n_jobs=n_connections, prefer=preferred_backend)( - delayed(fetch_tile_fn)(tile_url, wait, max_retries) for tile_url in tile_urls - ) + + arrays = [None] * len(tile_urls) # Pre-allocate result list + with get_progress_bar()(total=len(tile_urls), desc="Downloading tiles") as pbar: + with ThreadPoolExecutor(max_workers=n_connections) as executor: + # Submit all tasks and store futures with their indices + future_to_index = { + executor.submit(fetch_tile_fn, url, wait, max_retries): idx + for idx, url in enumerate(tile_urls) + } + + # Process completed futures as they finish + for future in as_completed(future_to_index): + idx = future_to_index[future] + try: + arrays[idx] = future.result() + except Exception as e: + # Re-raise any exceptions from the worker + raise e from None + pbar.update(1) + # merge downloaded tiles merged, extent = _merge_tiles(tiles, arrays) # lon/lat extent --> Spheric Mercator diff --git a/notebooks/intro_guide.ipynb b/notebooks/intro_guide.ipynb index 90f7bf74..5e334e7c 100644 --- a/notebooks/intro_guide.ipynb +++ b/notebooks/intro_guide.ipynb @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -34,7 +34,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -50,7 +50,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -68,7 +68,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -103,7 +103,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -138,7 +138,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -18614,7 +18614,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -18985,7 +18985,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -19006,7 +19006,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -19028,7 +19028,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -19055,7 +19055,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -19082,7 +19082,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -19122,7 +19122,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -19139,7 +19139,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -19180,7 +19180,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -19198,7 +19198,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -19216,7 +19216,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -19253,7 +19253,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -19269,7 +19269,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -19297,7 +19297,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -19321,7 +19321,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -19353,7 +19353,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -19390,7 +19390,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -19435,7 +19435,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -19462,7 +19462,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -19489,7 +19489,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -19516,7 +19516,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -19553,7 +19553,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -19582,7 +19582,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -19612,7 +19612,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -19631,6 +19631,23 @@ "cx.add_basemap(ax, crs=df.crs.to_string(), source=cx.providers.CartoDB.Positron, zoom=12)\n", "cx.add_basemap(ax, crs=df.crs.to_string(), source=cx.providers.CartoDB.PositronOnlyLabels, zoom=10)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Progress bars\n", + "\n", + "The tile download progress bar uses [`tqdm`](https://tqdm.github.io/) to display a progress bar. You can override the default `tqdm` instance. For example if you are running in a Jupyter Notebook environment, you might want to use `tqdm.notebook.tqdm` instead of the default `tqdm.tqdm`. You can also build custom `tqdm` implementations.\n", + "\n", + "```python\n", + "import contextily as cx\n", + "from tqdm.notebook import tqdm\n", + "cx.set_progress_bar(tqdm)\n", + "```\n", + "\n", + "If you don't want to display any progress bar, call `cx.set_progress_bar(None)`" + ] } ], "metadata": { diff --git a/pyproject.toml b/pyproject.toml index 913367ad..7fa8ba7f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,8 @@ dependencies = [ "rasterio", "requests", "joblib", - "xyzservices" + "xyzservices", + "tqdm" ] [project.urls]