Skip to content

Commit

Permalink
Improve submdspan testing (#342)
Browse files Browse the repository at this point in the history
This now tests that the elements of the submdspan point to the correct
element of the src. That caught a mistake in the layout_foo test layout.
Disable bracket operator for icpx due to compiler crash just in the test config.
Note: for icpx 2024 we could turn it on again.
  • Loading branch information
crtrott authored Jun 10, 2024
1 parent 46f8270 commit 4d54ab0
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 21 deletions.
1 change: 1 addition & 0 deletions .github/workflows/cmake.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ jobs:
# To get new URL, look here:
# https://www.intel.com/content/www/us/en/developer/articles/tool/oneapi-standalone-components.html#inpage-nav-6-undefined
compiler_url: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/ebf5d9aa-17a7-46a4-b5df-ace004227c0e/l_dpcpp-cpp-compiler_p_2023.2.1.8_offline.sh
cxx_flags_extra: "-DMDSPAN_USE_BRACKET_OPERATOR=0"
- enable_benchmark: ON
- stdcxx: 14
enable_benchmark: OFF
Expand Down
2 changes: 1 addition & 1 deletion tests/foo_customizations.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ class layout_foo::mapping {
template<class Indx0, class Indx1>
MDSPAN_INLINE_FUNCTION
constexpr index_type operator()(Indx0 idx0, Indx1 idx1) const noexcept {
return static_cast<index_type>(idx0 * __extents.extent(0) + idx1);
return static_cast<index_type>(idx0 * __extents.extent(1) + idx1);
}

MDSPAN_INLINE_FUNCTION static constexpr bool is_always_unique() noexcept { return true; }
Expand Down
60 changes: 40 additions & 20 deletions tests/test_submdspan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,38 +224,58 @@ struct TestSubMDSpan<
return Kokkos::full_extent;
}

template<class SrcExtents, class SubExtents, class ... SliceArgs>
template<class SrcMDSpan, class SubMDSpan, size_t ... SrcIdx, size_t ... SubIdx, class ... SliceArgs>
MDSPAN_INLINE_FUNCTION
static bool match_expected_extents(int src_idx, int sub_idx, SrcExtents src_ext, SubExtents sub_ext, int, SliceArgs ... slices) {
return match_expected_extents(++src_idx, sub_idx, src_ext, sub_ext, slices...);
static bool check_submdspan_match(int src_idx, int sub_idx, SrcMDSpan src_mds, SubMDSpan sub_mds, std::index_sequence<SrcIdx...>, std::index_sequence<SubIdx...>, int, SliceArgs ... slices) {
return check_submdspan_match(++src_idx, sub_idx, src_mds, sub_mds, std::index_sequence<SrcIdx...,2>(), std::index_sequence<SubIdx...>(), slices...);
}
template<class SrcExtents, class SubExtents, class ... SliceArgs>
template<class SrcMDSpan, class SubMDSpan, size_t ... SrcIdx, size_t ... SubIdx, class ... SliceArgs>
MDSPAN_INLINE_FUNCTION
static bool match_expected_extents(int src_idx, int sub_idx, SrcExtents src_ext, SubExtents sub_ext, std::pair<int,int> p, SliceArgs ... slices) {
using idx_t = typename SubExtents::index_type;
return (sub_ext.extent(sub_idx)==static_cast<idx_t>(p.second-p.first)) && match_expected_extents(++src_idx, ++sub_idx, src_ext, sub_ext, slices...);
static bool check_submdspan_match(int src_idx, int sub_idx, SrcMDSpan src_mds, SubMDSpan sub_mds, std::index_sequence<SrcIdx...>, std::index_sequence<SubIdx...>, std::pair<int,int> p, SliceArgs ... slices) {
using idx_t = typename SubMDSpan::index_type;
return (sub_mds.extent(sub_idx)==static_cast<idx_t>(p.second-p.first)) && check_submdspan_match(++src_idx, ++sub_idx, src_mds, sub_mds, std::index_sequence<SrcIdx...,2>(), std::index_sequence<SubIdx...,1>(), slices...);
}
template<class SrcExtents, class SubExtents, class ... SliceArgs>
template<class SrcMDSpan, class SubMDSpan, size_t ... SrcIdx, size_t ... SubIdx, class ... SliceArgs>
MDSPAN_INLINE_FUNCTION
static bool match_expected_extents(int src_idx, int sub_idx, SrcExtents src_ext, SubExtents sub_ext,
static bool check_submdspan_match(int src_idx, int sub_idx, SrcMDSpan src_mds, SubMDSpan sub_mds, std::index_sequence<SrcIdx...>, std::index_sequence<SubIdx...>,
Kokkos::strided_slice<int,int,int> p, SliceArgs ... slices) {
using idx_t = typename SubExtents::index_type;
return (sub_ext.extent(sub_idx)==static_cast<idx_t>((p.extent+p.stride-1)/p.stride)) && match_expected_extents(++src_idx, ++sub_idx, src_ext, sub_ext, slices...);
using idx_t = typename SubMDSpan::index_type;
return (sub_mds.extent(sub_idx)==static_cast<idx_t>((p.extent+p.stride-1)/p.stride)) && check_submdspan_match(++src_idx, ++sub_idx, src_mds, sub_mds, std::index_sequence<SrcIdx...,3>(), std::index_sequence<SubIdx...,1>(), slices...);
}
template<class SrcExtents, class SubExtents, class ... SliceArgs>
template<class SrcMDSpan, class SubMDSpan, size_t ... SrcIdx, size_t ... SubIdx, class ... SliceArgs>
MDSPAN_INLINE_FUNCTION
static bool match_expected_extents(int src_idx, int sub_idx, SrcExtents src_ext, SubExtents sub_ext,
static bool check_submdspan_match(int src_idx, int sub_idx, SrcMDSpan src_mds, SubMDSpan sub_mds, std::index_sequence<SrcIdx...>, std::index_sequence<SubIdx...>,
Kokkos::strided_slice<int,std::integral_constant<int, 0>,std::integral_constant<int,0>>, SliceArgs ... slices) {
return (sub_ext.extent(sub_idx)==0) && match_expected_extents(++src_idx, ++sub_idx, src_ext, sub_ext, slices...);
return (sub_mds.extent(sub_idx)==0) && check_submdspan_match(++src_idx, ++sub_idx, src_mds, sub_mds, std::index_sequence<SrcIdx...,1>(), std::index_sequence<SubIdx...,0>(), slices...);
}
template<class SrcExtents, class SubExtents, class ... SliceArgs>
template<class SrcMDSpan, class SubMDSpan, size_t ... SrcIdx, size_t ... SubIdx, class ... SliceArgs>
MDSPAN_INLINE_FUNCTION
static bool match_expected_extents(int src_idx, int sub_idx, SrcExtents src_ext, SubExtents sub_ext, Kokkos::full_extent_t, SliceArgs ... slices) {
return (sub_ext.extent(sub_idx)==src_ext.extent(src_idx)) && match_expected_extents(++src_idx, ++sub_idx, src_ext, sub_ext, slices...);
static bool check_submdspan_match(int src_idx, int sub_idx, SrcMDSpan src_mds, SubMDSpan sub_mds, std::index_sequence<SrcIdx...>, std::index_sequence<SubIdx...>, Kokkos::full_extent_t, SliceArgs ... slices) {
return (sub_mds.extent(sub_idx)==src_mds.extent(src_idx)) && check_submdspan_match(++src_idx, ++sub_idx, src_mds, sub_mds, std::index_sequence<SrcIdx...,1>(), std::index_sequence<SubIdx...,1>(), slices...);
}
template<class SrcExtents, class SubExtents>
template<class SrcMDSpan, class SubMDSpan, size_t ... SrcIdx, size_t ... SubIdx>
MDSPAN_INLINE_FUNCTION
static bool match_expected_extents(int, int, SrcExtents, SubExtents) { return true; }
static bool check_submdspan_match(int, int, SrcMDSpan src_mds, SubMDSpan sub_mds, std::index_sequence<SrcIdx...>, std::index_sequence<SubIdx...>) {
#if MDSPAN_USE_BRACKET_OPERATOR
if constexpr (SrcMDSpan::rank() == 0) {
return (&src_mds[]==&sub_mds[]);
} else if constexpr (SubMDSpan::rank() == 0) {
return (&src_mds[SrcIdx...]==&sub_mds[]);
} else {
if(sub_mds.size() == 0) return true;
return (&src_mds[SrcIdx...]==&sub_mds[SubIdx...]);
}
#else
if constexpr (SrcMDSpan::rank() == 0) {
return (&src_mds()==&sub_mds());
} else if constexpr (SubMDSpan::rank() == 0) {
return (&src_mds(SrcIdx...)==&sub_mds());
} else {
if(sub_mds.size() == 0) return true;
return (&src_mds(SrcIdx...)==&sub_mds(SubIdx...));
}
#endif
}

static void run() {
typename mds_org_t::mapping_type map(typename mds_org_t::extents_type(ConstrArgs...));
Expand All @@ -265,7 +285,7 @@ struct TestSubMDSpan<

dispatch([=] _MDSPAN_HOST_DEVICE () {
auto sub = Kokkos::submdspan(src, create_slice_arg(SubArgs())...);
bool match = match_expected_extents(0, 0, src.extents(), sub.extents(), create_slice_arg(SubArgs())...);
bool match = check_submdspan_match(0, 0, src, sub, std::index_sequence<>(), std::index_sequence<>(), create_slice_arg(SubArgs())...);
result[0] = match?1:0;
});
EXPECT_EQ(result[0], 1);
Expand Down

0 comments on commit 4d54ab0

Please sign in to comment.