Skip to content

Latest commit

 

History

History
54 lines (46 loc) · 1.77 KB

README.md

File metadata and controls

54 lines (46 loc) · 1.77 KB

Decoupled Torch Network-Aware Training on Interlinked Online Nodes (DeToNATION)

This code currently implements the results described in FlexDeMo: Decoupled Momentum Optimization for Fully and Hybrid Sharded Training.

Installation

Installation from PyPI:

pip install detonation

Installation from source:

git clone https://github.com/schneiderkamplab/DeToNATION
cd DeToNATION
pip install .

Example

There is a a full example for language model training using FlexDeMo in the example folder. Please refer to the documentation:

examples/t5/README.md

This example demonstrates the use of the prepare_detonation function for obtaining a distributed model and optimizer.

Usage

The direct usage of DeToNATION without using prepare_detonation requires three elements as exemplified below for the FlexDeMo optimizer, i.e., DeToNATION with node-based hybrid sharding using DeMo replication.

First, you need to wrap your model with FSDP and the hybrid sharding strategy:

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.HYBRID_SHARD,
)

Then, you can import and instantiate the FlexDeMo optimizer:

from detonation import DeMo
optim = DeMo(
    compression_topk=16,
    compression_chunk=128,
    sharding_parallel_group=model.process_group,
    replication_parallel_group=model._inter_node_pg,
)

Third and last, you need to wrap the forward and backward pass using a no_sync context manager to avoid automatic full gradient synchronization:

    with model.no_sync(): # Disable gradient synchronizations across FSDP instances.
        loss = model(input_ids=batch["input_ids"],labels=batch["labels"])["loss"]
        loss.backward()