diff --git a/core/base/batch_multi_vector.cpp b/core/base/batch_multi_vector.cpp index f17f1479f5f..6796b77e932 100644 --- a/core/base/batch_multi_vector.cpp +++ b/core/base/batch_multi_vector.cpp @@ -85,6 +85,38 @@ batch_dim<2> compute_batch_size( } // namespace detail +template +std::unique_ptr> +MultiVector::create_view_for_item(size_type item_id) +{ + auto exec = this->get_executor(); + auto num_rows = this->get_common_size()[0]; + auto stride = this->get_common_size()[1]; + auto mat = unbatch_type::create( + exec, this->get_common_size(), + make_array_view(exec, num_rows * stride, + this->get_values_for_item(item_id)), + stride); + return mat; +} + + +template +std::unique_ptr> +MultiVector::create_const_view_for_item(size_type item_id) const +{ + auto exec = this->get_executor(); + auto num_rows = this->get_common_size()[0]; + auto stride = this->get_common_size()[1]; + auto mat = unbatch_type::create_const( + exec, this->get_common_size(), + make_const_array_view(exec, num_rows * stride, + this->get_const_values_for_item(item_id)), + stride); + return mat; +} + + template MultiVector::MultiVector(std::shared_ptr exec, const batch_dim<2>& size) @@ -164,18 +196,13 @@ template std::vector>> MultiVector::unbatch() const { - using unbatch_type = matrix::Dense; auto exec = this->get_executor(); - auto unbatch_mats = std::vector>{}; + auto unbatched_mats = std::vector>{}; for (size_type b = 0; b < this->get_num_batch_items(); ++b) { - auto mat = unbatch_type::create(exec, this->get_common_size()); - exec->copy_from(exec.get(), mat->get_num_stored_elements(), - this->get_const_values() + - this->get_size().get_cumulative_offset(b), - mat->get_values()); - unbatch_mats.emplace_back(std::move(mat)); + unbatched_mats.emplace_back( + this->create_const_view_for_item(b)->clone()); } - return unbatch_mats; + return unbatched_mats; } @@ -336,19 +363,15 @@ void read_impl(MatrixType* mtx, const std::vector& data) GKO_THROW_IF_INVALID(data.size() > 0, "Input data is empty"); auto common_size = data[0].size; - auto batch_size = batch_dim<2>(data.size(), common_size); - for (const auto& b : data) { - auto b_size = b.size; - GKO_ASSERT_EQUAL_DIMENSIONS(common_size, b_size); - } + auto num_batch_items = data.size(); + auto batch_size = batch_dim<2>(num_batch_items, common_size); auto tmp = MatrixType::create(mtx->get_executor()->get_master(), batch_size); - tmp->fill(zero()); - for (size_type b = 0; b < data.size(); ++b) { - for (const auto& elem : data[b].nonzeros) { - tmp->at(b, elem.row, elem.column) = elem.value; - } + for (size_type b = 0; b < num_batch_items; ++b) { + assert(data[b].size() == common_size); + tmp->create_view_for_item(b)->read(data[b]); } + tmp->move_to(mtx); } @@ -370,20 +393,10 @@ void MultiVector::read(const std::vector& data) template void write_impl(const MatrixType* mtx, std::vector& data) { - auto tmp = make_temporary_clone(mtx->get_executor()->get_master(), mtx); - data = std::vector(mtx->get_num_batch_items()); for (size_type b = 0; b < mtx->get_num_batch_items(); ++b) { data[b] = {mtx->get_common_size(), {}}; - for (size_type row = 0; row < data[b].size[0]; ++row) { - for (size_type col = 0; col < data[b].size[1]; ++col) { - if (tmp->at(b, row, col) != - zero()) { - data[b].nonzeros.emplace_back(row, col, - tmp->at(b, row, col)); - } - } - } + mtx->create_const_view_for_item(b)->write(data[b]); } } diff --git a/core/test/base/batch_multi_vector.cpp b/core/test/base/batch_multi_vector.cpp index e87cedca913..055c2b899d0 100644 --- a/core/test/base/batch_multi_vector.cpp +++ b/core/test/base/batch_multi_vector.cpp @@ -55,7 +55,9 @@ class MultiVector : public ::testing::Test { mtx(gko::batch::initialize>( {{{-1.0, 2.0, 3.0}, {-1.5, 2.5, 3.5}}, {{1.0, 2.5, 3.0}, {1.0, 2.0, 3.0}}}, - exec)) + exec)), + dense_mtx(gko::initialize>( + {{1.0, 2.5, 3.0}, {1.0, 2.0, 3.0}}, exec)) {} @@ -89,6 +91,7 @@ class MultiVector : public ::testing::Test { std::shared_ptr exec; std::unique_ptr> mtx; + std::unique_ptr> dense_mtx; }; TYPED_TEST_SUITE(MultiVector, gko::test::ValueTypes); @@ -118,6 +121,13 @@ TYPED_TEST(MultiVector, CanGetValuesForEntry) } +TYPED_TEST(MultiVector, CanCreateDenseItemView) +{ + GKO_ASSERT_MTX_NEAR(this->mtx->create_view_for_item(1), this->dense_mtx, + 0.0); +} + + TYPED_TEST(MultiVector, CanBeCopied) { auto mtx_copy = gko::batch::MultiVector::create(this->exec); diff --git a/include/ginkgo/core/base/batch_multi_vector.hpp b/include/ginkgo/core/base/batch_multi_vector.hpp index 0e011f6b3ef..77171569320 100644 --- a/include/ginkgo/core/base/batch_multi_vector.hpp +++ b/include/ginkgo/core/base/batch_multi_vector.hpp @@ -130,6 +130,24 @@ class MultiVector void write(std::vector& data) const override; + /** + * Creates a mutable view (of matrix::Dense type) of one item of the Batch + * MultiVector object. Does not perform any deep copies, but only returns a + * view of the data. + * + * @param item_id The index of the batch item + * + * @return a matrix::Dense object with the data from the batch item at the + * given index. + */ + std::unique_ptr create_view_for_item(size_type item_id); + + /** + * @copydoc create_view_for_item(size_type) + */ + std::unique_ptr create_const_view_for_item( + size_type item_id) const; + /** * Unbatches the batched multi-vector and creates a std::vector of Dense * matrices @@ -208,8 +226,8 @@ class MultiVector * significantly more memory efficient than the non-constant version, * so always prefer this version. */ - const value_type* get_const_values_for_item( - size_type batch_id) const noexcept + const value_type* get_const_values_for_item(size_type batch_id) const + noexcept { GKO_ASSERT(batch_id < this->get_num_batch_items()); return values_.get_const_data() +