diff --git a/train/comms/pt/commsTraceParser.py b/train/comms/pt/commsTraceParser.py index 5e1f3505..46d3adbe 100644 --- a/train/comms/pt/commsTraceParser.py +++ b/train/comms/pt/commsTraceParser.py @@ -233,7 +233,7 @@ def _parseExecutionTrace( break for pg in pgObj: - backendId = pg["backend_id"] + backendId = pg["uid"] if "uid" in pg else pg["backend_id"] ranks = pg["ranks"] if isinstance(ranks, list): pgId = int(pg["pg_name"]) @@ -256,7 +256,7 @@ def _parseExecutionTrace( for node in in_trace.nodes.values(): if node.name == "record_param_comms": shift = ( - 0 if len(node.inputs) == 8 else 1 + 0 if len(node.inputs) == 8 or len(node.inputs) == 10 else 1 ) # wait/barrier ops do not have an input tensor (len=7), shift index one over newComm = commsArgs() newComm.id = node.id