Skip to content

Commit

Permalink
Add cg solver cuda/hip kernels
Browse files Browse the repository at this point in the history
Co-authored-by: Isha Aggarwal <[email protected]>
Co-authored-by: Aditya Kashi <[email protected]>
  • Loading branch information
3 people committed May 10, 2024
1 parent 0a6484e commit 3864556
Show file tree
Hide file tree
Showing 9 changed files with 556 additions and 19 deletions.
2 changes: 1 addition & 1 deletion common/cuda_hip/preconditioner/batch_scalar_jacobi.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public:
__host__ __device__ static constexpr int dynamic_work_size(
const int num_rows, int)
{
return num_rows;
return num_rows * sizeof(value_type);
}

/**
Expand Down
228 changes: 228 additions & 0 deletions common/cuda_hip/solver/batch_cg_kernels.hpp.inc
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

template <typename Group, typename BatchMatrixType_entry, typename PrecType,
typename ValueType>
__device__ __forceinline__ void initialize(
Group subgroup, const int num_rows, const BatchMatrixType_entry& mat_entry,
const ValueType* const b_global_entry,
const ValueType* const x_global_entry, ValueType* const x_shared_entry,
ValueType* const r_shared_entry, const PrecType& prec_shared,
ValueType* const z_shared_entry, ValueType& rho_old_shared_entry,
ValueType* const p_shared_entry,
typename gko::remove_complex<ValueType>& rhs_norms_sh)
{
// copy x from global to shared memory
// r = b
for (int iz = threadIdx.x; iz < num_rows; iz += blockDim.x) {
x_shared_entry[iz] = x_global_entry[iz];
r_shared_entry[iz] = b_global_entry[iz];
}
__syncthreads();

// r = b - A*x
advanced_apply(static_cast<ValueType>(-1.0), mat_entry, x_shared_entry,
static_cast<ValueType>(1.0), r_shared_entry);
__syncthreads();

// z = precond * r
prec_shared.apply(num_rows, r_shared_entry, z_shared_entry);
__syncthreads();

if (threadIdx.x / config::warp_size == 0) {
// Compute norms of rhs
single_rhs_compute_norm2(subgroup, num_rows, b_global_entry,
rhs_norms_sh);
} else if (threadIdx.x / config::warp_size == 1) {
// rho_old = r' * z
single_rhs_compute_conj_dot(subgroup, num_rows, r_shared_entry,
z_shared_entry, rho_old_shared_entry);
}

// p = z
for (int iz = threadIdx.x; iz < num_rows; iz += blockDim.x) {
p_shared_entry[iz] = z_shared_entry[iz];
}
}


template <typename ValueType>
__device__ __forceinline__ void update_p(const int num_rows,
const ValueType& rho_new_shared_entry,
const ValueType& rho_old_shared_entry,
const ValueType* const z_shared_entry,
ValueType* const p_shared_entry)
{
for (int li = threadIdx.x; li < num_rows; li += blockDim.x) {
const ValueType beta = rho_new_shared_entry / rho_old_shared_entry;
p_shared_entry[li] = z_shared_entry[li] + beta * p_shared_entry[li];
}
}


template <typename Group, typename ValueType>
__device__ __forceinline__ void update_x_and_r(
Group subgroup, const int num_rows, const ValueType& rho_old_shared_entry,
const ValueType* const p_shared_entry,
const ValueType* const Ap_shared_entry, ValueType& alpha_shared_entry,
ValueType* const x_shared_entry, ValueType* const r_shared_entry)
{
if (threadIdx.x / config::warp_size == 0) {
single_rhs_compute_conj_dot(subgroup, num_rows, p_shared_entry,
Ap_shared_entry, alpha_shared_entry);
}
__syncthreads();

for (int li = threadIdx.x; li < num_rows; li += blockDim.x) {
const ValueType alpha = rho_old_shared_entry / alpha_shared_entry;
x_shared_entry[li] += alpha * p_shared_entry[li];
r_shared_entry[li] -= alpha * Ap_shared_entry[li];
}
}


template <typename StopType, const int n_shared, const bool prec_shared_bool,
typename PrecType, typename LogType, typename BatchMatrixType,
typename ValueType>
__global__ void apply_kernel(const gko::kernels::batch_cg::storage_config sconf,
const int max_iter,
const gko::remove_complex<ValueType> tol,
LogType logger, PrecType prec_shared,
const BatchMatrixType mat,
const ValueType* const __restrict__ b,
ValueType* const __restrict__ x,
ValueType* const __restrict__ workspace = nullptr)
{
using real_type = typename gko::remove_complex<ValueType>;
const auto num_batch_items = mat.num_batch_items;
const auto num_rows = mat.num_rows;

constexpr auto tile_size = config::warp_size;
auto thread_block = group::this_thread_block();
auto subgroup = group::tiled_partition<tile_size>(thread_block);

for (size_type batch_id = blockIdx.x; batch_id < num_batch_items;
batch_id += gridDim.x) {
const int gmem_offset =
batch_id * sconf.gmem_stride_bytes / sizeof(ValueType);
extern __shared__ char local_mem_sh[];

ValueType* r_sh;
ValueType* z_sh;
ValueType* p_sh;
ValueType* Ap_sh;
ValueType* x_sh;
ValueType* prec_work_sh;

if (n_shared >= 1) {
r_sh = reinterpret_cast<ValueType*>(local_mem_sh);
} else {
r_sh = workspace + gmem_offset;
}
if (n_shared == 1) {
z_sh = workspace + gmem_offset;
} else {
z_sh = r_sh + sconf.padded_vec_len;
}
if (n_shared == 2) {
p_sh = workspace + gmem_offset;
} else {
p_sh = z_sh + sconf.padded_vec_len;
}
if (n_shared == 3) {
Ap_sh = workspace + gmem_offset;
} else {
Ap_sh = p_sh + sconf.padded_vec_len;
}
if (n_shared == 4) {
x_sh = workspace + gmem_offset;
} else {
x_sh = Ap_sh + sconf.padded_vec_len;
}
if (!prec_shared_bool && n_shared == 5) {
prec_work_sh = workspace + gmem_offset;
} else {
prec_work_sh = x_sh + sconf.padded_vec_len;
}

__shared__ uninitialized_array<ValueType, 1> rho_old_sh;
__shared__ uninitialized_array<ValueType, 1> rho_new_sh;
__shared__ uninitialized_array<ValueType, 1> alpha_sh;
__shared__ real_type norms_rhs_sh[1];
__shared__ real_type norms_res_sh[1];

const auto mat_entry =
gko::batch::matrix::extract_batch_item(mat, batch_id);
const ValueType* const b_global_entry =
gko::batch::multi_vector::batch_item_ptr(b, 1, num_rows, batch_id);
ValueType* const x_global_entry =
gko::batch::multi_vector::batch_item_ptr(x, 1, num_rows, batch_id);

// generate preconditioner
prec_shared.generate(batch_id, mat_entry, prec_work_sh);

// initialization
// compute b norms
// r = b - A*x
// z = precond*r
// rho_old = r' * z (' is for hermitian transpose)
// p = z
initialize(subgroup, num_rows, mat_entry, b_global_entry,
x_global_entry, x_sh, r_sh, prec_shared, z_sh, rho_old_sh[0],
p_sh, norms_rhs_sh[0]);
__syncthreads();

// stopping criterion object
StopType stop(tol, norms_rhs_sh);

int iter = 0;
for (; iter < max_iter; iter++) {
norms_res_sh[0] = sqrt(abs(rho_old_sh[0]));
__syncthreads();
if (stop.check_converged(norms_res_sh)) {
logger.log_iteration(batch_id, iter, norms_res_sh[0]);
break;
}

// Ap = A * p
simple_apply(mat_entry, p_sh, Ap_sh);
__syncthreads();

// alpha = rho_old / (p' * Ap)
// x = x + alpha * p
// r = r - alpha * Ap
update_x_and_r(subgroup, num_rows, rho_old_sh[0], p_sh, Ap_sh,
alpha_sh[0], x_sh, r_sh);
__syncthreads();

// z = precond * r
prec_shared.apply(num_rows, r_sh, z_sh);
__syncthreads();

if (threadIdx.x / config::warp_size == 0) {
// rho_new = (r)' * (z)
single_rhs_compute_conj_dot(subgroup, num_rows, r_sh, z_sh,
rho_new_sh[0]);
}
__syncthreads();

// beta = rho_new / rho_old
// p = z + beta * p
update_p(num_rows, rho_new_sh[0], rho_old_sh[0], z_sh, p_sh);
__syncthreads();

// rho_old = rho_new
if (threadIdx.x == 0) {
rho_old_sh[0] = rho_new_sh[0];
}
__syncthreads();
}

logger.log_iteration(batch_id, iter, norms_res_sh[0]);

// copy x back to global memory
single_rhs_copy(num_rows, x_sh, x_global_entry);
__syncthreads();
}
}
6 changes: 2 additions & 4 deletions core/solver/batch_bicgstab_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ void set_gmem_stride_bytes(storage_config& sconf,
gmem_stride += prec_storage_bytes;
}
// align global memory chunks
sconf.gmem_stride_bytes =
gmem_stride > 0 ? ceildiv(gmem_stride, align_bytes) * align_bytes : 0;
sconf.gmem_stride_bytes = ceildiv(gmem_stride, align_bytes) * align_bytes;
}


Expand Down Expand Up @@ -134,8 +133,7 @@ storage_config compute_shared_storage(const int available_shared_mem,
using real_type = remove_complex<ValueType>;
const int vec_size = num_rows * num_rhs * sizeof(ValueType);
const int num_main_vecs = 9;
const int prec_storage =
Prectype::dynamic_work_size(num_rows, num_nz) * sizeof(ValueType);
const int prec_storage = Prectype::dynamic_work_size(num_rows, num_nz);
int rem_shared = available_shared_mem;
// Set default values. Initially all vecs are in global memory.
// {prec_shared, n_shared, n_global, gmem_stride_bytes, padded_vec_len}
Expand Down
17 changes: 8 additions & 9 deletions core/solver/batch_cg_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,7 @@ storage_config compute_shared_storage(const int available_shared_mem,
using real_type = remove_complex<ValueType>;
const int vec_bytes = num_rows * num_rhs * sizeof(ValueType);
const int num_main_vecs = 5;
const int prec_storage =
Prectype::dynamic_work_size(num_rows, num_nz) * sizeof(ValueType);
const int prec_storage = Prectype::dynamic_work_size(num_rows, num_nz);
int rem_shared = available_shared_mem;
// Set default values. Initially all vecs are in global memory.
// {prec_shared, n_shared, n_global, gmem_stride_bytes, padded_vec_len}
Expand Down Expand Up @@ -160,13 +159,13 @@ storage_config compute_shared_storage(const int available_shared_mem,
} // namespace batch_cg


#define GKO_DECLARE_BATCH_CG_APPLY_KERNEL(_type) \
void apply( \
std::shared_ptr<const DefaultExecutor> exec, \
const gko::kernels::batch_cg::settings<remove_complex<_type>>& \
options, \
const batch::BatchLinOp* a, const batch::BatchLinOp* preconditioner, \
const batch::MultiVector<_type>* b, batch::MultiVector<_type>* x, \
#define GKO_DECLARE_BATCH_CG_APPLY_KERNEL(_type) \
void apply( \
std::shared_ptr<const DefaultExecutor> exec, \
const gko::kernels::batch_cg::settings<remove_complex<_type>>& \
options, \
const batch::BatchLinOp* mat, const batch::BatchLinOp* preconditioner, \
const batch::MultiVector<_type>* b, batch::MultiVector<_type>* x, \
gko::batch::log::detail::log_data<remove_complex<_type>>& logdata)


Expand Down
Loading

0 comments on commit 3864556

Please sign in to comment.