Skip to content

Commit

Permalink
Review updates WIP
Browse files Browse the repository at this point in the history
Co-authored-by: Yu-Hsiang Tsai <[email protected]>
Co-authored-by: Marcel Koch <[email protected]>
  • Loading branch information
3 people committed Oct 25, 2023
1 parent d7e4535 commit 82712a3
Show file tree
Hide file tree
Showing 23 changed files with 225 additions and 169 deletions.
2 changes: 1 addition & 1 deletion core/base/batch_struct.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ struct uniform_batch {
int32 num_rows;
int32 num_rhs;

size_type get_entry_storage() const
size_type get_storage_size() const
{
return num_rows * stride * sizeof(value_type);
}
Expand Down
2 changes: 2 additions & 0 deletions core/log/logger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ constexpr Logger::mask_type Logger::linop_factory_generate_completed_mask;
constexpr Logger::mask_type Logger::criterion_check_started_mask;
constexpr Logger::mask_type Logger::criterion_check_completed_mask;

constexpr Logger::mask_type Logger::batch_solver_completed_mask;

constexpr Logger::mask_type Logger::iteration_complete_mask;


Expand Down
6 changes: 2 additions & 4 deletions core/matrix/batch_struct.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ struct batch_item {
int32 stride;
int32 num_rows;
int32 num_cols;
int32 num_nnz = num_rows * stride;
};


Expand All @@ -73,14 +72,13 @@ struct uniform_batch {
int32 stride;
int32 num_rows;
int32 num_cols;
int32 num_nnz = num_rows * stride;

inline size_type get_num_nnz() const
{
return static_cast<size_type>(stride * num_rows);
}

inline size_type get_entry_storage() const
inline size_type get_storage_size() const
{
return get_num_nnz() * sizeof(value_type);
}
Expand Down Expand Up @@ -132,7 +130,7 @@ struct uniform_batch {
return static_cast<size_type>(stride * num_stored_elems_per_row);
}

inline size_type get_entry_storage() const
inline size_type get_storage_size() const
{
return get_num_nnz() * sizeof(value_type);
}
Expand Down
2 changes: 1 addition & 1 deletion core/solver/batch_bicgstab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ GKO_REGISTER_OPERATION(apply, batch_bicgstab::apply);
template <typename ValueType>
void Bicgstab<ValueType>::solver_apply(
const MultiVector<ValueType>* b, MultiVector<ValueType>* x,
log::BatchLogData<remove_complex<ValueType>>* log_data) const
log::detail::BatchLogData<remove_complex<ValueType>>* log_data) const
{
using MVec = MultiVector<ValueType>;
const kernels::batch_bicgstab::BicgstabSettings<remove_complex<ValueType>>
Expand Down
2 changes: 1 addition & 1 deletion core/solver/batch_bicgstab_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ inline int local_memory_requirement(const int num_rows, const int num_rhs)
remove_complex<_type>>& options, \
const batch::BatchLinOp* a, const batch::BatchLinOp* preconditioner, \
const batch::MultiVector<_type>* b, batch::MultiVector<_type>* x, \
gko::batch::log::BatchLogData<remove_complex<_type>>& logdata)
gko::batch::log::detail::BatchLogData<remove_complex<_type>>& logdata)


#define GKO_DECLARE_ALL_AS_TEMPLATES \
Expand Down
80 changes: 51 additions & 29 deletions core/solver/batch_dispatch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ namespace solver {


template <typename DValueType>
class DummyKernelCaller {
class KernelCallerInterface {
public:
template <typename BatchMatrixType, typename PrecType, typename StopType,
typename LogType>
Expand All @@ -174,30 +174,42 @@ class DummyKernelCaller {
};


namespace log {
namespace detail {
/**
*
* Types of batch loggers available.
*/
enum class BatchLogType { simple_convergence_completion };


} // namespace detail
} // namespace log


/**
* Handles dispatching to the correct instantiation of a batched solver
* depending on runtime parameters.
*
* @tparam KernelCaller Class with an interface like DummyKernelCaller,
* @tparam ValueType The user-facing value type.
* @tparam KernelCaller Class with an interface like KernelCallerInterface,
* that is responsible for finally calling the templated backend-specific
* kernel.
* @tparam SettingsType Structure type of options for the particular solver to
* be used.
* @tparam ValueType The user-facing value type.
*/
template <typename KernelCaller, typename SettingsType, typename ValueType>
template <typename ValueType, typename KernelCaller, typename SettingsType>
class BatchSolverDispatch {
public:
using value_type = ValueType;
using device_value_type = DeviceValueType<ValueType>;
using real_type = remove_complex<value_type>;

BatchSolverDispatch(const KernelCaller& kernel_caller,
const SettingsType& settings,
const BatchLinOp* const matrix,
const BatchLinOp* const preconditioner,
const log::BatchLogType logger_type =
log::BatchLogType::simple_convergence_completion)
BatchSolverDispatch(
const KernelCaller& kernel_caller, const SettingsType& settings,
const BatchLinOp* const matrix, const BatchLinOp* const preconditioner,
const log::detail::BatchLogType logger_type =
log::detail::BatchLogType::simple_convergence_completion)
: caller_{kernel_caller},
settings_{settings},
mat_{matrix},
Expand Down Expand Up @@ -250,9 +262,10 @@ class BatchSolverDispatch {
const BatchMatrixType& amat,
const multi_vector::uniform_batch<const device_value_type>& b_item,
const multi_vector::uniform_batch<device_value_type>& x_item,
log::BatchLogData<real_type>& log_data)
batch::log::detail::BatchLogData<real_type>& log_data)
{
if (logger_type_ == log::BatchLogType::simple_convergence_completion) {
if (logger_type_ ==
log::detail::BatchLogType::simple_convergence_completion) {
device::batch_log::SimpleFinalLogger<real_type> logger(
log_data.res_norms.get_data(), log_data.iter_counts.get_data());
dispatch_on_preconditioner(logger, amat, b_item, x_item);
Expand All @@ -261,19 +274,11 @@ class BatchSolverDispatch {
}
}

/**
* Solves a linear system from the given data and kernel caller.
*
* @note The correct backend-specific get_batch_struct function needs to be
* available in the current scope.
*/
void apply(const MultiVector<ValueType>* const b,
MultiVector<ValueType>* const x,
log::BatchLogData<real_type>& log_data)
void dispatch_on_matrix(
const multi_vector::uniform_batch<const device_value_type>& b_item,
const multi_vector::uniform_batch<device_value_type>& x_item,
batch::log::detail::BatchLogData<real_type>& log_data)
{
const auto x_item = device::get_batch_struct(x);
const auto b_item = device::get_batch_struct(b);

if (auto batch_mat =
dynamic_cast<const batch::matrix::Ell<ValueType, int32>*>(
mat_)) {
Expand All @@ -289,26 +294,42 @@ class BatchSolverDispatch {
}
}

/**
* Solves a linear system from the given data and kernel caller.
*
* @note The correct backend-specific get_batch_struct function needs to be
* available in the current scope.
*/
void apply(const MultiVector<ValueType>* const b,
MultiVector<ValueType>* const x,
batch::log::detail::BatchLogData<real_type>& log_data)
{
const auto x_item = device::get_batch_struct(x);
const auto b_item = device::get_batch_struct(b);

dispatch_on_matrix(b_item, x_item, log_data);
}

private:
const KernelCaller caller_;
const SettingsType settings_;
const BatchLinOp* mat_;
const BatchLinOp* precond_;
const log::BatchLogType logger_type_;
const log::detail::BatchLogType logger_type_;
};


/**
* Convenient function to create a dispatcher. Infers most template arguments.
*/
template <typename ValueType, typename KernelCaller, typename SettingsType>
BatchSolverDispatch<KernelCaller, SettingsType, ValueType> create_dispatcher(
BatchSolverDispatch<ValueType, KernelCaller, SettingsType> create_dispatcher(
const KernelCaller& kernel_caller, const SettingsType& settings,
const BatchLinOp* const matrix, const BatchLinOp* const preconditioner,
const log::BatchLogType logger_type =
log::BatchLogType::simple_convergence_completion)
const log::detail::BatchLogType logger_type =
log::detail::BatchLogType::simple_convergence_completion)
{
return BatchSolverDispatch<KernelCaller, SettingsType, ValueType>(
return BatchSolverDispatch<ValueType, KernelCaller, SettingsType>(
kernel_caller, settings, matrix, preconditioner, logger_type);
}

Expand All @@ -317,4 +338,5 @@ BatchSolverDispatch<KernelCaller, SettingsType, ValueType> create_dispatcher(
} // namespace batch
} // namespace gko


#endif // GKO_CORE_SOLVER_BATCH_DISPATCH_HPP_
44 changes: 22 additions & 22 deletions core/test/solver/batch_bicgstab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class BatchBicgstab : public ::testing::Test {
BatchBicgstab()
: exec(gko::ReferenceExecutor::create()),
mtx(gko::test::generate_3pt_stencil_batch_matrix<Mtx>(
this->exec->get_master(), nrows, nbatch)),
this->exec->get_master(), num_rows, num_batch_items)),
solver_factory(Solver::build()
.with_default_max_iterations(def_max_iters)
.with_default_tolerance(def_abs_res_tol)
Expand All @@ -71,8 +71,8 @@ class BatchBicgstab : public ::testing::Test {
{}

std::shared_ptr<const gko::Executor> exec;
const gko::size_type nbatch = 3;
const int nrows = 5;
const gko::size_type num_batch_items = 3;
const int num_rows = 5;
std::shared_ptr<Mtx> mtx;
std::unique_ptr<typename Solver::Factory> solver_factory;
const int def_max_iters = 100;
Expand All @@ -94,12 +94,10 @@ TYPED_TEST(BatchBicgstab, FactoryKnowsItsExecutor)
TYPED_TEST(BatchBicgstab, FactoryCreatesCorrectSolver)
{
using Solver = typename TestFixture::Solver;
for (size_t i = 0; i < this->nbatch; i++) {
ASSERT_EQ(this->solver->get_common_size(),
gko::dim<2>(this->nrows, this->nrows));
}
ASSERT_EQ(this->solver->get_common_size(),
gko::dim<2>(this->num_rows, this->num_rows));

auto solver = static_cast<Solver*>(this->solver.get());
auto solver = gko::as<Solver>(this->solver.get());

ASSERT_NE(solver->get_system_matrix(), nullptr);
ASSERT_EQ(solver->get_system_matrix(), this->mtx);
Expand All @@ -114,10 +112,11 @@ TYPED_TEST(BatchBicgstab, CanBeCopied)

copy->copy_from(this->solver.get());

ASSERT_EQ(copy->get_common_size(), gko::dim<2>(this->nrows, this->nrows));
ASSERT_EQ(copy->get_num_batch_items(), this->nbatch);
auto copy_mtx = static_cast<Solver*>(copy.get())->get_system_matrix();
const auto copy_batch_mtx = static_cast<const Mtx*>(copy_mtx.get());
ASSERT_EQ(copy->get_common_size(),
gko::dim<2>(this->num_rows, this->num_rows));
ASSERT_EQ(copy->get_num_batch_items(), this->num_batch_items);
auto copy_mtx = gko::as<Solver>(copy.get())->get_system_matrix();
const auto copy_batch_mtx = gko::as<const Mtx>(copy_mtx.get());
GKO_ASSERT_BATCH_MTX_NEAR(this->mtx.get(), copy_batch_mtx, 0.0);
}

Expand All @@ -130,10 +129,11 @@ TYPED_TEST(BatchBicgstab, CanBeMoved)

copy->move_from(this->solver);

ASSERT_EQ(copy->get_common_size(), gko::dim<2>(this->nrows, this->nrows));
ASSERT_EQ(copy->get_num_batch_items(), this->nbatch);
auto copy_mtx = static_cast<Solver*>(copy.get())->get_system_matrix();
const auto copy_batch_mtx = static_cast<const Mtx*>(copy_mtx.get());
ASSERT_EQ(copy->get_common_size(),
gko::dim<2>(this->num_rows, this->num_rows));
ASSERT_EQ(copy->get_num_batch_items(), this->num_batch_items);
auto copy_mtx = gko::as<Solver>(copy.get())->get_system_matrix();
const auto copy_batch_mtx = gko::as<const Mtx>(copy_mtx.get());
GKO_ASSERT_BATCH_MTX_NEAR(this->mtx.get(), copy_batch_mtx, 0.0);
}

Expand All @@ -145,10 +145,11 @@ TYPED_TEST(BatchBicgstab, CanBeCloned)

auto clone = this->solver->clone();

ASSERT_EQ(clone->get_common_size(), gko::dim<2>(this->nrows, this->nrows));
ASSERT_EQ(clone->get_num_batch_items(), this->nbatch);
auto clone_mtx = static_cast<Solver*>(clone.get())->get_system_matrix();
const auto clone_batch_mtx = static_cast<const Mtx*>(clone_mtx.get());
ASSERT_EQ(clone->get_common_size(),
gko::dim<2>(this->num_rows, this->num_rows));
ASSERT_EQ(clone->get_num_batch_items(), this->num_batch_items);
auto clone_mtx = gko::as<Solver>(clone.get())->get_system_matrix();
const auto clone_batch_mtx = gko::as<const Mtx>(clone_mtx.get());
GKO_ASSERT_BATCH_MTX_NEAR(this->mtx.get(), clone_batch_mtx, 0.0);
}

Expand All @@ -160,8 +161,7 @@ TYPED_TEST(BatchBicgstab, CanBeCleared)
this->solver->clear();

ASSERT_EQ(this->solver->get_num_batch_items(), 0);
auto solver_mtx =
static_cast<Solver*>(this->solver.get())->get_system_matrix();
auto solver_mtx = gko::as<Solver>(this->solver.get())->get_system_matrix();
ASSERT_EQ(solver_mtx, nullptr);
}

Expand Down
34 changes: 19 additions & 15 deletions core/test/utils/batch_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,6 @@ compute_residual_norms(
using real_vec = batch::MultiVector<remove_complex<value_type>>;
auto exec = mtx->get_executor();
auto num_batch_items = x->get_num_batch_items();
auto num_rows = x->get_common_size()[0];
auto num_rhs = x->get_common_size()[1];
const gko::batch_dim<2> norm_dim(num_batch_items, gko::dim<2>(1, num_rhs));

Expand All @@ -221,7 +220,13 @@ struct Result {

std::shared_ptr<multi_vec> x;
std::shared_ptr<real_vec> res_norm;
std::unique_ptr<gko::batch::log::BatchLogData<remove_complex<ValueType>>>
};


template <typename ValueType>
struct ResultWithLogData : public Result<ValueType> {
std::unique_ptr<
gko::batch::log::detail::BatchLogData<remove_complex<ValueType>>>
log_data;
};

Expand Down Expand Up @@ -255,9 +260,9 @@ Result<typename MatrixType::value_type> solve_linear_system(
}


template <typename MatrixType, typename SolveFunction, typename Settings>
Result<typename MatrixType::value_type> solve_linear_system(
std::shared_ptr<const Executor> exec, SolveFunction solve_function,
template <typename MatrixType, typename SolveLambda, typename Settings>
ResultWithLogData<typename MatrixType::value_type> solve_linear_system(
std::shared_ptr<const Executor> exec, SolveLambda solve_lambda,
const Settings settings, const LinearSystem<MatrixType>& sys,
std::shared_ptr<batch::BatchLinOpFactory> precond_factory = nullptr)
{
Expand All @@ -269,17 +274,15 @@ Result<typename MatrixType::value_type> solve_linear_system(
const size_type num_batch_items = sys.matrix->get_num_batch_items();
const int num_rows = sys.matrix->get_common_size()[0];
const int num_rhs = sys.rhs->get_common_size()[1];
const gko::batch_dim<2> vec_size(num_batch_items,
gko::dim<2>(num_rows, num_rhs));
const gko::batch_dim<2> norm_size(num_batch_items, gko::dim<2>(1, num_rhs));

Result<value_type> result;
// Initialize r to the original unscaled b
ResultWithLogData<value_type> result;
result.x = multi_vec::create_with_config_of(sys.rhs);
result.x->fill(zero<value_type>());

auto log_data = std::make_unique<batch::log::BatchLogData<real_type>>(
exec, num_batch_items);
auto log_data =
std::make_unique<batch::log::detail::BatchLogData<real_type>>(
exec, num_batch_items);

std::unique_ptr<gko::batch::BatchLinOp> precond;
if (precond_factory) {
Expand All @@ -288,11 +291,12 @@ Result<typename MatrixType::value_type> solve_linear_system(
precond = nullptr;
}

solve_function(settings, precond.get(), sys.matrix.get(), sys.rhs.get(),
result.x.get(), *log_data.get());
solve_lambda(settings, precond.get(), sys.matrix.get(), sys.rhs.get(),
result.x.get(), *log_data.get());

result.log_data = std::make_unique<batch::log::BatchLogData<real_type>>(
exec->get_master());
result.log_data =
std::make_unique<batch::log::detail::BatchLogData<real_type>>(
exec->get_master());
result.log_data->iter_counts = log_data->iter_counts;
result.log_data->res_norms = log_data->res_norms;

Expand Down
2 changes: 1 addition & 1 deletion cuda/solver/batch_bicgstab_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ void apply(std::shared_ptr<const DefaultExecutor> exec,
const batch::BatchLinOp* const precon,
const batch::MultiVector<ValueType>* const b,
batch::MultiVector<ValueType>* const x,
batch::log::BatchLogData<remove_complex<ValueType>>& logdata)
batch::log::detail::BatchLogData<remove_complex<ValueType>>& logdata)
GKO_NOT_IMPLEMENTED;

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL);
Expand Down
Loading

0 comments on commit 82712a3

Please sign in to comment.