Skip to content

Commit

Permalink
Reshape hessenberg_iter view to 'logical layout' (one column per rhs)
Browse files Browse the repository at this point in the history
for kernels that do not use the full Hessenberg matrix
  • Loading branch information
nbeams committed Aug 8, 2024
1 parent 0f2706c commit adf30a2
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 90 deletions.
29 changes: 13 additions & 16 deletions common/unified/solver/common_gmres_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,30 +69,28 @@ void hessenberg_qr(std::shared_ptr<const DefaultExecutor> exec,
exec,
[] GKO_KERNEL(auto rhs, auto givens_sin, auto givens_cos,
auto residual_norm, auto residual_norm_collection,
auto hessenberg_iter, auto iter, auto num_rhs,
auto final_iter_nums, auto stop_status) {
auto hessenberg_iter, auto iter, auto final_iter_nums,
auto stop_status) {
using value_type = std::decay_t<decltype(givens_sin(0, 0))>;
if (stop_status[rhs].has_stopped()) {
return;
}
// increment iteration count
final_iter_nums[rhs]++;
auto hess_this =
hessenberg_iter(0, rhs); // hessenberg_iter(0, rhs);
auto hess_next =
hessenberg_iter(0, num_rhs + rhs); // hessenberg_iter(1, rhs);
auto hess_this = hessenberg_iter(0, rhs);
auto hess_next = hessenberg_iter(1, rhs);
// apply previous Givens rotations to column
for (decltype(iter) j = 0; j < iter; ++j) {
// in here: hess_this = hessenberg_iter(j, rhs);
// hess_next = hessenberg_iter(j+1, rhs);
hess_next = hessenberg_iter(0, (j + 1) * num_rhs + rhs);
hess_next = hessenberg_iter(j + 1, rhs);
const auto gc = givens_cos(j, rhs);
const auto gs = givens_sin(j, rhs);
const auto out1 = gc * hess_this + gs * hess_next;
const auto out2 = -conj(gs) * hess_this + conj(gc) * hess_next;
hessenberg_iter(0, j * num_rhs + rhs) = out1;
hessenberg_iter(0, (j + 1) * num_rhs + rhs) = hess_this = out2;
hess_next = hessenberg_iter(0, (j + 2) * num_rhs + rhs);
hessenberg_iter(j, rhs) = out1;
hessenberg_iter(j + 1, rhs) = hess_this = out2;
hess_next = hessenberg_iter(j + 2, rhs);
}
// hess_this is hessenberg_iter(iter, rhs) and
// hess_next is hessenberg_iter(iter + 1, rhs)
Expand All @@ -112,9 +110,8 @@ void hessenberg_qr(std::shared_ptr<const DefaultExecutor> exec,
givens_sin(iter, rhs) = gs = conj(hess_next) / hypotenuse;
}
// apply new Givens rotation to column
hessenberg_iter(0, iter * num_rhs + rhs) =
gc * hess_this + gs * hess_next;
hessenberg_iter(0, (iter + 1) * num_rhs + rhs) = zero<value_type>();
hessenberg_iter(iter, rhs) = gc * hess_this + gs * hess_next;
hessenberg_iter(iter + 1, rhs) = zero<value_type>();
// apply new Givens rotation to RHS of least-squares problem
const auto rnc_new =
-conj(gs) * residual_norm_collection(iter, rhs);
Expand All @@ -123,9 +120,9 @@ void hessenberg_qr(std::shared_ptr<const DefaultExecutor> exec,
gc * residual_norm_collection(iter, rhs);
residual_norm(0, rhs) = abs(rnc_new);
},
residual_norm->get_size()[1], givens_sin, givens_cos, residual_norm,
residual_norm_collection, hessenberg_iter, iter,
residual_norm->get_size()[1], final_iter_nums, stop_status);
hessenberg_iter->get_size()[1], givens_sin, givens_cos, residual_norm,
residual_norm_collection, hessenberg_iter, iter, final_iter_nums,
stop_status);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(
Expand Down
6 changes: 4 additions & 2 deletions common/unified/solver/gmres_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,10 @@ void multi_dot(std::shared_ptr<const DefaultExecutor> exec,
next_krylov(row, irhs);
},
GKO_KERNEL_REDUCE_SUM(ValueType), hessenberg_col->get_values(),
gko::dim<2>{next_krylov->get_size()[0],
hessenberg_col->get_size()[1] - next_krylov->get_size()[1]},
gko::dim<2>{
next_krylov->get_size()[0],
hessenberg_col->get_size()[0] * hessenberg_col->get_size()[1] -
next_krylov->get_size()[1]},
krylov_bases, next_krylov, next_krylov->get_size()[1],
next_krylov->get_size()[0]);
}
Expand Down
49 changes: 27 additions & 22 deletions core/solver/gmres.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ void orthogonalize_mgs(matrix::Dense<ValueType>* hessenberg_iter,
// i)
// next_krylov -= hessenberg(i, restart_iter) * krylov_bases(:,
// i)
auto hessenberg_entry = hessenberg_iter->create_submatrix(
span{0, 1}, span{i * num_rhs, (i + 1) * num_rhs});
auto hessenberg_entry =
hessenberg_iter->create_submatrix(span{i, i + 1}, span{0, num_rhs});
auto krylov_basis = ::gko::detail::create_submatrix_helper(
krylov_bases, dim<2>{num_rows, num_rhs},
span{local_num_rows * i, local_num_rows * (i + 1)},
Expand Down Expand Up @@ -191,19 +191,18 @@ void finish_reduce(matrix::Dense<ValueType>* hessenberg_iter,
// below the diagonal in the "full" matrix are skipped, because they will
// be used to hold the norm of next_krylov for each rhs.
auto hessenberg_reduce = hessenberg_iter->create_submatrix(
span{0, 1}, span{0, num_rhs * (restart_iter + 1)});
span{0, restart_iter + 1}, span{0, num_rhs});
int message_size = static_cast<int>((restart_iter + 1) * num_rhs);
if (experimental::mpi::requires_host_buffer(exec, comm)) {
::gko::detail::DenseCache<ValueType> host_reduction_buffer;
host_reduction_buffer.init(exec->get_master(),
hessenberg_reduce->get_size());
host_reduction_buffer->copy_from(hessenberg_reduce);
comm.all_reduce(exec->get_master(), host_reduction_buffer->get_values(),
static_cast<int>(hessenberg_reduce->get_size()[1]),
MPI_SUM);
message_size, MPI_SUM);
hessenberg_reduce->copy_from(host_reduction_buffer.get());
} else {
comm.all_reduce(exec, hessenberg_reduce->get_values(),
static_cast<int>(hessenberg_reduce->get_size()[1]),
comm.all_reduce(exec, hessenberg_reduce->get_values(), message_size,
MPI_SUM);
}
}
Expand All @@ -228,8 +227,8 @@ void orthogonalize_cgs(matrix::Dense<ValueType>* hessenberg_iter,
for (size_type i = 0; i <= restart_iter; i++) {
// next_krylov -= hessenberg(i, restart_iter) * krylov_bases(:,
// i)
auto hessenberg_entry = hessenberg_iter->create_submatrix(
span{0, 1}, span{i * num_rhs, (i + 1) * num_rhs});
auto hessenberg_entry =
hessenberg_iter->create_submatrix(span{i, i + 1}, span{0, num_rhs});
auto krylov_col = ::gko::detail::create_submatrix_helper(
krylov_bases, dim<2>{num_rows, num_rhs},
span{local_num_rows * i, local_num_rows * (i + 1)},
Expand Down Expand Up @@ -260,8 +259,8 @@ void orthogonalize_cgs2(matrix::Dense<ValueType>* hessenberg_iter,
for (size_type i = 0; i <= restart_iter; i++) {
// next_krylov -= hessenberg(i, restart_iter) * krylov_bases(:,
// i)
auto hessenberg_entry = hessenberg_iter->create_submatrix(
span{0, 1}, span{i * num_rhs, (i + 1) * num_rhs});
auto hessenberg_entry =
hessenberg_iter->create_submatrix(span{i, i + 1}, span{0, num_rhs});
auto krylov_col = ::gko::detail::create_submatrix_helper(
krylov_bases, dim<2>{num_rows, num_rhs},
span{local_num_rows * i, local_num_rows * (i + 1)},
Expand All @@ -270,7 +269,7 @@ void orthogonalize_cgs2(matrix::Dense<ValueType>* hessenberg_iter,
}
// Re-orthogonalize
auto hessenberg_aux_iter = hessenberg_aux->create_submatrix(
span{0, 1}, span{0, (restart_iter + 2) * num_rhs});
span{0, restart_iter + 2}, span{0, num_rhs});
exec->run(gmres::make_multi_dot(
gko::detail::get_local(krylov_basis_small.get()),
gko::detail::get_local(next_krylov), hessenberg_aux_iter.get()));
Expand All @@ -280,8 +279,8 @@ void orthogonalize_cgs2(matrix::Dense<ValueType>* hessenberg_iter,
for (size_type i = 0; i <= restart_iter; i++) {
// next_krylov -= hessenberg(i, restart_iter) * krylov_bases(:,
// i)
auto hessenberg_entry = hessenberg_aux->create_submatrix(
span{0, 1}, span{i * num_rhs, (i + 1) * num_rhs});
auto hessenberg_entry =
hessenberg_aux->create_submatrix(span{i, i + 1}, span{0, num_rhs});
auto krylov_col = ::gko::detail::create_submatrix_helper(
krylov_bases, dim<2>{num_rows, num_rhs},
span{local_num_rows * i, local_num_rows * (i + 1)},
Expand Down Expand Up @@ -353,10 +352,13 @@ void Gmres<ValueType>::apply_dense_impl(const VectorType* dense_b,
// Krylov basis vector, for the (j % num_rhs)th RHS vector.
auto hessenberg = this->template create_workspace_op<LocalVector>(
ws::hessenberg, dim<2>{krylov_dim, (krylov_dim + 1) * num_rhs});
// Because the auxiliary Hessenberg workspace only ever stores one
// iteration of data at a time, we store it in the "logical" layout
// from the start.
LocalVector* hessenberg_aux = nullptr;
if (this->parameters_.orthog_method == gmres::orthog_method::cgs2) {
hessenberg_aux = this->template create_workspace_op<LocalVector>(
ws::hessenberg_aux, dim<2>{1, (krylov_dim + 1) * num_rhs});
ws::hessenberg_aux, dim<2>{(krylov_dim + 1), num_rhs});
}
auto givens_sin = this->template create_workspace_op<LocalVector>(
ws::givens_sin, dim<2>{krylov_dim, num_rhs});
Expand Down Expand Up @@ -506,12 +508,16 @@ void Gmres<ValueType>::apply_dense_impl(const VectorType* dense_b,
this->get_preconditioner()->apply(this_krylov,
preconditioned_krylov_vector);

// Create view of current "column" in the hessenberg matrix:
// Create view of current column in the hessenberg matrix:
// hessenberg_iter = hessenberg(:, restart_iter), which
// is actually stored as a row, hessenberg(restart_iter, :)
auto hessenberg_iter =
hessenberg->create_submatrix(span{restart_iter, restart_iter + 1},
span{0, num_rhs * (restart_iter + 2)});
// is actually stored as a row, hessenberg(restart_iter, :),
// but we will reshape it for viewing in hessenberg_iter.
auto hessenberg_iter = LocalVector::create(
exec, dim<2>{restart_iter + 2, num_rhs},
make_array_view(exec, (restart_iter + 2) * num_rhs,
hessenberg->get_values() +
restart_iter * hessenberg->get_size()[1]),
num_rhs);

// Start of Arnoldi
// next_krylov = A * preconditioned_krylov_vector
Expand All @@ -537,8 +543,7 @@ void Gmres<ValueType>::apply_dense_impl(const VectorType* dense_b,
// (stored in hessenberg(restart_iter, (restart_iter + 1) * num_rhs))
// next_krylov /= hessenberg(restart_iter+1, restart_iter)
auto hessenberg_norm_entry = hessenberg_iter->create_submatrix(
span{0, 1},
span{(restart_iter + 1) * num_rhs, (restart_iter + 2) * num_rhs});
span{restart_iter + 1, restart_iter + 2}, span{0, num_rhs});
help_compute_norm<ValueType>::compute_next_krylov_norm_into_hessenberg(
next_krylov.get(), hessenberg_norm_entry.get(),
next_krylov_norm_tmp, reduction_tmp);
Expand Down
44 changes: 17 additions & 27 deletions reference/solver/common_gmres_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,14 @@ template <typename ValueType>
void calculate_sin_and_cos(matrix::Dense<ValueType>* givens_sin,
matrix::Dense<ValueType>* givens_cos,
matrix::Dense<ValueType>* hessenberg_iter,
size_type iter, const size_type num_rhs,
const size_type rhs)
size_type iter, const size_type rhs)
{
if (is_zero(hessenberg_iter->at(0, iter * num_rhs + rhs))) {
if (is_zero(hessenberg_iter->at(iter, rhs))) {
givens_cos->at(iter, rhs) = zero<ValueType>();
givens_sin->at(iter, rhs) = one<ValueType>();
} else {
auto this_hess = hessenberg_iter->at(0, iter * num_rhs + rhs);
auto next_hess = hessenberg_iter->at(0, (iter + 1) * num_rhs + rhs);
auto this_hess = hessenberg_iter->at(iter, rhs);
auto next_hess = hessenberg_iter->at(iter + 1, rhs);
const auto scale = abs(this_hess) + abs(next_hess);
const auto hypotenuse =
scale * sqrt(abs(this_hess / scale) * abs(this_hess / scale) +
Expand All @@ -53,40 +52,32 @@ template <typename ValueType>
void givens_rotation(matrix::Dense<ValueType>* givens_sin,
matrix::Dense<ValueType>* givens_cos,
matrix::Dense<ValueType>* hessenberg_iter, size_type iter,
const size_type num_rhs,
const stopping_status* stop_status)
{
for (size_type i = 0; i < num_rhs; ++i) {
for (size_type i = 0; i < hessenberg_iter->get_size()[1]; ++i) {
if (stop_status[i].has_stopped()) {
continue;
}
for (size_type j = 0; j < iter; ++j) {
auto temp =
givens_cos->at(j, i) * hessenberg_iter->at(0, j * num_rhs + i) +
givens_sin->at(j, i) *
hessenberg_iter->at(0, (j + 1) * num_rhs + i);
hessenberg_iter->at(0, (j + 1) * num_rhs + i) =
-conj(givens_sin->at(j, i)) *
hessenberg_iter->at(0, j * num_rhs + i) +
conj(givens_cos->at(j, i)) *
hessenberg_iter->at(0, (j + 1) * num_rhs + i);
hessenberg_iter->at(0, j * num_rhs + i) = temp;
auto temp = givens_cos->at(j, i) * hessenberg_iter->at(j, i) +
givens_sin->at(j, i) * hessenberg_iter->at(j + 1, i);
hessenberg_iter->at(j + 1, i) =
-conj(givens_sin->at(j, i)) * hessenberg_iter->at(j, i) +
conj(givens_cos->at(j, i)) * hessenberg_iter->at(j + 1, i);
hessenberg_iter->at(j, i) = temp;
// temp = cos(j)*hessenberg(j) +
// sin(j)*hessenberg(j+1)
// hessenberg(j+1) = -conj(sin(j))*hessenberg(j) +
// conj(cos(j))*hessenberg(j+1)
// hessenberg(j) = temp;
}

calculate_sin_and_cos(givens_sin, givens_cos, hessenberg_iter, iter,
num_rhs, i);
calculate_sin_and_cos(givens_sin, givens_cos, hessenberg_iter, iter, i);

hessenberg_iter->at(0, iter * num_rhs + i) =
givens_cos->at(iter, i) *
hessenberg_iter->at(0, iter * num_rhs + i) +
givens_sin->at(iter, i) *
hessenberg_iter->at(0, (iter + 1) * num_rhs + i);
hessenberg_iter->at(0, (iter + 1) * num_rhs + i) = zero<ValueType>();
hessenberg_iter->at(iter, i) =
givens_cos->at(iter, i) * hessenberg_iter->at(iter, i) +
givens_sin->at(iter, i) * hessenberg_iter->at(iter + 1, i);
hessenberg_iter->at(iter + 1, i) = zero<ValueType>();
// hessenberg(iter) = cos(iter)*hessenberg(iter) +
// sin(iter)*hessenberg(iter + 1)
// hessenberg(iter+1) = 0
Expand Down Expand Up @@ -160,8 +151,7 @@ void hessenberg_qr(std::shared_ptr<const ReferenceExecutor> exec,
}
}

givens_rotation(givens_sin, givens_cos, hessenberg_iter, iter,
residual_norm->get_size()[1], stop_status);
givens_rotation(givens_sin, givens_cos, hessenberg_iter, iter, stop_status);
calculate_next_residual_norm(givens_sin, givens_cos, residual_norm,
residual_norm_collection, iter, stop_status);
}
Expand Down
16 changes: 8 additions & 8 deletions reference/solver/gmres_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,14 @@ void multi_dot(std::shared_ptr<const ReferenceExecutor> exec,
{
auto num_rhs = next_krylov->get_size()[1];
auto krylov_bases_rowoffset = next_krylov->get_size()[0];
for (size_type i = 0; i < hessenberg_col->get_size()[1]; ++i) {
auto ivec = i / num_rhs;
auto irhs = i % num_rhs;
hessenberg_col->at(0, i) = zero<ValueType>();
for (size_type j = 0; j < krylov_bases_rowoffset; ++j) {
hessenberg_col->at(0, i) +=
krylov_bases->at(ivec * krylov_bases_rowoffset + j, irhs) *
next_krylov->at(j, irhs);
for (size_type i = 0; i < hessenberg_col->get_size()[0] - 1; ++i) {
for (size_type k = 0; k < num_rhs; ++k) {
hessenberg_col->at(i, k) = zero<ValueType>();
for (size_type j = 0; j < krylov_bases_rowoffset; ++j) {
hessenberg_col->at(i, k) +=
conj(krylov_bases->at(i * krylov_bases_rowoffset + j, k)) *
next_krylov->at(j, k);
}
}
}
}
Expand Down
Loading

0 comments on commit adf30a2

Please sign in to comment.