Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UCX_MO: Add local MD support #76

Merged
merged 5 commits into from
Mar 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 21 additions & 38 deletions src/plugins/ucx/ucx_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -598,74 +598,57 @@ nixl_status_t nixlUcxEngine::getPublicData (const nixlBackendMD* meta,
return NIXL_SUCCESS;
}

nixl_status_t nixlUcxEngine::loadLocalMD (nixlBackendMD* input,
nixlBackendMD* &output) {

// To be cleaned up
nixl_status_t
nixlUcxEngine::internalMDHelper (const nixl_blob_t &blob,
const std::string &agent,
nixlBackendMD* &output) {
nixlUcxConnection conn;
nixlUcxPrivateMetadata* input_md = (nixlUcxPrivateMetadata*) input;
nixlUcxPublicMetadata *md = new nixlUcxPublicMetadata;
size_t size = blob.size();

//look up our own name
auto search = remoteConnMap.find(localAgent);
auto search = remoteConnMap.find(agent);

if(search == remoteConnMap.end()) {
//TODO: something wrong, local connection should have been established
//TODO: err: remote connection not found
return NIXL_ERR_NOT_FOUND;
}
conn = (nixlUcxConnection) search->second;

//directly copy underlying conn struct
md->conn = conn;

size_t size = input_md->rkeyStr.size();
char *addr = new char[size];
nixlSerDes::_stringToBytes(addr, input_md->rkeyStr, size);
nixlSerDes::_stringToBytes(addr, blob, size);

int ret = uw->rkeyImport(conn.ep, addr, size, md->rkey);
if (ret) {
// TODO: error out. Should we indicate which desc failed or unroll everything prior
return NIXL_ERR_BACKEND;
}

output = (nixlBackendMD*) md;

delete[] addr;

return NIXL_SUCCESS;
}

nixl_status_t
nixlUcxEngine::loadLocalMD (nixlBackendMD* input,
nixlBackendMD* &output)
{
nixlUcxPrivateMetadata* input_md = (nixlUcxPrivateMetadata*) input;
return internalMDHelper(input_md->rkeyStr, localAgent, output);
}

// To be cleaned up
nixl_status_t nixlUcxEngine::loadRemoteMD (const nixlBlobDesc &input,
const nixl_mem_t &nixl_mem,
const std::string &remote_agent,
nixlBackendMD* &output) {
size_t size = input.metaInfo.size();
char *addr = new char[size];
int ret;
nixlUcxConnection conn;

nixlUcxPublicMetadata *md = new nixlUcxPublicMetadata;

auto search = remoteConnMap.find(remote_agent);

if(search == remoteConnMap.end()) {
//TODO: err: remote connection not found
return NIXL_ERR_NOT_FOUND;
}
conn = (nixlUcxConnection) search->second;

nixlSerDes::_stringToBytes(addr, input.metaInfo, size);

md->conn = conn;
ret = uw->rkeyImport(conn.ep, addr, size, md->rkey);
if (ret) {
// TODO: error out. Should we indicate which desc failed or unroll everything prior
return NIXL_ERR_BACKEND;
}
output = (nixlBackendMD*) md;

delete[] addr;

return NIXL_SUCCESS;
nixlBackendMD* &output)
{
return internalMDHelper(input.metaInfo, remote_agent, output);
}

nixl_status_t nixlUcxEngine::unloadMD (nixlBackendMD* input) {
Expand Down
7 changes: 6 additions & 1 deletion src/plugins/ucx/ucx_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class nixlUcxConnection : public nixlBackendConnMD {
class nixlUcxPrivateMetadata : public nixlBackendMD {
private:
nixlUcxMem mem;
std::string rkeyStr;
nixl_blob_t rkeyStr;

public:
nixlUcxPrivateMetadata() : nixlBackendMD(true) {
Expand Down Expand Up @@ -175,6 +175,11 @@ class nixlUcxEngine : public nixlBackendEngine {
size_t length,
const ucp_am_recv_param_t *param);

// Memory management helpers
nixl_status_t internalMDHelper (const nixl_blob_t &blob,
const std::string &agent,
nixlBackendMD* &output);

// Notifications
static ucs_status_t notifAmCb(void *arg, const void *header,
size_t header_length, void *data,
Expand Down
48 changes: 30 additions & 18 deletions src/plugins/ucx_mo/ucx_mo_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ nixlUcxMoEngine::registerMem (const nixlBlobDesc &mem,
return NIXL_ERR_INVALID_PARAM;
}

priv->memType = nixl_mem;
priv->eidx = eidx;
engines[eidx]->registerMem(mem, nixl_mem, priv->md);

Expand Down Expand Up @@ -340,57 +341,51 @@ nixlUcxMoEngine::deregisterMem (nixlBackendMD* meta)
return NIXL_SUCCESS;
}

// To be cleaned up
nixl_status_t
nixlUcxMoEngine::loadLocalMD(nixlBackendMD* input,
nixlBackendMD* &output)
{
// TODO
return NIXL_ERR_NOT_FOUND;
}

nixl_status_t
nixlUcxMoEngine::loadRemoteMD (const nixlBlobDesc &input,
const nixl_mem_t &nixl_mem,
const string &remote_agent,
nixlBackendMD* &output)
nixlUcxMoEngine::internalMDHelper (const nixl_blob_t &blob,
const nixl_mem_t &nixl_mem,
const std::string &agent,
nixlBackendMD* &output)
{
nixlUcxMoConnection conn;
nixlSerDes sd;
string rkeyStr;
nixl_blob_t ucx_blob;
nixl_status_t status;
nixlBlobDesc input_int;

nixlUcxMoPublicMetadata *md = new nixlUcxMoPublicMetadata;

auto search = remoteConnMap.find(remote_agent);
auto search = remoteConnMap.find(agent);

if(search == remoteConnMap.end()) {
//TODO: err: remote connection not found
return NIXL_ERR_NOT_FOUND;
}
conn = (nixlUcxMoConnection) search->second;

status = sd.importStr(input.metaInfo);
status = sd.importStr(blob);

ssize_t ret = sd.getBufLen("EngIdx");
if (ret != sizeof(md->eidx)) {
return NIXL_ERR_MISMATCH;
}

status = sd.getBuf("EngIdx", &md->eidx, ret);
if (status != NIXL_SUCCESS) {
return status;
}

rkeyStr = sd.getStr("RkeyStr");
ucx_blob = sd.getStr("RkeyStr");
if (status != NIXL_SUCCESS) {
return status;
}

for (auto &e : engines) {
nixlBackendMD *int_md;
input_int.metaInfo = rkeyStr;
input_int.metaInfo = ucx_blob;
status = e->loadRemoteMD(input_int, nixl_mem,
getEngName(remote_agent, md->eidx),
getEngName(agent, md->eidx),
int_md);
if (status != NIXL_SUCCESS) {
return status;
Expand All @@ -402,6 +397,23 @@ nixlUcxMoEngine::loadRemoteMD (const nixlBlobDesc &input,
return NIXL_SUCCESS;
}

nixl_status_t
nixlUcxMoEngine::loadLocalMD(nixlBackendMD* input,
nixlBackendMD* &output)
{
nixlUcxMoPrivateMetadata* input_md = (nixlUcxMoPrivateMetadata*) input;
return internalMDHelper(input_md->rkeyStr, input_md->memType, localAgent, output);
}

nixl_status_t
nixlUcxMoEngine::loadRemoteMD (const nixlBlobDesc &input,
const nixl_mem_t &nixl_mem,
const string &remote_agent,
nixlBackendMD* &output)
{
return internalMDHelper(input.metaInfo, nixl_mem, remote_agent, output);
}

nixl_status_t
nixlUcxMoEngine::unloadMD (nixlBackendMD* input)
{
Expand Down
10 changes: 8 additions & 2 deletions src/plugins/ucx_mo/ucx_mo_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ class nixlUcxMoPrivateMetadata : public nixlBackendMD
private:
uint32_t eidx;
nixlBackendMD *md;
std::string rkeyStr;
nixl_mem_t memType;
nixl_blob_t rkeyStr;
public:
nixlUcxMoPrivateMetadata() : nixlBackendMD(true) {
}
Expand Down Expand Up @@ -133,6 +134,11 @@ class nixlUcxMoEngine : public nixlBackendEngine {
}
};

// Memory helper
nixl_status_t internalMDHelper (const nixl_blob_t &blob,
const nixl_mem_t &nixl_mem,
const std::string &agent,
nixlBackendMD* &output);

// Data transfer
nixl_status_t retHelper(nixl_status_t ret, nixlBackendEngine *eng,
Expand All @@ -143,7 +149,7 @@ class nixlUcxMoEngine : public nixlBackendEngine {
~nixlUcxMoEngine();

bool supportsRemote () const { return true; }
bool supportsLocal () const { return false; }
bool supportsLocal () const { return true; }
bool supportsNotif () const { return true; }
bool supportsProgTh () const { return pthrOn; }

Expand Down
Loading