Skip to content

Commit

Permalink
added veDeviceMesh (#32)
Browse files Browse the repository at this point in the history
This PR introduces veDeviceMesh, the device mesh API that integrates
handling of submeshes and process groups in performing training with
DDP, TP/SP, distributed optimizer and checkpointing. It also updates
fixes and patches related to veDeviceMesh API to the repository since
last PR.
  • Loading branch information
MackZackA authored Apr 23, 2024
1 parent 2f2daaa commit 97735b1
Show file tree
Hide file tree
Showing 145 changed files with 2,768 additions and 273 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ from HuggingFace without any model code modifications.

### Single Machine 8 cards
```
torchrun --nproc-per-node=8 --nnodes=1 --master-port=42516 -- python/example/mixtral_4D_benchmark/mixtral_train.py --num_hidden_layers=16
torchrun --nproc-per-node=8 --nnodes=1 --master-port=42516 -- examples/mixtral_4D_benchmark/mixtral_train.py --num_hidden_layers=16
```
This will start a 8-cards MFU benchmark for Mixtral with veScale with dp=1 and tp=8.

### Distributed Environment (2 Machine 16 cards example)
```
# You may need to pull up a suitable distributed cluster environment
torchrun --nproc-per-node=8 --nnodes=1 python/example/mixtral_4D_benchmark/mixtral_train.py --tp 8 --dp 2
torchrun --nproc-per-node=8 --nnodes=1 examples/mixtral_4D_benchmark/mixtral_train.py --tp 8 --dp 2
```
This will start a 16 cards MFU benchmark for Mixtral with veScale with dp=2 and tp=8.

Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@

import numpy as np
import torch
from torch.distributed import broadcast, all_reduce, barrier, init_process_group, destroy_process_group
from torch.distributed import broadcast, all_reduce, barrier, init_process_group, destroy_process_group, get_rank

from model import GPTConfig, GPT
from vescale.devicemesh_api.device_mesh_api import veDeviceMesh

from vescale.dtensor.device_mesh import init_device_mesh
from vescale import distribute_tensor
from vescale.dmodule.api import parallelize_module
from vescale.dtensor.placement_types import Replicate
Expand Down Expand Up @@ -113,8 +113,9 @@ def main():
device = f"cuda:{rank}"
torch.cuda.set_device(device)
init_process_group(backend=backend, world_size=world_size, rank=rank)
mesh = init_device_mesh(device, (dp_size, tp_size), mesh_dim_names=["DP", "TP"])
ddp_rank = mesh.get_rank() // tp_size

mesh = veDeviceMesh.init_device_mesh(device, (dp_size, tp_size), mesh_dim_names=["DP", "TP"])
ddp_rank = get_rank() // tp_size
else:
rank = 0
ddp_rank = 0
Expand Down Expand Up @@ -329,8 +330,7 @@ def get_lr(it):
# Load checkpoint
if load_checkpoint_path:
checkpoint_state = {"model": model, "optimizer": optimizer}
with mesh:
vescale.checkpoint.load(load_checkpoint_path, checkpoint_state)
vescale.checkpoint.load(load_checkpoint_path, checkpoint_state)
# training loop
X, Y = get_batch("train") # fetch the very first batch
t0 = time.time()
Expand Down Expand Up @@ -363,8 +363,7 @@ def get_lr(it):
# When iter_num == 0, the training does not start sotoptimizer state is empty,
# Don't save checkpoint
checkpoint_state = {"model": model, "optimizer": optimizer}
with mesh:
vescale.checkpoint.save(os.path.join(save_checkpoint_path, f"iter_{iter_num}"), checkpoint_state)
vescale.checkpoint.save(os.path.join(save_checkpoint_path, f"iter_{iter_num}"), checkpoint_state)
if iter_num == 0 and eval_only:
break

Expand Down
File renamed without changes.
File renamed without changes.
3 changes: 1 addition & 2 deletions python/requirements.txt → requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,4 @@ tqdm
optree
accelerate
transformers==4.37.2
grpcio
grpcio-tools
flash_attn
10 changes: 6 additions & 4 deletions scripts/run_test.sh
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
#!/bin/bash

echo "run all tests (for open source)"

set -ex

SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"

pushd "$SCRIPT_DIR"/..

# install vescale
pushd python && pip3 install -r requirements.txt --cache-dir "${HOME}"/.cache/pip && pip3 install -e . && popd
pip3 install -r requirements.txt --cache-dir "${HOME}"/.cache/pip && pip3 install -e .

# jump to test folder
pushd test/

PYTHONPATH=$(pwd):$PYTHONPATH

export PYTHONPATH
export PYTHONPATH=$(pwd):$PYTHONPATH
export VESCALE_SINGLE_DEVICE_RAND="1"

# run test
while IFS= read -r -d '' file
Expand Down
File renamed without changes.
27 changes: 15 additions & 12 deletions test/checkpoint/common_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import math

from vescale.dtensor.placement_types import Replicate, Shard
from vescale.dtensor.device_mesh import init_device_mesh
from vescale.dmodule.api import parallelize_module
from vescale.ddp.distributed_data_parallel import DistributedDataParallel as DDP
from vescale.optim.distributed_optimizer import DistributedOptimizer
Expand Down Expand Up @@ -121,9 +120,17 @@ def build_gpt_model_optimizer_and_dataset(init_method, dp_size=1, tp_size=1):
else:
gpt = GPT.from_pretrained(init_method, dict(dropout=0.0)).bfloat16()

device_mesh = init_device_mesh("cuda", (dp_size, tp_size), mesh_dim_names=("DP", "TP"))
device_mesh.__enter__()

open_source = False
try:
from vescale.devicemesh_api import veDeviceMesh
except ImportError:
open_source = True
device_mesh = veDeviceMesh.init_device_mesh(
device_type="cuda",
mesh_shape=(dp_size, tp_size),
mesh_dim_names=("DP", "TP"),
)

# Enable tensor Parallel
tp_gpt = parallelize_module(gpt, device_mesh["TP"], nanoGPT_plan)

Expand Down Expand Up @@ -273,15 +280,11 @@ def get_open_llama_model(layer_number=None):


def get_open_llama_model_optimizer(dp_size, tp_size, layer_number=None):
device_mesh = init_device_mesh(
"cuda",
(
dp_size,
tp_size,
),
mesh_dim_names=("DP", "TP"),
from vescale.devicemesh_api import veDeviceMesh

device_mesh = veDeviceMesh.init_device_mesh(
"cuda", (dp_size, tp_size), mesh_dim_names=("DP", "TP"), check_uniqueness=True
)
device_mesh.__enter__()
# Set 4 layers to avoid timeout on CI
# Use 32 layers when running on training platform
vescale_decoder, config = get_open_llama_model(layer_number=layer_number)
Expand Down
12 changes: 10 additions & 2 deletions test/checkpoint/nano_gpt/test_nano_gpt_load_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch.distributed as dist
from common_dtensor import DTensorTestBase, with_comms, skip_unless_torch_gpu
from torch.testing._internal.common_utils import run_tests
from vescale.dtensor.device_mesh import mesh_resources
from vescale.devicemesh_api.device_mesh_api import veDeviceMesh
import vescale
from vescale.dtensor.placement_types import Replicate

Expand Down Expand Up @@ -46,7 +46,15 @@ def test_save(self):
ddp_gpt, dist_optimizer, data_set = build_gpt_model_optimizer_and_dataset(
self.init_method, dp_size=2, tp_size=2
)
device_mesh = mesh_resources.get_current_mesh()

# turn off 'check_uniqueness' to allow multiple updates of global device mesh during runtime
device_mesh = veDeviceMesh.init_device_mesh(
device_type="cuda",
mesh_shape=(1, 2, 2),
mesh_dim_names=("PP", "DP", "TP"),
)
tp_sub_mesh = device_mesh["TP"]

# Do fwd+bwd+step on the first data
for X, Y in data_set[:1]:
input = vescale.distribute_tensor(X, device_mesh["TP"], [Replicate()])
Expand Down
17 changes: 4 additions & 13 deletions test/checkpoint/open_llama/test_open_llama_dp_reshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@
import torch
import torch.distributed as dist
from torch.testing._internal.common_utils import run_tests

from vescale.devicemesh_api import veDeviceMesh
from common_dtensor import DTensorTestBase, with_comms
from vescale.dtensor.device_mesh import mesh_resources
import vescale

from ..common_func import merge_optimizer_states, get_open_llama_model_optimizer
Expand Down Expand Up @@ -54,17 +53,12 @@ def test_open_llama2_with_ddp(self):

ckpt_state = {"model": ddp_decoder, "optimizer": ve_optimizer}
vescale.checkpoint.save(TMP_CKPT_DIR, ckpt_state)
device_mesh = mesh_resources.get_current_mesh()
dp_device_mesh = device_mesh["DP"]
dp_process_group = dp_device_mesh.get_dim_groups(0)
tp_device_mesh = device_mesh["TP"]
tp_process_group = tp_device_mesh.get_dim_groups(0)
# For processes with dp_rank = 0, dump model state_dict
if dist.get_rank(dp_process_group) == 0:
if veDeviceMesh.get_data_parallel_rank() == 0:
dumped_model_sd = {}
for k, v in ddp_decoder.state_dict().items():
dumped_model_sd[k] = v._local_tensor
torch.save(dumped_model_sd, f"open_llama_dp_reshard_model_tp_{dist.get_rank(tp_process_group)}.pt")
torch.save(dumped_model_sd, f"open_llama_dp_reshard_model_tp_{veDeviceMesh.get_tensor_parallel_rank()}.pt")

# Save merged optimizer state dict
optimizer_state = ve_optimizer.state_dict()
Expand All @@ -90,12 +84,9 @@ def test_open_llama2_with_ddp(self):

ckpt_state = {"model": ddp_decoder, "optimizer": ve_optimizer}
vescale.checkpoint.load(TMP_CKPT_DIR, ckpt_state)
device_mesh = mesh_resources.get_current_mesh()
tp_device_mesh = device_mesh["TP"]
tp_process_group = tp_device_mesh.get_dim_groups(0)
# Load model state dict and verify it
dumped_model_sd = torch.load(
f"open_llama_dp_reshard_model_tp_{dist.get_rank(tp_process_group)}.pt", map_location="cpu"
f"open_llama_dp_reshard_model_tp_{veDeviceMesh.get_tensor_parallel_rank()}.pt", map_location="cpu"
)

current_model_sd = ddp_decoder.state_dict()
Expand Down
9 changes: 7 additions & 2 deletions test/dmodule/test_dfactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from vescale.dmodule import _factory
from vescale.dmodule.api import parallelize_module
from vescale.dmodule.placements_interface import PlacementsInterface as PI
from vescale.dtensor.random import manual_seed

HIDDEN_SIZE = 4

Expand Down Expand Up @@ -111,8 +112,12 @@ def _match_factory_dfactory(self, factory, dfactory, global_shape, placements, d
actuals = (actual,)
goldens = (golden,)
elif factory in [torch.zeros, torch.ones, torch.empty, torch.randn]:
if factory == torch.randn:
manual_seed(0, device_mesh)
with _factory.FactoryDispatchMode(device_mesh=device_mesh, aten_dfactory_pi=aten_dfactory_pi):
actual = factory(global_shape, dtype=dtype, layout=layout, requires_grad=requires_grad)
if factory == torch.randn:
manual_seed(0, device_mesh)
golden = dfactory(
global_shape,
dtype=dtype,
Expand All @@ -129,7 +134,7 @@ def _match_factory_dfactory(self, factory, dfactory, global_shape, placements, d
for actual, golden in zip(actuals, goldens):
self.assertTrue(isinstance(actual, DTensor))
self.assertTrue(isinstance(golden, DTensor))
if factory in [torch.empty, torch.randn]: # TODO: fix torch.rand to equal
if factory in [torch.empty]: # TODO: fix torch.rand to equal
is_match = dtensor._utils._equal_meta_data(actual, golden, exact_device=True)
else:
is_match = dtensor.equal(actual, golden)
Expand All @@ -155,7 +160,7 @@ def test_match_factory_dfactory(self):

# self._seeding()
for factory, dfactory in factory_dfactory.items():
for global_shape in [(4, 4), (5, 4)]:
for global_shape in [(4, 4), (5, 4), (5, 7, 9)]:
for placements in ([Replicate()], [Shard(0)]):
self._match_factory_dfactory(factory, dfactory, global_shape, placements, device_mesh)

Expand Down
27 changes: 27 additions & 0 deletions test/dtensor/general/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed.fake_pg import FakeStore

from vescale.dtensor import rand as dtensor_rand
from vescale import DeviceMesh, DTensor, distribute_tensor
from vescale.dtensor.placement_types import Partial, Replicate, Shard
from vescale.dtensor.random import manual_seed


class DTensorTest(DTensorTestBase):
Expand Down Expand Up @@ -96,6 +98,14 @@ def test_dtensor_stride(self):
global_stride = (8 * self.world_size, 1, 32 * self.world_size)
self.assertEqual(dist_tensor.stride(), global_stride)

local_tensor = torch.randn(1, 0, 24, 128)
dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard1_spec)
self.assertEqual(dist_tensor.stride(), (24 * 128, 24 * 128, 128, 1))

local_tensor = torch.randn(1, 24, 1, 128)
dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard1_spec)
self.assertEqual(dist_tensor.stride(), (24 * 128 * self.world_size, 128, 128, 1))

@with_comms
def test_from_local_default(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
Expand Down Expand Up @@ -637,6 +647,23 @@ def test_default_value_sub_mesh(self):
[dt.to_local() for dt in dtensor_list],
)

@with_comms
def test_random_sub_mesh(self):
mesh = DeviceMesh(self.device_type, [0, 2])
global_shape = [7, 9]
placements = [Shard(0)]
torch.manual_seed(0)
torch.cuda.manual_seed(0)
expected_tensor = torch.rand(global_shape, device=self.device_type)
dist_expected = distribute_tensor(expected_tensor, mesh, placements)
print(f"rank {dist.get_rank()} expected_local {dist_expected.to_local()}")

# create DTensor
manual_seed(0, mesh)
ve_tensor = dtensor_rand(global_shape, device_mesh=mesh, placements=placements)

self.sub_mesh_assert_equal(mesh.mesh, dist_expected.to_local(), dist_expected.to_local(), ve_tensor.to_local())

@with_comms
def test_redistribute_sub_mesh(self):
mesh = DeviceMesh(self.device_type, [0, 2])
Expand Down
26 changes: 24 additions & 2 deletions test/dtensor/general/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,30 @@ def _rand_init_compare(self, init_op, dist_init_op, *args, **kwargs):
@skip_unless_torch_gpu
@with_comms
def test_randn_value(self):
self._rand_init_self_compare(dtensor.randn)
# self._rand_init_compare(torch.randn, dtensor.randn) # NOTE: Upstream doesn't match
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
torch_op = torch.randn
dtensor_op = dtensor.randn
for global_shape in [
(8,),
(8, 9),
(9, 10, 11, 12),
(33, 11, 13),
]:
all_placements = [[Replicate()], [Partial()]] + [[Shard(d)] for d in range(len(global_shape))]
for placements in all_placements:
torch.manual_seed(0)
torch.cuda.manual_seed(0)
expected_tensor = torch_op(global_shape, device=self.device_type)
dist_expected = distribute_tensor(expected_tensor, device_mesh, placements)

# create DTensor
manual_seed(0, device_mesh)
ve_tensor = dtensor_op(global_shape, device_mesh=device_mesh, placements=placements)

self.assertEqual(ve_tensor.to_local(), dist_expected.to_local(), atol=0.0, rtol=0.0)
global_tensor = ve_tensor.full_tensor()
expected_tensor = dist_expected.full_tensor()
self.assertEqual(global_tensor, expected_tensor, atol=0.0, rtol=0.0)

@with_comms
def test_arange(self):
Expand Down
2 changes: 1 addition & 1 deletion test/dtensor/ops/test_basic_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def test_mm_2d_mesh(self):
([Shard(0), Replicate()], [Replicate(), Replicate()], [Shard(0), Replicate()]),
([Replicate(), Replicate()], [Shard(1), Replicate()], [Shard(1), Replicate()]),
([Replicate(), Replicate()], [Replicate(), Replicate()], [Replicate(), Replicate()]),
# TODO(cery.di) : support 2d/3d mesh strategy mapping
# TODO : support 2d/3d mesh strategy mapping
([Replicate(), Shard(1)], [Replicate(), Shard(0)], [Replicate(), Partial()]),
)

Expand Down
Loading

0 comments on commit 97735b1

Please sign in to comment.