Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support dist.broadcast #7956

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open

support dist.broadcast #7956

wants to merge 3 commits into from

Conversation

zpcore
Copy link
Collaborator

@zpcore zpcore commented Sep 5, 2024

Support torch.distributed.broadcast for both dynamo and nondynamo.

This PR needs pytorch/pytorch#135171 to be merged first.

@zpcore zpcore added usability Bugs/features related to improving the usability of PyTorch/XLA tpuci labels Sep 5, 2024
@zpcore zpcore marked this pull request as ready for review September 5, 2024 02:04
XLATensorPtr xmask = bridge::GetXlaTensor(mask);
auto masked_input = tensor_methods::mul(xinput, xmask);
auto result = tensor_methods::all_reduce(masked_input, AllReduceType::kSum,
1.0, {}, true);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: name the non-obvious arguments at the end here. Assuming these two are scale and replica groups, /*scale=*/1, /*groups=*/{} (double check the names).

@@ -139,7 +140,7 @@ def test_all_to_all(self, pin_layout):
list(range(world_size))]])


@absltest.skipIf(lambda: tpu.num_logical_cores_per_chip() >= 2,
@absltest.skipIf(tpu.num_logical_cores_per_chip() >= 2,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤦 thanks



# "broadcast(Tensor self, int src, str tag, int[] ranks, int group_size) -> Tensor",
@torch.library.impl("_c10d_functional::broadcast", "XLA")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JackCaoG FYI

at::Tensor mask;
const torch::lazy::BackendDevice& device = xinput->GetDevice();
if (device.ordinal() == src) {
mask = at::ones_like(input);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an equivalent to torch.no_grad() in C++? That's the only difference I see between the original python version and this one

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Searched the doc and we can use the following scope for tensor operation without grad:

  {
    at::NoGradGuard no_grad;
    // tensor operations
   }

Copy link
Collaborator Author

@zpcore zpcore Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyone knows why we set no grad here:

with torch.no_grad():
@JackCaoG

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tpuci usability Bugs/features related to improving the usability of PyTorch/XLA
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants