-
Notifications
You must be signed in to change notification settings - Fork 2.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FSDP finetuned model inference question #634
Comments
I also noticed that because of the all-reduce before the forward pass, it's not recommended to use FSDP for inference. Does this mean FSDP inference isn't supported so far by Pytorch or it's just not recommended because all-reduce will make the FSDP inference inefficient? If it is the second case is there an example I can follow to use FSDP checkpoints for inference? |
Hi @mathmax12 thanks for reporting this, I was able to reproduce this and will be having a look. |
@mreso Thank you for looking into this issue. In the meantime, is there a workaround for using finetuned FSDP checkpoints for inference? Thanks. |
@mreso Is there an update on this? Thanks |
Sorry @mathmax12, did not yet get the chance to look deeper into this. |
Could we prioritize this? if the checkpoints don't work how can we use the fine-tuned FSDP checkpoint for inference? |
Facing a similar issue. Is there a solution for this? |
Hey, @mreso I found this only happens for llama3 and 3.1 models. inference with checkpoints from FSDP llama2 is ok. |
@mreso @HamidShojanazeri I noticed that the alpaca dataset has not been updated as the label token is still -1 instead of -100 as shown here. Edit: it is because we did not add llama 3 special tokens |
Is there a fix to this? |
@mreso any updates on this? I am also facing a similar issue. |
@aishwaryap @mathmax12 Thanks for reporting this bug. We just added a PR to add llama 3 support for alpaca dataset fine-tuning. Please give it a try and let me know if this helps. |
Thank @wukaixingxp for fixing this. Here is what I did |
@mathmax12 please use |
@wukaixingxp
I am also curious what is the different between the |
@mathmax12
|
@wukaixingxp Thanks for the updates. I can't find the |
yes, |
🚀 The feature, motivation and pitch
The fine-tuning with only FSDP works well and sharded checkpoints are saved as
__0_*.distcp, .metadata, and train_params.yam
l. I can see the loss drop reasonably. Here is the training command:torchrun --nnodes 1 --nproc_per_node 8 ./recipes/quickstart/finetuning/finetuning.py --model_name /tmp/llama-recipes/Meta-Llama-3.1-8B --output_dir ./fsdp_fine_tune_results/output_model_1_8 --dist_checkpoint_root_folder ./fsdp_fine_tune_results/fsdp_model_finetuned_1_8 --enable_fsdp --num_epochs 1 --batch_size_training 2 --dataset alpaca_dataset
Then I tried to do the inference with the FSDP checkpoints by:
python ./src/llama_recipes/inference/checkpoint_converter_fsdp_hf.py --fsdp_checkpoint_path ./fsdp_fine_tune_results/fsdp_model_finetuned_1_8/fine-tuned-/tmp/llama-recipes/Meta-Llama-3.1-8B --consolidated_model_path ./fsdp_fine_tune_results/fsdp_model_finetuned_1_8_hf --HF_model_path_or_name /tmp/llama-recipes/Meta-Llama-3.1-8B
python ./recipes/quickstart/inference/local_inference/inference.py --model_name ./fsdp_fine_tune_results/fsdp_model_finetuned_1_8_hf --prompt_file prompt_for_test.txt
But I got zero outputs:
"llama-recipes# python ./recipes/quickstart/inference/local_inference/inference.py --model_name ./fsdp_fine_tune_results/fsdp_model_finetuned_1_8_hf --prompt_file prompt_for_test.txt
/root/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning:
torch.distributed._shard.checkpoint
will be deprecated, usetorch.distributed.checkpoint
insteadfrom torch.distributed._shard.checkpoint import (
use_fast_kernelsFalse
Using the
SDPA
attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends.Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:14<00:00, 2.13s/it]
User prompt deemed safe.
User prompt:
I have tomatoes, basil and cheese at home. What can I cook for dinner?\n
Setting
pad_token_id
toeos_token_id
:128001 for open-end generation.the inference time is 286928.2311500283 ms
User input and model output deemed safe.
Model output:
I have tomatoes, basil and cheese at home. What can I cook for dinner?\n`
"
If I use the original
Meta-Llama-3.1-8B
model for the inference I can find the output is ok. Also when using the checkpoint from fine-tuning with FSDP + peft lora, the inference looks fine.Could someone let me know if I missed anything? or Is there a way/tool to check if the FSDP checkpoints to HF checkpoints conversion goes well?
Thanks!
Alternatives
No response
Additional context
No response
The text was updated successfully, but these errors were encountered: