diff --git a/include/tracing_impl/nvtx.h b/include/tracing_impl/nvtx.h index 1fbfe8b55..232923d3a 100644 --- a/include/tracing_impl/nvtx.h +++ b/include/tracing_impl/nvtx.h @@ -7,31 +7,36 @@ #if HAVE_NVTX_TRACING #include "nvToolsExt.h" +#include #define NCCL_OFI_N_NVTX_DOMAIN_PER_COMM 8 static inline void nvtx_mark_domain(nvtxDomainHandle_t domain, const char* name, uint32_t color) { - const nvtxEventAttributes_t eventAttrib = { - .version = NVTX_VERSION, - .size = NVTX_EVENT_ATTRIB_STRUCT_SIZE, - .colorType = NVTX_COLOR_ARGB, - .color = color, - .messageType = NVTX_MESSAGE_TYPE_ASCII, - .message = { .ascii = name }, - }; + nvtxEventAttributes_t eventAttrib; + memset(&eventAttrib, 0, sizeof(nvtxEventAttributes_t)); + + eventAttrib.version = NVTX_VERSION; + eventAttrib.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE; + eventAttrib.colorType = NVTX_COLOR_ARGB; + eventAttrib.color = color; + eventAttrib.messageType = NVTX_MESSAGE_TYPE_ASCII; + eventAttrib.message = { .ascii = name }; + nvtxDomainMarkEx(domain, &eventAttrib); } static inline nvtxRangeId_t nvtx_start_domain(bool have_domain, nvtxDomainHandle_t domain, const char* name, uint32_t color) { - const nvtxEventAttributes_t eventAttrib = { - .version = NVTX_VERSION, - .size = NVTX_EVENT_ATTRIB_STRUCT_SIZE, - .colorType = NVTX_COLOR_ARGB, - .color = color, - .messageType = NVTX_MESSAGE_TYPE_ASCII, - .message = { .ascii = name }, - }; + nvtxEventAttributes_t eventAttrib; + memset(&eventAttrib, 0, sizeof(nvtxEventAttributes_t)); + + eventAttrib.version = NVTX_VERSION; + eventAttrib.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE; + eventAttrib.colorType = NVTX_COLOR_ARGB; + eventAttrib.color = color; + eventAttrib.messageType = NVTX_MESSAGE_TYPE_ASCII; + eventAttrib.message = { .ascii = name }; + if (have_domain) return nvtxDomainRangeStartEx(domain, &eventAttrib); else