diff --git a/include/dlaf/communication/broadcast_panel.h b/include/dlaf/communication/broadcast_panel.h index e83e35a98d..c3cf7e0725 100644 --- a/include/dlaf/communication/broadcast_panel.h +++ b/include/dlaf/communication/broadcast_panel.h @@ -18,11 +18,13 @@ #include #include +#include #include #include #include #include #include +#include #include #include #include @@ -89,7 +91,6 @@ auto& get_taskchain(comm::CommunicatorPipeline& row return col_task_chain; } } -} // namespace internal /// Broadcast /// @@ -105,9 +106,6 @@ auto& get_taskchain(comm::CommunicatorPipeline& row /// - linking as external tile, if the tile is already available locally for the rank /// - receiving the tile from the owning rank (via a broadcast) /// -/// Be aware that the last tile will just be available on @p panel, but it won't be transposed to -/// @p panelT. -/// /// @param rank_root specifies on which rank the @p panel is the source of the data /// @param panel /// on rank_root it is the source panel (a) @@ -125,7 +123,8 @@ template & panel, matrix::Panel& panelT, comm::CommunicatorPipeline& row_task_chain, - comm::CommunicatorPipeline& col_task_chain) { + comm::CommunicatorPipeline& col_task_chain, + common::IterableRange2D range) { constexpr Coord axisT = orthogonal(axis); constexpr Coord coord = std::decay_t::coord; @@ -183,13 +182,6 @@ void broadcast(comm::IndexT_MPI rank_root, matrix::Panel& p auto& chain_step2 = internal::get_taskchain(row_task_chain, col_task_chain); - const SizeType last_tile = std::max(panelT.rangeStart(), panelT.rangeEnd() - 1); - const auto owner = dist.template rankGlobalTile(last_tile); - const auto range = dist.rankIndex().get(coordT) == owner - ? common::iterate_range2d(*panelT.iteratorLocal().begin(), - LocalTileIndex(coordT, panelT.rangeEndLocal() - 1, 1)) - : panelT.iteratorLocal(); - for (const auto& indexT : range) { auto [index_diag, owner_diag] = internal::transposedOwner(dist, indexT); @@ -208,6 +200,39 @@ void broadcast(comm::IndexT_MPI rank_root, matrix::Panel& p } } } +} // namespace internal + +template >> +void broadcast(comm::IndexT_MPI rank_root, matrix::Panel& panel, + matrix::Panel& panelT, + comm::CommunicatorPipeline& row_task_chain, + comm::CommunicatorPipeline& col_task_chain) { + constexpr Coord coordT = std::decay_t::coord; + const auto& dist = panel.parentDistribution(); + + const SizeType last_tile = std::max(panelT.rangeStart(), panelT.rangeEnd() - 1); + + if (panel.rangeStart() == panel.rangeEnd()) + return; + + const auto owner = dist.template rankGlobalTile(last_tile); + const auto range = dist.rankIndex().get(coordT) == owner + ? common::iterate_range2d(*panelT.iteratorLocal().begin(), + LocalTileIndex(coordT, panelT.rangeEndLocal() - 1, 1)) + : panelT.iteratorLocal(); + + internal::broadcast(rank_root, panel, panelT, row_task_chain, col_task_chain, range); +} + +template >> +void broadcast_all(comm::IndexT_MPI rank_root, matrix::Panel& panel, + matrix::Panel& panelT, + comm::CommunicatorPipeline& row_task_chain, + comm::CommunicatorPipeline& col_task_chain) { + internal::broadcast(rank_root, panel, panelT, row_task_chain, col_task_chain, panelT.iteratorLocal()); +} } } diff --git a/include/dlaf/eigensolver/reduction_to_band.h b/include/dlaf/eigensolver/reduction_to_band.h index 9a22ee168d..1f7dc13a3b 100644 --- a/include/dlaf/eigensolver/reduction_to_band.h +++ b/include/dlaf/eigensolver/reduction_to_band.h @@ -119,4 +119,18 @@ Matrix reduction_to_band(comm::CommunicatorGrid& grid, Matrix::call(grid, mat_a, band_size); } + +template +internal::CARed2BandResult ca_reduction_to_band(comm::CommunicatorGrid& grid, Matrix& mat_a, + const SizeType band_size) { + DLAF_ASSERT(matrix::square_size(mat_a), mat_a); + DLAF_ASSERT(matrix::square_blocksize(mat_a), mat_a); + DLAF_ASSERT(matrix::single_tile_per_block(mat_a), mat_a); + DLAF_ASSERT(matrix::equal_process_grid(mat_a, grid), mat_a, grid); + + DLAF_ASSERT(band_size >= 2, band_size); + DLAF_ASSERT(mat_a.blockSize().rows() % band_size == 0, mat_a.blockSize().rows(), band_size); + + return CAReductionToBand::call(grid, mat_a, band_size); +} } diff --git a/include/dlaf/eigensolver/reduction_to_band/api.h b/include/dlaf/eigensolver/reduction_to_band/api.h index 5ead6f0325..a107ee282c 100644 --- a/include/dlaf/eigensolver/reduction_to_band/api.h +++ b/include/dlaf/eigensolver/reduction_to_band/api.h @@ -24,9 +24,24 @@ struct ReductionToBand { const SizeType band_size); }; +template +struct CARed2BandResult { + Matrix taus_1st; + // hh_1st are stored in-place + Matrix taus_2nd; + Matrix hh_2nd; +}; + +template +struct CAReductionToBand { + static CARed2BandResult call(comm::CommunicatorGrid& grid, Matrix& mat_a, + const SizeType band_size); +}; + // ETI #define DLAF_EIGENSOLVER_REDUCTION_TO_BAND_ETI(KWORD, BACKEND, DEVICE, DATATYPE) \ - KWORD template struct ReductionToBand; + KWORD template struct ReductionToBand; \ + KWORD template struct CAReductionToBand; DLAF_EIGENSOLVER_REDUCTION_TO_BAND_ETI(extern, Backend::MC, Device::CPU, float) DLAF_EIGENSOLVER_REDUCTION_TO_BAND_ETI(extern, Backend::MC, Device::CPU, double) diff --git a/include/dlaf/eigensolver/reduction_to_band/ca-impl.h b/include/dlaf/eigensolver/reduction_to_band/ca-impl.h new file mode 100644 index 0000000000..5031af338b --- /dev/null +++ b/include/dlaf/eigensolver/reduction_to_band/ca-impl.h @@ -0,0 +1,1043 @@ +// +// Distributed Linear Algebra with Future (DLAF) +// +// Copyright (c) 2018-2024, ETH Zurich +// All rights reserved. +// +// Please, refer to the LICENSE file in the root directory. +// SPDX-License-Identifier: BSD-3-Clause +// + +#pragma once + +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace dlaf::eigensolver::internal { + +namespace ca_red2band { +template +void hemm(comm::Index2D rank_qr, matrix::Panel& W1, + matrix::Panel& W1T, + const matrix::SubMatrixView& at_view, matrix::Matrix& A, + matrix::Panel& W0, + matrix::Panel& W0T, + comm::CommunicatorPipeline& mpi_row_chain, + comm::CommunicatorPipeline& mpi_col_chain) { + namespace ex = pika::execution::experimental; + + using red2band::hemmDiag; + using red2band::hemmOffDiag; + + using pika::execution::thread_priority; + + const auto dist = A.distribution(); + const auto rank = dist.rankIndex(); + + // Note: + // They have to be set to zero, because all tiles are going to be reduced, and some tiles may not get + // "initialized" during computation, so they should not contribute with any spurious value to final result. + matrix::util::set0(thread_priority::high, W1); + matrix::util::set0(thread_priority::high, W1T); + + const LocalTileIndex at_offset = at_view.begin(); + + for (SizeType i_lc = at_offset.row(); i_lc < dist.local_nr_tiles().rows(); ++i_lc) { + // Note: + // diagonal included: get where the first upper tile is in local coordinates + const SizeType i = dist.template global_tile_from_local_tile(i_lc); + const auto j_end_lc = dist.template next_local_tile_from_global_tile(i + 1); + + for (SizeType j_lc = j_end_lc - 1; j_lc >= at_offset.col(); --j_lc) { + const LocalTileIndex ij_lc{i_lc, j_lc}; + const GlobalTileIndex ij = dist.global_tile_index(ij_lc); + + const bool is_diagonal_tile = (ij.row() == ij.col()); + + auto getSubA = [&A, &at_view, ij_lc]() { return splitTile(A.read(ij_lc), at_view(ij_lc)); }; + + if (is_diagonal_tile) { + const comm::IndexT_MPI id_qr_R = dist.template rank_global_tile(ij.col()); + + // Note: + // Use W0 just if the tile belongs to the current local transformation. + if (id_qr_R != rank_qr.row()) + continue; + + hemmDiag(thread_priority::high, getSubA(), W0.read(ij_lc), W1.readwrite(ij_lc)); + } + else { + const GlobalTileIndex ijL = ij; + const comm::IndexT_MPI id_qr_lower_R = dist.template rank_global_tile(ijL.col()); + if (id_qr_lower_R == rank_qr.row()) { + // Note: + // Since it is not a diagonal tile, otherwise it would have been managed in the previous + // branch, the second operand might not be available in W but it is accessible through the + // support panel W1T. + // However, since we are still computing the "straight" part, the result can be stored + // in the "local" panel W1. + hemmOffDiag(thread_priority::high, blas::Op::NoTrans, getSubA(), W0T.read(ij_lc), + W1.readwrite(ij_lc)); + } + + const GlobalTileIndex ijU = transposed(ij); + const comm::IndexT_MPI id_qr_upper_R = dist.template rank_global_tile(ijU.col()); + if (id_qr_upper_R == rank_qr.row()) { + // Note: + // Here we are considering the hermitian part of A, so pretend to deal with transposed coordinate. + // Check if the result still belongs to the same rank, otherwise store it in the support panel. + const comm::IndexT_MPI owner_row = dist.template rank_global_tile(ijU.row()); + const SizeType iU_lc = dist.template local_tile_from_global_tile(ij.col()); + const LocalTileIndex i_w1_lc(iU_lc, 0); + const LocalTileIndex i_w1t_lc(0, ij_lc.col()); + auto tile_w1 = (rank.row() == owner_row) ? W1.readwrite(i_w1_lc) : W1T.readwrite(i_w1t_lc); + + hemmOffDiag(thread_priority::high, blas::Op::ConjTrans, getSubA(), W0.read(ij_lc), + std::move(tile_w1)); + } + } + } + } + + // Note: + // At this point, partial results of W1 are available in the panels, and they have to be reduced, + // both row-wise and col-wise. The final W1 result will be available just on Ai panel column. + + // Note: + // The first step in reducing partial results distributed over W1 and W1T, it is to reduce the row + // panel W1T col-wise, by collecting all W1T results on the rank which can "mirror" the result on its + // rows (i.e. diagonal). So, for each tile of the row panel, select who is the "diagonal" rank that can + // mirror and reduce on it. + if (mpi_col_chain.size() > 1) { + for (const auto& i_wt_lc : W1T.iteratorLocal()) { + const auto i_diag = dist.template global_tile_from_local_tile(i_wt_lc.col()); + const auto rank_owner_row = dist.template rank_global_tile(i_diag); + + if (rank_owner_row == rank.row()) { + // Note: + // Since it is the owner, it has to perform the "mirroring" of the results from columns to + // rows. + // Moreover, it reduces in place because the owner of the diagonal stores the partial result + // directly in W1 (without using W1T) + const auto i_w1_lc = dist.template local_tile_from_global_tile(i_diag); + ex::start_detached(comm::schedule_reduce_recv_in_place(mpi_col_chain.exclusive(), MPI_SUM, + W1.readwrite({i_w1_lc, 0}))); + } + else { + ex::start_detached(comm::schedule_reduce_send(mpi_col_chain.exclusive(), rank_owner_row, MPI_SUM, + W1T.read(i_wt_lc))); + } + } + } + + // Note: + // At this point partial results are all collected in X (Xt has been embedded in previous step), + // so the last step needed is to reduce these last partial results in the final results. + if (mpi_row_chain.size() > 1) { + for (const auto& i_w1_lc : W1.iteratorLocal()) { + if (rank_qr.col() == rank.col()) + ex::start_detached(comm::schedule_reduce_recv_in_place(mpi_row_chain.exclusive(), MPI_SUM, + W1.readwrite(i_w1_lc))); + else + ex::start_detached(comm::schedule_reduce_send(mpi_row_chain.exclusive(), rank_qr.col(), MPI_SUM, + W1.read(i_w1_lc))); + } + } +} + +template +void her2kUpdateTrailingMatrix(comm::Index2D rank_qr, const matrix::SubMatrixView& at_view, + matrix::Matrix& a, matrix::Panel& W3, + matrix::Panel& V) { + static_assert(std::is_signed_v>, "alpha in computations requires to be -1"); + + using pika::execution::thread_priority; + using red2band::her2kDiag; + using red2band::her2kOffDiag; + + const auto dist = a.distribution(); + const comm::Index2D rank = dist.rank_index(); + + const LocalTileIndex at_offset = at_view.begin(); + + if (rank_qr.row() != rank.row()) + return; + + for (SizeType i_lc = at_offset.row(); i_lc < dist.local_nr_tiles().rows(); ++i_lc) { + // Note: + // diagonal included: get where the first upper tile is in local coordinates + const SizeType i = dist.template global_tile_from_local_tile(i_lc); + const auto j_end_lc = dist.template next_local_tile_from_global_tile(i + 1); + + for (SizeType j_lc = j_end_lc - 1; j_lc >= at_offset.col(); --j_lc) { + const LocalTileIndex ij_lc{i_lc, j_lc}; + const GlobalTileIndex ij = dist.global_tile_index(ij_lc); + + const comm::IndexT_MPI id_qr_L = dist.template rank_global_tile(ij.row()); + const comm::IndexT_MPI id_qr_R = dist.template rank_global_tile(ij.col()); + + // Note: this computation applies just to tiles where transformation applies both from L and R + if (id_qr_L != id_qr_R) + continue; + + const bool is_diagonal_tile = (ij.row() == ij.col()); + + auto getSubA = [&a, &at_view, ij_lc]() { return splitTile(a.readwrite(ij_lc), at_view(ij_lc)); }; + + // The first column of the trailing matrix (except for the very first global tile) has to be + // updated first, in order to unlock the next iteration as soon as possible. + const auto priority = (j_lc == at_offset.col()) ? thread_priority::high : thread_priority::normal; + + if (is_diagonal_tile) { + her2kDiag(priority, V.read(ij_lc), W3.read(ij_lc), getSubA()); + } + else { + // TODO fix doc + // Note: + // - We are updating from both L and R. + // - We are computing all combinations of W3 and V (and viceversa), and putting results in A + // - By looping on position of A that will contain the result + // - We use the same row for the first operand + // - We use the col as the row for the second operand + const SizeType iT_lc = dist.template local_tile_from_global_tile(ij.col()); + + // A -= W3 . V* + her2kOffDiag(priority, W3.read(ij_lc), V.read({iT_lc, 0}), getSubA()); + // A -= V . W3* + her2kOffDiag(priority, V.read(ij_lc), W3.read({iT_lc, 0}), getSubA()); + } + } + } +} + +template +void hemm2nd(comm::IndexT_MPI rank_panel, matrix::Panel& W1, + matrix::Panel& W1T, + const matrix::SubMatrixView& at_view, const SizeType j_end, matrix::Matrix& A, + matrix::Panel& W0, + matrix::Panel& W0T, + comm::CommunicatorPipeline& mpi_row_chain, + comm::CommunicatorPipeline& mpi_col_chain) { + namespace ex = pika::execution::experimental; + + using red2band::hemmDiag; + using red2band::hemmOffDiag; + + using pika::execution::thread_priority; + + const auto dist = A.distribution(); + const auto rank = dist.rankIndex(); + + // Note: + // They have to be set to zero, because all tiles are going to be reduced, and some tiles may not get + // "initialized" during computation, so they should not contribute with any spurious value to final result. + matrix::util::set0(thread_priority::high, W1); + matrix::util::set0(thread_priority::high, W1T); + + const LocalTileIndex at_offset = at_view.begin(); + + const SizeType jR_end_lc = dist.template next_local_tile_from_global_tile(j_end); + + for (SizeType i_lc = at_offset.row(); i_lc < dist.localNrTiles().rows(); ++i_lc) { + const auto j_end_lc = + std::min(jR_end_lc, dist.template next_local_tile_from_global_tile( + dist.template global_tile_from_local_tile(i_lc) + 1)); + for (SizeType j_lc = at_offset.col(); j_lc < j_end_lc; ++j_lc) { + const LocalTileIndex ij_lc(i_lc, j_lc); + const GlobalTileIndex ij = dist.global_tile_index(ij_lc); + + // skip upper + if (ij.row() < ij.col()) { + continue; + } + + const bool is_diag = (ij.row() == ij.col()); + + if (is_diag) { + hemmDiag(thread_priority::high, A.read(ij_lc), W0.read(ij_lc), W1.readwrite(ij_lc)); + } + else { + // Lower + hemmOffDiag(thread_priority::high, blas::Op::NoTrans, A.read(ij_lc), W0T.read(ij_lc), + W1.readwrite(ij_lc)); + + // Upper + const GlobalTileIndex ijU = transposed(ij); + + // Note: if it is out of the "sub-matrix" + if (ijU.col() >= j_end) + continue; + + const comm::IndexT_MPI owner_row = dist.template rank_global_tile(ijU.row()); + const SizeType iU_lc = dist.template local_tile_from_global_tile(ij.col()); + const LocalTileIndex i_w1_lc(iU_lc, 0); + const LocalTileIndex i_w1t_lc(0, ij_lc.col()); + auto tile_w1 = (rank.row() == owner_row) ? W1.readwrite(i_w1_lc) : W1T.readwrite(i_w1t_lc); + + hemmOffDiag(thread_priority::high, blas::Op::ConjTrans, A.read(ij_lc), W0.read(ij_lc), + std::move(tile_w1)); + } + } + } + + // Note: + // At this point, partial results of W1 are available in the panels, and they have to be reduced, + // both row-wise and col-wise. The final W1 result will be available just on Ai panel column. + + // Note: + // The first step in reducing partial results distributed over W1 and W1T, it is to reduce the row + // panel W1T col-wise, by collecting all W1T results on the rank which can "mirror" the result on its + // rows (i.e. diagonal). So, for each tile of the row panel, select who is the "diagonal" rank that can + // mirror and reduce on it. + if (mpi_col_chain.size() > 1) { + for (const auto& i_wt_lc : W1T.iteratorLocal()) { + const auto i_diag = dist.template global_tile_from_local_tile(i_wt_lc.col()); + const auto rank_owner_row = dist.template rank_global_tile(i_diag); + + if (rank_owner_row == rank.row()) { + // Note: + // Since it is the owner, it has to perform the "mirroring" of the results from columns to + // rows. + // Moreover, it reduces in place because the owner of the diagonal stores the partial result + // directly in W1 (without using W1T) + const auto i_w1_lc = dist.template local_tile_from_global_tile(i_diag); + ex::start_detached(comm::schedule_reduce_recv_in_place(mpi_col_chain.exclusive(), MPI_SUM, + W1.readwrite({i_w1_lc, 0}))); + } + else { + ex::start_detached(comm::schedule_reduce_send(mpi_col_chain.exclusive(), rank_owner_row, MPI_SUM, + W1T.read(i_wt_lc))); + } + } + } + + // Note: + // At this point partial results are all collected in X (Xt has been embedded in previous step), + // so the last step needed is to reduce these last partial results in the final results. + if (mpi_row_chain.size() > 1) { + for (const auto& i_w1_lc : W1.iteratorLocal()) { + if (rank_panel == rank.col()) + ex::start_detached(comm::schedule_reduce_recv_in_place(mpi_row_chain.exclusive(), MPI_SUM, + W1.readwrite(i_w1_lc))); + else + ex::start_detached(comm::schedule_reduce_send(mpi_row_chain.exclusive(), rank_panel, MPI_SUM, + W1.read(i_w1_lc))); + } + } +} + +template +void her2k_2nd(const SizeType i_end, const SizeType j_end, const matrix::SubMatrixView& at_view, + matrix::Matrix& a, matrix::Panel& W1, + matrix::Panel& W1T, + matrix::Panel& V, + matrix::Panel& VT) { + static_assert(std::is_signed_v>, "alpha in computations requires to be -1"); + + using pika::execution::thread_priority; + using red2band::her2kDiag; + using red2band::her2kOffDiag; + + const auto dist = a.distribution(); + + const LocalTileIndex at_offset_lc = at_view.begin(); + + const SizeType iL_end_lc = dist.template next_local_tile_from_global_tile(i_end); + const SizeType jR_end_lc = dist.template next_local_tile_from_global_tile(j_end); + for (SizeType i_lc = at_offset_lc.row(); i_lc < iL_end_lc; ++i_lc) { + const auto j_end_lc = + std::min(jR_end_lc, dist.template next_local_tile_from_global_tile( + dist.template global_tile_from_local_tile(i_lc) + 1)); + for (SizeType j_lc = at_offset_lc.col(); j_lc < j_end_lc; ++j_lc) { + const LocalTileIndex ij_local{i_lc, j_lc}; + const GlobalTileIndex ij = dist.globalTileIndex(ij_local); + + const bool is_diagonal_tile = (ij.row() == ij.col()); + + auto getSubA = [&a, &at_view, ij_local]() { + return splitTile(a.readwrite(ij_local), at_view(ij_local)); + }; + + // The first column of the trailing matrix (except for the very first global tile) has to be + // updated first, in order to unlock the next iteration as soon as possible. + const auto priority = + (j_lc == at_offset_lc.col()) ? thread_priority::high : thread_priority::normal; + + if (is_diagonal_tile) { + her2kDiag(priority, V.read(ij_local), W1.read(ij_local), getSubA()); + } + else { + // A -= X . V* + her2kOffDiag(priority, W1.read(ij_local), VT.read(ij_local), getSubA()); + + // A -= V . X* + her2kOffDiag(priority, V.read(ij_local), W1T.read(ij_local), getSubA()); + } + } + } + + // This is just going to update rows that are going to be updated just from right. + for (SizeType i_lc = iL_end_lc; i_lc < dist.local_nr_tiles().rows(); ++i_lc) { + const auto j_end_lc = + std::min(jR_end_lc, dist.template next_local_tile_from_global_tile( + dist.template global_tile_from_local_tile(i_lc) + 1)); + for (SizeType j_lc = at_offset_lc.col(); j_lc < j_end_lc; ++j_lc) { + const LocalTileIndex ij_lc{i_lc, j_lc}; + + auto getSubA = [&a, &at_view, ij_lc]() { return splitTile(a.readwrite(ij_lc), at_view(ij_lc)); }; + + // The first column of the trailing matrix (except for the very first global tile) has to be + // updated first, in order to unlock the next iteration as soon as possible. + const auto priority = + (j_lc == at_offset_lc.col()) ? thread_priority::high : thread_priority::normal; + + // A -= X . V* + her2kOffDiag(priority, W1.read(ij_lc), VT.read(ij_lc), getSubA()); + } + } +} +} + +// Distributed implementation of reduction to band +template +CARed2BandResult CAReductionToBand::call(comm::CommunicatorGrid& grid, + Matrix& mat_a, const SizeType band_size) { + namespace ex = pika::execution::experimental; + namespace di = dlaf::internal; + + using common::RoundRobin; + using matrix::Panel; + using matrix::StoreTransposed; + + const auto& dist = mat_a.distribution(); + const comm::Index2D rank = dist.rank_index(); + + // Note: + // Reflector of size = 1 is not considered whatever T is (i.e. neither real nor complex) + const SizeType nrefls = std::max(0, dist.size().cols() - band_size - 1); + + // Note: + // Each rank has space for storing taus for one tile. It is distributed as the input matrix (i.e. 2D) + // Note: + // It is distributed "transposed" because of implicit assumptions in functions, i.e. + // computePanelReflectors and computeTFactor, which expect a column vector. + DLAF_ASSERT(dist.block_size().cols() % band_size == 0, dist.block_size().cols(), band_size); + Matrix mat_taus_1st(matrix::Distribution( + GlobalElementSize(nrefls, dist.grid_size().rows()), TileElementSize(dist.block_size().cols(), 1), + transposed(dist.grid_size()), transposed(rank), transposed(dist.source_rank_index()))); + + // Note: + // It has room for storing one tile per rank and it is distributed as the input matrix (i.e. 2D) + const matrix::Distribution dist_hh_2nd( + GlobalElementSize(dist.grid_size().rows() * dist.block_size().rows(), dist.size().cols()), + dist.block_size(), dist.grid_size(), rank, dist.source_rank_index()); + Matrix mat_hh_2nd(dist_hh_2nd); + + // Note: + // Is stored as a column but it acts as a row vector. It is replicated over rows and it is distributed + // over columns (i.e. 1D distributed) + DLAF_ASSERT(dist.block_size().cols() % band_size == 0, dist.block_size().cols(), band_size); + Matrix mat_taus_2nd( + matrix::Distribution(GlobalElementSize(nrefls, 1), TileElementSize(dist.block_size().cols(), 1), + comm::Size2D(dist.grid_size().cols(), 1), comm::Index2D(rank.col(), 0), + comm::Index2D(dist.source_rank_index().col(), 0))); + + if (nrefls == 0) + return {std::move(mat_taus_1st), std::move(mat_taus_2nd), std::move(mat_hh_2nd)}; + + auto mpi_col_chain = grid.col_communicator_pipeline(); + auto mpi_row_chain = grid.row_communicator_pipeline(); + + constexpr std::size_t n_workspaces = 2; + + // TODO HEADS workspace + // - column vector + // - has to be fully local + // - no more than grid_size.rows() tiles (1 tile per rank in the column) + // - we use panel just because it offers the ability to shrink width/height + const matrix::Distribution dist_heads( + LocalElementSize(dist.grid_size().rows() * dist.block_size().rows(), dist.block_size().cols()), + dist.block_size()); + + RoundRobin> panels_heads(n_workspaces, dist_heads); + + // update trailing matrix workspaces + RoundRobin> panels_v(n_workspaces, dist); + RoundRobin> panels_vt(n_workspaces, dist); + + RoundRobin> panels_w0(n_workspaces, dist); + RoundRobin> panels_w0t(n_workspaces, dist); + + RoundRobin> panels_w1(n_workspaces, dist); + RoundRobin> panels_w1t(n_workspaces, dist); + + RoundRobin> panels_w3(n_workspaces, dist); + + DLAF_ASSERT(mat_a.block_size().cols() == band_size, mat_a.block_size().cols(), band_size); + const SizeType ntiles = (nrefls - 1) / band_size + 1; + + const bool is_full_band = (band_size == dist.blockSize().cols()); + DLAF_ASSERT(is_full_band, is_full_band); + + for (SizeType j = 0; j < ntiles; ++j) { + const SizeType i = j + 1; + const SizeType j_lc = dist.template local_tile_from_global_tile(j); + + const SizeType nrefls_1st = [&]() -> SizeType { + const SizeType i_head_lc = dist.template next_local_tile_from_global_tile(i); + + if (i_head_lc >= dist.local_nr_tiles().rows()) + return 0; + + const SizeType i_head = dist.template global_tile_from_local_tile(i_head_lc); + const GlobalTileIndex ij_head(i_head, j); + const TileElementSize head_size = mat_a.tile_size_of(ij_head); + + if (i_head_lc == dist.local_nr_tiles().rows() - 1) + return head_size.rows() - 1; + else + return head_size.rows(); + }(); + + auto get_tile_tau = [&]() { + if (nrefls_1st == band_size) + return mat_taus_1st.readwrite(LocalTileIndex(j_lc, 0)); + return splitTile(mat_taus_1st.readwrite(LocalTileIndex(j_lc, 0)), {{0, 0}, {nrefls_1st, 1}}); + }; + + auto get_tile_tau_ro = [&]() { + if (nrefls_1st == band_size) + return mat_taus_1st.read(LocalTileIndex(j_lc, 0)); + return splitTile(mat_taus_1st.read(LocalTileIndex(j_lc, 0)), {{0, 0}, {nrefls_1st, 1}}); + }; + + // panel + const GlobalTileIndex panel_offset(i, j); + const GlobalElementIndex panel_offset_el(panel_offset.row() * band_size, + panel_offset.col() * band_size); + matrix::SubPanelView panel_view(dist, panel_offset_el, band_size); + + const comm::IndexT_MPI rank_panel(dist.template rank_global_tile(panel_offset.col())); + + const SizeType n_qr_heads = + std::min(panel_view.offset().row() + grid.size().rows(), dist.nr_tiles().rows()) - + panel_view.offset().row(); + + // trailing + const GlobalTileIndex at_offset(i, j + 1); + const GlobalElementIndex at_offset_el(at_offset.row() * band_size, at_offset.col() * band_size); + const LocalTileIndex at_offset_lc( + dist.template next_local_tile_from_global_tile(at_offset.row()), + dist.template next_local_tile_from_global_tile(at_offset.col())); + matrix::SubMatrixView at_view(dist, at_offset_el); + + // PANEL: just ranks in the current column + // QR local (HH reflectors stored in-place) + if (rank_panel == rank.col()) { + using red2band::local::computePanelReflectors; + computePanelReflectors(mat_a, get_tile_tau(), panel_view); + } + + // TRAILING 1st pass + if (at_offset_el.isIn(mat_a.size())) { + // TODO FIXME workaround for possibly non-participating ranks (they cannot have width = 0) + const SizeType nrefls_1st_min = std::max(nrefls_1st, 1); + + const LocalTileIndex zero_lc(0, 0); + matrix::Matrix ws_T({nrefls_1st_min, nrefls_1st_min}, dist.block_size()); + + auto& ws_V = panels_v.nextResource(); + ws_V.setRangeStart(at_offset); + ws_V.setWidth(nrefls_1st_min); + + auto& ws_W0 = panels_w0.nextResource(); + ws_W0.setRangeStart(at_offset); + ws_W0.setWidth(nrefls_1st_min); + + if (rank_panel == rank.col() && nrefls_1st != 0) { + using factorization::internal::computeTFactor; + using red2band::local::setupReflectorPanelV; + + const bool has_head = !panel_view.iteratorLocal().empty(); + setupReflectorPanelV(has_head, panel_view, nrefls_1st_min, ws_V, mat_a, !is_full_band); + + computeTFactor(ws_V, get_tile_tau_ro(), ws_T.readwrite(zero_lc)); + + // W = V T + red2band::local::trmmComputeW(ws_W0, ws_V, ws_T.read(zero_lc)); + } + + // Note: apply local transformations, one after the other + // matrix::Matrix ws_W2 = std::move(ws_T); + + for (int idx_qr_head = 0; idx_qr_head < n_qr_heads; ++idx_qr_head) { + const SizeType head_qr = at_view.offset().row() + idx_qr_head; + const comm::Index2D rank_qr(dist.template rank_global_tile(head_qr), rank_panel); + + const bool is_row_involved = rank_qr.row() == rank.row(); + const bool is_col_involved = [head_qr, dist, rank]() { + for (SizeType k = head_qr; k < dist.nr_tiles().rows(); k += dist.grid_size().rows()) { + const comm::IndexT_MPI rank_owner_col = dist.template rank_global_tile(k); + if (rank_owner_col == rank.col()) + return true; + } + return false; + }(); + + const SizeType nrtiles_transf = + util::ceilDiv(dist.nr_tiles().rows() - head_qr, dist.grid_size().rows()); + + // TODO FIXME cannot skip because otherwise no W1 reduction if any rank skip?! + // // this rank is not involved at all + // if (!is_row_involved && !is_col_involved) + // continue; + + // number of reflectors of this "local" transformation + const SizeType nrefls_this = [&]() { + using matrix::internal::distribution::global_tile_from_local_tile_on_rank; + const SizeType i_head = head_qr; + if (nrtiles_transf == 1) { + const TileElementSize tile_size = dist.tile_size_of({i_head, j}); + return std::min(tile_size.rows() - 1, tile_size.cols()); + } + return dist.block_size().cols(); + }(); + + if (nrefls_this == 0) + continue; + + // "local" broadcast along rows involved in this local transformation + if (is_row_involved) { + comm::broadcast(rank_panel, ws_V, mpi_row_chain); + comm::broadcast(rank_panel, ws_W0, mpi_row_chain); + } + + auto& ws_VT = panels_vt.nextResource(); + // TODO FIXME workaround for panel problem on reset about range + ws_VT.setRange(at_offset, common::indexFromOrigin(dist.nr_tiles())); + ws_VT.setHeight(nrefls_this); + + auto& ws_W0T = panels_w0t.nextResource(); + // TODO FIXME workaround for panel problem on reset about range + ws_W0T.setRange(at_offset, common::indexFromOrigin(dist.nr_tiles())); + ws_W0T.setHeight(nrefls_this); + + // broadcast along cols involved in this local transformation + if (is_col_involved) { // diagonal + // set diagonal tiles + for (const auto ij_lc : ws_VT.iteratorLocal()) { + const SizeType k = dist.template global_tile_from_local_tile(ij_lc.col()); + const comm::IndexT_MPI rank_src = dist.template rank_global_tile(k); + + if (rank_qr.row() != rank_src) + continue; + + using comm::schedule_bcast_recv; + using comm::schedule_bcast_send; + + if (rank_src == rank.row()) { + const SizeType i_lc = dist.template local_tile_from_global_tile(k); + + ws_VT.setTile(ij_lc, ws_V.read({i_lc, 0})); + ws_W0T.setTile(ij_lc, ws_W0.read({i_lc, 0})); + + // if (nrtiles_transf > 1) { + ex::start_detached(schedule_bcast_send(mpi_col_chain.exclusive(), ws_V.read({i_lc, 0}))); + ex::start_detached(schedule_bcast_send(mpi_col_chain.exclusive(), ws_W0.read({i_lc, 0}))); + // } + } + else { + // if (nrtiles_transf > 1) { + ex::start_detached(schedule_bcast_recv(mpi_col_chain.exclusive(), rank_src, + ws_VT.readwrite(ij_lc))); + ex::start_detached(schedule_bcast_recv(mpi_col_chain.exclusive(), rank_src, + ws_W0T.readwrite(ij_lc))); + // } + } + } + } + + // W1 = A W0 + // Note: + // it will fill up all W1 (distributed) and it will read just part of A, i.e. columns where + // this local transformation should be applied from the right. + auto& ws_W1 = panels_w1.nextResource(); + auto& ws_W1T = panels_w1t.nextResource(); + ws_W1.setRangeStart(at_offset); + ws_W1.setWidth(nrefls_this); + + ws_W1T.setRangeStart(at_offset); + ws_W1T.setHeight(nrefls_this); + + // TODO FIXME restrict communication to just interested ones + using ca_red2band::hemm; + hemm(rank_qr, ws_W1, ws_W1T, at_view, mat_a, ws_W0, ws_W0T, mpi_row_chain, mpi_col_chain); + + // Note: + // W1T has been used as support panel, so reset it again. + ws_W1T.reset(); + ws_W1T.setRangeStart(at_offset); + ws_W1T.setHeight(nrefls_this); + + // TODO FIXME restrict communication to just interested ones + comm::broadcast(rank_panel, ws_W1, ws_W1T, mpi_row_chain, mpi_col_chain); + + // LR + // A -= V W1* + W1 V* - V W0* W1 V* + if (rank.row() == rank_qr.row()) { + matrix::Matrix ws_W2({nrefls_this, nrefls_this}, dist.block_size()); + // W2 = W0.T W1 + red2band::local::gemmComputeW2(ws_W2, ws_W0, ws_W1); + + // Note: + // Next steps for L and R need W1, so we create a copy that we are going to update for this step. + auto& ws_W3 = panels_w3.nextResource(); + ws_W3.setRangeStart(at_offset); + ws_W3.setWidth(nrefls_this); + + for (const auto& idx : ws_W1.iteratorLocal()) + ex::start_detached(ex::when_all(ws_W1.read(idx), ws_W3.readwrite(idx)) | + matrix::copy(di::Policy{})); + + // W1 -= 0.5 V W2 + red2band::local::gemmUpdateX(ws_W3, ws_W2, ws_V); + // A -= W1 V.T + V W1.T + ca_red2band::her2kUpdateTrailingMatrix(rank_qr, at_view, mat_a, ws_W3, ws_V); + + ws_W3.reset(); + } + + // R (exclusively) + // A -= W1 V* + // Note: all rows, but just the columns that are in the local transformation rank + for (SizeType j_lc = at_offset_lc.col(); j_lc < dist.local_nr_tiles().cols(); ++j_lc) { + const SizeType j = dist.template global_tile_from_local_tile(j_lc); + const comm::IndexT_MPI id_qr_R = dist.template rank_global_tile(j); + + if (rank_qr.row() != id_qr_R) + continue; + + for (SizeType i_lc = at_offset_lc.row(); i_lc < dist.local_nr_tiles().rows(); ++i_lc) { + const LocalTileIndex ij_lc(i_lc, j_lc); + const GlobalTileIndex ij = dist.global_tile_index(ij_lc); + + // TODO just lower part of trailing matrix + if (ij.row() < ij.col()) + continue; + + const comm::IndexT_MPI id_qr_L = dist.template rank_global_tile(ij.row()); + + // Note: exclusively from R, if it is an LR tile, it is computed elsewhere + if (id_qr_L == id_qr_R) + continue; + + ex::start_detached(di::whenAllLift(blas::Op::NoTrans, blas::Op::ConjTrans, T(-1), + ws_W1.read(ij_lc), ws_VT.read(ij_lc), T(1), + mat_a.readwrite(ij_lc)) | + tile::gemm(di::Policy())); + } + } + + // L (exclusively) + // A -= V W1* + // Note: all cols, but just the rows of current transformation + if (rank_qr.row() == rank.row()) { + for (SizeType i_lc = at_offset_lc.row(); i_lc < dist.local_nr_tiles().rows(); ++i_lc) { + const comm::IndexT_MPI id_qr_L = rank_qr.row(); + + for (SizeType j_lc = at_offset_lc.col(); j_lc < dist.local_nr_tiles().cols(); ++j_lc) { + const LocalTileIndex ij_lc(i_lc, j_lc); + const GlobalTileIndex ij = dist.global_tile_index(ij_lc); + + // TODO just lower part of trailing matrix + if (ij.row() < ij.col()) + continue; + + const comm::IndexT_MPI id_qr_R = dist.template rank_global_tile(ij.col()); + + // Note: exclusively from L, if it is an LR tile, it is computed elsewhere + if (id_qr_L == id_qr_R) + continue; + + ex::start_detached(di::whenAllLift(blas::Op::NoTrans, blas::Op::ConjTrans, T(-1), + ws_V.read(ij_lc), ws_W1T.read(ij_lc), T(1), + mat_a.readwrite(ij_lc)) | + tile::gemm(di::Policy())); + } + } + } + + ws_W1T.reset(); + ws_W1.reset(); + + ws_W0T.reset(); + ws_VT.reset(); + } + + ws_W0.reset(); + ws_V.reset(); + } + + // ===== 2nd pass + const matrix::Distribution dist_heads_current = [&]() { + using matrix::internal::distribution::global_tile_element_distance; + const SizeType i_begin = i; + const SizeType i_end = std::min(i + dist.grid_size().rows(), dist.nr_tiles().rows()); + const SizeType nrows = global_tile_element_distance(dist, i_begin, i_end); + return matrix::Distribution{LocalElementSize(nrows, band_size), dist.block_size()}; + }(); + + const SizeType nrefls_step = [&]() { + const SizeType reflector_size = dist_heads_current.size().rows(); + return std::min(dist_heads_current.size().cols(), reflector_size - 1); + }(); + + auto get_tile_tau2 = [&]() { + if (nrefls_step == band_size) + return mat_taus_2nd.readwrite(LocalTileIndex(j_lc, 0)); + return splitTile(mat_taus_2nd.readwrite(LocalTileIndex(j_lc, 0)), {{0, 0}, {nrefls_step, 1}}); + }; + + auto get_tile_tau2_ro = [&]() { + if (nrefls_step == band_size) + return mat_taus_2nd.read(LocalTileIndex(j_lc, 0)); + return splitTile(mat_taus_2nd.read(LocalTileIndex(j_lc, 0)), {{0, 0}, {nrefls_step, 1}}); + }; + + // PANEL: just ranks in the current column + // QR local with just heads (HH reflectors have to be computed elsewhere, 1st-pass in-place) + auto&& panel_heads = panels_heads.nextResource(); + panel_heads.setRangeEnd({n_qr_heads, 0}); + + const matrix::SubPanelView panel_heads_view(dist_heads_current, {0, 0}, band_size); + + const bool rank_has_head_row = !panel_view.iteratorLocal().empty(); + if (rank_panel == rank.col()) { + const comm::IndexT_MPI rank_hoh = dist.template rank_global_tile(panel_offset.row()); + + for (int idx_head = 0; idx_head < n_qr_heads; ++idx_head) { + using dlaf::comm::schedule_bcast_recv; + using dlaf::comm::schedule_bcast_send; + + const LocalTileIndex idx_panel_head(idx_head, 0); + + const GlobalTileIndex ij_head(panel_view.offset().row() + idx_head, j); + const comm::IndexT_MPI rank_head = dist.template rank_global_tile(ij_head.row()); + + if (rank.row() == rank_head) { + // copy - set - send + ex::start_detached(ex::when_all(mat_a.read(ij_head), panel_heads.readwrite(idx_panel_head)) | + di::transform(di::Policy(), [=](const auto& head_in, auto&& head) { + // TODO FIXME workaround for over-sized panel + if (head_in.size() != head.size()) + tile::internal::set0(head); + + // TODO FIXME change copy and if possible just upper + // matrix::internal::copy(head_in, head); + lapack::lacpy(blas::Uplo::General, head_in.size().rows(), + head_in.size().cols(), head_in.ptr(), head_in.ld(), + head.ptr(), head.ld()); + lapack::laset(blas::Uplo::Lower, head.size().rows() - 1, + head.size().cols(), T(0), T(0), head.ptr({1, 0}), + head.ld()); + })); + ex::start_detached(schedule_bcast_send(mpi_col_chain.exclusive(), + panel_heads.read(idx_panel_head))); + } + else { + // receive + ex::start_detached(schedule_bcast_recv(mpi_col_chain.exclusive(), rank_head, + panel_heads.readwrite(idx_panel_head))); + } + } + + // QR local on heads + using red2band::local::computePanelReflectors; + computePanelReflectors(panel_heads, get_tile_tau2(), panel_heads_view); + + // copy back data + { + // - just head of heads upper to mat_a + // - reflectors to hh_2nd + const GlobalTileIndex ij_hoh(panel_view.offset().row(), j); + if (rank.row() == dist.template rank_global_tile(ij_hoh.row())) + ex::start_detached(ex::when_all(panel_heads.read({0, 0}), mat_a.readwrite(ij_hoh)) | + di::transform(di::Policy(), [](const auto& hoh, auto&& hoh_a) { + common::internal::SingleThreadedBlasScope single; + lapack::lacpy(blas::Uplo::Upper, hoh.size().rows(), hoh.size().cols(), + hoh.ptr(), hoh.ld(), hoh_a.ptr(), hoh_a.ld()); + })); + + // Note: not all ranks might have an head + if (rank_has_head_row) { + const auto i_head_lc = + dist.template next_local_tile_from_global_tile(panel_view.offset().row()); + const auto i_head = dist.template global_tile_from_local_tile(i_head_lc); + const auto idx_head = i_head - panel_view.offset().row(); + const LocalTileIndex idx_panel_head(idx_head, 0); + const LocalTileIndex ij_head(0, j_lc); + + auto sender_heads = + ex::when_all(panel_heads.read(idx_panel_head), mat_hh_2nd.readwrite(ij_head)); + + if (rank.row() == rank_hoh) { + ex::start_detached(std::move(sender_heads) | + di::transform(di::Policy(), [](const auto& head, auto&& head_a) { + common::internal::SingleThreadedBlasScope single; + lapack::laset(blas::Uplo::Upper, head_a.size().rows(), + head_a.size().cols(), T(0), T(1), head_a.ptr(), + head_a.ld()); + lapack::lacpy(blas::Uplo::Lower, head.size().rows() - 1, + head.size().cols() - 1, head.ptr({1, 0}), head.ld(), + head_a.ptr({1, 0}), head_a.ld()); + })); + } + else { + ex::start_detached(std::move(sender_heads) | + di::transform(di::Policy(), [](const auto& head, auto&& head_a) { + common::internal::SingleThreadedBlasScope single; + lapack::lacpy(blas::Uplo::General, head.size().rows(), + head.size().cols(), head.ptr(), head.ld(), head_a.ptr(), + head_a.ld()); + })); + } + } + } + + panel_heads.reset(); + } + + // TRAILING 2nd pass + { + panel_heads.setRangeEnd({n_qr_heads, 0}); + panel_heads.setWidth(nrefls_step); + + const GlobalTileIndex at_end_L(at_offset.row() + n_qr_heads, 0); + const GlobalTileIndex at_end_R(0, at_offset.col() + n_qr_heads); + + const LocalTileIndex zero_lc(0, 0); + matrix::Matrix ws_T({nrefls_step, nrefls_step}, dist.block_size()); + + auto& ws_V = panels_v.nextResource(); + ws_V.setRange(at_offset, at_end_L); + ws_V.setWidth(nrefls_step); + + if (rank_panel == rank.col()) { + // setup reflector panel + const LocalTileIndex ij_head(0, j_lc); + const LocalTileIndex ij_vhh_lc(ws_V.rangeStartLocal(), 0); + + if (rank_has_head_row) { + ex::start_detached(ex::when_all(mat_hh_2nd.read(ij_head), ws_V.readwrite(ij_vhh_lc)) | + di::transform(di::Policy(), [=](const auto& head_in, auto&& head) { + lapack::lacpy(blas::Uplo::General, head.size().rows(), head.size().cols(), + head_in.ptr(), head_in.ld(), head.ptr(), head.ld()); + })); + } + + using factorization::internal::computeTFactor; + const GlobalTileIndex j_tau(j, 0); + computeTFactor(panel_heads, get_tile_tau2_ro(), ws_T.readwrite(zero_lc)); + } + + auto& ws_VT = panels_vt.nextResource(); + + ws_VT.setRange(at_offset, at_end_R); + ws_VT.setHeight(nrefls_step); + + comm::broadcast_all(rank_panel, ws_V, ws_VT, mpi_row_chain, mpi_col_chain); + + // Note: + // Differently from 1st pass, where transformations are independent one from the other, + // this 2nd pass is a single QR transformation that has to be applied from L and R. + + // W0 = V T + auto& ws_W0 = panels_w0.nextResource(); + ws_W0.setRange(at_offset, at_end_L); + ws_W0.setWidth(nrefls_step); + + if (rank.col() == rank_panel) + red2band::local::trmmComputeW(ws_W0, ws_V, ws_T.read(zero_lc)); + + // distribute W0 -> W0T + auto& ws_W0T = panels_w0t.nextResource(); + ws_W0T.setRange(at_offset, at_end_R); + ws_W0T.setHeight(nrefls_step); + + comm::broadcast_all(rank_panel, ws_W0, ws_W0T, mpi_row_chain, mpi_col_chain); + + // W1 = A W0 + auto& ws_W1 = panels_w1.nextResource(); + ws_W1.setRangeStart(at_offset); + ws_W1.setWidth(nrefls_step); + + auto& ws_W1T = panels_w1t.nextResource(); + ws_W1T.setRangeStart(at_offset); + ws_W1T.setHeight(nrefls_step); + + ca_red2band::hemm2nd(rank_panel, ws_W1, ws_W1T, at_view, at_end_R.col(), mat_a, ws_W0, + ws_W0T, mpi_row_chain, mpi_col_chain); + + // W1 = W1 - 0.5 V W0* W1 + if (rank.col() == rank_panel) { + matrix::Matrix ws_W2 = std::move(ws_T); + + // W2 = W0T W1 + red2band::local::gemmComputeW2(ws_W2, ws_W0, ws_W1); + if (mpi_col_chain.size() > 1) { + ex::start_detached(comm::schedule_all_reduce_in_place(mpi_col_chain.exclusive(), MPI_SUM, + ws_W2.readwrite(zero_lc))); + } + + // W1 = W1 - 0.5 V W2 + red2band::local::gemmUpdateX(ws_W1, ws_W2, ws_V); + } + + // distribute W1 -> W1T + ws_W1T.reset(); + ws_W1T.setRangeStart(at_offset); + ws_W1T.setHeight(nrefls_step); + + comm::broadcast(rank_panel, ws_W1, ws_W1T, mpi_row_chain, mpi_col_chain); + + // LR: A -= W1 VT + V W1T + // R : [at_end_L.row():, :at_endR_col()] A = A - W1 V.T + ca_red2band::her2k_2nd(at_end_L.row(), at_end_R.col(), at_view, mat_a, ws_W1, ws_W1T, ws_V, + ws_VT); + + ws_W1T.reset(); + ws_W1.reset(); + ws_W0T.reset(); + ws_W0.reset(); + ws_VT.reset(); + ws_V.reset(); + } + + panel_heads.reset(); + } + + return {std::move(mat_taus_1st), std::move(mat_taus_2nd), std::move(mat_hh_2nd)}; +} +} diff --git a/include/dlaf/eigensolver/reduction_to_band/common.h b/include/dlaf/eigensolver/reduction_to_band/common.h new file mode 100644 index 0000000000..9c0b8ab240 --- /dev/null +++ b/include/dlaf/eigensolver/reduction_to_band/common.h @@ -0,0 +1,475 @@ +// +// Distributed Linear Algebra with Future (DLAF) +// +// Copyright (c) 2018-2024, ETH Zurich +// All rights reserved. +// +// Please, refer to the LICENSE file in the root directory. +// SPDX-License-Identifier: BSD-3-Clause +// + +#pragma once + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace dlaf::eigensolver::internal { + +// Given a vector of vectors, reduce all vectors in the first one using sum operation +template +void reduceColumnVectors(std::vector>& columnVectors) { + for (std::size_t i = 1; i < columnVectors.size(); ++i) { + DLAF_ASSERT_HEAVY(columnVectors[0].size() == columnVectors[i].size(), columnVectors[0].size(), + columnVectors[i].size()); + for (SizeType j = 0; j < columnVectors[0].size(); ++j) + columnVectors[0][j] += columnVectors[i][j]; + } +} + +namespace red2band { + +// Extract x0 and compute local cumulative sum of squares of the reflector column +template +std::array computeX0AndSquares(const bool has_head, const std::vector>& panel, + SizeType j) { + std::array x0_and_squares{0, 0}; + auto it_begin = panel.begin(); + auto it_end = panel.end(); + + common::internal::SingleThreadedBlasScope single; + + if (has_head) { + auto& tile_v0 = *it_begin++; + + const TileElementIndex idx_x0(j, j); + x0_and_squares[0] = tile_v0(idx_x0); + + T* reflector_ptr = tile_v0.ptr(idx_x0); + x0_and_squares[1] = + blas::dot(tile_v0.size().rows() - idx_x0.row(), reflector_ptr, 1, reflector_ptr, 1); + } + + for (auto it = it_begin; it != it_end; ++it) { + const auto& tile = *it; + + T* reflector_ptr = tile.ptr({0, j}); + x0_and_squares[1] += blas::dot(tile.size().rows(), reflector_ptr, 1, reflector_ptr, 1); + } + return x0_and_squares; +} + +template +T computeReflectorAndTau(const bool has_head, const std::vector>& panel, + const SizeType j, std::array x0_and_squares) { + if (x0_and_squares[1] == T(0)) + return T(0); + + const T norm = std::sqrt(x0_and_squares[1]); + const T x0 = x0_and_squares[0]; + const T y = std::signbit(std::real(x0_and_squares[0])) ? norm : -norm; + const T tau = (y - x0) / y; + + auto it_begin = panel.begin(); + auto it_end = panel.end(); + + common::internal::SingleThreadedBlasScope single; + + if (has_head) { + const auto& tile_v0 = *it_begin++; + + const TileElementIndex idx_x0(j, j); + tile_v0(idx_x0) = y; + + if (j + 1 < tile_v0.size().rows()) { + T* v = tile_v0.ptr({j + 1, j}); + blas::scal(tile_v0.size().rows() - (j + 1), T(1) / (x0 - y), v, 1); + } + } + + for (auto it = it_begin; it != it_end; ++it) { + auto& tile_v = *it; + T* v = tile_v.ptr({0, j}); + blas::scal(tile_v.size().rows(), T(1) / (x0 - y), v, 1); + } + + return tau; +} + +template +void computeWTrailingPanel(const bool has_head, const std::vector>& panel, + common::internal::vector& w, SizeType j, const SizeType pt_cols, + const std::size_t begin, const std::size_t end) { + // for each tile in the panel, consider just the trailing panel + // i.e. all rows (height = reflector), just columns to the right of the current reflector + if (!(pt_cols > 0)) + return; + + const TileElementIndex index_el_x0(j, j); + bool has_first_component = has_head; + + common::internal::SingleThreadedBlasScope single; + + // W = Pt* . V + for (auto index = begin; index < end; ++index) { + const matrix::Tile& tile_a = panel[index]; + const SizeType first_element = has_first_component ? index_el_x0.row() : 0; + + TileElementIndex pt_start{first_element, index_el_x0.col() + 1}; + TileElementSize pt_size{tile_a.size().rows() - pt_start.row(), pt_cols}; + TileElementIndex v_start{first_element, index_el_x0.col()}; + + if (has_first_component) { + const TileElementSize offset{1, 0}; + + const T fake_v = 1; + blas::gemv(blas::Layout::ColMajor, blas::Op::ConjTrans, offset.rows(), pt_size.cols(), T(1), + tile_a.ptr(pt_start), tile_a.ld(), &fake_v, 1, T(0), w.data(), 1); + + pt_start = pt_start + offset; + v_start = v_start + offset; + pt_size = pt_size - offset; + + has_first_component = false; + } + + if (pt_start.isIn(tile_a.size())) { + // W += 1 . A* . V + blas::gemv(blas::Layout::ColMajor, blas::Op::ConjTrans, pt_size.rows(), pt_size.cols(), T(1), + tile_a.ptr(pt_start), tile_a.ld(), tile_a.ptr(v_start), 1, T(1), w.data(), 1); + } + } +} + +template +void updateTrailingPanel(const bool has_head, const std::vector>& panel, SizeType j, + const std::vector& w, const T tau, const std::size_t begin, + const std::size_t end) { + const TileElementIndex index_el_x0(j, j); + + bool has_first_component = has_head; + + common::internal::SingleThreadedBlasScope single; + + // GER Pt = Pt - tau . v . w* + for (auto index = begin; index < end; ++index) { + const matrix::Tile& tile_a = panel[index]; + const SizeType first_element = has_first_component ? index_el_x0.row() : 0; + + TileElementIndex pt_start{first_element, index_el_x0.col() + 1}; + TileElementSize pt_size{tile_a.size().rows() - pt_start.row(), + tile_a.size().cols() - pt_start.col()}; + TileElementIndex v_start{first_element, index_el_x0.col()}; + + if (has_first_component) { + const TileElementSize offset{1, 0}; + + // Pt = Pt - tau * v[0] * w* + const T fake_v = 1; + blas::ger(blas::Layout::ColMajor, 1, pt_size.cols(), -dlaf::conj(tau), &fake_v, 1, w.data(), 1, + tile_a.ptr(pt_start), tile_a.ld()); + + pt_start = pt_start + offset; + v_start = v_start + offset; + pt_size = pt_size - offset; + + has_first_component = false; + } + + if (pt_start.isIn(tile_a.size())) { + // Pt = Pt - tau * v * w* + blas::ger(blas::Layout::ColMajor, pt_size.rows(), pt_size.cols(), -dlaf::conj(tau), + tile_a.ptr(v_start), 1, w.data(), 1, tile_a.ptr(pt_start), tile_a.ld()); + } + } +} + +template +void hemmDiag(pika::execution::thread_priority priority, ASender&& tile_a, WSender&& tile_w, + XSender&& tile_x) { + using T = dlaf::internal::SenderElementType; + using pika::execution::thread_stacksize; + + pika::execution::experimental::start_detached( + dlaf::internal::whenAllLift(blas::Side::Left, blas::Uplo::Lower, T(1), + std::forward(tile_a), std::forward(tile_w), T(1), + std::forward(tile_x)) | + tile::hemm(dlaf::internal::Policy(priority, thread_stacksize::nostack))); +} + +// X += op(A) * W +template +void hemmOffDiag(pika::execution::thread_priority priority, blas::Op op, ASender&& tile_a, + WSender&& tile_w, XSender&& tile_x) { + using T = dlaf::internal::SenderElementType; + using pika::execution::thread_stacksize; + + pika::execution::experimental::start_detached( + dlaf::internal::whenAllLift(op, blas::Op::NoTrans, T(1), std::forward(tile_a), + std::forward(tile_w), T(1), std::forward(tile_x)) | + tile::gemm(dlaf::internal::Policy(priority, thread_stacksize::nostack))); +} + +template +void her2kDiag(pika::execution::thread_priority priority, VSender&& tile_v, XSender&& tile_x, + ASender&& tile_a) { + using T = dlaf::internal::SenderElementType; + using pika::execution::thread_stacksize; + + pika::execution::experimental::start_detached( + dlaf::internal::whenAllLift(blas::Uplo::Lower, blas::Op::NoTrans, T(-1), + std::forward(tile_v), std::forward(tile_x), + BaseType(1), std::forward(tile_a)) | + tile::her2k(dlaf::internal::Policy(priority, thread_stacksize::nostack))); +} + +// C -= A . B* +template +void her2kOffDiag(pika::execution::thread_priority priority, ASender&& tile_a, BSender&& tile_b, + CSender&& tile_c) { + using T = dlaf::internal::SenderElementType; + using pika::execution::thread_stacksize; + + pika::execution::experimental::start_detached( + dlaf::internal::whenAllLift(blas::Op::NoTrans, blas::Op::ConjTrans, T(-1), + std::forward(tile_a), std::forward(tile_b), T(1), + std::forward(tile_c)) | + tile::gemm(dlaf::internal::Policy(priority, thread_stacksize::nostack))); +} + +} + +namespace red2band::local { + +template +T computeReflector(const std::vector>& panel, SizeType j) { + constexpr bool has_head = true; + + std::array x0_and_squares = computeX0AndSquares(has_head, panel, j); + + auto tau = computeReflectorAndTau(has_head, panel, j, std::move(x0_and_squares)); + + return tau; +} + +template +void computePanelReflectors(MatrixLikeA& mat_a, matrix::ReadWriteTileSender tile_tau, + const matrix::SubPanelView& panel_view) { + static_assert(D == MatrixLikeA::device); + static_assert(std::is_same_v); + + using pika::execution::thread_priority; + namespace ex = pika::execution::experimental; + namespace di = dlaf::internal; + + std::vector> panel_tiles; + const auto panel_range = panel_view.iteratorLocal(); + const std::size_t panel_ntiles = to_sizet(std::distance(panel_range.begin(), panel_range.end())); + + if (panel_ntiles == 0) { + return; + } + + panel_tiles.reserve(panel_ntiles); + for (const auto& i : panel_range) { + const matrix::SubTileSpec& spec = panel_view(i); + panel_tiles.emplace_back(matrix::splitTile(mat_a.readwrite(i), spec)); + } + + const std::size_t nthreads = getReductionToBandPanelNWorkers(); + ex::start_detached( + ex::when_all(ex::just(std::make_unique>(nthreads), + std::vector>{}), // w (internally required) + std::move(tile_tau), ex::when_all_vector(std::move(panel_tiles))) | + ex::transfer(di::getBackendScheduler(thread_priority::high)) | + ex::bulk(nthreads, [nthreads, cols = panel_view.cols()](const std::size_t index, auto& barrier_ptr, + auto& w, auto& taus, auto& tiles) { + const auto barrier_busy_wait = getReductionToBandBarrierBusyWait(); + const std::size_t batch_size = util::ceilDiv(tiles.size(), nthreads); + const std::size_t begin = index * batch_size; + const std::size_t end = std::min(index * batch_size + batch_size, tiles.size()); + const SizeType nrefls = taus.size().rows(); + + if (index == 0) { + w.resize(nthreads); + } + + for (SizeType j = 0; j < nrefls; ++j) { + // STEP1: compute tau and reflector (single-thread) + if (index == 0) { + taus({j, 0}) = computeReflector(tiles, j); + } + + barrier_ptr->arrive_and_wait(barrier_busy_wait); + + // STEP2a: compute w (multi-threaded) + const SizeType pt_cols = cols - (j + 1); + if (pt_cols == 0) + break; + const bool has_head = (index == 0); + + w[index] = common::internal::vector(pt_cols, 0); + computeWTrailingPanel(has_head, tiles, w[index], j, pt_cols, begin, end); + barrier_ptr->arrive_and_wait(barrier_busy_wait); + + // STEP2b: reduce w results (single-threaded) + if (index == 0) + dlaf::eigensolver::internal::reduceColumnVectors(w); + barrier_ptr->arrive_and_wait(barrier_busy_wait); + + // STEP3: update trailing panel (multi-threaded) + updateTrailingPanel(has_head, tiles, j, w[0], taus({j, 0}), begin, end); + barrier_ptr->arrive_and_wait(barrier_busy_wait); + } + })); +} + +template +void setupReflectorPanelV(bool has_head, const matrix::SubPanelView& panel_view, const SizeType nrefls, + matrix::Panel& v, matrix::Matrix& mat_a, + bool force_copy = false) { + namespace ex = pika::execution::experimental; + + using pika::execution::thread_priority; + using pika::execution::thread_stacksize; + + // Note: + // Reflectors are stored in the lower triangular part of the A matrix leading to sharing memory + // between reflectors and results, which are in the upper triangular part. The problem exists only + // for the first tile (of the V, i.e. band excluded). Since refelectors will be used in next + // computations, they should be well-formed, i.e. a unit lower trapezoidal matrix. For this reason, + // a support tile is used, where just the reflectors values are copied, the diagonal is set to 1 + // and the rest is zeroed out. + auto it_begin = panel_view.iteratorLocal().begin(); + auto it_end = panel_view.iteratorLocal().end(); + + if (has_head) { + const LocalTileIndex i = *it_begin; + matrix::SubTileSpec spec = panel_view(i); + + // Note: + // If the number of reflectors are limited by height (|reflector| > 1), the panel is narrower than + // the blocksize, leading to just using a part of A (first full nrefls columns) + spec.size = {spec.size.rows(), std::min(nrefls, spec.size.cols())}; + + // Note: + // copy + laset is done in two independent tasks, but it could be theoretically merged to into a + // single task doing both. + const auto p = dlaf::internal::Policy(thread_priority::high, thread_stacksize::nostack); + ex::start_detached(dlaf::internal::whenAllLift(splitTile(mat_a.read(i), spec), v.readwrite(i)) | + matrix::copy(p)); + ex::start_detached(dlaf::internal::whenAllLift(blas::Uplo::Upper, T(0), T(1), v.readwrite(i)) | + tile::laset(p)); + + ++it_begin; + } + + // The rest of the V panel of reflectors can just point to the values in A, since they are + // well formed in-place. + for (auto it = it_begin; it < it_end; ++it) { + const LocalTileIndex idx = *it; + const matrix::SubTileSpec& spec = panel_view(idx); + + // Note: This is a workaround for the deadlock problem with sub-tiles. + // Without this copy, during matrix update the same tile would get accessed at the same + // time both in readonly mode (for reflectors) and in readwrite mode (for updating the + // matrix). This would result in a deadlock, so instead of linking the panel to an external + // tile, memory provided internally by the panel is used as support. In this way, the two + // subtiles used in the operation belong to different tiles. + if (force_copy) + ex::start_detached(ex::when_all(matrix::splitTile(mat_a.read(idx), spec), v.readwrite(idx)) | + matrix::copy(dlaf::internal::Policy(thread_priority::high, + thread_stacksize::nostack))); + else + v.setTile(idx, matrix::splitTile(mat_a.read(idx), spec)); + } +} + +template +void trmmComputeW(matrix::Panel& w, matrix::Panel& v, + matrix::ReadOnlyTileSender tile_t) { + namespace ex = pika::execution::experimental; + + using pika::execution::thread_priority; + using pika::execution::thread_stacksize; + using namespace blas; + + auto it = w.iteratorLocal(); + + for (const auto& index_i : it) { + ex::start_detached(dlaf::internal::whenAllLift(Side::Right, Uplo::Upper, Op::NoTrans, Diag::NonUnit, + T(1), tile_t, v.read(index_i), w.readwrite(index_i)) | + tile::trmm3(dlaf::internal::Policy(thread_priority::high, + thread_stacksize::nostack))); + } + + if (it.empty()) { + ex::start_detached(std::move(tile_t)); + } +} + +template +void gemmComputeW2(matrix::Matrix& w2, matrix::Panel& w, + matrix::Panel& x) { + using pika::execution::thread_priority; + using pika::execution::thread_stacksize; + + namespace ex = pika::execution::experimental; + + // Note: + // Not all ranks in the column always hold at least a tile in the panel Ai, but all ranks in + // the column are going to participate to the reduce. For them, it is important to set the + // partial result W2 to zero. + ex::start_detached(w2.readwrite(LocalTileIndex(0, 0)) | + tile::set0(dlaf::internal::Policy(thread_priority::high, + thread_stacksize::nostack))); + + using namespace blas; + // GEMM W2 = W* . X + for (const auto& index_tile : w.iteratorLocal()) + ex::start_detached( + dlaf::internal::whenAllLift(Op::ConjTrans, Op::NoTrans, T(1), w.read(index_tile), + x.read(index_tile), T(1), w2.readwrite(LocalTileIndex(0, 0))) | + tile::gemm(dlaf::internal::Policy(thread_priority::high, thread_stacksize::nostack))); +} + +template +void gemmUpdateX(matrix::Panel& x, matrix::Matrix& w2, + matrix::Panel& v) { + namespace ex = pika::execution::experimental; + + using pika::execution::thread_priority; + using pika::execution::thread_stacksize; + using namespace blas; + + // GEMM X = X - 0.5 . V . W2 + for (const auto& index_i : v.iteratorLocal()) + ex::start_detached( + dlaf::internal::whenAllLift(Op::NoTrans, Op::NoTrans, T(-0.5), v.read(index_i), + w2.read(LocalTileIndex(0, 0)), T(1), x.readwrite(index_i)) | + tile::gemm(dlaf::internal::Policy(thread_priority::high, thread_stacksize::nostack))); +} + +} + +} diff --git a/include/dlaf/eigensolver/reduction_to_band/impl.h b/include/dlaf/eigensolver/reduction_to_band/impl.h index 2a7882fd82..165b87c546 100644 --- a/include/dlaf/eigensolver/reduction_to_band/impl.h +++ b/include/dlaf/eigensolver/reduction_to_band/impl.h @@ -38,9 +38,8 @@ #include #include #include -#include -#include #include +#include #include #include #include @@ -59,408 +58,9 @@ namespace dlaf::eigensolver::internal { -// Given a vector of vectors, reduce all vectors in the first one using sum operation -template -void reduceColumnVectors(std::vector>& columnVectors) { - for (std::size_t i = 1; i < columnVectors.size(); ++i) { - DLAF_ASSERT_HEAVY(columnVectors[0].size() == columnVectors[i].size(), columnVectors[0].size(), - columnVectors[i].size()); - for (SizeType j = 0; j < columnVectors[0].size(); ++j) - columnVectors[0][j] += columnVectors[i][j]; - } -} - namespace red2band { -// Extract x0 and compute local cumulative sum of squares of the reflector column -template -std::array computeX0AndSquares(const bool has_head, const std::vector>& panel, - SizeType j) { - std::array x0_and_squares{0, 0}; - auto it_begin = panel.begin(); - auto it_end = panel.end(); - - common::internal::SingleThreadedBlasScope single; - - if (has_head) { - auto& tile_v0 = *it_begin++; - - const TileElementIndex idx_x0(j, j); - x0_and_squares[0] = tile_v0(idx_x0); - - T* reflector_ptr = tile_v0.ptr({idx_x0}); - x0_and_squares[1] = - blas::dot(tile_v0.size().rows() - idx_x0.row(), reflector_ptr, 1, reflector_ptr, 1); - } - - for (auto it = it_begin; it != it_end; ++it) { - const auto& tile = *it; - - T* reflector_ptr = tile.ptr({0, j}); - x0_and_squares[1] += blas::dot(tile.size().rows(), reflector_ptr, 1, reflector_ptr, 1); - } - return x0_and_squares; -} - -template -T computeReflectorAndTau(const bool has_head, const std::vector>& panel, - const SizeType j, std::array x0_and_squares) { - if (x0_and_squares[1] == T(0)) - return T(0); - - const T norm = std::sqrt(x0_and_squares[1]); - const T x0 = x0_and_squares[0]; - const T y = std::signbit(std::real(x0_and_squares[0])) ? norm : -norm; - const T tau = (y - x0) / y; - - auto it_begin = panel.begin(); - auto it_end = panel.end(); - - common::internal::SingleThreadedBlasScope single; - - if (has_head) { - const auto& tile_v0 = *it_begin++; - - const TileElementIndex idx_x0(j, j); - tile_v0(idx_x0) = y; - - if (j + 1 < tile_v0.size().rows()) { - T* v = tile_v0.ptr({j + 1, j}); - blas::scal(tile_v0.size().rows() - (j + 1), T(1) / (x0 - y), v, 1); - } - } - - for (auto it = it_begin; it != it_end; ++it) { - auto& tile_v = *it; - T* v = tile_v.ptr({0, j}); - blas::scal(tile_v.size().rows(), T(1) / (x0 - y), v, 1); - } - - return tau; -} - -template -void computeWTrailingPanel(const bool has_head, const std::vector>& panel, - common::internal::vector& w, SizeType j, const SizeType pt_cols, - const std::size_t begin, const std::size_t end) { - // for each tile in the panel, consider just the trailing panel - // i.e. all rows (height = reflector), just columns to the right of the current reflector - if (!(pt_cols > 0)) - return; - - const TileElementIndex index_el_x0(j, j); - bool has_first_component = has_head; - - common::internal::SingleThreadedBlasScope single; - - // W = Pt* . V - for (auto index = begin; index < end; ++index) { - const matrix::Tile& tile_a = panel[index]; - const SizeType first_element = has_first_component ? index_el_x0.row() : 0; - - TileElementIndex pt_start{first_element, index_el_x0.col() + 1}; - TileElementSize pt_size{tile_a.size().rows() - pt_start.row(), pt_cols}; - TileElementIndex v_start{first_element, index_el_x0.col()}; - - if (has_first_component) { - const TileElementSize offset{1, 0}; - - const T fake_v = 1; - blas::gemv(blas::Layout::ColMajor, blas::Op::ConjTrans, offset.rows(), pt_size.cols(), T(1), - tile_a.ptr(pt_start), tile_a.ld(), &fake_v, 1, T(0), w.data(), 1); - - pt_start = pt_start + offset; - v_start = v_start + offset; - pt_size = pt_size - offset; - - has_first_component = false; - } - - if (pt_start.isIn(tile_a.size())) { - // W += 1 . A* . V - blas::gemv(blas::Layout::ColMajor, blas::Op::ConjTrans, pt_size.rows(), pt_size.cols(), T(1), - tile_a.ptr(pt_start), tile_a.ld(), tile_a.ptr(v_start), 1, T(1), w.data(), 1); - } - } -} - -template -void updateTrailingPanel(const bool has_head, const std::vector>& panel, SizeType j, - const std::vector& w, const T tau, const std::size_t begin, - const std::size_t end) { - const TileElementIndex index_el_x0(j, j); - - bool has_first_component = has_head; - - common::internal::SingleThreadedBlasScope single; - - // GER Pt = Pt - tau . v . w* - for (auto index = begin; index < end; ++index) { - const matrix::Tile& tile_a = panel[index]; - const SizeType first_element = has_first_component ? index_el_x0.row() : 0; - - TileElementIndex pt_start{first_element, index_el_x0.col() + 1}; - TileElementSize pt_size{tile_a.size().rows() - pt_start.row(), - tile_a.size().cols() - pt_start.col()}; - TileElementIndex v_start{first_element, index_el_x0.col()}; - - if (has_first_component) { - const TileElementSize offset{1, 0}; - - // Pt = Pt - tau * v[0] * w* - const T fake_v = 1; - blas::ger(blas::Layout::ColMajor, 1, pt_size.cols(), -dlaf::conj(tau), &fake_v, 1, w.data(), 1, - tile_a.ptr(pt_start), tile_a.ld()); - - pt_start = pt_start + offset; - v_start = v_start + offset; - pt_size = pt_size - offset; - - has_first_component = false; - } - - if (pt_start.isIn(tile_a.size())) { - // Pt = Pt - tau * v * w* - blas::ger(blas::Layout::ColMajor, pt_size.rows(), pt_size.cols(), -dlaf::conj(tau), - tile_a.ptr(v_start), 1, w.data(), 1, tile_a.ptr(pt_start), tile_a.ld()); - } - } -} - -template -void hemmDiag(pika::execution::thread_priority priority, ASender&& tile_a, WSender&& tile_w, - XSender&& tile_x) { - using T = dlaf::internal::SenderElementType; - using pika::execution::thread_stacksize; - - pika::execution::experimental::start_detached( - dlaf::internal::whenAllLift(blas::Side::Left, blas::Uplo::Lower, T(1), - std::forward(tile_a), std::forward(tile_w), T(1), - std::forward(tile_x)) | - tile::hemm(dlaf::internal::Policy(priority, thread_stacksize::nostack))); -} - -// X += op(A) * W -template -void hemmOffDiag(pika::execution::thread_priority priority, blas::Op op, ASender&& tile_a, - WSender&& tile_w, XSender&& tile_x) { - using T = dlaf::internal::SenderElementType; - using pika::execution::thread_stacksize; - - pika::execution::experimental::start_detached( - dlaf::internal::whenAllLift(op, blas::Op::NoTrans, T(1), std::forward(tile_a), - std::forward(tile_w), T(1), std::forward(tile_x)) | - tile::gemm(dlaf::internal::Policy(priority, thread_stacksize::nostack))); -} - -template -void her2kDiag(pika::execution::thread_priority priority, VSender&& tile_v, XSender&& tile_x, - ASender&& tile_a) { - using T = dlaf::internal::SenderElementType; - using pika::execution::thread_stacksize; - - pika::execution::experimental::start_detached( - dlaf::internal::whenAllLift(blas::Uplo::Lower, blas::Op::NoTrans, T(-1), - std::forward(tile_v), std::forward(tile_x), - BaseType(1), std::forward(tile_a)) | - tile::her2k(dlaf::internal::Policy(priority, thread_stacksize::nostack))); -} - -// C -= A . B* -template -void her2kOffDiag(pika::execution::thread_priority priority, ASender&& tile_a, BSender&& tile_b, - CSender&& tile_c) { - using T = dlaf::internal::SenderElementType; - using pika::execution::thread_stacksize; - - pika::execution::experimental::start_detached( - dlaf::internal::whenAllLift(blas::Op::NoTrans, blas::Op::ConjTrans, T(-1), - std::forward(tile_a), std::forward(tile_b), T(1), - std::forward(tile_c)) | - tile::gemm(dlaf::internal::Policy(priority, thread_stacksize::nostack))); -} - namespace local { - -template -T computeReflector(const std::vector>& panel, SizeType j) { - constexpr bool has_head = true; - - std::array x0_and_squares = computeX0AndSquares(has_head, panel, j); - - auto tau = computeReflectorAndTau(has_head, panel, j, std::move(x0_and_squares)); - - return tau; -} - -template -void computePanelReflectors(MatrixLikeA& mat_a, MatrixLikeTaus& mat_taus, const SizeType j_sub, - const matrix::SubPanelView& panel_view) { - static Device constexpr D = MatrixLikeA::device; - using T = typename MatrixLikeA::ElementType; - using pika::execution::thread_priority; - namespace ex = pika::execution::experimental; - namespace di = dlaf::internal; - - std::vector> panel_tiles; - panel_tiles.reserve(to_sizet(std::distance(panel_view.iteratorLocal().begin(), - panel_view.iteratorLocal().end()))); - for (const auto& i : panel_view.iteratorLocal()) { - const matrix::SubTileSpec& spec = panel_view(i); - panel_tiles.emplace_back(matrix::splitTile(mat_a.readwrite(i), spec)); - } - - const std::size_t nthreads = getReductionToBandPanelNWorkers(); - auto s = - ex::when_all(ex::just(std::make_unique>(nthreads), - std::vector>{}), // w (internally required) - mat_taus.readwrite(LocalTileIndex(j_sub, 0)), - ex::when_all_vector(std::move(panel_tiles))) | - ex::transfer(di::getBackendScheduler(thread_priority::high)) | - ex::bulk(nthreads, [nthreads, cols = panel_view.cols()](const std::size_t index, auto& barrier_ptr, - auto& w, auto& taus, auto& tiles) { - const auto barrier_busy_wait = getReductionToBandBarrierBusyWait(); - const std::size_t batch_size = util::ceilDiv(tiles.size(), nthreads); - const std::size_t begin = index * batch_size; - const std::size_t end = std::min(index * batch_size + batch_size, tiles.size()); - const SizeType nrefls = taus.size().rows(); - - if (index == 0) { - w.resize(nthreads); - } - - for (SizeType j = 0; j < nrefls; ++j) { - // STEP1: compute tau and reflector (single-thread) - if (index == 0) { - taus({j, 0}) = computeReflector(tiles, j); - } - - barrier_ptr->arrive_and_wait(barrier_busy_wait); - - // STEP2a: compute w (multi-threaded) - const SizeType pt_cols = cols - (j + 1); - if (pt_cols == 0) - break; - const bool has_head = (index == 0); - - w[index] = common::internal::vector(pt_cols, 0); - computeWTrailingPanel(has_head, tiles, w[index], j, pt_cols, begin, end); - barrier_ptr->arrive_and_wait(barrier_busy_wait); - - // STEP2b: reduce w results (single-threaded) - if (index == 0) - dlaf::eigensolver::internal::reduceColumnVectors(w); - barrier_ptr->arrive_and_wait(barrier_busy_wait); - - // STEP3: update trailing panel (multi-threaded) - updateTrailingPanel(has_head, tiles, j, w[0], taus({j, 0}), begin, end); - barrier_ptr->arrive_and_wait(barrier_busy_wait); - } - }); - ex::start_detached(std::move(s)); -} - -template -void setupReflectorPanelV(bool has_head, const matrix::SubPanelView& panel_view, const SizeType nrefls, - matrix::Panel& v, matrix::Matrix& mat_a, - bool force_copy = false) { - namespace ex = pika::execution::experimental; - - using pika::execution::thread_priority; - using pika::execution::thread_stacksize; - - // Note: - // Reflectors are stored in the lower triangular part of the A matrix leading to sharing memory - // between reflectors and results, which are in the upper triangular part. The problem exists only - // for the first tile (of the V, i.e. band excluded). Since refelectors will be used in next - // computations, they should be well-formed, i.e. a unit lower trapezoidal matrix. For this reason, - // a support tile is used, where just the reflectors values are copied, the diagonal is set to 1 - // and the rest is zeroed out. - auto it_begin = panel_view.iteratorLocal().begin(); - auto it_end = panel_view.iteratorLocal().end(); - - if (has_head) { - const LocalTileIndex i = *it_begin; - matrix::SubTileSpec spec = panel_view(i); - - // Note: - // If the number of reflectors are limited by height (|reflector| > 1), the panel is narrower than - // the blocksize, leading to just using a part of A (first full nrefls columns) - spec.size = {spec.size.rows(), std::min(nrefls, spec.size.cols())}; - - // Note: - // copy + laset is done in two independent tasks, but it could be theoretically merged to into a - // single task doing both. - const auto p = dlaf::internal::Policy(thread_priority::high, thread_stacksize::nostack); - ex::start_detached(dlaf::internal::whenAllLift(splitTile(mat_a.read(i), spec), v.readwrite(i)) | - matrix::copy(p)); - ex::start_detached(dlaf::internal::whenAllLift(blas::Uplo::Upper, T(0), T(1), v.readwrite(i)) | - tile::laset(p)); - - ++it_begin; - } - - // The rest of the V panel of reflectors can just point to the values in A, since they are - // well formed in-place. - for (auto it = it_begin; it < it_end; ++it) { - const LocalTileIndex idx = *it; - const matrix::SubTileSpec& spec = panel_view(idx); - - // Note: This is a workaround for the deadlock problem with sub-tiles. - // Without this copy, during matrix update the same tile would get accessed at the same - // time both in readonly mode (for reflectors) and in readwrite mode (for updating the - // matrix). This would result in a deadlock, so instead of linking the panel to an external - // tile, memory provided internally by the panel is used as support. In this way, the two - // subtiles used in the operation belong to different tiles. - if (force_copy) - ex::start_detached(ex::when_all(matrix::splitTile(mat_a.read(idx), spec), v.readwrite(idx)) | - matrix::copy(dlaf::internal::Policy(thread_priority::high, - thread_stacksize::nostack))); - else - v.setTile(idx, matrix::splitTile(mat_a.read(idx), spec)); - } -} - -template -void trmmComputeW(matrix::Panel& w, matrix::Panel& v, - matrix::ReadOnlyTileSender tile_t) { - namespace ex = pika::execution::experimental; - - using pika::execution::thread_priority; - using pika::execution::thread_stacksize; - using namespace blas; - - auto it = w.iteratorLocal(); - - for (const auto& index_i : it) { - ex::start_detached(dlaf::internal::whenAllLift(Side::Right, Uplo::Upper, Op::NoTrans, Diag::NonUnit, - T(1), tile_t, v.read(index_i), w.readwrite(index_i)) | - tile::trmm3(dlaf::internal::Policy(thread_priority::high, - thread_stacksize::nostack))); - } - - if (it.empty()) { - ex::start_detached(std::move(tile_t)); - } -} - -template -void gemmUpdateX(matrix::Panel& x, matrix::Matrix& w2, - matrix::Panel& v) { - namespace ex = pika::execution::experimental; - - using pika::execution::thread_priority; - using pika::execution::thread_stacksize; - using namespace blas; - - // GEMM X = X - 0.5 . V . W2 - for (const auto& index_i : v.iteratorLocal()) - ex::start_detached( - dlaf::internal::whenAllLift(Op::NoTrans, Op::NoTrans, T(-0.5), v.read(index_i), - w2.read(LocalTileIndex(0, 0)), T(1), x.readwrite(index_i)) | - tile::gemm(dlaf::internal::Policy(thread_priority::high, thread_stacksize::nostack))); -} - template void hemmComputeX(matrix::Panel& x, const matrix::SubMatrixView& view, matrix::Matrix& a, matrix::Panel& w) { @@ -516,31 +116,6 @@ void hemmComputeX(matrix::Panel& x, const matrix::SubMatrixVie } } -template -void gemmComputeW2(matrix::Matrix& w2, matrix::Panel& w, - matrix::Panel& x) { - using pika::execution::thread_priority; - using pika::execution::thread_stacksize; - - namespace ex = pika::execution::experimental; - - // Note: - // Not all ranks in the column always hold at least a tile in the panel Ai, but all ranks in - // the column are going to participate to the reduce. For them, it is important to set the - // partial result W2 to zero. - ex::start_detached(w2.readwrite(LocalTileIndex(0, 0)) | - tile::set0(dlaf::internal::Policy(thread_priority::high, - thread_stacksize::nostack))); - - using namespace blas; - // GEMM W2 = W* . X - for (const auto& index_tile : w.iteratorLocal()) - ex::start_detached( - dlaf::internal::whenAllLift(Op::ConjTrans, Op::NoTrans, T(1), w.read(index_tile), - x.read(index_tile), T(1), w2.readwrite(LocalTileIndex(0, 0))) | - tile::gemm(dlaf::internal::Policy(thread_priority::high, thread_stacksize::nostack))); -} - template void her2kUpdateTrailingMatrix(const matrix::SubMatrixView& view, matrix::Matrix& a, matrix::Panel& x, @@ -859,10 +434,10 @@ template struct ComputePanelHelper { ComputePanelHelper(const std::size_t, matrix::Distribution) {} - void call(Matrix& mat_a, Matrix& mat_taus, const SizeType j_sub, + void call(Matrix& mat_a, matrix::ReadWriteTileSender tile_tau, const matrix::SubPanelView& panel_view) { using red2band::local::computePanelReflectors; - computePanelReflectors(mat_a, mat_taus, j_sub, panel_view); + computePanelReflectors(mat_a, std::move(tile_tau), panel_view); } template @@ -896,7 +471,7 @@ struct ComputePanelHelper { auto& v = panels_v.nextResource(); copyToCPU(panel_view, mat_a, v); - computePanelReflectors(v, mat_taus, j_sub, panel_view); + computePanelReflectors(v, mat_taus.readwrite(GlobalTileIndex(j_sub, 0)), panel_view); copyFromCPU(panel_view, v, mat_a); } @@ -1039,7 +614,7 @@ Matrix ReductionToBand::call(Matrix& mat_a, const v.setWidth(nrefls_tile); // PANEL - compute_panel_helper.call(mat_a, mat_taus_retiled, j_sub, panel_view); + compute_panel_helper.call(mat_a, mat_taus_retiled.readwrite(GlobalTileIndex(j_sub, 0)), panel_view); // Note: // - has_reflector_head tells if this rank owns the first tile of the panel (being local, always true) @@ -1321,7 +896,7 @@ Matrix ReductionToBand::call(comm::CommunicatorGrid& gr xt.setRangeStart(at_offset); xt.setHeight(nrefls_tile); - comm::broadcast(rank_v0.col(), x, xt, mpi_row_chain, mpi_col_chain); + comm::broadcast_all(rank_v0.col(), x, xt, mpi_row_chain, mpi_col_chain); // TRAILING MATRIX UPDATE @@ -1378,30 +953,6 @@ Matrix ReductionToBand::call(comm::CommunicatorGrid& gr dist.template nextLocalTileFromGlobalElement(at_offset.col()), }; - // Note: - // This additional communication of the last tile is a workaround for supporting following trigger - // when b < mb. - // Indeed, if b < mb the last column have (at least) a panel to compute, but differently from - // other columns, broadcast transposed doesn't communicate the last tile, which is an assumption - // needed to make the following trigger work correctly. - const SizeType at_tile_col = - dist.template globalTileFromGlobalElement(at_offset.col()); - - if (at_tile_col == dist.nrTiles().cols() - 1) { - const comm::IndexT_MPI owner = rank_v0.row(); - if (rank.row() == owner) { - xt.setTile(at, x.read(at)); - - if (dist.commGridSize().rows() > 1) - ex::start_detached(comm::schedule_bcast_send(mpi_col_chain.exclusive(), xt.read(at))); - } - else { - if (dist.commGridSize().rows() > 1) - ex::start_detached(comm::schedule_bcast_recv(mpi_col_chain.exclusive(), owner, - xt.readwrite(at))); - } - } - if constexpr (dlaf::comm::CommunicationDevice_v == D) { // Note: // if there is no need for additional buffers, we can just wait that xt[0] is ready for diff --git a/include/dlaf/factorization/qr/t_factor_impl.h b/include/dlaf/factorization/qr/t_factor_impl.h index 87ed14e845..49a9ac73b4 100644 --- a/include/dlaf/factorization/qr/t_factor_impl.h +++ b/include/dlaf/factorization/qr/t_factor_impl.h @@ -58,9 +58,9 @@ struct Helpers { std::forward(t)); } - template - static auto gemvColumnT(SizeType first_row_tile, VISender tile_vi, - matrix::ReadOnlyTileSender taus, TSender&& tile_t) { + static auto gemvColumnT(SizeType first_row_tile, matrix::ReadOnlyTileSender tile_vi, + matrix::ReadOnlyTileSender taus, + matrix::ReadWriteTileSender&& tile_t) { namespace ex = pika::execution::experimental; auto gemv_func = [first_row_tile](const auto& tile_v, const auto& taus, auto&& tile_t) noexcept { @@ -76,12 +76,13 @@ struct Helpers { // Position of the 1 in the diagonal in the current column. SizeType i_diag = j - first_row_tile; - const SizeType first_element_in_col = std::max(0, i_diag); // Break if the reflector starts in the next tile. if (i_diag >= tile_v.size().rows()) break; + const SizeType first_element_in_col = std::max(0, i_diag); + // T(0:j, j) = -tau . V(j:, 0:j)* . V(j:, j) // [j x 1] = [(n-j) x j]* . [(n-j) x 1] TileElementIndex va_start{first_element_in_col, 0}; @@ -93,20 +94,18 @@ struct Helpers { } blas::gemv(blas::Layout::ColMajor, blas::Op::ConjTrans, va_size.rows(), va_size.cols(), -tau, - tile_v.ptr(va_start), tile_v.ld(), tile_v.ptr(vb_start), 1, 1, tile_t.ptr(t_start), + tile_v.ptr(va_start), tile_v.ld(), tile_v.ptr(vb_start), 1, T(1), tile_t.ptr(t_start), 1); } return std::move(tile_t); }; return dlaf::internal::transform( dlaf::internal::Policy(pika::execution::thread_priority::high), - std::move(gemv_func), ex::when_all(tile_vi, std::move(taus), std::forward(tile_t))); + std::move(gemv_func), ex::when_all(tile_vi, std::move(taus), std::move(tile_t))); } template static auto trmvUpdateColumn(TSender&& tile_t) noexcept { - namespace ex = pika::execution::experimental; - // Update each column (in order) t = T . t // remember that T is upper triangular, so it is possible to use TRMV auto trmv_func = [](matrix::Tile&& tile_t) { @@ -121,6 +120,7 @@ struct Helpers { // TODO: Why return if the tile is unused? return std::move(tile_t); }; + return dlaf::internal::transform( dlaf::internal::Policy(pika::execution::thread_priority::high), std::move(trmv_func), std::forward(tile_t)); @@ -237,7 +237,8 @@ void QR_Tfactor::call(matrix::Panel& if (hh_panel.getWidth() == 0) return; - const auto v_start = hh_panel.offsetElement(); + const SizeType bs = hh_panel.parentDistribution().blockSize().rows(); + const SizeType offset_lc = (bs - hh_panel.tile_size_of_local_head().rows()); matrix::ReadWriteTileSender t_local = Helpers::set0(std::move(t)); @@ -261,7 +262,7 @@ void QR_Tfactor::call(matrix::Panel& // -tau(j) . V(j:, 0:j)* . V(j:, j) for (const auto& v_i : hh_panel.iteratorLocal()) { const SizeType first_row_tile = - std::max(0, v_i.row() * hh_panel.parentDistribution().blockSize().rows() - v_start); + std::max(0, (v_i.row() - hh_panel.rangeStartLocal()) * bs - offset_lc); // Note: // Since we are writing always on the same t, the gemv are serialized diff --git a/include/dlaf/matrix/panel.h b/include/dlaf/matrix/panel.h index 498ed60697..16b79476d0 100644 --- a/include/dlaf/matrix/panel.h +++ b/include/dlaf/matrix/panel.h @@ -304,6 +304,10 @@ struct Panel { has_been_used_ = false; } + TileElementSize tile_size_of_local_head() const { + return tileSize(LocalTileIndex(coord, rangeStartLocal())); + } + protected: using ReadWriteSenderType = typename BaseT::ReadWriteSenderType; diff --git a/miniapp/miniapp_reduction_to_band.cpp b/miniapp/miniapp_reduction_to_band.cpp index 299af230b0..147aa073bd 100644 --- a/miniapp/miniapp_reduction_to_band.cpp +++ b/miniapp/miniapp_reduction_to_band.cpp @@ -47,6 +47,7 @@ struct Options SizeType m; SizeType mb; SizeType b; + bool use_ca_algo; #ifdef DLAF_WITH_HDF5 std::filesystem::path input_file; std::string input_dataset; @@ -55,7 +56,7 @@ struct Options Options(const pika::program_options::variables_map& vm) : MiniappOptions(vm), m(vm["matrix-size"].as()), mb(vm["block-size"].as()), - b(vm["band-size"].as()) { + b(vm["band-size"].as()), use_ca_algo(vm["ca"].as()) { DLAF_ASSERT(m > 0, m); DLAF_ASSERT(mb > 0, mb); @@ -148,17 +149,20 @@ struct reductionToBandMiniapp { DLAF_MPI_CHECK_ERROR(MPI_Barrier(world)); dlaf::common::Timer<> timeit; - auto bench = [&]() { - if (opts.local) - return dlaf::eigensolver::internal::reduction_to_band(matrix, opts.b); - else - return dlaf::eigensolver::internal::reduction_to_band(comm_grid, matrix, opts.b); - }; - auto taus = bench(); + if (opts.use_ca_algo) { + auto res = + dlaf::eigensolver::internal::ca_reduction_to_band(comm_grid, matrix, opts.b); + res.taus_1st.waitLocalTiles(); + res.taus_2nd.waitLocalTiles(); + res.hh_2nd.waitLocalTiles(); + } + else { + auto taus = dlaf::eigensolver::internal::reduction_to_band(comm_grid, matrix, opts.b); + taus.waitLocalTiles(); + } // wait and barrier for all ranks matrix.waitLocalTiles(); - taus.waitLocalTiles(); DLAF_MPI_CHECK_ERROR(MPI_Barrier(world)); elapsed_time = timeit.elapsed(); @@ -244,9 +248,10 @@ int main(int argc, char** argv) { // clang-format off desc_commandline.add_options() - ("matrix-size", value() ->default_value(4096), "Matrix rows") - ("block-size", value() ->default_value( 256), "Block cyclic distribution size") - ("band-size", value() ->default_value( -1), "Band size (a negative value implies band-size=block-size") + ("matrix-size", value() ->default_value(4096), "Matrix rows") + ("block-size", value() ->default_value( 256), "Block cyclic distribution size") + ("band-size", value() ->default_value( -1), "Band size (a negative value implies band-size=block-size") + ("ca", bool_switch() ->default_value(false), "Use communication avoiding") #ifdef DLAF_WITH_HDF5 ("input-file", value() , "Load matrix from given HDF5 file") ("output-file", value() , "Save band matrix to given HDF5 file") diff --git a/src/eigensolver/reduction_to_band/gpu.cpp b/src/eigensolver/reduction_to_band/gpu.cpp index 7e0cce7e77..e079f22fe1 100644 --- a/src/eigensolver/reduction_to_band/gpu.cpp +++ b/src/eigensolver/reduction_to_band/gpu.cpp @@ -8,6 +8,7 @@ // SPDX-License-Identifier: BSD-3-Clause // +#include #include namespace dlaf::eigensolver::internal { diff --git a/src/eigensolver/reduction_to_band/mc.cpp b/src/eigensolver/reduction_to_band/mc.cpp index 93c8be8806..5b3a041356 100644 --- a/src/eigensolver/reduction_to_band/mc.cpp +++ b/src/eigensolver/reduction_to_band/mc.cpp @@ -8,6 +8,7 @@ // SPDX-License-Identifier: BSD-3-Clause // +#include #include namespace dlaf::eigensolver::internal { diff --git a/test/unit/eigensolver/test_reduction_to_band.cpp b/test/unit/eigensolver/test_reduction_to_band.cpp index c4bee9db09..0a9dbae12b 100644 --- a/test/unit/eigensolver/test_reduction_to_band.cpp +++ b/test/unit/eigensolver/test_reduction_to_band.cpp @@ -8,12 +8,17 @@ // SPDX-License-Identifier: BSD-3-Clause // +#include #include +#include +#include +#include #include #include #include +#include #include #include #include @@ -24,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -84,10 +90,14 @@ std::vector configs{ {{0, 0}, {3, 3}}, // full-tile band {{3, 3}, {3, 3}}, // single tile (nothing to do) + {{6, 6}, {3, 3}}, // tile always full size (less room for distribution over ranks) + {{9, 9}, {3, 3}}, // tile always full size (less room for distribution over ranks) {{12, 12}, {3, 3}}, // tile always full size (less room for distribution over ranks) - {{13, 13}, {3, 3}}, // tile incomplete {{24, 24}, {3, 3}}, // tile always full size (more room for distribution) {{40, 40}, {5, 5}}, + // tile incomplete + {{8, 8}, {3, 3}}, + {{13, 13}, {3, 3}}, }; std::vector configs_subband{ @@ -489,3 +499,283 @@ TYPED_TEST(ReductionToBandTestGPU, CorrectnessDistributedSubBand) { } } #endif + +template +struct CAReductionToBandTest : public TestWithCommGrids {}; + +template +using CAReductionToBandTestMC = ReductionToBandTest; + +TYPED_TEST_SUITE(CAReductionToBandTestMC, MatrixElementTypes); + +template +MatrixLocal allGatherT(Matrix& source, comm::CommunicatorGrid& comm_grid) { + // TODO tranposed distribution + // DLAF_ASSERT(matrix::equal_process_grid(source, comm_grid), source, comm_grid); + + namespace tt = pika::this_thread::experimental; + + MatrixLocal> dest(source.size(), source.baseTileSize()); + + const auto& dist_source = source.distribution(); + const auto rank = transposed(dist_source.rank_index()); + + for (const auto& ij : iterate_range2d(dist_source.nr_tiles())) { + const comm::Index2D owner = transposed(dist_source.rank_global_tile(ij)); + + auto& dest_tile = dest.tile(ij); + + if (owner == rank) { + const auto source_tile_holder = tt::sync_wait(source.read(ij)); + const auto& source_tile = source_tile_holder.get(); + comm::sync::broadcast::send(comm_grid.fullCommunicator(), source_tile); + matrix::internal::copy(source_tile, dest_tile); + } + else { + comm::sync::broadcast::receive_from(comm_grid.rankFullCommunicator(owner), + comm_grid.fullCommunicator(), dest_tile); + } + } + + return MatrixLocal(std::move(dest)); +} + +template +auto checkResult(const Distribution dist, const SizeType band_size, + Matrix& reference, const MatrixLocal& mat_b, + const MatrixLocal& mat_hh_1st, const MatrixLocal& taus_1st, + const MatrixLocal& mat_hh_2nd, const std::vector& taus_2nd) { + const GlobalElementIndex offset(band_size, 0); + // Now that all input are collected locally, it's time to apply the transformation, + // ...but just if there is any + if (offset.isIn(mat_hh_1st.size())) { + dlaf::common::internal::SingleThreadedBlasScope single; + + const SizeType ntiles = mat_b.nrTiles().cols() - 1; + + // Apply in reverse order (blocked algorithm), which means both from last to first, inverting + // intra-step too, i.e. 2nd first and 1st last. + for (SizeType j = ntiles - 1; j >= 0; --j) { + const SizeType i = j + 1; + const SizeType i_el = + dist.template global_element_from_global_tile_and_tile_element(i, 0); + + const std::size_t nranks_with_data = + to_sizet(std::min(mat_b.nrTiles().rows() - i, dist.grid_size().rows())); + + // === 2nd pass + // prepare workspace (height = max(nranks)) + reorder heads + std::vector col_rank_order(nranks_with_data, -1); + const comm::IndexT_MPI first_rank = dist.template rank_global_tile(i); + std::iota(col_rank_order.begin(), col_rank_order.end(), first_rank); + std::transform(col_rank_order.begin(), col_rank_order.end(), col_rank_order.begin(), + [size = dist.grid_size().rows()](const comm::IndexT_MPI& value) { + return std::modulus{}(value, size); + }); + + // HH2 + const matrix::Distribution dist_hh_2nd = [&]() { + using matrix::internal::distribution::global_tile_element_distance; + const SizeType i_begin = i; + const SizeType i_end = std::min(i + dist.grid_size().rows(), dist.nr_tiles().rows()); + const SizeType nrows = global_tile_element_distance(dist, i_begin, i_end); + return matrix::Distribution({nrows, mat_b.blockSize().cols()}, mat_b.blockSize()); + }(); + MatrixLocal hh_2nd(dist_hh_2nd.size(), dist_hh_2nd.block_size()); + + const SizeType nrefls = [&]() { + const SizeType reflector_size = hh_2nd.size().rows(); + return std::min(hh_2nd.size().cols(), reflector_size - 1); + }(); + + if (nrefls > 0) { + for (SizeType i = 0; i < to_SizeType(col_rank_order.size()); ++i) { + const SizeType ii = to_SizeType(col_rank_order[to_sizet(i)]); + + const bool is_last = i == (to_SizeType(col_rank_order.size()) - 1); + const SizeType last_rows = hh_2nd.size().rows() % dist.block_size().rows(); + if (!is_last || last_rows == 0) { + matrix::internal::copy(mat_hh_2nd.tile({ii, j}), hh_2nd.tile({i, 0})); + } + else { + const auto& tile = mat_hh_2nd.tile({ii, j}).subTileReference( + {{0, 0}, {last_rows, dist.block_size().cols()}}); + matrix::internal::copy(tile, hh_2nd.tile({i, 0})); + } + } + + // T2 + const SizeType j_el = + dist.template global_element_from_global_tile_and_tile_element(j, 0); + + MatrixLocal T_2nd({nrefls, nrefls}, mat_b.blockSize()); + lapack::larft(lapack::Direction::Forward, lapack::StoreV::Columnwise, hh_2nd.size().rows(), + nrefls, hh_2nd.ptr(), hh_2nd.ld(), taus_2nd.data() + j_el, T_2nd.ptr(), + T_2nd.ld()); + + // Apply HH2 from L and R + lapack::larfb(lapack::Side::Left, lapack::Op::NoTrans, lapack::Direction::Forward, + lapack::StoreV::Columnwise, hh_2nd.size().rows(), mat_b.size().cols() - j_el, + nrefls, hh_2nd.ptr(), hh_2nd.ld(), T_2nd.ptr(), T_2nd.ld(), + mat_b.ptr({i_el, j_el}), mat_b.ld()); + lapack::larfb(lapack::Side::Right, lapack::Op::ConjTrans, lapack::Direction::Forward, + lapack::StoreV::Columnwise, mat_b.size().rows() - j_el, hh_2nd.size().rows(), + nrefls, hh_2nd.ptr(), hh_2nd.ld(), T_2nd.ptr(), T_2nd.ld(), + mat_b.ptr({j_el, i_el}), mat_b.ld()); + } + + // === 1st pass + // HH1 (for all ranks) + // prepare workspace (height = local matrix for each rank) with zeros to fill voids + // Note: HH_1st workspaces is stored as a matrix where each column of tiles is for a specific rank. + const matrix::Distribution dist_hh_1st({dist.size().rows() - i_el, + to_SizeType(nranks_with_data) * dist.tile_size().cols()}, + dist.tile_size()); + MatrixLocal hh_1st(dist_hh_1st.size(), dist_hh_1st.tile_size()); + + std::size_t col_rank_current = 0; + for (SizeType i_a = i; i_a < dist.nr_tiles().rows(); ++i_a, ++col_rank_current) { + col_rank_current %= col_rank_order.size(); + + const SizeType i_hh = i_a - i; + + for (SizeType j_hh = 0; j_hh < hh_1st.nrTiles().cols(); ++j_hh) { + const auto& tile_hh = hh_1st.tile({i_hh, j_hh}); + + if (j_hh == to_SizeType(col_rank_current)) { + dlaf::matrix::internal::copy(mat_hh_1st.tile({i_a, j}), tile_hh); + } + else { + dlaf::tile::internal::set0(tile_hh); + } + } + } + + // Note: well-formed heads + for (SizeType j = 0; j < hh_1st.nrTiles().cols(); ++j) { + const auto& tile_hh = hh_1st.tile({j, j}); + dlaf::tile::internal::laset(blas::Uplo::Upper, T(0), T(1), tile_hh); + } + + // Note: apply one HH1 per time, independently, order not relevant + for (SizeType col_rank = 0; col_rank < to_SizeType(col_rank_order.size()); ++col_rank) { + const SizeType rank = to_SizeType(col_rank_order[to_sizet(col_rank)]); + + const SizeType i_begin = col_rank; + const SizeType i_end_gap = dist_hh_1st.nr_tiles().rows(); + const SizeType i_end = i_begin + dlaf::util::ceilDiv(dist_hh_1st.nr_tiles().rows() - i_begin, + to_SizeType(col_rank_order.size())); + using matrix::internal::distribution::global_tile_element_distance; + + const SizeType refl_size = global_tile_element_distance(dist_hh_1st, i_begin, i_end); + + const auto& hh_1st_head = hh_1st.tile({col_rank, col_rank}); + const SizeType nrefls = std::min(refl_size - 1, hh_1st_head.size().cols()); + + if (nrefls <= 0) + continue; + + // Compute T1 + const auto& tile_taus = taus_1st.tile({j, rank}); + + const SizeType refl_size_gap = + global_tile_element_distance(dist_hh_1st, i_begin, i_end_gap); + + MatrixLocal T_1st({nrefls, nrefls}, mat_b.blockSize()); + lapack::larft(lapack::Direction::Forward, lapack::StoreV::Columnwise, refl_size_gap, nrefls, + hh_1st_head.ptr(), hh_1st_head.ld(), tile_taus.ptr(), T_1st.ptr(), T_1st.ld()); + + // Apply HH1 (of a rank) from L and R + { + const SizeType m = dist.size().rows() - (i + col_rank) * dist.tile_size().rows(); + const SizeType n = dist.size().cols() - j * dist.tile_size().cols(); + lapack::larfb(lapack::Side::Left, lapack::Op::NoTrans, lapack::Direction::Forward, + lapack::StoreV::Columnwise, m, n, nrefls, hh_1st_head.ptr(), hh_1st.ld(), + T_1st.ptr(), T_1st.ld(), mat_b.tile({i + col_rank, j}).ptr(), mat_b.ld()); + } + { + const SizeType m = dist.size().rows() - j * dist.tile_size().rows(); + const SizeType n = dist.size().cols() - (i + col_rank) * dist.tile_size().cols(); + lapack::larfb(lapack::Side::Right, lapack::Op::ConjTrans, lapack::Direction::Forward, + lapack::StoreV::Columnwise, m, n, nrefls, hh_1st_head.ptr(), hh_1st.ld(), + T_1st.ptr(), T_1st.ld(), mat_b.tile({j, i + col_rank}).ptr(), mat_b.ld()); + } + } + } + } + + // Eventually, check the result obtained by applying the inverse transformation equals the original matrix + auto result = [&dist = reference.distribution(), + &mat_local = mat_b](const GlobalElementIndex& element) { + const auto tile_index = dist.globalTileIndex(element); + const auto tile_element = dist.tileElementIndex(element); + return mat_local.tile_read(tile_index)(tile_element); + }; + + CHECK_MATRIX_NEAR(result, reference, 0, + std::max(1, mat_b.size().linear_size()) * TypeUtilities::error); +} + +template +void testCAReductionToBand(comm::CommunicatorGrid& grid, const LocalElementSize size, + const TileElementSize block_size, const SizeType band_size, + const InputMatrixStructure input_matrix_structure) { + const SizeType k_reflectors = std::max(SizeType(0), size.rows() - band_size - 1); + DLAF_ASSERT(block_size.rows() % band_size == 0, block_size.rows(), band_size); + + const Distribution dist({size.rows(), size.cols()}, block_size, grid.size(), grid.rank(), {0, 0}); + + // setup the reference input matrix + Matrix reference = [&]() { + Matrix reference(dist); + if (input_matrix_structure == InputMatrixStructure::banded) + // Matrix already in band form, with band smaller than band_size + matrix::util::set_random_hermitian_banded(reference, band_size - 1); + else + matrix::util::set_random_hermitian(reference); + return reference; + }(); + + Matrix matrix_a_h(dist); + copy(reference, matrix_a_h); + + eigensolver::internal::CARed2BandResult red2band_result = [&]() { + MatrixMirror matrix_a(matrix_a_h); + return eigensolver::internal::ca_reduction_to_band(grid, matrix_a.get(), band_size); + }(); + + ASSERT_EQ(red2band_result.taus_1st.block_size().rows(), block_size.rows()); + ASSERT_EQ(red2band_result.taus_2nd.block_size().rows(), block_size.rows()); + + checkUpperPartUnchanged(reference, matrix_a_h); + + // Wait for all work to finish before doing blocking communication + pika::wait(); + + auto mat_hh_1st = allGather(blas::Uplo::Lower, matrix_a_h, grid); + + auto taus_1st = allGatherT(red2band_result.taus_1st, grid); + ASSERT_EQ(taus_1st.size().rows(), k_reflectors); + ASSERT_EQ(taus_1st.size().cols(), grid.size().rows()); + + auto mat_hh_2nd = allGather(blas::Uplo::General, red2band_result.hh_2nd, grid); + + auto taus_2nd = allGatherTaus(k_reflectors, red2band_result.taus_2nd, grid); + ASSERT_EQ(taus_2nd.size(), k_reflectors); + + auto mat_band = makeLocal(matrix_a_h); + splitReflectorsAndBand(mat_hh_1st, mat_band, band_size); + + checkResult(dist, band_size, reference, mat_band, mat_hh_1st, taus_1st, mat_hh_2nd, taus_2nd); +} + +TYPED_TEST(CAReductionToBandTestMC, CorrectnessDistributed) { + for (auto&& comm_grid : this->commGrids()) { + for (const auto& [size, block_size, band_size] : configs) { + for (auto input_matrix_structure : {InputMatrixStructure::full}) { + testCAReductionToBand(comm_grid, size, block_size, + band_size, input_matrix_structure); + } + } + } +} diff --git a/test/unit/matrix/test_layout_info.cpp b/test/unit/matrix/test_layout_info.cpp index 935507853b..e160ca80dc 100644 --- a/test/unit/matrix/test_layout_info.cpp +++ b/test/unit/matrix/test_layout_info.cpp @@ -20,6 +20,7 @@ using namespace dlaf; using namespace testing; +// size, block_size, ld, row_offset, col_offset, min_memory const std::vector> values({{{31, 17}, {7, 11}, 31, 7, 341, 527}, // Scalapack like layout {{31, 17}, {32, 11}, 31, 31, 341, 527}, // only one row of tiles @@ -53,12 +54,7 @@ const std::vector(v); - auto block_size = std::get<1>(v); - auto ld = std::get<2>(v); - auto row_offset = std::get<3>(v); - auto col_offset = std::get<4>(v); - auto min_memory = std::get<5>(v); + auto [size, block_size, ld, row_offset, col_offset, min_memory] = v; matrix::LayoutInfo layout(size, block_size, ld, row_offset, col_offset); @@ -102,12 +98,7 @@ TEST(LayoutInfoTest, ComparisonOperator) { matrix::LayoutInfo layout0({25, 25}, {5, 5}, 50, 8, 1000); for (const auto& v : comp_values) { - auto size = std::get<0>(v); - auto block_size = std::get<1>(v); - auto ld = std::get<2>(v); - auto row_offset = std::get<3>(v); - auto col_offset = std::get<4>(v); - auto is_equal = std::get<5>(v); + auto [size, block_size, ld, row_offset, col_offset, is_equal] = v; matrix::LayoutInfo layout(size, block_size, ld, row_offset, col_offset); @@ -135,12 +126,7 @@ const std::vector(v); - auto block_size = std::get<1>(v); - auto ld = std::get<2>(v); - auto row_offset = std::get<3>(v); - auto col_offset = std::get<4>(v); - auto min_memory = std::get<5>(v); + auto [size, block_size, ld, row_offset, col_offset, min_memory] = v; matrix::LayoutInfo exp_layout(size, block_size, ld, row_offset, col_offset); matrix::LayoutInfo layout = colMajorLayout(size, block_size, ld);