-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix loop bug in SlicedAttnProcessor #8836
Merged
yiyixuxu
merged 13 commits into
huggingface:main
from
shinetzh:fix_bug_of_SlicedAttnProcessor
Jul 20, 2024
Merged
Changes from 11 commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
d0b472c
fix loop bug in SlicedAttnProcessor
de81050
same loop bug in SlicedAttnAddedKVProcessor
e4506f4
update test_attention_slicing_forward_pass
f140280
fix check quality
1fdb140
Merge branch 'main' into fix_bug_of_SlicedAttnProcessor
shinetzh 439f2e3
make quality
a309b96
Merge branch 'main' of https://github.com/huggingface/diffusers into …
5742c34
fix tests bug
ebdaef1
Merge branch 'main' of https://github.com/huggingface/diffusers into …
0ede41c
make slice_size <= dim, and warning users
fcfa319
make quality
69a4d5f
Merge branch 'main' of https://github.com/huggingface/diffusers into …
e77b460
remove slize_size=3
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1351,14 +1351,30 @@ def _test_attention_slicing_forward_pass( | |
|
||
pipe.enable_attention_slicing(slice_size=1) | ||
inputs = self.get_dummy_inputs(generator_device) | ||
output_with_slicing = pipe(**inputs)[0] | ||
output_with_slicing1 = pipe(**inputs)[0] | ||
|
||
pipe.enable_attention_slicing(slice_size=2) | ||
inputs = self.get_dummy_inputs(generator_device) | ||
output_with_slicing2 = pipe(**inputs)[0] | ||
|
||
pipe.enable_attention_slicing(slice_size=3) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we remove the |
||
inputs = self.get_dummy_inputs(generator_device) | ||
output_with_slicing3 = pipe(**inputs)[0] | ||
|
||
if test_max_difference: | ||
max_diff = np.abs(to_np(output_with_slicing) - to_np(output_without_slicing)).max() | ||
self.assertLess(max_diff, expected_max_diff, "Attention slicing should not affect the inference results") | ||
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() | ||
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() | ||
max_diff3 = np.abs(to_np(output_with_slicing3) - to_np(output_without_slicing)).max() | ||
self.assertLess( | ||
max(max_diff1, max_diff2, max_diff3), | ||
expected_max_diff, | ||
"Attention slicing should not affect the inference results", | ||
) | ||
|
||
if test_mean_pixel_difference: | ||
assert_mean_pixel_difference(to_np(output_with_slicing[0]), to_np(output_without_slicing[0])) | ||
assert_mean_pixel_difference(to_np(output_with_slicing1[0]), to_np(output_without_slicing[0])) | ||
assert_mean_pixel_difference(to_np(output_with_slicing2[0]), to_np(output_without_slicing[0])) | ||
assert_mean_pixel_difference(to_np(output_with_slicing3[0]), to_np(output_without_slicing[0])) | ||
|
||
@unittest.skipIf( | ||
torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"), | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's try not to make this update and change the test instead (we should try not to update the user inputs for user, we always prefer to be explicit and throw an error message, )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, get it. if remove the slice_size=3, the CI will pass.