Skip to content

Commit

Permalink
aws: Use regex expansion for platform names
Browse files Browse the repository at this point in the history
Our platform map is getting a bit absurd as we duplicate a bunch
of information for every product variant (I'm looking at you,
P5, P5e, and P5en).  Instead, support a wild card (and really,
any valid regex) in the platform data name.  This impleemntation is
wildly inefficient, since we have to build a new regex engine every
time through the loop, but we only do this at init, and the total
number of platforms isn't that big, so let's not worry about that.

Signed-off-by: Brian Barrett <[email protected]>
  • Loading branch information
bwbarrett committed Nov 27, 2024
1 parent bcb2e96 commit 662f30d
Showing 1 changed file with 30 additions and 48 deletions.
78 changes: 30 additions & 48 deletions src/platform-aws.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#ifdef HAVE_RDMA_FI_EXT_H
#include <rdma/fi_ext.h>
#endif
#include <regex.h>
#include <dlfcn.h>

#include "nccl_ofi.h"
Expand Down Expand Up @@ -69,27 +70,7 @@ static struct ec2_platform_data platform_data_map[] = {
.domain_per_thread = 0,
},
{
.name = "p5.48xlarge",
.topology = NULL,
.default_dup_conns = 0,
.latency = 75.0,
.gdr_required = true,
.net_flush_required = false,
.default_protocol = "RDMA",
.domain_per_thread = 0,
},
{
.name = "p5e.48xlarge",
.topology = NULL,
.default_dup_conns = 0,
.latency = 75.0,
.gdr_required = true,
.net_flush_required = false,
.default_protocol = "RDMA",
.domain_per_thread = 0,
},
{
.name = "p5en.48xlarge",
.name = "p5.*",
.topology = NULL,
.default_dup_conns = 0,
.latency = 75.0,
Expand All @@ -109,7 +90,7 @@ static struct ec2_platform_data platform_data_map[] = {
.domain_per_thread = 0,
},
{
.name = "trn1.32xlarge",
.name = "trn1.*",
.topology = NULL,
.default_dup_conns = 0,
.latency = 75.0,
Expand All @@ -119,17 +100,7 @@ static struct ec2_platform_data platform_data_map[] = {
.domain_per_thread = 1,
},
{
.name = "trn1n.32xlarge",
.topology = NULL,
.default_dup_conns = 0,
.latency = 75.0,
.gdr_required = true,
.net_flush_required = true,
.default_protocol = "SENDRECV",
.domain_per_thread = 1,
},
{
.name = "trn2.48xlarge",
.name = "trn2.*",
.topology = NULL,
.default_dup_conns = 0,
.latency = 75.0,
Expand All @@ -138,16 +109,6 @@ static struct ec2_platform_data platform_data_map[] = {
.default_protocol = "RDMA",
.domain_per_thread = 1,
},
{
.name = "trn2n.48xlarge",
.topology = NULL,
.default_dup_conns = 0,
.latency = 75.0,
.gdr_required = true,
.net_flush_required = true,
.default_protocol = "RDMA",
.domain_per_thread = 1,
}
};

/*
Expand All @@ -165,26 +126,47 @@ static struct ec2_platform_data *get_platform_data()
static struct ec2_platform_data *platform_data = NULL;
const size_t platform_n = sizeof(platform_data_map)/sizeof(platform_data_map[0]);
const char* platform_type = NULL;
regex_t regex;
int ret;

nccl_net_ofi_mutex_lock(&mutex);

if (init) {
nccl_net_ofi_mutex_unlock(&mutex);
return platform_data;
goto done;
}
init = true;

platform_type = nccl_net_ofi_get_product_name();
if (platform_type == NULL) {
nccl_net_ofi_mutex_unlock(&mutex);
return NULL;
platform_data = NULL;
goto done;
}

for (size_t idx = 0; idx < platform_n; idx++) {
if (strcmp(platform_type, platform_data_map[idx].name) == 0)
ret = regcomp(&regex, platform_data_map[idx].name, 0);
if (ret != 0) {
NCCL_OFI_WARN("Could not compile platform_type regex for %s",
platform_data_map[idx].name);
platform_data = NULL;
goto done;
}

ret = regexec(&regex, platform_type, 0, NULL, 0);

if (ret == 0) {
platform_data = &platform_data_map[idx];
} else if (ret != REG_NOMATCH) {
NCCL_OFI_WARN("Regex match failed");
platform_data = NULL;
goto done;
}

regfree(&regex);
}

NCCL_OFI_TRACE(NCCL_NET | NCCL_INIT, "Using platform block %s for instance type %s",
platform_data->name, platform_type);
done:
nccl_net_ofi_mutex_unlock(&mutex);

return platform_data;
Expand Down

0 comments on commit 662f30d

Please sign in to comment.