Skip to content

Commit

Permalink
use new factory setup, move crit to base
Browse files Browse the repository at this point in the history
  • Loading branch information
pratikvn committed Oct 25, 2023
1 parent c22a6c9 commit 2611c7a
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 51 deletions.
14 changes: 7 additions & 7 deletions core/test/solver/batch_bicgstab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class BatchBicgstab : public ::testing::Test {
this->exec->get_master(), nrows, nbatch)),
solver_factory(Solver::build()
.with_default_max_iterations(def_max_iters)
.with_default_residual_tol(def_abs_res_tol)
.with_default_tolerance(def_abs_res_tol)
.with_tolerance_type(def_tol_type)
.on(exec)),
solver(solver_factory->generate(mtx))
Expand Down Expand Up @@ -174,13 +174,13 @@ TYPED_TEST(BatchBicgstab, CanSetCriteriaInFactory)
auto solver_factory =
Solver::build()
.with_default_max_iterations(22)
.with_default_residual_tol(static_cast<real_type>(0.25))
.with_default_tolerance(static_cast<real_type>(0.25))
.with_tolerance_type(gko::batch::stop::ToleranceType::relative)
.on(this->exec);

auto solver = solver_factory->generate(this->mtx);
ASSERT_EQ(solver->get_parameters().default_max_iterations, 22);
ASSERT_EQ(solver->get_parameters().default_residual_tol, 0.25);
ASSERT_EQ(solver->get_parameters().default_tolerance, 0.25);
ASSERT_EQ(solver->get_parameters().tolerance_type,
gko::batch::stop::ToleranceType::relative);
}
Expand All @@ -193,15 +193,15 @@ TYPED_TEST(BatchBicgstab, CanSetResidualTol)
auto solver_factory =
Solver::build()
.with_default_max_iterations(22)
.with_default_residual_tol(static_cast<real_type>(0.25))
.with_default_tolerance(static_cast<real_type>(0.25))
.with_tolerance_type(gko::batch::stop::ToleranceType::relative)
.on(this->exec);
auto solver = solver_factory->generate(this->mtx);

solver->set_residual_tolerance(0.5);

ASSERT_EQ(solver->get_parameters().default_max_iterations, 22);
ASSERT_EQ(solver->get_parameters().default_residual_tol, 0.25);
ASSERT_EQ(solver->get_parameters().default_tolerance, 0.25);
ASSERT_EQ(solver->get_residual_tolerance(), 0.5);
}

Expand All @@ -213,14 +213,14 @@ TYPED_TEST(BatchBicgstab, CanSetMaxIterations)
auto solver_factory =
Solver::build()
.with_default_max_iterations(22)
.with_default_residual_tol(static_cast<real_type>(0.25))
.with_default_tolerance(static_cast<real_type>(0.25))
.with_tolerance_type(gko::batch::stop::ToleranceType::relative)
.on(this->exec);
auto solver = solver_factory->generate(this->mtx);

solver->set_max_iterations(10);

ASSERT_EQ(solver->get_parameters().default_residual_tol, 0.25);
ASSERT_EQ(solver->get_parameters().default_tolerance, 0.25);
ASSERT_EQ(solver->get_parameters().default_max_iterations, 22);
ASSERT_EQ(solver->get_max_iterations(), 10);
}
Expand Down
42 changes: 5 additions & 37 deletions include/ginkgo/core/solver/batch_bicgstab.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,43 +79,11 @@ class Bicgstab final
using value_type = ValueType;
using real_type = gko::remove_complex<ValueType>;

GKO_CREATE_FACTORY_PARAMETERS(parameters, Factory)
{
/**
* Preconditioner factory.
*/
std::shared_ptr<const BatchLinOpFactory> GKO_FACTORY_PARAMETER_SCALAR(
preconditioner, nullptr);

/**
* Already generated preconditioner. If one is provided, the factory
* `preconditioner` will be ignored.
*/
std::shared_ptr<const BatchLinOp> GKO_FACTORY_PARAMETER_SCALAR(
generated_preconditioner, nullptr);

/**
* Default maximum number iterations allowed.
*
* Generated solvers are initialized with this value for their maximum
* iterations.
*/
int GKO_FACTORY_PARAMETER_SCALAR(default_max_iterations, 100);

/**
* Default residual tolerance.
*
* Generated solvers are initialized with this value for their residual
* tolerance.
*/
real_type GKO_FACTORY_PARAMETER_SCALAR(default_residual_tol, 1e-11);

/**
* To specify which tolerance is to be considered.
*/
::gko::batch::stop::ToleranceType GKO_FACTORY_PARAMETER_SCALAR(
tolerance_type, ::gko::batch::stop::ToleranceType::absolute);
};
class Factory;

struct parameters_type
: enable_preconditioned_iterative_solver_factory_parameters<
parameters_type, Factory> {};
GKO_ENABLE_BATCH_LIN_OP_FACTORY(Bicgstab, parameters, Factory);
GKO_ENABLE_BUILD_METHOD(Factory);

Expand Down
94 changes: 93 additions & 1 deletion include/ginkgo/core/solver/batch_solver_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,13 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#define GKO_PUBLIC_CORE_SOLVER_BATCH_SOLVER_HPP_


#include <ginkgo/core/base/abstract_factory.hpp>
#include <ginkgo/core/base/batch_lin_op.hpp>
#include <ginkgo/core/base/batch_multi_vector.hpp>
#include <ginkgo/core/base/utils_helper.hpp>
#include <ginkgo/core/log/batch_logger.hpp>
#include <ginkgo/core/matrix/batch_identity.hpp>
#include <ginkgo/core/stop/batch_stop_enum.hpp>


namespace gko {
Expand Down Expand Up @@ -143,14 +145,104 @@ template <typename ParamsType>
common_batch_params extract_common_batch_params(ParamsType& params)
{
return {params.preconditioner, params.generated_preconditioner,
params.default_residual_tol, params.default_max_iterations};
params.default_tolerance, params.default_max_iterations};
}


} // namespace detail


/**
* The parameter type shared between all preconditioned iterative solvers,
* excluding the parameters available in iterative_solver_factory_parameters.
* @see GKO_CREATE_FACTORY_PARAMETERS
*/
struct preconditioned_iterative_solver_factory_parameters {
/**
* The preconditioner to be used by the iterative solver. By default, no
* preconditioner is used.
*/
std::shared_ptr<const BatchLinOpFactory> preconditioner{nullptr};

/**
* Already generated preconditioner. If one is provided, the factory
* `preconditioner` will be ignored.
*/
std::shared_ptr<const BatchLinOp> generated_preconditioner{nullptr};
};


template <typename Parameters, typename Factory>
struct enable_preconditioned_iterative_solver_factory_parameters
: enable_parameters_type<Parameters, Factory>,
preconditioned_iterative_solver_factory_parameters {
/**
* Default maximum number iterations allowed.
*
* Generated solvers are initialized with this value for their maximum
* iterations.
*/
int GKO_FACTORY_PARAMETER_SCALAR(default_max_iterations, 100);

/**
* Default residual tolerance.
*
* Generated solvers are initialized with this value for their residual
* tolerance.
*/
double GKO_FACTORY_PARAMETER_SCALAR(default_tolerance, 1e-11);

/**
* To specify which type of tolerance check is to be considered, absolute or
* relative (to the rhs l2 norm)
*/
::gko::batch::stop::ToleranceType GKO_FACTORY_PARAMETER_SCALAR(
tolerance_type, ::gko::batch::stop::ToleranceType::absolute);

/**
* Provides a preconditioner factory to be used by the iterative solver in a
* fluent interface.
* @see preconditioned_iterative_solver_factory_parameters::preconditioner
*/
Parameters& with_preconditioner(
deferred_factory_parameter<BatchLinOpFactory> preconditioner)
{
this->preconditioner_generator = std::move(preconditioner);
this->deferred_factories["preconditioner"] = [](const auto& exec,
auto& params) {
if (!params.preconditioner_generator.is_empty()) {
params.preconditioner =
params.preconditioner_generator.on(exec);
}
};
return *self();
}

/**
* Provides a concrete preconditioner to be used by the iterative solver in
* a fluent interface.
* @see preconditioned_iterative_solver_factory_parameters::preconditioner
*/
Parameters& with_generated_preconditioner(
std::shared_ptr<const BatchLinOp> generated_preconditioner)
{
this->generated_preconditioner = std::move(generated_preconditioner);
return *self();
}

private:
GKO_ENABLE_SELF(Parameters);

deferred_factory_parameter<BatchLinOpFactory> preconditioner_generator;
};


/**
* This mixin provides apply and common iterative solver functionality to all
* the batched solvers.
*
* @tparam ConcreteSolver The concrete solver class.
* @tparam ValueType The value type of the multivectors.
* @tparam PolymorphicBase The base class; must be a subclass of BatchLinOp.
*/
template <typename ConcreteSolver,
Expand Down
8 changes: 4 additions & 4 deletions reference/test/solver/batch_bicgstab_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ TYPED_TEST(BatchBicgstab, CanSolveDenseSystem)
auto solver_factory =
Solver::build()
.with_default_max_iterations(max_iters)
.with_default_residual_tol(tol)
.with_default_tolerance(tol)
.with_tolerance_type(gko::batch::stop::ToleranceType::relative)
.on(this->exec);
const int num_rows = 13;
Expand Down Expand Up @@ -233,7 +233,7 @@ TYPED_TEST(BatchBicgstab, ApplyLogsResAndIters)
auto solver_factory =
Solver::build()
.with_default_max_iterations(max_iters)
.with_default_residual_tol(tol)
.with_default_tolerance(tol)
.with_tolerance_type(gko::batch::stop::ToleranceType::relative)
.on(this->exec);
const int num_rows = 13;
Expand Down Expand Up @@ -273,7 +273,7 @@ TYPED_TEST(BatchBicgstab, CanSolveEllSystem)
auto solver_factory =
Solver::build()
.with_default_max_iterations(max_iters)
.with_default_residual_tol(tol)
.with_default_tolerance(tol)
.with_tolerance_type(gko::batch::stop::ToleranceType::relative)
.on(this->exec);
const int num_rows = 13;
Expand Down Expand Up @@ -306,7 +306,7 @@ TYPED_TEST(BatchBicgstab, CanSolveDenseHpdSystem)
auto solver_factory =
Solver::build()
.with_default_max_iterations(max_iters)
.with_default_residual_tol(tol)
.with_default_tolerance(tol)
.with_tolerance_type(gko::batch::stop::ToleranceType::absolute)
.on(this->exec);
const int num_rows = 65;
Expand Down
4 changes: 2 additions & 2 deletions test/solver/batch_bicgstab_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class BatchBicgstab : public CommonTestFixture {
solver_factory =
solver_type::build()
.with_default_max_iterations(max_iters)
.with_default_residual_tol(tol)
.with_default_tolerance(tol)
.with_tolerance_type(gko::batch::stop::ToleranceType::relative)
.on(exec);
return gko::test::generate_3pt_stencil_batch_problem<MatrixType>(
Expand Down Expand Up @@ -204,7 +204,7 @@ TEST_F(BatchBicgstab, CanSolveLargeHpdSystem)
auto solver_factory =
solver_type::build()
.with_default_max_iterations(max_iters)
.with_default_residual_tol(tol)
.with_default_tolerance(tol)
.with_tolerance_type(gko::batch::stop::ToleranceType::absolute)
.on(exec);
std::shared_ptr<Logger> logger = Logger::create(exec);
Expand Down

0 comments on commit 2611c7a

Please sign in to comment.