Skip to content

Commit

Permalink
Review updates
Browse files Browse the repository at this point in the history
Co-authored-by: Yu-Hsiang Tsai <[email protected]>
Co-authored-by: Marcel Koch <[email protected]>
  • Loading branch information
3 people committed Oct 30, 2023
1 parent 6459e2f commit fb50eaf
Show file tree
Hide file tree
Showing 18 changed files with 254 additions and 275 deletions.
2 changes: 1 addition & 1 deletion common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ __global__ __launch_bounds__(


template <typename Group, typename ValueType>
__device__ __forceinline__ void single_rhs_compute_dot(Group subgroup,
__device__ __forceinline__ void single_rhs_compute_conj_dot(Group subgroup,
const int num_rows,
const ValueType* x,
const ValueType* y,
Expand Down
13 changes: 3 additions & 10 deletions common/cuda_hip/preconditioner/batch_identity.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,9 @@ public:
return 0;
}

__device__ __forceinline__ void generate(
size_type,
const gko::batch::matrix::ell::batch_item<const ValueType, gko::int32>&,
ValueType*)
{}

__device__ __forceinline__ void generate(
size_type,
const gko::batch::matrix::dense::batch_item<const ValueType>&,
ValueType*)
template <typename batch_item_type>
__device__ __forceinline__ void generate(size_type, const batch_item_type&,
ValueType*)
{}

__device__ __forceinline__ void apply(const int num_rows,
Expand Down
26 changes: 15 additions & 11 deletions common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ __device__ __forceinline__ void initialize(
const ValueType* const x_global_entry, ValueType& rho_old, ValueType& omega,
ValueType& alpha, ValueType* const x_shared_entry,
ValueType* const r_shared_entry, ValueType* const r_hat_shared_entry,
ValueType* const p_shared_entry, ValueType* const v_shared_entry,
ValueType* const p_shared_entry, ValueType* const p_hat_shared_entry,
ValueType* const v_shared_entry,
typename gko::remove_complex<ValueType>& rhs_norm,
typename gko::remove_complex<ValueType>& res_norm)
{
Expand Down Expand Up @@ -70,6 +71,7 @@ __device__ __forceinline__ void initialize(
for (int iz = threadIdx.x; iz < num_rows; iz += blockDim.x) {
r_hat_shared_entry[iz] = r_shared_entry[iz];
p_shared_entry[iz] = zero<ValueType>();
p_hat_shared_entry[iz] = zero<ValueType>();
v_shared_entry[iz] = zero<ValueType>();
}
}
Expand All @@ -82,8 +84,8 @@ __device__ __forceinline__ void update_p(
const ValueType* const r_shared_entry,
const ValueType* const v_shared_entry, ValueType* const p_shared_entry)
{
const ValueType beta = (rho_new / rho_old) * (alpha / omega);
for (int r = threadIdx.x; r < num_rows; r += blockDim.x) {
const ValueType beta = (rho_new / rho_old) * (alpha / omega);
p_shared_entry[r] =
r_shared_entry[r] +
beta * (p_shared_entry[r] - omega * v_shared_entry[r]);
Expand All @@ -97,8 +99,8 @@ __device__ __forceinline__ void compute_alpha(
const ValueType* const v_shared_entry, ValueType& alpha)
{
if (threadIdx.x / config::warp_size == 0) {
single_rhs_compute_dot(subgroup, num_rows, r_hat_shared_entry,
v_shared_entry, alpha);
single_rhs_compute_conj_dot(subgroup, num_rows, r_hat_shared_entry,
v_shared_entry, alpha);
}
__syncthreads();
if (threadIdx.x == 0) {
Expand Down Expand Up @@ -126,11 +128,11 @@ __device__ __forceinline__ void compute_omega(
const ValueType* const s_shared_entry, ValueType& temp, ValueType& omega)
{
if (threadIdx.x / config::warp_size == 0) {
single_rhs_compute_dot(subgroup, num_rows, t_shared_entry,
s_shared_entry, omega);
single_rhs_compute_conj_dot(subgroup, num_rows, t_shared_entry,
s_shared_entry, omega);
} else if (threadIdx.x / config::warp_size == 1) {
single_rhs_compute_dot(subgroup, num_rows, t_shared_entry,
t_shared_entry, temp);
single_rhs_compute_conj_dot(subgroup, num_rows, t_shared_entry,
t_shared_entry, temp);
}

__syncthreads();
Expand Down Expand Up @@ -278,10 +280,12 @@ __global__ void apply_kernel(
// compute residual norms
// r_hat = r
// p = 0
// p_hat = 0
// v = 0
initialize(subgroup, num_rows, mat_entry, b_entry_ptr, x_gl_entry_ptr,
rho_old_sh[0], omega_sh[0], alpha_sh[0], x_sh, r_sh,
r_hat_sh, p_sh, v_sh, norms_rhs_sh[0], norms_res_sh[0]);
r_hat_sh, p_sh, p_hat_sh, v_sh, norms_rhs_sh[0],
norms_res_sh[0]);
__syncthreads();

// stopping criterion object
Expand All @@ -296,8 +300,8 @@ __global__ void apply_kernel(

// rho_new = < r_hat , r > = (r_hat)' * (r)
if (threadIdx.x / config::warp_size == 0) {
single_rhs_compute_dot(subgroup, num_rows, r_hat_sh, r_sh,
rho_new_sh[0]);
single_rhs_compute_conj_dot(subgroup, num_rows, r_hat_sh, r_sh,
rho_new_sh[0]);
}
__syncthreads();

Expand Down
18 changes: 9 additions & 9 deletions core/solver/batch_bicgstab_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,7 @@ void set_gmem_stride_bytes(storage_config& sconf,
}
// align global memory chunks
sconf.gmem_stride_bytes =
gmem_stride > 0 ? ((gmem_stride - 1) / align_bytes + 1) * align_bytes
: 0;
gmem_stride > 0 ? ceildiv(gmem_stride, align_bytes) * align_bytes : 0;
}


Expand All @@ -143,8 +142,8 @@ void set_gmem_stride_bytes(storage_config& sconf,
* - rhs_norms
* - res_norms
*
* @param shared_mem_per_blk The amount of shared memory per block to use for
* keeping intermediate vectors. In case keeping the matrix in L1 cache etc.
* @param available_shared_mem The amount of shared memory per block to use
* for keeping intermediate vectors. In case keeping the matrix in L1 cache etc.
* should be prioritized, the cache configuration must be updated separately
* and the needed space should be subtracted before passing to this
* function.
Expand All @@ -154,7 +153,7 @@ void set_gmem_stride_bytes(storage_config& sconf,
* @return A struct containing allocation information specific to Bicgstab.
*/
template <typename Prectype, typename ValueType, int align_bytes = 32>
storage_config compute_shared_storage(const int shared_mem_per_blk,
storage_config compute_shared_storage(const int available_shared_mem,
const int num_rows, const int num_nz,
const int num_rhs)
{
Expand All @@ -163,10 +162,11 @@ storage_config compute_shared_storage(const int shared_mem_per_blk,
const int num_main_vecs = 9;
const int prec_storage =
Prectype::dynamic_work_size(num_rows, num_nz) * sizeof(ValueType);
int rem_shared = shared_mem_per_blk;
// Set default values. All vecs are in global.
int rem_shared = available_shared_mem;
// Set default values. Initially all vecs are in global memory.
// {prec_shared, n_shared, n_global, gmem_stride_bytes, padded_vec_len}
storage_config sconf{false, 0, num_main_vecs, 0, num_rows};
// If available shared mem, is zero, set all vecs to global.
// If available shared mem is zero, set all vecs to global.
if (rem_shared <= 0) {
set_gmem_stride_bytes<align_bytes>(sconf, vec_size, prec_storage);
return sconf;
Expand All @@ -177,13 +177,13 @@ storage_config compute_shared_storage(const int shared_mem_per_blk,
const int num_vecs_shared = min(initial_vecs_available, num_main_vecs);
sconf.n_shared += num_vecs_shared;
sconf.n_global -= num_vecs_shared;
rem_shared -= num_vecs_shared * vec_size;
// Set the storage configuration with preconditioner workspace in global if
// there are any vectors in global memory.
if (sconf.n_global > 0) {
set_gmem_stride_bytes<align_bytes>(sconf, vec_size, prec_storage);
return sconf;
}
rem_shared -= num_vecs_shared * vec_size;
// If more shared memory space is available and preconditioner workspace is
// needed, enable preconditioner workspace to use shared memory.
if (rem_shared >= prec_storage && prec_storage > 0) {
Expand Down
18 changes: 9 additions & 9 deletions core/test/utils/batch_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ struct LinearSystem {

std::shared_ptr<const MatrixType> matrix;
std::shared_ptr<multi_vec> rhs;
std::shared_ptr<real_vec> rhs_norm;
std::shared_ptr<real_vec> host_rhs_norm;
std::shared_ptr<multi_vec> exact_sol;
};

Expand All @@ -250,8 +250,8 @@ LinearSystem<MatrixType> generate_batch_linear_system(
// A * x_{exact} = b
sys.matrix->apply(sys.exact_sol, sys.rhs);
const gko::batch_dim<2> norm_dim(num_batch_items, gko::dim<2>(1, num_rhs));
sys.rhs_norm = real_vec::create(exec->get_master(), norm_dim);
sys.rhs->compute_norm2(sys.rhs_norm.get());
sys.host_rhs_norm = real_vec::create(exec->get_master(), norm_dim);
sys.rhs->compute_norm2(sys.host_rhs_norm.get());
return sys;
}

Expand All @@ -273,13 +273,13 @@ compute_residual_norms(
const gko::batch_dim<2> norm_dim(num_batch_items, gko::dim<2>(1, num_rhs));

auto residual_vec = b->clone();
auto res_norms = real_vec::create(exec->get_master(), norm_dim);
auto res_norm = real_vec::create(exec->get_master(), norm_dim);
auto alpha =
gko::batch::initialize<multi_vec>(num_batch_items, {-1.0}, exec);
auto beta = gko::batch::initialize<multi_vec>(num_batch_items, {1.0}, exec);
mtx->apply(alpha, x, beta, residual_vec);
residual_vec->compute_norm2(res_norms);
return res_norms;
residual_vec->compute_norm2(res_norm);
return res_norm;
}


Expand All @@ -289,7 +289,7 @@ struct Result {
using real_vec = batch::MultiVector<remove_complex<ValueType>>;

std::shared_ptr<multi_vec> x;
std::shared_ptr<real_vec> res_norm;
std::shared_ptr<real_vec> host_res_norm;
};


Expand Down Expand Up @@ -323,7 +323,7 @@ Result<typename MatrixType::value_type> solve_linear_system(
result.x->fill(zero<value_type>());

solver->apply(sys.rhs, result.x);
result.res_norm =
result.host_res_norm =
compute_residual_norms(sys.matrix.get(), sys.rhs.get(), result.x.get());

return std::move(result);
Expand Down Expand Up @@ -369,7 +369,7 @@ ResultWithLogData<typename MatrixType::value_type> solve_linear_system(
result.log_data->iter_counts = log_data->iter_counts;
result.log_data->res_norms = log_data->res_norms;

result.res_norm =
result.host_res_norm =
compute_residual_norms(sys.matrix.get(), sys.rhs.get(), result.x.get());

return std::move(result);
Expand Down
3 changes: 2 additions & 1 deletion cuda/matrix/batch_struct.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ get_batch_struct(batch::matrix::Dense<ValueType>* const op)
* Generates an immutable uniform batch struct from a batch of ell matrices.
*/
template <typename ValueType, typename IndexType>
inline batch::matrix::ell::uniform_batch<const cuda_type<ValueType>, IndexType>
inline batch::matrix::ell::uniform_batch<const cuda_type<ValueType>,
const IndexType>
get_batch_struct(const batch::matrix::Ell<ValueType, IndexType>* const op)
{
return {as_cuda_type(op->get_const_values()),
Expand Down
17 changes: 8 additions & 9 deletions cuda/solver/batch_bicgstab_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,10 @@ int get_num_threads_per_block(std::shared_ptr<const DefaultExecutor> exec,
cudaDeviceGetAttribute(&max_regs_blk, cudaDevAttrMaxRegistersPerBlock,
exec->get_device_id());
const int max_threads_regs =
((max_regs_blk /
static_cast<int>((static_cast<double>(num_regs_used)))) /
warp_sz) *
warp_sz;
((max_regs_blk / static_cast<int>(num_regs_used)) / warp_sz) * warp_sz;
int max_threads = std::min(max_threads_regs, device_max_threads);
max_threads = max_threads <= 1024 ? max_threads : 1024;
return std::min(num_warps * warp_sz, max_threads);
return std::max(std::min(num_warps * warp_sz, max_threads), min_block_size);
}


Expand Down Expand Up @@ -136,12 +133,12 @@ using settings = gko::kernels::batch_bicgstab::settings<T>;


template <typename CuValueType>
class KernelCaller {
class kernel_caller {
public:
using value_type = CuValueType;

KernelCaller(std::shared_ptr<const DefaultExecutor> exec,
const settings<remove_complex<value_type>> settings)
kernel_caller(std::shared_ptr<const DefaultExecutor> exec,
const settings<remove_complex<value_type>> settings)
: exec_{std::move(exec)}, settings_{settings}
{}

Expand Down Expand Up @@ -263,6 +260,8 @@ public:
sconf, logger, prec, mat, b.values, x.values,
workspace_data, block_size, shared_size);
break;
default:
GKO_NOT_IMPLEMENTED;
}
}
Expand All @@ -286,7 +285,7 @@ void apply(std::shared_ptr<const DefaultExecutor> exec,
{
using cu_value_type = cuda_type<ValueType>;
auto dispatcher = batch::solver::create_dispatcher<ValueType>(
KernelCaller<cu_value_type>(exec, settings), settings, mat, precon);
kernel_caller<cu_value_type>(exec, settings), settings, mat, precon);
dispatcher.apply(b, x, logdata);
}
Expand Down
Loading

0 comments on commit fb50eaf

Please sign in to comment.