Skip to content

Commit

Permalink
Make submdspan_mapping hidden friend
Browse files Browse the repository at this point in the history
  • Loading branch information
crtrott committed Oct 17, 2023
1 parent 51b9db0 commit 042cbbc
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 20 deletions.
10 changes: 10 additions & 0 deletions include/experimental/__p0009_bits/layout_left.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,16 @@ class layout_left::mapping {
private:
_MDSPAN_NO_UNIQUE_ADDRESS extents_type __extents{};

// [mdspan.submdspan.mapping], submdspan mapping specialization
template<class... SliceSpecifiers>
constexpr auto submdspan_mapping_impl( // exposition only
SliceSpecifiers... slices) const;

template<class... SliceSpecifiers>
friend constexpr auto submdspan_mapping(
const mapping& src, SliceSpecifiers... slices) {
return src.submdspan_mapping_impl(slices...);
}
};


Expand Down
10 changes: 10 additions & 0 deletions include/experimental/__p0009_bits/layout_right.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,16 @@ class layout_right::mapping {
private:
_MDSPAN_NO_UNIQUE_ADDRESS extents_type __extents{};

// [mdspan.submdspan.mapping], submdspan mapping specialization
template<class... SliceSpecifiers>
constexpr auto submdspan_mapping_impl( // exposition only
SliceSpecifiers... slices) const;

template<class... SliceSpecifiers>
friend constexpr auto submdspan_mapping(
const mapping& src, SliceSpecifiers... slices) {
return src.submdspan_mapping_impl(slices...);
}
};

} // end namespace MDSPAN_IMPL_STANDARD_NAMESPACE
Expand Down
10 changes: 10 additions & 0 deletions include/experimental/__p0009_bits/layout_stride.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,16 @@ struct layout_stride {
}
#endif

// [mdspan.submdspan.mapping], submdspan mapping specialization
template<class... SliceSpecifiers>
constexpr auto submdspan_mapping_impl( // exposition only
SliceSpecifiers... slices) const;

template<class... SliceSpecifiers>
friend constexpr auto submdspan_mapping(
const mapping& src, SliceSpecifiers... slices) {
return src.submdspan_mapping_impl(slices...);
}
};
};

Expand Down
43 changes: 23 additions & 20 deletions include/experimental/__p2630_bits/submdspan_mapping.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,15 @@ struct preserve_layout_left_mapping<std::index_sequence<Idx...>, SubRank,
#pragma diag_suppress = implicit_return_from_non_void_function
#endif
// Actual submdspan mapping call
template <class Extents, class... SliceSpecifiers>
template <class Extents>
template <class... SliceSpecifiers>
MDSPAN_INLINE_FUNCTION
constexpr auto
submdspan_mapping(const layout_left::mapping<Extents> &src_mapping,
SliceSpecifiers... slices) {
layout_left::mapping<Extents>::submdspan_mapping_impl(SliceSpecifiers... slices) const {

// compute sub extents
using src_ext_t = Extents;
auto dst_ext = submdspan_extents(src_mapping.extents(), slices...);
auto dst_ext = submdspan_extents(extents(), slices...);
using dst_ext_t = decltype(dst_ext);

// figure out sub layout type
Expand All @@ -119,7 +119,7 @@ submdspan_mapping(const layout_left::mapping<Extents> &src_mapping,
// layout_left case
return submdspan_mapping_result<dst_mapping_t>{
dst_mapping_t(dst_ext),
static_cast<size_t>(src_mapping(detail::first_of(slices)...))};
static_cast<size_t>(this->operator()(detail::first_of(slices)...))};
} else {
// layout_stride case
auto inv_map = detail::inv_map_rank(
Expand All @@ -128,15 +128,15 @@ submdspan_mapping(const layout_left::mapping<Extents> &src_mapping,
slices...);
return submdspan_mapping_result<dst_mapping_t>{
dst_mapping_t(dst_ext, detail::construct_sub_strides(
src_mapping, inv_map,
*this, inv_map,
// HIP needs deduction guides to have markups so we need to be explicit
// NVCC 11.0 has a bug with deduction guide here, tested that 11.2 does not have the issue
#if defined(_MDSPAN_HAS_HIP) || (defined(__NVCC__) && (__CUDACC_VER_MAJOR__ * 100 + __CUDACC_VER_MINOR__ * 10) < 1120)
std::tuple<decltype(detail::stride_of(slices))...>{detail::stride_of(slices)...})),
#else
std::tuple{detail::stride_of(slices)...})),
#endif
static_cast<size_t>(src_mapping(detail::first_of(slices)...))};
static_cast<size_t>(this->operator()(detail::first_of(slices)...))};
}
#if defined(__NVCC__) && !defined(__CUDA_ARCH__) && defined(__GNUC__)
__builtin_unreachable();
Expand Down Expand Up @@ -203,14 +203,15 @@ struct preserve_layout_right_mapping<std::index_sequence<Idx...>, SubRank,
#pragma diagnostic push
#pragma diag_suppress = implicit_return_from_non_void_function
#endif
template <class Extents, class... SliceSpecifiers>
template <class Extents>
template <class... SliceSpecifiers>
MDSPAN_INLINE_FUNCTION
constexpr auto
submdspan_mapping(const layout_right::mapping<Extents> &src_mapping,
SliceSpecifiers... slices) {
layout_right::mapping<Extents>::submdspan_mapping_impl(
SliceSpecifiers... slices) const {
// get sub extents
using src_ext_t = Extents;
auto dst_ext = submdspan_extents(src_mapping.extents(), slices...);
auto dst_ext = submdspan_extents(extents(), slices...);
using dst_ext_t = decltype(dst_ext);

// determine new layout type
Expand All @@ -225,7 +226,7 @@ submdspan_mapping(const layout_right::mapping<Extents> &src_mapping,
// layout_right case
return submdspan_mapping_result<dst_mapping_t>{
dst_mapping_t(dst_ext),
static_cast<size_t>(src_mapping(detail::first_of(slices)...))};
static_cast<size_t>(this->operator()(detail::first_of(slices)...))};
} else {
// layout_stride case
auto inv_map = detail::inv_map_rank(
Expand All @@ -234,15 +235,15 @@ submdspan_mapping(const layout_right::mapping<Extents> &src_mapping,
slices...);
return submdspan_mapping_result<dst_mapping_t>{
dst_mapping_t(dst_ext, detail::construct_sub_strides(
src_mapping, inv_map,
*this, inv_map,
// HIP needs deduction guides to have markups so we need to be explicit
// NVCC 11.0 has a bug with deduction guide here, tested that 11.2 does not have the issue
#if defined(_MDSPAN_HAS_HIP) || (defined(__NVCC__) && (__CUDACC_VER_MAJOR__ * 100 + __CUDACC_VER_MINOR__ * 10) < 1120)
std::tuple<decltype(detail::stride_of(slices))...>{detail::stride_of(slices)...})),
#else
std::tuple{detail::stride_of(slices)...})),
#endif
static_cast<size_t>(src_mapping(detail::first_of(slices)...))};
static_cast<size_t>(this->operator()(detail::first_of(slices)...))};
}
#if defined(__NVCC__) && !defined(__CUDA_ARCH__) && defined(__GNUC__)
__builtin_unreachable();
Expand All @@ -263,12 +264,13 @@ submdspan_mapping(const layout_right::mapping<Extents> &src_mapping,
//**********************************
// layout_stride submdspan_mapping
//*********************************
template <class Extents, class... SliceSpecifiers>
template <class Extents>
template <class... SliceSpecifiers>
MDSPAN_INLINE_FUNCTION
constexpr auto
submdspan_mapping(const layout_stride::mapping<Extents> &src_mapping,
SliceSpecifiers... slices) {
auto dst_ext = submdspan_extents(src_mapping.extents(), slices...);
layout_stride::mapping<Extents>::submdspan_mapping_impl(
SliceSpecifiers... slices) const {
auto dst_ext = submdspan_extents(extents(), slices...);
using dst_ext_t = decltype(dst_ext);
auto inv_map = detail::inv_map_rank(
std::integral_constant<size_t,0>(),
Expand All @@ -277,14 +279,15 @@ submdspan_mapping(const layout_stride::mapping<Extents> &src_mapping,
using dst_mapping_t = typename layout_stride::template mapping<dst_ext_t>;
return submdspan_mapping_result<dst_mapping_t>{
dst_mapping_t(dst_ext, detail::construct_sub_strides(
src_mapping, inv_map,
*this, inv_map,
// HIP needs deduction guides to have markups so we need to be explicit
// NVCC 11.0 has a bug with deduction guide here, tested that 11.2 does not have the issue
#if defined(_MDSPAN_HAS_HIP) || (defined(__NVCC__) && (__CUDACC_VER_MAJOR__ * 100 + __CUDACC_VER_MINOR__ * 10) < 1120)
std::tuple<decltype(detail::stride_of(slices))...>(detail::stride_of(slices)...))),
#else
std::tuple(detail::stride_of(slices)...))),
#endif
static_cast<size_t>(src_mapping(detail::first_of(slices)...))};
static_cast<size_t>(this->operator()(detail::first_of(slices)...))};
}

} // namespace MDSPAN_IMPL_STANDARD_NAMESPACE

0 comments on commit 042cbbc

Please sign in to comment.