Skip to content

Commit

Permalink
[core] deduce constness in run
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcelKoch committed May 7, 2024
1 parent 6f0eacd commit 8d03dbf
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 18 deletions.
5 changes: 2 additions & 3 deletions core/base/block_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ namespace {
template <typename Fn>
auto dispatch_dense(Fn&& fn, LinOp* v)
{
return run<matrix::Dense<float>*, matrix::Dense<double>*,
matrix::Dense<std::complex<float>>*,
matrix::Dense<std::complex<double>>*>(v, std::forward<Fn>(fn));
return run<matrix::Dense, float, double, std::complex<float>,
std::complex<double>>(v, std::forward<Fn>(fn));
}


Expand Down
49 changes: 43 additions & 6 deletions core/base/dispatch_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ namespace gko {
namespace detail {


template <typename T, typename MaybeConstU>
using with_same_constness_t = std::conditional_t<
std::is_const<typename std::remove_reference_t<MaybeConstU>>::value,
const T, T>;


/**
*
* @copydoc run<typename ReturnType, typename K, typename... Types, typename T,
Expand All @@ -24,7 +30,7 @@ namespace detail {
* @note this is the end case
*/
template <typename ReturnType, typename T, typename Func, typename... Args>
ReturnType run_impl(T obj, Func&&, Args&&...)
ReturnType run_impl(T* obj, Func&&, Args&&...)
{
GKO_NOT_SUPPORTED(obj);
}
Expand All @@ -37,9 +43,9 @@ ReturnType run_impl(T obj, Func&&, Args&&...)
*/
template <typename ReturnType, typename K, typename... Types, typename T,
typename Func, typename... Args>
ReturnType run_impl(T obj, Func&& f, Args&&... args)
ReturnType run_impl(T* obj, Func&& f, Args&&... args)
{
if (auto dobj = dynamic_cast<K>(obj)) {
if (auto dobj = dynamic_cast<with_same_constness_t<K, T>*>(obj)) {
return f(dobj, std::forward<Args>(args)...);
} else {
return run_impl<ReturnType, Types...>(obj, std::forward<Func>(f),
Expand Down Expand Up @@ -120,18 +126,49 @@ ReturnType run_impl(T obj, Func&& f, Args&&... args)
*/
template <typename K, typename... Types, typename T, typename Func,
typename... Args>
auto run(T obj, Func&& f, Args&&... args)
auto run(T* obj, Func&& f, Args&&... args)
{
#if __cplusplus < 201703L
using ReturnType = std::result_of_t<Func(K, Args...)>;
using ReturnType =
std::result_of_t<Func(detail::with_same_constness_t<K, T>*, Args...)>;
#else
using ReturnType = std::invoke_result_t<Func, K, Args...>;
using ReturnType =
std::invoke_result_t<Func, detail::with_same_constness_t<K, T>*,
Args...>;
#endif
return detail::run_impl<ReturnType, K, Types...>(
obj, std::forward<Func>(f), std::forward<Args>(args)...);
}


/**
* run uses template to go through the list and select the valid
* template and run it.
*
* @tparam Base the Base class with one template
* @tparam K the current template type of B. pointer of const Base<K> is tried
* in the conversion.
* @tparam ...Types the other types will be tried in the conversion if K fails
* @tparam T the type of input object waiting converted
* @tparam Func the function will run if the object can be converted to pointer
* of const Base<K>
* @tparam ...Args the additional arguments for the Func
*
* @param obj the input object waiting converted
* @param f the function will run if obj can be converted successfully
* @param args the additional arguments for the function
*
* @return the result of f invoked with obj cast to the first matching type
*/
template <template <class> class K, typename... Types, typename T,
typename Func, typename... Args>
auto run(T* obj, Func&& f, Args&&... args)
{
return run<K<Types>...>(obj, std::forward<Func>(f),
std::forward<Args>(args)...);
}


/**
* run uses template to go through the list and select the valid
* template and run it.
Expand Down
4 changes: 1 addition & 3 deletions core/matrix/dense.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1522,9 +1522,7 @@ template <typename ValueType, typename Function>
void gather_mixed_real_complex(Function fn, LinOp* out)
{
#ifdef GINKGO_MIXED_PRECISION
using fst_type = matrix::Dense<ValueType>;
using snd_type = matrix::Dense<next_precision<ValueType>>;
run<fst_type*, snd_type*>(out, fn);
run<matrix::Dense, ValueType, next_precision<ValueType>>(out, fn);
#else
precision_dispatch<ValueType>(fn, out);
#endif
Expand Down
4 changes: 2 additions & 2 deletions core/matrix/permutation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,8 @@ void dispatch_dense(const LinOp* op, Functor fn)
{
using matrix::Dense;
using std::complex;
run<const Dense<double>*, const Dense<float>*,
const Dense<complex<double>>*, const Dense<complex<float>>*>(op, fn);
run<Dense, double, float, std::complex<double>, std::complex<float>>(op,
fn);
}


Expand Down
6 changes: 2 additions & 4 deletions core/matrix/row_gatherer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,15 @@ RowGatherer<IndexType>::create_const(
template <typename IndexType>
void RowGatherer<IndexType>::apply_impl(const LinOp* in, LinOp* out) const
{
run<const Dense<float>*, const Dense<double>*,
const Dense<std::complex<float>>*, const Dense<std::complex<double>>*>(
run<Dense, float, double, std::complex<float>, std::complex<double>>(
in, [&](auto gather) { gather->row_gather(&row_idxs_, out); });
}

template <typename IndexType>
void RowGatherer<IndexType>::apply_impl(const LinOp* alpha, const LinOp* in,
const LinOp* beta, LinOp* out) const
{
run<const Dense<float>*, const Dense<double>*,
const Dense<std::complex<float>>*, const Dense<std::complex<double>>*>(
run<Dense, float, double, std::complex<float>, std::complex<double>>(
in,
[&](auto gather) { gather->row_gather(alpha, &row_idxs_, beta, out); });
}
Expand Down

0 comments on commit 8d03dbf

Please sign in to comment.