Skip to content

Commit

Permalink
correct array_to_device for non-contiguous case
Browse files Browse the repository at this point in the history
  • Loading branch information
KrisThielemans committed Jul 7, 2024
1 parent 165da28 commit 7ac03fd
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/include/stir/cuda_utilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
\author Kris Thielemans
*/
#include "stir/Array.h"
#include "stir/info.h"
#include <vector>

START_NAMESPACE_STIR
Expand All @@ -28,15 +29,17 @@ array_to_device(elemT* dev_data, const Array<num_dimensions, elemT>& stir_array)
{
if (stir_array.is_contiguous())
{
info("array_to_device contiguous", 100);
cudaMemcpy(dev_data, stir_array.get_const_full_data_ptr(), stir_array.size_all() * sizeof(elemT), cudaMemcpyHostToDevice);
stir_array.release_const_full_data_ptr();
}
else
{
info("array_to_device non-contiguous", 100);
// Allocate host memory to get contiguous vector, copy array to it and copy from device to host
std::vector<elemT> tmp_data(stir_array.size_all());
std::copy(stir_array.begin_all(), stir_array.end_all(), tmp_data.begin());
cudaMemcpy(tmp_data.data(), dev_data, stir_array.size_all() * sizeof(elemT), cudaMemcpyHostToDevice);
cudaMemcpy(dev_data, tmp_data.data(), stir_array.size_all() * sizeof(elemT), cudaMemcpyHostToDevice);
}
}

Expand All @@ -46,11 +49,13 @@ array_to_host(Array<num_dimensions, elemT>& stir_array, const elemT* dev_data)
{
if (stir_array.is_contiguous())
{
info("array_to_host contiguous", 100);
cudaMemcpy(stir_array.get_full_data_ptr(), dev_data, stir_array.size_all() * sizeof(elemT), cudaMemcpyDeviceToHost);
stir_array.release_full_data_ptr();
}
else
{
info("array_to_host non-contiguous", 100);
// Allocate host memory for the result and copy from device to host
std::vector<elemT> tmp_data(stir_array.size_all());
cudaMemcpy(tmp_data.data(), dev_data, stir_array.size_all() * sizeof(elemT), cudaMemcpyDeviceToHost);
Expand Down

0 comments on commit 7ac03fd

Please sign in to comment.