diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index b5e7c95..5eeabea 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -11,19 +11,23 @@ on: pull_request: workflow_dispatch: schedule: - - cron: '0 0 * * *' # daily + - cron: '0 0 * * *' # nightly build jobs: build: - name: Build py${{ matrix.python-version }} @ ${{ matrix.os }} 🐍 + name: py${{ matrix.python-version }} @ ${{ matrix.os }} runs-on: ${{ matrix.os }} strategy: matrix: - python-version: ['3.8', '3.9', '3.10', '3.11'] - os: ["ubuntu-latest"] include: + - os: "ubuntu-latest" + python-version: '3.8' # first supported + - os: "ubuntu-latest" + python-version: '3.12' # latest supported - os: "windows-latest" - python-version: '3.11' + python-version: '3.12' # latest supported + - os: "macos-latest" + python-version: '3.12' # latest supported steps: - uses: actions/checkout@v4 with: @@ -106,7 +110,7 @@ jobs: echo "GITHUB_REF = $GITHUB_REF" echo "GITHUB_REPOSITORY = $GITHUB_REPOSITORY" - name: Download Artifacts - uses: actions/download-artifact@v2 + uses: actions/download-artifact@v4 - name: Display downloaded files run: ls -aR - name: Upload to PyPI diff --git a/environment.yml b/environment.yml index 28880cc..a86796d 100644 --- a/environment.yml +++ b/environment.yml @@ -26,6 +26,7 @@ dependencies: - more_itertools - smecv_grid - tqdm + - joblib # Optional, for documentation and testing - nbconvert - sphinx_rtd_theme diff --git a/setup.cfg b/setup.cfg index 823fa8a..0b3a80b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -38,6 +38,7 @@ install_requires = pyresample tqdm more_itertools + joblib # The usage of test_requires is discouraged, see `Dependency Management` docs # tests_require = pytest; pytest-cov # Require a specific Python version, e.g. Python 2.7 or >= 3.4 diff --git a/src/repurpose/process.py b/src/repurpose/process.py index 029e435..9340d73 100644 --- a/src/repurpose/process.py +++ b/src/repurpose/process.py @@ -1,12 +1,6 @@ -import sys import time -import warnings import os -if 'numpy' in sys.modules: - warnings.warn("Numpy is already imported. Environment variables set in " - "repurpose.utils wont have any effect!") - # Note: Must be set BEFORE the first numpy import!! os.environ['MKL_NUM_THREADS'] = '1' os.environ['NUMEXPR_NUM_THREADS'] = '1' @@ -14,15 +8,18 @@ os.environ['MKL_DYNAMIC'] = 'FALSE' os.environ['OPENBLAS_NUM_THREADS'] = '1' +import traceback import numpy as np from tqdm import tqdm import logging -from multiprocessing import Pool from datetime import datetime import sys from pathlib import Path -from typing import List +from typing import List, Any from glob import glob +from joblib import Parallel, delayed, parallel_config +from logging.handlers import QueueHandler, QueueListener +from multiprocessing import Manager class ImageBaseConnection: @@ -102,8 +99,8 @@ def read(self, timestamp, **kwargs): def rootdir() -> Path: - return Path(os.path.join(os.path.dirname( - os.path.abspath(__file__)))).parents[1] + p = str(os.path.join(os.path.dirname(os.path.abspath(__file__)))) + return Path(p).parents[1] def idx_chunks(idx, n=-1): @@ -123,6 +120,68 @@ def idx_chunks(idx, n=-1): for i in range(0, len(idx.values), n): yield idx[i:i + n] +class ProgressParallel(Parallel): + def __init__(self, use_tqdm=True, total=None, desc="", + *args, **kwargs) -> None: + """ + Joblib parallel with progress bar + """ + self._use_tqdm = use_tqdm + self._total = total + self._desc = desc + super().__init__(*args, **kwargs) + + def __call__(self, *args, **kwargs): + """ + Wraps progress bar around function calls + """ + with tqdm( + disable=not self._use_tqdm, total=self._total, desc=self._desc + ) as self._pbar: + return Parallel.__call__(self, *args, **kwargs) + def print_progress(self): + """ + Updated the progress bar after each successful call + """ + if self._total is None: + self._pbar.total = self.n_dispatched_tasks + self._pbar.n = self.n_completed_tasks + self._pbar.refresh() + +def configure_worker_logger(log_queue, log_level, name): + worker_logger = logging.getLogger(name) + if not worker_logger.hasHandlers(): + h = QueueHandler(log_queue) + worker_logger.addHandler(h) + worker_logger.setLevel(log_level) + return worker_logger + +def run_with_error_handling(FUNC, + ignore_errors=False, + log_queue=None, + log_level="WARNING", + logger_name=None, + **kwargs) -> Any: + + if log_queue is not None: + logger = configure_worker_logger(log_queue, log_level, logger_name) + else: + # normal logger + logger = logging.getLogger(logger_name) + + r = None + + try: + r = FUNC(**kwargs) + except Exception as e: + if ignore_errors: + logger.error(f"The following ERROR was raised in the parallelized " + f"function `{FUNC.__name__}` but was ignored due to " + f"the chosen settings: " + f"{traceback.format_exc()}") + else: + raise e + return r def parallel_process_async( FUNC, @@ -135,9 +194,13 @@ def parallel_process_async( log_path=None, log_filename=None, loglevel="WARNING", + logger_name=None, verbose=False, - progress_bar_label="Processed" -) -> List: + progress_bar_label="Processed", + backend="loky", + sharedmem=False, + parallel_kwargs=None, +) -> list: """ Applies the passed function to all elements of the passed iterables. Parallel function calls are processed ASYNCHRONOUSLY (ie order of @@ -176,29 +239,41 @@ def parallel_process_async( Name of the logfile in `log_path to create. If None is chosen, a name is created automatically. If `log_path is None, this has no effect. loglevel: str, optional (default: "WARNING") - Log level to use for logging. Must be one of + Which level should be logged. Must be one of ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]. - verbose: float, optional (default: False) + logger_name: str, optional (default: None) + The name to assign to the logger that can be accessed in FUNC to + log to. If not given, then the root logger is used. e.g + ``` + logger = logging.getLogger() + logger.error("Some error message") + ``` + verbose: bool, optional (default: False) Print all logging messages to stdout, useful for debugging. progress_bar_label: str, optional (default: "Processed") Label to use for the progress bar. + backend: Literal["threading", "multiprocessing", "loky"] = "loky" + The backend to use for parallel execution (if n_proc > 1). + Defaults to "loky". See joblib docs for more info. + sharedmem: bool, optional (default:True) + Activate shared memory option (slow) Returns ------- - results: List - List of return values from each function call + results: list or None + List of return values from each function call or None if no return + values are found. """ if activate_logging: - logger = logging.getLogger() + logger = logging.getLogger(logger_name) + logger.setLevel(loglevel.upper()) if STATIC_KWARGS is None: STATIC_KWARGS = dict() if verbose: + # in this case we also print ALL log messages streamHandler = logging.StreamHandler(sys.stdout) - formatter = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s') - streamHandler.setFormatter(formatter) logger.setLevel('DEBUG') logger.addHandler(streamHandler) @@ -211,14 +286,14 @@ def parallel_process_async( log_file = None if log_file: + # in this case the logger should write to file os.makedirs(os.path.dirname(log_file), exist_ok=True) - logging.basicConfig( - filename=log_file, - level=loglevel.upper(), - format="%(levelname)s %(asctime)s %(message)s", + filehandler = logging.FileHandler(log_file) + filehandler.setFormatter(logging.Formatter( + "%(levelname)s %(asctime)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S", - force=True, - ) + )) + logger.addHandler(filehandler) else: logger = None @@ -246,48 +321,57 @@ def parallel_process_async( kws.update(STATIC_KWARGS) process_kwargs.append(kws) - if show_progress_bars: - pbar = tqdm(total=len(process_kwargs), desc=progress_bar_label) - else: - pbar = None - - results = [] - - def update(r) -> None: - if r is not None: - results.append(r) - if pbar is not None: - pbar.update() - - def error(e) -> None: - if logger is not None: - logging.error(e) - if not ignore_errors: - raise e - if pbar is not None: - pbar.update() - if n_proc == 1: + results = [] + if show_progress_bars: + pbar = tqdm(total=len(process_kwargs), desc=progress_bar_label) + else: + pbar = None + for kwargs in process_kwargs: - try: - r = FUNC(**kwargs) - update(r) - except Exception as e: - error(e) + r = run_with_error_handling(FUNC, ignore_errors, + logger_name=logger_name, + **kwargs) + if r is not None: + results.append(r) + if pbar is not None: + pbar.update() else: - with Pool(n_proc) as pool: - for kwds in process_kwargs: - pool.apply_async( - FUNC, - kwds=kwds, - callback=update, - error_callback=error, - ) - pool.close() - pool.join() - - if pbar is not None: - pbar.close() + if logger is not None: + log_level = logger.getEffectiveLevel() + m = Manager() + q = m.Queue() + listener = QueueListener(q, *logger.handlers, + respect_handler_level=True) + listener.start() + else: + q = None + log_level = None + listener = None + + n = 1 if backend == 'loky' else None + with parallel_config(backend=backend, inner_max_num_threads=n): + results: list = ProgressParallel( + use_tqdm=show_progress_bars, + n_jobs=n_proc, + verbose=0, + total=len(process_kwargs), + desc=progress_bar_label, + require='sharedmem' if sharedmem else None, + return_as="list", + **parallel_kwargs or dict(), + )(delayed(run_with_error_handling)( + FUNC, ignore_errors, + log_queue=q, + log_level=log_level, + logger_name=logger_name, + **kwargs) + for kwargs in process_kwargs) + + results = [r for r in results if r is not None] + + if listener is not None: + listener.stop() if logger is not None: if verbose: @@ -299,4 +383,7 @@ def error(e) -> None: handler.close() handlers.clear() - return results + if len(results) == 0: + return None + else: + return results diff --git a/tests/test_process.py b/tests/test_process.py index c320a97..fd0c1ec 100644 --- a/tests/test_process.py +++ b/tests/test_process.py @@ -3,6 +3,7 @@ import time import tempfile import logging +import pytest from repurpose.process import parallel_process_async, idx_chunks @@ -24,12 +25,17 @@ def func(x: int, p: int): logging.info(f'x={x}, p={p}') return x**p -def test_apply_to_elements(): + +@pytest.mark.parametrize("n_proc,backend", [ + ("1", None), # backend doesn't matter in this case + ("2", "threading"), ("2", "multiprocessing"), ("2", "loky") + ]) +def test_apply_to_elements(n_proc, backend): iter_kwargs = {'x': [1, 2, 3, 4]} static_kwargs = {'p': 2} with tempfile.TemporaryDirectory() as log_path: res = parallel_process_async( - func, iter_kwargs, static_kwargs, n_proc=1, + func, iter_kwargs, static_kwargs, n_proc=int(n_proc), show_progress_bars=False, verbose=False, loglevel="DEBUG", - ignore_errors=True, log_path=log_path) - assert sorted(res) == [1, 4, 9, 16] + ignore_errors=True, log_path=log_path, backend=backend) + assert res == [1, 4, 9, 16]