Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion src/mpi.f90
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ module mpi
module procedure MPI_Recv_StatusIgnore_proc
end interface

interface MPI_Sendrecv
module procedure MPI_Sendrecv_proc
end interface

interface MPI_Waitall
module procedure MPI_Waitall_proc
end interface
Expand Down Expand Up @@ -668,7 +672,49 @@ subroutine MPI_Irecv_proc(buf, count, datatype, source, tag, comm, request, ierr
print *, "MPI_Irecv failed with error code: ", local_ierr
end if
end if
end subroutine
end subroutine MPI_Irecv_proc

subroutine MPI_Sendrecv_proc (sendbuf, sendcount, sendtype, dest, sendtag, &
recvbuf, recvcount, recvtype, source, recvtag, comm, status, ierror)
use iso_c_binding, only: c_int, c_ptr, c_loc
use mpi_c_bindings, only: c_mpi_sendrecv, c_mpi_status_c2f
real(8), dimension(:,:), target, intent(in) :: sendbuf
integer, intent(in) :: sendcount, dest, sendtag
real(8), dimension(:,:), target, intent(out) :: recvbuf
integer, intent(in) :: recvcount, source, recvtag
integer, intent(in) :: comm
integer, intent(in) :: sendtype, recvtype
integer(kind=MPI_HANDLE_KIND) :: c_comm
integer, intent(out) :: status(MPI_STATUS_SIZE)
integer, optional, intent(out) :: ierror
integer(c_int) :: local_ierr, status_ierr
integer(kind=MPI_HANDLE_KIND) :: c_sendtype, c_recvtype
type(c_ptr) :: sendbuf_ptr, recvbuf_ptr, c_status
integer(c_int), dimension(MPI_STATUS_SIZE), target :: tmp_status

c_comm = handle_mpi_comm_f2c(comm)

c_sendtype = handle_mpi_datatype_f2c(sendtype)
c_recvtype = handle_mpi_datatype_f2c(recvtype)
sendbuf_ptr = c_loc(sendbuf)
recvbuf_ptr = c_loc(recvbuf)
c_status = c_loc(tmp_status)

local_ierr = c_mpi_sendrecv(sendbuf_ptr, sendcount, c_sendtype, dest, sendtag, &
recvbuf_ptr, recvcount, c_recvtype, source, recvtag, &
c_comm, c_status)

if (local_ierr == MPI_SUCCESS) then
! status_ierr = c_mpi_status_c2f(c_status, status)
end if

if (local_ierr /= MPI_SUCCESS) then
print *, "MPI_Sendrecv failed with error code: ", local_ierr
if (present(ierror)) then
ierror = local_ierr
end if
end if
end subroutine MPI_Sendrecv_proc

subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ierror)
use iso_c_binding, only: c_int, c_ptr, c_loc
Expand Down
16 changes: 16 additions & 0 deletions src/mpi_c_bindings.f90
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,22 @@ function c_mpi_recv(buf, count, c_dtype, source, tag, c_comm, status) bind(C, na
integer(c_int) :: c_mpi_recv
end function c_mpi_recv

function c_mpi_sendrecv (sendbuf, sendcount, sendtype, dest, sendtag, &
recvbuf, recvcount, recvtype, source, recvtag, comm, status) bind(C, name="MPI_Sendrecv")
use iso_c_binding, only: c_int, c_ptr
type(c_ptr), value :: sendbuf
integer(c_int), value :: sendcount
integer(kind=MPI_HANDLE_KIND), value :: sendtype
integer(c_int), value :: dest, sendtag
type(c_ptr), value :: recvbuf
integer(c_int), value :: recvcount
integer(kind=MPI_HANDLE_KIND), value :: recvtype
integer(c_int), value :: source, recvtag
integer(kind=MPI_HANDLE_KIND), value :: comm
type(c_ptr), value :: status
integer(c_int) :: c_mpi_sendrecv
end function c_mpi_sendrecv

function c_mpi_waitall(count, requests, statuses) bind(C, name="MPI_Waitall")
use iso_c_binding, only: c_int, c_ptr
integer(c_int), value :: count
Expand Down
52 changes: 52 additions & 0 deletions tests/sendrecv_1.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
program sendrecv_1
use mpi
implicit none
integer :: ierr, rank, size, next, prev
real(8), allocatable :: sendbuf(:,:), recvbuf(:,:)
integer :: status(MPI_STATUS_SIZE)
logical :: error
integer :: i, j, n1, n2

n1 = 2
n2 = 3

! Initialize MPI
call MPI_Init(ierr)
call MPI_Comm_rank(MPI_COMM_WORLD, rank, ierr)
call MPI_Comm_size(MPI_COMM_WORLD, size, ierr)

! Set up ring communication
next = mod(rank + 1, size) ! Send to next process
prev = mod(rank - 1 + size, size) ! Receive from previous process

! Allocate and initialize send/recv buffers
allocate(sendbuf(n1, n2))
allocate(recvbuf(n1, n2))
sendbuf = rank
recvbuf = -1.0d0

! Perform sendrecv
call MPI_Sendrecv(sendbuf, n1*n2, MPI_REAL8, next, 0, &
recvbuf, n1*n2, MPI_REAL8, prev, 0, &
MPI_COMM_WORLD, status, ierr)

! Verify result
error = .false.
do i = 1, n1
do j = 1, n2
if (recvbuf(i,j) /= real(prev,8)) then
print *, "Rank ", rank, ": Error at (",i,",",j,"): Expected ", prev, ", got ", recvbuf(i,j)
error = .true.
end if
end do
end do

if (.not. error .and. rank == 0) then
print *, "MPI_Sendrecv test passed: rank ", rank, " received correct data"
end if

! Clean up
call MPI_Finalize(ierr)

if (error) error stop 1
end program sendrecv_1