From f6c23458cf8fe4884a66154ee35245a8e308cef6 Mon Sep 17 00:00:00 2001 From: Wolfgang Preimesberger Date: Thu, 2 May 2024 09:19:22 +0200 Subject: [PATCH] Use joblib for parallel processing --- environment.yml | 1 + setup.cfg | 1 + src/repurpose/process.py | 197 +++++++++++++++++++++++++++------------ tests/test_process.py | 14 ++- 4 files changed, 149 insertions(+), 64 deletions(-) diff --git a/environment.yml b/environment.yml index 28880cc..a86796d 100644 --- a/environment.yml +++ b/environment.yml @@ -26,6 +26,7 @@ dependencies: - more_itertools - smecv_grid - tqdm + - joblib # Optional, for documentation and testing - nbconvert - sphinx_rtd_theme diff --git a/setup.cfg b/setup.cfg index 823fa8a..0b3a80b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -38,6 +38,7 @@ install_requires = pyresample tqdm more_itertools + joblib # The usage of test_requires is discouraged, see `Dependency Management` docs # tests_require = pytest; pytest-cov # Require a specific Python version, e.g. Python 2.7 or >= 3.4 diff --git a/src/repurpose/process.py b/src/repurpose/process.py index 029e435..f703193 100644 --- a/src/repurpose/process.py +++ b/src/repurpose/process.py @@ -3,26 +3,29 @@ import warnings import os -if 'numpy' in sys.modules: - warnings.warn("Numpy is already imported. Environment variables set in " - "repurpose.utils wont have any effect!") - -# Note: Must be set BEFORE the first numpy import!! -os.environ['MKL_NUM_THREADS'] = '1' -os.environ['NUMEXPR_NUM_THREADS'] = '1' -os.environ['OMP_NUM_THREADS'] = '1' -os.environ['MKL_DYNAMIC'] = 'FALSE' -os.environ['OPENBLAS_NUM_THREADS'] = '1' +# if 'numpy' in sys.modules: +# warnings.warn("Numpy is already imported. Please make sure " +# "`repurpose.process` is imported before numpy to avoid " +# "numpy multi-threading.") +# +# # Note: Must be set BEFORE the first numpy import!! +# os.environ['MKL_NUM_THREADS'] = '1' +# os.environ['NUMEXPR_NUM_THREADS'] = '1' +# os.environ['OMP_NUM_THREADS'] = '1' +# os.environ['MKL_DYNAMIC'] = 'FALSE' +# os.environ['OPENBLAS_NUM_THREADS'] = '1' import numpy as np from tqdm import tqdm import logging -from multiprocessing import Pool from datetime import datetime import sys from pathlib import Path -from typing import List +from typing import List, Any from glob import glob +from joblib import Parallel, delayed, parallel_config +from logging.handlers import QueueHandler, QueueListener +from multiprocessing import Manager class ImageBaseConnection: @@ -102,8 +105,8 @@ def read(self, timestamp, **kwargs): def rootdir() -> Path: - return Path(os.path.join(os.path.dirname( - os.path.abspath(__file__)))).parents[1] + p = str(os.path.join(os.path.dirname(os.path.abspath(__file__)))) + return Path(p).parents[1] def idx_chunks(idx, n=-1): @@ -123,6 +126,63 @@ def idx_chunks(idx, n=-1): for i in range(0, len(idx.values), n): yield idx[i:i + n] +class ProgressParallel(Parallel): + def __init__(self, use_tqdm=True, total=None, desc="", + *args, **kwargs) -> None: + """ + Joblib parallel with progress bar + """ + self._use_tqdm = use_tqdm + self._total = total + self._desc = desc + super().__init__(*args, **kwargs) + + def __call__(self, *args, **kwargs): + """ + Wraps progress bar around function calls + """ + with tqdm( + disable=not self._use_tqdm, total=self._total, desc=self._desc + ) as self._pbar: + return Parallel.__call__(self, *args, **kwargs) + def print_progress(self): + """ + Updated the progress bar after each successful call + """ + if self._total is None: + self._pbar.total = self.n_dispatched_tasks + self._pbar.n = self.n_completed_tasks + self._pbar.refresh() + +def configure_worker_logger(log_queue, log_level): + worker_logger = logging.getLogger('worker') + if not worker_logger.hasHandlers(): + h = QueueHandler(log_queue) + worker_logger.addHandler(h) + worker_logger.setLevel(log_level) + return worker_logger + +def run_with_error_handling(FUNC, + ignore_errors=False, log_queue=None, log_level="WARNING", + **kwargs) -> Any: + + if log_queue is not None: + logger = configure_worker_logger(log_queue, log_level) + logger_name = logger.name + kwargs['logger_name'] = logger_name + else: + logger = logging.getLogger() + + r = None + + try: + r = FUNC(**kwargs) + except Exception as e: + if ignore_errors: + logger.error(f"Error: {e}") + else: + raise e + return r def parallel_process_async( FUNC, @@ -136,8 +196,11 @@ def parallel_process_async( log_filename=None, loglevel="WARNING", verbose=False, - progress_bar_label="Processed" -) -> List: + progress_bar_label="Processed", + backend="loky", + sharedmem=False, + parallel_kwargs=None, +) -> list: """ Applies the passed function to all elements of the passed iterables. Parallel function calls are processed ASYNCHRONOUSLY (ie order of @@ -178,15 +241,21 @@ def parallel_process_async( loglevel: str, optional (default: "WARNING") Log level to use for logging. Must be one of ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]. - verbose: float, optional (default: False) + verbose: bool, optional (default: False) Print all logging messages to stdout, useful for debugging. progress_bar_label: str, optional (default: "Processed") Label to use for the progress bar. + backend: Literal["threading", "multiprocessing", "loky"] = "loky" + The backend to use for parallel execution (if n_proc > 1). + Defaults to "loky". See joblib docs for more info. + sharedmem: bool, optional (default:True) + Activate shared memory option (slow) Returns ------- - results: List - List of return values from each function call + results: list or None + List of return values from each function call or None if no return + values are found. """ if activate_logging: logger = logging.getLogger() @@ -213,7 +282,7 @@ def parallel_process_async( if log_file: os.makedirs(os.path.dirname(log_file), exist_ok=True) logging.basicConfig( - filename=log_file, + filename=str(log_file), level=loglevel.upper(), format="%(levelname)s %(asctime)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S", @@ -246,48 +315,53 @@ def parallel_process_async( kws.update(STATIC_KWARGS) process_kwargs.append(kws) - if show_progress_bars: - pbar = tqdm(total=len(process_kwargs), desc=progress_bar_label) - else: - pbar = None - - results = [] - - def update(r) -> None: - if r is not None: - results.append(r) - if pbar is not None: - pbar.update() - - def error(e) -> None: - if logger is not None: - logging.error(e) - if not ignore_errors: - raise e - if pbar is not None: - pbar.update() - if n_proc == 1: + logging.info("Processing metadata with {} process.".format(n_proc)) + results = [] + if show_progress_bars: + pbar = tqdm(total=len(process_kwargs), desc=progress_bar_label) + else: + pbar = None for kwargs in process_kwargs: - try: - r = FUNC(**kwargs) - update(r) - except Exception as e: - error(e) + r = run_with_error_handling(FUNC, ignore_errors, **kwargs) + if r is not None: + results.append(r) + if pbar is not None: + pbar.update() else: - with Pool(n_proc) as pool: - for kwds in process_kwargs: - pool.apply_async( - FUNC, - kwds=kwds, - callback=update, - error_callback=error, - ) - pool.close() - pool.join() - - if pbar is not None: - pbar.close() + logging.info(f"Processing metadata with {n_proc} processes.") + if logger is not None: + m = Manager() + q = m.Queue() + listener = QueueListener(q, *logger.handlers) + listener.start() + log_level = logger.getEffectiveLevel() + else: + q = None + log_level = None + listener = None + + with parallel_config(backend=backend, inner_max_num_threads=1): + results: list = ProgressParallel( + use_tqdm=show_progress_bars, + n_jobs=n_proc, + verbose=0, + total=len(process_kwargs), + desc=progress_bar_label, + require='sharedmem' if sharedmem else None, + return_as="list", + **parallel_kwargs or dict(), + )(delayed(run_with_error_handling)( + FUNC, ignore_errors, + log_queue=q, + log_level=log_level, + **kwargs) + for kwargs in process_kwargs) + + results = [r for r in results if r is not None] + + if listener is not None: + listener.stop() if logger is not None: if verbose: @@ -299,4 +373,7 @@ def error(e) -> None: handler.close() handlers.clear() - return results + if len(results) == 0: + return None + else: + return results diff --git a/tests/test_process.py b/tests/test_process.py index c320a97..fd0c1ec 100644 --- a/tests/test_process.py +++ b/tests/test_process.py @@ -3,6 +3,7 @@ import time import tempfile import logging +import pytest from repurpose.process import parallel_process_async, idx_chunks @@ -24,12 +25,17 @@ def func(x: int, p: int): logging.info(f'x={x}, p={p}') return x**p -def test_apply_to_elements(): + +@pytest.mark.parametrize("n_proc,backend", [ + ("1", None), # backend doesn't matter in this case + ("2", "threading"), ("2", "multiprocessing"), ("2", "loky") + ]) +def test_apply_to_elements(n_proc, backend): iter_kwargs = {'x': [1, 2, 3, 4]} static_kwargs = {'p': 2} with tempfile.TemporaryDirectory() as log_path: res = parallel_process_async( - func, iter_kwargs, static_kwargs, n_proc=1, + func, iter_kwargs, static_kwargs, n_proc=int(n_proc), show_progress_bars=False, verbose=False, loglevel="DEBUG", - ignore_errors=True, log_path=log_path) - assert sorted(res) == [1, 4, 9, 16] + ignore_errors=True, log_path=log_path, backend=backend) + assert res == [1, 4, 9, 16]