Skip to content

Commit

Permalink
Merge a fix for is_communicator_group
Browse files Browse the repository at this point in the history
The default for `is_communicator_group` is now set to `std::false_type`, so
only actual communicator groups are specialized to `std::true_type`.
Tests are also added to prevent the same mistake from happening in the future.

Related PR: #1645
  • Loading branch information
thoasm authored Jul 18, 2024
2 parents e34e43f + 14f0e24 commit 6ed7108
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 7 deletions.
2 changes: 1 addition & 1 deletion cuda/components/cooperative_groups.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ struct is_synchronizable_group_impl : std::false_type {};


template <typename T>
struct is_communicator_group_impl : std::true_type {};
struct is_communicator_group_impl : std::false_type {};

} // namespace detail

Expand Down
36 changes: 36 additions & 0 deletions cuda/test/components/cooperative_groups.cu
Original file line number Diff line number Diff line change
Expand Up @@ -223,4 +223,40 @@ TEST_F(CooperativeGroups, SubwarpBallot) { test(cg_subwarp_ballot); }
TEST_F(CooperativeGroups, SubwarpBallot2) { test_subwarp(cg_subwarp_ballot); }
__global__ void cg_communicator_categorization(bool*)
{
auto this_block = group::this_thread_block();
auto tiled_partition =
group::tiled_partition<config::warp_size>(this_block);
auto subwarp_partition = group::tiled_partition<subwarp_size>(this_block);
using not_group = int;
using this_block_t = decltype(this_block);
using tiled_partition_t = decltype(tiled_partition);
using subwarp_partition_t = decltype(subwarp_partition);
static_assert(!group::is_group<not_group>::value &&
group::is_group<this_block_t>::value &&
group::is_group<tiled_partition_t>::value &&
group::is_group<subwarp_partition_t>::value,
"Group check doesn't work.");
static_assert(
!group::is_synchronizable_group<not_group>::value &&
group::is_synchronizable_group<this_block_t>::value &&
group::is_synchronizable_group<tiled_partition_t>::value &&
group::is_synchronizable_group<subwarp_partition_t>::value,
"Synchronizable group check doesn't work.");
static_assert(!group::is_communicator_group<not_group>::value &&
!group::is_communicator_group<this_block_t>::value &&
group::is_communicator_group<tiled_partition_t>::value &&
group::is_communicator_group<subwarp_partition_t>::value,
"Communicator group check doesn't work.");
}
TEST_F(CooperativeGroups, CorrectCategorization)
{
test(cg_communicator_categorization);
}
} // namespace
2 changes: 1 addition & 1 deletion dpcpp/components/cooperative_groups.dp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ struct is_synchronizable_group_impl : std::false_type {};


template <typename T>
struct is_communicator_group_impl : std::true_type {};
struct is_communicator_group_impl : std::false_type {};


} // namespace detail
Expand Down
43 changes: 43 additions & 0 deletions dpcpp/test/components/cooperative_groups.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,49 @@ GKO_ENABLE_DEFAULT_CONFIG_CALL(cg_ballot_call, cg_ballot, default_config_list)
TEST_P(CooperativeGroups, Ballot) { test_all_subgroup(cg_ballot_call<bool*>); }


template <typename cfg>
void cg_communicator_categorization(bool* s, sycl::nd_item<3> item_ct1)
{
auto this_block = group::this_thread_block(item_ct1);
auto tiled_partition =
group::tiled_partition<cfg::subgroup_size>(this_block);

using not_group = int;
using this_block_t = decltype(this_block);
using tiled_partition_t = decltype(tiled_partition);

static_assert(!group::is_group<not_group>::value &&
group::is_group<this_block_t>::value &&
group::is_group<tiled_partition_t>::value,
"Group check doesn't work.");
static_assert(!group::is_synchronizable_group<not_group>::value &&
group::is_synchronizable_group<this_block_t>::value &&
group::is_synchronizable_group<tiled_partition_t>::value,
"Synchronizable group check doesn't work.");
static_assert(!group::is_communicator_group<not_group>::value &&
!group::is_communicator_group<this_block_t>::value &&
group::is_communicator_group<tiled_partition_t>::value,
"Communicator group check doesn't work.");
// Make it work with the test framework, which performs 3 tests
s[this_block.thread_rank()] = true;
s[this_block.thread_rank() + cfg::subgroup_size] = true;
s[this_block.thread_rank() + 2 * cfg::subgroup_size] = true;
}

GKO_ENABLE_DEFAULT_HOST_CONFIG_TYPE(cg_communicator_categorization,
cg_communicator_categorization)
GKO_ENABLE_IMPLEMENTATION_CONFIG_SELECTION_TOTYPE(
cg_communicator_categorization, cg_communicator_categorization, DCFG_1D)
GKO_ENABLE_DEFAULT_CONFIG_CALL(cg_communicator_categorization_call,
cg_communicator_categorization,
default_config_list)

TEST_P(CooperativeGroups, CorrectCategorization)
{
test_all_subgroup(cg_communicator_categorization_call<bool*>);
}


INSTANTIATE_TEST_SUITE_P(DifferentSubgroup, CooperativeGroups,
testing::Values(4, 8, 16, 32, 64),
testing::PrintToStringParamName());
Expand Down
11 changes: 6 additions & 5 deletions hip/components/cooperative_groups.hip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ struct is_synchronizable_group_impl : std::false_type {};


template <typename T>
struct is_communicator_group_impl : std::true_type {};
struct is_communicator_group_impl : std::false_type {};

} // namespace detail

Expand Down Expand Up @@ -370,12 +370,13 @@ namespace detail {


template <unsigned Size>
struct is_group_impl<thread_block_tile<Size>> : std::true_type {};
struct is_group_impl<group::thread_block_tile<Size>> : std::true_type {};
template <unsigned Size>
struct is_synchronizable_group_impl<thread_block_tile<Size>> : std::true_type {
};
struct is_synchronizable_group_impl<group::thread_block_tile<Size>>
: std::true_type {};
template <unsigned Size>
struct is_communicator_group_impl<thread_block_tile<Size>> : std::true_type {};
struct is_communicator_group_impl<group::thread_block_tile<Size>>
: std::true_type {};


} // namespace detail
Expand Down
36 changes: 36 additions & 0 deletions hip/test/components/cooperative_groups.hip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,42 @@ TEST_F(CooperativeGroups, SubwarpBallot) { test(cg_subwarp_ballot); }
TEST_F(CooperativeGroups, SubwarpBallot2) { test_subwarp(cg_subwarp_ballot); }


__global__ void cg_communicator_categorization(bool*)
{
auto this_block = group::this_thread_block();
auto tiled_partition =
group::tiled_partition<config::warp_size>(this_block);
auto subwarp_partition = group::tiled_partition<subwarp_size>(this_block);

using not_group = int;
using this_block_t = decltype(this_block);
using tiled_partition_t = decltype(tiled_partition);
using subwarp_partition_t = decltype(subwarp_partition);

static_assert(!group::is_group<not_group>::value &&
group::is_group<this_block_t>::value &&
group::is_group<tiled_partition_t>::value &&
group::is_group<subwarp_partition_t>::value,
"Group check doesn't work.");
static_assert(
!group::is_synchronizable_group<not_group>::value &&
group::is_synchronizable_group<this_block_t>::value &&
group::is_synchronizable_group<tiled_partition_t>::value &&
group::is_synchronizable_group<subwarp_partition_t>::value,
"Synchronizable group check doesn't work.");
static_assert(!group::is_communicator_group<not_group>::value &&
!group::is_communicator_group<this_block_t>::value &&
group::is_communicator_group<tiled_partition_t>::value &&
group::is_communicator_group<subwarp_partition_t>::value,
"Communicator group check doesn't work.");
}

TEST_F(CooperativeGroups, CorrectCategorization)
{
test(cg_communicator_categorization);
}


template <typename ValueType>
__global__ void cg_shuffle_sum(const int num, ValueType* __restrict__ value)
{
Expand Down

0 comments on commit 6ed7108

Please sign in to comment.