diff --git a/prov/efa/src/efa_msg.c b/prov/efa/src/efa_msg.c index 7920afbf531..fbd4adb2bd9 100644 --- a/prov/efa/src/efa_msg.c +++ b/prov/efa/src/efa_msg.c @@ -67,10 +67,12 @@ static inline ssize_t efa_post_recv(struct efa_base_ep *base_ep, const struct fi struct ibv_recv_wr *wr; uintptr_t addr; ssize_t err, post_recv_err; - size_t i, wr_index = base_ep->recv_wr_index; + size_t i, wr_index; efa_tracepoint(recv_begin_msg_context, (size_t) msg->context, (size_t) msg->addr); + ofi_genlock_lock(&base_ep->util_ep.lock); + wr_index = base_ep->recv_wr_index; if (wr_index >= base_ep->info->rx_attr->size) { EFA_INFO(FI_LOG_EP_DATA, "recv_wr_index exceeds the rx limit, " @@ -118,8 +120,10 @@ static inline ssize_t efa_post_recv(struct efa_base_ep *base_ep, const struct fi base_ep->recv_wr_index++; - if (flags & FI_MORE) - return 0; + if (flags & FI_MORE) { + err = 0; + goto out; + } efa_tracepoint(post_recv, wr->wr_id, (uintptr_t)msg->context); @@ -134,6 +138,9 @@ static inline ssize_t efa_post_recv(struct efa_base_ep *base_ep, const struct fi base_ep->recv_wr_index = 0; +out: + ofi_genlock_unlock(&base_ep->util_ep.lock); + return err; out_err: @@ -148,6 +155,8 @@ static inline ssize_t efa_post_recv(struct efa_base_ep *base_ep, const struct fi base_ep->recv_wr_index = 0; + ofi_genlock_unlock(&base_ep->util_ep.lock); + return err; } @@ -209,6 +218,7 @@ static inline ssize_t efa_post_send(struct efa_base_ep *base_ep, const struct fi assert(len <= base_ep->info->ep_attr->max_msg_size); + ofi_genlock_lock(&base_ep->util_ep.lock); if (!base_ep->is_wr_started) { ibv_wr_start(qp->ibv_qp_ex); base_ep->is_wr_started = true; @@ -260,10 +270,9 @@ static inline ssize_t efa_post_send(struct efa_base_ep *base_ep, const struct fi ret = ibv_wr_complete(qp->ibv_qp_ex); base_ep->is_wr_started = false; } - if (OFI_UNLIKELY(ret)) - return ret; - return 0; + ofi_genlock_unlock(&base_ep->util_ep.lock); + return ret; } static ssize_t efa_ep_sendmsg(struct fid_ep *ep_fid, const struct fi_msg *msg, uint64_t flags) diff --git a/prov/efa/src/efa_rma.c b/prov/efa/src/efa_rma.c index a7bad7d3877..052e2aa89d7 100644 --- a/prov/efa/src/efa_rma.c +++ b/prov/efa/src/efa_rma.c @@ -83,6 +83,9 @@ static inline ssize_t efa_rma_post_read(struct efa_base_ep *base_ep, base_ep->domain->device->max_rdma_size); qp = base_ep->qp; + + ofi_genlock_lock(&base_ep->util_ep.lock); + if (!base_ep->is_wr_started) { ibv_wr_start(qp->ibv_qp_ex); base_ep->is_wr_started = true; @@ -113,10 +116,9 @@ static inline ssize_t efa_rma_post_read(struct efa_base_ep *base_ep, err = ibv_wr_complete(qp->ibv_qp_ex); base_ep->is_wr_started = false; } - if (OFI_UNLIKELY(err)) - return err; - return 0; + ofi_genlock_unlock(&base_ep->util_ep.lock); + return err; } static @@ -212,6 +214,9 @@ static inline ssize_t efa_rma_post_write(struct efa_base_ep *base_ep, efa_tracepoint(write_begin_msg_context, (size_t) msg->context, (size_t) msg->addr); qp = base_ep->qp; + + ofi_genlock_lock(&base_ep->util_ep.lock); + if (!base_ep->is_wr_started) { ibv_wr_start(qp->ibv_qp_ex); base_ep->is_wr_started = true; @@ -256,10 +261,8 @@ static inline ssize_t efa_rma_post_write(struct efa_base_ep *base_ep, base_ep->is_wr_started = false; } - if (OFI_UNLIKELY(err)) - return err; - - return 0; + ofi_genlock_unlock(&base_ep->util_ep.lock); + return err; } ssize_t efa_rma_writemsg(struct fid_ep *ep_fid, const struct fi_msg_rma *msg,