diff --git a/include/dlaf/eigensolver/reduction_to_band/impl.h b/include/dlaf/eigensolver/reduction_to_band/impl.h index 7e03220b8a..897cc16ddd 100644 --- a/include/dlaf/eigensolver/reduction_to_band/impl.h +++ b/include/dlaf/eigensolver/reduction_to_band/impl.h @@ -474,6 +474,7 @@ void gemmComputeW2(matrix::Matrix& w2, matrix::Panel(thread_priority::high))); } + ex::start_detached(tile::set0(dlaf::internal::Policy(), w2.readwrite_sender(LocalTileIndex(0, 0)))); ex::start_detached(buffers.reduce(w2.readwrite_sender(LocalTileIndex(0, 0)))); } diff --git a/include/dlaf/matrix/extra_buffers.h b/include/dlaf/matrix/extra_buffers.h index 55cd9c8dc9..f546e38dbb 100644 --- a/include/dlaf/matrix/extra_buffers.h +++ b/include/dlaf/matrix/extra_buffers.h @@ -29,29 +29,51 @@ struct ExtraBuffers : protected Matrix { pika::execution::thread_priority::high))); } + auto read_sender(SizeType index) { + return Matrix::read_sender(internalIndex(index)); + } + auto readwrite_sender(SizeType index) { - index %= nbuffers_; - return Matrix::readwrite_sender(LocalTileIndex{index, 0}); + return Matrix::readwrite_sender(internalIndex(index)); } template [[nodiscard]] auto reduce(TileSender tile) { namespace ex = pika::execution::experimental; - std::vector>> buffers; - for (const auto& ij : common::iterate_range2d(this->distribution().localNrTiles())) - buffers.emplace_back(Matrix::operator()(ij)); - auto all_buffers = ex::when_all_vector(std::move(buffers)); - - return ex::when_all(std::move(tile), std::move(all_buffers)) | - ex::then([](const matrix::Tile& tile, const std::vector>& buffers) { - tile::internal::set0(tile); - for (auto& buffer : buffers) - dlaf::tile::internal::add(T(1), buffer, tile); - }); + std::vector>>> buffers; + for (SizeType index = 0; index < nbuffers_; ++index) + buffers.emplace_back(read_sender(index)); + + return ex::when_all(std::move(tile), ex::when_all_vector(std::move(buffers))) | + dlaf::internal::transform(dlaf::internal::Policy>(), + [](const matrix::Tile& tile, + const std::vector>>& + buffers, + auto&&... ts) { + for (const auto& buffer : buffers) { + if constexpr (D == Device::CPU) { + static_assert(sizeof...(ts) == 0, + "Parameter pack should be empty for MC."); + dlaf::tile::internal::add(T(1), buffer.get(), tile); + } +#ifdef DLAF_WITH_GPU + else if constexpr (D == Device::GPU) { + dlaf::tile::internal::add(T(1), buffer.get(), tile, ts...); + } +#endif + else { + DLAF_STATIC_UNIMPLEMENTED(T); + } + } + }); } protected: + LocalTileIndex internalIndex(SizeType index) const noexcept { + return LocalTileIndex{index % nbuffers_, 0}; + } + SizeType nbuffers_; }; }