Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support dist.broadcast #7956

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion test/pjrt/test_collective_ops_tpu.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import random
from typing import List
import torch
import torch.nn as nn
Expand Down Expand Up @@ -139,7 +140,7 @@ def test_all_to_all(self, pin_layout):
list(range(world_size))]])


@absltest.skipIf(lambda: tpu.num_logical_cores_per_chip() >= 2,
@absltest.skipIf(tpu.num_logical_cores_per_chip() >= 2,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤦 thanks

"Dynamo not supported on TPU v2/v3")
class TestDistCollectiveOpsTpu(parameterized.TestCase):
"""Test for collective ops from torch.distributed"""
Expand Down Expand Up @@ -246,6 +247,29 @@ def callable(output, input):
assert 'xla::reduce_scatter_tensor' in met.counter_names()
return output.cpu()

@staticmethod
def _broadcast(src: int, random_fill: int, use_dynamo: bool):
met.clear_all()
dist.init_process_group("xla", init_method='xla://')
device = xm.xla_device()

def callable(input, src):
dist.broadcast(input, src)
return input

tensor_in = torch.tensor([xr.global_ordinal(), random_fill],
dtype=torch.float,
device=device)
f = torch.compile(callable, backend='openxla') if use_dynamo else callable
output = f(tensor_in, src)
torch_xla.sync()
if not use_dynamo:
assert 'xla::AllReduceInPlace' in met.counter_names(
) or 'xla::AllReduce' in met.counter_names()
else:
assert 'xla::collective_broadcast' in met.counter_names()
return output.cpu()

@parameterized.named_parameters(('dynamo', True), ('nondynamo', False))
def test_all_reduce(self, use_dynamo):
results = pjrt.run_multiprocess(self._all_reduce, use_dynamo=use_dynamo)
Expand Down Expand Up @@ -287,6 +311,22 @@ def test_reduce_scatter(self, use_dynamo):
for index, val in results.items():
torch.testing.assert_close(val, expected[index])

@parameterized.named_parameters(('dynamo', True), ('nondynamo', False))
def test_broadcast(self, use_dynamo):
src = random.randrange(0, tpu.num_expected_global_devices())
random_fill = random.randint(-100, 100)
results = pjrt.run_multiprocess(
self._broadcast,
src=src,
random_fill=random_fill,
use_dynamo=use_dynamo)
expected = torch.tensor([
src,
random_fill,
], dtype=torch.float)
for index, val in results.items():
torch.testing.assert_close(val, expected)


if __name__ == '__main__':
absltest.main()
2 changes: 1 addition & 1 deletion torch_xla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def _init_xla_lazy_backend():
from .torch_xla import *

# register all custom kenels and decomp by default
from ._internal import custom_kernel, decomp_registration, c10d_registration
from ._internal import custom_kernel, decomp_registration

# select default PJRT_DEVICE before any execution
runtime._maybe_select_default_device()
22 changes: 0 additions & 22 deletions torch_xla/_internal/c10d_registration.py

This file was deleted.

22 changes: 22 additions & 0 deletions torch_xla/csrc/cross_replica_reduces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,28 @@ AllGatherResultCoalesced BuildAllGatherCoalesced(
return {result, token_handler.GetNewToken(result[0])};
}

at::Tensor collective_broadcast(const at::Tensor& input, int64_t src,
std::string) {
XLATensorPtr xinput = bridge::GetXlaTensor(input);
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
at::Tensor mask;
const torch::lazy::BackendDevice& device = xinput->GetDevice();
if (device.ordinal() == src) {
mask = at::ones_like(input);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an equivalent to torch.no_grad() in C++? That's the only difference I see between the original python version and this one

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Searched the doc and we can use the following scope for tensor operation without grad:

  {
    at::NoGradGuard no_grad;
    // tensor operations
   }

Copy link
Collaborator Author

@zpcore zpcore Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyone knows why we set no grad here:

with torch.no_grad():
@JackCaoG

} else {
mask = at::zeros_like(input);
}
XLATensorPtr xmask = bridge::GetXlaTensor(mask);
auto masked_input = tensor_methods::mul(xinput, xmask);
auto result = tensor_methods::all_reduce(masked_input, AllReduceType::kSum,
1.0, {}, true);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: name the non-obvious arguments at the end here. Assuming these two are scale and replica groups, /*scale=*/1, /*groups=*/{} (double check the names).

return bridge::AtenFromXlaTensor(result);
}

TORCH_LIBRARY_IMPL(_c10d_functional, XLA, m) {
m.impl("broadcast", collective_broadcast);
}

CollectivePermuteResult BuildCollectivePermute(
xla::XlaOp input, xla::XlaOp token,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs) {
Expand Down
Loading