Skip to content

Commit

Permalink
[batch] split cg compilation (hip)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcelKoch committed Jun 24, 2024
1 parent ace85cd commit 0fc6b86
Show file tree
Hide file tree
Showing 5 changed files with 317 additions and 176 deletions.
2 changes: 2 additions & 0 deletions hip/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ include(${PROJECT_SOURCE_DIR}/cmake/template_instantiation.cmake)
add_instantiation_files(. matrix/csr_kernels.instantiate.hip.cpp CSR_INSTANTIATE)
add_instantiation_files(. matrix/fbcsr_kernels.instantiate.hip.cpp FBCSR_INSTANTIATE)
add_instantiation_files(. solver/batch_bicgstab_kernels_launch.instantiate.hip.cpp BATCH_BICGSTAB_INSTANTIATE)
add_instantiation_files(. solver/batch_cg_kernels_launch.instantiate.hip.cpp BATCH_CG_INSTANTIATE)
# we don't split up the dense kernels into distinct compilations
list(APPEND GKO_UNIFIED_COMMON_SOURCES ${PROJECT_SOURCE_DIR}/common/unified/matrix/dense_kernels.instantiate.cpp)
set(GINKGO_HIP_SOURCES
Expand Down Expand Up @@ -60,6 +61,7 @@ set(GINKGO_HIP_SOURCES
solver/batch_bicgstab_kernels.hip.cpp
${BATCH_BICGSTAB_INSTANTIATE}
solver/batch_cg_kernels.hip.cpp
${BATCH_CG_INSTANTIATE}
solver/cb_gmres_kernels.hip.cpp
solver/idr_kernels.hip.cpp
solver/lower_trs_kernels.hip.cpp
Expand Down
4 changes: 2 additions & 2 deletions hip/solver/batch_bicgstab_kernels.hip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ class kernel_caller {

hip_value_type* const workspace_data = workspace.get_data();

// Template parameters launch_apply_kernel<StopType, n_shared,
// prec_shared)
// Template parameters launch_apply_kernel<ValueType, n_shared,
// prec_shared, StopType>
if (sconf.prec_shared) {
launch_apply_kernel<ValueType, 9, true, StopType>(
exec_, sconf, settings_, logger, prec, mat, b.values, x.values,
Expand Down
176 changes: 2 additions & 174 deletions hip/solver/batch_cg_kernels.hip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,21 @@


#include <hip/hip_runtime.h>
#include <thrust/functional.h>
#include <thrust/transform.h>


#include <ginkgo/core/base/exception_helpers.hpp>
#include <ginkgo/core/base/math.hpp>


#include "core/base/batch_struct.hpp"
#include "core/matrix/batch_struct.hpp"
#include "core/solver/batch_dispatch.hpp"
#include "hip/base/batch_struct.hip.hpp"
#include "hip/base/config.hip.hpp"
#include "hip/base/math.hip.hpp"
#include "hip/base/thrust.hip.hpp"
#include "hip/base/types.hip.hpp"
#include "hip/components/cooperative_groups.hip.hpp"
#include "hip/components/reduction.hip.hpp"
#include "hip/components/thread_ids.hip.hpp"
#include "hip/components/uninitialized_array.hip.hpp"
#include "hip/matrix/batch_struct.hip.hpp"
#include "hip/solver/batch_cg_kernels.hip.hpp"


namespace gko {
namespace kernels {
namespace hip {


constexpr int default_block_size = 256;
constexpr int sm_oversubscription = 4;

/**
* @brief The batch Cg solver namespace.
*
Expand All @@ -45,161 +29,6 @@ constexpr int sm_oversubscription = 4;
namespace batch_cg {


#include "common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc"
#include "common/cuda_hip/components/uninitialized_array.hpp.inc"
#include "common/cuda_hip/matrix/batch_csr_kernels.hpp.inc"
#include "common/cuda_hip/matrix/batch_dense_kernels.hpp.inc"
#include "common/cuda_hip/matrix/batch_ell_kernels.hpp.inc"
#include "common/cuda_hip/solver/batch_cg_kernels.hpp.inc"


template <typename BatchMatrixType>
int get_num_threads_per_block(std::shared_ptr<const DefaultExecutor> exec,
const int num_rows)
{
int num_warps = std::max(num_rows / 4, 2);
constexpr int warp_sz = static_cast<int>(config::warp_size);
const int min_block_size = 2 * warp_sz;
const int device_max_threads =
((std::max(num_rows, min_block_size)) / warp_sz) * warp_sz;
// This value has been taken from ROCm docs. This is the number of registers
// that maximizes the occupancy on an AMD GPU (MI200). HIP does not have an
// API to query the number of registers a function uses.
const int num_regs_used_per_thread = 64;
int max_regs_blk = 0;
GKO_ASSERT_NO_HIP_ERRORS(hipDeviceGetAttribute(
&max_regs_blk, hipDeviceAttributeMaxRegistersPerBlock,
exec->get_device_id()));
int max_threads_regs = (max_regs_blk / num_regs_used_per_thread);
max_threads_regs = (max_threads_regs / warp_sz) * warp_sz;
int max_threads = std::min(max_threads_regs, device_max_threads);
max_threads = max_threads <= 1024 ? max_threads : 1024;
return std::max(std::min(num_warps * warp_sz, max_threads), min_block_size);
}


template <typename T>
using settings = gko::kernels::batch_cg::settings<T>;


template <typename HipValueType>
class kernel_caller {
public:
using value_type = HipValueType;

kernel_caller(std::shared_ptr<const DefaultExecutor> exec,
const settings<remove_complex<value_type>> settings)
: exec_{exec}, settings_{settings}
{}

template <typename StopType, const int n_shared,
const bool prec_shared_bool, typename PrecType, typename LogType,
typename BatchMatrixType>
void launch_apply_kernel(
const gko::kernels::batch_cg::storage_config& sconf, LogType& logger,
PrecType& prec, const BatchMatrixType& mat,
const value_type* const __restrict__ b_values,
value_type* const __restrict__ x_values,
value_type* const __restrict__ workspace_data, const int& block_size,
const size_t& shared_size) const
{
apply_kernel<StopType, n_shared, prec_shared_bool>
<<<mat.num_batch_items, block_size, shared_size,
exec_->get_stream()>>>(sconf, settings_.max_iterations,
settings_.residual_tol, logger, prec, mat,
b_values, x_values, workspace_data);
}


template <typename BatchMatrixType, typename PrecType, typename StopType,
typename LogType>
void call_kernel(
LogType logger, const BatchMatrixType& mat, PrecType prec,
const gko::batch::multi_vector::uniform_batch<const value_type>& b,
const gko::batch::multi_vector::uniform_batch<value_type>& x) const
{
using real_type = gko::remove_complex<value_type>;
const size_type num_batch_items = mat.num_batch_items;
constexpr int align_multiple = 8;
const int padded_num_rows =
ceildiv(mat.num_rows, align_multiple) * align_multiple;
int shmem_per_blk = 0;
GKO_ASSERT_NO_HIP_ERRORS(hipDeviceGetAttribute(
&shmem_per_blk, hipDeviceAttributeMaxSharedMemoryPerBlock,
exec_->get_device_id()));
const int block_size =
get_num_threads_per_block<BatchMatrixType>(exec_, mat.num_rows);
GKO_ASSERT(block_size >= 2 * config::warp_size);
GKO_ASSERT(block_size % config::warp_size == 0);

// Returns amount required in bytes
const size_t prec_size = PrecType::dynamic_work_size(
padded_num_rows, mat.get_single_item_num_nnz());
const auto sconf =
gko::kernels::batch_cg::compute_shared_storage<PrecType,
value_type>(
shmem_per_blk, padded_num_rows, mat.get_single_item_num_nnz(),
b.num_rhs);
const size_t shared_size =
sconf.n_shared * padded_num_rows * sizeof(value_type) +
(sconf.prec_shared ? prec_size : 0);
auto workspace = gko::array<value_type>(
exec_,
sconf.gmem_stride_bytes * num_batch_items / sizeof(value_type));
GKO_ASSERT(sconf.gmem_stride_bytes % sizeof(value_type) == 0);

value_type* const workspace_data = workspace.get_data();

// Template parameters launch_apply_kernel<StopType, n_shared,
// prec_shared)
if (sconf.prec_shared) {
launch_apply_kernel<StopType, 5, true>(
sconf, logger, prec, mat, b.values, x.values, workspace_data,
block_size, shared_size);
} else {
switch (sconf.n_shared) {
case 0:
launch_apply_kernel<StopType, 0, false>(
sconf, logger, prec, mat, b.values, x.values,
workspace_data, block_size, shared_size);
break;
case 1:
launch_apply_kernel<StopType, 1, false>(
sconf, logger, prec, mat, b.values, x.values,
workspace_data, block_size, shared_size);
break;
case 2:
launch_apply_kernel<StopType, 2, false>(
sconf, logger, prec, mat, b.values, x.values,
workspace_data, block_size, shared_size);
break;
case 3:
launch_apply_kernel<StopType, 3, false>(
sconf, logger, prec, mat, b.values, x.values,
workspace_data, block_size, shared_size);
break;
case 4:
launch_apply_kernel<StopType, 4, false>(
sconf, logger, prec, mat, b.values, x.values,
workspace_data, block_size, shared_size);
break;
case 5:
launch_apply_kernel<StopType, 5, false>(
sconf, logger, prec, mat, b.values, x.values,
workspace_data, block_size, shared_size);
break;
default:
GKO_NOT_IMPLEMENTED;
}
}
}

private:
std::shared_ptr<const DefaultExecutor> exec_;
const settings<remove_complex<value_type>> settings_;
};


template <typename ValueType>
void apply(std::shared_ptr<const DefaultExecutor> exec,
const settings<remove_complex<ValueType>>& settings,
Expand All @@ -209,9 +38,8 @@ void apply(std::shared_ptr<const DefaultExecutor> exec,
batch::MultiVector<ValueType>* const x,
batch::log::detail::log_data<remove_complex<ValueType>>& logdata)
{
using hip_value_type = hip_type<ValueType>;
auto dispatcher = batch::solver::create_dispatcher<ValueType>(
kernel_caller<hip_value_type>(exec, settings), settings, mat, precon);
kernel_caller<ValueType>(exec, settings), settings, mat, precon);
dispatcher.apply(b, x, logdata);
}

Expand Down
Loading

0 comments on commit 0fc6b86

Please sign in to comment.