Skip to content

Commit 02a1ba6

Browse files
zpcorePei Zhang
and
Pei Zhang
authored
update DTensor usage with upstream (#9079)
Co-authored-by: Pei Zhang <[email protected]>
1 parent 02c0ed9 commit 02a1ba6

File tree

4 files changed

+15
-15
lines changed

4 files changed

+15
-15
lines changed

docs/source/perf/spmd_advanced.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,14 @@ The main use case for `XLAShardedTensor` [[RFC](https://github.com/pytorch/xla/i
8080
There is also an ongoing effort to integrate <code>XLAShardedTensor</code> into <code>DistributedTensor</code> API to support XLA backend [[RFC](https://github.com/pytorch/pytorch/issues/92909)].
8181

8282
### DTensor Integration
83-
PyTorch has prototype-released [DTensor](https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/README.md) in 2.1.
83+
PyTorch has prototype-released [DTensor](https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/README.md) since 2.1.
8484
We are integrating PyTorch/XLA SPMD into DTensor API [RFC](https://github.com/pytorch/pytorch/issues/92909). We have a proof-of-concept integration for `distribute_tensor`, which calls `mark_sharding` annotation API to shard a tensor and its computation using XLA:
8585
```python
8686
import torch
87-
from torch.distributed import DeviceMesh, Shard, distribute_tensor
87+
from torch.distributed.tensor import init_device_mesh, Shard, distribute_tensor
8888

8989
# distribute_tensor now works with `xla` backend using PyTorch/XLA SPMD.
90-
mesh = DeviceMesh("xla", list(range(world_size)))
90+
mesh = init_device_mesh("xla", mesh_shape=(world_size,))
9191
big_tensor = torch.randn(100000, 88)
9292
my_dtensor = distribute_tensor(big_tensor, mesh, [Shard(0)])
9393
```
@@ -152,15 +152,15 @@ PyTorch/XLA auto-sharding can be enabled by one of the following:
152152
import torch_xla.runtime as xr
153153
xr.use_spmd(auto=True)
154154
```
155-
- Calling `pytorch.distributed._tensor.distribute_module` with `auto-policy` and `xla`:
155+
- Calling `pytorch.distributed.tensor.distribute_module` with `auto-policy` and `xla`:
156156

157157
```python
158158
import torch_xla.runtime as xr
159-
from torch.distributed._tensor import DeviceMesh, distribute_module
159+
from torch.distributed.tensor import init_device_mesh, distribute_module
160160
from torch_xla.distributed.spmd import auto_policy
161161

162162
device_count = xr.global_runtime_device_count()
163-
device_mesh = DeviceMesh("xla", list(range(device_count)))
163+
device_mesh = init_device_mesh("xla", mesh_shape=(device_count,))
164164

165165
# Currently, model should be loaded to xla device via distribute_module.
166166
model = MyModule() # nn.module

test/spmd/test_dtensor_integration.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import torch
55
from torch import nn
66
import torch.optim as optim
7-
from torch.distributed._tensor import (DeviceMesh, Shard, distribute_tensor,
8-
distribute_module)
7+
from torch.distributed.tensor import init_device_mesh, Shard, distribute_tensor, distribute_module
8+
99
import torch_xla
1010
import torch_xla.debug.metrics as met
1111
import torch_xla.runtime as xr
@@ -25,7 +25,7 @@ def setUpClass(cls):
2525

2626
def test_xla_distribute_tensor(self):
2727
device_count = xr.global_runtime_device_count()
28-
device_mesh = DeviceMesh("xla", list(range(device_count)))
28+
device_mesh = init_device_mesh("xla", mesh_shape=(device_count,))
2929
shard_spec = [Shard(0)]
3030

3131
for requires_grad in [True, False]:
@@ -53,7 +53,7 @@ def test_optimizer_step_with_sharding(self):
5353

5454
# Running the same mark_sharding test with xla_distribute_tensor instead
5555
device_count = xr.global_runtime_device_count()
56-
device_mesh = DeviceMesh("xla", list(range(device_count)))
56+
device_mesh = init_device_mesh("xla", mesh_shape=(device_count,))
5757
shard_spec = [Shard(0)]
5858
distribute_tensor(model.fc1.weight, device_mesh, shard_spec)
5959
sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight)
@@ -79,7 +79,7 @@ def test_xla_distribute_module(self):
7979
model = self.SimpleLinear().to(xm.xla_device())
8080

8181
device_count = xr.global_runtime_device_count()
82-
device_mesh = DeviceMesh("xla", list(range(device_count)))
82+
device_mesh = init_device_mesh("xla", mesh_shape=(device_count,))
8383

8484
def shard_params(mod_name, mod, mesh):
8585
shard_spec = [Shard(0)]

test/spmd/test_dtensor_integration2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import torch
55
from torch import nn
66
import torch.optim as optim
7-
from torch.distributed._tensor import (DeviceMesh, Shard, distribute_tensor,
8-
distribute_module)
7+
from torch.distributed.tensor import (DeviceMesh, Shard, distribute_tensor,
8+
distribute_module)
99
import torch_xla
1010
import torch_xla.debug.metrics as met
1111
import torch_xla.runtime as xr

torch_xla/distributed/spmd/api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
import torch
99

1010
import torch.nn as nn
11-
from torch.distributed._tensor.device_mesh import DeviceMesh
12-
from torch.distributed._tensor.placement_types import Placement, Replicate
11+
from torch.distributed import DeviceMesh
12+
from torch.distributed.tensor.placement_types import Placement, Replicate
1313

1414
import torch_xla.core.xla_model as xm # type:ignore[import] # noqa: F401
1515
import torch_xla.runtime as xr # type:ignore[import]

0 commit comments

Comments
 (0)