-
Notifications
You must be signed in to change notification settings - Fork 480
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support dist.broadcast #7956
base: master
Are you sure you want to change the base?
support dist.broadcast #7956
Conversation
XLATensorPtr xmask = bridge::GetXlaTensor(mask); | ||
auto masked_input = tensor_methods::mul(xinput, xmask); | ||
auto result = tensor_methods::all_reduce(masked_input, AllReduceType::kSum, | ||
1.0, {}, true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: name the non-obvious arguments at the end here. Assuming these two are scale and replica groups, /*scale=*/1, /*groups=*/{}
(double check the names).
@@ -139,7 +140,7 @@ def test_all_to_all(self, pin_layout): | |||
list(range(world_size))]]) | |||
|
|||
|
|||
@absltest.skipIf(lambda: tpu.num_logical_cores_per_chip() >= 2, | |||
@absltest.skipIf(tpu.num_logical_cores_per_chip() >= 2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤦 thanks
|
||
|
||
# "broadcast(Tensor self, int src, str tag, int[] ranks, int group_size) -> Tensor", | ||
@torch.library.impl("_c10d_functional::broadcast", "XLA") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JackCaoG FYI
at::Tensor mask; | ||
const torch::lazy::BackendDevice& device = xinput->GetDevice(); | ||
if (device.ordinal() == src) { | ||
mask = at::ones_like(input); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there an equivalent to torch.no_grad()
in C++? That's the only difference I see between the original python version and this one
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Searched the doc and we can use the following scope for tensor operation without grad:
{
at::NoGradGuard no_grad;
// tensor operations
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Support torch.distributed.broadcast for both dynamo and nondynamo.
This PR needs pytorch/pytorch#135171 to be merged first.