Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batch::cg solver device kernels #1609

Merged
merged 11 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion common/cuda_hip/preconditioner/batch_scalar_jacobi.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public:
__host__ __device__ static constexpr int dynamic_work_size(
const int num_rows, int)
{
return num_rows;
return num_rows * sizeof(value_type);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is that some rebase left over?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but I moved it to #1600 now. I think that will be merged first, so will rebase this on that afterwards

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then maybe change the base of the PR? Makes it easier to review.

}

/**
Expand Down
228 changes: 228 additions & 0 deletions common/cuda_hip/solver/batch_cg_kernels.hpp.inc
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

template <typename Group, typename BatchMatrixType_entry, typename PrecType,
typename ValueType>
__device__ __forceinline__ void initialize(
Group subgroup, const int num_rows, const BatchMatrixType_entry& mat_entry,
const ValueType* const b_global_entry,
const ValueType* const x_global_entry, ValueType* const x_shared_entry,
ValueType* const r_shared_entry, const PrecType& prec_shared,
ValueType* const z_shared_entry, ValueType& rho_old_shared_entry,
ValueType* const p_shared_entry,
typename gko::remove_complex<ValueType>& rhs_norms_sh)
{
// copy x from global to shared memory
// r = b
for (int iz = threadIdx.x; iz < num_rows; iz += blockDim.x) {
x_shared_entry[iz] = x_global_entry[iz];
r_shared_entry[iz] = b_global_entry[iz];
}
__syncthreads();

// r = b - A*x
advanced_apply(static_cast<ValueType>(-1.0), mat_entry, x_shared_entry,
static_cast<ValueType>(1.0), r_shared_entry);
__syncthreads();

// z = precond * r
prec_shared.apply(num_rows, r_shared_entry, z_shared_entry);
__syncthreads();

if (threadIdx.x / config::warp_size == 0) {
// Compute norms of rhs
single_rhs_compute_norm2(subgroup, num_rows, b_global_entry,
rhs_norms_sh);
} else if (threadIdx.x / config::warp_size == 1) {
// rho_old = r' * z
single_rhs_compute_conj_dot(subgroup, num_rows, r_shared_entry,
z_shared_entry, rho_old_shared_entry);
}

// p = z
for (int iz = threadIdx.x; iz < num_rows; iz += blockDim.x) {
p_shared_entry[iz] = z_shared_entry[iz];
}
}


template <typename ValueType>
__device__ __forceinline__ void update_p(const int num_rows,
const ValueType& rho_new_shared_entry,
const ValueType& rho_old_shared_entry,
const ValueType* const z_shared_entry,
ValueType* const p_shared_entry)
{
for (int li = threadIdx.x; li < num_rows; li += blockDim.x) {
const ValueType beta = rho_new_shared_entry / rho_old_shared_entry;
p_shared_entry[li] = z_shared_entry[li] + beta * p_shared_entry[li];
}
}


template <typename Group, typename ValueType>
__device__ __forceinline__ void update_x_and_r(
Group subgroup, const int num_rows, const ValueType& rho_old_shared_entry,
const ValueType* const p_shared_entry,
const ValueType* const Ap_shared_entry, ValueType& alpha_shared_entry,
ValueType* const x_shared_entry, ValueType* const r_shared_entry)
{
if (threadIdx.x / config::warp_size == 0) {
single_rhs_compute_conj_dot(subgroup, num_rows, p_shared_entry,
Ap_shared_entry, alpha_shared_entry);
}
__syncthreads();

for (int li = threadIdx.x; li < num_rows; li += blockDim.x) {
const ValueType alpha = rho_old_shared_entry / alpha_shared_entry;
x_shared_entry[li] += alpha * p_shared_entry[li];
r_shared_entry[li] -= alpha * Ap_shared_entry[li];
}
}


template <typename StopType, const int n_shared, const bool prec_shared_bool,
typename PrecType, typename LogType, typename BatchMatrixType,
typename ValueType>
__global__ void apply_kernel(const gko::kernels::batch_cg::storage_config sconf,
const int max_iter,
const gko::remove_complex<ValueType> tol,
LogType logger, PrecType prec_shared,
const BatchMatrixType mat,
const ValueType* const __restrict__ b,
ValueType* const __restrict__ x,
ValueType* const __restrict__ workspace = nullptr)
{
using real_type = typename gko::remove_complex<ValueType>;
const auto num_batch_items = mat.num_batch_items;
const auto num_rows = mat.num_rows;

constexpr auto tile_size = config::warp_size;
auto thread_block = group::this_thread_block();
auto subgroup = group::tiled_partition<tile_size>(thread_block);

for (size_type batch_id = blockIdx.x; batch_id < num_batch_items;
batch_id += gridDim.x) {
const int gmem_offset =
batch_id * sconf.gmem_stride_bytes / sizeof(ValueType);
extern __shared__ char local_mem_sh[];

ValueType* r_sh;
ValueType* z_sh;
ValueType* p_sh;
ValueType* Ap_sh;
ValueType* x_sh;
ValueType* prec_work_sh;

if (n_shared >= 1) {
r_sh = reinterpret_cast<ValueType*>(local_mem_sh);
} else {
r_sh = workspace + gmem_offset;
}
if (n_shared == 1) {
z_sh = workspace + gmem_offset;
} else {
z_sh = r_sh + sconf.padded_vec_len;
}
if (n_shared == 2) {
p_sh = workspace + gmem_offset;
} else {
p_sh = z_sh + sconf.padded_vec_len;
}
if (n_shared == 3) {
Ap_sh = workspace + gmem_offset;
} else {
Ap_sh = p_sh + sconf.padded_vec_len;
}
if (n_shared == 4) {
x_sh = workspace + gmem_offset;
} else {
x_sh = Ap_sh + sconf.padded_vec_len;
}
if (!prec_shared_bool && n_shared == 5) {
prec_work_sh = workspace + gmem_offset;
} else {
prec_work_sh = x_sh + sconf.padded_vec_len;
}

__shared__ uninitialized_array<ValueType, 1> rho_old_sh;
__shared__ uninitialized_array<ValueType, 1> rho_new_sh;
__shared__ uninitialized_array<ValueType, 1> alpha_sh;
__shared__ real_type norms_rhs_sh[1];
__shared__ real_type norms_res_sh[1];
MarcelKoch marked this conversation as resolved.
Show resolved Hide resolved

const auto mat_entry =
gko::batch::matrix::extract_batch_item(mat, batch_id);
const ValueType* const b_global_entry =
gko::batch::multi_vector::batch_item_ptr(b, 1, num_rows, batch_id);
ValueType* const x_global_entry =
gko::batch::multi_vector::batch_item_ptr(x, 1, num_rows, batch_id);

// generate preconditioner
prec_shared.generate(batch_id, mat_entry, prec_work_sh);

// initialization
// compute b norms
// r = b - A*x
// z = precond*r
// rho_old = r' * z (' is for hermitian transpose)
// p = z
initialize(subgroup, num_rows, mat_entry, b_global_entry,
x_global_entry, x_sh, r_sh, prec_shared, z_sh, rho_old_sh[0],
p_sh, norms_rhs_sh[0]);
__syncthreads();

// stopping criterion object
StopType stop(tol, norms_rhs_sh);

int iter = 0;
for (; iter < max_iter; iter++) {
norms_res_sh[0] = sqrt(abs(rho_old_sh[0]));
__syncthreads();
if (stop.check_converged(norms_res_sh)) {
logger.log_iteration(batch_id, iter, norms_res_sh[0]);
break;
}
pratikvn marked this conversation as resolved.
Show resolved Hide resolved

// Ap = A * p
simple_apply(mat_entry, p_sh, Ap_sh);
__syncthreads();

// alpha = rho_old / (p' * Ap)
// x = x + alpha * p
// r = r - alpha * Ap
update_x_and_r(subgroup, num_rows, rho_old_sh[0], p_sh, Ap_sh,
alpha_sh[0], x_sh, r_sh);
__syncthreads();

// z = precond * r
prec_shared.apply(num_rows, r_sh, z_sh);
__syncthreads();

if (threadIdx.x / config::warp_size == 0) {
// rho_new = (r)' * (z)
single_rhs_compute_conj_dot(subgroup, num_rows, r_sh, z_sh,
rho_new_sh[0]);
}
__syncthreads();

// beta = rho_new / rho_old
// p = z + beta * p
update_p(num_rows, rho_new_sh[0], rho_old_sh[0], z_sh, p_sh);
__syncthreads();

// rho_old = rho_new
if (threadIdx.x == 0) {
rho_old_sh[0] = rho_new_sh[0];
}
__syncthreads();
}

logger.log_iteration(batch_id, iter, norms_res_sh[0]);

// copy x back to global memory
single_rhs_copy(num_rows, x_sh, x_global_entry);
__syncthreads();
}
}
6 changes: 2 additions & 4 deletions core/solver/batch_bicgstab_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ void set_gmem_stride_bytes(storage_config& sconf,
gmem_stride += prec_storage_bytes;
}
// align global memory chunks
sconf.gmem_stride_bytes =
gmem_stride > 0 ? ceildiv(gmem_stride, align_bytes) * align_bytes : 0;
sconf.gmem_stride_bytes = ceildiv(gmem_stride, align_bytes) * align_bytes;
pratikvn marked this conversation as resolved.
Show resolved Hide resolved
}


Expand Down Expand Up @@ -134,8 +133,7 @@ storage_config compute_shared_storage(const int available_shared_mem,
using real_type = remove_complex<ValueType>;
const int vec_size = num_rows * num_rhs * sizeof(ValueType);
const int num_main_vecs = 9;
const int prec_storage =
Prectype::dynamic_work_size(num_rows, num_nz) * sizeof(ValueType);
const int prec_storage = Prectype::dynamic_work_size(num_rows, num_nz);
int rem_shared = available_shared_mem;
// Set default values. Initially all vecs are in global memory.
// {prec_shared, n_shared, n_global, gmem_stride_bytes, padded_vec_len}
Expand Down
17 changes: 8 additions & 9 deletions core/solver/batch_cg_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,7 @@ storage_config compute_shared_storage(const int available_shared_mem,
using real_type = remove_complex<ValueType>;
const int vec_bytes = num_rows * num_rhs * sizeof(ValueType);
const int num_main_vecs = 5;
const int prec_storage =
Prectype::dynamic_work_size(num_rows, num_nz) * sizeof(ValueType);
const int prec_storage = Prectype::dynamic_work_size(num_rows, num_nz);
int rem_shared = available_shared_mem;
// Set default values. Initially all vecs are in global memory.
// {prec_shared, n_shared, n_global, gmem_stride_bytes, padded_vec_len}
Expand Down Expand Up @@ -160,13 +159,13 @@ storage_config compute_shared_storage(const int available_shared_mem,
} // namespace batch_cg


#define GKO_DECLARE_BATCH_CG_APPLY_KERNEL(_type) \
void apply( \
std::shared_ptr<const DefaultExecutor> exec, \
const gko::kernels::batch_cg::settings<remove_complex<_type>>& \
options, \
const batch::BatchLinOp* a, const batch::BatchLinOp* preconditioner, \
const batch::MultiVector<_type>* b, batch::MultiVector<_type>* x, \
#define GKO_DECLARE_BATCH_CG_APPLY_KERNEL(_type) \
void apply( \
std::shared_ptr<const DefaultExecutor> exec, \
const gko::kernels::batch_cg::settings<remove_complex<_type>>& \
options, \
const batch::BatchLinOp* mat, const batch::BatchLinOp* preconditioner, \
const batch::MultiVector<_type>* b, batch::MultiVector<_type>* x, \
gko::batch::log::detail::log_data<remove_complex<_type>>& logdata)


Expand Down
2 changes: 1 addition & 1 deletion cuda/solver/batch_bicgstab_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ public:
auto workspace = gko::array<value_type>(
exec_,
sconf.gmem_stride_bytes * num_batch_items / sizeof(value_type));
assert(sconf.gmem_stride_bytes % sizeof(value_type) == 0);
GKO_ASSERT(sconf.gmem_stride_bytes % sizeof(value_type) == 0);

value_type* const workspace_data = workspace.get_data();

Expand Down
Loading
Loading