From 0749914220c15d32f3ab2d134b6ce1fda71dd74d Mon Sep 17 00:00:00 2001
From: Hui Zhou <hzhou321@anl.gov>
Date: Mon, 12 Aug 2024 17:27:49 -0500
Subject: [PATCH] coll: MPIR_THREADCOMM_RANK_SIZE to check coll_attr

Enhance the macro MPIR_THREADCOMM_RANK_SIZE to check coll_attr for rank
and size.
---
 src/include/mpir_threadcomm.h                         | 11 +++++------
 src/mpi/coll/allgather/allgather_intra_brucks.c       |  2 +-
 src/mpi/coll/allgatherv/allgatherv_intra_brucks.c     |  2 +-
 .../allreduce/allreduce_intra_recursive_doubling.c    |  2 +-
 src/mpi/coll/alltoall/alltoall_intra_brucks.c         |  2 +-
 src/mpi/coll/alltoallv/alltoallv_intra_scattered.c    |  2 +-
 src/mpi/coll/alltoallw/alltoallw_intra_scattered.c    |  2 +-
 src/mpi/coll/barrier/barrier_intra_k_dissemination.c  |  2 +-
 src/mpi/coll/bcast/bcast_intra_binomial.c             |  2 +-
 src/mpi/coll/exscan/exscan_intra_recursive_doubling.c |  2 +-
 src/mpi/coll/gather/gather_intra_binomial.c           |  2 +-
 src/mpi/coll/gatherv/gatherv_allcomm_linear.c         |  2 +-
 src/mpi/coll/reduce/reduce_intra_binomial.c           |  2 +-
 .../reduce_scatter_intra_recursive_halving.c          |  2 +-
 .../reduce_scatter_block_intra_recursive_halving.c    |  2 +-
 src/mpi/coll/scan/scan_intra_recursive_doubling.c     |  2 +-
 src/mpi/coll/scatter/scatter_intra_binomial.c         |  2 +-
 src/mpi/coll/scatterv/scatterv_allcomm_linear.c       |  2 +-
 18 files changed, 22 insertions(+), 23 deletions(-)

diff --git a/src/include/mpir_threadcomm.h b/src/include/mpir_threadcomm.h
index cda298f1f9e..02c974795ef 100644
--- a/src/include/mpir_threadcomm.h
+++ b/src/include/mpir_threadcomm.h
@@ -111,23 +111,22 @@ MPL_STATIC_INLINE_PREFIX
 }
 
 #ifdef ENABLE_THREADCOMM
-#define MPIR_THREADCOMM_RANK_SIZE(comm, rank_, size_) do { \
+#define MPIR_THREADCOMM_RANK_SIZE(comm, coll_attr, rank_, size_) do { \
         MPIR_Threadcomm *threadcomm = (comm)->threadcomm; \
         if (threadcomm) { \
+            MPIR_Assert(MPIR_COLL_ATTR_GET_SUBGROUP(coll_attr) == 0); /* for now */ \
             int intracomm_size = (comm)->local_size; \
             size_ = threadcomm->rank_offset_table[intracomm_size - 1]; \
             rank_ = MPIR_THREADCOMM_TID_TO_RANK(threadcomm, MPIR_threadcomm_get_tid(threadcomm)); \
         } else { \
-            rank_ = (comm)->rank; \
-            size_ = (comm)->local_size; \
+            MPIR_COLL_INTRA_RANK_SIZE(comm, coll_attr, rank_, size_); \
         } \
     } while (0)
 
 #else
-#define MPIR_THREADCOMM_RANK_SIZE(comm, rank_, size_) do { \
+#define MPIR_THREADCOMM_RANK_SIZE(comm, coll_attr, rank_, size_) do { \
         MPIR_Assert((comm)->threadcomm == NULL); \
-        rank_ = (comm)->rank; \
-        size_ = (comm)->local_size; \
+        MPIR_COLL_INTRA_RANK_SIZE(comm, coll_attr, rank_, size_); \
     } while (0)
 
 #endif
diff --git a/src/mpi/coll/allgather/allgather_intra_brucks.c b/src/mpi/coll/allgather/allgather_intra_brucks.c
index e4f4bc85537..d96647a72fc 100644
--- a/src/mpi/coll/allgather/allgather_intra_brucks.c
+++ b/src/mpi/coll/allgather/allgather_intra_brucks.c
@@ -34,7 +34,7 @@ int MPIR_Allgather_intra_brucks(const void *sendbuf,
     if (((sendcount == 0) && (sendbuf != MPI_IN_PLACE)) || (recvcount == 0))
         goto fn_exit;
 
-    MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size);
+    MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size);
 
     MPIR_Datatype_get_extent_macro(recvtype, recvtype_extent);
     MPIR_Datatype_get_size_macro(recvtype, recvtype_sz);
diff --git a/src/mpi/coll/allgatherv/allgatherv_intra_brucks.c b/src/mpi/coll/allgatherv/allgatherv_intra_brucks.c
index f9a1f85de57..c61b2d54905 100644
--- a/src/mpi/coll/allgatherv/allgatherv_intra_brucks.c
+++ b/src/mpi/coll/allgatherv/allgatherv_intra_brucks.c
@@ -34,7 +34,7 @@ int MPIR_Allgatherv_intra_brucks(const void *sendbuf,
     void *tmp_buf;
     MPIR_CHKLMEM_DECL(1);
 
-    MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size);
+    MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size);
 
     total_count = 0;
     for (i = 0; i < comm_size; i++)
diff --git a/src/mpi/coll/allreduce/allreduce_intra_recursive_doubling.c b/src/mpi/coll/allreduce/allreduce_intra_recursive_doubling.c
index 7cb1d6461e5..5679ee1155d 100644
--- a/src/mpi/coll/allreduce/allreduce_intra_recursive_doubling.c
+++ b/src/mpi/coll/allreduce/allreduce_intra_recursive_doubling.c
@@ -32,7 +32,7 @@ int MPIR_Allreduce_intra_recursive_doubling(const void *sendbuf,
     MPI_Aint true_extent, true_lb, extent;
     void *tmp_buf;
 
-    MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size);
+    MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size);
 
     is_commutative = MPIR_Op_is_commutative(op);
 
diff --git a/src/mpi/coll/alltoall/alltoall_intra_brucks.c b/src/mpi/coll/alltoall/alltoall_intra_brucks.c
index 815ae7ced9c..e5eeeaba9ac 100644
--- a/src/mpi/coll/alltoall/alltoall_intra_brucks.c
+++ b/src/mpi/coll/alltoall/alltoall_intra_brucks.c
@@ -37,7 +37,7 @@ int MPIR_Alltoall_intra_brucks(const void *sendbuf,
     void *tmp_buf;
     MPIR_CHKLMEM_DECL(6);
 
-    MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size);
+    MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size);
 
 #ifdef HAVE_ERROR_CHECKING
     MPIR_Assert(sendbuf != MPI_IN_PLACE);
diff --git a/src/mpi/coll/alltoallv/alltoallv_intra_scattered.c b/src/mpi/coll/alltoallv/alltoallv_intra_scattered.c
index 475caf0bb5a..c7b2d0d4cbf 100644
--- a/src/mpi/coll/alltoallv/alltoallv_intra_scattered.c
+++ b/src/mpi/coll/alltoallv/alltoallv_intra_scattered.c
@@ -37,7 +37,7 @@ int MPIR_Alltoallv_intra_scattered(const void *sendbuf, const MPI_Aint * sendcou
 
     MPIR_CHKLMEM_DECL(2);
 
-    MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size);
+    MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size);
 
     /* Get extent of recv type, but send type is only valid if (sendbuf!=MPI_IN_PLACE) */
     MPIR_Datatype_get_extent_macro(recvtype, recv_extent);
diff --git a/src/mpi/coll/alltoallw/alltoallw_intra_scattered.c b/src/mpi/coll/alltoallw/alltoallw_intra_scattered.c
index 77443fb8a8b..4233ccc6c42 100644
--- a/src/mpi/coll/alltoallw/alltoallw_intra_scattered.c
+++ b/src/mpi/coll/alltoallw/alltoallw_intra_scattered.c
@@ -36,7 +36,7 @@ int MPIR_Alltoallw_intra_scattered(const void *sendbuf, const MPI_Aint sendcount
     MPI_Aint type_size;
     MPIR_CHKLMEM_DECL(2);
 
-    MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size);
+    MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size);
 
 #ifdef HAVE_ERROR_CHECKING
     /* When MPI_IN_PLACE, we use pair-wise sendrecv_replace in order to conserve memory usage,
diff --git a/src/mpi/coll/barrier/barrier_intra_k_dissemination.c b/src/mpi/coll/barrier/barrier_intra_k_dissemination.c
index 859c3c961de..2ef13ce7894 100644
--- a/src/mpi/coll/barrier/barrier_intra_k_dissemination.c
+++ b/src/mpi/coll/barrier/barrier_intra_k_dissemination.c
@@ -21,7 +21,7 @@ int MPIR_Barrier_intra_dissemination(MPIR_Comm * comm_ptr, int coll_attr)
     int size, rank, src, dst, mask, mpi_errno = MPI_SUCCESS;
     int mpi_errno_ret = MPI_SUCCESS;
 
-    MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, size);
+    MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, size);
 
     mask = 0x1;
     while (mask < size) {
diff --git a/src/mpi/coll/bcast/bcast_intra_binomial.c b/src/mpi/coll/bcast/bcast_intra_binomial.c
index 01e780dc0e7..ba46680963f 100644
--- a/src/mpi/coll/bcast/bcast_intra_binomial.c
+++ b/src/mpi/coll/bcast/bcast_intra_binomial.c
@@ -32,7 +32,7 @@ int MPIR_Bcast_intra_binomial(void *buffer,
     void *tmp_buf = NULL;
     MPIR_CHKLMEM_DECL(1);
 
-    MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size);
+    MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size);
 
     if (HANDLE_IS_BUILTIN(datatype))
         is_contig = 1;
diff --git a/src/mpi/coll/exscan/exscan_intra_recursive_doubling.c b/src/mpi/coll/exscan/exscan_intra_recursive_doubling.c
index 4aa11821e2a..517146c7e25 100644
--- a/src/mpi/coll/exscan/exscan_intra_recursive_doubling.c
+++ b/src/mpi/coll/exscan/exscan_intra_recursive_doubling.c
@@ -59,7 +59,7 @@ int MPIR_Exscan_intra_recursive_doubling(const void *sendbuf,
     void *partial_scan, *tmp_buf;
     MPIR_CHKLMEM_DECL(2);
 
-    MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size);
+    MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size);
 
     is_commutative = MPIR_Op_is_commutative(op);
 
diff --git a/src/mpi/coll/gather/gather_intra_binomial.c b/src/mpi/coll/gather/gather_intra_binomial.c
index fd9bbab5297..aa9b8a07a10 100644
--- a/src/mpi/coll/gather/gather_intra_binomial.c
+++ b/src/mpi/coll/gather/gather_intra_binomial.c
@@ -58,7 +58,7 @@ int MPIR_Gather_intra_binomial(const void *sendbuf, MPI_Aint sendcount, MPI_Data
     MPIR_CHKLMEM_DECL(1);
 
 
-    MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size);
+    MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size);
 
     /* Use binomial tree algorithm. */
 
diff --git a/src/mpi/coll/gatherv/gatherv_allcomm_linear.c b/src/mpi/coll/gatherv/gatherv_allcomm_linear.c
index 0a3b3fdd1ab..7fe5cc13455 100644
--- a/src/mpi/coll/gatherv/gatherv_allcomm_linear.c
+++ b/src/mpi/coll/gatherv/gatherv_allcomm_linear.c
@@ -54,7 +54,7 @@ int MPIR_Gatherv_allcomm_linear(const void *sendbuf,
     MPI_Status *starray;
     MPIR_CHKLMEM_DECL(2);
 
-    MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size);
+    MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size);
 
     /* If rank == root, then I recv lots, otherwise I send */
     if (((comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) && (root == rank)) ||
diff --git a/src/mpi/coll/reduce/reduce_intra_binomial.c b/src/mpi/coll/reduce/reduce_intra_binomial.c
index 9c10e483115..d29310d95b6 100644
--- a/src/mpi/coll/reduce/reduce_intra_binomial.c
+++ b/src/mpi/coll/reduce/reduce_intra_binomial.c
@@ -24,7 +24,7 @@ int MPIR_Reduce_intra_binomial(const void *sendbuf,
     void *tmp_buf;
     MPIR_CHKLMEM_DECL(2);
 
-    MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size);
+    MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size);
 
     /* Create a temporary buffer */
 
diff --git a/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_halving.c b/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_halving.c
index 956ab8cfdfd..fdf7711939e 100644
--- a/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_halving.c
+++ b/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_halving.c
@@ -50,7 +50,7 @@ int MPIR_Reduce_scatter_intra_recursive_halving(const void *sendbuf, void *recvb
     int pof2, old_i, newrank;
     MPIR_CHKLMEM_DECL(5);
 
-    MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size);
+    MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size);
 
 #ifdef HAVE_ERROR_CHECKING
     {
diff --git a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_halving.c b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_halving.c
index 1669f1b5657..b7a1918528b 100644
--- a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_halving.c
+++ b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_halving.c
@@ -53,7 +53,7 @@ int MPIR_Reduce_scatter_block_intra_recursive_halving(const void *sendbuf,
     int pof2, old_i, newrank;
     MPIR_CHKLMEM_DECL(5);
 
-    MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size);
+    MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size);
 
 #ifdef HAVE_ERROR_CHECKING
     {
diff --git a/src/mpi/coll/scan/scan_intra_recursive_doubling.c b/src/mpi/coll/scan/scan_intra_recursive_doubling.c
index 55064face49..5d4c9cbf111 100644
--- a/src/mpi/coll/scan/scan_intra_recursive_doubling.c
+++ b/src/mpi/coll/scan/scan_intra_recursive_doubling.c
@@ -55,7 +55,7 @@ int MPIR_Scan_intra_recursive_doubling(const void *sendbuf,
     void *partial_scan, *tmp_buf;
     MPIR_CHKLMEM_DECL(2);
 
-    MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size);
+    MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size);
 
     is_commutative = MPIR_Op_is_commutative(op);
 
diff --git a/src/mpi/coll/scatter/scatter_intra_binomial.c b/src/mpi/coll/scatter/scatter_intra_binomial.c
index de95ab55dd6..940df5de1f7 100644
--- a/src/mpi/coll/scatter/scatter_intra_binomial.c
+++ b/src/mpi/coll/scatter/scatter_intra_binomial.c
@@ -42,7 +42,7 @@ int MPIR_Scatter_intra_binomial(const void *sendbuf, MPI_Aint sendcount, MPI_Dat
     int mpi_errno_ret = MPI_SUCCESS;
     MPIR_CHKLMEM_DECL(4);
 
-    MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size);
+    MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size);
 
     if (rank == root)
         MPIR_Datatype_get_extent_macro(sendtype, extent);
diff --git a/src/mpi/coll/scatterv/scatterv_allcomm_linear.c b/src/mpi/coll/scatterv/scatterv_allcomm_linear.c
index b0ca7f3184e..6031f21b87d 100644
--- a/src/mpi/coll/scatterv/scatterv_allcomm_linear.c
+++ b/src/mpi/coll/scatterv/scatterv_allcomm_linear.c
@@ -30,7 +30,7 @@ int MPIR_Scatterv_allcomm_linear(const void *sendbuf, const MPI_Aint * sendcount
     MPI_Status *starray;
     MPIR_CHKLMEM_DECL(2);
 
-    MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size);
+    MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size);
 
     /* If I'm the root, then scatter */
     if (((comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) && (root == rank)) ||