Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Precision issue caused by different token dispatchers in MoE training #1327

Open
qi7kuo opened this issue Dec 17, 2024 · 0 comments
Open

Comments

@qi7kuo
Copy link

qi7kuo commented Dec 17, 2024

Describe the bug
When using different dispatchers (i.e., AllGather and AlltoAll), it can cause training precision errors.

To Reproduce
You can use the following bash command to reproduce.
My dataset is en-wikipedia.

export CUDA_DEVICE_MAX_CONNECTIONS=1

MASTER_ADDR=localhost
MASTER_PORT=6648
WORLD_SIZE=1
RANK=0
NPROC_PER_NODE=4

export CUDA_DEVICE_MAX_CONNECTIONS=1

DATA_PATH=/mnt/gpt_data

DISTRIBUTED_ARGS="--nproc_per_node $NPROC_PER_NODE --nnodes $WORLD_SIZE --node_rank $RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"

torchrun $DISTRIBUTED_ARGS \
       pretrain_gpt.py \
       --num-experts 8 \
       --expert-model-parallel-size 4 \
       --optimizer adam \
       --distributed-backend nccl \
       --moe-token-dispatcher-type alltoall \
       --num-layers 12 \
       --hidden-size 256 \
       --num-attention-heads 32 \
       --swiglu \
       --ffn-hidden-size 512 \
       --disable-bias-linear \
       --attention-dropout 0 \
       --hidden-dropout 0 \
       --normalization RMSNorm \
       --untie-embeddings-and-output-weights \
       --use-rotary-position-embeddings \
       --no-position-embedding \
       --no-masked-softmax-fusion \
       --micro-batch-size 1 \
       --global-batch-size 4 \
       --seq-length 2048 \
       --max-position-embeddings 2048 \
       --data-path $DATA_PATH/wiki-gpt_text_document \
       --vocab-file $DATA_PATH/gpt2-vocab.json \
       --merge-file $DATA_PATH/gpt2-merges.txt \
       --split 100,0,0 \
       --train-iters 100 \
       --lr-decay-iters 10000 \
       --lr-warmup-iters 100 \
       --adam-beta1 0.9 \
       --adam-beta2 0.95 \
       --lr 1e-5 \
       --lr-decay-style cosine \
       --min-lr 1e-6 \
       --weight-decay 1e-2 \
       --clip-grad 1.0 \
       --log-interval 1  2>&1 | tee moe_std.log

I find the differences when the iter=100. If we use alltoall dispatcher, the loss is 1.030484E+01. However, if we use allgather dispatcher, the loss is 1.030483E+01.

Stack trace/logs
I have investigate the problem, and my thought and trials are listed below:

I checked the intermediate activation tensors. I noticed the precision difference happens after an iteration which has 0 token allocated for an expert (In my environment, it happens in iteration 37)

Then, I checked the weight & grad. I noticed that in iteration 37, the forward intermediate activation and model weight are the same. However, in the backward pass of iteration 37, the module get different gradients when using different dispatcher, which is straightforward reason of precision issue.

Environment (please complete the following information):

  • Megatron-LM commit ID 81fee9b
  • PyTorch version: 2.4.0
  • CUDA version: 12.0
  • NCCL version: 24.07
  • Device: 4 * V100 (32GB)

Additional context

Let me have a summary here. I am looking forward to the help and finally find the root reason.
To be honest, this bug confused me as the behavior of All-Gather Dispatcher and AlltoAll Dispatcher should be the same.

As I mentioned before, the precision issue happens when there is an expert receive 0 tokens, which maybe the core reason for the issue. I noticed that the forward activation tensors are same, but the backward weight grad tensors are different.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant