diff --git a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp index 36a82807011..aab068d103e 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp +++ b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp @@ -135,10 +135,10 @@ class kernel_caller { static_cast(max_group_size)); // reserve 5 for intermediate rho-s, norms, - // alpha, omega, temp and for reduce_over_group + // alpha, omega, temp // If the value available is negative, then set it to 0 const int static_var_mem = - (group_size + 5) * sizeof(ValueType) + 2 * sizeof(real_type); + 5 * sizeof(ValueType) + 2 * sizeof(real_type); int shmem_per_blk = std::max( static_cast( device.get_info()) - diff --git a/dpcpp/solver/batch_cg_kernels.dp.cpp b/dpcpp/solver/batch_cg_kernels.dp.cpp index 31f6ad0e40d..02c40424a35 100644 --- a/dpcpp/solver/batch_cg_kernels.dp.cpp +++ b/dpcpp/solver/batch_cg_kernels.dp.cpp @@ -135,10 +135,10 @@ class kernel_caller { static_cast(max_group_size)); // reserve 3 for intermediate rho, - // alpha, reduce_over_group, and two norms + // alpha and two norms // If the value available is negative, then set it to 0 const int static_var_mem = - (group_size + 3) * sizeof(ValueType) + 2 * sizeof(real_type); + 3 * sizeof(ValueType) + 2 * sizeof(real_type); int shmem_per_blk = std::max( static_cast( device.get_info()) -