Skip to content

Commit

Permalink
fix gpu and do not implicitly set0 the result tile
Browse files Browse the repository at this point in the history
not yet as continuation
  • Loading branch information
albestro committed Feb 22, 2023
1 parent 1c06b55 commit f2d49df
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 13 deletions.
1 change: 1 addition & 0 deletions include/dlaf/eigensolver/reduction_to_band/impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,7 @@ void gemmComputeW2(matrix::Matrix<T, D>& w2, matrix::Panel<Coord::Col, const T,
tile::gemm(dlaf::internal::Policy<B>(thread_priority::high)));
}

ex::start_detached(tile::set0(dlaf::internal::Policy<B>(), w2.readwrite_sender(LocalTileIndex(0, 0))));
ex::start_detached(buffers.reduce(w2.readwrite_sender(LocalTileIndex(0, 0))));
}

Expand Down
48 changes: 35 additions & 13 deletions include/dlaf/matrix/extra_buffers.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,29 +29,51 @@ struct ExtraBuffers : protected Matrix<T, D> {
pika::execution::thread_priority::high)));
}

auto read_sender(SizeType index) {
return Matrix<T, D>::read_sender(internalIndex(index));
}

auto readwrite_sender(SizeType index) {
index %= nbuffers_;
return Matrix<T, D>::readwrite_sender(LocalTileIndex{index, 0});
return Matrix<T, D>::readwrite_sender(internalIndex(index));
}

template <class TileSender>
[[nodiscard]] auto reduce(TileSender tile) {
namespace ex = pika::execution::experimental;

std::vector<pika::future<matrix::Tile<T, D>>> buffers;
for (const auto& ij : common::iterate_range2d(this->distribution().localNrTiles()))
buffers.emplace_back(Matrix<T, D>::operator()(ij));
auto all_buffers = ex::when_all_vector(std::move(buffers));

return ex::when_all(std::move(tile), std::move(all_buffers)) |
ex::then([](const matrix::Tile<T, D>& tile, const std::vector<matrix::Tile<T, D>>& buffers) {
tile::internal::set0(tile);
for (auto& buffer : buffers)
dlaf::tile::internal::add(T(1), buffer, tile);
});
std::vector<ex::any_sender<pika::shared_future<matrix::Tile<const T, D>>>> buffers;
for (SizeType index = 0; index < nbuffers_; ++index)
buffers.emplace_back(read_sender(index));

return ex::when_all(std::move(tile), ex::when_all_vector(std::move(buffers))) |
dlaf::internal::transform(dlaf::internal::Policy<DefaultBackend_v<D>>(),
[](const matrix::Tile<T, D>& tile,
const std::vector<pika::shared_future<matrix::Tile<const T, D>>>&
buffers,
auto&&... ts) {
for (const auto& buffer : buffers) {
if constexpr (D == Device::CPU) {
static_assert(sizeof...(ts) == 0,
"Parameter pack should be empty for MC.");
dlaf::tile::internal::add(T(1), buffer.get(), tile);
}
#ifdef DLAF_WITH_GPU
else if constexpr (D == Device::GPU) {
dlaf::tile::internal::add(T(1), buffer.get(), tile, ts...);
}
#endif
else {
DLAF_STATIC_UNIMPLEMENTED(T);
}
}
});
}

protected:
LocalTileIndex internalIndex(SizeType index) const noexcept {
return LocalTileIndex{index % nbuffers_, 0};
}

SizeType nbuffers_;
};
}

0 comments on commit f2d49df

Please sign in to comment.