From 8fe0f1e000292b61ef81a7873f4b9231555a86c0 Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Fri, 25 Apr 2025 12:50:43 -0700 Subject: [PATCH] fix issues with set_workers --- CHANGELOG.md | 1 + mkl_fft/interfaces/README.md | 23 ++++++++ mkl_fft/interfaces/_scipy_fft.py | 57 +++++++------------ .../third_party/scipy/test_multithreading.py | 29 ++++++---- 4 files changed, 62 insertions(+), 48 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9e34420..af31a6e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Fixed an issue for calling `mkl_fft.interfaces.numpy.fftn` with an empty axes [gh-139](https://github.com/IntelPython/mkl_fft/pull/139) * Fixed an issue for calling `mkl_fft.interfaces.numpy.fftn` with a zero-size array [gh-139](https://github.com/IntelPython/mkl_fft/pull/139) * Fixed inconsistency of input and output arrays dtype for `irfft` function [gh-180](https://github.com/IntelPython/mkl_fft/pull/180) +* Fixed issues with `set_workers` function in SciPy interface `mkl_fft.interfaces.scipy_fft` [gh-183](https://github.com/IntelPython/mkl_fft/pull/183) ## [1.3.14] (04/10/2025) diff --git a/mkl_fft/interfaces/README.md b/mkl_fft/interfaces/README.md index d4cae26..0e8969f 100644 --- a/mkl_fft/interfaces/README.md +++ b/mkl_fft/interfaces/README.md @@ -42,6 +42,8 @@ This interface is a drop-in replacement for the [`scipy.fft`](https://scipy.gith * Helper functions: `fftshift`, `ifftshift`, `fftfreq`, `rfftfreq`, `set_workers`, `get_workers`. All of these functions, except for `set_workers` and `get_workers`, serve as a fallback to the SciPy implementation and are included for completeness. +Note that in computing FFTs, the default value of `workers` parameter is the maximum number of threads available unlike the default behavior of SciPy where only one thread is used. + The following example shows how to use this interface for calculating a 1D FFT. ```python @@ -102,3 +104,24 @@ with scipy.fft.set_backend(mkl_backend, only=True): print(f"Time with OneMKL FFT backend installed: {t2:.1f} seconds") # Time with MKL FFT backend installed: 9.1 seconds ``` + +In the following example, we use `set_worker` to control the number of threads when `mkl_fft` is being used as a backend for SciPy. + +```python +import numpy, mkl, scipy +import mkl_fft.interfaces.scipy_fft as mkl_fft +import scipy +a = numpy.random.randn(128, 64) + 1j*numpy.random.randn(128, 64) +scipy.fft.set_global_backend(mkl_fft) # set mkl_fft as global backend + +mkl.verbose(1) +# True +mkl.get_max_threads() +# 112 +y = scipy.signal.fftconvolve(a, a) # Note that Nthr:112 +# MKL_VERBOSE FFT(dcbo256x128,input_strides:{0,128,1},output_strides:{0,128,1},bScale:3.05176e-05,tLim:56,unaligned_input,unaligned_output,desc:0x563aefe86180) 165.02us CNR:OFF Dyn:1 FastMM:1 TID:0 NThr:112 + +with mkl_fft.set_workers(4): + y = scipy.signal.fftconvolve(a, a) # Note that Nthr:4 +# MKL_VERBOSE FFT(dcbo256x128,input_strides:{0,128,1},output_strides:{0,128,1},bScale:3.05176e-05,tLim:4,unaligned_output,desc:0x563aefe86180) 187.37us CNR:OFF Dyn:1 FastMM:1 TID:0 NThr:4 +``` diff --git a/mkl_fft/interfaces/_scipy_fft.py b/mkl_fft/interfaces/_scipy_fft.py index fa8482e..78b7dc3 100644 --- a/mkl_fft/interfaces/_scipy_fft.py +++ b/mkl_fft/interfaces/_scipy_fft.py @@ -32,7 +32,6 @@ import contextlib import contextvars import operator -import os from numbers import Number import mkl @@ -67,31 +66,13 @@ ] -class _cpu_max_threads_count: - def __init__(self): - self.cpu_count = None - self.max_threads_count = None - - def get_cpu_count(self): - if self.cpu_count is None: - max_threads = self.get_max_threads_count() - self.cpu_count = max_threads - return self.cpu_count - - def get_max_threads_count(self): - if self.max_threads_count is None: - # pylint: disable=no-member - self.max_threads_count = mkl.get_max_threads() - - return self.max_threads_count - - class _workers_data: def __init__(self, workers=None): - if workers: - self.workers_ = workers + if workers is not None: # workers = 0 should be handled + self.workers_ = _workers_to_num_threads(workers) else: - self.workers_ = _cpu_max_threads_count().get_cpu_count() + # Unlike SciPy, the default value is maximum number of threads + self.workers_ = mkl.get_max_threads() # pylint: disable=no-member self.workers_ = operator.index(self.workers_) @property @@ -109,8 +90,9 @@ def workers(self, workers_val): def _workers_to_num_threads(w): - """Handle conversion of workers to a positive number of threads in the - same way as scipy.fft.helpers._workers. + """ + Handle conversion of workers to a positive number of threads in the + same way as scipy.fft._pocketfft.helpers._workers. """ if w is None: return _workers_global_settings.get().workers @@ -118,12 +100,13 @@ def _workers_to_num_threads(w): if _w == 0: raise ValueError("Number of workers must not be zero") if _w < 0: - ub = os.cpu_count() - _w += ub + 1 + # SciPy uses os.cpu_count() + _cpu_count = mkl.get_max_threads() # pylint: disable=no-member + _w += _cpu_count + 1 if _w <= 0: raise ValueError( - "workers value out of range; got {}, must not be" - " less than {}".format(w, -ub) + f"workers value out of range; got {w}, must not be less " + f"than {-_cpu_count}" ) return _w @@ -135,14 +118,16 @@ def __init__(self, workers): def __enter__(self): try: + # mkl.set_num_threads_local sets the number of threads to the + # given input number, and returns the previous number of threads # pylint: disable=no-member self.prev_num_threads = mkl.set_num_threads_local(self.n_threads) except Exception as e: raise ValueError( - "Class argument {} result in invalid number of threads {}".format( - self.workers, self.n_threads - ) + f"Class argument {self.workers} results in invalid number of " + f"threads {self.n_threads}" ) from e + return self def __exit__(self, *args): # restore old value @@ -684,21 +669,19 @@ def get_workers(): @contextlib.contextmanager -def set_workers(n_workers): +def set_workers(workers): """ Set the value of workers used by default, returns the previous value. For full documentation refer to `scipy.fft.set_workers`. """ - nw = operator.index(n_workers) + nw = operator.index(workers) token = None try: new_wd = _workers_data(nw) token = _workers_global_settings.set(new_wd) yield finally: - if token: + if token is not None: _workers_global_settings.reset(token) - else: - raise ValueError diff --git a/mkl_fft/tests/third_party/scipy/test_multithreading.py b/mkl_fft/tests/third_party/scipy/test_multithreading.py index 945a5e4..470f2ae 100644 --- a/mkl_fft/tests/third_party/scipy/test_multithreading.py +++ b/mkl_fft/tests/third_party/scipy/test_multithreading.py @@ -4,6 +4,7 @@ import multiprocessing import os +import mkl import numpy as np import pytest from numpy.testing import assert_allclose @@ -52,7 +53,7 @@ def _mt_fft(x): return fft.fft(x, workers=2) -@pytest.mark.slow +# @pytest.mark.slow def test_mixed_threads_processes(x): # Test that the fft threadpool is safe to use before & after fork @@ -68,36 +69,42 @@ def test_mixed_threads_processes(x): def test_invalid_workers(x): - cpus = os.cpu_count() + # cpus = os.cpu_count() + threads = mkl.get_max_threads() # pylint: disable=no-member + # cpus and threads are usually the same but in CI, cpus = 4 and threads = 2 + # SciPy uses `os.cpu_count()` to get the number of workers, while + # `mkl_fft.interfaces.scipy_fft` uses `mkl.get_max_threads()` - fft.ifft([1], workers=-cpus) + fft.ifft([1], workers=-threads) with pytest.raises(ValueError, match="workers must not be zero"): fft.fft(x, workers=0) with pytest.raises(ValueError, match="workers value out of range"): - fft.ifft(x, workers=-cpus - 1) + fft.ifft(x, workers=-threads - 1) -@pytest.mark.skip() def test_set_get_workers(): - cpus = os.cpu_count() - assert fft.get_workers() == 1 + # cpus = os.cpu_count() + threads = mkl.get_max_threads() # pylint: disable=no-member + + # default value is max number of threads unlike stock SciPy + assert fft.get_workers() == threads with fft.set_workers(4): assert fft.get_workers() == 4 with fft.set_workers(-1): - assert fft.get_workers() == cpus + assert fft.get_workers() == threads assert fft.get_workers() == 4 - assert fft.get_workers() == 1 + # default value is max number of threads unlike stock SciPy + assert fft.get_workers() == threads - with fft.set_workers(-cpus): + with fft.set_workers(-threads): assert fft.get_workers() == 1 -@pytest.mark.skip("mkl_fft does not validate workers") def test_set_workers_invalid(): with pytest.raises(ValueError, match="workers must not be zero"):