Skip to content

Commit

Permalink
review updates:
Browse files Browse the repository at this point in the history
- ensure equal bounds
- don't make row_ptrs owning
- disable conversion assignment
- add in-code comment
- fix tests
- use this->
- add tests for incompatible reads

Co-authored-by: Tobias Ribizel <[email protected]>
Co-authored-by: Pratik Nayak <[email protected]>
Co-authored-by: Yu-Hsiang M. Tsai <[email protected]>
  • Loading branch information
4 people committed Dec 8, 2023
1 parent a18a2e3 commit c99bce3
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 58 deletions.
13 changes: 8 additions & 5 deletions core/matrix/coo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,11 @@ void Coo<ValueType, IndexType>::read(const mat_data& data)
auto size = data.size;
auto exec = this->get_executor();
this->set_size(size);
row_idxs_.resize_and_reset(data.nonzeros.size());
col_idxs_.resize_and_reset(data.nonzeros.size());
values_.resize_and_reset(data.nonzeros.size());
device_mat_data view{exec, size, row_idxs_.as_view(), col_idxs_.as_view(),
values_.as_view()};
this->row_idxs_.resize_and_reset(data.nonzeros.size());
this->col_idxs_.resize_and_reset(data.nonzeros.size());
this->values_.resize_and_reset(data.nonzeros.size());
device_mat_data view{exec, size, this->row_idxs_.as_view(),
this->col_idxs_.as_view(), this->values_.as_view()};
const auto host_data =
make_array_view(exec->get_master(), data.nonzeros.size(),
const_cast<matrix_data_entry<ValueType, IndexType>*>(
Expand All @@ -203,6 +203,9 @@ template <typename ValueType, typename IndexType>
void Coo<ValueType, IndexType>::read(const device_mat_data& data)
{
this->set_size(data.get_size());
// copy the arrays from device matrix data into the arrays of
// this. Compared to the read(device_mat_data&&) version, the internal
// arrays keep their current ownership status
this->values_ = make_const_array_view(data.get_executor(),
data.get_num_stored_elements(),
data.get_const_values());
Expand Down
19 changes: 13 additions & 6 deletions core/matrix/csr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,12 +426,16 @@ void Csr<ValueType, IndexType>::read(const mat_data& data)
{
auto size = data.size;
auto exec = this->get_executor();
row_ptrs_.resize_and_reset(size[0] + 1);
col_idxs_.resize_and_reset(data.nonzeros.size());
values_.resize_and_reset(data.nonzeros.size());
this->set_size(size);
this->row_ptrs_.resize_and_reset(size[0] + 1);
this->col_idxs_.resize_and_reset(data.nonzeros.size());
this->values_.resize_and_reset(data.nonzeros.size());
// the device matrix data contains views on the column indices
// and values array of this matrix, and an owning array for the
// row indices (which doesn't exist in this matrix)
device_mat_data view{exec, size,
array<IndexType>{exec, data.nonzeros.size()},
col_idxs_.as_view(), values_.as_view()};
this->col_idxs_.as_view(), this->values_.as_view()};
const auto host_data =
make_array_view(exec->get_master(), data.nonzeros.size(),
const_cast<matrix_data_entry<ValueType, IndexType>*>(
Expand All @@ -440,7 +444,7 @@ void Csr<ValueType, IndexType>::read(const mat_data& data)
csr::make_aos_to_soa(*make_temporary_clone(exec, &host_data), view));
exec->run(csr::make_convert_idxs_to_ptrs(view.get_const_row_idxs(),
view.get_num_stored_elements(),
size[0], get_row_ptrs()));
size[0], this->get_row_ptrs()));
this->make_srow();
}

Expand All @@ -452,6 +456,9 @@ void Csr<ValueType, IndexType>::read(const device_mat_data& data)
auto exec = this->get_executor();
this->row_ptrs_.resize_and_reset(size[0] + 1);
this->set_size(size);
// copy the column indices and values array from the device matrix data
// into this. Compared to the read(device_mat_data&&) version, the internal
// arrays keep their current ownership status.
this->values_ = make_const_array_view(data.get_executor(),
data.get_num_stored_elements(),
data.get_const_values());
Expand All @@ -476,7 +483,7 @@ void Csr<ValueType, IndexType>::read(device_mat_data&& data)
auto size = data.get_size();
auto exec = this->get_executor();
auto arrays = data.empty_out();
this->row_ptrs_ = array<IndexType>{exec, size[0] + 1};
this->row_ptrs_.resize_and_reset(size[0] + 1);
this->set_size(size);
this->values_ = std::move(arrays.values);
this->col_idxs_ = std::move(arrays.col_idxs);
Expand Down
8 changes: 3 additions & 5 deletions core/test/base/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,6 @@ TYPED_TEST(Array, CopyArrayToView)
{
TypeParam data[] = {1, 2, 3};
auto view = gko::make_array_view(this->exec, 2, data);
gko::array<TypeParam> array_size1(this->exec, {5});
gko::array<TypeParam> array_size2(this->exec, {5, 4});
gko::array<TypeParam> array_size4(this->exec, {5, 4, 2, 1});

Expand All @@ -534,7 +533,6 @@ TYPED_TEST(Array, CopyArrayToView)
EXPECT_EQ(data[1], TypeParam{4});
EXPECT_EQ(view.get_size(), 2);
EXPECT_EQ(array_size2.get_size(), 2);
ASSERT_THROW(view = array_size1, gko::OutOfBoundsError);
ASSERT_THROW(view = array_size4, gko::OutOfBoundsError);
}

Expand All @@ -546,6 +544,7 @@ TYPED_TEST(Array, CopyConstViewToArray)
gko::array<TypeParam> array(this->exec, {5, 4, 2});

array = const_view;
data[1] = 7;

EXPECT_EQ(array.get_data()[0], TypeParam{1});
EXPECT_EQ(array.get_data()[1], TypeParam{2});
Expand All @@ -561,19 +560,18 @@ TYPED_TEST(Array, CopyConstViewToView)
TypeParam data1[] = {1, 2, 3, 4};
TypeParam data2[] = {5, 4, 2};
auto view = gko::make_array_view(this->exec, 3, data2);
auto const_view2 = gko::make_const_array_view(this->exec, 2, data1);
auto const_view3 = gko::make_const_array_view(this->exec, 3, data1);
auto const_view4 = gko::make_const_array_view(this->exec, 4, data1);

view = const_view3;
data1[1] = 7;

EXPECT_EQ(view.get_data()[0], TypeParam{1});
EXPECT_EQ(view.get_data()[1], TypeParam{2});
EXPECT_EQ(view.get_data()[2], TypeParam{3});
EXPECT_EQ(view.get_size(), 3);
EXPECT_EQ(const_view3.get_size(), 3);
ASSERT_THROW(view = const_view2, gko::ValueMismatch);
ASSERT_THROW(view = const_view4, gko::ValueMismatch);
ASSERT_THROW(view = const_view4, gko::OutOfBoundsError);
}

TYPED_TEST(Array, MoveArrayToArray)
Expand Down
56 changes: 48 additions & 8 deletions core/test/matrix/coo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,22 @@ TYPED_TEST(Coo, CanBeReadFromMatrixDataIntoViews)
}


TYPED_TEST(Coo, ThrowsOnIncompatibleReadFromMatrixDataIntoViews)
{
using Mtx = typename TestFixture::Mtx;
using value_type = typename TestFixture::value_type;
using index_type = typename TestFixture::index_type;
auto row_idxs = gko::array<index_type>(this->exec, 1);
auto col_idxs = gko::array<index_type>(this->exec, 1);
auto values = gko::array<value_type>(this->exec, 1);
auto m = Mtx::create(this->exec, gko::dim<2>{2, 3}, values.as_view(),
col_idxs.as_view(), row_idxs.as_view());

ASSERT_THROW(m->read({{2, 3}, {{0, 0, 1.0}, {0, 1, 3.0}}}),
gko::NotSupported);
}


TYPED_TEST(Coo, CanBeReadFromMatrixAssemblyData)
{
using Mtx = typename TestFixture::Mtx;
Expand Down Expand Up @@ -255,7 +271,7 @@ TYPED_TEST(Coo, CanBeReadFromDeviceMatrixData)
this->assert_equal_to_original_mtx(m);
ASSERT_EQ(device_data.get_num_stored_elements(),
m->get_num_stored_elements());
GKO_ASSERT_EQUAL_DIMENSIONS(&device_data, m);
ASSERT_EQ(device_data.get_size(), m->get_size());
}


Expand Down Expand Up @@ -287,6 +303,27 @@ TYPED_TEST(Coo, CanBeReadFromDeviceMatrixDataIntoViews)
}


TYPED_TEST(Coo, ThrowsOnIncompatibleReadFromDeviceMatrixDataIntoViews)
{
using Mtx = typename TestFixture::Mtx;
using value_type = typename TestFixture::value_type;
using index_type = typename TestFixture::index_type;
auto row_idxs = gko::array<index_type>(this->exec, 1);
auto col_idxs = gko::array<index_type>(this->exec, 1);
auto values = gko::array<value_type>(this->exec, 1);
auto m = Mtx::create(this->exec, gko::dim<2>{2, 3}, values.as_view(),
col_idxs.as_view(), row_idxs.as_view());
gko::matrix_assembly_data<value_type, index_type> data(gko::dim<2>{2, 3});
data.set_value(0, 0, 1.0);
data.set_value(0, 1, 3.0);
auto device_data =
gko::device_matrix_data<value_type, index_type>::create_from_host(
this->exec, data.get_ordered_data());

ASSERT_THROW(m->read(device_data), gko::OutOfBoundsError);
}


TYPED_TEST(Coo, CanBeReadFromMovedDeviceMatrixData)
{
using Mtx = typename TestFixture::Mtx;
Expand All @@ -305,7 +342,7 @@ TYPED_TEST(Coo, CanBeReadFromMovedDeviceMatrixData)
m->read(std::move(device_data));

this->assert_equal_to_original_mtx(m);
GKO_ASSERT_EQUAL_DIMENSIONS(&device_data, gko::dim<2>{});
ASSERT_EQ(device_data.get_size(), gko::dim<2>{});
ASSERT_EQ(device_data.get_num_stored_elements(), 0);
}

Expand All @@ -315,9 +352,9 @@ TYPED_TEST(Coo, CanBeReadFromMovedDeviceMatrixDataIntoViews)
using Mtx = typename TestFixture::Mtx;
using value_type = typename TestFixture::value_type;
using index_type = typename TestFixture::index_type;
auto row_idxs = gko::array<index_type>(this->exec, 4);
auto col_idxs = gko::array<index_type>(this->exec, 4);
auto values = gko::array<value_type>(this->exec, 4);
auto row_idxs = gko::array<index_type>(this->exec, 2);
auto col_idxs = gko::array<index_type>(this->exec, 2);
auto values = gko::array<value_type>(this->exec, 2);
row_idxs.fill(0);
col_idxs.fill(0);
values.fill(gko::zero<value_type>());
Expand All @@ -331,13 +368,16 @@ TYPED_TEST(Coo, CanBeReadFromMovedDeviceMatrixDataIntoViews)
auto device_data =
gko::device_matrix_data<value_type, index_type>::create_from_host(
this->exec, data.get_ordered_data());
auto orig_row_idxs = device_data.get_row_idxs();
auto orig_col_idxs = device_data.get_col_idxs();
auto orig_values = device_data.get_values();

m->read(std::move(device_data));

this->assert_equal_to_original_mtx(m);
ASSERT_NE(row_idxs.get_data(), m->get_row_idxs());
ASSERT_NE(col_idxs.get_data(), m->get_col_idxs());
ASSERT_NE(values.get_data(), m->get_values());
ASSERT_EQ(orig_row_idxs, m->get_row_idxs());
ASSERT_EQ(orig_col_idxs, m->get_col_idxs());
ASSERT_EQ(orig_values, m->get_values());
}


Expand Down
71 changes: 53 additions & 18 deletions core/test/matrix/csr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,14 @@ class Csr : public ::testing::Test {
std::shared_ptr<const gko::Executor> exec;
std::unique_ptr<Mtx> mtx;

void assert_equal_to_original_mtx(const index_type* r, const index_type* c,
const value_type* v)
void assert_equal_to_original_mtx(gko::ptr_param<const Mtx> m)
{
auto v = m->get_const_values();
auto c = m->get_const_col_idxs();
auto r = m->get_const_row_ptrs();
auto s = m->get_const_srow();
ASSERT_EQ(m->get_size(), gko::dim<2>(2, 3));
ASSERT_EQ(m->get_num_stored_elements(), 4);
EXPECT_EQ(r[0], 0);
EXPECT_EQ(r[1], 3);
EXPECT_EQ(r[2], 4);
Expand All @@ -67,17 +72,6 @@ class Csr : public ::testing::Test {
EXPECT_EQ(v[1], value_type{3.0});
EXPECT_EQ(v[2], value_type{2.0});
EXPECT_EQ(v[3], value_type{5.0});
}

void assert_equal_to_original_mtx(gko::ptr_param<const Mtx> m)
{
auto v = m->get_const_values();
auto c = m->get_const_col_idxs();
auto r = m->get_const_row_ptrs();
auto s = m->get_const_srow();
ASSERT_EQ(m->get_size(), gko::dim<2>(2, 3));
ASSERT_EQ(m->get_num_stored_elements(), 4);
assert_equal_to_original_mtx(r, c, v);
EXPECT_EQ(s[0], 0);
}

Expand Down Expand Up @@ -241,6 +235,23 @@ TYPED_TEST(Csr, CanBeReadFromMatrixDataIntoViews)
}


TYPED_TEST(Csr, ThrowsOnIncompatibleReadFromMatrixDataIntoViews)
{
using Mtx = typename TestFixture::Mtx;
using value_type = typename TestFixture::value_type;
using index_type = typename TestFixture::index_type;
auto row_ptrs = gko::array<index_type>(this->exec, 3);
auto col_idxs = gko::array<index_type>(this->exec, 1);
auto values = gko::array<value_type>(this->exec, 1);
auto m = Mtx::create(this->exec, gko::dim<2>{2, 3}, values.as_view(),
col_idxs.as_view(), row_ptrs.as_view(),
std::make_shared<typename Mtx::load_balance>(2));

ASSERT_THROW(m->read({{2, 3}, {{0, 0, 1.0}, {0, 1, 3.0}}}),
gko::NotSupported);
}


TYPED_TEST(Csr, CanBeReadFromMatrixAssemblyData)
{
using Mtx = typename TestFixture::Mtx;
Expand Down Expand Up @@ -281,7 +292,7 @@ TYPED_TEST(Csr, CanBeReadFromDeviceMatrixData)
this->assert_equal_to_original_mtx(m);
ASSERT_EQ(device_data.get_num_stored_elements(),
m->get_num_stored_elements());
GKO_ASSERT_EQUAL_DIMENSIONS(&device_data, m);
ASSERT_EQ(device_data.get_size(), m->get_size());
}


Expand Down Expand Up @@ -314,6 +325,28 @@ TYPED_TEST(Csr, CanBeReadFromDeviceMatrixDataIntoViews)
}


TYPED_TEST(Csr, ThrowsOnIncompatibleReadFromDeviceMatrixDataIntoViews)
{
using Mtx = typename TestFixture::Mtx;
using value_type = typename TestFixture::value_type;
using index_type = typename TestFixture::index_type;
auto row_ptrs = gko::array<index_type>(this->exec, 3);
auto col_idxs = gko::array<index_type>(this->exec, 1);
auto values = gko::array<value_type>(this->exec, 1);
auto m = Mtx::create(this->exec, gko::dim<2>{2, 3}, values.as_view(),
col_idxs.as_view(), row_ptrs.as_view(),
std::make_shared<typename Mtx::load_balance>(2));
gko::matrix_assembly_data<value_type, index_type> data(m->get_size());
data.set_value(0, 0, 1.0);
data.set_value(0, 1, 3.0);
auto device_data =
gko::device_matrix_data<value_type, index_type>::create_from_host(
this->exec, data.get_ordered_data());

ASSERT_THROW(m->read(device_data), gko::OutOfBoundsError);
}


TYPED_TEST(Csr, CanBeReadFromMovedDeviceMatrixData)
{
using Mtx = typename TestFixture::Mtx;
Expand All @@ -333,7 +366,7 @@ TYPED_TEST(Csr, CanBeReadFromMovedDeviceMatrixData)
m->read(std::move(device_data));

this->assert_equal_to_original_mtx(m);
GKO_ASSERT_EQUAL_DIMENSIONS(&device_data, gko::dim<2>{});
ASSERT_EQ(device_data.get_size(), gko::dim<2>{});
ASSERT_EQ(device_data.get_num_stored_elements(), 0);
}

Expand All @@ -360,13 +393,15 @@ TYPED_TEST(Csr, CanBeReadFromMovedDeviceMatrixDataIntoViews)
auto device_data =
gko::device_matrix_data<value_type, index_type>::create_from_host(
this->exec, data.get_ordered_data());
auto orig_col_idxs = device_data.get_col_idxs();
auto orig_values = device_data.get_values();

m->read(std::move(device_data));

this->assert_equal_to_original_mtx(m);
ASSERT_NE(row_ptrs.get_data(), m->get_row_ptrs());
ASSERT_NE(col_idxs.get_data(), m->get_col_idxs());
ASSERT_NE(values.get_data(), m->get_values());
ASSERT_EQ(row_ptrs.get_data(), m->get_row_ptrs());
ASSERT_EQ(orig_col_idxs, m->get_col_idxs());
ASSERT_EQ(orig_values, m->get_values());
}


Expand Down
23 changes: 7 additions & 16 deletions include/ginkgo/core/base/array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ class array {
}

/**
* Copies or converts data from a const_array_view.
* Copies data from a const_array_view.
*
* In the case of an array target, the array is resized to match the
* source's size. In the case of a view target, if the dimensions are not
Expand All @@ -564,14 +564,10 @@ class array {
* executor, it will inherit the executor of other.
*
* @param other the const_array_view to copy from
* @tparam OtherValueType the value type of `other`
*
* @return this
*/
template <typename OtherValueType>
std::enable_if_t<std::is_convertible<OtherValueType, ValueType>::value,
array>&
operator=(const detail::const_array_view<OtherValueType>& other)
array& operator=(const detail::const_array_view<ValueType>& other)
{
if (this->exec_ == nullptr) {
this->exec_ = other.get_executor();
Expand All @@ -587,20 +583,15 @@ class array {
} else {
GKO_ENSURE_COMPATIBLE_BOUNDS(other.get_size(), this->get_size());
}
array<OtherValueType> tmp{this->exec_};
const OtherValueType* source = other.get_const_data();
// if we are on different executors: copy, then convert
array tmp{this->exec_};
const ValueType* source = other.get_const_data();
// if we are on different executors: copy
if (this->exec_ != other.get_executor()) {
tmp = other.copy_to_array();
source = tmp.get_const_data();
}
if (std::is_same<OtherValueType, ValueType>::value) {
exec_->copy_from(other.get_executor(), other.get_size(),
other.get_const_data(), this->get_data());
} else {
detail::convert_data(this->exec_, other.get_size(), source,
this->get_data());
}
exec_->copy_from(other.get_executor(), other.get_size(), source,
this->get_data());
return *this;
}

Expand Down

0 comments on commit c99bce3

Please sign in to comment.