Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix loop bug in SlicedAttnProcessor #8836

Merged
merged 13 commits into from
Jul 20, 2024
4 changes: 2 additions & 2 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2190,7 +2190,7 @@ def __call__(
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
)

for i in range(batch_size_attention // self.slice_size):
for i in range((batch_size_attention - 1) // self.slice_size + 1):
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size

Expand Down Expand Up @@ -2287,7 +2287,7 @@ def __call__(
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
)

for i in range(batch_size_attention // self.slice_size):
for i in range((batch_size_attention - 1) // self.slice_size + 1):
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size

Expand Down
5 changes: 4 additions & 1 deletion src/diffusers/models/unets/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,10 @@ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
size = slice_size[i]
dim = sliceable_head_dims[i]
if size is not None and size > dim:
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
slice_size[i] = dim
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's try not to make this update and change the test instead (we should try not to update the user inputs for user, we always prefer to be explicit and throw an error message, )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, get it. if remove the slice_size=3, the CI will pass.

logger.warning(
f"size {size} has to be smaller or equal to {dim}, and slice_size {size} has been set to {dim}"
)

# Recursively walk through all the children.
# Any children which exposes the set_attention_slice method
Expand Down
24 changes: 20 additions & 4 deletions tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1351,14 +1351,30 @@ def _test_attention_slicing_forward_pass(

pipe.enable_attention_slicing(slice_size=1)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing = pipe(**inputs)[0]
output_with_slicing1 = pipe(**inputs)[0]

pipe.enable_attention_slicing(slice_size=2)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing2 = pipe(**inputs)[0]

pipe.enable_attention_slicing(slice_size=3)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we remove the slice_size=3 test? I think the CI would pass without this, no?

inputs = self.get_dummy_inputs(generator_device)
output_with_slicing3 = pipe(**inputs)[0]

if test_max_difference:
max_diff = np.abs(to_np(output_with_slicing) - to_np(output_without_slicing)).max()
self.assertLess(max_diff, expected_max_diff, "Attention slicing should not affect the inference results")
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
max_diff3 = np.abs(to_np(output_with_slicing3) - to_np(output_without_slicing)).max()
self.assertLess(
max(max_diff1, max_diff2, max_diff3),
expected_max_diff,
"Attention slicing should not affect the inference results",
)

if test_mean_pixel_difference:
assert_mean_pixel_difference(to_np(output_with_slicing[0]), to_np(output_without_slicing[0]))
assert_mean_pixel_difference(to_np(output_with_slicing1[0]), to_np(output_without_slicing[0]))
assert_mean_pixel_difference(to_np(output_with_slicing2[0]), to_np(output_without_slicing[0]))
assert_mean_pixel_difference(to_np(output_with_slicing3[0]), to_np(output_without_slicing[0]))

@unittest.skipIf(
torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"),
Expand Down