-
Notifications
You must be signed in to change notification settings - Fork 92
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add batch::cg solver device kernels #1609
Merged
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
3631744
Add cg solver cuda/hip kernels
pratikvn cd4e614
add dpcpp kernels
pratikvn 24e6920
review updates
pratikvn 80e634f
dpcpp kernel fix WIP
pratikvn 9b1dc02
review updates
pratikvn 22ca18b
update tolerances
pratikvn c5d63e8
hip fixes
pratikvn 8dc3aee
dpcpp fixes
pratikvn 958e6d2
update tol
pratikvn e739dcb
update static slm sizes
pratikvn 229db3f
reduce num batch items
pratikvn File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,228 @@ | ||
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors | ||
// | ||
// SPDX-License-Identifier: BSD-3-Clause | ||
|
||
template <typename Group, typename BatchMatrixType_entry, typename PrecType, | ||
typename ValueType> | ||
__device__ __forceinline__ void initialize( | ||
Group subgroup, const int num_rows, const BatchMatrixType_entry& mat_entry, | ||
const ValueType* const b_global_entry, | ||
const ValueType* const x_global_entry, ValueType* const x_shared_entry, | ||
ValueType* const r_shared_entry, const PrecType& prec_shared, | ||
ValueType* const z_shared_entry, ValueType& rho_old_shared_entry, | ||
ValueType* const p_shared_entry, | ||
typename gko::remove_complex<ValueType>& rhs_norms_sh) | ||
{ | ||
// copy x from global to shared memory | ||
// r = b | ||
for (int iz = threadIdx.x; iz < num_rows; iz += blockDim.x) { | ||
x_shared_entry[iz] = x_global_entry[iz]; | ||
r_shared_entry[iz] = b_global_entry[iz]; | ||
} | ||
__syncthreads(); | ||
|
||
// r = b - A*x | ||
advanced_apply(static_cast<ValueType>(-1.0), mat_entry, x_shared_entry, | ||
static_cast<ValueType>(1.0), r_shared_entry); | ||
__syncthreads(); | ||
|
||
// z = precond * r | ||
prec_shared.apply(num_rows, r_shared_entry, z_shared_entry); | ||
__syncthreads(); | ||
|
||
if (threadIdx.x / config::warp_size == 0) { | ||
// Compute norms of rhs | ||
single_rhs_compute_norm2(subgroup, num_rows, b_global_entry, | ||
rhs_norms_sh); | ||
} else if (threadIdx.x / config::warp_size == 1) { | ||
// rho_old = r' * z | ||
single_rhs_compute_conj_dot(subgroup, num_rows, r_shared_entry, | ||
z_shared_entry, rho_old_shared_entry); | ||
} | ||
|
||
// p = z | ||
for (int iz = threadIdx.x; iz < num_rows; iz += blockDim.x) { | ||
p_shared_entry[iz] = z_shared_entry[iz]; | ||
} | ||
} | ||
|
||
|
||
template <typename ValueType> | ||
__device__ __forceinline__ void update_p(const int num_rows, | ||
const ValueType& rho_new_shared_entry, | ||
const ValueType& rho_old_shared_entry, | ||
const ValueType* const z_shared_entry, | ||
ValueType* const p_shared_entry) | ||
{ | ||
for (int li = threadIdx.x; li < num_rows; li += blockDim.x) { | ||
const ValueType beta = rho_new_shared_entry / rho_old_shared_entry; | ||
p_shared_entry[li] = z_shared_entry[li] + beta * p_shared_entry[li]; | ||
} | ||
} | ||
|
||
|
||
template <typename Group, typename ValueType> | ||
__device__ __forceinline__ void update_x_and_r( | ||
Group subgroup, const int num_rows, const ValueType& rho_old_shared_entry, | ||
const ValueType* const p_shared_entry, | ||
const ValueType* const Ap_shared_entry, ValueType& alpha_shared_entry, | ||
ValueType* const x_shared_entry, ValueType* const r_shared_entry) | ||
{ | ||
if (threadIdx.x / config::warp_size == 0) { | ||
single_rhs_compute_conj_dot(subgroup, num_rows, p_shared_entry, | ||
Ap_shared_entry, alpha_shared_entry); | ||
} | ||
__syncthreads(); | ||
|
||
for (int li = threadIdx.x; li < num_rows; li += blockDim.x) { | ||
const ValueType alpha = rho_old_shared_entry / alpha_shared_entry; | ||
x_shared_entry[li] += alpha * p_shared_entry[li]; | ||
r_shared_entry[li] -= alpha * Ap_shared_entry[li]; | ||
} | ||
} | ||
|
||
|
||
template <typename StopType, const int n_shared, const bool prec_shared_bool, | ||
typename PrecType, typename LogType, typename BatchMatrixType, | ||
typename ValueType> | ||
__global__ void apply_kernel(const gko::kernels::batch_cg::storage_config sconf, | ||
const int max_iter, | ||
const gko::remove_complex<ValueType> tol, | ||
LogType logger, PrecType prec_shared, | ||
const BatchMatrixType mat, | ||
const ValueType* const __restrict__ b, | ||
ValueType* const __restrict__ x, | ||
ValueType* const __restrict__ workspace = nullptr) | ||
{ | ||
using real_type = typename gko::remove_complex<ValueType>; | ||
const auto num_batch_items = mat.num_batch_items; | ||
const auto num_rows = mat.num_rows; | ||
|
||
constexpr auto tile_size = config::warp_size; | ||
auto thread_block = group::this_thread_block(); | ||
auto subgroup = group::tiled_partition<tile_size>(thread_block); | ||
|
||
for (size_type batch_id = blockIdx.x; batch_id < num_batch_items; | ||
batch_id += gridDim.x) { | ||
const int gmem_offset = | ||
batch_id * sconf.gmem_stride_bytes / sizeof(ValueType); | ||
extern __shared__ char local_mem_sh[]; | ||
|
||
ValueType* r_sh; | ||
ValueType* z_sh; | ||
ValueType* p_sh; | ||
ValueType* Ap_sh; | ||
ValueType* x_sh; | ||
ValueType* prec_work_sh; | ||
|
||
if (n_shared >= 1) { | ||
r_sh = reinterpret_cast<ValueType*>(local_mem_sh); | ||
} else { | ||
r_sh = workspace + gmem_offset; | ||
} | ||
if (n_shared == 1) { | ||
z_sh = workspace + gmem_offset; | ||
} else { | ||
z_sh = r_sh + sconf.padded_vec_len; | ||
} | ||
if (n_shared == 2) { | ||
p_sh = workspace + gmem_offset; | ||
} else { | ||
p_sh = z_sh + sconf.padded_vec_len; | ||
} | ||
if (n_shared == 3) { | ||
Ap_sh = workspace + gmem_offset; | ||
} else { | ||
Ap_sh = p_sh + sconf.padded_vec_len; | ||
} | ||
if (n_shared == 4) { | ||
x_sh = workspace + gmem_offset; | ||
} else { | ||
x_sh = Ap_sh + sconf.padded_vec_len; | ||
} | ||
if (!prec_shared_bool && n_shared == 5) { | ||
prec_work_sh = workspace + gmem_offset; | ||
} else { | ||
prec_work_sh = x_sh + sconf.padded_vec_len; | ||
} | ||
|
||
__shared__ uninitialized_array<ValueType, 1> rho_old_sh; | ||
__shared__ uninitialized_array<ValueType, 1> rho_new_sh; | ||
__shared__ uninitialized_array<ValueType, 1> alpha_sh; | ||
__shared__ real_type norms_rhs_sh[1]; | ||
__shared__ real_type norms_res_sh[1]; | ||
MarcelKoch marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
const auto mat_entry = | ||
gko::batch::matrix::extract_batch_item(mat, batch_id); | ||
const ValueType* const b_global_entry = | ||
gko::batch::multi_vector::batch_item_ptr(b, 1, num_rows, batch_id); | ||
ValueType* const x_global_entry = | ||
gko::batch::multi_vector::batch_item_ptr(x, 1, num_rows, batch_id); | ||
|
||
// generate preconditioner | ||
prec_shared.generate(batch_id, mat_entry, prec_work_sh); | ||
|
||
// initialization | ||
// compute b norms | ||
// r = b - A*x | ||
// z = precond*r | ||
// rho_old = r' * z (' is for hermitian transpose) | ||
// p = z | ||
initialize(subgroup, num_rows, mat_entry, b_global_entry, | ||
x_global_entry, x_sh, r_sh, prec_shared, z_sh, rho_old_sh[0], | ||
p_sh, norms_rhs_sh[0]); | ||
__syncthreads(); | ||
|
||
// stopping criterion object | ||
StopType stop(tol, norms_rhs_sh); | ||
|
||
int iter = 0; | ||
for (; iter < max_iter; iter++) { | ||
norms_res_sh[0] = sqrt(abs(rho_old_sh[0])); | ||
__syncthreads(); | ||
if (stop.check_converged(norms_res_sh)) { | ||
logger.log_iteration(batch_id, iter, norms_res_sh[0]); | ||
break; | ||
} | ||
pratikvn marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
// Ap = A * p | ||
simple_apply(mat_entry, p_sh, Ap_sh); | ||
__syncthreads(); | ||
|
||
// alpha = rho_old / (p' * Ap) | ||
// x = x + alpha * p | ||
// r = r - alpha * Ap | ||
update_x_and_r(subgroup, num_rows, rho_old_sh[0], p_sh, Ap_sh, | ||
alpha_sh[0], x_sh, r_sh); | ||
__syncthreads(); | ||
|
||
// z = precond * r | ||
prec_shared.apply(num_rows, r_sh, z_sh); | ||
__syncthreads(); | ||
|
||
if (threadIdx.x / config::warp_size == 0) { | ||
// rho_new = (r)' * (z) | ||
single_rhs_compute_conj_dot(subgroup, num_rows, r_sh, z_sh, | ||
rho_new_sh[0]); | ||
} | ||
__syncthreads(); | ||
|
||
// beta = rho_new / rho_old | ||
// p = z + beta * p | ||
update_p(num_rows, rho_new_sh[0], rho_old_sh[0], z_sh, p_sh); | ||
__syncthreads(); | ||
|
||
// rho_old = rho_new | ||
if (threadIdx.x == 0) { | ||
rho_old_sh[0] = rho_new_sh[0]; | ||
} | ||
__syncthreads(); | ||
} | ||
|
||
logger.log_iteration(batch_id, iter, norms_res_sh[0]); | ||
|
||
// copy x back to global memory | ||
single_rhs_copy(num_rows, x_sh, x_global_entry); | ||
__syncthreads(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is that some rebase left over?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but I moved it to #1600 now. I think that will be merged first, so will rebase this on that afterwards
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then maybe change the base of the PR? Makes it easier to review.