Skip to content

Commit f5d3eb0

Browse files
committed
tmp changes
1 parent c8154d2 commit f5d3eb0

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

mkl_fft/interfaces/_scipy_fft.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
import contextlib
3333
import contextvars
3434
import operator
35-
import os
3635
from numbers import Number
3736

3837
import mkl
@@ -101,7 +100,8 @@ def _workers_to_num_threads(w):
101100
if _w == 0:
102101
raise ValueError("Number of workers must not be zero")
103102
if _w < 0:
104-
_cpu_count = os.cpu_count()
103+
# SciPy uses os.cpu_count()
104+
_cpu_count = mkl.get_max_threads() # pylint: disable=no-member
105105
_w += _cpu_count + 1
106106
if _w <= 0:
107107
raise ValueError(

mkl_fft/tests/third_party/scipy/test_multithreading.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import multiprocessing
55
import os
66

7+
import mkl
78
import numpy as np
89
import pytest
910
from numpy.testing import assert_allclose
@@ -80,22 +81,26 @@ def test_invalid_workers(x):
8081

8182

8283
def test_set_get_workers():
83-
cpus = os.cpu_count()
84+
# cpus = os.cpu_count()
85+
threads = mkl.get_max_threads() # pylint: disable=no-member
86+
# cpus and threads are usually the same but in CI, cpus = 4 and threads = 2
87+
# SciPy uses `os.cpu_count()` to get the number of workers, while
88+
# `mkl_fft.interfaces.scipy_fft` uses `mkl.get_max_threads()`
8489

8590
# default value is max number of threads unlike stock SciPy
86-
assert fft.get_workers() == cpus
91+
assert fft.get_workers() == threads
8792
with fft.set_workers(4):
8893
assert fft.get_workers() == 4
8994

9095
with fft.set_workers(-1):
91-
assert fft.get_workers() == cpus
96+
assert fft.get_workers() == threads
9297

9398
assert fft.get_workers() == 4
9499

95100
# default value is max number of threads unlike stock SciPy
96-
assert fft.get_workers() == cpus
101+
assert fft.get_workers() == threads
97102

98-
with fft.set_workers(-cpus):
103+
with fft.set_workers(-threads):
99104
assert fft.get_workers() == 1
100105

101106

0 commit comments

Comments
 (0)