Skip to content

Commit

Permalink
fix bug of Attention.head_to_batch_dim issue huggingface#10303
Browse files Browse the repository at this point in the history
  • Loading branch information
Dawn-LX committed Dec 20, 2024
1 parent 41ba8c0 commit 9b24fb5
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,8 +612,10 @@ def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:

def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
r"""
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
the number of heads initialized while constructing the `Attention` class.
Reshape the tensor from `[batch_size, seq_len, dim]` to
`[batch_size, seq_len, heads, dim // heads]` for out_dim==4
or `[batch_size * heads, seq_len, dim // heads]` for out_dim==3
where `heads` is the number of heads initialized while constructing the `Attention` class.
Args:
tensor (`torch.Tensor`): The tensor to reshape.
Expand All @@ -630,9 +632,10 @@ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Ten
else:
batch_size, extra_dim, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3)


assert out_dim in [3,4]
if out_dim == 3:
tensor = tensor.permute(0, 2, 1, 3)
tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)

return tensor
Expand Down

0 comments on commit 9b24fb5

Please sign in to comment.