Skip to content

Commit

Permalink
Address reviewer comments
Browse files Browse the repository at this point in the history
Signed-off-by: Matthew Michel <[email protected]>
  • Loading branch information
mmichel11 committed Nov 6, 2024
1 parent 39b572f commit 33337f8
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 41 deletions.
10 changes: 4 additions & 6 deletions include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@
#include <cmath>
#include <limits>
#include <cstdint>
#include <tuple>

#include "../../iterator_impl.h"
#include "../../execution_impl.h"
#include "../../utils_ranges.h"
#include "../../utils.h"

#include "sycl_defs.h"
#include "parallel_backend_sycl_utils.h"
Expand Down Expand Up @@ -258,12 +260,8 @@ struct __parallel_for_large_submitter;
template <typename... _Name, typename... _RangeTypes>
struct __parallel_for_large_submitter<__internal::__optional_kernel_name<_Name...>, _RangeTypes...>
{
// Flatten the range as std::tuple value types in the range are likely coming from separate ranges in a zip
// iterator.
using _FlattenedRangesTuple = typename oneapi::dpl::__internal::__flatten_std_or_internal_tuple<
std::tuple<oneapi::dpl::__internal::__value_t<_RangeTypes>...>>::type;
static constexpr std::size_t __min_type_size =
oneapi::dpl::__internal::__min_tuple_type_size_v<_FlattenedRangesTuple>;
static constexpr std::size_t __min_type_size = oneapi::dpl::__internal::__min_nested_type_size<
std::tuple<oneapi::dpl::__internal::__value_t<_RangeTypes>...>>::value;
// __iters_per_work_item is set to 1, 2, 4, 8, or 16 depending on the smallest type in the
// flattened ranges. This allows us to launch enough work per item to saturate the device's memory
// bandwidth. This heuristic errs on the side of launching more work per item than what is needed to
Expand Down
19 changes: 0 additions & 19 deletions include/oneapi/dpl/pstl/tuple_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -793,25 +793,6 @@ struct __decay_with_tuple_specialization<::std::tuple<_Args...>>
template <typename... _Args>
using __decay_with_tuple_specialization_t = typename __decay_with_tuple_specialization<_Args...>::type;

// Flatten nested std::tuple or oneapi::dpl::__internal::tuple types into a single std::tuple.
template <typename _T>
struct __flatten_std_or_internal_tuple
{
using type = std::tuple<_T>;
};

template <typename... _Ts>
struct __flatten_std_or_internal_tuple<std::tuple<_Ts...>>
{
using type = decltype(std::tuple_cat(std::declval<typename __flatten_std_or_internal_tuple<_Ts>::type>()...));
};

template <typename... _Ts>
struct __flatten_std_or_internal_tuple<oneapi::dpl::__internal::tuple<_Ts...>>
{
using type = decltype(std::tuple_cat(std::declval<typename __flatten_std_or_internal_tuple<_Ts>::type>()...));
};

} // namespace __internal
} // namespace dpl
} // namespace oneapi
Expand Down
25 changes: 9 additions & 16 deletions include/oneapi/dpl/pstl/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <iterator>
#include <functional>
#include <type_traits>
#include <algorithm>

#if _ONEDPL_BACKEND_SYCL
# include "hetero/dpcpp/sycl_defs.h"
Expand Down Expand Up @@ -784,29 +785,21 @@ union __lazy_ctor_storage
}
};

// Utility that returns the smallest type size in a tuple.
template <typename _Tuple>
class __min_tuple_type_size;

// Returns the smallest type within a set of potentially nested template types.
// E.g. If we consider the type: T = tuple<float, tuple<short, long>, int, double>,
// then __min_nested_type_size<T>::value returns sizeof(short).
template <typename _T>
class __min_tuple_type_size<std::tuple<_T>>
struct __min_nested_type_size
{
public:
static constexpr std::size_t value = sizeof(_T);
constexpr static std::size_t value = sizeof(_T);
};

template <typename _T, typename... _Ts>
class __min_tuple_type_size<std::tuple<_T, _Ts...>>
template <template <typename...> typename _WrapperType, typename... _Ts>
struct __min_nested_type_size<_WrapperType<_Ts...>>
{
static constexpr std::size_t __min_type_value_ts = __min_tuple_type_size<std::tuple<_Ts...>>::value;

public:
static constexpr std::size_t value = std::min(sizeof(_T), __min_type_value_ts);
constexpr static std::size_t value = std::min({__min_nested_type_size<_Ts>::value...});
};

template <typename _Tuple>
inline constexpr std::size_t __min_tuple_type_size_v = __min_tuple_type_size<_Tuple>::value;

} // namespace __internal
} // namespace dpl
} // namespace oneapi
Expand Down

0 comments on commit 33337f8

Please sign in to comment.