Skip to content

Commit

Permalink
Merge (#1443): Add device kernels for batch Bicgstab solver.
Browse files Browse the repository at this point in the history
This PR adds the batch bicgstab solver kernels for CUDA, HIP and DPCPP backends. Some additional single rhs vector kernels are also added into the batch multivector kernels.

Related PR: #1443
  • Loading branch information
pratikvn authored Nov 5, 2023
2 parents 4a4eeb8 + a1b84d4 commit 47b3267
Show file tree
Hide file tree
Showing 37 changed files with 2,128 additions and 236 deletions.
54 changes: 54 additions & 0 deletions common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,28 @@ __global__ __launch_bounds__(
}


template <typename Group, typename ValueType>
__device__ __forceinline__ void single_rhs_compute_conj_dot(Group subgroup,
const int num_rows,
const ValueType* x,
const ValueType* y,
ValueType& result)

{
ValueType val = zero<ValueType>();
for (int r = subgroup.thread_rank(); r < num_rows; r += subgroup.size()) {
val += conj(x[r]) * y[r];
}

// subgroup level reduction
val = reduce(subgroup, val, thrust::plus<ValueType>{});

if (subgroup.thread_rank() == 0) {
result = val;
}
}


template <typename Group, typename ValueType, typename Mapping>
__device__ __forceinline__ void gen_one_dot(
const gko::batch::multi_vector::batch_item<const ValueType>& x,
Expand Down Expand Up @@ -165,6 +187,27 @@ __launch_bounds__(default_block_size, sm_oversubscription) void compute_gen_dot_
}


template <typename Group, typename ValueType>
__device__ __forceinline__ void single_rhs_compute_norm2(
Group subgroup, const int num_rows, const ValueType* x,
remove_complex<ValueType>& result)
{
using real_type = typename gko::remove_complex<ValueType>;
real_type val = zero<real_type>();

for (int r = subgroup.thread_rank(); r < num_rows; r += subgroup.size()) {
val += squared_norm(x[r]);
}

// subgroup level reduction
val = reduce(subgroup, val, thrust::plus<remove_complex<ValueType>>{});

if (subgroup.thread_rank() == 0) {
result = sqrt(val);
}
}


template <typename Group, typename ValueType>
__device__ __forceinline__ void one_norm2(
const gko::batch::multi_vector::batch_item<const ValueType>& x,
Expand Down Expand Up @@ -238,6 +281,17 @@ __global__ __launch_bounds__(
}


template <typename ValueType>
__device__ __forceinline__ void single_rhs_copy(const int num_rows,
const ValueType* in,
ValueType* out)
{
for (int iz = threadIdx.x; iz < num_rows; iz += blockDim.x) {
out[iz] = in[iz];
}
}


/**
* Copies the values of one multi-vector into another.
*
Expand Down
2 changes: 1 addition & 1 deletion common/cuda_hip/log/batch_logger.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
template <typename RealType>
class SimpleFinalLogger final {
public:
using real_type = remove_complex<RealType>;
using real_type = RealType;

SimpleFinalLogger(real_type* const batch_residuals, int* const batch_iters)
: final_residuals_{batch_residuals}, final_iters_{batch_iters}
Expand Down
13 changes: 3 additions & 10 deletions common/cuda_hip/preconditioner/batch_identity.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,9 @@ public:
return 0;
}

__device__ __forceinline__ void generate(
size_type,
const gko::batch::matrix::ell::batch_item<const ValueType, gko::int32>&,
ValueType*)
{}

__device__ __forceinline__ void generate(
size_type,
const gko::batch::matrix::dense::batch_item<const ValueType>&,
ValueType*)
template <typename batch_item_type>
__device__ __forceinline__ void generate(size_type, const batch_item_type&,
ValueType*)
{}

__device__ __forceinline__ void apply(const int num_rows,
Expand Down
Loading

0 comments on commit 47b3267

Please sign in to comment.