Skip to content

Commit

Permalink
add checked_lookup parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Sep 24, 2024
1 parent a2c69b1 commit f7ee9ec
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 25 deletions.
38 changes: 25 additions & 13 deletions common/cuda_hip/factorization/lu_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ __global__ __launch_bounds__(default_block_size) void initialize(
}


template <typename ValueType, typename IndexType>
template <bool checked_lookup, typename ValueType, typename IndexType>
__global__ __launch_bounds__(default_block_size) void factorize(
const IndexType* __restrict__ row_ptrs, const IndexType* __restrict__ cols,
const IndexType* __restrict__ storage_offsets,
Expand Down Expand Up @@ -130,12 +130,14 @@ __global__ __launch_bounds__(default_block_size) void factorize(
upper_nz += config::warp_size) {
const auto upper_col = cols[upper_nz];
const auto upper_val = vals[upper_nz];
// const auto output_pos = lookup[upper_col];
const auto output_pos = lookup.lookup_unsafe(upper_col) + row_begin;
if (output_pos >= row_begin && output_pos < row_end &&
cols[output_pos] == upper_col) {
// if (output_pos != invalid_index<IndexType>()) {
// output_pos += row_begin;
if (checked_lookup) {
const auto pos = lookup[upper_col];
if (pos != invalid_index<IndexType>()) {
vals[row_begin + pos] -= scale * upper_val;
}
} else {
const auto output_pos =
lookup.lookup_unsafe(upper_col) + row_begin;
vals[output_pos] -= scale * upper_val;
}
}
Expand Down Expand Up @@ -258,19 +260,29 @@ template <typename ValueType, typename IndexType>
void factorize(std::shared_ptr<const DefaultExecutor> exec,
const IndexType* lookup_offsets, const int64* lookup_descs,
const int32* lookup_storage, const IndexType* diag_idxs,
matrix::Csr<ValueType, IndexType>* factors,
matrix::Csr<ValueType, IndexType>* factors, bool checked_lookup,
array<int>& tmp_storage)
{
const auto num_rows = factors->get_size()[0];
if (num_rows > 0) {
syncfree_storage storage(exec, tmp_storage, num_rows);
const auto num_blocks =
ceildiv(num_rows, default_block_size / config::warp_size);
kernel::factorize<<<num_blocks, default_block_size, 0,
exec->get_stream()>>>(
factors->get_const_row_ptrs(), factors->get_const_col_idxs(),
lookup_offsets, lookup_storage, lookup_descs, diag_idxs,
as_device_type(factors->get_values()), storage, num_rows);
if (checked_lookup) {
kernel::factorize<true>
<<<num_blocks, default_block_size, 0, exec->get_stream()>>>(
factors->get_const_row_ptrs(),
factors->get_const_col_idxs(), lookup_offsets,
lookup_storage, lookup_descs, diag_idxs,
as_device_type(factors->get_values()), storage, num_rows);
} else {
kernel::factorize<false>
<<<num_blocks, default_block_size, 0, exec->get_stream()>>>(
factors->get_const_row_ptrs(),
factors->get_const_col_idxs(), lookup_offsets,
lookup_storage, lookup_descs, diag_idxs,
as_device_type(factors->get_values()), storage, num_rows);
}
}
}

Expand Down
8 changes: 4 additions & 4 deletions core/factorization/lu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,10 @@ std::unique_ptr<LinOp> Lu<ValueType, IndexType>::generate_impl(
storage.get_const_data(), diag_idxs.get_data(), factors.get()));
// run numerical factorization
array<int> tmp{exec};
exec->run(make_factorize(storage_offsets.get_const_data(),
row_descs.get_const_data(),
storage.get_const_data(),
diag_idxs.get_const_data(), factors.get(), tmp));
exec->run(make_factorize(
storage_offsets.get_const_data(), row_descs.get_const_data(),
storage.get_const_data(), diag_idxs.get_const_data(), factors.get(),
parameters_.checked_lookup, tmp));
return factorization_type::create_from_combined_lu(std::move(factors));
}

Expand Down
2 changes: 1 addition & 1 deletion core/factorization/lu_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ namespace kernels {
const IndexType* lookup_offsets, const int64* lookup_descs, \
const int32* lookup_storage, const IndexType* diag_idxs, \
matrix::Csr<ValueType, IndexType>* factors, \
array<int>& tmp_storage)
bool checked_lookup array<int>& tmp_storage)


#define GKO_DECLARE_LU_SYMMETRIC_FACTORIZE_SIMPLE(IndexType) \
Expand Down
2 changes: 1 addition & 1 deletion dpcpp/factorization/lu_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ template <typename ValueType, typename IndexType>
void factorize(std::shared_ptr<const DefaultExecutor> exec,
const IndexType* lookup_offsets, const int64* lookup_descs,
const int32* lookup_storage, const IndexType* diag_idxs,
matrix::Csr<ValueType, IndexType>* factors,
matrix::Csr<ValueType, IndexType>* factors, bool checked_lookup,
array<int>& tmp_storage) GKO_NOT_IMPLEMENTED;

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_LU_FACTORIZE);
Expand Down
9 changes: 9 additions & 0 deletions include/ginkgo/core/factorization/lu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,15 @@ class Lu
* incorrect results or crash.
*/
bool GKO_FACTORY_PARAMETER_SCALAR(skip_sorting, false);

/**
* The symbolic factoization should contains the fill-in information. If
* it is not the case, users might face hang or illegal access issue.
* Please enable this option when the symbolic factorization does not
* contain the full fill-in information. Symbolic factorization must
* still contain the entry for the original matrix.
*/
bool GKO_FACTORY_PARAMETER_SCALAR(checked_lookup, false);
};

/**
Expand Down
13 changes: 10 additions & 3 deletions omp/factorization/lu_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ template <typename ValueType, typename IndexType>
void factorize(std::shared_ptr<const DefaultExecutor> exec,
const IndexType* lookup_offsets, const int64* lookup_descs,
const int32* lookup_storage, const IndexType* diag_idxs,
matrix::Csr<ValueType, IndexType>* factors,
matrix::Csr<ValueType, IndexType>* factors, bool checked_lookup,
array<int>& tmp_storage)
{
const auto num_rows = factors->get_size()[0];
Expand All @@ -89,8 +89,15 @@ void factorize(std::shared_ptr<const DefaultExecutor> exec,
for (auto dep_nz = dep_diag_idx + 1; dep_nz < dep_end; dep_nz++) {
const auto col = cols[dep_nz];
const auto val = vals[dep_nz];
const auto nz = row_begin + lookup.lookup_unsafe(col);
vals[nz] -= scale * val;
if (checked_lookup) {
const auto idx = lookup[col];
if (idx != invalid_index<IndexType>()) {
vals[row_begin + idx] -= scale * val;
}
} else {
const auto nz = row_begin + lookup.lookup_unsafe(col);
vals[nz] -= scale * val;
}
}
}
}
Expand Down
13 changes: 10 additions & 3 deletions reference/factorization/lu_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ template <typename ValueType, typename IndexType>
void factorize(std::shared_ptr<const DefaultExecutor> exec,
const IndexType* lookup_offsets, const int64* lookup_descs,
const int32* lookup_storage, const IndexType* diag_idxs,
matrix::Csr<ValueType, IndexType>* factors,
matrix::Csr<ValueType, IndexType>* factors, bool checked_lookup,
array<int>& tmp_storage)
{
const auto num_rows = factors->get_size()[0];
Expand All @@ -87,8 +87,15 @@ void factorize(std::shared_ptr<const DefaultExecutor> exec,
for (auto dep_nz = dep_diag_idx + 1; dep_nz < dep_end; dep_nz++) {
const auto col = cols[dep_nz];
const auto val = vals[dep_nz];
const auto nz = row_begin + lookup.lookup_unsafe(col);
vals[nz] -= scale * val;
if (checked_lookup) {
const auto idx = lookup[col];
if (idx != invalid_index<IndexType>()) {
vals[row_begin + idx] -= scale * val;
}
} else {
const auto nz = row_begin + lookup.lookup_unsafe(col);
vals[nz] -= scale * val;
}
}
}
}
Expand Down

0 comments on commit f7ee9ec

Please sign in to comment.