Skip to content

Commit

Permalink
review updates:
Browse files Browse the repository at this point in the history
- disable conversion assignment
- add in-code comment
- fix tests

Co-authored-by: Tobias Ribizel <[email protected]>
  • Loading branch information
MarcelKoch and upsj committed Dec 5, 2023
1 parent 2b04fd1 commit b86c726
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 23 deletions.
3 changes: 3 additions & 0 deletions core/matrix/coo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,9 @@ template <typename ValueType, typename IndexType>
void Coo<ValueType, IndexType>::read(const device_mat_data& data)
{
this->set_size(data.get_size());
// copy the arrays from device mnatrix data into the arrays of
// this. Compared to the read(device_mat_data&&) version, the internal
// arrays keep their current ownership status
this->values_ = make_const_array_view(data.get_executor(),
data.get_num_stored_elements(),
data.get_const_values());
Expand Down
7 changes: 7 additions & 0 deletions core/matrix/csr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,9 +426,13 @@ void Csr<ValueType, IndexType>::read(const mat_data& data)
{
auto size = data.size;
auto exec = this->get_executor();
this->set_size(size);
row_ptrs_.resize_and_reset(size[0] + 1);
col_idxs_.resize_and_reset(data.nonzeros.size());
values_.resize_and_reset(data.nonzeros.size());
// the device matrix data contains views on the column indices
// and values array of this matrix, and an owning array for the
// row indices (which doesn't exist in this matrix)
device_mat_data view{exec, size,
array<IndexType>{exec, data.nonzeros.size()},
col_idxs_.as_view(), values_.as_view()};
Expand All @@ -452,6 +456,9 @@ void Csr<ValueType, IndexType>::read(const device_mat_data& data)
auto exec = this->get_executor();
this->row_ptrs_.resize_and_reset(size[0] + 1);
this->set_size(size);
// copy the column indices and values array from the device matrix data
// into this. Compared to the read(device_mat_data&&) version, the internal
// arrays keep their current ownership status.
this->values_ = make_const_array_view(data.get_executor(),
data.get_num_stored_elements(),
data.get_const_values());
Expand Down
10 changes: 5 additions & 5 deletions core/test/matrix/coo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ TYPED_TEST(Coo, CanBeReadFromDeviceMatrixData)
this->assert_equal_to_original_mtx(m);
ASSERT_EQ(device_data.get_num_stored_elements(),
m->get_num_stored_elements());
GKO_ASSERT_EQUAL_DIMENSIONS(&device_data, m);
ASSERT_EQ(device_data.get_size(), m->get_size());
}


Expand Down Expand Up @@ -305,7 +305,7 @@ TYPED_TEST(Coo, CanBeReadFromMovedDeviceMatrixData)
m->read(std::move(device_data));

this->assert_equal_to_original_mtx(m);
GKO_ASSERT_EQUAL_DIMENSIONS(&device_data, gko::dim<2>{});
ASSERT_EQ(device_data.get_size(), gko::dim<2>{});
ASSERT_EQ(device_data.get_num_stored_elements(), 0);
}

Expand Down Expand Up @@ -335,9 +335,9 @@ TYPED_TEST(Coo, CanBeReadFromMovedDeviceMatrixDataIntoViews)
m->read(std::move(device_data));

this->assert_equal_to_original_mtx(m);
ASSERT_NE(row_idxs.get_data(), m->get_row_idxs());
ASSERT_NE(col_idxs.get_data(), m->get_col_idxs());
ASSERT_NE(values.get_data(), m->get_values());
ASSERT_EQ(device_data.get_row_idxs(), m->get_row_idxs());
ASSERT_EQ(device_data.get_col_idxs(), m->get_col_idxs());
ASSERT_EQ(device_data.get_values(), m->get_values());
}


Expand Down
8 changes: 4 additions & 4 deletions core/test/matrix/csr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ TYPED_TEST(Csr, CanBeReadFromDeviceMatrixData)
this->assert_equal_to_original_mtx(m);
ASSERT_EQ(device_data.get_num_stored_elements(),
m->get_num_stored_elements());
GKO_ASSERT_EQUAL_DIMENSIONS(&device_data, m);
ASSERT_EQ(device_data.get_size(), m->get_size());
}


Expand Down Expand Up @@ -333,7 +333,7 @@ TYPED_TEST(Csr, CanBeReadFromMovedDeviceMatrixData)
m->read(std::move(device_data));

this->assert_equal_to_original_mtx(m);
GKO_ASSERT_EQUAL_DIMENSIONS(&device_data, gko::dim<2>{});
ASSERT_EQ(device_data.get_size(), gko::dim<2>{});
ASSERT_EQ(device_data.get_num_stored_elements(), 0);
}

Expand Down Expand Up @@ -365,8 +365,8 @@ TYPED_TEST(Csr, CanBeReadFromMovedDeviceMatrixDataIntoViews)

this->assert_equal_to_original_mtx(m);
ASSERT_EQ(row_ptrs.get_data(), m->get_row_ptrs());
ASSERT_NE(col_idxs.get_data(), m->get_col_idxs());
ASSERT_NE(values.get_data(), m->get_values());
ASSERT_EQ(device_data.get_col_idxs(), m->get_col_idxs());
ASSERT_EQ(device_data.get_values(), m->get_values());
}


Expand Down
20 changes: 6 additions & 14 deletions include/ginkgo/core/base/array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ class array {
}

/**
* Copies or converts data from a const_array_view.
* Copies data from a const_array_view.
*
* In the case of an array target, the array is resized to match the
* source's size. In the case of a view target, if the dimensions are not
Expand All @@ -568,10 +568,7 @@ class array {
*
* @return this
*/
template <typename OtherValueType>
std::enable_if_t<std::is_convertible<OtherValueType, ValueType>::value,
array>&
operator=(const detail::const_array_view<OtherValueType>& other)
array& operator=(const detail::const_array_view<ValueType>& other)
{
if (this->exec_ == nullptr) {
this->exec_ = other.get_executor();
Expand All @@ -587,20 +584,15 @@ class array {
} else {
GKO_ENSURE_COMPATIBLE_BOUNDS(other.get_size(), this->get_size());
}
array<OtherValueType> tmp{this->exec_};
const OtherValueType* source = other.get_const_data();
array tmp{this->exec_};
const ValueType* source = other.get_const_data();
// if we are on different executors: copy, then convert
if (this->exec_ != other.get_executor()) {
tmp = other;
source = tmp.get_const_data();
}
if (std::is_same<OtherValueType, ValueType>::value) {
exec_->copy_from(other.get_executor(), other.get_size(),
other.get_const_data(), this->get_data());
} else {
detail::convert_data(this->exec_, other.get_size(), source,
this->get_data());
}
exec_->copy_from(other.get_executor(), other.get_size(), source,
this->get_data());
return *this;
}

Expand Down

0 comments on commit b86c726

Please sign in to comment.