Skip to content

Commit

Permalink
Fix read bug and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
pratikvn committed Jul 27, 2023
1 parent a618e7f commit f0ef0fa
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 54 deletions.
8 changes: 2 additions & 6 deletions core/base/batch_multi_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,12 +269,8 @@ void read_impl(MatrixType* mtx, const std::vector<MatrixData>& data)
MatrixType::create(mtx->get_executor()->get_master(), batch_size);
tmp->fill(zero<typename MatrixType::value_type>());
for (size_type b = 0; b < data.size(); ++b) {
size_type ind = 0;
for (size_type row = 0; row < data[b].size[0]; ++row) {
for (size_type col = 0; col < data[b].size[1]; ++col) {
tmp->at(b, row, col) = data[b].nonzeros[ind].value;
++ind;
}
for (const auto& elem : data[b].nonzeros) {
tmp->at(b, elem.row, elem.column) = elem.value;
}
}
tmp->move_to(mtx);
Expand Down
57 changes: 27 additions & 30 deletions core/test/base/batch_multi_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,33 @@ TYPED_TEST(BatchMultiVector, CanBeReadFromMatrixData)
}


TYPED_TEST(BatchMultiVector, CanBeReadFromSparseMatrixData)
{
using value_type = typename TestFixture::value_type;
auto m = gko::BatchMultiVector<TypeParam>::create(this->exec);
// clang-format off
m->read({gko::matrix_data<TypeParam>{{2, 2},
{{0, 0, 1.0},
{0, 1, 3.0},
{1, 1, 5.0}}},
gko::matrix_data<TypeParam>{{2, 2},
{{0, 0, -1.0},
{0, 1, 0.5},
{1, 1, 9.0}}}});
// clang-format on

ASSERT_EQ(m->get_common_size(), gko::dim<2>(2, 2));
EXPECT_EQ(m->at(0, 0, 0), value_type{1.0});
EXPECT_EQ(m->at(0, 0, 1), value_type{3.0});
EXPECT_EQ(m->at(0, 1, 0), value_type{0.0});
EXPECT_EQ(m->at(0, 1, 1), value_type{5.0});
EXPECT_EQ(m->at(1, 0, 0), value_type{-1.0});
EXPECT_EQ(m->at(1, 0, 1), value_type{0.5});
EXPECT_EQ(m->at(1, 1, 0), value_type{0.0});
EXPECT_EQ(m->at(1, 1, 1), value_type{9.0});
}


TYPED_TEST(BatchMultiVector, GeneratesCorrectMatrixData)
{
using value_type = typename TestFixture::value_type;
Expand All @@ -422,33 +449,3 @@ TYPED_TEST(BatchMultiVector, GeneratesCorrectMatrixData)
EXPECT_EQ(data[1].nonzeros[4], tpl(1, 1, value_type{2.0}));
EXPECT_EQ(data[1].nonzeros[5], tpl(1, 2, value_type{3.0}));
}


TYPED_TEST(BatchMultiVector, CanBeReadFromMatrixAssemblyData)
{
using value_type = typename TestFixture::value_type;
auto m = gko::BatchMultiVector<TypeParam>::create(this->exec);
gko::matrix_assembly_data<TypeParam> data1(gko::dim<2>{2, 2});
data1.set_value(0, 0, 1.0);
data1.set_value(0, 1, 3.0);
data1.set_value(1, 0, 0.0);
data1.set_value(1, 1, 5.0);
gko::matrix_assembly_data<TypeParam> data2(gko::dim<2>{2, 2});
data2.set_value(0, 0, 2.0);
data2.set_value(0, 1, 1.0);
data2.set_value(1, 0, 5.0);
data2.set_value(1, 1, 4.0);
auto data = std::vector<gko::matrix_assembly_data<TypeParam>>{data1, data2};

m->read(data);

ASSERT_EQ(m->get_common_size(), gko::dim<2>(2, 2));
EXPECT_EQ(m->at(0, 0, 0), value_type{1.0});
EXPECT_EQ(m->at(0, 1, 0), value_type{0.0});
EXPECT_EQ(m->at(0, 0, 1), value_type{3.0});
EXPECT_EQ(m->at(0, 1, 1), value_type{5.0});
EXPECT_EQ(m->at(1, 0, 0), value_type{2.0});
EXPECT_EQ(m->at(1, 1, 0), value_type{5.0});
EXPECT_EQ(m->at(1, 0, 1), value_type{1.0});
EXPECT_EQ(m->at(1, 1, 1), value_type{4.0});
}
18 changes: 0 additions & 18 deletions include/ginkgo/core/base/batch_lin_op_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,24 +75,6 @@ class BatchReadableFromMatrixData {
*/
virtual void read(
const std::vector<matrix_data<ValueType, IndexType>>& data) = 0;

/**
* Reads a matrix from a std::vector of matrix_assembly_data objects.
*
* @param data the std::vector of matrix_assembly_data objects
*/
void read(const std::vector<matrix_assembly_data<ValueType, IndexType>>&
assembly_data)
{
auto mat_data = std::vector<matrix_data<ValueType, IndexType>>(
assembly_data.size());
size_type ind = 0;
for (const auto& i : assembly_data) {
mat_data[ind] = i.get_ordered_data();
++ind;
}
this->read(mat_data);
}
};


Expand Down

0 comments on commit f0ef0fa

Please sign in to comment.