diff --git a/pycbc/events/simd_threshold_ccode.cpp b/pycbc/events/simd_threshold_ccode.cpp index 95a05549640..015c023efc1 100644 --- a/pycbc/events/simd_threshold_ccode.cpp +++ b/pycbc/events/simd_threshold_ccode.cpp @@ -92,11 +92,14 @@ void _parallel_threshold(int64_t N, std::complex * __restrict arr, #pragma omp ordered { - t+=c; + t += c; + memmove( + outl + t - c, outl + start, sizeof(unsigned int)*c + ); + memmove( + outv + t - c, outv + start, sizeof(std::complex)*c + ); } - memmove(outl+t-c, outl+start, sizeof(unsigned int)*c); - memmove(outv+t-c, outv+start, sizeof(std::complex)*c); - } count[0] = t; diff --git a/test/test_threshold.py b/test/test_threshold.py index 387ed55fe71..e1f434c9757 100644 --- a/test/test_threshold.py +++ b/test/test_threshold.py @@ -14,47 +14,41 @@ # with this program; if not, write to the Free Software Foundation, Inc., # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. -# -# ============================================================================= -# -# Preamble -# -# ============================================================================= -# """ -These are the unittests for the pycbc.waveform module +Unit tests for PyCBC's thresholding code. """ + import unittest import numpy -from pycbc.types import * -from pycbc.scheme import * -from pycbc.events import * +from pycbc.types import Array, complex64 +from pycbc.events import threshold from utils import parse_args_all_schemes, simple_exit _scheme, _context = parse_args_all_schemes("Threshold") -from pycbc.events.threshold_cpu import threshold_numpy -trusted_threshold = threshold_numpy +from pycbc.events.threshold_cpu import threshold_numpy as trusted_threshold + class TestThreshold(unittest.TestCase): - def setUp(self,*args): + def setUp(self, *args): self.context = _context self.scheme = _scheme r = numpy.random.uniform(low=-1, high=1.0, size=2**20) i = numpy.random.uniform(low=-1, high=1.0, size=2**20) - v = r + i*1.0j + v = r + i * 1.0j self.series = Array(v, dtype=complex64) self.threshold = 1.3 self.locs, self.vals = trusted_threshold(self.series, self.threshold) - self.tolerance = 1e-6 - print(len(self.locs), len(self.vals)) + print(f'Reference: {len(self.locs)} locs, {len(self.vals)} vals') def test_threshold(self): with self.context: locs, vals = threshold(self.series, self.threshold) + print(f'Test: {len(locs)} locs, {len(vals)} vals') self.assertTrue((locs == self.locs).all()) self.assertTrue((vals == self.vals).all()) - print(len(locs), len(vals)) + + suite = unittest.TestSuite() suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestThreshold))