Skip to content

Commit

Permalink
Review updates
Browse files Browse the repository at this point in the history
Co-authored-by: Marcel Koch <[email protected]>
  • Loading branch information
pratikvn and MarcelKoch committed Nov 26, 2023
1 parent c5c9612 commit aa40e0a
Show file tree
Hide file tree
Showing 7 changed files with 12 additions and 33 deletions.
28 changes: 9 additions & 19 deletions core/matrix/batch_struct.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ struct batch_item {
ValueType* values;
const index_type* col_idxs;
const index_type* row_ptrs;
index_type stride;
index_type num_rows;
index_type num_cols;
};
Expand All @@ -49,7 +48,6 @@ struct uniform_batch {
const index_type* col_idxs;
const index_type* row_ptrs;
size_type num_batch_items;
index_type stride;
index_type num_rows;
index_type num_cols;
index_type num_nnz_per_item;
Expand Down Expand Up @@ -155,16 +153,16 @@ template <typename ValueType, typename IndexType>
GKO_ATTRIBUTES GKO_INLINE csr::batch_item<const ValueType, const IndexType>
to_const(const csr::batch_item<ValueType, IndexType>& b)
{
return {b.values, b.col_idxs, b.row_ptrs, b.stride, b.num_rows, b.num_cols};
return {b.values, b.col_idxs, b.row_ptrs, b.num_rows, b.num_cols};
}


template <typename ValueType, typename IndexType>
GKO_ATTRIBUTES GKO_INLINE csr::uniform_batch<const ValueType, const IndexType>
to_const(const csr::uniform_batch<ValueType, IndexType>& ub)
{
return {ub.values, ub.col_idxs, ub.row_ptrs, ub.num_batch_items,
ub.stride, ub.num_rows, ub.num_cols, ub.num_nnz_per_item};
return {ub.values, ub.col_idxs, ub.row_ptrs, ub.num_batch_items,
ub.num_rows, ub.num_cols, ub.num_nnz_per_item};
}


Expand All @@ -173,28 +171,20 @@ GKO_ATTRIBUTES GKO_INLINE csr::batch_item<ValueType, IndexType>
extract_batch_item(const csr::uniform_batch<ValueType, IndexType>& batch,
const size_type batch_idx)
{
return {batch.values + batch_idx * batch.num_nnz_per_item,
batch.col_idxs,
batch.row_ptrs,
batch.stride,
batch.num_rows,
batch.num_cols};
return {batch.values + batch_idx * batch.num_nnz_per_item, batch.col_idxs,
batch.row_ptrs, batch.num_rows, batch.num_cols};
}

template <typename ValueType, typename IndexType>
GKO_ATTRIBUTES GKO_INLINE csr::batch_item<ValueType, IndexType>
extract_batch_item(ValueType* const batch_values,
IndexType* const batch_col_idxs,
IndexType* const batch_row_ptrs, const int stride,
const int num_rows, const int num_cols, int num_nnz_per_item,
IndexType* const batch_row_ptrs, const int num_rows,
const int num_cols, int num_nnz_per_item,
const size_type batch_idx)
{
return {batch_values + batch_idx * num_nnz_per_item,
batch_col_idxs,
batch_row_ptrs,
stride,
num_rows,
num_cols};
return {batch_values + batch_idx * num_nnz_per_item, batch_col_idxs,
batch_row_ptrs, num_rows, num_cols};
}


Expand Down
7 changes: 2 additions & 5 deletions core/test/matrix/batch_csr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ TYPED_TEST(Csr, CanBeConstructedFromCsrMatricesByDuplication)
auto m =
gko::batch::create_from_item<BatchCsrMtx>(this->exec, 3, mat1.get(), 2);

GKO_ASSERT_BATCH_MTX_NEAR(bat_m.get(), m.get(), 1e-14);
GKO_ASSERT_BATCH_MTX_NEAR(bat_m.get(), m.get(), 0.);
}


Expand All @@ -324,7 +324,7 @@ TYPED_TEST(Csr, CanBeConstructedByDuplicatingCsrMatrices)

auto m2 = gko::batch::duplicate<BatchCsrMtx>(this->exec, 3, m.get(), 2);

GKO_ASSERT_BATCH_MTX_NEAR(m2.get(), m_ref.get(), 1e-14);
GKO_ASSERT_BATCH_MTX_NEAR(m2.get(), m_ref.get(), 0.);
}


Expand All @@ -349,7 +349,6 @@ TYPED_TEST(Csr, CanBeListConstructed)
using value_type = typename TestFixture::value_type;
using index_type = typename TestFixture::index_type;
using BatchCsrMtx = typename TestFixture::BatchCsrMtx;
using CsrMtx = typename TestFixture::CsrMtx;

auto m = gko::batch::initialize<BatchCsrMtx>({{0.0, -1.0}, {0.0, -5.0}},
this->exec, 1);
Expand Down Expand Up @@ -452,7 +451,6 @@ TYPED_TEST(Csr, ThrowsForDataWithDifferentNnz)
{
using value_type = typename TestFixture::value_type;
using index_type = typename TestFixture::index_type;
using BatchCsrMtx = typename TestFixture::BatchCsrMtx;
auto vec_data = std::vector<gko::matrix_data<value_type, index_type>>{};
vec_data.emplace_back(
gko::matrix_data<value_type, index_type>({2, 3}, {
Expand All @@ -474,7 +472,6 @@ TYPED_TEST(Csr, ThrowsForDataWithDifferentSparsity)
{
using value_type = typename TestFixture::value_type;
using index_type = typename TestFixture::index_type;
using BatchCsrMtx = typename TestFixture::BatchCsrMtx;
auto vec_data = std::vector<gko::matrix_data<value_type, index_type>>{};
vec_data.emplace_back(
gko::matrix_data<value_type, index_type>({2, 3}, {
Expand Down
2 changes: 0 additions & 2 deletions cuda/matrix/batch_struct.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ get_batch_struct(const batch::matrix::Csr<ValueType, IndexType>* const op)
op->get_const_col_idxs(),
op->get_const_row_ptrs(),
op->get_num_batch_items(),
static_cast<IndexType>(op->get_common_size()[1]),
static_cast<IndexType>(op->get_common_size()[0]),
static_cast<IndexType>(op->get_common_size()[1]),
static_cast<IndexType>(op->get_num_elements_per_item())};
Expand All @@ -62,7 +61,6 @@ get_batch_struct(batch::matrix::Csr<ValueType, IndexType>* const op)
op->get_col_idxs(),
op->get_row_ptrs(),
op->get_num_batch_items(),
static_cast<IndexType>(op->get_common_size()[1]),
static_cast<IndexType>(op->get_common_size()[0]),
static_cast<IndexType>(op->get_common_size()[1]),
static_cast<IndexType>(op->get_num_elements_per_item())};
Expand Down
2 changes: 0 additions & 2 deletions dpcpp/matrix/batch_struct.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ get_batch_struct(const batch::matrix::Csr<ValueType, IndexType>* const op)
op->get_const_col_idxs(),
op->get_const_row_ptrs(),
op->get_num_batch_items(),
static_cast<IndexType>(op->get_common_size()[1]),
static_cast<IndexType>(op->get_common_size()[0]),
static_cast<IndexType>(op->get_common_size()[1]),
static_cast<IndexType>(op->get_num_elements_per_item())};
Expand All @@ -60,7 +59,6 @@ inline batch::matrix::csr::uniform_batch<ValueType, IndexType> get_batch_struct(
op->get_col_idxs(),
op->get_row_ptrs(),
op->get_num_batch_items(),
static_cast<IndexType>(op->get_common_size()[1]),
static_cast<IndexType>(op->get_common_size()[0]),
static_cast<IndexType>(op->get_common_size()[1]),
static_cast<IndexType>(op->get_num_elements_per_item())};
Expand Down
2 changes: 0 additions & 2 deletions hip/matrix/batch_struct.hip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ get_batch_struct(const batch::matrix::Csr<ValueType, IndexType>* const op)
op->get_const_col_idxs(),
op->get_const_row_ptrs(),
op->get_num_batch_items(),
static_cast<IndexType>(op->get_common_size()[1]),
static_cast<IndexType>(op->get_common_size()[0]),
static_cast<IndexType>(op->get_common_size()[1]),
static_cast<IndexType>(op->get_num_elements_per_item())};
Expand All @@ -62,7 +61,6 @@ get_batch_struct(batch::matrix::Csr<ValueType, IndexType>* const op)
op->get_col_idxs(),
op->get_row_ptrs(),
op->get_num_batch_items(),
static_cast<IndexType>(op->get_common_size()[1]),
static_cast<IndexType>(op->get_common_size()[0]),
static_cast<IndexType>(op->get_common_size()[1]),
static_cast<IndexType>(op->get_num_elements_per_item())};
Expand Down
2 changes: 0 additions & 2 deletions reference/matrix/batch_struct.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ get_batch_struct(const batch::matrix::Csr<ValueType, IndexType>* const op)
op->get_const_col_idxs(),
op->get_const_row_ptrs(),
op->get_num_batch_items(),
static_cast<IndexType>(op->get_common_size()[1]),
static_cast<IndexType>(op->get_common_size()[0]),
static_cast<IndexType>(op->get_common_size()[1]),
static_cast<IndexType>(op->get_num_elements_per_item())};
Expand All @@ -65,7 +64,6 @@ inline batch::matrix::csr::uniform_batch<ValueType, IndexType> get_batch_struct(
op->get_col_idxs(),
op->get_row_ptrs(),
op->get_num_batch_items(),
static_cast<IndexType>(op->get_common_size()[1]),
static_cast<IndexType>(op->get_common_size()[0]),
static_cast<IndexType>(op->get_common_size()[1]),
static_cast<IndexType>(op->get_num_elements_per_item())};
Expand Down
2 changes: 1 addition & 1 deletion test/matrix/batch_csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class Csr : public CommonTestFixture {
dresult = gko::clone(exec, expected);
}

std::ranlux48 rand_engine;
std::default_random_engine rand_engine;

const gko::size_type batch_size = 11;
std::unique_ptr<BMtx> mat;
Expand Down

0 comments on commit aa40e0a

Please sign in to comment.