Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[dpcpp] move to proper headers
Browse files Browse the repository at this point in the history
pratikvn committed Aug 19, 2024

Verified

This commit was signed with the committer’s verified signature.
pratikvn Pratik Nayak
1 parent 8cbcae2 commit 43e788f
Showing 6 changed files with 102 additions and 63 deletions.
64 changes: 30 additions & 34 deletions dpcpp/base/batch_multi_vector_kernels.dp.cpp
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@

#include "core/base/batch_struct.hpp"
#include "core/components/prefix_sum_kernels.hpp"
#include "dpcpp/base/batch_multi_vector_kernels.hpp"
#include "dpcpp/base/batch_struct.hpp"
#include "dpcpp/base/config.hpp"
#include "dpcpp/base/dim3.dp.hpp"
@@ -29,17 +30,9 @@
namespace gko {
namespace kernels {
namespace dpcpp {
/**
* @brief The MultiVector matrix format namespace.
* @ref MultiVector
* @ingroup batch_multi_vector
*/
namespace batch_multi_vector {


#include "dpcpp/base/batch_multi_vector_kernels.hpp.inc"


template <typename ValueType>
void scale(std::shared_ptr<const DefaultExecutor> exec,
const batch::MultiVector<ValueType>* const alpha,
@@ -71,7 +64,7 @@ void scale(std::shared_ptr<const DefaultExecutor> exec,
const auto alpha_b =
batch::extract_batch_item(alpha_ub, group_id);
const auto x_b = batch::extract_batch_item(x_ub, group_id);
scale_kernel(
batch_single_kernels::scale_kernel(
alpha_b, x_b, item_ct1,
[](int row, int col, int stride) { return 0; });
});
@@ -85,10 +78,11 @@ void scale(std::shared_ptr<const DefaultExecutor> exec,
const auto alpha_b =
batch::extract_batch_item(alpha_ub, group_id);
const auto x_b = batch::extract_batch_item(x_ub, group_id);
scale_kernel(alpha_b, x_b, item_ct1,
[](int row, int col, int stride) {
return row * stride + col;
});
batch_single_kernels::scale_kernel(
alpha_b, x_b, item_ct1,
[](int row, int col, int stride) {
return row * stride + col;
});
});
});
} else {
@@ -100,7 +94,7 @@ void scale(std::shared_ptr<const DefaultExecutor> exec,
const auto alpha_b =
batch::extract_batch_item(alpha_ub, group_id);
const auto x_b = batch::extract_batch_item(x_ub, group_id);
scale_kernel(
batch_single_kernels::scale_kernel(
alpha_b, x_b, item_ct1,
[](int row, int col, int stride) { return col; });
});
@@ -144,8 +138,9 @@ void add_scaled(std::shared_ptr<const DefaultExecutor> exec,
batch::extract_batch_item(alpha_ub, group_id);
const auto x_b = batch::extract_batch_item(x_ub, group_id);
const auto y_b = batch::extract_batch_item(y_ub, group_id);
add_scaled_kernel(alpha_b, x_b, y_b, item_ct1,
[](auto col) { return 0; });
batch_single_kernels::add_scaled_kernel(
alpha_b, x_b, y_b, item_ct1,
[](auto col) { return 0; });
});
});
} else {
@@ -158,8 +153,9 @@ void add_scaled(std::shared_ptr<const DefaultExecutor> exec,
batch::extract_batch_item(alpha_ub, group_id);
const auto x_b = batch::extract_batch_item(x_ub, group_id);
const auto y_b = batch::extract_batch_item(y_ub, group_id);
add_scaled_kernel(alpha_b, x_b, y_b, item_ct1,
[](auto col) { return col; });
batch_single_kernels::add_scaled_kernel(
alpha_b, x_b, y_b, item_ct1,
[](auto col) { return col; });
});
});
}
@@ -206,7 +202,7 @@ void compute_dot(std::shared_ptr<const DefaultExecutor> exec,
batch::extract_batch_item(y_ub, group_id);
const auto res_b =
batch::extract_batch_item(res_ub, group_id);
single_rhs_compute_conj_dot_sg(
batch_single_kernels::single_rhs_compute_conj_dot_sg(
x_b.num_rows, x_b.values, y_b.values,
res_b.values[0], item_ct1);
});
@@ -226,7 +222,7 @@ void compute_dot(std::shared_ptr<const DefaultExecutor> exec,
batch::extract_batch_item(y_ub, group_id);
const auto res_b =
batch::extract_batch_item(res_ub, group_id);
compute_gen_dot_product_kernel(
batch_single_kernels::compute_gen_dot_product_kernel(
x_b, y_b, res_b, item_ct1,
[](auto val) { return val; });
});
@@ -272,7 +268,7 @@ void compute_conj_dot(std::shared_ptr<const DefaultExecutor> exec,
const auto y_b = batch::extract_batch_item(y_ub, group_id);
const auto res_b =
batch::extract_batch_item(res_ub, group_id);
compute_gen_dot_product_kernel(
batch_single_kernels::compute_gen_dot_product_kernel(
x_b, y_b, res_b, item_ct1,
[](auto val) { return conj(val); });
});
@@ -308,17 +304,16 @@ void compute_norm2(std::shared_ptr<const DefaultExecutor> exec,
exec->get_queue()->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl_nd_range(grid, block),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(max_subgroup_size)]] {
auto group = item_ct1.get_group();
auto group_id = group.get_group_linear_id();
const auto x_b =
batch::extract_batch_item(x_ub, group_id);
const auto res_b =
batch::extract_batch_item(res_ub, group_id);
single_rhs_compute_norm2_sg(x_b.num_rows, x_b.values,
res_b.values[0], item_ct1);
});
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(
max_subgroup_size)]] {
auto group = item_ct1.get_group();
auto group_id = group.get_group_linear_id();
const auto x_b = batch::extract_batch_item(x_ub, group_id);
const auto res_b =
batch::extract_batch_item(res_ub, group_id);
batch_single_kernels::single_rhs_compute_norm2_sg(
x_b.num_rows, x_b.values, res_b.values[0], item_ct1);
});
});
} else {
exec->get_queue()->submit([&](sycl::handler& cgh) {
@@ -332,7 +327,8 @@ void compute_norm2(std::shared_ptr<const DefaultExecutor> exec,
batch::extract_batch_item(x_ub, group_id);
const auto res_b =
batch::extract_batch_item(res_ub, group_id);
compute_norm2_kernel(x_b, res_b, item_ct1);
batch_single_kernels::compute_norm2_kernel(x_b, res_b,
item_ct1);
});
});
}
@@ -371,7 +367,7 @@ void copy(std::shared_ptr<const DefaultExecutor> exec,
const auto x_b = batch::extract_batch_item(x_ub, group_id);
const auto result_b =
batch::extract_batch_item(result_ub, group_id);
copy_kernel(x_b, result_b, item_ct1);
batch_single_kernels::copy_kernel(x_b, result_b, item_ct1);
});
});
}
Original file line number Diff line number Diff line change
@@ -2,6 +2,29 @@
//
// SPDX-License-Identifier: BSD-3-Clause


#include <memory>

#include <CL/sycl.hpp>

#include "core/base/batch_struct.hpp"
#include "dpcpp/base/batch_struct.hpp"
#include "dpcpp/base/config.hpp"
#include "dpcpp/base/dim3.dp.hpp"
#include "dpcpp/base/dpct.hpp"
#include "dpcpp/base/helper.hpp"
#include "dpcpp/components/cooperative_groups.dp.hpp"
#include "dpcpp/components/intrinsics.dp.hpp"
#include "dpcpp/components/reduction.dp.hpp"
#include "dpcpp/components/thread_ids.dp.hpp"


namespace gko {
namespace kernels {
namespace GKO_DEVICE_NAMESPACE {
namespace batch_single_kernels {


template <typename ValueType, typename Mapping>
__dpct_inline__ void scale_kernel(
const gko::batch::multi_vector::batch_item<const ValueType>& alpha,
@@ -229,3 +252,9 @@ __dpct_inline__ void copy_kernel(
out.values[i * out.stride + j] = in.values[i * in.stride + j];
}
}


} // namespace batch_single_kernels
} // namespace GKO_DEVICE_NAMESPACE
} // namespace kernels
} // namespace gko
2 changes: 1 addition & 1 deletion dpcpp/solver/batch_bicgstab_kernels.dp.cpp
Original file line number Diff line number Diff line change
@@ -13,6 +13,7 @@
#include "core/base/batch_struct.hpp"
#include "core/matrix/batch_struct.hpp"
#include "core/solver/batch_dispatch.hpp"
#include "dpcpp/base/batch_multi_vector_kernels.hpp"
#include "dpcpp/base/batch_struct.hpp"
#include "dpcpp/base/config.hpp"
#include "dpcpp/base/dim3.dp.hpp"
@@ -36,7 +37,6 @@ namespace dpcpp {
namespace batch_bicgstab {


#include "dpcpp/base/batch_multi_vector_kernels.hpp.inc"
#include "dpcpp/matrix/batch_csr_kernels.hpp.inc"
#include "dpcpp/matrix/batch_dense_kernels.hpp.inc"
#include "dpcpp/matrix/batch_ell_kernels.hpp.inc"
43 changes: 26 additions & 17 deletions dpcpp/solver/batch_bicgstab_kernels.hpp.inc
Original file line number Diff line number Diff line change
@@ -39,11 +39,13 @@ __dpct_inline__ void initialize(
item_ct1.barrier(sycl::access::fence_space::global_and_local);

if (sg_id == 0) {
single_rhs_compute_norm2_sg(num_rows, r_shared_entry, res_norm,
item_ct1);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_norm2_sg(num_rows, r_shared_entry, res_norm,
item_ct1);
} else if (sg_id == 1) {
single_rhs_compute_norm2_sg(num_rows, b_global_entry, rhs_norm,
item_ct1);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_norm2_sg(num_rows, b_global_entry, rhs_norm,
item_ct1);
}
item_ct1.barrier(sycl::access::fence_space::global_and_local);

@@ -86,8 +88,9 @@ __dpct_inline__ void compute_alpha(const int num_rows, const ValueType& rho_new,
const auto sg_id = sg.get_group_id();
const auto tid = item_ct1.get_local_linear_id();
if (sg_id == 0) {
single_rhs_compute_conj_dot_sg(num_rows, r_hat_shared_entry,
v_shared_entry, alpha, item_ct1);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_conj_dot_sg(num_rows, r_hat_shared_entry,
v_shared_entry, alpha, item_ct1);
}
item_ct1.barrier(sycl::access::fence_space::global_and_local);
if (tid == 0) {
@@ -123,11 +126,13 @@ __dpct_inline__ void compute_omega(const int num_rows,
const auto sg_id = sg.get_group_id();
const auto tid = item_ct1.get_local_linear_id();
if (sg_id == 0) {
single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry, s_shared_entry,
omega, item_ct1);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry,
s_shared_entry, omega, item_ct1);
} else if (sg_id == 1) {
single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry, t_shared_entry,
temp, item_ct1);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry,
t_shared_entry, temp, item_ct1);
}
item_ct1.barrier(sycl::access::fence_space::global_and_local);
if (tid == 0) {
@@ -308,8 +313,9 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf,

// rho_new = < r_hat , r > = (r_hat)' * (r)
if (sg_id == 0) {
single_rhs_compute_conj_dot_sg(num_rows, r_hat_sh, r_sh,
rho_new_sh[0], item_ct1);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_conj_dot_sg(num_rows, r_hat_sh, r_sh,
rho_new_sh[0], item_ct1);
}
item_ct1.barrier(sycl::access::fence_space::global_and_local);

@@ -338,8 +344,9 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf,

// an estimate of residual norms
if (sg_id == 0) {
single_rhs_compute_norm2_sg(num_rows, s_sh, norms_res_sh[0],
item_ct1);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_norm2_sg(num_rows, s_sh, norms_res_sh[0],
item_ct1);
}
item_ct1.barrier(sycl::access::fence_space::global_and_local);

@@ -368,8 +375,9 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf,
item_ct1.barrier(sycl::access::fence_space::global_and_local);

if (sg_id == 0)
single_rhs_compute_norm2_sg(num_rows, r_sh, norms_res_sh[0],
item_ct1);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_norm2_sg(num_rows, r_sh, norms_res_sh[0],
item_ct1);
if (tid == group_size - 1) {
rho_old_sh[0] = rho_new_sh[0];
}
@@ -379,6 +387,7 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf,
logger.log_iteration(batch_id, iter, norms_res_sh[0]);

// copy x back to global memory
copy_kernel(num_rows, x_sh, x_global_entry, item_ct1);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::copy_kernel(
num_rows, x_sh, x_global_entry, item_ct1);
item_ct1.barrier(sycl::access::fence_space::global_and_local);
}
2 changes: 1 addition & 1 deletion dpcpp/solver/batch_cg_kernels.dp.cpp
Original file line number Diff line number Diff line change
@@ -13,6 +13,7 @@
#include "core/base/batch_struct.hpp"
#include "core/matrix/batch_struct.hpp"
#include "core/solver/batch_dispatch.hpp"
#include "dpcpp/base/batch_multi_vector_kernels.hpp"
#include "dpcpp/base/batch_struct.hpp"
#include "dpcpp/base/config.hpp"
#include "dpcpp/base/dim3.dp.hpp"
@@ -36,7 +37,6 @@ namespace dpcpp {
namespace batch_cg {


#include "dpcpp/base/batch_multi_vector_kernels.hpp.inc"
#include "dpcpp/matrix/batch_csr_kernels.hpp.inc"
#include "dpcpp/matrix/batch_dense_kernels.hpp.inc"
#include "dpcpp/matrix/batch_ell_kernels.hpp.inc"
25 changes: 15 additions & 10 deletions dpcpp/solver/batch_cg_kernels.hpp.inc
Original file line number Diff line number Diff line change
@@ -40,11 +40,13 @@ __dpct_inline__ void initialize(
// Compute norms of rhs
// and rho_old = r' * z
if (sg_id == 0) {
single_rhs_compute_norm2_sg(num_rows, b_global_entry, rhs_norms,
item_ct1);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_norm2_sg(num_rows, b_global_entry, rhs_norms,
item_ct1);
} else if (sg_id == 1) {
single_rhs_compute_conj_dot_sg(num_rows, r_shared_entry, z_shared_entry,
rho_old, item_ct1);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_conj_dot_sg(num_rows, r_shared_entry,
z_shared_entry, rho_old, item_ct1);
}
item_ct1.barrier(sycl::access::fence_space::global_and_local);

@@ -80,9 +82,10 @@ __dpct_inline__ void update_x_and_r(
auto sg = item_ct1.get_sub_group();
const auto tid = item_ct1.get_local_linear_id();
if (sg.get_group_id() == 0) {
single_rhs_compute_conj_dot_sg(num_rows, p_shared_entry,
Ap_shared_entry, alpha_shared_entry,
item_ct1);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_conj_dot_sg(num_rows, p_shared_entry,
Ap_shared_entry, alpha_shared_entry,
item_ct1);
}
item_ct1.barrier(sycl::access::fence_space::global_and_local);
if (tid == 0) {
@@ -221,8 +224,9 @@ __dpct_inline__ void apply_kernel(

// rho_new = (r)' * (z)
if (sg_id == 0) {
single_rhs_compute_conj_dot_sg(num_rows, r_sh, z_sh, rho_new_sh[0],
item_ct1);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_conj_dot_sg(num_rows, r_sh, z_sh,
rho_new_sh[0], item_ct1);
}
item_ct1.barrier(sycl::access::fence_space::global_and_local);

@@ -239,6 +243,7 @@ __dpct_inline__ void apply_kernel(
logger.log_iteration(batch_id, iter, norms_res_sh[0]);

// copy x back to global memory
copy_kernel(num_rows, x_sh, x_global_entry, item_ct1);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::copy_kernel(
num_rows, x_sh, x_global_entry, item_ct1);
item_ct1.barrier(sycl::access::fence_space::global_and_local);
}

0 comments on commit 43e788f

Please sign in to comment.