Skip to content

Commit

Permalink
Add wrapped cuda implementation of calc_doppler_width
Browse files Browse the repository at this point in the history
  • Loading branch information
smokestacklightnin committed Aug 10, 2023
1 parent af60053 commit a7d15bd
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 1 deletion.
31 changes: 31 additions & 0 deletions stardis/opacities/broadening.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,37 @@ def _calc_doppler_width_cuda(res, nu_line, temperature, atomic_mass):
res[tid] = _calc_doppler_width(nu_line[tid], temperature[tid], atomic_mass[tid])


def calc_doppler_width_cuda(
nu_line,
temperature,
atomic_mass,
nthreads=256,
ret_np_ndarray=True,
dtype=float,
):
arg_list = (
nu_line,
temperature,
atomic_mass,
)

shortest_arg_idx = np.argmin(map(len, arg_list))
size = len(arg_list[shortest_arg_idx])

nblocks = 1 + (size // nthreads)

arg_list = tuple(map(lambda v: cp.array(v, dtype=dtype), arg_list))

res = cp.empty_like(arg_list[shortest_arg_idx], dtype=dtype)

_calc_doppler_width_cuda[nblocks, nthreads](
res,
*arg_list,
)

return cp.asnumpy(res) if ret_np_ndarray else res


@numba.njit
def calc_n_effective(ion_number, ionization_energy, level_energy):
"""
Expand Down
37 changes: 36 additions & 1 deletion stardis/opacities/tests/test_broadening.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
from astropy import constants as const
from numba import cuda

from stardis.opacities.broadening import calc_doppler_width, _calc_doppler_width_cuda
from stardis.opacities.broadening import (
calc_doppler_width,
_calc_doppler_width_cuda,
calc_doppler_width_cuda,
)

GPUs_available = cuda.is_available()

Expand Down Expand Up @@ -93,3 +97,34 @@ def test_calc_doppler_width_cuda_unwrapped_sample_values(
cp.asnumpy(result_values),
calc_doppler_width_cuda_unwrapped_sample_values_expected_result,
)


@pytest.mark.skipif(
not GPUs_available, reason="No GPU is available to test CUDA function"
)
@pytest.mark.parametrize(
"calc_doppler_width_cuda_sample_values_input_nu_line, calc_doppler_width_cuda_sample_values_input_temperature, calc_doppler_width_cuda_sample_values_input_atomic_mass, calc_doppler_width_cuda_wrapped_sample_cuda_values_expected_result",
[
(
np.array(2 * [SPEED_OF_LIGHT]),
np.array(2 * [0.5]),
np.array(2 * [BOLTZMANN_CONSTANT]),
np.array(2 * [1.0]),
),
],
)
def test_calc_doppler_width_cuda_wrapped_sample_cuda_values(
calc_doppler_width_cuda_sample_values_input_nu_line,
calc_doppler_width_cuda_sample_values_input_temperature,
calc_doppler_width_cuda_sample_values_input_atomic_mass,
calc_doppler_width_cuda_wrapped_sample_cuda_values_expected_result,
):
arg_list = (
calc_doppler_width_cuda_sample_values_input_nu_line,
calc_doppler_width_cuda_sample_values_input_temperature,
calc_doppler_width_cuda_sample_values_input_atomic_mass,
)
assert np.allclose(
calc_doppler_width_cuda(*map(cp.asarray, arg_list)),
calc_doppler_width_cuda_wrapped_sample_cuda_values_expected_result,
)

0 comments on commit a7d15bd

Please sign in to comment.