Skip to content

Commit

Permalink
Add Dense matrix view creation
Browse files Browse the repository at this point in the history
  • Loading branch information
pratikvn committed Aug 1, 2023
1 parent 4df67c3 commit d8e7343
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 33 deletions.
73 changes: 43 additions & 30 deletions core/base/batch_multi_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,38 @@ batch_dim<2> compute_batch_size(
} // namespace detail


template <typename ValueType>
std::unique_ptr<matrix::Dense<ValueType>>
MultiVector<ValueType>::create_view_for_item(size_type item_id)
{
auto exec = this->get_executor();
auto num_rows = this->get_common_size()[0];
auto stride = this->get_common_size()[1];
auto mat = unbatch_type::create(
exec, this->get_common_size(),
make_array_view(exec, num_rows * stride,
this->get_values_for_item(item_id)),
stride);
return mat;
}


template <typename ValueType>
std::unique_ptr<const matrix::Dense<ValueType>>
MultiVector<ValueType>::create_const_view_for_item(size_type item_id) const
{
auto exec = this->get_executor();
auto num_rows = this->get_common_size()[0];
auto stride = this->get_common_size()[1];
auto mat = unbatch_type::create_const(
exec, this->get_common_size(),
make_const_array_view(exec, num_rows * stride,
this->get_const_values_for_item(item_id)),
stride);
return mat;
}


template <typename ValueType>
MultiVector<ValueType>::MultiVector(std::shared_ptr<const Executor> exec,
const batch_dim<2>& size)
Expand Down Expand Up @@ -164,18 +196,13 @@ template <typename ValueType>
std::vector<std::unique_ptr<matrix::Dense<ValueType>>>
MultiVector<ValueType>::unbatch() const
{
using unbatch_type = matrix::Dense<ValueType>;
auto exec = this->get_executor();
auto unbatch_mats = std::vector<std::unique_ptr<unbatch_type>>{};
auto unbatched_mats = std::vector<std::unique_ptr<unbatch_type>>{};
for (size_type b = 0; b < this->get_num_batch_items(); ++b) {
auto mat = unbatch_type::create(exec, this->get_common_size());
exec->copy_from(exec.get(), mat->get_num_stored_elements(),
this->get_const_values() +
this->get_size().get_cumulative_offset(b),
mat->get_values());
unbatch_mats.emplace_back(std::move(mat));
unbatched_mats.emplace_back(
this->create_const_view_for_item(b)->clone());
}
return unbatch_mats;
return unbatched_mats;
}


Expand Down Expand Up @@ -336,19 +363,15 @@ void read_impl(MatrixType* mtx, const std::vector<MatrixData>& data)
GKO_THROW_IF_INVALID(data.size() > 0, "Input data is empty");

auto common_size = data[0].size;
auto batch_size = batch_dim<2>(data.size(), common_size);
for (const auto& b : data) {
auto b_size = b.size;
GKO_ASSERT_EQUAL_DIMENSIONS(common_size, b_size);
}
auto num_batch_items = data.size();
auto batch_size = batch_dim<2>(num_batch_items, common_size);
auto tmp =
MatrixType::create(mtx->get_executor()->get_master(), batch_size);
tmp->fill(zero<typename MatrixType::value_type>());
for (size_type b = 0; b < data.size(); ++b) {
for (const auto& elem : data[b].nonzeros) {
tmp->at(b, elem.row, elem.column) = elem.value;
}
for (size_type b = 0; b < num_batch_items; ++b) {
assert(data[b].size() == common_size);
tmp->create_view_for_item(b)->read(data[b]);
}

tmp->move_to(mtx);
}

Expand All @@ -370,20 +393,10 @@ void MultiVector<ValueType>::read(const std::vector<mat_data64>& data)
template <typename MatrixType, typename MatrixData>
void write_impl(const MatrixType* mtx, std::vector<MatrixData>& data)
{
auto tmp = make_temporary_clone(mtx->get_executor()->get_master(), mtx);

data = std::vector<MatrixData>(mtx->get_num_batch_items());
for (size_type b = 0; b < mtx->get_num_batch_items(); ++b) {
data[b] = {mtx->get_common_size(), {}};
for (size_type row = 0; row < data[b].size[0]; ++row) {
for (size_type col = 0; col < data[b].size[1]; ++col) {
if (tmp->at(b, row, col) !=
zero<typename MatrixType::value_type>()) {
data[b].nonzeros.emplace_back(row, col,
tmp->at(b, row, col));
}
}
}
mtx->create_const_view_for_item(b)->write(data[b]);
}
}

Expand Down
12 changes: 11 additions & 1 deletion core/test/base/batch_multi_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ class MultiVector : public ::testing::Test {
mtx(gko::batch::initialize<gko::batch::MultiVector<value_type>>(
{{{-1.0, 2.0, 3.0}, {-1.5, 2.5, 3.5}},
{{1.0, 2.5, 3.0}, {1.0, 2.0, 3.0}}},
exec))
exec)),
dense_mtx(gko::initialize<gko::matrix::Dense<value_type>>(
{{1.0, 2.5, 3.0}, {1.0, 2.0, 3.0}}, exec))
{}


Expand Down Expand Up @@ -89,6 +91,7 @@ class MultiVector : public ::testing::Test {

std::shared_ptr<const gko::Executor> exec;
std::unique_ptr<gko::batch::MultiVector<value_type>> mtx;
std::unique_ptr<gko::matrix::Dense<value_type>> dense_mtx;
};

TYPED_TEST_SUITE(MultiVector, gko::test::ValueTypes);
Expand Down Expand Up @@ -118,6 +121,13 @@ TYPED_TEST(MultiVector, CanGetValuesForEntry)
}


TYPED_TEST(MultiVector, CanCreateDenseItemView)
{
GKO_ASSERT_MTX_NEAR(this->mtx->create_view_for_item(1), this->dense_mtx,
0.0);
}


TYPED_TEST(MultiVector, CanBeCopied)
{
auto mtx_copy = gko::batch::MultiVector<TypeParam>::create(this->exec);
Expand Down
22 changes: 20 additions & 2 deletions include/ginkgo/core/base/batch_multi_vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,24 @@ class MultiVector

void write(std::vector<mat_data64>& data) const override;

/**
* Creates a mutable view (of matrix::Dense type) of one item of the Batch
* MultiVector object. Does not perform any deep copies, but only returns a
* view of the data.
*
* @param item_id The index of the batch item
*
* @return a matrix::Dense object with the data from the batch item at the
* given index.
*/
std::unique_ptr<unbatch_type> create_view_for_item(size_type item_id);

/**
* @copydoc create_view_for_item(size_type)
*/
std::unique_ptr<const unbatch_type> create_const_view_for_item(
size_type item_id) const;

/**
* Unbatches the batched multi-vector and creates a std::vector of Dense
* matrices
Expand Down Expand Up @@ -208,8 +226,8 @@ class MultiVector
* significantly more memory efficient than the non-constant version,
* so always prefer this version.
*/
const value_type* get_const_values_for_item(
size_type batch_id) const noexcept
const value_type* get_const_values_for_item(size_type batch_id) const
noexcept
{
GKO_ASSERT(batch_id < this->get_num_batch_items());
return values_.get_const_data() +
Expand Down

0 comments on commit d8e7343

Please sign in to comment.