From 662f30d05ba8a1c9eac06cd284c8eb42dd689174 Mon Sep 17 00:00:00 2001 From: Brian Barrett Date: Wed, 27 Nov 2024 00:17:07 +0000 Subject: [PATCH] aws: Use regex expansion for platform names Our platform map is getting a bit absurd as we duplicate a bunch of information for every product variant (I'm looking at you, P5, P5e, and P5en). Instead, support a wild card (and really, any valid regex) in the platform data name. This impleemntation is wildly inefficient, since we have to build a new regex engine every time through the loop, but we only do this at init, and the total number of platforms isn't that big, so let's not worry about that. Signed-off-by: Brian Barrett --- src/platform-aws.c | 78 ++++++++++++++++++---------------------------- 1 file changed, 30 insertions(+), 48 deletions(-) diff --git a/src/platform-aws.c b/src/platform-aws.c index 787305fdf..d5f32e752 100644 --- a/src/platform-aws.c +++ b/src/platform-aws.c @@ -15,6 +15,7 @@ #ifdef HAVE_RDMA_FI_EXT_H #include #endif +#include #include #include "nccl_ofi.h" @@ -69,27 +70,7 @@ static struct ec2_platform_data platform_data_map[] = { .domain_per_thread = 0, }, { - .name = "p5.48xlarge", - .topology = NULL, - .default_dup_conns = 0, - .latency = 75.0, - .gdr_required = true, - .net_flush_required = false, - .default_protocol = "RDMA", - .domain_per_thread = 0, - }, - { - .name = "p5e.48xlarge", - .topology = NULL, - .default_dup_conns = 0, - .latency = 75.0, - .gdr_required = true, - .net_flush_required = false, - .default_protocol = "RDMA", - .domain_per_thread = 0, - }, - { - .name = "p5en.48xlarge", + .name = "p5.*", .topology = NULL, .default_dup_conns = 0, .latency = 75.0, @@ -109,7 +90,7 @@ static struct ec2_platform_data platform_data_map[] = { .domain_per_thread = 0, }, { - .name = "trn1.32xlarge", + .name = "trn1.*", .topology = NULL, .default_dup_conns = 0, .latency = 75.0, @@ -119,17 +100,7 @@ static struct ec2_platform_data platform_data_map[] = { .domain_per_thread = 1, }, { - .name = "trn1n.32xlarge", - .topology = NULL, - .default_dup_conns = 0, - .latency = 75.0, - .gdr_required = true, - .net_flush_required = true, - .default_protocol = "SENDRECV", - .domain_per_thread = 1, - }, - { - .name = "trn2.48xlarge", + .name = "trn2.*", .topology = NULL, .default_dup_conns = 0, .latency = 75.0, @@ -138,16 +109,6 @@ static struct ec2_platform_data platform_data_map[] = { .default_protocol = "RDMA", .domain_per_thread = 1, }, - { - .name = "trn2n.48xlarge", - .topology = NULL, - .default_dup_conns = 0, - .latency = 75.0, - .gdr_required = true, - .net_flush_required = true, - .default_protocol = "RDMA", - .domain_per_thread = 1, - } }; /* @@ -165,26 +126,47 @@ static struct ec2_platform_data *get_platform_data() static struct ec2_platform_data *platform_data = NULL; const size_t platform_n = sizeof(platform_data_map)/sizeof(platform_data_map[0]); const char* platform_type = NULL; + regex_t regex; + int ret; nccl_net_ofi_mutex_lock(&mutex); if (init) { - nccl_net_ofi_mutex_unlock(&mutex); - return platform_data; + goto done; } init = true; platform_type = nccl_net_ofi_get_product_name(); if (platform_type == NULL) { - nccl_net_ofi_mutex_unlock(&mutex); - return NULL; + platform_data = NULL; + goto done; } for (size_t idx = 0; idx < platform_n; idx++) { - if (strcmp(platform_type, platform_data_map[idx].name) == 0) + ret = regcomp(®ex, platform_data_map[idx].name, 0); + if (ret != 0) { + NCCL_OFI_WARN("Could not compile platform_type regex for %s", + platform_data_map[idx].name); + platform_data = NULL; + goto done; + } + + ret = regexec(®ex, platform_type, 0, NULL, 0); + + if (ret == 0) { platform_data = &platform_data_map[idx]; + } else if (ret != REG_NOMATCH) { + NCCL_OFI_WARN("Regex match failed"); + platform_data = NULL; + goto done; + } + + regfree(®ex); } + NCCL_OFI_TRACE(NCCL_NET | NCCL_INIT, "Using platform block %s for instance type %s", + platform_data->name, platform_type); +done: nccl_net_ofi_mutex_unlock(&mutex); return platform_data;