Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Review updates
Browse files Browse the repository at this point in the history
Co-authored-by: Yu-Hsiang Tsai <[email protected]>
Co-authored-by: Marcel Koch <[email protected]>
3 people committed Oct 25, 2023

Verified

This commit was signed with the committer’s verified signature.
johnjeffers John Jeffers
1 parent d7e4535 commit e17e58d
Showing 23 changed files with 399 additions and 379 deletions.
2 changes: 1 addition & 1 deletion core/base/batch_struct.hpp
Original file line number Diff line number Diff line change
@@ -71,7 +71,7 @@ struct uniform_batch {
int32 num_rows;
int32 num_rhs;

size_type get_entry_storage() const
size_type get_storage_size() const
{
return num_rows * stride * sizeof(value_type);
}
2 changes: 2 additions & 0 deletions core/log/logger.cpp
Original file line number Diff line number Diff line change
@@ -75,6 +75,8 @@ constexpr Logger::mask_type Logger::linop_factory_generate_completed_mask;
constexpr Logger::mask_type Logger::criterion_check_started_mask;
constexpr Logger::mask_type Logger::criterion_check_completed_mask;

constexpr Logger::mask_type Logger::batch_solver_completed_mask;

constexpr Logger::mask_type Logger::iteration_complete_mask;


6 changes: 2 additions & 4 deletions core/matrix/batch_struct.hpp
Original file line number Diff line number Diff line change
@@ -56,7 +56,6 @@ struct batch_item {
int32 stride;
int32 num_rows;
int32 num_cols;
int32 num_nnz = num_rows * stride;
};


@@ -73,14 +72,13 @@ struct uniform_batch {
int32 stride;
int32 num_rows;
int32 num_cols;
int32 num_nnz = num_rows * stride;

inline size_type get_num_nnz() const
{
return static_cast<size_type>(stride * num_rows);
}

inline size_type get_entry_storage() const
inline size_type get_storage_size() const
{
return get_num_nnz() * sizeof(value_type);
}
@@ -132,7 +130,7 @@ struct uniform_batch {
return static_cast<size_type>(stride * num_stored_elems_per_row);
}

inline size_type get_entry_storage() const
inline size_type get_storage_size() const
{
return get_num_nnz() * sizeof(value_type);
}
2 changes: 1 addition & 1 deletion core/solver/batch_bicgstab.cpp
Original file line number Diff line number Diff line change
@@ -56,7 +56,7 @@ GKO_REGISTER_OPERATION(apply, batch_bicgstab::apply);
template <typename ValueType>
void Bicgstab<ValueType>::solver_apply(
const MultiVector<ValueType>* b, MultiVector<ValueType>* x,
log::BatchLogData<remove_complex<ValueType>>* log_data) const
log::detail::log_data<remove_complex<ValueType>>* log_data) const
{
using MVec = MultiVector<ValueType>;
const kernels::batch_bicgstab::BicgstabSettings<remove_complex<ValueType>>
4 changes: 2 additions & 2 deletions core/solver/batch_bicgstab_kernels.hpp
Original file line number Diff line number Diff line change
@@ -56,7 +56,7 @@ template <typename RealType>
struct BicgstabSettings {
int max_iterations;
RealType residual_tol;
::gko::batch::stop::ToleranceType tol_type;
::gko::batch::stop::tolerance_type tol_type;
};


@@ -100,7 +100,7 @@ inline int local_memory_requirement(const int num_rows, const int num_rhs)
remove_complex<_type>>& options, \
const batch::BatchLinOp* a, const batch::BatchLinOp* preconditioner, \
const batch::MultiVector<_type>* b, batch::MultiVector<_type>* x, \
gko::batch::log::BatchLogData<remove_complex<_type>>& logdata)
gko::batch::log::detail::log_data<remove_complex<_type>>& logdata)


#define GKO_DECLARE_ALL_AS_TEMPLATES \
84 changes: 53 additions & 31 deletions core/solver/batch_dispatch.hpp
Original file line number Diff line number Diff line change
@@ -163,7 +163,7 @@ namespace solver {


template <typename DValueType>
class DummyKernelCaller {
class KernelCallerInterface {
public:
template <typename BatchMatrixType, typename PrecType, typename StopType,
typename LogType>
@@ -174,30 +174,42 @@ class DummyKernelCaller {
};


namespace log {
namespace detail {
/**
*
* Types of batch loggers available.
*/
enum class log_type { simple_convergence_completion };


} // namespace detail
} // namespace log


/**
* Handles dispatching to the correct instantiation of a batched solver
* depending on runtime parameters.
*
* @tparam KernelCaller Class with an interface like DummyKernelCaller,
* @tparam ValueType The user-facing value type.
* @tparam KernelCaller Class with an interface like KernelCallerInterface,
* that is responsible for finally calling the templated backend-specific
* kernel.
* @tparam SettingsType Structure type of options for the particular solver to
* be used.
* @tparam ValueType The user-facing value type.
*/
template <typename KernelCaller, typename SettingsType, typename ValueType>
template <typename ValueType, typename KernelCaller, typename SettingsType>
class BatchSolverDispatch {
public:
using value_type = ValueType;
using device_value_type = DeviceValueType<ValueType>;
using real_type = remove_complex<value_type>;

BatchSolverDispatch(const KernelCaller& kernel_caller,
const SettingsType& settings,
const BatchLinOp* const matrix,
const BatchLinOp* const preconditioner,
const log::BatchLogType logger_type =
log::BatchLogType::simple_convergence_completion)
BatchSolverDispatch(
const KernelCaller& kernel_caller, const SettingsType& settings,
const BatchLinOp* const matrix, const BatchLinOp* const preconditioner,
const log::detail::log_type logger_type =
log::detail::log_type::simple_convergence_completion)
: caller_{kernel_caller},
settings_{settings},
mat_{matrix},
@@ -212,12 +224,12 @@ class BatchSolverDispatch {
const multi_vector::uniform_batch<const device_value_type>& b_item,
const multi_vector::uniform_batch<device_value_type>& x_item)
{
if (settings_.tol_type == stop::ToleranceType::absolute) {
if (settings_.tol_type == stop::tolerance_type::absolute) {
caller_.template call_kernel<
BatchMatrixType, PrecType,
device::batch_stop::SimpleAbsResidual<device_value_type>,
LogType>(logger, mat_item, precond, b_item, x_item);
} else if (settings_.tol_type == stop::ToleranceType::relative) {
} else if (settings_.tol_type == stop::tolerance_type::relative) {
caller_.template call_kernel<
BatchMatrixType, PrecType,
device::batch_stop::SimpleRelResidual<device_value_type>,
@@ -250,9 +262,10 @@ class BatchSolverDispatch {
const BatchMatrixType& amat,
const multi_vector::uniform_batch<const device_value_type>& b_item,
const multi_vector::uniform_batch<device_value_type>& x_item,
log::BatchLogData<real_type>& log_data)
batch::log::detail::log_data<real_type>& log_data)
{
if (logger_type_ == log::BatchLogType::simple_convergence_completion) {
if (logger_type_ ==
log::detail::log_type::simple_convergence_completion) {
device::batch_log::SimpleFinalLogger<real_type> logger(
log_data.res_norms.get_data(), log_data.iter_counts.get_data());
dispatch_on_preconditioner(logger, amat, b_item, x_item);
@@ -261,19 +274,11 @@ class BatchSolverDispatch {
}
}

/**
* Solves a linear system from the given data and kernel caller.
*
* @note The correct backend-specific get_batch_struct function needs to be
* available in the current scope.
*/
void apply(const MultiVector<ValueType>* const b,
MultiVector<ValueType>* const x,
log::BatchLogData<real_type>& log_data)
void dispatch_on_matrix(
const multi_vector::uniform_batch<const device_value_type>& b_item,
const multi_vector::uniform_batch<device_value_type>& x_item,
batch::log::detail::log_data<real_type>& log_data)
{
const auto x_item = device::get_batch_struct(x);
const auto b_item = device::get_batch_struct(b);

if (auto batch_mat =
dynamic_cast<const batch::matrix::Ell<ValueType, int32>*>(
mat_)) {
@@ -289,26 +294,42 @@ class BatchSolverDispatch {
}
}

/**
* Solves a linear system from the given data and kernel caller.
*
* @note The correct backend-specific get_batch_struct function needs to be
* available in the current scope.
*/
void apply(const MultiVector<ValueType>* const b,
MultiVector<ValueType>* const x,
batch::log::detail::log_data<real_type>& log_data)
{
const auto x_item = device::get_batch_struct(x);
const auto b_item = device::get_batch_struct(b);

dispatch_on_matrix(b_item, x_item, log_data);
}

private:
const KernelCaller caller_;
const SettingsType settings_;
const BatchLinOp* mat_;
const BatchLinOp* precond_;
const log::BatchLogType logger_type_;
const log::detail::log_type logger_type_;
};


/**
* Convenient function to create a dispatcher. Infers most template arguments.
*/
template <typename ValueType, typename KernelCaller, typename SettingsType>
BatchSolverDispatch<KernelCaller, SettingsType, ValueType> create_dispatcher(
BatchSolverDispatch<ValueType, KernelCaller, SettingsType> create_dispatcher(
const KernelCaller& kernel_caller, const SettingsType& settings,
const BatchLinOp* const matrix, const BatchLinOp* const preconditioner,
const log::BatchLogType logger_type =
log::BatchLogType::simple_convergence_completion)
const log::detail::log_type logger_type =
log::detail::log_type::simple_convergence_completion)
{
return BatchSolverDispatch<KernelCaller, SettingsType, ValueType>(
return BatchSolverDispatch<ValueType, KernelCaller, SettingsType>(
kernel_caller, settings, matrix, preconditioner, logger_type);
}

@@ -317,4 +338,5 @@ BatchSolverDispatch<KernelCaller, SettingsType, ValueType> create_dispatcher(
} // namespace batch
} // namespace gko


#endif // GKO_CORE_SOLVER_BATCH_DISPATCH_HPP_
60 changes: 30 additions & 30 deletions core/test/solver/batch_bicgstab.cpp
Original file line number Diff line number Diff line change
@@ -60,8 +60,8 @@ class BatchBicgstab : public ::testing::Test {

BatchBicgstab()
: exec(gko::ReferenceExecutor::create()),
mtx(gko::test::generate_3pt_stencil_batch_matrix<Mtx>(
this->exec->get_master(), nrows, nbatch)),
mtx(gko::share(gko::test::generate_3pt_stencil_batch_matrix<Mtx>(
this->exec->get_master(), num_batch_items, num_rows))),
solver_factory(Solver::build()
.with_default_max_iterations(def_max_iters)
.with_default_tolerance(def_abs_res_tol)
@@ -71,14 +71,14 @@ class BatchBicgstab : public ::testing::Test {
{}

std::shared_ptr<const gko::Executor> exec;
const gko::size_type nbatch = 3;
const int nrows = 5;
std::shared_ptr<Mtx> mtx;
const gko::size_type num_batch_items = 3;
const int num_rows = 5;
std::shared_ptr<const Mtx> mtx;
std::unique_ptr<typename Solver::Factory> solver_factory;
const int def_max_iters = 100;
const real_type def_abs_res_tol = 1e-11;
const gko::batch::stop::ToleranceType def_tol_type =
gko::batch::stop::ToleranceType::absolute;
const gko::batch::stop::tolerance_type def_tol_type =
gko::batch::stop::tolerance_type::absolute;
std::unique_ptr<gko::batch::BatchLinOp> solver;
};

@@ -94,12 +94,10 @@ TYPED_TEST(BatchBicgstab, FactoryKnowsItsExecutor)
TYPED_TEST(BatchBicgstab, FactoryCreatesCorrectSolver)
{
using Solver = typename TestFixture::Solver;
for (size_t i = 0; i < this->nbatch; i++) {
ASSERT_EQ(this->solver->get_common_size(),
gko::dim<2>(this->nrows, this->nrows));
}
ASSERT_EQ(this->solver->get_common_size(),
gko::dim<2>(this->num_rows, this->num_rows));

auto solver = static_cast<Solver*>(this->solver.get());
auto solver = gko::as<Solver>(this->solver.get());

ASSERT_NE(solver->get_system_matrix(), nullptr);
ASSERT_EQ(solver->get_system_matrix(), this->mtx);
@@ -114,10 +112,11 @@ TYPED_TEST(BatchBicgstab, CanBeCopied)

copy->copy_from(this->solver.get());

ASSERT_EQ(copy->get_common_size(), gko::dim<2>(this->nrows, this->nrows));
ASSERT_EQ(copy->get_num_batch_items(), this->nbatch);
auto copy_mtx = static_cast<Solver*>(copy.get())->get_system_matrix();
const auto copy_batch_mtx = static_cast<const Mtx*>(copy_mtx.get());
ASSERT_EQ(copy->get_common_size(),
gko::dim<2>(this->num_rows, this->num_rows));
ASSERT_EQ(copy->get_num_batch_items(), this->num_batch_items);
auto copy_mtx = gko::as<Solver>(copy.get())->get_system_matrix();
const auto copy_batch_mtx = gko::as<const Mtx>(copy_mtx.get());
GKO_ASSERT_BATCH_MTX_NEAR(this->mtx.get(), copy_batch_mtx, 0.0);
}

@@ -130,10 +129,11 @@ TYPED_TEST(BatchBicgstab, CanBeMoved)

copy->move_from(this->solver);

ASSERT_EQ(copy->get_common_size(), gko::dim<2>(this->nrows, this->nrows));
ASSERT_EQ(copy->get_num_batch_items(), this->nbatch);
auto copy_mtx = static_cast<Solver*>(copy.get())->get_system_matrix();
const auto copy_batch_mtx = static_cast<const Mtx*>(copy_mtx.get());
ASSERT_EQ(copy->get_common_size(),
gko::dim<2>(this->num_rows, this->num_rows));
ASSERT_EQ(copy->get_num_batch_items(), this->num_batch_items);
auto copy_mtx = gko::as<Solver>(copy.get())->get_system_matrix();
const auto copy_batch_mtx = gko::as<const Mtx>(copy_mtx.get());
GKO_ASSERT_BATCH_MTX_NEAR(this->mtx.get(), copy_batch_mtx, 0.0);
}

@@ -145,10 +145,11 @@ TYPED_TEST(BatchBicgstab, CanBeCloned)

auto clone = this->solver->clone();

ASSERT_EQ(clone->get_common_size(), gko::dim<2>(this->nrows, this->nrows));
ASSERT_EQ(clone->get_num_batch_items(), this->nbatch);
auto clone_mtx = static_cast<Solver*>(clone.get())->get_system_matrix();
const auto clone_batch_mtx = static_cast<const Mtx*>(clone_mtx.get());
ASSERT_EQ(clone->get_common_size(),
gko::dim<2>(this->num_rows, this->num_rows));
ASSERT_EQ(clone->get_num_batch_items(), this->num_batch_items);
auto clone_mtx = gko::as<Solver>(clone.get())->get_system_matrix();
const auto clone_batch_mtx = gko::as<const Mtx>(clone_mtx.get());
GKO_ASSERT_BATCH_MTX_NEAR(this->mtx.get(), clone_batch_mtx, 0.0);
}

@@ -160,8 +161,7 @@ TYPED_TEST(BatchBicgstab, CanBeCleared)
this->solver->clear();

ASSERT_EQ(this->solver->get_num_batch_items(), 0);
auto solver_mtx =
static_cast<Solver*>(this->solver.get())->get_system_matrix();
auto solver_mtx = gko::as<Solver>(this->solver.get())->get_system_matrix();
ASSERT_EQ(solver_mtx, nullptr);
}

@@ -175,14 +175,14 @@ TYPED_TEST(BatchBicgstab, CanSetCriteriaInFactory)
Solver::build()
.with_default_max_iterations(22)
.with_default_tolerance(static_cast<real_type>(0.25))
.with_tolerance_type(gko::batch::stop::ToleranceType::relative)
.with_tolerance_type(gko::batch::stop::tolerance_type::relative)
.on(this->exec);

auto solver = solver_factory->generate(this->mtx);
ASSERT_EQ(solver->get_parameters().default_max_iterations, 22);
ASSERT_EQ(solver->get_parameters().default_tolerance, 0.25);
ASSERT_EQ(solver->get_parameters().tolerance_type,
gko::batch::stop::ToleranceType::relative);
gko::batch::stop::tolerance_type::relative);
}


@@ -194,7 +194,7 @@ TYPED_TEST(BatchBicgstab, CanSetResidualTol)
Solver::build()
.with_default_max_iterations(22)
.with_default_tolerance(static_cast<real_type>(0.25))
.with_tolerance_type(gko::batch::stop::ToleranceType::relative)
.with_tolerance_type(gko::batch::stop::tolerance_type::relative)
.on(this->exec);
auto solver = solver_factory->generate(this->mtx);

@@ -214,7 +214,7 @@ TYPED_TEST(BatchBicgstab, CanSetMaxIterations)
Solver::build()
.with_default_max_iterations(22)
.with_default_tolerance(static_cast<real_type>(0.25))
.with_tolerance_type(gko::batch::stop::ToleranceType::relative)
.with_tolerance_type(gko::batch::stop::tolerance_type::relative)
.on(this->exec);
auto solver = solver_factory->generate(this->mtx);

Loading

0 comments on commit e17e58d

Please sign in to comment.