Skip to content

Commit

Permalink
Merge xm all_gather patch (#3416)
Browse files Browse the repository at this point in the history
* Set proper shard_count for all_gather, when replica groups are non-empty.

* Update test_mp_all_gather.py
  • Loading branch information
yeounoh authored Mar 9, 2022
1 parent c91b766 commit 3b12115
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 7 deletions.
32 changes: 26 additions & 6 deletions test/test_mp_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,41 @@

def _mp_fn(index):
device = xm.xla_device()
world_size = xm.xrt_world_size()
if xm.xla_device_hw(device) in ('TPU', 'GPU'):
world_size = xm.xrt_world_size()
# Testing with a single replica group
ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device)
result = xm.all_gather(ordinal_tensor)
result = xm.all_gather(ordinal_tensor, dim=0)

cpu_result = result.cpu()
expected = torch.arange(0, world_size, dtype=torch.float)
if not cpu_result.allclose(expected):
print('xm.all_gather() produced wrong reductions', file=sys.stderr)
print('[{}] {}'.format(index, cpu_result), file=sys.stderr)
print(f'[{index}] {cpu_result}', file=sys.stderr)
sys.exit(1)

# Testing with two replica groups
if world_size % 2 == 0 and world_size > 1:
mp_groups = [[n for n in range(world_size) if n % 2 == 0],
[n for n in range(world_size) if n % 2 == 1]]
group_size = len(mp_groups[0])
replica_id = int(index % 2 == 1)

result = xm.all_gather(ordinal_tensor, dim=0, groups=mp_groups)

cpu_result = result.cpu()
expected = torch.arange(replica_id, world_size, step=2, dtype=torch.float)
if not cpu_result.allclose(expected):
print('xm.all_gather() produced wrong reductions', file=sys.stderr)
print(f'[{index}] {cpu_result}', file=sys.stderr)
sys.exit(1)
else:
print(
f'Failed to create two replica groups with {world_size} replicas',
file=sys.stderr)

else:
print(
'Default device {} is not a TPU or GPU device'.format(device),
file=sys.stderr)
print(f'{device} is not a TPU or GPU device', file=sys.stderr)


if __name__ == '__main__':
Expand Down
8 changes: 7 additions & 1 deletion torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,13 @@ def all_gather(value, dim=0, groups=None, output=None):
if dim < 0:
dim = value.dim() + dim
token, devctx = _get_all_reduce_token()
shard_count = None if groups else xrt_world_size()
if groups:
shard_count = len(groups[0])
assert all(len(group) == shard_count for group in groups), \
"Replica groups must have the same number of replicas/shards."
else:
# All replicas belong to a single group
shard_count = xrt_world_size()
if output != None:
# Call the out of place version of the all_gather
new_token = torch_xla._XLAC._xla_all_gather_out(output, value, token, dim,
Expand Down

0 comments on commit 3b12115

Please sign in to comment.