Skip to content

Commit

Permalink
add dpcpp kernels
Browse files Browse the repository at this point in the history
Co-authored-by: Phuong Nguyen <[email protected]>
  • Loading branch information
pratikvn and Phuong Nguyen committed May 10, 2024
1 parent 3864556 commit 03b7cac
Show file tree
Hide file tree
Showing 6 changed files with 444 additions and 10 deletions.
5 changes: 5 additions & 0 deletions cuda/solver/batch_cg_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,11 @@ public:
sconf, logger, prec, mat, b.values, x.values,
workspace_data, block_size, shared_size);
break;
case 5:
launch_apply_kernel<StopType, 5, false>(
sconf, logger, prec, mat, b.values, x.values,
workspace_data, block_size, shared_size);
break;
default:
GKO_NOT_IMPLEMENTED;
}
Expand Down
11 changes: 5 additions & 6 deletions dpcpp/solver/batch_bicgstab_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ __dpct_inline__ int get_group_size(int value,


template <typename ValueType>
class KernelCaller {
class kernel_caller {
public:
KernelCaller(std::shared_ptr<const DefaultExecutor> exec,
const settings<remove_complex<ValueType>> settings)
kernel_caller(std::shared_ptr<const DefaultExecutor> exec,
const settings<remove_complex<ValueType>> settings)
: exec_{std::move(exec)}, settings_{settings}
{}

Expand Down Expand Up @@ -167,8 +167,7 @@ class KernelCaller {
int n_shared_total = sconf.n_shared + int(sconf.prec_shared);

// template
// launch_apply_kernel<StopType, subgroup_size, n_shared_total,
// sg_kernel_all>
// launch_apply_kernel<StopType, subgroup_size, n_shared_total>
if (num_rows <= 32 && n_shared_total == 10) {
launch_apply_kernel<StopType, 32, 10>(
sconf, logger, prec, mat, b.values, x.values, workspace_data,
Expand Down Expand Up @@ -256,7 +255,7 @@ void apply(std::shared_ptr<const DefaultExecutor> exec,
batch::log::detail::log_data<remove_complex<ValueType>>& logdata)
{
auto dispatcher = batch::solver::create_dispatcher<ValueType>(
KernelCaller<ValueType>(exec, settings), settings, mat, precond);
kernel_caller<ValueType>(exec, settings), settings, mat, precond);
dispatcher.apply(b, x, logdata);
}

Expand Down
7 changes: 5 additions & 2 deletions dpcpp/solver/batch_bicgstab_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ __dpct_inline__ void initialize(
const auto sg_id = sg.get_group_id();
const auto tid = item_ct1.get_local_linear_id();
const auto group_size = item_ct1.get_local_range().size();
const auto group = item_ct1.get_group();

rho_old = one<ValueType>();
omega = one<ValueType>();
Expand Down Expand Up @@ -296,11 +295,15 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf,
p_sh, p_hat_sh, v_sh, norms_rhs_sh[0], norms_res_sh[0],
item_ct1);
item_ct1.barrier(sycl::access::fence_space::global_and_local);
int iter = 0;

// stopping criterion object
StopType stop(tol, norms_rhs_sh);
if (stop.check_converged(norms_res_sh)) {
logger.log_iteration(batch_id, iter, norms_res_sh[0]);
return;
}

int iter = 0;
for (; iter < max_iter; iter++) {
if (stop.check_converged(norms_res_sh)) {
logger.log_iteration(batch_id, iter, norms_res_sh[0]);
Expand Down
177 changes: 176 additions & 1 deletion dpcpp/solver/batch_cg_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,185 @@ namespace batch_cg {
#include "dpcpp/matrix/batch_csr_kernels.hpp.inc"
#include "dpcpp/matrix/batch_dense_kernels.hpp.inc"
#include "dpcpp/matrix/batch_ell_kernels.hpp.inc"
#include "dpcpp/solver/batch_cg_kernels.hpp.inc"


template <typename T>
using settings = gko::kernels::batch_cg::settings<T>;


__dpct_inline__ int get_group_size(int value,
int subgroup_size = config::warp_size)
{
int num_sg = ceildiv(value, subgroup_size);
return num_sg * subgroup_size;
}


template <typename ValueType>
class kernel_caller {
public:
kernel_caller(std::shared_ptr<const DefaultExecutor> exec,
const settings<remove_complex<ValueType>> settings)
: exec_{std::move(exec)}, settings_{settings}
{}

template <typename StopType, const int subgroup_size,
const int n_shared_total, typename PrecType, typename LogType,
typename BatchMatrixType>
__dpct_inline__ void launch_apply_kernel(
const gko::kernels::batch_cg::storage_config& sconf, LogType& logger,
PrecType& prec, const BatchMatrixType mat,
const ValueType* const __restrict__ b_values,
ValueType* const __restrict__ x_values,
ValueType* const __restrict__ workspace, const int& group_size,
const int& shared_size) const
{
auto num_rows = mat.num_rows;

const dim3 block(group_size);
const dim3 grid(mat.num_batch_items);

auto max_iters = settings_.max_iterations;
auto res_tol = settings_.residual_tol;

exec_->get_queue()->submit([&](sycl::handler& cgh) {
sycl::accessor<ValueType, 1, sycl::access_mode::read_write,
sycl::access::target::local>
slm_values(sycl::range<1>(shared_size), cgh);

cgh.parallel_for(
sycl_nd_range(grid, block),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(
subgroup_size)]] [[intel::kernel_args_restrict]] {
auto batch_id = item_ct1.get_group_linear_id();
const auto mat_global_entry =
gko::batch::matrix::extract_batch_item(mat, batch_id);
const ValueType* const b_global_entry =
gko::batch::multi_vector::batch_item_ptr(
b_values, 1, num_rows, batch_id);
ValueType* const x_global_entry =
gko::batch::multi_vector::batch_item_ptr(
x_values, 1, num_rows, batch_id);
apply_kernel<StopType, n_shared_total>(
sconf, max_iters, res_tol, logger, prec,
mat_global_entry, b_global_entry, x_global_entry,
num_rows, mat.get_single_item_num_nnz(),
static_cast<ValueType*>(slm_values.get_pointer()),
item_ct1, workspace);
});
});
}

template <typename BatchMatrixType, typename PrecType, typename StopType,
typename LogType>
void call_kernel(
LogType logger, const BatchMatrixType& mat, PrecType prec,
const gko::batch::multi_vector::uniform_batch<const ValueType>& b,
const gko::batch::multi_vector::uniform_batch<ValueType>& x) const
{
using real_type = typename gko::remove_complex<ValueType>;
const size_type num_batch_items = mat.num_batch_items;
const auto num_rows = mat.num_rows;
const auto num_rhs = b.num_rhs;
GKO_ASSERT(num_rhs == 1);

auto device = exec_->get_queue()->get_device();
int group_size =
device.get_info<sycl::info::device::max_work_group_size>();
if (group_size > num_rows) {
group_size = get_group_size(num_rows);
}
group_size = std::min(
std::max(group_size, static_cast<int>(2 * config::warp_size)),
static_cast<int>(max_group_size));

// reserve 3 for intermediate rho, norms,
// alpha, and for reduce_over_group
// If the value available is negative, then set it to 0
const int static_var_mem =
(group_size + 3) * sizeof(ValueType) + 2 * sizeof(real_type);
int shmem_per_blk = std::max(
static_cast<int>(
device.get_info<sycl::info::device::local_mem_size>()) -
static_var_mem,
0);
const int padded_num_rows = num_rows;
const size_type prec_size = PrecType::dynamic_work_size(
padded_num_rows, mat.get_single_item_num_nnz());
const auto sconf =
gko::kernels::batch_cg::compute_shared_storage<PrecType, ValueType>(
shmem_per_blk, padded_num_rows, mat.get_single_item_num_nnz(),
b.num_rhs);
const size_t shared_size = sconf.n_shared * padded_num_rows +
(sconf.prec_shared ? prec_size : 0);
auto workspace = gko::array<ValueType>(
exec_,
sconf.gmem_stride_bytes * num_batch_items / sizeof(ValueType));
GKO_ASSERT(sconf.gmem_stride_bytes % sizeof(ValueType) == 0);

ValueType* const workspace_data = workspace.get_data();
int n_shared_total = sconf.n_shared + int(sconf.prec_shared);

// template
// launch_apply_kernel<StopType, subgroup_size, n_shared_total>
if (num_rows <= 32 && n_shared_total == 6)
launch_apply_kernel<StopType, 16, 6>(
sconf, logger, prec, mat, b.values, x.values, workspace_data,
group_size, shared_size);
else if (num_rows <= 256 && n_shared_total == 6)
launch_apply_kernel<StopType, 32, 6>(
sconf, logger, prec, mat, b.values, x.values, workspace_data,
group_size, shared_size);
else {
switch (n_shared_total) {
case 0:
launch_apply_kernel<StopType, 32, 0>(
sconf, logger, prec, mat, b.values, x.values,
workspace_data, group_size, shared_size);
break;
case 1:
launch_apply_kernel<StopType, 32, 1>(
sconf, logger, prec, mat, b.values, x.values,
workspace_data, group_size, shared_size);
break;
case 2:
launch_apply_kernel<StopType, 32, 2>(
sconf, logger, prec, mat, b.values, x.values,
workspace_data, group_size, shared_size);
break;
case 3:
launch_apply_kernel<StopType, 32, 3>(
sconf, logger, prec, mat, b.values, x.values,
workspace_data, group_size, shared_size);
break;
case 4:
launch_apply_kernel<StopType, 32, 4>(
sconf, logger, prec, mat, b.values, x.values,
workspace_data, group_size, shared_size);
break;
case 5:
launch_apply_kernel<StopType, 32, 5>(
sconf, logger, prec, mat, b.values, x.values,
workspace_data, group_size, shared_size);
break;
case 6:
launch_apply_kernel<StopType, 32, 6>(
sconf, logger, prec, mat, b.values, x.values,
workspace_data, group_size, shared_size);
break;
default:
GKO_NOT_IMPLEMENTED;
}
}
}

private:
std::shared_ptr<const DefaultExecutor> exec_;
const settings<remove_complex<ValueType>> settings_;
};


template <typename ValueType>
void apply(std::shared_ptr<const DefaultExecutor> exec,
const settings<remove_complex<ValueType>>& settings,
Expand All @@ -58,7 +231,9 @@ void apply(std::shared_ptr<const DefaultExecutor> exec,
batch::MultiVector<ValueType>* const x,
batch::log::detail::log_data<remove_complex<ValueType>>& logdata)
{
GKO_NOT_IMPLEMENTED;
auto dispatcher = batch_solver::create_dispatcher<ValueType>(
kernel_caller<ValueType>(exec, settings), settings, mat, prec);
dispatcher.apply(b, x, logdata);
}


Expand Down
Loading

0 comments on commit 03b7cac

Please sign in to comment.