Skip to content

Commit

Permalink
[core] allow filling the device_matrix_data
Browse files Browse the repository at this point in the history
The main use case is in combination with `sum_duplicates` and `remove_zeros` to simplify the assembly setup.
  • Loading branch information
MarcelKoch committed Sep 17, 2024
1 parent b5745ac commit 0893e50
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 0 deletions.
20 changes: 20 additions & 0 deletions core/base/device_matrix_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ device_matrix_data<ValueType, IndexType>::device_matrix_data(
{}


template <typename ValueType, typename IndexType>
device_matrix_data<ValueType, IndexType>::device_matrix_data(
std::shared_ptr<const Executor> exec, dim<2> size, size_type num_entries,
matrix_data_entry<ValueType, IndexType> value)
: device_matrix_data(std::move(exec), size, num_entries)
{
fill(value);
}


template <typename ValueType, typename IndexType>
device_matrix_data<ValueType, IndexType>::device_matrix_data(
std::shared_ptr<const Executor> exec, const device_matrix_data& data)
Expand Down Expand Up @@ -93,6 +103,16 @@ device_matrix_data<ValueType, IndexType>::create_from_host(
}


template <typename ValueType, typename IndexType>
void device_matrix_data<ValueType, IndexType>::fill(
matrix_data_entry<ValueType, IndexType> value)
{
row_idxs_.fill(value.row);
col_idxs_.fill(value.column);
values_.fill(value.value);
}


template <typename ValueType, typename IndexType>
void device_matrix_data<ValueType, IndexType>::sort_row_major()
{
Expand Down
21 changes: 21 additions & 0 deletions include/ginkgo/core/base/device_matrix_data.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,21 @@ class device_matrix_data {
explicit device_matrix_data(std::shared_ptr<const Executor> exec,
dim<2> size = {}, size_type num_entries = 0);

/**
* Initializes a new device_matrix_data object.
* It uses the given executor to allocate storage for the given number of
* entries and matrix dimensions, and initializes all entries with a given
* value.
*
* @param exec the executor to be used to store the matrix entries
* @param size the matrix dimensions
* @param num_entries the number of entries to be stored
* @param value fills the matrix data with this value
*/
explicit device_matrix_data(std::shared_ptr<const Executor> exec,
dim<2> size, size_type num_entries,
matrix_data_entry<ValueType, IndexType> value);

/**
* Initializes a device_matrix_data object by copying an existing object on
* another executor.
Expand Down Expand Up @@ -114,6 +129,12 @@ class device_matrix_data {
static device_matrix_data create_from_host(
std::shared_ptr<const Executor> exec, const host_type& data);

/**
* Fills the matrix entries with a specified value
*/
void fill(matrix_data_entry<ValueType, IndexType> value);


/**
* Sorts the matrix entries in row-major order
* This means that they will be sorted by row index first, and then by
Expand Down
46 changes: 46 additions & 0 deletions test/base/device_matrix_data_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,30 @@ TYPED_TEST(DeviceMatrixData, ConstructsCorrectly)
}


TYPED_TEST(DeviceMatrixData, ConstructsWithValueCorrectly)
{
using value_type = typename TestFixture::value_type;
using index_type = typename TestFixture::index_type;

gko::device_matrix_data<value_type, index_type> local_data{
this->exec, gko::dim<2>{4, 3}, 10, {0, 1, value_type{2.0}}};

ASSERT_EQ((gko::dim<2>{4, 3}), local_data.get_size());
ASSERT_EQ(this->exec, local_data.get_executor());
ASSERT_EQ(local_data.get_num_stored_elements(), 10);
auto arrays = local_data.empty_out();
auto expected_row_idxs = gko::array<index_type>(this->exec, 10);
auto expected_col_idxs = gko::array<index_type>(this->exec, 10);
auto expected_values = gko::array<value_type>(this->exec, 10);
expected_row_idxs.fill(0);
expected_col_idxs.fill(1);
expected_values.fill(2.0);
GKO_ASSERT_ARRAY_EQ(arrays.row_idxs, expected_row_idxs);
GKO_ASSERT_ARRAY_EQ(arrays.col_idxs, expected_col_idxs);
GKO_ASSERT_ARRAY_EQ(arrays.values, expected_values);
}


TYPED_TEST(DeviceMatrixData, CopyConstructsOnOtherExecutorCorrectly)
{
using value_type = typename TestFixture::value_type;
Expand Down Expand Up @@ -241,6 +265,28 @@ TYPED_TEST(DeviceMatrixData, CopiesToHost)
}


TYPED_TEST(DeviceMatrixData, CanFillEntries)
{
using value_type = typename TestFixture::value_type;
using index_type = typename TestFixture::index_type;
using device_matrix_data = gko::device_matrix_data<value_type, index_type>;
auto device_data = device_matrix_data{this->exec, gko::dim<2>{4, 3}, 10};

device_data.fill({0, 1, value_type{2.0}});

auto arrays = device_data.empty_out();
auto expected_row_idxs = gko::array<index_type>(this->exec, 10);
auto expected_col_idxs = gko::array<index_type>(this->exec, 10);
auto expected_values = gko::array<value_type>(this->exec, 10);
expected_row_idxs.fill(0);
expected_col_idxs.fill(1);
expected_values.fill(2.0);
GKO_ASSERT_ARRAY_EQ(arrays.row_idxs, expected_row_idxs);
GKO_ASSERT_ARRAY_EQ(arrays.col_idxs, expected_col_idxs);
GKO_ASSERT_ARRAY_EQ(arrays.values, expected_values);
}


TYPED_TEST(DeviceMatrixData, SortsRowMajor)
{
using value_type = typename TestFixture::value_type;
Expand Down

0 comments on commit 0893e50

Please sign in to comment.