Skip to content

Commit

Permalink
review updates:
Browse files Browse the repository at this point in the history
- ensure equal bounds
- don't make row_ptrs owning
- disable conversion assignment
- add in-code comment
- fix tests
- use this->
- use assert_eq instead of compatible_bounds
- remove compatible_bounds

Co-authored-by: Tobias Ribizel <[email protected]>
Co-authored-by: Pratik Nayak <[email protected]>
  • Loading branch information
3 people committed Dec 7, 2023
1 parent a031529 commit 284f979
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 60 deletions.
13 changes: 8 additions & 5 deletions core/matrix/coo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,11 @@ void Coo<ValueType, IndexType>::read(const mat_data& data)
auto size = data.size;
auto exec = this->get_executor();
this->set_size(size);
row_idxs_.resize_and_reset(data.nonzeros.size());
col_idxs_.resize_and_reset(data.nonzeros.size());
values_.resize_and_reset(data.nonzeros.size());
device_mat_data view{exec, size, row_idxs_.as_view(), col_idxs_.as_view(),
values_.as_view()};
this->row_idxs_.resize_and_reset(data.nonzeros.size());
this->col_idxs_.resize_and_reset(data.nonzeros.size());
this->values_.resize_and_reset(data.nonzeros.size());
device_mat_data view{exec, size, this->row_idxs_.as_view(),
this->col_idxs_.as_view(), this->values_.as_view()};
const auto host_data =
make_array_view(exec->get_master(), data.nonzeros.size(),
const_cast<matrix_data_entry<ValueType, IndexType>*>(
Expand All @@ -203,6 +203,9 @@ template <typename ValueType, typename IndexType>
void Coo<ValueType, IndexType>::read(const device_mat_data& data)
{
this->set_size(data.get_size());
// copy the arrays from device mnatrix data into the arrays of
// this. Compared to the read(device_mat_data&&) version, the internal
// arrays keep their current ownership status
this->values_ = make_const_array_view(data.get_executor(),
data.get_num_stored_elements(),
data.get_const_values());
Expand Down
19 changes: 13 additions & 6 deletions core/matrix/csr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,12 +426,16 @@ void Csr<ValueType, IndexType>::read(const mat_data& data)
{
auto size = data.size;
auto exec = this->get_executor();
row_ptrs_.resize_and_reset(size[0] + 1);
col_idxs_.resize_and_reset(data.nonzeros.size());
values_.resize_and_reset(data.nonzeros.size());
this->set_size(size);
this->row_ptrs_.resize_and_reset(size[0] + 1);
this->col_idxs_.resize_and_reset(data.nonzeros.size());
this->values_.resize_and_reset(data.nonzeros.size());
// the device matrix data contains views on the column indices
// and values array of this matrix, and an owning array for the
// row indices (which doesn't exist in this matrix)
device_mat_data view{exec, size,
array<IndexType>{exec, data.nonzeros.size()},
col_idxs_.as_view(), values_.as_view()};
this->col_idxs_.as_view(), this->values_.as_view()};
const auto host_data =
make_array_view(exec->get_master(), data.nonzeros.size(),
const_cast<matrix_data_entry<ValueType, IndexType>*>(
Expand All @@ -440,7 +444,7 @@ void Csr<ValueType, IndexType>::read(const mat_data& data)
csr::make_aos_to_soa(*make_temporary_clone(exec, &host_data), view));
exec->run(csr::make_convert_idxs_to_ptrs(view.get_const_row_idxs(),
view.get_num_stored_elements(),
size[0], get_row_ptrs()));
size[0], this->get_row_ptrs()));
this->make_srow();
}

Expand All @@ -452,6 +456,9 @@ void Csr<ValueType, IndexType>::read(const device_mat_data& data)
auto exec = this->get_executor();
this->row_ptrs_.resize_and_reset(size[0] + 1);
this->set_size(size);
// copy the column indices and values array from the device matrix data
// into this. Compared to the read(device_mat_data&&) version, the internal
// arrays keep their current ownership status.
this->values_ = make_const_array_view(data.get_executor(),
data.get_num_stored_elements(),
data.get_const_values());
Expand All @@ -476,7 +483,7 @@ void Csr<ValueType, IndexType>::read(device_mat_data&& data)
auto size = data.get_size();
auto exec = this->get_executor();
auto arrays = data.empty_out();
this->row_ptrs_ = array<IndexType>{exec, size[0] + 1};
this->row_ptrs_.resize_and_reset(size[0] + 1);
this->set_size(size);
this->values_ = std::move(arrays.values);
this->col_idxs_ = std::move(arrays.col_idxs);
Expand Down
6 changes: 3 additions & 3 deletions core/test/base/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ TYPED_TEST(Array, CopyViewToView)
EXPECT_EQ(view.get_size(), 3);
EXPECT_EQ(view2.get_size(), 3);
EXPECT_EQ(view2.get_data()[0], TypeParam{2});
ASSERT_THROW(view2 = view_size4, gko::OutOfBoundsError);
ASSERT_THROW(view2 = view_size4, gko::ValueMismatch);
}


Expand Down Expand Up @@ -534,8 +534,8 @@ TYPED_TEST(Array, CopyArrayToView)
EXPECT_EQ(data[1], TypeParam{4});
EXPECT_EQ(view.get_size(), 2);
EXPECT_EQ(array_size2.get_size(), 2);
ASSERT_THROW(view = array_size1, gko::OutOfBoundsError);
ASSERT_THROW(view = array_size4, gko::OutOfBoundsError);
ASSERT_THROW(view = array_size1, gko::ValueMismatch);
ASSERT_THROW(view = array_size4, gko::ValueMismatch);
}


Expand Down
13 changes: 8 additions & 5 deletions core/test/matrix/coo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ TYPED_TEST(Coo, CanBeReadFromDeviceMatrixData)
this->assert_equal_to_original_mtx(m);
ASSERT_EQ(device_data.get_num_stored_elements(),
m->get_num_stored_elements());
GKO_ASSERT_EQUAL_DIMENSIONS(&device_data, m);
ASSERT_EQ(device_data.get_size(), m->get_size());
}


Expand Down Expand Up @@ -305,7 +305,7 @@ TYPED_TEST(Coo, CanBeReadFromMovedDeviceMatrixData)
m->read(std::move(device_data));

this->assert_equal_to_original_mtx(m);
GKO_ASSERT_EQUAL_DIMENSIONS(&device_data, gko::dim<2>{});
ASSERT_EQ(device_data.get_size(), gko::dim<2>{});
ASSERT_EQ(device_data.get_num_stored_elements(), 0);
}

Expand All @@ -331,13 +331,16 @@ TYPED_TEST(Coo, CanBeReadFromMovedDeviceMatrixDataIntoViews)
auto device_data =
gko::device_matrix_data<value_type, index_type>::create_from_host(
this->exec, data.get_ordered_data());
auto orig_row_idxs = device_data.get_row_idxs();
auto orig_col_idxs = device_data.get_col_idxs();
auto orig_values = device_data.get_values();

m->read(std::move(device_data));

this->assert_equal_to_original_mtx(m);
ASSERT_NE(row_idxs.get_data(), m->get_row_idxs());
ASSERT_NE(col_idxs.get_data(), m->get_col_idxs());
ASSERT_NE(values.get_data(), m->get_values());
ASSERT_EQ(orig_row_idxs, m->get_row_idxs());
ASSERT_EQ(orig_col_idxs, m->get_col_idxs());
ASSERT_EQ(orig_values, m->get_values());
}


Expand Down
12 changes: 7 additions & 5 deletions core/test/matrix/csr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ TYPED_TEST(Csr, CanBeReadFromDeviceMatrixData)
this->assert_equal_to_original_mtx(m);
ASSERT_EQ(device_data.get_num_stored_elements(),
m->get_num_stored_elements());
GKO_ASSERT_EQUAL_DIMENSIONS(&device_data, m);
ASSERT_EQ(device_data.get_size(), m->get_size());
}


Expand Down Expand Up @@ -333,7 +333,7 @@ TYPED_TEST(Csr, CanBeReadFromMovedDeviceMatrixData)
m->read(std::move(device_data));

this->assert_equal_to_original_mtx(m);
GKO_ASSERT_EQUAL_DIMENSIONS(&device_data, gko::dim<2>{});
ASSERT_EQ(device_data.get_size(), gko::dim<2>{});
ASSERT_EQ(device_data.get_num_stored_elements(), 0);
}

Expand All @@ -360,13 +360,15 @@ TYPED_TEST(Csr, CanBeReadFromMovedDeviceMatrixDataIntoViews)
auto device_data =
gko::device_matrix_data<value_type, index_type>::create_from_host(
this->exec, data.get_ordered_data());
auto orig_col_idxs = device_data.get_col_idxs();
auto orig_values = device_data.get_values();

m->read(std::move(device_data));

this->assert_equal_to_original_mtx(m);
ASSERT_NE(row_ptrs.get_data(), m->get_row_ptrs());
ASSERT_NE(col_idxs.get_data(), m->get_col_idxs());
ASSERT_NE(values.get_data(), m->get_values());
ASSERT_EQ(row_ptrs.get_data(), m->get_row_ptrs());
ASSERT_EQ(orig_col_idxs, m->get_col_idxs());
ASSERT_EQ(orig_values, m->get_values());
}


Expand Down
26 changes: 9 additions & 17 deletions include/ginkgo/core/base/array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ class array {
if (this->is_owning()) {
this->resize_and_reset(other.get_size());
} else {
GKO_ENSURE_COMPATIBLE_BOUNDS(other.get_size(), this->get_size());
GKO_ASSERT_EQ(other.get_size(), this->get_size());
}
exec_->copy_from(other.get_executor(), other.get_size(),
other.get_const_data(), this->get_data());
Expand Down Expand Up @@ -536,7 +536,7 @@ class array {
if (this->is_owning()) {
this->resize_and_reset(other.get_size());
} else {
GKO_ENSURE_COMPATIBLE_BOUNDS(other.get_size(), this->get_size());
GKO_ASSERT_EQ(other.get_size(), this->get_size());
}
array<OtherValueType> tmp{this->exec_};
const OtherValueType* source = other.get_const_data();
Expand All @@ -551,7 +551,7 @@ class array {
}

/**
* Copies or converts data from a const_array_view.
* Copies data from a const_array_view.
*
* In the case of an array target, the array is resized to match the
* source's size. In the case of a view target, if the dimensions are not
Expand All @@ -568,10 +568,7 @@ class array {
*
* @return this
*/
template <typename OtherValueType>
std::enable_if_t<std::is_convertible<OtherValueType, ValueType>::value,
array>&
operator=(const detail::const_array_view<OtherValueType>& other)
array& operator=(const detail::const_array_view<ValueType>& other)
{
if (this->exec_ == nullptr) {
this->exec_ = other.get_executor();
Expand All @@ -585,22 +582,17 @@ class array {
if (this->is_owning()) {
this->resize_and_reset(other.get_size());
} else {
GKO_ENSURE_COMPATIBLE_BOUNDS(other.get_size(), this->get_size());
GKO_ASSERT_EQ(other.get_size(), this->get_size());
}
array<OtherValueType> tmp{this->exec_};
const OtherValueType* source = other.get_const_data();
array tmp{this->exec_};
const ValueType* source = other.get_const_data();
// if we are on different executors: copy, then convert
if (this->exec_ != other.get_executor()) {
tmp = other.copy_to_array();
source = tmp.get_const_data();
}
if (std::is_same<OtherValueType, ValueType>::value) {
exec_->copy_from(other.get_executor(), other.get_size(),
other.get_const_data(), this->get_data());
} else {
detail::convert_data(this->exec_, other.get_size(), source,
this->get_data());
}
exec_->copy_from(other.get_executor(), other.get_size(), source,
this->get_data());
return *this;
}

Expand Down
19 changes: 0 additions & 19 deletions include/ginkgo/core/base/exception_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -755,25 +755,6 @@ inline T ensure_allocated_impl(T ptr, const std::string& file, int line,
"semi-colon warnings")


/**
* Ensures that two dimensions have compatible bounds, in particular before a
* copy operation. This means the target should have at least as much elements
* as the source.
*
* @param _source the source of the expected copy operation
* @param _target the destination of the expected copy operation
*
* @throw OutOfBoundsError if `_source > _target`
*/
#define GKO_ENSURE_COMPATIBLE_BOUNDS(_source, _target) \
if (_source > _target) { \
throw ::gko::OutOfBoundsError(__FILE__, __LINE__, _source, _target); \
} \
static_assert(true, \
"This assert is used to counter the false positive extra " \
"semi-colon warnings")


/**
* Creates a StreamError exception.
* This macro sets the correct information about the location of the error
Expand Down

0 comments on commit 284f979

Please sign in to comment.