Skip to content

Commit df1bbd8

Browse files
committed
tmp changes
1 parent c8154d2 commit df1bbd8

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

mkl_fft/interfaces/_scipy_fft.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,32 @@
6767
]
6868

6969

70+
class _cpu_max_threads_count:
71+
def __init__(self):
72+
self.cpu_count = None
73+
self.max_threads_count = None
74+
75+
def get_cpu_count(self):
76+
if self.cpu_count is None:
77+
max_threads = self.get_max_threads_count()
78+
self.cpu_count = max_threads
79+
return self.cpu_count
80+
81+
def get_max_threads_count(self):
82+
if self.max_threads_count is None:
83+
# pylint: disable=no-member
84+
self.max_threads_count = mkl.get_max_threads()
85+
86+
return self.max_threads_count
87+
88+
7089
class _workers_data:
7190
def __init__(self, workers=None):
7291
if workers is not None: # workers = 0 should be handled
7392
self.workers_ = _workers_to_num_threads(workers)
7493
else:
7594
# Unlike SciPy, the default value is maximum number of threads
76-
self.workers_ = mkl.get_max_threads() # pylint: disable=no-member
95+
self.workers_ = _cpu_max_threads_count().get_cpu_count()
7796
self.workers_ = operator.index(self.workers_)
7897

7998
@property

mkl_fft/tests/third_party/scipy/test_multithreading.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import multiprocessing
55
import os
66

7+
import mkl
78
import numpy as np
89
import pytest
910
from numpy.testing import assert_allclose
@@ -81,9 +82,10 @@ def test_invalid_workers(x):
8182

8283
def test_set_get_workers():
8384
cpus = os.cpu_count()
85+
threads = mkl.get_max_threads()
8486

8587
# default value is max number of threads unlike stock SciPy
86-
assert fft.get_workers() == cpus
88+
# assert fft.get_workers() == cpus
8789
with fft.set_workers(4):
8890
assert fft.get_workers() == 4
8991

@@ -93,11 +95,13 @@ def test_set_get_workers():
9395
assert fft.get_workers() == 4
9496

9597
# default value is max number of threads unlike stock SciPy
96-
assert fft.get_workers() == cpus
98+
# assert fft.get_workers() == cpus
9799

98100
with fft.set_workers(-cpus):
99101
assert fft.get_workers() == 1
100102

103+
assert threads == cpus
104+
101105

102106
def test_set_workers_invalid():
103107

0 commit comments

Comments
 (0)