Skip to content

Commit

Permalink
Merge pull request #1640 from ginkgo-project/distributed_row_col_scale
Browse files Browse the repository at this point in the history
Adds a row_scale and a col_scale function to the distributed matrix.

    The row scaling is straight forward, with a distributed vector of scaling coefficients that has the same row partitioning as the matrix, we can just scale the rows of local and non-local matrices with the local vectors.
    For column scaling, we need to communicate the non-local scaling coefficients to scale the columns of the non-local matrices. This can be simply done with the matrix's communicate.

Related PR: #1640
  • Loading branch information
fritzgoebel authored Jul 19, 2024
2 parents 6ed7108 + 3c16ba4 commit 26eb276
Show file tree
Hide file tree
Showing 3 changed files with 253 additions and 0 deletions.
71 changes: 71 additions & 0 deletions core/distributed/matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <ginkgo/core/distributed/vector.hpp>
#include <ginkgo/core/matrix/coo.hpp>
#include <ginkgo/core/matrix/csr.hpp>
#include <ginkgo/core/matrix/diagonal.hpp>

#include "core/distributed/matrix_kernels.hpp"

Expand Down Expand Up @@ -504,6 +505,76 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::apply_impl(
}


template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
void Matrix<ValueType, LocalIndexType, GlobalIndexType>::col_scale(
ptr_param<const global_vector_type> scaling_factors)
{
GKO_ASSERT_CONFORMANT(this, scaling_factors.get());
GKO_ASSERT_EQ(scaling_factors->get_size()[1], 1);
auto exec = this->get_executor();
auto comm = this->get_communicator();
size_type n_local_cols = local_mtx_->get_size()[1];
size_type n_non_local_cols = non_local_mtx_->get_size()[1];
std::unique_ptr<global_vector_type> scaling_factors_single_stride;
auto stride = scaling_factors->get_stride();
if (stride != 1) {
scaling_factors_single_stride = global_vector_type::create(exec, comm);
scaling_factors_single_stride->copy_from(scaling_factors.get());
}
const auto scale_values =
stride == 1 ? scaling_factors->get_const_local_values()
: scaling_factors_single_stride->get_const_local_values();
const auto scale_diag = gko::matrix::Diagonal<ValueType>::create_const(
exec, n_local_cols,
make_const_array_view(exec, n_local_cols, scale_values));

auto req = this->communicate(
stride == 1 ? scaling_factors->get_local_vector()
: scaling_factors_single_stride->get_local_vector());
scale_diag->rapply(local_mtx_, local_mtx_);
req.wait();
if (n_non_local_cols > 0) {
auto use_host_buffer = mpi::requires_host_buffer(exec, comm);
if (use_host_buffer) {
recv_buffer_->copy_from(host_recv_buffer_.get());
}
const auto non_local_scale_diag =
gko::matrix::Diagonal<ValueType>::create_const(
exec, n_non_local_cols,
make_const_array_view(exec, n_non_local_cols,
recv_buffer_->get_const_values()));
non_local_scale_diag->rapply(non_local_mtx_, non_local_mtx_);
}
}


template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
void Matrix<ValueType, LocalIndexType, GlobalIndexType>::row_scale(
ptr_param<const global_vector_type> scaling_factors)
{
GKO_ASSERT_EQUAL_ROWS(this, scaling_factors.get());
GKO_ASSERT_EQ(scaling_factors->get_size()[1], 1);
auto exec = this->get_executor();
auto comm = this->get_communicator();
size_type n_local_rows = local_mtx_->get_size()[0];
std::unique_ptr<global_vector_type> scaling_factors_single_stride;
auto stride = scaling_factors->get_stride();
if (stride != 1) {
scaling_factors_single_stride = global_vector_type::create(exec, comm);
scaling_factors_single_stride->copy_from(scaling_factors.get());
}
const auto scale_values =
stride == 1 ? scaling_factors->get_const_local_values()
: scaling_factors_single_stride->get_const_local_values();
const auto scale_diag = gko::matrix::Diagonal<ValueType>::create_const(
exec, n_local_rows,
make_const_array_view(exec, n_local_rows, scale_values));

scale_diag->apply(local_mtx_, local_mtx_);
scale_diag->apply(non_local_mtx_, non_local_mtx_);
}


template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
Matrix<ValueType, LocalIndexType, GlobalIndexType>::Matrix(const Matrix& other)
: EnableDistributedLinOp<Matrix<value_type, local_index_type,
Expand Down
18 changes: 18 additions & 0 deletions include/ginkgo/core/distributed/matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,24 @@ class Matrix
std::vector<comm_index_type> recv_offsets,
array<local_index_type> recv_gather_idxs);

/**
* Scales the columns of the matrix by the respective entries of the vector.
* The vector's row partition has to be the same as the matrix's column
* partition. The scaling is done in-place.
*
* @param scaling_factors The vector containing the scaling factors.
*/
void col_scale(ptr_param<const global_vector_type> scaling_factors);

/**
* Scales the rows of the matrix by the respective entries of the vector.
* The vector and the matrix have to have the same row partition.
* The scaling is done in-place.
*
* @param scaling_factors The vector containing the scaling factors.
*/
void row_scale(ptr_param<const global_vector_type> scaling_factors);

protected:
explicit Matrix(std::shared_ptr<const Executor> exec,
mpi::communicator comm);
Expand Down
164 changes: 164 additions & 0 deletions test/mpi/matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <ginkgo/core/matrix/csr.hpp>

#include "core/test/utils.hpp"
#include "ginkgo/core/base/exception.hpp"
#include "test/utils/mpi/common_fixture.hpp"


Expand Down Expand Up @@ -516,6 +517,169 @@ TYPED_TEST(Matrix, CanAdvancedApplyToMultipleVectorsLarge)
}


TYPED_TEST(Matrix, CanColScale)
{
using value_type = typename TestFixture::value_type;
using index_type = typename TestFixture::global_index_type;
using csr = typename TestFixture::local_matrix_type;
using dist_vec_type = typename TestFixture::dist_vec_type;
auto vec_md = gko::matrix_data<value_type, index_type>{
I<I<value_type>>{{1}, {2}, {3}, {4}, {5}}};
I<I<value_type>> res_col_scale_local[] = {
{{8, 0}, {0, 0}}, {{0, 10}, {0, 0}}, {{0}}};
I<I<value_type>> res_col_scale_non_local[] = {
{{2, 0}, {6, 12}}, {{0, 0, 18}, {32, 35, 0}}, {{50, 9}}};
auto rank = this->comm.rank();
auto col_scaling_factors = dist_vec_type::create(this->exec, this->comm);
col_scaling_factors->read_distributed(vec_md, this->col_part);

this->dist_mat->col_scale(col_scaling_factors);

GKO_ASSERT_MTX_NEAR(gko::as<csr>(this->dist_mat->get_local_matrix()),
res_col_scale_local[rank], 0);
GKO_ASSERT_MTX_NEAR(gko::as<csr>(this->dist_mat->get_non_local_matrix()),
res_col_scale_non_local[rank], 0);
}


TYPED_TEST(Matrix, CanRowScale)
{
using value_type = typename TestFixture::value_type;
using index_type = typename TestFixture::global_index_type;
using csr = typename TestFixture::local_matrix_type;
using dist_vec_type = typename TestFixture::dist_vec_type;
auto vec_md = gko::matrix_data<value_type, index_type>{
I<I<value_type>>{{1}, {2}, {3}, {4}, {5}}};
I<I<value_type>> res_row_scale_local[] = {
{{2, 0}, {0, 0}}, {{0, 15}, {0, 0}}, {{0}}};
I<I<value_type>> res_row_scale_non_local[] = {
{{1, 0}, {6, 8}}, {{0, 0, 18}, {32, 28, 0}}, {{50, 45}}};
auto rank = this->comm.rank();
auto row_scaling_factors = dist_vec_type::create(this->exec, this->comm);
row_scaling_factors->read_distributed(vec_md, this->row_part);

this->dist_mat->row_scale(row_scaling_factors);

GKO_ASSERT_MTX_NEAR(gko::as<csr>(this->dist_mat->get_local_matrix()),
res_row_scale_local[rank], 0);
GKO_ASSERT_MTX_NEAR(gko::as<csr>(this->dist_mat->get_non_local_matrix()),
res_row_scale_non_local[rank], 0);
}


TYPED_TEST(Matrix, CanColScaleWithStride)
{
using value_type = typename TestFixture::value_type;
using index_type = typename TestFixture::global_index_type;
using csr = typename TestFixture::local_matrix_type;
using dist_vec_type = typename TestFixture::dist_vec_type;
auto vec_md = gko::matrix_data<value_type, index_type>{
I<I<value_type>>{{1}, {2}, {3}, {4}, {5}}};
I<I<value_type>> res_col_scale_local[] = {
{{8, 0}, {0, 0}}, {{0, 10}, {0, 0}}, {{0}}};
I<I<value_type>> res_col_scale_non_local[] = {
{{2, 0}, {6, 12}}, {{0, 0, 18}, {32, 35, 0}}, {{50, 9}}};
gko::dim<2> local_sizes[] = {{2, 1}, {2, 1}, {1, 1}};
auto rank = this->comm.rank();
auto col_scaling_factors = dist_vec_type::create(
this->exec, this->comm, gko::dim<2>{5, 1}, local_sizes[rank], 2);
col_scaling_factors->read_distributed(vec_md, this->col_part);

this->dist_mat->col_scale(col_scaling_factors);

ASSERT_EQ(col_scaling_factors->get_stride(), 2);
GKO_ASSERT_MTX_NEAR(gko::as<csr>(this->dist_mat->get_local_matrix()),
res_col_scale_local[rank], 0);
GKO_ASSERT_MTX_NEAR(gko::as<csr>(this->dist_mat->get_non_local_matrix()),
res_col_scale_non_local[rank], 0);
}


TYPED_TEST(Matrix, CanRowScaleWithStride)
{
using value_type = typename TestFixture::value_type;
using index_type = typename TestFixture::global_index_type;
using csr = typename TestFixture::local_matrix_type;
using dist_vec_type = typename TestFixture::dist_vec_type;
auto vec_md = gko::matrix_data<value_type, index_type>{
I<I<value_type>>{{1}, {2}, {3}, {4}, {5}}};
I<I<value_type>> res_row_scale_local[] = {
{{2, 0}, {0, 0}}, {{0, 15}, {0, 0}}, {{0}}};
I<I<value_type>> res_row_scale_non_local[] = {
{{1, 0}, {6, 8}}, {{0, 0, 18}, {32, 28, 0}}, {{50, 45}}};
gko::dim<2> local_sizes[] = {{2, 1}, {2, 1}, {1, 1}};
auto rank = this->comm.rank();
auto row_scaling_factors = dist_vec_type::create(
this->exec, this->comm, gko::dim<2>{5, 1}, local_sizes[rank], 2);
row_scaling_factors->read_distributed(vec_md, this->row_part);

this->dist_mat->row_scale(row_scaling_factors);

ASSERT_EQ(row_scaling_factors->get_stride(), 2);
GKO_ASSERT_MTX_NEAR(gko::as<csr>(this->dist_mat->get_local_matrix()),
res_row_scale_local[rank], 0);
GKO_ASSERT_MTX_NEAR(gko::as<csr>(this->dist_mat->get_non_local_matrix()),
res_row_scale_non_local[rank], 0);
}


TYPED_TEST(Matrix, ColScaleThrowsOnWrongDimension)
{
using value_type = typename TestFixture::value_type;
using index_type = typename TestFixture::global_index_type;
using dist_vec_type = typename TestFixture::dist_vec_type;
using part_type = typename TestFixture::part_type;
auto vec_md = gko::matrix_data<value_type, index_type>{
I<I<value_type>>{{1}, {2}, {3}, {4}}};
auto two_vec_md = gko::matrix_data<value_type, index_type>{
I<I<value_type>>{{1, 1}, {2, 2}, {3, 3}, {4, 4}, {5, 5}}};
auto rank = this->comm.rank();
auto col_part = part_type::build_from_mapping(
this->exec,
gko::array<gko::experimental::distributed::comm_index_type>(
this->exec,
I<gko::experimental::distributed::comm_index_type>{1, 2, 0, 0}),
3);
auto col_scaling_factors = dist_vec_type::create(this->exec, this->comm);
col_scaling_factors->read_distributed(vec_md, col_part);
auto two_col_scaling_factors =
dist_vec_type::create(this->exec, this->comm);
two_col_scaling_factors->read_distributed(two_vec_md, this->col_part);

ASSERT_THROW(this->dist_mat->col_scale(col_scaling_factors),
gko::DimensionMismatch);
ASSERT_THROW(this->dist_mat->col_scale(two_col_scaling_factors),
gko::ValueMismatch);
}


TYPED_TEST(Matrix, RowScaleThrowsOnWrongDimension)
{
using value_type = typename TestFixture::value_type;
using index_type = typename TestFixture::global_index_type;
using dist_vec_type = typename TestFixture::dist_vec_type;
using part_type = typename TestFixture::part_type;
auto vec_md = gko::matrix_data<value_type, index_type>{
I<I<value_type>>{{1}, {2}, {3}, {4}}};
auto two_vec_md = gko::matrix_data<value_type, index_type>{
I<I<value_type>>{{1, 1}, {2, 2}, {3, 3}, {4, 4}, {5, 5}}};
auto rank = this->comm.rank();
auto row_part = part_type::build_from_contiguous(
this->exec,
gko::array<index_type>(this->exec, I<index_type>{0, 2, 3, 4}));
auto row_scaling_factors = dist_vec_type::create(this->exec, this->comm);
row_scaling_factors->read_distributed(vec_md, row_part);
auto two_row_scaling_factors =
dist_vec_type::create(this->exec, this->comm);
two_row_scaling_factors->read_distributed(two_vec_md, this->col_part);

ASSERT_THROW(this->dist_mat->row_scale(row_scaling_factors),
gko::DimensionMismatch);
ASSERT_THROW(this->dist_mat->row_scale(two_row_scaling_factors),
gko::ValueMismatch);
}


TYPED_TEST(Matrix, CanConvertToNextPrecision)
{
using T = typename TestFixture::value_type;
Expand Down

0 comments on commit 26eb276

Please sign in to comment.