Skip to content

Commit

Permalink
Assert data is device_accessible in halo_exchange on_device
Browse files Browse the repository at this point in the history
  • Loading branch information
wdeconinck committed Aug 22, 2024
1 parent 068532f commit cd856d3
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 4 deletions.
12 changes: 10 additions & 2 deletions src/atlas/parallel/HaloExchange.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,14 @@ void HaloExchange::execute(array::Array& field, bool on_device) const {
DATA_TYPE* inner_buffer = allocate_buffer<DATA_TYPE>(inner_size, on_device);
DATA_TYPE* halo_buffer = allocate_buffer<DATA_TYPE>(halo_size, on_device);

#if ATLAS_HAVE_GPU
if (on_device) {
ATLAS_ASSERT( is_device_accessible(inner_buffer) );
ATLAS_ASSERT( is_device_accessible(halo_buffer) );
ATLAS_ASSERT( is_device_accessible(field_dv.data()) );
}
#endif

counts_displs_setup<DATA_TYPE>(var_size, inner_counts_init, halo_counts_init, inner_counts, halo_counts,
inner_displs, halo_displs);

Expand Down Expand Up @@ -294,7 +302,7 @@ void HaloExchange::ireceive(int tag, std::vector<int>& recv_displs, std::vector<
for (size_t jproc = 0; jproc < static_cast<size_t>(nproc); ++jproc) {
if (recv_counts[jproc] > 0) {
recv_req[jproc] =
comm().iReceive(&recv_buffer[recv_displs[jproc]], recv_counts[jproc], jproc, tag);
comm().iReceive(recv_buffer+recv_displs[jproc], recv_counts[jproc], jproc, tag);
}
}
}
Expand All @@ -309,7 +317,7 @@ void HaloExchange::isend_and_wait_for_receive(int tag, std::vector<int>& recv_co
ATLAS_TRACE_MPI(ISEND) {
for (size_t jproc = 0; jproc < static_cast<size_t>(nproc); ++jproc) {
if (send_counts[jproc] > 0) {
send_req[jproc] = comm().iSend(&send_buffer[send_displs[jproc]], send_counts[jproc], jproc, tag);
send_req[jproc] = comm().iSend(send_buffer+send_displs[jproc], send_counts[jproc], jproc, tag);
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions src/atlas/parallel/HaloExchangeGPU.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,8 @@ struct halo_packer_hic {
array::ArrayView<DATA_TYPE, RANK>& dfield);
};

bool is_device_accessible(const void* data);


} // namespace parallel
} // namespace atlas
15 changes: 13 additions & 2 deletions src/atlas/parallel/HaloExchangeGPU.hic
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@
namespace atlas {
namespace parallel {

bool is_device_accessible(const void* ptr) {
hicPointerAttributes attr;
hicError_t code = hicPointerGetAttributes(&attr, ptr);
if( code != hicSuccess ) {
static_cast<void>(hicGetLastError());
return false;
}
return ptr == attr.devicePointer;
}

template<int ParallelDim, int RANK>
struct get_buffer_index{
template<typename DATA_TYPE>
Expand Down Expand Up @@ -147,10 +157,11 @@ void halo_packer_hic<ParallelDim, DATA_TYPE, RANK>::pack( const int sendcnt, arr
const unsigned int block_size_x = 32;
const unsigned int block_size_y = (RANK==1) ? 1 : 4;

unsigned int nblocks_y = get_n_hic_blocks<ParallelDim, RANK>::apply(hfield, block_size_y);
const unsigned int nblocks_x = (sendcnt+block_size_x-1)/block_size_x;
const unsigned int nblocks_y = get_n_hic_blocks<ParallelDim, RANK>::apply(hfield, block_size_y);

dim3 threads(block_size_x, block_size_y);
dim3 blocks((sendcnt+block_size_x-1)/block_size_x, nblocks_y);
dim3 blocks(nblocks_x, nblocks_y);
hicDeviceSynchronize();
hicError_t err = hicGetLastError();
if (err != hicSuccess) {
Expand Down

0 comments on commit cd856d3

Please sign in to comment.