diff --git a/stardis/opacities/tests/test_voigt.py b/stardis/opacities/tests/test_voigt.py index 3f981b15..edc83824 100644 --- a/stardis/opacities/tests/test_voigt.py +++ b/stardis/opacities/tests/test_voigt.py @@ -2,6 +2,7 @@ import numpy as np from math import sqrt from numba import cuda +import cupy as cp from stardis.opacities.voigt import ( faddeeva, _faddeeva_cuda, @@ -45,15 +46,17 @@ def test_faddeeva_cuda_unwrapped_sample_values( faddeeva_cuda_unwrapped_sample_values_input, faddeeva_cuda_unwrapped_sample_values_expected_result, ): - test_values = cuda.to_device(faddeeva_cuda_unwrapped_sample_values_input) - result_values = cuda.device_array_like(test_values) + test_values = cp.asarray(faddeeva_cuda_unwrapped_sample_values_input) + result_values = cp.empty_like(test_values) + nthreads = 256 length = len(faddeeva_cuda_unwrapped_sample_values_input) + nblocks = 1 + (length // nthreads) - _faddeeva_cuda.forall(length)(result_values, test_values) + _faddeeva_cuda[nblocks, nthreads](result_values, test_values) assert np.allclose( - result_values.copy_to_host(), + cp.asnumpy(result_values), faddeeva_cuda_unwrapped_sample_values_expected_result, ) @@ -95,32 +98,11 @@ def test_faddeeva_cuda_wrapped_sample_cuda_values( faddeeva_cuda_wrapped_sample_cuda_values_expected_result, ): assert np.allclose( - faddeeva_cuda( - cuda.device_array_like(faddeeva_cuda_wrapped_sample_cuda_values_input) - ), + faddeeva_cuda(cp.asarray(faddeeva_cuda_wrapped_sample_cuda_values_input)), faddeeva_cuda_wrapped_sample_cuda_values_expected_result, ) -@pytest.mark.skipif( - not GPUs_available, reason="No GPU is available to test CUDA function" -) -@pytest.mark.parametrize( - "faddeeva_cuda_wrapped_noncomplex_input_input", - [ - np.array([0, 0], dtype=int), - np.array([0, 0], dtype=float), - ], -) -def test_faddeeva_cuda_wrapped_noncomplex_input( - faddeeva_cuda_wrapped_noncomplex_input_input, -): - with pytest.raises(TypeError): - _ = faddeeva_cuda( - cuda.device_array_like(faddeeva_cuda_wrapped_noncomplex_input_input) - ) - - test_voigt_profile_division_by_zero_test_values = [ -100, -5, diff --git a/stardis/opacities/voigt.py b/stardis/opacities/voigt.py index c9f848d5..76a86d98 100644 --- a/stardis/opacities/voigt.py +++ b/stardis/opacities/voigt.py @@ -1,6 +1,7 @@ import numpy as np import numba from numba import cuda +import cupy as cp import cmath SQRT_PI = np.sqrt(np.pi, dtype=float) @@ -91,18 +92,14 @@ def _faddeeva_cuda(res, z): res[tid] = _faddeeva(z[tid]) -def faddeeva_cuda(z): +def faddeeva_cuda(z, nthreads=256, ret_np_ndarray=True): size = len(z) - if hasattr(z, "astype"): - z = z.astype(complex) - if not np.iscomplexobj(z): - raise TypeError( - f"Faddeeva with cuda only works with complex arguments. Expected any complex datatyep and instead got {z.dtype}." - ) - res = cuda.device_array_like(z) + nblocks = 1 + (size // nthreads) + z = cp.asarray(z, dtype=complex) + res = cp.empty_like(z) - _faddeeva_cuda.forall(size)(res, z) - return res.copy_to_host() + _faddeeva_cuda[nblocks, nthreads](res, z) + return cp.asnumpy(res) if ret_np_ndarray else res @numba.njit