Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
ilmarkov committed Nov 17, 2022
0 parents commit fb71bab
Show file tree
Hide file tree
Showing 57 changed files with 8,723 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
/.idea/*
/build/*
/dist/*
11 changes: 11 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
FROM nvcr.io/nvidia/pytorch:22.10-py3

RUN apt-get update && apt-get install -y --no-install-recommends openssh-client openssh-server && \
mkdir -p /var/run/sshd

ENV MPI_HOME=/opt/hpcx/ompi/
ENV NCCL_INCLUDE=/usr/include
ENV NCCL_LIB=/usr/lib/x86_64-linux-gnu/

RUN git clone https://github.com/IST-DASLab/torch_cgx /torch_cgx &&\
cd /torch_cgx && python setup.py install
617 changes: 617 additions & 0 deletions LICENSE.txt

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
graft src
90 changes: 90 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# CGX

CGX is a pytorch extension adding a backend for pytorch distributed supporting allreduce of quantized buffers.
It supports quantizations of float16, float32 to 1-8 bits.

CGX is based on MPI torch.distributed backend. The extension essentially only replaces allreduce primitive.

## Quick Start

### Prerequisites
CGX, as a pytorch extension, requires `pytorch>=1.10.0`.

For faster build we recommend to have `ninja` installed (`pip install ninja`).

The compression is only supported for GPU-based buffers so either CUDA or ROCm is required.
If CUDA or ROCm are installed not in the standard paths, set `[CUDA|ROCM]_HOME` or `[CUDA|ROCM]_PATH` accordingly.

As long as it is based on MPI, it requires OpenMPI with GPU support installed (other MPI implementations were not tested).
Also, the library supports NCCL based communications, so it requires NVIDIA NCCL library.

### Build from source
Set `MPI_HOME` environment variable to mpi home. In case of AMD GPU, set `CGX_CUDA` to 0.
Set `NCCL_HOME` environment variable to NCCL home, or `NCCL_INCLUDE` and `NCCL_LIB`.
Set `QSGD_DETERMENISTIC=0` if you want to have stochastic version QSGD.

```bash
git clone https://github.com/IST-DASLab/torch_cgx
export MPI_HOME=/path/to/mpi
export NCCL_HOME=/path/to/nccl
python setup.py install
```

### Usage
The only changes to the training script using pytorch distributed required
are importing the built extension and specifying `cgx` as `torch.distributed.init_process_group` backend parameter.

Example:
``` python
import torch
import torch.distributed as dist
import torch_cgx

dist.init_process_group('cgx', init_method='env://', rank=args.local_rank)
```
Also, it order to perform layerwise compression and being able to filter small sensitive to gradient compression
layers (typically these are batch norm layers and biases) the `cgx` needs to have information about the model.
For that users need to register the communication hook. The minimal size of the layers can be
controlled with `layer_min_size` parameter.

``` python
model = torch.
from cgx_utils import cgx_hook, CGXState
state = CGXState(torch.distributed.group.WORLD, layer_min_size=1024,
compression_params={"bits": args.quantization_bits,
"bucket_size": args.quantization_bucket_size})
model.register_comm_hook(state, cgx_hook)
```

As long as the extension is based on MPI backend, it requires MPI-compliant launcher (`torch.distributed.launch` won't work):
`$ mpirun -np 2 python train.py`

Also, if your training script was run previously with `torch.distributed.launch` utility, due to MPI launcher you need to set an environment variables (see cifar_train.py in examples)
```
if "OMPI_COMM_WORLD_SIZE" in os.environ:
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '4040'
os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
```

## Tuning
CGX can be tuned with the following environment variables:

- `CGX_COMPRESSION_QUANTIZATION_BITS` - number of bits each value of buffer is quantized to (from 1 to 8). Default is 32 which means no quantization is applied. This variable must be used if the `cgx_hook` communication hook is not registered.
- `CGX_COMPRESSION_BUCKET_SIZE` - size of subarray into which buffer is split before quantization. Default is 512.
- `CGX_COMPRESSION_SKIP_INCOMPLETE_BUCKETS` - boolean variable (0 or 1). After the splitting buffer into buckets, some values of buffer may remain. The variable tells quantization algorithm to compress or not to compress the remaining values. Default 0.
- `CGX_COMPRESSION_MINIMAL_SIZE` - minimal size of buffer (number of elements) to compress. Default is 0 but in fact minimal size is forced to be not less than 16.
- `CGX_FUSION_BUFFER_SIZE_MB`. CGX is leveraging [Tensor Fusion](https://github.com/horovod/horovod#tensor-fusion), a performance feature introduced in Horovod. This feature batches small allreduce operations. This decreases a latency in Data Parallel training. The environment variable controls the size of maximal buffer (in MB) that is communicated within one iteration of allreduce algorithm. Default is 64. The variable must be set **before** loading the module.
- `CGX_INNER_COMMUNICATOR_TYPE`. Specifies what library to use as communication backend for intra node communication (MPI, SHM, NCCL).
- `CGX_CROSS_COMMUNICATOR_TYPE`. Specifies what library to use as communication backend for inter node communication (MPI, NCCL).
- `CGX_INTRA_BROADCAST`. Parameter for multinode training. When enabled, inter-node communication is performed by only one gpu per node.

## Examples

Basic examples are provided under the [example](examples) folder.

## Notes
- As Compression method, basic max-min uniform quantization function is used. In order to use max-min with random rounding like in QSGD, compile the library with QSGD_DETERMINISTIC=0
- Reduction algorithm: Scatter-Reduce-AllGather.
- Part of the source code is based on [Horovod](https://github.com/horovod/horovod) and [NCCL](https://github.com/NVIDIA/nccl) sources.
1 change: 1 addition & 0 deletions cgx_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .allreduce_hooks import CGXState, cgx_hook
73 changes: 73 additions & 0 deletions cgx_utils/allreduce_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# pytorch-cgx
#
# Copyright (C) 2022 IST Austria
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from typing import Dict
import torch
import torch.distributed as dist
import torch_cgx
import os

COMPRESSION_QUANTIZATION_BITS = "CGX_COMPRESSION_QUANTIZATION_BITS"
COMPRESSION_BUCKET_SIZE = "CGX_COMPRESSION_BUCKET_SIZE"
COMPRESSION_MINIMAL_SIZE = "CGX_COMPRESSION_MINIMAL_SIZE"
VALUE_NO_COMPRESS=32


class CGXState(object):
def __init__(self, process_group: dist.ProcessGroup, layer_min_size: int = 1024,
compression_params: Dict[str, int] = None):
self.process_group = process_group if process_group is not None else dist.group.WORLD
min_size_to_compress = int(os.getenv(COMPRESSION_MINIMAL_SIZE, "16"))
self.layer_min_size = max(layer_min_size, min_size_to_compress)
self.quantization_bits = int(os.getenv(COMPRESSION_QUANTIZATION_BITS, str(VALUE_NO_COMPRESS)))
self.quantization_bucket_size = int(os.getenv(COMPRESSION_BUCKET_SIZE, "1024"))
self.step = 0
if compression_params is not None:
self.quantization_bits = compression_params.get("bits", self.quantization_bits)
self.quantization_bucket_size = compression_params.get("bucket_size", self.quantization_bucket_size)

def should_compress_(self, tensor: torch.Tensor):
if tensor.dim() <= 1 or tensor.numel() < self.layer_min_size:
return False
return True


def _allreduce_fut(
process_group: dist.ProcessGroup, tensor: torch.Tensor
) -> torch.futures.Future[torch.Tensor]:
"Averages the input gradient tensor by allreduce and returns a future."
group_to_use = process_group if process_group is not None else dist.group.WORLD
# Apply the division first to avoid overflow, especially for FP16.
tensor.div_(group_to_use.size())
return (
dist.all_reduce(tensor, group=group_to_use, async_op=True)
.get_future()
.then(lambda fut: fut.value()[0])
)


def cgx_hook(
state: CGXState, bucket: dist.GradBucket
) -> torch.futures.Future[torch.Tensor]:
if state.step == 2:
for layer_idx, tensor in enumerate(bucket.gradients()):
bits = state.quantization_bits if state.should_compress_(tensor) else VALUE_NO_COMPRESS
torch_cgx.register_layer(bucket.index(), layer_idx, tensor.numel(),
bits, state.quantization_bucket_size)
if bucket.is_last():
state.step += 1
state.layer_idx = 0
return _allreduce_fut(state.process_group, bucket.buffer())
Loading

0 comments on commit fb71bab

Please sign in to comment.