Skip to content

Commit

Permalink
nvector: token pasting macro for CUDA/HIP overlap;
Browse files Browse the repository at this point in the history
preprocessor improvements
  • Loading branch information
jsdomine committed Jul 6, 2023
1 parent e0b9e73 commit 921873b
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 29 deletions.
44 changes: 24 additions & 20 deletions include/nvector/nvector_parhyp.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,31 +93,35 @@

#if defined(SUNDIALS_HYPRE_BACKENDS_CUDA)

#define NV_SUNMemoryHelper_TYPE_PH SUNMemoryHelper_Cuda
#define NV_SUNExecPolicy_TYPE_PH SUNCudaExecPolicy
#define NV_Stream_TYPE_PH cudaStream_t

#define NV_LANG_STRING_PH "CUDA"
#define NV_lang_TOKEN_PH cuda

#define NV_DeviceSynchronize_CALL_PH cudaDeviceSynchronize
#define NV_GetLastError_CALL_PH cudaGetLastError
#define NV_MemsetAsync_CALL_PH cudaMemsetAsync
#define NV_Verify_CALL_PH SUNDIALS_CUDA_VERIFY
#define NV_GPU_LANG_TOKEN_PH cuda
#define NV_GPU_LANG_STRING_PH "CUDA"
#define NV_ADD_LANG_PREFIX_PH(token) cuda##token // token pasting; expands to ```cuda[token]```
/* Example usage: NV_ADD_LANG_PREFIX_PH(MemsetAsync)(...) -> cudaMemsetAsync(...) */

#define NV_EXECPOLICY_TYPE_PH SUNCudaExecPolicy
#define NV_MEMHELP_STRUCT_PH SUNMemoryHelper_Cuda
#define NV_VERIFY_CALL_PH SUNDIALS_CUDA_VERIFY

// #define NV_lang_TOKEN_PH cuda
// #define NV_Stream_TYPE_PH cudaStream_t
// #define NV_DeviceSynchronize_CALL_PH cudaDeviceSynchronize
// #define NV_GetLastError_CALL_PH cudaGetLastError
// #define NV_MemsetAsync_CALL_PH cudaMemsetAsync
// #define NV_Memcpy_CALL_PH cudaMemcpy
// #define NV_MemcpyDeviceToHost_TOKEN_PH cudaMemcpyDeviceToHost
// #define NV_MemcpyHostToDevice_TOKEN_PH cudaMemcpyHostToDevice

#elif defined(SUNDIALS_HYPRE_BACKENDS_HIP)

#define NV_SUNMemoryHelper_TYPE_PH SUNMemoryHelper_Hip
#define NV_SUNExecPolicy_TYPE_PH SUNHipExecPolicy
#define NV_Stream_TYPE_PH hipStream_t
#define NV_GPU_LANG_TOKEN_PH hip
#define NV_GPU_LANG_STRING_PH "HIP"
#define NV_ADD_LANG_PREFIX_PH(token) hip##token // token pasting; expands to ```hip[token]```
/* Example usage: NV_ADD_LANG_PREFIX_PH(MemsetAsync)(...) -> hipMemsetAsync(...) */

#define NV_LANG_STRING_PH "HIP"
#define NV_lang_TOKEN_PH hip
#define NV_EXECPOLICY_TYPE_PH SUNHipExecPolicy
#define NV_MEMHELP_STRUCT_PH SUNMemoryHelper_Hip
#define NV_VERIFY_CALL_PH SUNDIALS_HIP_VERIFY

#define NV_DeviceSynchronize_CALL_PH hipDeviceSynchronize
#define NV_GetLastError_CALL_PH hipGetLastError
#define NV_MemsetAsync_CALL_PH hipMemsetAsync
#define NV_Verify_CALL_PH SUNDIALS_HIP_VERIFY
#endif

/* --- Wrapper to enable C++ usage --- */
Expand Down
24 changes: 15 additions & 9 deletions src/nvector/parhyp/nvector_parhyp.c
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ N_Vector N_VMake_ParHyp(HYPRE_ParVector x, SUNContext sunctx)

/* Attach CUDA/HIP-only content */
#if defined(SUNDIALS_HYPRE_BACKENDS_CUDA_OR_HIP)
NV_MEMHELP_PH(v) = NV_SUNMemoryHelper_TYPE_PH(sunctx);
NV_MEMHELP_PH(v) = NV_MEMHELP_STRUCT_PH(sunctx);
NV_STREAM_POLICY_PH(v) = DEFAULT_STREAMING_EXECPOLICY.clone();
NV_REDUCE_POLICY_PH(v) = DEFAULT_REDUCTION_EXECPOLICY.clone();

Expand Down Expand Up @@ -487,6 +487,12 @@ void N_VPrintFile_ParHyp(N_Vector x, FILE *outfile)
N = NV_LOCLENGTH_PH(x);
xd = NV_DATA_PH(x);

// #if defined(SUNDIALS_HYPRE_BACKENDS_CUDA_OR_HIP)
// realtype *host_data = (realtype*)malloc(sizeof(realtype)*local_length);
// NV_ADD_LANG_PREFIX_PH(Memcpy)(host_data,Xdata,sizeof(realtype)*local_length,NV_ADD_LANG_PREFIX_PH(MemcpyDeviceToHost));
// Xdata = host_data;
// #endif

for (i = 0; i < N; i++) {
#if defined(SUNDIALS_EXTENDED_PRECISION)
fprintf(outfile, "%Lg\n", xd[i]);
Expand Down Expand Up @@ -2346,7 +2352,7 @@ static int InitializeDeviceCounter(N_Vector v)
&(NV_PRIVATE_PH(v)->device_counter), sizeof(unsigned int),
SUNMEMTYPE_DEVICE, (void*) NV_STREAM_PH(v));
}
NV_MemsetAsync_CALL_PH(NV_DCOUNTERp_PH(v), 0, sizeof(unsigned int), *NV_STREAM_PH(v));
NV_ADD_LANG_PREFIX_PH(MemsetAsync)(NV_DCOUNTERp_PH(v), 0, sizeof(unsigned int), *NV_STREAM_PH(v));
return retval;
}

Expand Down Expand Up @@ -2453,12 +2459,12 @@ static void FreeReductionBuffer(N_Vector v)

static int GetKernelParameters(N_Vector v, booleantype reduction, size_t& grid,
size_t& block, size_t& shMemSize,
NV_Stream_TYPE_PH& stream, bool& atomic, size_t n)
NV_ADD_LANG_PREFIX_PH(Stream_t)& stream, bool& atomic, size_t n)
{
n = (n == 0) ? NV_CONTENT_PH(v)->length : n;
if (reduction)
{
NV_SUNExecPolicy_TYPE_PH* reduce_exec_policy = NV_CONTENT_PH(v)->reduce_exec_policy;
NV_EXECPOLICY_TYPE_PH* reduce_exec_policy = NV_CONTENT_PH(v)->reduce_exec_policy;
grid = reduce_exec_policy->gridSize(n);
block = reduce_exec_policy->blockSize();
shMemSize = 0;
Expand All @@ -2476,7 +2482,7 @@ static int GetKernelParameters(N_Vector v, booleantype reduction, size_t& grid,
}
}

if (block % sundials::NV_lang_TOKEN_PH::WARP_SIZE)
if (block % sundials::NV_GPU_LANG_TOKEN_PH::WARP_SIZE)
{
#ifdef SUNDIALS_DEBUG
throw std::runtime_error("the block size must be a multiple must be of the "NV_LANG_STRING_PH" warp size");
Expand All @@ -2486,7 +2492,7 @@ static int GetKernelParameters(N_Vector v, booleantype reduction, size_t& grid,
}
else
{
NV_SUNExecPolicy_TYPE_PH* stream_exec_policy = NV_CONTENT_PH(v)->stream_exec_policy;
NV_EXECPOLICY_TYPE_PH* stream_exec_policy = NV_CONTENT_PH(v)->stream_exec_policy;
grid = stream_exec_policy->gridSize(n);
block = stream_exec_policy->blockSize();
shMemSize = 0;
Expand All @@ -2513,7 +2519,7 @@ static int GetKernelParameters(N_Vector v, booleantype reduction, size_t& grid,
}

static int GetKernelParameters(N_Vector v, booleantype reduction, size_t& grid,
size_t& block, size_t& shMemSize, NV_Stream_TYPE_PH& stream,
size_t& block, size_t& shMemSize, NV_ADD_LANG_PREFIX_PH(Stream_t)& stream,
size_t n)
{
bool atomic;
Expand All @@ -2524,8 +2530,8 @@ static void PostKernelLaunch()
{
// TODO: implement "SUNDIALS_DEBUG_PARHYP_LASTERROR"?
#if defined(SUNDIALS_DEBUG_CUDA_LASTERROR) || defined(SUNDIALS_DEBUG_HIP_LASTERROR)
NV_DeviceSynchronize_CALL_PH();
NV_Verify_CALL_PH(NV_GetLastError_CALL_PH());
NV_ADD_LANG_PREFIX_PH(DeviceSynchronize)();
NV_VERIFY_CALL_PH(NV_ADD_LANG_PREFIX_PH(GetLastError)());
#endif
}

Expand Down

0 comments on commit 921873b

Please sign in to comment.