Skip to content

Commit 384fcb6

Browse files
authored
fix issues with set_workers (#183)
* fix issues with set_workers * address comment
1 parent d3f7e2d commit 384fcb6

File tree

4 files changed

+64
-50
lines changed

4 files changed

+64
-50
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2020
* Fixed an issue for calling `mkl_fft.interfaces.numpy.fftn` with an empty axes [gh-139](https://github.com/IntelPython/mkl_fft/pull/139)
2121
* Fixed an issue for calling `mkl_fft.interfaces.numpy.fftn` with a zero-size array [gh-139](https://github.com/IntelPython/mkl_fft/pull/139)
2222
* Fixed inconsistency of input and output arrays dtype for `irfft` function [gh-180](https://github.com/IntelPython/mkl_fft/pull/180)
23+
* Fixed issues with `set_workers` function in SciPy interface `mkl_fft.interfaces.scipy_fft` [gh-183](https://github.com/IntelPython/mkl_fft/pull/183)
2324

2425
## [1.3.14] (04/10/2025)
2526

mkl_fft/interfaces/README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ This interface is a drop-in replacement for the [`scipy.fft`](https://scipy.gith
4242

4343
* Helper functions: `fftshift`, `ifftshift`, `fftfreq`, `rfftfreq`, `set_workers`, `get_workers`. All of these functions, except for `set_workers` and `get_workers`, serve as a fallback to the SciPy implementation and are included for completeness.
4444

45+
Note that in computing FFTs, the default value of `workers` parameter is the maximum number of threads available unlike the default behavior of SciPy where only one thread is used.
46+
4547
The following example shows how to use this interface for calculating a 1D FFT.
4648

4749
```python
@@ -102,3 +104,24 @@ with scipy.fft.set_backend(mkl_backend, only=True):
102104
print(f"Time with OneMKL FFT backend installed: {t2:.1f} seconds")
103105
# Time with MKL FFT backend installed: 9.1 seconds
104106
```
107+
108+
In the following example, we use `set_worker` to control the number of threads when `mkl_fft` is being used as a backend for SciPy.
109+
110+
```python
111+
import numpy, mkl, scipy
112+
import mkl_fft.interfaces.scipy_fft as mkl_fft
113+
import scipy
114+
a = numpy.random.randn(128, 64) + 1j*numpy.random.randn(128, 64)
115+
scipy.fft.set_global_backend(mkl_fft) # set mkl_fft as global backend
116+
117+
mkl.verbose(1)
118+
# True
119+
mkl.get_max_threads()
120+
# 112
121+
y = scipy.signal.fftconvolve(a, a) # Note that Nthr:112
122+
# MKL_VERBOSE FFT(dcbo256x128,input_strides:{0,128,1},output_strides:{0,128,1},bScale:3.05176e-05,tLim:56,unaligned_input,unaligned_output,desc:0x563aefe86180) 165.02us CNR:OFF Dyn:1 FastMM:1 TID:0 NThr:112
123+
124+
with mkl_fft.set_workers(4):
125+
y = scipy.signal.fftconvolve(a, a) # Note that Nthr:4
126+
# MKL_VERBOSE FFT(dcbo256x128,input_strides:{0,128,1},output_strides:{0,128,1},bScale:3.05176e-05,tLim:4,unaligned_output,desc:0x563aefe86180) 187.37us CNR:OFF Dyn:1 FastMM:1 TID:0 NThr:4
127+
```

mkl_fft/interfaces/_scipy_fft.py

Lines changed: 20 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
import contextlib
3333
import contextvars
3434
import operator
35-
import os
3635
from numbers import Number
3736

3837
import mkl
@@ -67,31 +66,13 @@
6766
]
6867

6968

70-
class _cpu_max_threads_count:
71-
def __init__(self):
72-
self.cpu_count = None
73-
self.max_threads_count = None
74-
75-
def get_cpu_count(self):
76-
if self.cpu_count is None:
77-
max_threads = self.get_max_threads_count()
78-
self.cpu_count = max_threads
79-
return self.cpu_count
80-
81-
def get_max_threads_count(self):
82-
if self.max_threads_count is None:
83-
# pylint: disable=no-member
84-
self.max_threads_count = mkl.get_max_threads()
85-
86-
return self.max_threads_count
87-
88-
8969
class _workers_data:
9070
def __init__(self, workers=None):
91-
if workers:
92-
self.workers_ = workers
71+
if workers is not None: # workers = 0 should be handled
72+
self.workers_ = _workers_to_num_threads(workers)
9373
else:
94-
self.workers_ = _cpu_max_threads_count().get_cpu_count()
74+
# Unlike SciPy, the default value is maximum number of threads
75+
self.workers_ = mkl.get_max_threads() # pylint: disable=no-member
9576
self.workers_ = operator.index(self.workers_)
9677

9778
@property
@@ -109,21 +90,23 @@ def workers(self, workers_val):
10990

11091

11192
def _workers_to_num_threads(w):
112-
"""Handle conversion of workers to a positive number of threads in the
113-
same way as scipy.fft.helpers._workers.
93+
"""
94+
Handle conversion of workers to a positive number of threads in the
95+
same way as scipy.fft._pocketfft.helpers._workers.
11496
"""
11597
if w is None:
11698
return _workers_global_settings.get().workers
11799
_w = operator.index(w)
118100
if _w == 0:
119101
raise ValueError("Number of workers must not be zero")
120102
if _w < 0:
121-
ub = os.cpu_count()
122-
_w += ub + 1
103+
# SciPy uses os.cpu_count()
104+
_cpu_count = mkl.get_max_threads() # pylint: disable=no-member
105+
_w += _cpu_count + 1
123106
if _w <= 0:
124107
raise ValueError(
125-
"workers value out of range; got {}, must not be"
126-
" less than {}".format(w, -ub)
108+
f"workers value out of range; got {w}, must not be less "
109+
f"than {-_cpu_count}"
127110
)
128111
return _w
129112

@@ -135,14 +118,16 @@ def __init__(self, workers):
135118

136119
def __enter__(self):
137120
try:
121+
# mkl.set_num_threads_local sets the number of threads to the
122+
# given input number, and returns the previous number of threads
138123
# pylint: disable=no-member
139124
self.prev_num_threads = mkl.set_num_threads_local(self.n_threads)
140125
except Exception as e:
141126
raise ValueError(
142-
"Class argument {} result in invalid number of threads {}".format(
143-
self.workers, self.n_threads
144-
)
127+
f"Class argument {self.workers} results in invalid number of "
128+
f"threads {self.n_threads}"
145129
) from e
130+
return self
146131

147132
def __exit__(self, *args):
148133
# restore old value
@@ -684,21 +669,19 @@ def get_workers():
684669

685670

686671
@contextlib.contextmanager
687-
def set_workers(n_workers):
672+
def set_workers(workers):
688673
"""
689674
Set the value of workers used by default, returns the previous value.
690675
691676
For full documentation refer to `scipy.fft.set_workers`.
692677
693678
"""
694-
nw = operator.index(n_workers)
679+
nw = operator.index(workers)
695680
token = None
696681
try:
697682
new_wd = _workers_data(nw)
698683
token = _workers_global_settings.set(new_wd)
699684
yield
700685
finally:
701-
if token:
686+
if token is not None:
702687
_workers_global_settings.reset(token)
703-
else:
704-
raise ValueError

mkl_fft/tests/third_party/scipy/test_multithreading.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
# https://github.com/scipy/scipy/blob/main/scipy/fft/tests/test_multithreading.py.py
33

44
import multiprocessing
5-
import os
65

6+
import mkl
77
import numpy as np
88
import pytest
99
from numpy.testing import assert_allclose
@@ -52,7 +52,7 @@ def _mt_fft(x):
5252
return fft.fft(x, workers=2)
5353

5454

55-
@pytest.mark.slow
55+
# @pytest.mark.slow
5656
def test_mixed_threads_processes(x):
5757
# Test that the fft threadpool is safe to use before & after fork
5858

@@ -68,42 +68,49 @@ def test_mixed_threads_processes(x):
6868

6969

7070
def test_invalid_workers(x):
71-
cpus = os.cpu_count()
71+
# cpus = os.cpu_count()
72+
threads = mkl.get_max_threads() # pylint: disable=no-member
73+
# cpus and threads are usually the same but in CI, cpus = 4 and threads = 2
74+
# SciPy uses `os.cpu_count()` to get the number of workers, while
75+
# `mkl_fft.interfaces.scipy_fft` uses `mkl.get_max_threads()`
7276

73-
fft.ifft([1], workers=-cpus)
77+
fft.ifft([1], workers=-threads)
7478

7579
with pytest.raises(ValueError, match="workers must not be zero"):
7680
fft.fft(x, workers=0)
7781

7882
with pytest.raises(ValueError, match="workers value out of range"):
79-
fft.ifft(x, workers=-cpus - 1)
83+
fft.ifft(x, workers=-threads - 1)
8084

8185

82-
@pytest.mark.skip()
8386
def test_set_get_workers():
84-
cpus = os.cpu_count()
85-
assert fft.get_workers() == 1
87+
# cpus = os.cpu_count()
88+
threads = mkl.get_max_threads() # pylint: disable=no-member
89+
90+
# default value is max number of threads unlike stock SciPy
91+
assert fft.get_workers() == threads
8692
with fft.set_workers(4):
8793
assert fft.get_workers() == 4
8894

8995
with fft.set_workers(-1):
90-
assert fft.get_workers() == cpus
96+
assert fft.get_workers() == threads
9197

9298
assert fft.get_workers() == 4
9399

94-
assert fft.get_workers() == 1
100+
# default value is max number of threads unlike stock SciPy
101+
assert fft.get_workers() == threads
95102

96-
with fft.set_workers(-cpus):
103+
with fft.set_workers(-threads):
97104
assert fft.get_workers() == 1
98105

99106

100-
@pytest.mark.skip("mkl_fft does not validate workers")
101107
def test_set_workers_invalid():
102108

103109
with pytest.raises(ValueError, match="workers must not be zero"):
104110
with fft.set_workers(0):
105111
pass
106112

107113
with pytest.raises(ValueError, match="workers value out of range"):
108-
with fft.set_workers(-os.cpu_count() - 1):
114+
# pylint: disable=no-member
115+
with fft.set_workers(-mkl.get_max_threads() - 1):
109116
pass

0 commit comments

Comments
 (0)