-
Notifications
You must be signed in to change notification settings - Fork 217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add boft support in stable-diffusion #1295
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Wang, Yi A <[email protected]>
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Signed-off-by: Wang, Yi A <[email protected]>
Hi @sywangyi |
same with latest main. 1 case fail FAILED tests/test_diffusers.py::GaudiStableDiffusionXLImg2ImgPipelineTests::test_stable_diffusion_xl_img2img_euler - AssertionError: 0.21911774845123289 not less than 0.01 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @sywangyi
I spent some time on this PR and did some testing:
-
I've reworked the README file for a better read. Please apply the changes with the attached patch using
git am < 000*
(don't copy past the changes, apply the patch please).
0001-fea-dreambooth-reworked-the-readme.patch -
I've tested the PEFT example with both
lora
andboft
. The lora example finishes in about 6min (5m47.993s
) but theboft
one has been running for ~80min and only compeleted 24% (Steps: 24%|██▎ | 188/800 [1:21:18<3:53:56, 22.94s/it, loss=0.0225, lr=0.0001]
). Any thoughts on whyboft
is so significantly slower thanlora
? Is this bc of lack of hpu graphs? Let's investigate this a bit more.- I've provided the tested
cmd
below.
- I've provided the tested
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export INSTANCE_DIR="dog"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="out"
logfile=pr1295.$(date -u +%Y%m%d%H%M).$(hostname).log
time python ../../gaudi_spawn.py --world_size 8 --use_mpi train_dreambooth.py --pretrained_model_name_or_path=$MODEL_NAME --instance_data_dir=$INSTANCE_DIR --output_dir=$OUTPUT_DIR --class_data_dir=$CLASS_DIR --with_prior_preservation --prior_loss_weight=1.0 --instance_prompt="a photo of sks dog" --class_prompt="a photo of dog" --resolution=512 --train_batch_size=1 --num_class_images=200 --gradient_accumulation_steps=1 --learning_rate=1e-4 --lr_scheduler="constant" --lr_warmup_steps=0 --max_train_steps=800 --mixed_precision=bf16 --use_hpu_graphs_for_training --use_hpu_graphs_for_inference --gaudi_config_name Habana/stable-diffusion lora --unet_r 8 --unet_alpha 8 2>&1 | tee $logfile
time python ../../gaudi_spawn.py --world_size 8 --use_mpi train_dreambooth.py --pretrained_model_name_or_path=$MODEL_NAME --instance_data_dir=$INSTANCE_DIR --output_dir=$OUTPUT_DIR --class_data_dir=$CLASS_DIR --with_prior_preservation --prior_loss_weight=1.0 --instance_prompt="a photo of sks dog" --class_prompt="a photo of dog" --resolution=512 --train_batch_size=1 --num_class_images=200 --gradient_accumulation_steps=1 --learning_rate=1e-4 --lr_scheduler="constant" --lr_warmup_steps=0 --max_train_steps=800 --mixed_precision=bf16 --gaudi_config_name Habana/stable-diffusion boft 2>&1 | tee $logfile
yes. I have file a bug to pytorch training team about the perf issue, will cc you in the jira |
@sywangyi please test with latest synapse SW, if there is still issue, we dont need to merge this change for next synapse release as it's not functional |
which version do you mean? I think habana pytorch training team is still working on it. |
@sywangyi do you have test result? |
I have tested this with driver 1.18.0-460 and the corresponding docker. It still shows the same behavior. |
@imangohari1 , do we have update on this? |
I am not if the issue is resolved or not. |
according to https://habana.atlassian.net/browse/HS-3208, it has not been resolved yet |
Just to update: RnD guy found the low level issue that produces the slow compilation when "torch.block_diag" operation run. You can find all the details in the ticket. |
What does this PR do?
Fixes # (issue)
Before submitting