From 9d91451647ef52e76d48b45a2a373e17266c32be Mon Sep 17 00:00:00 2001 From: Sheng Fu Date: Mon, 15 Apr 2024 18:30:56 -0700 Subject: [PATCH] Fixed comm parser issue Summary: This DIFF is to fix the following two comm parser issue: 1. process_group:init changed backend_id to uid 2. record_param_comms changed input size from 8 to 10 Differential Revision: D56091619 --- train/comms/pt/commsTraceParser.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train/comms/pt/commsTraceParser.py b/train/comms/pt/commsTraceParser.py index 5e1f3505..46d3adbe 100644 --- a/train/comms/pt/commsTraceParser.py +++ b/train/comms/pt/commsTraceParser.py @@ -233,7 +233,7 @@ def _parseExecutionTrace( break for pg in pgObj: - backendId = pg["backend_id"] + backendId = pg["uid"] if "uid" in pg else pg["backend_id"] ranks = pg["ranks"] if isinstance(ranks, list): pgId = int(pg["pg_name"]) @@ -256,7 +256,7 @@ def _parseExecutionTrace( for node in in_trace.nodes.values(): if node.name == "record_param_comms": shift = ( - 0 if len(node.inputs) == 8 else 1 + 0 if len(node.inputs) == 8 or len(node.inputs) == 10 else 1 ) # wait/barrier ops do not have an input tensor (len=7), shift index one over newComm = commsArgs() newComm.id = node.id