Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix loop bug in SlicedAttnProcessor #8836

Merged
merged 13 commits into from
Jul 20, 2024

Conversation

shinetzh
Copy link
Contributor

Fixes # (loop bug in SlicedAttnProcessor)

@sayakpaul @yiyixuxu @DN6

@sayakpaul
Copy link
Member

Thank you! Do you think our tests need to be updated to catch bugs like this?

def test_attention_slicing_forward_pass(self, expected_max_diff=1e-3):

@shinetzh
Copy link
Contributor Author

Bugs like this in test_attention_slicing_forward_pass have been fixed, include "SlicedAttnProcessor" and "SlicedAttnAddedKVProcessor". May be there is no need to catch bugs like this in test_attention_slicing_forward_pass. But maybe tests on other function or module are needed.

@shinetzh
Copy link
Contributor Author

Or, if it's needed to update the test "test_attention_slicing_forward_pass" to catch bugs like this, I am happy to do that

@shinetzh
Copy link
Contributor Author

I have update the test_attention_slicing_forward_pass to catch bugs like this.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks a lot!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@shinetzh
Copy link
Contributor Author

space in blank line have been removed

@yiyixuxu
Copy link
Collaborator

thanks for fixing this for us!
can you run make style ?

@shinetzh
Copy link
Contributor Author

ok, wait a minute

@shinetzh
Copy link
Contributor Author

done

@shinetzh
Copy link
Contributor Author

sorry for tests bug, have fixed it

@yiyixuxu
Copy link
Collaborator

hey I think the failing tests are relevant here https://github.com/huggingface/diffusers/actions/runs/9970139809/job/27588409202?pr=8836#step:7:17944

can you look into them?

@shinetzh
Copy link
Contributor Author

I have look into them, and this is relevant with unet_2d_condition. I change unet_2d_conditon a little, I don't know if this is appropriate. Please check it.

@shinetzh
Copy link
Contributor Author

This commit: 0ede41c

inputs = self.get_dummy_inputs(generator_device)
output_with_slicing2 = pipe(**inputs)[0]

pipe.enable_attention_slicing(slice_size=3)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we remove the slice_size=3 test? I think the CI would pass without this, no?

@@ -815,7 +815,10 @@ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
size = slice_size[i]
dim = sliceable_head_dims[i]
if size is not None and size > dim:
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
slice_size[i] = dim
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's try not to make this update and change the test instead (we should try not to update the user inputs for user, we always prefer to be explicit and throw an error message, )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, get it. if remove the slice_size=3, the CI will pass.

@shinetzh
Copy link
Contributor Author

I have remove slice_size=3, and make unet_2d_condition unchanged.

@shinetzh
Copy link
Contributor Author

It seems like failed on other test file.
image

@yiyixuxu yiyixuxu merged commit 3b04cdc into huggingface:main Jul 20, 2024
13 of 15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants