Skip to content

Commit

Permalink
[batch] add launch bounds and fix register check
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcelKoch committed Oct 2, 2024
1 parent 85b80df commit 988743f
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 28 deletions.
17 changes: 11 additions & 6 deletions common/cuda_hip/solver/batch_bicgstab_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ namespace GKO_DEVICE_NAMESPACE {
namespace batch_single_kernels {


constexpr int max_bicgstab_threads = 1024;


template <typename Group, typename BatchMatrixType_entry, typename ValueType>
__device__ __forceinline__ void initialize(
Group subgroup, const int num_rows, const BatchMatrixType_entry& mat_entry,
Expand Down Expand Up @@ -170,12 +173,14 @@ __device__ __forceinline__ void update_x_middle(
template <typename StopType, int n_shared, bool prec_shared_bool,
typename PrecType, typename LogType, typename BatchMatrixType,
typename ValueType>
__global__ void apply_kernel(
const gko::kernels::batch_bicgstab::storage_config sconf,
const int max_iter, const gko::remove_complex<ValueType> tol,
LogType logger, PrecType prec_shared, const BatchMatrixType mat,
const ValueType* const __restrict__ b, ValueType* const __restrict__ x,
ValueType* const __restrict__ workspace = nullptr)
__global__ void __launch_bounds__(max_bicgstab_threads)
apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf,
const int max_iter, const gko::remove_complex<ValueType> tol,
LogType logger, PrecType prec_shared,
const BatchMatrixType mat,
const ValueType* const __restrict__ b,
ValueType* const __restrict__ x,
ValueType* const __restrict__ workspace = nullptr)
{
using real_type = typename gko::remove_complex<ValueType>;
const auto num_batch_items = mat.num_batch_items;
Expand Down
19 changes: 11 additions & 8 deletions common/cuda_hip/solver/batch_cg_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ namespace GKO_DEVICE_NAMESPACE {
namespace batch_single_kernels {


constexpr int max_cg_threads = 1024;


template <typename Group, typename BatchMatrixType_entry, typename PrecType,
typename ValueType>
__device__ __forceinline__ void initialize(
Expand Down Expand Up @@ -115,14 +118,14 @@ __device__ __forceinline__ void update_x_and_r(
template <typename StopType, const int n_shared, const bool prec_shared_bool,
typename PrecType, typename LogType, typename BatchMatrixType,
typename ValueType>
__global__ void apply_kernel(const gko::kernels::batch_cg::storage_config sconf,
const int max_iter,
const gko::remove_complex<ValueType> tol,
LogType logger, PrecType prec_shared,
const BatchMatrixType mat,
const ValueType* const __restrict__ b,
ValueType* const __restrict__ x,
ValueType* const __restrict__ workspace = nullptr)
__global__ void __launch_bounds__(max_cg_threads)
apply_kernel(const gko::kernels::batch_cg::storage_config sconf,
const int max_iter, const gko::remove_complex<ValueType> tol,
LogType logger, PrecType prec_shared,
const BatchMatrixType mat,
const ValueType* const __restrict__ b,
ValueType* const __restrict__ x,
ValueType* const __restrict__ workspace = nullptr)
{
using real_type = typename gko::remove_complex<ValueType>;
const auto num_batch_items = mat.num_batch_items;
Expand Down
27 changes: 19 additions & 8 deletions cuda/solver/batch_bicgstab_launch.instantiate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,29 @@ int get_num_threads_per_block(std::shared_ptr<const DefaultExecutor> exec,
constexpr int warp_sz = static_cast<int>(config::warp_size);
const int min_block_size = 2 * warp_sz;
const int device_max_threads =
((std::max(num_rows, min_block_size)) / warp_sz) * warp_sz;
cudaFuncAttributes funcattr;
cudaFuncGetAttributes(
&funcattr,
batch_single_kernels::apply_kernel<StopType, 9, true, PrecType, LogType,
BatchMatrixType, ValueType>);
const int num_regs_used = funcattr.numRegs;
(std::max(num_rows, min_block_size) / warp_sz) * warp_sz;
auto get_num_regs = [](const auto func) {
cudaFuncAttributes funcattr;
cudaFuncGetAttributes(&funcattr, func);
return funcattr.numRegs;
};
const int num_regs_used = std::max(
get_num_regs(
batch_single_kernels::apply_kernel<StopType, 9, true, PrecType,
LogType, BatchMatrixType,
ValueType>),
get_num_regs(
batch_single_kernels::apply_kernel<StopType, 0, false, PrecType,
LogType, BatchMatrixType,
ValueType>));
int max_regs_blk = 0;
cudaDeviceGetAttribute(&max_regs_blk, cudaDevAttrMaxRegistersPerBlock,
exec->get_device_id());
const int max_threads_regs =
((max_regs_blk / static_cast<int>(num_regs_used)) / warp_sz) * warp_sz;
int max_threads = std::min(max_threads_regs, device_max_threads);
max_threads = max_threads <= 1024 ? max_threads : 1024;
max_threads = max_threads <= max_bicgstab_threads ? max_threads
: max_bicgstab_threads;
return std::max(std::min(num_warps * warp_sz, max_threads), min_block_size);
}

Expand Down Expand Up @@ -78,6 +87,8 @@ void launch_apply_kernel(
ValueType* const __restrict__ workspace_data, const int& block_size,
const size_t& shared_size)
{
std::cout << n_shared << " " << prec_shared << " " << shared_size
<< std::endl;
batch_single_kernels::apply_kernel<StopType, n_shared, prec_shared>
<<<mat.num_batch_items, block_size, shared_size, exec->get_stream()>>>(
sconf, settings.max_iterations, as_cuda_type(settings.residual_tol),
Expand Down
21 changes: 15 additions & 6 deletions cuda/solver/batch_cg_launch.instantiate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,28 @@ int get_num_threads_per_block(std::shared_ptr<const DefaultExecutor> exec,
const int min_block_size = 2 * warp_sz;
const int device_max_threads =
(std::max(num_rows, min_block_size) / warp_sz) * warp_sz;
cudaFuncAttributes funcattr;
cudaFuncGetAttributes(
&funcattr,
batch_single_kernels::apply_kernel<StopType, 5, true, PrecType, LogType,
BatchMatrixType, ValueType>);
auto get_num_regs = [](const auto func) {
cudaFuncAttributes funcattr;
cudaFuncGetAttributes(&funcattr, func);
return funcattr.numRegs;
};
const int num_regs_used = std::max(
get_num_regs(
batch_single_kernels::apply_kernel<StopType, 5, true, PrecType,
LogType, BatchMatrixType,
ValueType>),
get_num_regs(
batch_single_kernels::apply_kernel<StopType, 0, false, PrecType,
LogType, BatchMatrixType,
ValueType>));
const int num_regs_used = funcattr.numRegs;
int max_regs_blk = 0;
cudaDeviceGetAttribute(&max_regs_blk, cudaDevAttrMaxRegistersPerBlock,
exec->get_device_id());
const int max_threads_regs =
((max_regs_blk / static_cast<int>(num_regs_used)) / warp_sz) * warp_sz;
int max_threads = std::min(max_threads_regs, device_max_threads);
max_threads = max_threads <= 1024 ? max_threads : 1024;
max_threads = max_threads <= max_cg_threads ? max_threads : max_cg_threads;
return std::max(std::min(num_warps * warp_sz, max_threads), min_block_size);
}

Expand Down

0 comments on commit 988743f

Please sign in to comment.