diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index a5fcf50465..534174380f 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -234,6 +234,7 @@ def initialize_ub( ranks_per_domain_list, backend=bootstrap_backend ) local_rank = torch.distributed.get_rank(intra_domain_group) + intra_domain_ranks = torch.distributed.get_process_group_ranks(intra_domain_group) inter_domain_group, _ = torch.distributed.new_subgroups_by_enumeration( [list(ranks) for ranks in zip(*ranks_per_domain_list)],