Skip to content

fix issues with set_workers #183

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* Fixed an issue for calling `mkl_fft.interfaces.numpy.fftn` with an empty axes [gh-139](https://github.com/IntelPython/mkl_fft/pull/139)
* Fixed an issue for calling `mkl_fft.interfaces.numpy.fftn` with a zero-size array [gh-139](https://github.com/IntelPython/mkl_fft/pull/139)
* Fixed inconsistency of input and output arrays dtype for `irfft` function [gh-180](https://github.com/IntelPython/mkl_fft/pull/180)
* Fixed issues with `set_workers` function in SciPy interface `mkl_fft.interfaces.scipy_fft` [gh-183](https://github.com/IntelPython/mkl_fft/pull/183)

## [1.3.14] (04/10/2025)

Expand Down
23 changes: 23 additions & 0 deletions mkl_fft/interfaces/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ This interface is a drop-in replacement for the [`scipy.fft`](https://scipy.gith

* Helper functions: `fftshift`, `ifftshift`, `fftfreq`, `rfftfreq`, `set_workers`, `get_workers`. All of these functions, except for `set_workers` and `get_workers`, serve as a fallback to the SciPy implementation and are included for completeness.

Note that in computing FFTs, the default value of `workers` parameter is the maximum number of threads available unlike the default behavior of SciPy where only one thread is used.

The following example shows how to use this interface for calculating a 1D FFT.

```python
Expand Down Expand Up @@ -102,3 +104,24 @@ with scipy.fft.set_backend(mkl_backend, only=True):
print(f"Time with OneMKL FFT backend installed: {t2:.1f} seconds")
# Time with MKL FFT backend installed: 9.1 seconds
```

In the following example, we use `set_worker` to control the number of threads when `mkl_fft` is being used as a backend for SciPy.

```python
import numpy, mkl, scipy
import mkl_fft.interfaces.scipy_fft as mkl_fft
import scipy
a = numpy.random.randn(128, 64) + 1j*numpy.random.randn(128, 64)
scipy.fft.set_global_backend(mkl_fft) # set mkl_fft as global backend

mkl.verbose(1)
# True
mkl.get_max_threads()
# 112
y = scipy.signal.fftconvolve(a, a) # Note that Nthr:112
# MKL_VERBOSE FFT(dcbo256x128,input_strides:{0,128,1},output_strides:{0,128,1},bScale:3.05176e-05,tLim:56,unaligned_input,unaligned_output,desc:0x563aefe86180) 165.02us CNR:OFF Dyn:1 FastMM:1 TID:0 NThr:112

with mkl_fft.set_workers(4):
y = scipy.signal.fftconvolve(a, a) # Note that Nthr:4
# MKL_VERBOSE FFT(dcbo256x128,input_strides:{0,128,1},output_strides:{0,128,1},bScale:3.05176e-05,tLim:4,unaligned_output,desc:0x563aefe86180) 187.37us CNR:OFF Dyn:1 FastMM:1 TID:0 NThr:4
```
57 changes: 20 additions & 37 deletions mkl_fft/interfaces/_scipy_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import contextlib
import contextvars
import operator
import os
from numbers import Number

import mkl
Expand Down Expand Up @@ -67,31 +66,13 @@
]


class _cpu_max_threads_count:
def __init__(self):
self.cpu_count = None
self.max_threads_count = None

def get_cpu_count(self):
if self.cpu_count is None:
max_threads = self.get_max_threads_count()
self.cpu_count = max_threads
return self.cpu_count

def get_max_threads_count(self):
if self.max_threads_count is None:
# pylint: disable=no-member
self.max_threads_count = mkl.get_max_threads()

return self.max_threads_count


class _workers_data:
def __init__(self, workers=None):
if workers:
self.workers_ = workers
if workers is not None: # workers = 0 should be handled
self.workers_ = _workers_to_num_threads(workers)
else:
self.workers_ = _cpu_max_threads_count().get_cpu_count()
# Unlike SciPy, the default value is maximum number of threads
self.workers_ = mkl.get_max_threads() # pylint: disable=no-member
self.workers_ = operator.index(self.workers_)

@property
Expand All @@ -109,21 +90,23 @@ def workers(self, workers_val):


def _workers_to_num_threads(w):
"""Handle conversion of workers to a positive number of threads in the
same way as scipy.fft.helpers._workers.
"""
Handle conversion of workers to a positive number of threads in the
same way as scipy.fft._pocketfft.helpers._workers.
"""
if w is None:
return _workers_global_settings.get().workers
_w = operator.index(w)
if _w == 0:
raise ValueError("Number of workers must not be zero")
if _w < 0:
ub = os.cpu_count()
_w += ub + 1
# SciPy uses os.cpu_count()
_cpu_count = mkl.get_max_threads() # pylint: disable=no-member
_w += _cpu_count + 1
if _w <= 0:
raise ValueError(
"workers value out of range; got {}, must not be"
" less than {}".format(w, -ub)
f"workers value out of range; got {w}, must not be less "
f"than {-_cpu_count}"
)
return _w

Expand All @@ -135,14 +118,16 @@ def __init__(self, workers):

def __enter__(self):
try:
# mkl.set_num_threads_local sets the number of threads to the
# given input number, and returns the previous number of threads
# pylint: disable=no-member
self.prev_num_threads = mkl.set_num_threads_local(self.n_threads)
except Exception as e:
raise ValueError(
"Class argument {} result in invalid number of threads {}".format(
self.workers, self.n_threads
)
f"Class argument {self.workers} results in invalid number of "
f"threads {self.n_threads}"
) from e
return self

def __exit__(self, *args):
# restore old value
Expand Down Expand Up @@ -684,21 +669,19 @@ def get_workers():


@contextlib.contextmanager
def set_workers(n_workers):
def set_workers(workers):
"""
Set the value of workers used by default, returns the previous value.

For full documentation refer to `scipy.fft.set_workers`.

"""
nw = operator.index(n_workers)
nw = operator.index(workers)
token = None
try:
new_wd = _workers_data(nw)
token = _workers_global_settings.set(new_wd)
yield
finally:
if token:
if token is not None:
_workers_global_settings.reset(token)
else:
raise ValueError
29 changes: 18 additions & 11 deletions mkl_fft/tests/third_party/scipy/test_multithreading.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import multiprocessing
import os

import mkl
import numpy as np
import pytest
from numpy.testing import assert_allclose
Expand Down Expand Up @@ -52,7 +53,7 @@ def _mt_fft(x):
return fft.fft(x, workers=2)


@pytest.mark.slow
# @pytest.mark.slow
def test_mixed_threads_processes(x):
# Test that the fft threadpool is safe to use before & after fork

Expand All @@ -68,36 +69,42 @@ def test_mixed_threads_processes(x):


def test_invalid_workers(x):
cpus = os.cpu_count()
# cpus = os.cpu_count()
threads = mkl.get_max_threads() # pylint: disable=no-member
# cpus and threads are usually the same but in CI, cpus = 4 and threads = 2
# SciPy uses `os.cpu_count()` to get the number of workers, while
# `mkl_fft.interfaces.scipy_fft` uses `mkl.get_max_threads()`

fft.ifft([1], workers=-cpus)
fft.ifft([1], workers=-threads)

with pytest.raises(ValueError, match="workers must not be zero"):
fft.fft(x, workers=0)

with pytest.raises(ValueError, match="workers value out of range"):
fft.ifft(x, workers=-cpus - 1)
fft.ifft(x, workers=-threads - 1)


@pytest.mark.skip()
def test_set_get_workers():
cpus = os.cpu_count()
assert fft.get_workers() == 1
# cpus = os.cpu_count()
threads = mkl.get_max_threads() # pylint: disable=no-member

# default value is max number of threads unlike stock SciPy
assert fft.get_workers() == threads
with fft.set_workers(4):
assert fft.get_workers() == 4

with fft.set_workers(-1):
assert fft.get_workers() == cpus
assert fft.get_workers() == threads

assert fft.get_workers() == 4

assert fft.get_workers() == 1
# default value is max number of threads unlike stock SciPy
assert fft.get_workers() == threads

with fft.set_workers(-cpus):
with fft.set_workers(-threads):
assert fft.get_workers() == 1


@pytest.mark.skip("mkl_fft does not validate workers")
def test_set_workers_invalid():

with pytest.raises(ValueError, match="workers must not be zero"):
Expand Down
Loading