Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FP8 training can't works with deepspeed #3360

Closed
2 of 4 tasks
XiaobingSuper opened this issue Jan 23, 2025 · 5 comments · Fixed by #3361 · May be fixed by #3385
Closed
2 of 4 tasks

FP8 training can't works with deepspeed #3360

XiaobingSuper opened this issue Jan 23, 2025 · 5 comments · Fixed by #3361 · May be fixed by #3385

Comments

@XiaobingSuper
Copy link
Contributor

System Info

- `Accelerate` version: 1.3.0
- Platform: Linux-5.4.250-4-velinux1u1-amd64-x86_64-with-glibc2.35
- `accelerate` bash location: /usr/local/bin/accelerate
- Python version: 3.10.12
- Numpy version: 1.24.4
- PyTorch version (GPU?): 2.5.1+cu124 (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- PyTorch MLU available: False
- PyTorch MUSA available: False
- System RAM: 1928.86 GB
- GPU type: NVIDIA H20
- `Accelerate` default config:
        Not found

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

Create deepspeed configure json:

{
    "train_batch_size": 16,
    "train_micro_batch_size_per_gpu": 16,
    "gradient_accumulation_steps": 1,
    "zero_optimization": {
        "stage": 3,
        "stage3_gather_16bit_weights_on_model_save": false
    },
    "gradient_clipping": 1.0,
    "bf16": {"enabled": true},
    "fp16": {"enabled": false},
    "zero_allow_untested_optimizer": true
}

and accelerate configure yaml:

distributed_type: DEEPSPEED
deepspeed_config:
  deepspeed_config_file: "config.json"
  zero3_init_flag: true
num_processes: 1

Then run the following script:

from accelerate import Accelerator
from accelerate.utils import has_transformer_engine_layers, FP8RecipeKwargs
from fp8_utils import  get_training_utilities

MODEL_NAME = "bert-base-cased"

FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"}
kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)]

accelerator = Accelerator(
    mixed_precision="fp8", kwargs_handlers=kwargs_handlers
)

model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
    MODEL_NAME, accelerator=accelerator
)

model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)

assert has_transformer_engine_layers(model), "Model should have Transformer Engine layers"

Expected behavior

There will be an assert error:Model should have Transformer Engine layers

@XiaobingSuper
Copy link
Contributor Author

you can also reproduce this issue by running https://github.com/huggingface/accelerate/blob/main/benchmarks/fp8/transformer_engine/distrib_deepspeed.py:

CUDA_VISIBLE_DEVICES=4,5 accelerate launch distrib_deepspeed.py

Image

@Superskyyy
Copy link

Can confirm FP8 canot work with deepspeed, but a different problem. using fp8 in mixed precision will dsible zero stages back to 0 even though i set 2-3. Additional thing is that if I supply deepspeed json config, it won't allow me to set mxied_precision=fp8 in accelerate config, there isn't a way to supply deepspeed fp8 in json too.

@XiaobingSuper
Copy link
Contributor Author

XiaobingSuper commented Feb 7, 2025

@Superskyyy I run my given example(the code in Reproduction), the print log is:

2025-02-07 06:40:19,888] [INFO] [config.py:1002:print]   zero_config .................. stage=3 contiguous_gradients=True reduce_scatter=True reduce_bucket_size=500000000 use_multi_rank_bucket_allreduce=True allgather_partitions=True allgather_bucket_size=500000000 overlap_comm=True load_from_fp32_weights=True elastic_checkpoint=False offload_param=None offload_optimizer=None sub_group_size=1000000000 cpu_offload_param=None cpu_offload_use_pin_memory=None cpu_offload=None prefetch_bucket_size=50000000 param_persistence_threshold=100000 model_persistence_threshold=9223372036854775807 max_live_parameters=1000000000 max_reuse_distance=1000000000 gather_16bit_weights_on_model_save=False module_granularity_threshold=0 use_all_reduce_for_fetch_params=False stage3_gather_fp16_weights_on_model_save=False ignore_unused_parameters=True legacy_stage1=False round_robin_gradients=False zero_hpz_partition_size=1 zero_quantized_weights=False zero_quantized_nontrainable_weights=False zero_quantized_gradients=False mics_shard_size=-1 mics_hierarchical_params_gather=False memory_efficient_linear=True pipeline_loading_checkpoint=False override_module_apply=True
[2025-02-07 06:40:19,888] [INFO] [config.py:1002:print]   zero_enabled ................. True
[2025-02-07 06:40:19,888] [INFO] [config.py:1002:print]   zero_force_ds_cpu_optimizer .. True
[2025-02-07 06:40:19,888] [INFO] [config.py:1002:print]   zero_optimization_stage ...... 3
[2025-02-07 06:40:19,888] [INFO] [config.py:988:print_user_config]   json = {
    "train_batch_size": 16, 
    "train_micro_batch_size_per_gpu": 16, 
    "gradient_accumulation_steps": 1, 
    "zero_optimization": {
        "stage": 3, 
        "stage3_gather_16bit_weights_on_model_save": false
    }, 
    "gradient_clipping": 1.0, 
    "bf16": {
        "enabled": true
    }, 
    "fp16": {
        "enabled": false
    }, 
    "zero_allow_untested_optimizer": true, 
    "steps_per_print": inf
}

the zero_optimization's stage is not zero. could you also run the given example or give your case which I can reproduce it?

@XiaobingSuper
Copy link
Contributor Author

Can confirm FP8 canot work with deepspeed, but a different problem. using fp8 in mixed precision will dsible zero stages back to 0 even though i set 2-3. Additional thing is that if I supply deepspeed json config, it won't allow me to set mxied_precision=fp8 in accelerate config, there isn't a way to supply deepspeed fp8 in json too.

For the Additional thing is that if I supply deepspeed json config, it won't allow me to set mxied_precision=fp8 in accelerate config, there isn't a way to supply deepspeed fp8 in json too.

Yes, I also meet this issue:

Image

@XiaobingSuper
Copy link
Contributor Author

Can confirm FP8 canot work with deepspeed, but a different problem. using fp8 in mixed precision will dsible zero stages back to 0 even though i set 2-3. Additional thing is that if I supply deepspeed json config, it won't allow me to set mxied_precision=fp8 in accelerate config, there isn't a way to supply deepspeed fp8 in json too.

For the Additional thing is that if I supply deepspeed json config, it won't allow me to set mxied_precision=fp8 in accelerate config, there isn't a way to supply deepspeed fp8 in json too.

Yes, I also meet this issue:

Image

@Superskyyy this PR #3385 can work for this case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants