-
Notifications
You must be signed in to change notification settings - Fork 326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fp8_model_init doesn't work with DDP #1135
Comments
Do you need both DDP and FP8 params for your use-case? We haven't considered this combination so far since optimizing FP8 params tends to have poor convergence. There are a few ways to proceed:
|
@MaciejBalaNV Transformer Engine modules that are initialized under For reference, here's a modified version of your DDP example that works correctly on my end: import os
import socket
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import transformer_engine as te
class BasicMLP(nn.Module):
"""Basic MLP block"""
def __init__(self, hidden_size, ffn_hidden_size, **kwargs):
super().__init__()
tp_group = kwargs.pop("tp_group", None)
parallel_mode = kwargs.pop("parallel_mode", None)
fc1_parallel_mode = fc2_parallel_mode = parallel_mode
if tp_group is not None:
fc1_parallel_mode = "row"
fc2_parallel_mode = "column"
self.fc1 = te.pytorch.Linear(hidden_size, ffn_hidden_size,parallel_mode=fc1_parallel_mode,
**kwargs)
self.fc2 = te.pytorch.Linear(ffn_hidden_size, hidden_size, parallel_mode=fc2_parallel_mode,
**kwargs)
def forward(self, x):
"""Forward pass: FC2(act_fn(FC1(x)))"""
return self.fc2(self.fc1(x))
def _ddp_main(rank, world_size, num_replicas):
SEQ_LENGTH = 512
BATCH_SIZE = 2
HIDDEN_SIZE = 256
FFN_HIDDEN_SIZE = 4 * HIDDEN_SIZE
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = socket.gethostname()
os.environ["MASTER_PORT"] = "12345"
dist.init_process_group(backend="nccl")
torch.cuda.set_device(rank)
if num_replicas == 1:
dp_group = None
tp_group = dist.new_group()
elif num_replicas == world_size:
dp_group = dist.new_group()
tp_group = None
else:
assert num_replicas > 0 and num_replicas < world_size and world_size % num_replicas == 0
replica_size = world_size // num_replicas
mesh_2d = dist.init_device_mesh("cuda", (num_replicas, replica_size))
dp_group, tp_group = mesh_2d.get_all_groups()
with te.pytorch.fp8.fp8_model_init(enabled=True):
model = BasicMLP(HIDDEN_SIZE, FFN_HIDDEN_SIZE, tp_group=tp_group)
if dp_group is not None:
model = DDP(model, process_group=dp_group)
optim = torch.optim.Adam(model.parameters())
for _ in range(10):
optim.zero_grad()
input_data = torch.randn((SEQ_LENGTH, BATCH_SIZE, HIDDEN_SIZE), device="cuda")
with te.pytorch.fp8_autocast(enabled=True):
output = model(input_data)
loss = output.sum()
loss.backward()
optim.step()
dist.destroy_process_group()
if __name__ == "__main__":
NUM_REPLICAS = 2
if "TORCHELASTIC_RUN_ID" in os.environ:
# Using the `torchrun` utility
WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
_ddp_main(WORLD_RANK, WORLD_SIZE, NUM_REPLICAS)
else:
WORLD_SIZE = 8
mp.spawn(_ddp_main, args=(WORLD_SIZE, 2), nprocs=WORLD_SIZE, join=True) |
@timmoon10 @denera |
When I'm trying to use
fp8_model_init
feature, it doesn't seem compatible with DDP. It throws an error:RuntimeError: Modules with uninitialized parameters can't be used with "DistributedDataParallel". Run a dummy forward pass to correctly initialize the modules
Running a dummy forward pass doesn't help, using
reset_parameters
doesn't help either. Using a separate stream for DDP also does not fix this issue.A simple reproducible case:
@denera
The text was updated successfully, but these errors were encountered: