This github's target is to enable MPI-DDP in PyTorch. As you know, PyTorch DDP only support nccl and gloo backends.
You will be able to enable the distributed MPI-backend PyTorch Training with only 2 lines:
- add DistributedSampler in your DataLoader
- pass your model to DistributedDataParallel
This usage is exactly the same as the torch.nn.parallel.DistributedDataParallel() See imagenet example here: https://github.com/pytorch/examples/blob/master/imagenet/main.py#L88
- Pytorch : build from source (v0.3.1 is recommended)
bash run.sh
This github implemented a strong scaling for mnist, which means the global batchsize is fixed no matter how many node we use. See more info about Strong vs Weak Scaling at wiki.
Since this is a strong scaling example, we should perform an average after the all_reduce, which is the same as torch.nn.parallel.DistributedDataParallel.