diff --git a/src/platform-aws.c b/src/platform-aws.c index 787305fdf..d5f32e752 100644 --- a/src/platform-aws.c +++ b/src/platform-aws.c @@ -15,6 +15,7 @@ #ifdef HAVE_RDMA_FI_EXT_H #include #endif +#include #include #include "nccl_ofi.h" @@ -69,27 +70,7 @@ static struct ec2_platform_data platform_data_map[] = { .domain_per_thread = 0, }, { - .name = "p5.48xlarge", - .topology = NULL, - .default_dup_conns = 0, - .latency = 75.0, - .gdr_required = true, - .net_flush_required = false, - .default_protocol = "RDMA", - .domain_per_thread = 0, - }, - { - .name = "p5e.48xlarge", - .topology = NULL, - .default_dup_conns = 0, - .latency = 75.0, - .gdr_required = true, - .net_flush_required = false, - .default_protocol = "RDMA", - .domain_per_thread = 0, - }, - { - .name = "p5en.48xlarge", + .name = "p5.*", .topology = NULL, .default_dup_conns = 0, .latency = 75.0, @@ -109,7 +90,7 @@ static struct ec2_platform_data platform_data_map[] = { .domain_per_thread = 0, }, { - .name = "trn1.32xlarge", + .name = "trn1.*", .topology = NULL, .default_dup_conns = 0, .latency = 75.0, @@ -119,17 +100,7 @@ static struct ec2_platform_data platform_data_map[] = { .domain_per_thread = 1, }, { - .name = "trn1n.32xlarge", - .topology = NULL, - .default_dup_conns = 0, - .latency = 75.0, - .gdr_required = true, - .net_flush_required = true, - .default_protocol = "SENDRECV", - .domain_per_thread = 1, - }, - { - .name = "trn2.48xlarge", + .name = "trn2.*", .topology = NULL, .default_dup_conns = 0, .latency = 75.0, @@ -138,16 +109,6 @@ static struct ec2_platform_data platform_data_map[] = { .default_protocol = "RDMA", .domain_per_thread = 1, }, - { - .name = "trn2n.48xlarge", - .topology = NULL, - .default_dup_conns = 0, - .latency = 75.0, - .gdr_required = true, - .net_flush_required = true, - .default_protocol = "RDMA", - .domain_per_thread = 1, - } }; /* @@ -165,26 +126,47 @@ static struct ec2_platform_data *get_platform_data() static struct ec2_platform_data *platform_data = NULL; const size_t platform_n = sizeof(platform_data_map)/sizeof(platform_data_map[0]); const char* platform_type = NULL; + regex_t regex; + int ret; nccl_net_ofi_mutex_lock(&mutex); if (init) { - nccl_net_ofi_mutex_unlock(&mutex); - return platform_data; + goto done; } init = true; platform_type = nccl_net_ofi_get_product_name(); if (platform_type == NULL) { - nccl_net_ofi_mutex_unlock(&mutex); - return NULL; + platform_data = NULL; + goto done; } for (size_t idx = 0; idx < platform_n; idx++) { - if (strcmp(platform_type, platform_data_map[idx].name) == 0) + ret = regcomp(®ex, platform_data_map[idx].name, 0); + if (ret != 0) { + NCCL_OFI_WARN("Could not compile platform_type regex for %s", + platform_data_map[idx].name); + platform_data = NULL; + goto done; + } + + ret = regexec(®ex, platform_type, 0, NULL, 0); + + if (ret == 0) { platform_data = &platform_data_map[idx]; + } else if (ret != REG_NOMATCH) { + NCCL_OFI_WARN("Regex match failed"); + platform_data = NULL; + goto done; + } + + regfree(®ex); } + NCCL_OFI_TRACE(NCCL_NET | NCCL_INIT, "Using platform block %s for instance type %s", + platform_data->name, platform_type); +done: nccl_net_ofi_mutex_unlock(&mutex); return platform_data;