Skip to content

Commit

Permalink
first usage of extra buffers for W2 in red2band
Browse files Browse the repository at this point in the history
  • Loading branch information
albestro committed Feb 17, 2023
1 parent 3fd2578 commit 7bec672
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions include/dlaf/eigensolver/reduction_to_band/impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "dlaf/lapack/tile.h"
#include "dlaf/matrix/copy_tile.h"
#include "dlaf/matrix/distribution.h"
#include "dlaf/matrix/extra_buffers.h"
#include "dlaf/matrix/index.h"
#include "dlaf/matrix/matrix.h"
#include "dlaf/matrix/panel.h"
Expand Down Expand Up @@ -455,20 +456,25 @@ void gemmComputeW2(matrix::Matrix<T, D>& w2, matrix::Panel<Coord::Col, const T,

namespace ex = pika::execution::experimental;

// Note:
// Not all ranks in the column always hold at least a tile in the panel Ai, but all ranks in
// the column are going to participate to the reduce. For them, it is important to set the
// partial result W2 to zero.
ex::start_detached(w2.readwrite_sender(LocalTileIndex(0, 0)) |
tile::set0(dlaf::internal::Policy<B>(thread_priority::high)));
ExtraBuffers<T, D> buffers(w2.blockSize(), 6);

//// Note:
//// Not all ranks in the column always hold at least a tile in the panel Ai, but all ranks in
//// the column are going to participate to the reduce. For them, it is important to set the
//// partial result W2 to zero.
// ex::start_detached(w2.readwrite_sender(LocalTileIndex(0, 0)) |
// tile::set0(dlaf::internal::Policy<B>(thread_priority::high)));

using namespace blas;
// GEMM W2 = W* . X
for (const auto& index_tile : w.iteratorLocal())
for (const auto& index_tile : w.iteratorLocal()) {
ex::start_detached(dlaf::internal::whenAllLift(Op::ConjTrans, Op::NoTrans, T(1),
w.read_sender(index_tile), x.read_sender(index_tile),
T(1), w2.readwrite_sender(LocalTileIndex(0, 0))) |
T(1), buffers.readwrite_sender(index_tile.row())) |
tile::gemm(dlaf::internal::Policy<B>(thread_priority::high)));
}

ex::start_detached(buffers.reduce(w2.readwrite_sender(LocalTileIndex(0, 0))));
}

template <Backend B, Device D, class T>
Expand Down Expand Up @@ -959,7 +965,7 @@ common::internal::vector<pika::shared_future<common::internal::vector<T>>> Reduc
const LocalTileIndex t_idx(0, 0);
// TODO used just by the column, maybe we can re-use a panel tile?
// TODO probably the first one in any panel is ok?
Matrix<T, D> t({nrefls_block, nrefls_block}, dist.blockSize());
Matrix<T, D> t({nrefls_block, nrefls_block}, {nrefls_block, nrefls_block});

computeTFactor<B>(v, taus.back(), t.readwrite_sender(t_idx));

Expand Down Expand Up @@ -1107,7 +1113,7 @@ common::internal::vector<pika::shared_future<common::internal::vector<T>>> Reduc
const LocalTileIndex t_idx(0, 0);
// TODO used just by the column, maybe we can re-use a panel tile?
// TODO or we can keep just the sh_future and allocate just inside if (is_panel_rank_col)
matrix::Matrix<T, D> t({nrefls_block, nrefls_block}, dist.blockSize());
matrix::Matrix<T, D> t({nrefls_block, nrefls_block}, {nrefls_block, nrefls_block});

// PANEL
const matrix::SubPanelView panel_view(dist, ij_offset, band_size);
Expand Down

0 comments on commit 7bec672

Please sign in to comment.