Skip to content

Commit

Permalink
Merge pull request #283 from crtrott/submdspan_updates
Browse files Browse the repository at this point in the history
Submdspan bring in line with latest C++ draft (hidden friends)
  • Loading branch information
crtrott authored Oct 19, 2023
2 parents 6d31a92 + 38c15d9 commit 33a0b93
Show file tree
Hide file tree
Showing 11 changed files with 127 additions and 118 deletions.
54 changes: 26 additions & 28 deletions compilation_tests/ctest_constexpr_submdspan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

#include <mdspan/mdspan.hpp>

namespace KokkosEx = MDSPAN_IMPL_STANDARD_NAMESPACE::MDSPAN_IMPL_PROPOSED_NAMESPACE;

// Only works with newer constexpr
#if defined(_MDSPAN_USE_CONSTEXPR_14) && _MDSPAN_USE_CONSTEXPR_14

Expand All @@ -32,7 +30,7 @@ dynamic_extent_1d() {
auto s = Kokkos::mdspan<int, Kokkos::dextents<size_t,1>, Layout>(data, 5);
int result = 0;
for (size_t i = 0; i < s.extent(0); ++i) {
auto ss = KokkosEx::submdspan(s, i);
auto ss = Kokkos::submdspan(s, i);
result += __MDSPAN_OP0(ss);
}
// 1 + 2 + 3 + 4 + 5
Expand All @@ -57,7 +55,7 @@ dynamic_extent_1d_all_slice() {
auto s = Kokkos::mdspan<
int, Kokkos::extents<size_t,Kokkos::dynamic_extent>, Layout>(data, 5);
int result = 0;
auto ss = KokkosEx::submdspan(s, Kokkos::full_extent);
auto ss = Kokkos::submdspan(s, Kokkos::full_extent);
for (size_t i = 0; i < s.extent(0); ++i) {
result += __MDSPAN_OP(ss, i);
}
Expand All @@ -82,7 +80,7 @@ dynamic_extent_1d_pair_full() {
auto s = Kokkos::mdspan<
int, Kokkos::extents<size_t,Kokkos::dynamic_extent>, Layout>(data, 5);
int result = 0;
auto ss = KokkosEx::submdspan(s, std::pair<std::ptrdiff_t, std::ptrdiff_t>{0, 5});
auto ss = Kokkos::submdspan(s, std::pair<std::ptrdiff_t, std::ptrdiff_t>{0, 5});
for (size_t i = 0; i < s.extent(0); ++i) {
result += __MDSPAN_OP(ss, i);
}
Expand All @@ -101,7 +99,7 @@ dynamic_extent_1d_pair_each() {
int, Kokkos::extents<size_t,Kokkos::dynamic_extent>, Layout>(data, 5);
int result = 0;
for (size_t i = 0; i < s.extent(0); ++i) {
auto ss = KokkosEx::submdspan(s,
auto ss = Kokkos::submdspan(s,
std::pair<std::ptrdiff_t, std::ptrdiff_t>{i, i+1});
result += __MDSPAN_OP(ss, 0);
}
Expand All @@ -127,11 +125,11 @@ dynamic_extent_1d_all_three() {
int data[] = {1, 2, 3, 4, 5};
auto s = Kokkos::mdspan<
int, Kokkos::extents<size_t,Kokkos::dynamic_extent>, Layout>(data, 5);
auto s1 = KokkosEx::submdspan(s, std::pair<std::ptrdiff_t, std::ptrdiff_t>{0, 5});
auto s2 = KokkosEx::submdspan(s1, Kokkos::full_extent);
auto s1 = Kokkos::submdspan(s, std::pair<std::ptrdiff_t, std::ptrdiff_t>{0, 5});
auto s2 = Kokkos::submdspan(s1, Kokkos::full_extent);
int result = 0;
for (size_t i = 0; i < s.extent(0); ++i) {
auto ss = KokkosEx::submdspan(s2, i);
auto ss = Kokkos::submdspan(s2, i);
result += __MDSPAN_OP0(ss);
}
constexpr_assert_equal(15, result);
Expand All @@ -157,7 +155,7 @@ dynamic_extent_2d_idx_idx() {
int result = 0;
for(size_t row = 0; row < s.extent(0); ++row) {
for(size_t col = 0; col < s.extent(1); ++col) {
auto ss = KokkosEx::submdspan(s, row, col);
auto ss = Kokkos::submdspan(s, row, col);
result += __MDSPAN_OP0(ss);
}
}
Expand All @@ -176,9 +174,9 @@ dynamic_extent_2d_idx_all_idx() {
data, 2, 3);
int result = 0;
for(size_t row = 0; row < s.extent(0); ++row) {
auto srow = KokkosEx::submdspan(s, row, Kokkos::full_extent);
auto srow = Kokkos::submdspan(s, row, Kokkos::full_extent);
for(size_t col = 0; col < s.extent(1); ++col) {
auto scol = KokkosEx::submdspan(srow, col);
auto scol = Kokkos::submdspan(srow, col);
constexpr_assert_equal(__MDSPAN_OP0(scol), __MDSPAN_OP(srow, col));
result += __MDSPAN_OP0(scol);
}
Expand All @@ -205,9 +203,9 @@ simple_static_submdspan_test_1(int add_to_row) {
auto s = Kokkos::mdspan<int, Kokkos::extents<size_t,3, 3>>(data);
int result = 0;
for(int col = 0; col < 3; ++col) {
auto scol = KokkosEx::submdspan(s, Kokkos::full_extent, col);
auto scol = Kokkos::submdspan(s, Kokkos::full_extent, col);
for(int row = 0; row < 3; ++row) {
auto srow = KokkosEx::submdspan(scol, row);
auto srow = Kokkos::submdspan(scol, row);
result += __MDSPAN_OP0(srow) * (row + add_to_row);
}
}
Expand Down Expand Up @@ -247,18 +245,18 @@ mixed_submdspan_left_test_2() {
Kokkos::extents<size_t,3, Kokkos::dynamic_extent>, Kokkos::layout_left>(data, 5);
int result = 0;
for(int col = 0; col < 5; ++col) {
auto scol = KokkosEx::submdspan(s, Kokkos::full_extent, col);
auto scol = Kokkos::submdspan(s, Kokkos::full_extent, col);
for(int row = 0; row < 3; ++row) {
auto srow = KokkosEx::submdspan(scol, row);
auto srow = Kokkos::submdspan(scol, row);
result += __MDSPAN_OP0(srow) * (row + 1);
}
}
// 1 + 2 + 3 + 2*(4 + 5 + 6) + 3*(7 + 8 + 9)= 108
constexpr_assert_equal(108, result);
for(int row = 0; row < 3; ++row) {
auto srow = KokkosEx::submdspan(s, row, Kokkos::full_extent);
auto srow = Kokkos::submdspan(s, row, Kokkos::full_extent);
for(int col = 0; col < 5; ++col) {
auto scol = KokkosEx::submdspan(srow, col);
auto scol = Kokkos::submdspan(srow, col);
result += __MDSPAN_OP0(scol) * (row + 1);
}
}
Expand Down Expand Up @@ -290,17 +288,17 @@ mixed_submdspan_test_3() {
int, Kokkos::extents<size_t,3, Kokkos::dynamic_extent>, Layout>(data, 5);
int result = 0;
for(int col = 0; col < 5; ++col) {
auto scol = KokkosEx::submdspan(s, Kokkos::full_extent, col);
auto scol = Kokkos::submdspan(s, Kokkos::full_extent, col);
for(int row = 0; row < 3; ++row) {
auto srow = KokkosEx::submdspan(scol, row);
auto srow = Kokkos::submdspan(scol, row);
result += __MDSPAN_OP0(srow) * (row + 1);
}
}
constexpr_assert_equal(71, result);
for(int row = 0; row < 3; ++row) {
auto srow = KokkosEx::submdspan(s, row, Kokkos::full_extent);
auto srow = Kokkos::submdspan(s, row, Kokkos::full_extent);
for(int col = 0; col < 5; ++col) {
auto scol = KokkosEx::submdspan(srow, col);
auto scol = Kokkos::submdspan(srow, col);
result += __MDSPAN_OP0(scol) * (row + 1);
}
}
Expand Down Expand Up @@ -338,14 +336,14 @@ submdspan_single_element_stress_test_impl_2(
int data[] = { 42 };
auto s = mdspan_t(data);
auto s_dyn = dyn_mdspan_t(data, _repeated_ptrdiff_t<1, Idxs>...);
auto ss = KokkosEx::submdspan(s, _repeated_ptrdiff_t<0, Idxs>...);
auto ss_dyn = KokkosEx::submdspan(s_dyn, _repeated_ptrdiff_t<0, Idxs>...);
auto ss_all = KokkosEx::submdspan(s, _repeated_with_idxs_t<Kokkos::full_extent_t, Idxs>{}...);
auto ss_all_dyn = KokkosEx::submdspan(s_dyn, _repeated_with_idxs_t<Kokkos::full_extent_t, Idxs>{}...);
auto ss = Kokkos::submdspan(s, _repeated_ptrdiff_t<0, Idxs>...);
auto ss_dyn = Kokkos::submdspan(s_dyn, _repeated_ptrdiff_t<0, Idxs>...);
auto ss_all = Kokkos::submdspan(s, _repeated_with_idxs_t<Kokkos::full_extent_t, Idxs>{}...);
auto ss_all_dyn = Kokkos::submdspan(s_dyn, _repeated_with_idxs_t<Kokkos::full_extent_t, Idxs>{}...);
auto val = __MDSPAN_OP(ss_all, (_repeated_ptrdiff_t<0, Idxs>...));
auto val_dyn = __MDSPAN_OP(ss_all_dyn, (_repeated_ptrdiff_t<0, Idxs>...));
auto ss_pair = KokkosEx::submdspan(s, _repeated_with_idxs_t<std::pair<ptrdiff_t, ptrdiff_t>, Idxs>{0, 1}...);
auto ss_pair_dyn = KokkosEx::submdspan(s_dyn, _repeated_with_idxs_t<std::pair<ptrdiff_t, ptrdiff_t>, Idxs>{0, 1}...);
auto ss_pair = Kokkos::submdspan(s, _repeated_with_idxs_t<std::pair<ptrdiff_t, ptrdiff_t>, Idxs>{0, 1}...);
auto ss_pair_dyn = Kokkos::submdspan(s_dyn, _repeated_with_idxs_t<std::pair<ptrdiff_t, ptrdiff_t>, Idxs>{0, 1}...);
auto val_pair = __MDSPAN_OP(ss_pair, (_repeated_ptrdiff_t<0, Idxs>...));
auto val_pair_dyn = __MDSPAN_OP(ss_pair_dyn, (_repeated_ptrdiff_t<0, Idxs>...));
constexpr_assert_equal(42, ss());
Expand Down
10 changes: 10 additions & 0 deletions include/experimental/__p0009_bits/layout_left.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,16 @@ class layout_left::mapping {
private:
_MDSPAN_NO_UNIQUE_ADDRESS extents_type __extents{};

// [mdspan.submdspan.mapping], submdspan mapping specialization
template<class... SliceSpecifiers>
constexpr auto submdspan_mapping_impl(
SliceSpecifiers... slices) const;

template<class... SliceSpecifiers>
friend constexpr auto submdspan_mapping(
const mapping& src, SliceSpecifiers... slices) {
return src.submdspan_mapping_impl(slices...);
}
};


Expand Down
10 changes: 10 additions & 0 deletions include/experimental/__p0009_bits/layout_right.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,16 @@ class layout_right::mapping {
private:
_MDSPAN_NO_UNIQUE_ADDRESS extents_type __extents{};

// [mdspan.submdspan.mapping], submdspan mapping specialization
template<class... SliceSpecifiers>
constexpr auto submdspan_mapping_impl(
SliceSpecifiers... slices) const;

template<class... SliceSpecifiers>
friend constexpr auto submdspan_mapping(
const mapping& src, SliceSpecifiers... slices) {
return src.submdspan_mapping_impl(slices...);
}
};

} // end namespace MDSPAN_IMPL_STANDARD_NAMESPACE
Expand Down
10 changes: 10 additions & 0 deletions include/experimental/__p0009_bits/layout_stride.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,16 @@ struct layout_stride {
}
#endif

// [mdspan.submdspan.mapping], submdspan mapping specialization
template<class... SliceSpecifiers>
constexpr auto submdspan_mapping_impl(
SliceSpecifiers... slices) const;

template<class... SliceSpecifiers>
friend constexpr auto submdspan_mapping(
const mapping& src, SliceSpecifiers... slices) {
return src.submdspan_mapping_impl(slices...);
}
};
};

Expand Down
9 changes: 4 additions & 5 deletions include/experimental/__p2630_bits/strided_slice.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#include <type_traits>

namespace MDSPAN_IMPL_STANDARD_NAMESPACE {
namespace MDSPAN_IMPL_PROPOSED_NAMESPACE {

namespace {
template<class T>
Expand All @@ -29,21 +28,21 @@ namespace {
template<class T, T val>
struct __mdspan_is_integral_constant<std::integral_constant<T,val>>: std::true_type {};
}

// Slice Specifier allowing for strides and compile time extent
template <class OffsetType, class ExtentType, class StrideType>
struct strided_slice {
using offset_type = OffsetType;
using extent_type = ExtentType;
using stride_type = StrideType;

OffsetType offset;
ExtentType extent;
StrideType stride;
_MDSPAN_NO_UNIQUE_ADDRESS OffsetType offset{};
_MDSPAN_NO_UNIQUE_ADDRESS ExtentType extent{};
_MDSPAN_NO_UNIQUE_ADDRESS StrideType stride{};

static_assert(std::is_integral_v<OffsetType> || __mdspan_is_integral_constant<OffsetType>::value);
static_assert(std::is_integral_v<ExtentType> || __mdspan_is_integral_constant<ExtentType>::value);
static_assert(std::is_integral_v<StrideType> || __mdspan_is_integral_constant<StrideType>::value);
};

} // MDSPAN_IMPL_PROPOSED_NAMESPACE
} // MDSPAN_IMPL_STANDARD_NAMESPACE
10 changes: 4 additions & 6 deletions include/experimental/__p2630_bits/submdspan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,21 @@
#include "submdspan_mapping.hpp"

namespace MDSPAN_IMPL_STANDARD_NAMESPACE {
namespace MDSPAN_IMPL_PROPOSED_NAMESPACE {
template <class ElementType, class Extents, class LayoutPolicy,
class AccessorPolicy, class... SliceSpecifiers>
MDSPAN_INLINE_FUNCTION
constexpr auto
submdspan(const mdspan<ElementType, Extents, LayoutPolicy, AccessorPolicy> &src,
SliceSpecifiers... slices) {
const auto sub_mapping_offset = submdspan_mapping(src.mapping(), slices...);
const auto sub_submdspan_mapping_result = submdspan_mapping(src.mapping(), slices...);
// NVCC has a problem with the deduction so lets figure out the type
using sub_mapping_t = std::remove_cv_t<decltype(sub_mapping_offset.mapping)>;
using sub_mapping_t = std::remove_cv_t<decltype(sub_submdspan_mapping_result.mapping)>;
using sub_extents_t = typename sub_mapping_t::extents_type;
using sub_layout_t = typename sub_mapping_t::layout_type;
using sub_accessor_t = typename AccessorPolicy::offset_policy;
return mdspan<ElementType, sub_extents_t, sub_layout_t, sub_accessor_t>(
src.accessor().offset(src.data_handle(), sub_mapping_offset.offset),
sub_mapping_offset.mapping,
src.accessor().offset(src.data_handle(), sub_submdspan_mapping_result.offset),
sub_submdspan_mapping_result.mapping,
sub_accessor_t(src.accessor()));
}
} // namespace MDSPAN_IMPL_PROPOSED_NAMESPACE
} // namespace MDSPAN_IMPL_STANDARD_NAMESPACE
2 changes: 0 additions & 2 deletions include/experimental/__p2630_bits/submdspan_extents.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

#include "strided_slice.hpp"
namespace MDSPAN_IMPL_STANDARD_NAMESPACE {
namespace MDSPAN_IMPL_PROPOSED_NAMESPACE {
namespace detail {

// Mapping from submapping ranks to srcmapping ranks
Expand Down Expand Up @@ -319,5 +318,4 @@ constexpr auto submdspan_extents(const extents<IndexType, Extents...> &src_exts,
return detail::extents_constructor<ext_t::rank(), ext_t>::next_extent(
src_exts, slices...);
}
} // namespace MDSPAN_IMPL_PROPOSED_NAMESPACE
} // namespace MDSPAN_IMPL_STANDARD_NAMESPACE
Loading

0 comments on commit 33a0b93

Please sign in to comment.