Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Joblib for parallelisation #25

Merged
merged 5 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,23 @@ on:
pull_request:
workflow_dispatch:
schedule:
- cron: '0 0 * * *' # daily
- cron: '0 0 * * *' # nightly build

jobs:
build:
name: Build py${{ matrix.python-version }} @ ${{ matrix.os }} 🐍
name: py${{ matrix.python-version }} @ ${{ matrix.os }}
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: ['3.8', '3.9', '3.10', '3.11']
os: ["ubuntu-latest"]
include:
- os: "ubuntu-latest"
python-version: '3.8' # first supported
- os: "ubuntu-latest"
python-version: '3.12' # latest supported
- os: "windows-latest"
python-version: '3.11'
python-version: '3.12' # latest supported
- os: "macos-latest"
python-version: '3.12' # latest supported
steps:
- uses: actions/checkout@v4
with:
Expand Down Expand Up @@ -106,7 +110,7 @@ jobs:
echo "GITHUB_REF = $GITHUB_REF"
echo "GITHUB_REPOSITORY = $GITHUB_REPOSITORY"
- name: Download Artifacts
uses: actions/download-artifact@v2
uses: actions/download-artifact@v4
- name: Display downloaded files
run: ls -aR
- name: Upload to PyPI
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies:
- more_itertools
- smecv_grid
- tqdm
- joblib
# Optional, for documentation and testing
- nbconvert
- sphinx_rtd_theme
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ install_requires =
pyresample
tqdm
more_itertools
joblib
# The usage of test_requires is discouraged, see `Dependency Management` docs
# tests_require = pytest; pytest-cov
# Require a specific Python version, e.g. Python 2.7 or >= 3.4
Expand Down
219 changes: 153 additions & 66 deletions src/repurpose/process.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,25 @@
import sys
import time
import warnings
import os

if 'numpy' in sys.modules:
warnings.warn("Numpy is already imported. Environment variables set in "
"repurpose.utils wont have any effect!")

# Note: Must be set BEFORE the first numpy import!!
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['NUMEXPR_NUM_THREADS'] = '1'
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['MKL_DYNAMIC'] = 'FALSE'
os.environ['OPENBLAS_NUM_THREADS'] = '1'

import traceback
import numpy as np
from tqdm import tqdm
import logging
from multiprocessing import Pool
from datetime import datetime
import sys
from pathlib import Path
from typing import List
from typing import List, Any
from glob import glob
from joblib import Parallel, delayed, parallel_config
from logging.handlers import QueueHandler, QueueListener
from multiprocessing import Manager


class ImageBaseConnection:
Expand Down Expand Up @@ -102,8 +99,8 @@ def read(self, timestamp, **kwargs):


def rootdir() -> Path:
return Path(os.path.join(os.path.dirname(
os.path.abspath(__file__)))).parents[1]
p = str(os.path.join(os.path.dirname(os.path.abspath(__file__))))
return Path(p).parents[1]


def idx_chunks(idx, n=-1):
Expand All @@ -123,6 +120,68 @@ def idx_chunks(idx, n=-1):
for i in range(0, len(idx.values), n):
yield idx[i:i + n]

class ProgressParallel(Parallel):
def __init__(self, use_tqdm=True, total=None, desc="",
*args, **kwargs) -> None:
"""
Joblib parallel with progress bar
"""
self._use_tqdm = use_tqdm
self._total = total
self._desc = desc
super().__init__(*args, **kwargs)

def __call__(self, *args, **kwargs):
"""
Wraps progress bar around function calls
"""
with tqdm(
disable=not self._use_tqdm, total=self._total, desc=self._desc
) as self._pbar:
return Parallel.__call__(self, *args, **kwargs)
def print_progress(self):
"""
Updated the progress bar after each successful call
"""
if self._total is None:
self._pbar.total = self.n_dispatched_tasks
self._pbar.n = self.n_completed_tasks
self._pbar.refresh()

def configure_worker_logger(log_queue, log_level, name):
worker_logger = logging.getLogger(name)
if not worker_logger.hasHandlers():
h = QueueHandler(log_queue)
worker_logger.addHandler(h)
worker_logger.setLevel(log_level)
return worker_logger

def run_with_error_handling(FUNC,
ignore_errors=False,
log_queue=None,
log_level="WARNING",
logger_name=None,
**kwargs) -> Any:

if log_queue is not None:
logger = configure_worker_logger(log_queue, log_level, logger_name)
else:
# normal logger
logger = logging.getLogger(logger_name)

r = None

try:
r = FUNC(**kwargs)
except Exception as e:
if ignore_errors:
logger.error(f"The following ERROR was raised in the parallelized "
f"function `{FUNC.__name__}` but was ignored due to "
f"the chosen settings: "
f"{traceback.format_exc()}")
else:
raise e
return r

def parallel_process_async(
FUNC,
Expand All @@ -135,9 +194,13 @@ def parallel_process_async(
log_path=None,
log_filename=None,
loglevel="WARNING",
logger_name=None,
verbose=False,
progress_bar_label="Processed"
) -> List:
progress_bar_label="Processed",
backend="loky",
sharedmem=False,
parallel_kwargs=None,
) -> list:
"""
Applies the passed function to all elements of the passed iterables.
Parallel function calls are processed ASYNCHRONOUSLY (ie order of
Expand Down Expand Up @@ -176,29 +239,41 @@ def parallel_process_async(
Name of the logfile in `log_path to create. If None is chosen, a name
is created automatically. If `log_path is None, this has no effect.
loglevel: str, optional (default: "WARNING")
Log level to use for logging. Must be one of
Which level should be logged. Must be one of
["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"].
verbose: float, optional (default: False)
logger_name: str, optional (default: None)
The name to assign to the logger that can be accessed in FUNC to
log to. If not given, then the root logger is used. e.g
```
logger = logging.getLogger(<logger_name>)
logger.error("Some error message")
```
verbose: bool, optional (default: False)
Print all logging messages to stdout, useful for debugging.
progress_bar_label: str, optional (default: "Processed")
Label to use for the progress bar.
backend: Literal["threading", "multiprocessing", "loky"] = "loky"
The backend to use for parallel execution (if n_proc > 1).
Defaults to "loky". See joblib docs for more info.
sharedmem: bool, optional (default:True)
Activate shared memory option (slow)

Returns
-------
results: List
List of return values from each function call
results: list or None
List of return values from each function call or None if no return
values are found.
"""
if activate_logging:
logger = logging.getLogger()
logger = logging.getLogger(logger_name)
logger.setLevel(loglevel.upper())

if STATIC_KWARGS is None:
STATIC_KWARGS = dict()

if verbose:
# in this case we also print ALL log messages
streamHandler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
streamHandler.setFormatter(formatter)
logger.setLevel('DEBUG')
logger.addHandler(streamHandler)

Expand All @@ -211,14 +286,14 @@ def parallel_process_async(
log_file = None

if log_file:
# in this case the logger should write to file
os.makedirs(os.path.dirname(log_file), exist_ok=True)
logging.basicConfig(
filename=log_file,
level=loglevel.upper(),
format="%(levelname)s %(asctime)s %(message)s",
filehandler = logging.FileHandler(log_file)
filehandler.setFormatter(logging.Formatter(
"%(levelname)s %(asctime)s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
force=True,
)
))
logger.addHandler(filehandler)
else:
logger = None

Expand Down Expand Up @@ -246,48 +321,57 @@ def parallel_process_async(
kws.update(STATIC_KWARGS)
process_kwargs.append(kws)

if show_progress_bars:
pbar = tqdm(total=len(process_kwargs), desc=progress_bar_label)
else:
pbar = None

results = []

def update(r) -> None:
if r is not None:
results.append(r)
if pbar is not None:
pbar.update()

def error(e) -> None:
if logger is not None:
logging.error(e)
if not ignore_errors:
raise e
if pbar is not None:
pbar.update()

if n_proc == 1:
results = []
if show_progress_bars:
pbar = tqdm(total=len(process_kwargs), desc=progress_bar_label)
else:
pbar = None

for kwargs in process_kwargs:
try:
r = FUNC(**kwargs)
update(r)
except Exception as e:
error(e)
r = run_with_error_handling(FUNC, ignore_errors,
logger_name=logger_name,
**kwargs)
if r is not None:
results.append(r)
if pbar is not None:
pbar.update()
else:
with Pool(n_proc) as pool:
for kwds in process_kwargs:
pool.apply_async(
FUNC,
kwds=kwds,
callback=update,
error_callback=error,
)
pool.close()
pool.join()

if pbar is not None:
pbar.close()
if logger is not None:
log_level = logger.getEffectiveLevel()
m = Manager()
q = m.Queue()
listener = QueueListener(q, *logger.handlers,
respect_handler_level=True)
listener.start()
else:
q = None
log_level = None
listener = None

n = 1 if backend == 'loky' else None
with parallel_config(backend=backend, inner_max_num_threads=n):
results: list = ProgressParallel(
use_tqdm=show_progress_bars,
n_jobs=n_proc,
verbose=0,
total=len(process_kwargs),
desc=progress_bar_label,
require='sharedmem' if sharedmem else None,
return_as="list",
**parallel_kwargs or dict(),
)(delayed(run_with_error_handling)(
FUNC, ignore_errors,
log_queue=q,
log_level=log_level,
logger_name=logger_name,
**kwargs)
for kwargs in process_kwargs)

results = [r for r in results if r is not None]

if listener is not None:
listener.stop()

if logger is not None:
if verbose:
Expand All @@ -299,4 +383,7 @@ def error(e) -> None:
handler.close()
handlers.clear()

return results
if len(results) == 0:
return None
else:
return results
14 changes: 10 additions & 4 deletions tests/test_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time
import tempfile
import logging
import pytest

from repurpose.process import parallel_process_async, idx_chunks

Expand All @@ -24,12 +25,17 @@ def func(x: int, p: int):
logging.info(f'x={x}, p={p}')
return x**p

def test_apply_to_elements():

@pytest.mark.parametrize("n_proc,backend", [
("1", None), # backend doesn't matter in this case
("2", "threading"), ("2", "multiprocessing"), ("2", "loky")
])
def test_apply_to_elements(n_proc, backend):
iter_kwargs = {'x': [1, 2, 3, 4]}
static_kwargs = {'p': 2}
with tempfile.TemporaryDirectory() as log_path:
res = parallel_process_async(
func, iter_kwargs, static_kwargs, n_proc=1,
func, iter_kwargs, static_kwargs, n_proc=int(n_proc),
show_progress_bars=False, verbose=False, loglevel="DEBUG",
ignore_errors=True, log_path=log_path)
assert sorted(res) == [1, 4, 9, 16]
ignore_errors=True, log_path=log_path, backend=backend)
assert res == [1, 4, 9, 16]
Loading