Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 128 additions & 3 deletions flagcx/adaptor/include/ib_common.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
/*************************************************************************
* Copyright (c) 2024, FlagCX Inc.
* All rights reserved.
* Copyright (c) 2023 BAAI. All rights reserved.
*
* This file contains common InfiniBand structures and constants
* shared between IBRC and UCX adaptors.
Expand Down Expand Up @@ -125,9 +124,106 @@ struct flagcxIbMrHandle {
#define FLAGCX_NET_IB_REQ_SEND 1
#define FLAGCX_NET_IB_REQ_RECV 2
#define FLAGCX_NET_IB_REQ_FLUSH 3
#define FLAGCX_NET_IB_REQ_ACK 4

extern const char *reqTypeStr[];

#define FLAGCX_IB_RETRANS_MAX_INFLIGHT 2048
#define FLAGCX_IB_RETRANS_BUFFER_SIZE 1024
#define FLAGCX_IB_RETRANS_MAX_CHUNK_SIZE (8 * 1024 * 1024)
#define FLAGCX_IB_SRQ_SIZE 1024


#define FLAGCX_IB_ACK_BUF_PADDING 40
#define FLAGCX_IB_ACK_BUF_COUNT 64

struct flagcxIbRetransHdr {
uint32_t magic;
uint32_t seq;
uint32_t size;
uint32_t rkey;
uint64_t remoteAddr;
uint32_t immData;
uint32_t padding;
} __attribute__((packed));

struct flagcxIbAckMsg {
uint16_t peerId;
uint16_t flowId;
uint16_t path;
uint16_t ackSeq;
uint16_t sackBitmapCount;
uint16_t padding;
uint64_t timestampUs;
uint64_t sackBitmap;
} __attribute__((packed));

struct flagcxIbCtrlQp {
struct ibv_qp *qp;
struct ibv_cq *cq;
struct ibv_ah *ah;
uint32_t remoteQpn;
uint32_t remoteQkey;
};

struct flagcxIbRetransRecvBuf {
void *buffer;
struct ibv_mr *mr;
size_t size;
int inUse;
};

struct flagcxIbSrqMgr {
void *srq;
struct ibv_cq *cq;
struct flagcxIbRetransRecvBuf bufs[FLAGCX_IB_SRQ_SIZE];
int bufCount;
// Buffer management for SRQ (similar to UCCL)
int freeBufIndices[FLAGCX_IB_SRQ_SIZE]; // Stack of free buffer indices
int freeBufCount; // Number of free buffers available
int postSrqCount; // Number of recv WRs that need to be posted to SRQ
};

struct flagcxIbRetransEntry {
uint32_t seq;
uint32_t size;
uint64_t sendTimeUs;
uint64_t remoteAddr;
void *data;
uint32_t lkeys[FLAGCX_IB_MAX_DEVS_PER_NIC];
uint32_t rkeys[FLAGCX_IB_MAX_DEVS_PER_NIC];
int retry_count;
int valid;
};

struct flagcxIbRetransState {
uint32_t sendSeq;
uint32_t sendUna;
uint32_t recvSeq;

struct flagcxIbRetransEntry buffer[FLAGCX_IB_RETRANS_MAX_INFLIGHT];
int bufferHead;
int bufferTail;
int bufferCount;

uint64_t lastAckTimeUs;
uint64_t rtoUs;
uint64_t srttUs;
uint64_t rttvarUs;

uint64_t totalSent;
uint64_t totalRetrans;
uint64_t totalAcked;
uint64_t totalTimeout;

int enabled;
int maxRetry;
int ackInterval;
uint32_t minRtoUs;
uint32_t maxRtoUs;
int retransQPIndex;
};

struct flagcxIbQp {
struct ibv_qp *qp;
int devIndex;
Expand Down Expand Up @@ -176,6 +272,11 @@ struct flagcxIbConnectionMetadata {
char devName[MAX_MERGED_DEV_NAME];
uint64_t fifoAddr;
int ndevs;

uint32_t ctrlQpn[FLAGCX_IB_MAX_DEVS_PER_NIC];
union ibv_gid ctrlGid[FLAGCX_IB_MAX_DEVS_PER_NIC];
uint16_t ctrlLid[FLAGCX_IB_MAX_DEVS_PER_NIC];
int retransEnabled;
};

struct flagcxIbNetCommDevBase {
Expand All @@ -199,6 +300,10 @@ struct flagcxIbRemSizesFifo {
struct flagcxIbSendCommDev {
struct flagcxIbNetCommDevBase base;
struct ibv_mr *fifoMr;

struct flagcxIbCtrlQp ctrlQp;
struct ibv_mr *ackMr;
void *ackBuffer;
};

struct alignas(32) flagcxIbNetCommBase {
Expand Down Expand Up @@ -226,7 +331,17 @@ struct flagcxIbSendComm {
struct ibv_send_wr wrs[FLAGCX_NET_IB_MAX_RECVS + 1];
struct flagcxIbRemSizesFifo remSizesFifo;
uint64_t fifoHead;
int ar; // Use adaptive routing when all merged devices have it enabled
int ar;

struct flagcxIbRetransState retrans;
uint64_t last_timeout_check_us;

int outstanding_sends;
int outstanding_retrans;
int max_outstanding;

struct flagcxIbRetransHdr retrans_hdr_pool[32];
struct ibv_mr *retrans_hdr_mr;
};

struct flagcxIbGpuFlush {
Expand All @@ -249,6 +364,13 @@ struct alignas(16) flagcxIbRecvCommDev {
struct ibv_mr *fifoMr;
struct ibv_sge fifoSge;
struct ibv_mr *sizesFifoMr;
struct flagcxIbCtrlQp ctrlQp;
struct ibv_mr *ackMr;
void *ackBuffer;

void *retransRecvBufs[32];
struct ibv_mr *retransRecvMr;
int retransRecvBufCount;
};

struct alignas(32) flagcxIbRecvComm {
Expand All @@ -258,6 +380,9 @@ struct alignas(32) flagcxIbRecvComm {
int sizesFifo[MAX_REQUESTS][FLAGCX_NET_IB_MAX_RECVS];
int gpuFlushHostMem;
int flushEnabled;

struct flagcxIbRetransState retrans;
struct flagcxIbSrqMgr srqMgr;
};

// Global arrays (declared as extern, defined in adaptor files)
Expand Down
113 changes: 113 additions & 0 deletions flagcx/adaptor/include/ibuc_retrans.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*************************************************************************
* Copyright (c) 2023 BAAI. All rights reserved.
* All rights reserved.
*
* IBUC Retransmission Support - Header
************************************************************************/

#ifndef FLAGCX_IBUC_RETRANS_H_
#define FLAGCX_IBUC_RETRANS_H_

#include "flagcx_common.h"
#include "ib_common.h"
#include <stdint.h>
#include <time.h>

// Retransmission constants
#define FLAGCX_RETRANS_MAGIC 0xDEADBEEF // Magic number for retransmission header
#define FLAGCX_RETRANS_WR_ID 0xFFFFFFFEULL // WR ID for retransmission completions

extern int64_t flagcxParamIbRetransEnable(void);
extern int64_t flagcxParamIbRetransTimeout(void);
extern int64_t flagcxParamIbRetransMaxRetry(void);
extern int64_t flagcxParamIbRetransAckInterval(void);
extern int64_t flagcxParamIbMaxOutstanding(void);

static inline uint64_t flagcxIbGetTimeUs(void) {
struct timespec ts;
clock_gettime(CLOCK_MONOTONIC, &ts);
return (uint64_t)ts.tv_sec * 1000000ULL + (uint64_t)ts.tv_nsec / 1000ULL;
}

static inline int flagcxIbSeqLess(uint32_t a, uint32_t b) {
uint16_t a16 = a & 0xFFFF;
uint16_t b16 = b & 0xFFFF;
return (int16_t)(a16 - b16) < 0;
}

static inline int flagcxIbSeqLeq(uint32_t a, uint32_t b) {
uint16_t a16 = a & 0xFFFF;
uint16_t b16 = b & 0xFFFF;
return (int16_t)(a16 - b16) <= 0;
}

flagcxResult_t flagcxIbRetransInit(struct flagcxIbRetransState *state);

flagcxResult_t flagcxIbRetransDestroy(struct flagcxIbRetransState *state);

flagcxResult_t flagcxIbRetransAddPacket(struct flagcxIbRetransState *state,
uint32_t seq, uint32_t size, void *data,
uint64_t remote_addr, uint32_t *lkeys,
uint32_t *rkeys);

flagcxResult_t flagcxIbRetransProcessAck(struct flagcxIbRetransState *state,
struct flagcxIbAckMsg *ack_msg);

flagcxResult_t flagcxIbRetransCheckTimeout(struct flagcxIbRetransState *state,
struct flagcxIbSendComm *comm);

flagcxResult_t flagcxIbRetransRecvPacket(struct flagcxIbRetransState *state,
uint32_t seq,
struct flagcxIbAckMsg *ack_msg,
int *should_ack);

flagcxResult_t flagcxIbRetransPiggybackAck(struct flagcxIbSendFifo *fifo_elem,
struct flagcxIbAckMsg *ack_msg);

flagcxResult_t flagcxIbRetransExtractAck(struct flagcxIbSendFifo *fifo_elem,
struct flagcxIbAckMsg *ack_msg);

static inline uint32_t flagcxIbEncodeImmData(uint32_t seq, uint32_t size) {
return ((seq & 0xFFFF) << 16) | (size & 0xFFFF);
}

static inline void flagcxIbDecodeImmData(uint32_t imm_data, uint32_t *seq,
uint32_t *size) {
*seq = (imm_data >> 16) & 0xFFFF;
*size = imm_data & 0xFFFF;
}

void flagcxIbRetransPrintStats(struct flagcxIbRetransState *state,
const char *prefix);

flagcxResult_t flagcxIbCreateCtrlQp(struct ibv_context *context,
struct ibv_pd *pd, uint8_t port_num,
struct flagcxIbCtrlQp *ctrlQp);

flagcxResult_t flagcxIbDestroyCtrlQp(struct flagcxIbCtrlQp *ctrlQp);

flagcxResult_t
flagcxIbSetupCtrlQpConnection(struct ibv_context *context, struct ibv_pd *pd,
struct flagcxIbCtrlQp *ctrlQp,
uint32_t remote_qpn, union ibv_gid *remote_gid,
uint16_t remote_lid, uint8_t port_num,
uint8_t link_layer, uint8_t local_gid_index);

flagcxResult_t flagcxIbRetransSendAckViaUd(struct flagcxIbRecvComm *comm,
struct flagcxIbAckMsg *ack_msg,
int devIndex);

flagcxResult_t flagcxIbRetransRecvAckViaUd(struct flagcxIbSendComm *comm,
int devIndex);

flagcxResult_t flagcxIbRetransResendViaSend(struct flagcxIbSendComm *comm,
uint32_t seq);

flagcxResult_t flagcxIbCreateSrq(struct ibv_context *context, struct ibv_pd *pd,
struct flagcxIbSrqMgr *srqMgr);

flagcxResult_t flagcxIbDestroySrq(struct flagcxIbSrqMgr *srqMgr);

flagcxResult_t flagcxIbSrqPostRecv(struct flagcxIbSrqMgr *srqMgr, int count);

#endif // FLAGCX_IBUC_RETRANS_H_
8 changes: 4 additions & 4 deletions flagcx/adaptor/net/ibrc_adaptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1101,15 +1101,15 @@ flagcxResult_t flagcxIbAccept(void *listenComm, void **recvComm) {

// Stripe QP creation across merged devs
// Make sure to get correct remote peer dev and QP info
int remDevIndex;
int remDevIdx;
int devIndex;
devIndex = 0;
for (int q = 0; q < rComm->base.nqps; q++) {
remDevIndex = remMeta.qpInfo[q].devIndex;
remDevInfo = remMeta.devs + remDevIndex;
remDevIdx = remMeta.qpInfo[q].devIndex;
remDevInfo = remMeta.devs + remDevIdx;
qp = rComm->base.qps + q;
rCommDev = rComm->devs + devIndex;
qp->remDevIdx = remDevIndex;
qp->remDevIdx = remDevIdx;

// Local ibDevN
ibDevN = rComm->devs[devIndex].base.ibDevN;
Expand Down
Loading