Skip to content

Commit

Permalink
[core] remove unique ptr from run
Browse files Browse the repository at this point in the history
Co-authored-by: Yu-Hsiang M. Tsai <[email protected]>
  • Loading branch information
MarcelKoch and yhmtsai committed May 14, 2024
1 parent 1ebd740 commit 456f509
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 119 deletions.
210 changes: 95 additions & 115 deletions core/base/dispatch_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ using with_same_constness_t = std::conditional_t<
/**
*
* @copydoc run<typename ReturnType, typename K, typename... Types, typename T,
* typename Func, typename... Args>
* typename Func, typename... Args>(T*, Func&&, Args&&...)
*
* @note this is the end case
*/
Expand All @@ -36,8 +36,8 @@ ReturnType run_impl(T* obj, Func&&, Args&&...)
}

/**
* @copydoc run<typename K, typename... Types, typename T, typename Func,
* typename... Args>
* @copydoc run<typename ReturnType, typename K, typename... Types, typename T,
* typename Func, typename... Args>(T*, Func&&, Args&&...)
*
* @note This has additionally the return type encoded.
*/
Expand All @@ -55,71 +55,78 @@ ReturnType run_impl(T* obj, Func&& f, Args&&... args)


/**
* @copydoc run<typename K, typename... Types, typename T, typename Func,
* typename... Args>
* @copydoc run<template <typename> class Base, typename T, typename Func,
* typename... Args>(T, Func&&, Args&&... )
*
* @note This handle the shared_ptr cases
* @note This is the end case for the smart pointer cases
*/
template <typename ReturnType, typename K, typename... Types, typename T,
typename Func, typename... Args>
ReturnType run_impl(std::shared_ptr<T> obj, Func&& f, Args&&... args)
{
if (auto dobj = std::dynamic_pointer_cast<K>(obj)) {
return f(dobj, std::forward<Args>(args)...);
} else {
return run_impl<ReturnType, Types...>(obj, std::forward<Func>(f),
std::forward<Args>(args)...);
}
}


/**
* run uses template to go through the list and select the valid
* template and run it.
*
* @tparam Base the Base class with one template
* @tparam T the type of input object waiting converted
* @tparam Func the validation
* @tparam ...Args the variadic arguments.
*
* @note this is the end case
*/
template <typename ReturnType, template <typename> class Base, typename T,
typename Func, typename... Args>
template <typename ReturnType, typename T, typename Func, typename... Args>
ReturnType run_impl(T obj, Func, Args...)
{
GKO_NOT_SUPPORTED(obj);
}


/**
* run uses template to go through the list and select the valid
* template and run it.
* @copydoc run<template <typename> class Base, typename T, typename Func,
* typename... Args>(T, Func&&, Args&&... )
*
* @tparam Base the Base class with one template
* @tparam K the current template type of B. pointer of const Base<K> is tried
* in the conversion.
* @tparam ...Types the other types will be tried in the conversion if K fails
* @tparam T the type of input object waiting converted
* @tparam Func the function type that is invoked if the object can be
* converted to pointer of const Base<K>
* @tparam ...Args the additional arguments for the Func
*
* @param obj the input object that should be converted
* @param f the function will get invoked if obj can be converted successfully
* @param args the additional arguments for the function
* @note This handles the shared pointer case
*/
template <typename ReturnType, template <typename> class Base, typename K,
typename... Types, typename T, typename Func, typename... Args>
ReturnType run_impl(T obj, Func&& f, Args&&... args)
template <typename ReturnType, typename K, typename... Types, typename T,
typename Func, typename... Args>
ReturnType run_impl(std::shared_ptr<T> obj, Func&& f, Args&&... args)
{
if (auto dobj = std::dynamic_pointer_cast<const Base<K>>(obj)) {
if (auto dobj =
std::dynamic_pointer_cast<with_same_constness_t<K, T>>(obj)) {
return f(dobj, args...);
} else {
return run_impl<ReturnType, Base, Types...>(
obj, std::forward<Func>(f), std::forward<Args>(args)...);
return run_impl<ReturnType, Types...>(obj, std::forward<Func>(f),
std::forward<Args>(args)...);
}
}

/**
* Helper struct to get the result type of a function.
*
* @tparam T Blueprint type for the function. This determines the
* const-qualifier for K, as well as the pointer type
* (either T*, or shared_ptr<T>) for K.
* @tparam K The actual type to be used in the function.
* @tparam Func The function to get the result from.
* @tparam Args Additional arguments to the function.
*/
template <typename T, typename K, typename Func, typename... Args>
struct result_of;

template <typename T, typename K, typename Func, typename... Args>
struct result_of<T*, K, Func, Args...> {
#if __cplusplus < 201703L
// result_of_t is deprecated in C++17
using type =
std::result_of_t<Func(detail::with_same_constness_t<K, T>*, Args...)>;
#else
using type =
std::invoke_result_t<Func, detail::with_same_constness_t<K, T>*,
Args...>;
#endif
};

template <typename T, typename K, typename Func, typename... Args>
struct result_of<std::shared_ptr<T>, K, Func, Args...> {
#if __cplusplus < 201703L
// result_of_t is deprecated in C++17
using type = std::result_of_t<Func(
std::shared_ptr<detail::with_same_constness_t<K, T>>, Args...)>;
#else
using type = std::invoke_result_t<
Func, std::shared_ptr<detail::with_same_constness_t<K, T>>, Args...>;
#endif
};

template <typename T, typename K, typename Func, typename... Args>
using result_of_t = typename result_of<T, K, Func, Args...>::type;


} // namespace detail

Expand Down Expand Up @@ -148,15 +155,7 @@ template <typename K, typename... Types, typename T, typename Func,
typename... Args>
auto run(T* obj, Func&& f, Args&&... args)
{
#if __cplusplus < 201703L
// result_of_t is deprecated in C++17
using ReturnType =
std::result_of_t<Func(detail::with_same_constness_t<K, T>*, Args...)>;
#else
using ReturnType =
std::invoke_result_t<Func, detail::with_same_constness_t<K, T>*,
Args...>;
#endif
using ReturnType = detail::result_of_t<T*, K, Func, Args...>;
return detail::run_impl<ReturnType, K, Types...>(
obj, std::forward<Func>(f), std::forward<Args>(args)...);
}
Expand All @@ -166,62 +165,56 @@ auto run(T* obj, Func&& f, Args&&... args)
* run uses template to go through the list and select the valid
* template and run it.
*
* @tparam K the current shared_ptr type tried in the conversion
* @tparam ...Types the other types will be tried in the conversion if K fails
* @tparam T the type of input object
* @tparam Func the function type that is invoked if the object can be
* converted to std::shared_ptr<K>
* @tparam Base the Base class with one template
* @tparam ...Types the types that will be tried with Base, i.e. Base<Types>...
* @tparam T the type of input object waiting converted
* @tparam Func the function will run if the object can be converted to pointer
* of const Base<K>
* @tparam ...Args the additional arguments for the Func
*
* @param obj the input object that should be converted
* @param f the function will get invoked if obj can be converted successfully
* @param obj the input object waiting converted
* @param f the function will run if obj can be converted successfully
* @param args the additional arguments for the function
*
* @note This assumes that the return type of f is independent of the input
* types (std::shared_ptr<K>, std::shared_ptr<Types>...)
*
* @return the result of f invoked with obj cast to the first matching type
*/
template <typename K, typename... Types, typename T, typename Func,
typename... Args>
auto run(std::shared_ptr<T> obj, Func&& f, Args&&... args)
template <template <class> class Base, typename... Types, typename T,
typename Func, typename... Args>
auto run(T* obj, Func&& f, Args&&... args)
{
#if __cplusplus < 201703L
// result_of_t is deprecated in C++17
using ReturnType = std::result_of_t<Func(std::shared_ptr<K>, Args...)>;
#else
using ReturnType = std::invoke_result_t<Func, std::shared_ptr<K>, Args...>;
#endif
return detail::run_impl<ReturnType, K, Types...>(
obj, std::forward<Func>(f), std::forward<Args>(args)...);
return run<Base<Types>...>(obj, std::forward<Func>(f),
std::forward<Args>(args)...);
}


/**
* run uses template to go through the list and select the valid
* template and run it.
*
* @tparam Base the Base class with one template
* @tparam K the current template type of B. pointer of const Base<K> is tried
* in the conversion.
* @tparam K the current type to try in the conversion
* @tparam ...Types the other types will be tried in the conversion if K fails
* @tparam T the type of input object waiting converted
* @tparam Func the function will run if the object can be converted to pointer
* of const Base<K>
* @tparam T the element type of input object waiting converted
* @tparam Func the function type that is invoked if the object can be
* converted to pointer of Base<K>
* @tparam ...Args the additional arguments for the Func
*
* @param obj the input object waiting converted
* @param f the function will run if obj can be converted successfully
* @param obj the input object that should be converted
* @param f the function will get invoked if obj can be converted successfully
* @param args the additional arguments for the function
*
* @note This assumes that the return type of f is independent of the input
* types (smart_ptr<K>, smart_ptr<Types>...)
*
* @return the result of f invoked with obj cast to the first matching type
*/
template <template <class> class K, typename... Types, typename T,
typename Func, typename... Args>
auto run(T* obj, Func&& f, Args&&... args)
template <typename K, typename... Types, typename T, typename Func,
typename... Args>
auto run(std::shared_ptr<T> obj, Func&& f, Args&&... args)
{
return run<K<Types>...>(obj, std::forward<Func>(f),
std::forward<Args>(args)...);
using ReturnType =
detail::result_of_t<std::shared_ptr<T>, K, Func, Args...>;
return detail::run_impl<ReturnType, K, Types...>(
obj, std::forward<Func>(f), std::forward<Args>(args)...);
}


Expand All @@ -230,10 +223,8 @@ auto run(T* obj, Func&& f, Args&&... args)
* template and run it.
*
* @tparam Base the Base class with one template
* @tparam K the current template type of B. pointer of const Base<K> is tried
* in the conversion.
* @tparam ...Types the other types will be tried in the conversion if K fails
* @tparam T the type of input object waiting converted
* @tparam ...Types the types that will be tried with Base, i.e. Base<Types>...
* @tparam T the element type of input object waiting converted
* @tparam Func the function type that is invoked if the object can be
* converted to pointer of const Base<K>
* @tparam ...Args the additional arguments for the Func
Expand All @@ -247,23 +238,12 @@ auto run(T* obj, Func&& f, Args&&... args)
*
* @return the result of f invoked with obj cast to the first matching type
*/
template <template <typename> class Base, typename K, typename... Types,
typename T, typename Func, typename... Args>
auto run(T obj, Func&& f, Args&&... args)
template <template <typename> class Base, typename... Types, typename T,
typename Func, typename... Args>
auto run(std::shared_ptr<T> obj, Func&& f, Args&&... args)
{
// Since T is a smart pointer, the type used to invoke f also has to be a
// smart pointer. unique_ptr is used because it can be converted into a
// shared_ptr, but not the other way around.
#if __cplusplus < 201703L
// result_of_t is deprecated in C++17
using ReturnType =
std::result_of_t<Func(std::unique_ptr<Base<K>>, Args...)>;
#else
using ReturnType =
std::invoke_result_t<Func, std::unique_ptr<Base<K>>, Args...>;
#endif
return detail::run_impl<ReturnType, Base, K, Types...>(
obj, std::forward<Func>(f), std::forward<Args>(args)...);
return run<Base<Types>...>(obj, std::forward<Func>(f),
std::forward<Args>(args)...);
}


Expand Down
7 changes: 3 additions & 4 deletions core/multigrid/pgm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,8 @@ void Pgm<ValueType, IndexType>::generate()
setup_fine_op(obj);
} else {
// handle other ValueTypes.
run<const ConvertibleTo<fst_mtx_type>,
const ConvertibleTo<snd_mtx_type>>(obj, convert_fine_op);
run<ConvertibleTo, fst_mtx_type, snd_mtx_type>(obj,
convert_fine_op);
}

auto distributed_setup = [&](auto matrix) {
Expand Down Expand Up @@ -490,8 +490,7 @@ void Pgm<ValueType, IndexType>::generate()
};

// the fine op is using csr with the current ValueType
run<const fst_mtx_type, const snd_mtx_type>(this->get_fine_op(),
distributed_setup);
run<fst_mtx_type, snd_mtx_type>(this->get_fine_op(), distributed_setup);
} else
#endif // GINKGO_BUILD_MPI
{
Expand Down

0 comments on commit 456f509

Please sign in to comment.