Skip to content

Commit b0512bf

Browse files
authored
Feat: Implement Wrappers of MPI_COMM_CREATE and MPI_GROUP_RANGE_INCL (#132)
* Feat: Implement Wrappers of MPI_COMM_CREATE and MPI_GROUP_RANGE_INCL * Update test * Fix: Support 2D ranges array * Apply code review
1 parent 9db0270 commit b0512bf

File tree

4 files changed

+139
-5
lines changed

4 files changed

+139
-5
lines changed

src/mpi.f90

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ module mpi
1717
integer, parameter :: MPI_SUCCESS = 0
1818

1919
integer, parameter :: MPI_COMM_WORLD = -1000
20+
integer, parameter :: MPI_COMM_NULL = -1001
2021
real(8), parameter :: MPI_IN_PLACE = -1002
2122
integer, parameter :: MPI_SUM = -2300
2223
integer, parameter :: MPI_MAX = -2301
@@ -49,6 +50,10 @@ module mpi
4950
module procedure MPI_Comm_Group_proc
5051
end interface MPI_Comm_Group
5152

53+
interface MPI_Comm_create
54+
module procedure MPI_Comm_create_proc
55+
end interface MPI_Comm_create
56+
5257
interface MPI_Group_free
5358
module procedure MPI_Group_free_proc
5459
end interface MPI_Group_free
@@ -57,6 +62,11 @@ module mpi
5762
module procedure MPI_Group_size_proc
5863
end interface MPI_Group_size
5964

65+
interface MPI_Group_range_incl
66+
module procedure MPI_Group_range_incl_proc
67+
end interface MPI_Group_range_incl
68+
69+
6070
interface MPI_Comm_dup
6171
module procedure MPI_Comm_dup_proc
6272
end interface MPI_Comm_dup
@@ -175,6 +185,16 @@ integer(kind=MPI_HANDLE_KIND) function handle_mpi_comm_f2c(comm_f) result(c_comm
175185
end if
176186
end function handle_mpi_comm_f2c
177187

188+
integer(kind=MPI_HANDLE_KIND) function handle_mpi_comm_c2f(comm_c) result(f_comm)
189+
use mpi_c_bindings, only: c_mpi_comm_c2f, c_mpi_comm_null
190+
integer(kind=mpi_handle_kind), intent(in) :: comm_c
191+
if (comm_c == c_mpi_comm_null) then
192+
f_comm = MPI_COMM_NULL
193+
else
194+
f_comm = c_mpi_comm_c2f(comm_c)
195+
end if
196+
end function handle_mpi_comm_c2f
197+
178198
integer(kind=MPI_HANDLE_KIND) function handle_mpi_info_f2c(info_f) result(c_info)
179199
use mpi_c_bindings, only: c_mpi_info_f2c, c_mpi_info_null
180200
integer, intent(in) :: info_f
@@ -350,6 +370,51 @@ subroutine MPI_Group_free_proc(group, ierror)
350370
end if
351371
end subroutine MPI_Group_free_proc
352372

373+
subroutine MPI_Group_range_incl_proc(group, n, ranks, newgroup, ierror)
374+
use mpi_c_bindings, only: c_mpi_group_range_incl, c_mpi_group_f2c, c_mpi_comm_c2f, c_mpi_group_c2f
375+
use iso_c_binding, only: c_int, c_ptr
376+
integer, intent(in) :: group
377+
integer, intent(in) :: n
378+
integer, dimension(:,:), intent(in) :: ranks
379+
integer, intent(out) :: newgroup
380+
integer, optional, intent(out) :: ierror
381+
integer(kind=MPI_HANDLE_KIND) :: c_group, c_newgroup
382+
integer(c_int) :: local_ierr
383+
384+
c_group = c_mpi_group_f2c(group)
385+
local_ierr = c_mpi_group_range_incl(c_group, n, ranks, c_newgroup)
386+
newgroup = c_mpi_group_c2f(c_newgroup)
387+
388+
if (present(ierror)) then
389+
ierror = local_ierr
390+
else if (local_ierr /= MPI_SUCCESS) then
391+
print *, "MPI_Group_incl failed with error code: ", local_ierr
392+
end if
393+
end subroutine MPI_Group_range_incl_proc
394+
395+
subroutine MPI_Comm_create_proc(comm, group, newcomm, ierror)
396+
use mpi_c_bindings, only: c_mpi_comm_create, c_mpi_comm_f2c, c_mpi_comm_c2f, c_mpi_group_f2c, c_mpi_comm_null
397+
use iso_c_binding, only: c_int, c_ptr
398+
integer, intent(in) :: comm
399+
integer, intent(in) :: group
400+
integer, intent(out) :: newcomm
401+
integer, optional, intent(out) :: ierror
402+
integer(kind=MPI_HANDLE_KIND) :: c_comm, c_group, c_newcomm
403+
integer(c_int) :: local_ierr
404+
405+
c_comm = handle_mpi_comm_f2c(comm)
406+
c_group = c_mpi_group_f2c(group)
407+
local_ierr = c_mpi_comm_create(c_comm, c_group, c_newcomm)
408+
409+
newcomm = handle_mpi_comm_c2f(c_newcomm)
410+
411+
if (present(ierror)) then
412+
ierror = local_ierr
413+
else if (local_ierr /= MPI_SUCCESS) then
414+
print *, "MPI_Comm_create failed with error code: ", local_ierr
415+
end if
416+
end subroutine MPI_Comm_create_proc
417+
353418
subroutine MPI_Comm_dup_proc(comm, newcomm, ierror)
354419
use mpi_c_bindings, only: c_mpi_comm_dup, c_mpi_comm_c2f
355420
integer, intent(in) :: comm

src/mpi_c_bindings.f90

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ module mpi_c_bindings
1717
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_REAL") :: c_mpi_real
1818
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_INT") :: c_mpi_int
1919
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_COMM_WORLD") :: c_mpi_comm_world
20+
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_COMM_NULL") :: c_mpi_comm_null
2021
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_SUM") :: c_mpi_sum
2122
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_MAX") :: c_mpi_max
2223
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_LOGICAL") :: c_mpi_logical
@@ -338,5 +339,22 @@ function c_mpi_group_free(group) bind(C, name="MPI_Group_free")
338339
integer(c_int) :: c_mpi_group_free
339340
end function c_mpi_group_free
340341

342+
function c_mpi_group_range_incl(group, n, ranges, c_newgroup) bind(C, name="MPI_Group_range_incl")
343+
use iso_c_binding, only: c_ptr, c_int
344+
integer(kind=MPI_HANDLE_KIND), value :: group
345+
integer(c_int), value :: n
346+
integer(c_int), dimension(*) :: ranges
347+
integer(kind=MPI_HANDLE_KIND) :: c_newgroup
348+
integer(c_int) :: c_mpi_group_range_incl
349+
end function c_mpi_group_range_incl
350+
351+
function c_mpi_comm_create(comm, group, newcomm) bind(C, name="MPI_Comm_create")
352+
use iso_c_binding, only: c_ptr, c_int
353+
integer(kind=MPI_HANDLE_KIND), value :: comm
354+
integer(kind=MPI_HANDLE_KIND), value :: group
355+
integer(kind=MPI_HANDLE_KIND), intent(out) :: newcomm
356+
integer(c_int) :: c_mpi_comm_create
357+
end function c_mpi_comm_create
358+
341359
end interface
342360
end module mpi_c_bindings

src/mpi_constants.c

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,30 @@ MPI_Status* c_MPI_STATUSES_IGNORE = MPI_STATUSES_IGNORE;
44

55
MPI_Info c_MPI_INFO_NULL = MPI_INFO_NULL;
66

7-
MPI_Comm c_MPI_COMM_WORLD = MPI_COMM_WORLD;
7+
void* c_MPI_IN_PLACE = MPI_IN_PLACE;
8+
9+
// DataType Declarations
810

911
MPI_Datatype c_MPI_DOUBLE = MPI_DOUBLE;
1012

1113
MPI_Datatype c_MPI_FLOAT = MPI_FLOAT;
1214

1315
MPI_Datatype c_MPI_INT = MPI_INT;
1416

15-
void* c_MPI_IN_PLACE = MPI_IN_PLACE;
17+
MPI_Datatype c_MPI_LOGICAL = MPI_LOGICAL;
18+
19+
MPI_Datatype c_MPI_CHARACTER = MPI_CHARACTER;
20+
21+
MPI_Datatype c_MPI_REAL = MPI_REAL;
22+
23+
// Operation Declarations
1624

1725
MPI_Op c_MPI_SUM = MPI_SUM;
1826

1927
MPI_Op c_MPI_MAX = MPI_MAX;
2028

21-
MPI_Datatype c_MPI_LOGICAL = MPI_LOGICAL;
29+
// Communicators Declarations
2230

23-
MPI_Datatype c_MPI_CHARACTER = MPI_CHARACTER;
31+
MPI_Comm c_MPI_COMM_NULL = MPI_COMM_NULL;
2432

25-
MPI_Datatype c_MPI_REAL = MPI_REAL;
33+
MPI_Comm c_MPI_COMM_WORLD = MPI_COMM_WORLD;

tests/comm_create_1.f90

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
program minimal_mre_range
2+
use mpi
3+
implicit none
4+
5+
integer :: ierr, rank, new_rank, size
6+
integer :: group_world, group_range, new_comm
7+
integer, dimension(1,3) :: range ! 1D array to define a single range
8+
integer :: i
9+
10+
call MPI_INIT(ierr)
11+
call MPI_COMM_RANK(MPI_COMM_WORLD, rank, ierr)
12+
call MPI_COMM_SIZE(MPI_COMM_WORLD, size, ierr)
13+
14+
! Get the group of MPI_COMM_WORLD
15+
call MPI_COMM_GROUP(MPI_COMM_WORLD, group_world, ierr)
16+
17+
! Define 1D range: start, end, stride
18+
range(1,1) = 0 ! start
19+
range(1,2) = size - 1 ! end
20+
range(1,3) = 1 ! stride
21+
22+
23+
! Create a new group that includes all ranks
24+
call MPI_GROUP_RANGE_INCL(group_world, 1, range, group_range, ierr)
25+
26+
! Create new communicator
27+
call MPI_COMM_CREATE(MPI_COMM_WORLD, group_range, new_comm, ierr)
28+
29+
! Print participation
30+
if (new_comm /= MPI_COMM_NULL) then
31+
call MPI_COMM_RANK(new_comm, new_rank, ierr)
32+
if (ierr /= MPI_SUCCESS) error stop "MPI_COMM_RANK on new_comm failed"
33+
print *, 'Global rank', rank, 'is in new_comm with local rank', new_rank
34+
else
35+
print *, 'Rank', rank, 'is NOT in the new communicator.'
36+
end if
37+
38+
! Free groups (no comm_free)
39+
call MPI_GROUP_FREE(group_range, ierr)
40+
call MPI_GROUP_FREE(group_world, ierr)
41+
42+
call MPI_FINALIZE(ierr)
43+
end program minimal_mre_range

0 commit comments

Comments
 (0)