Skip to content

Commit

Permalink
F/flax split head dim (#5181)
Browse files Browse the repository at this point in the history
* split_head_dim flax attn

* Make split_head_dim non default

* make style and make quality

* add description for split_head_dim flag

* Update src/diffusers/models/attention_flax.py

Co-authored-by: Patrick von Platen <[email protected]>

---------

Co-authored-by: Juan Acevedo <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
  • Loading branch information
3 people authored Sep 26, 2023
1 parent c82f7ba commit 16d56c4
Showing 1 changed file with 25 additions and 7 deletions.
32 changes: 25 additions & 7 deletions src/diffusers/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ class FlaxAttention(nn.Module):
Dropout rate
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases, enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
Expand All @@ -140,6 +142,7 @@ class FlaxAttention(nn.Module):
dim_head: int = 64
dropout: float = 0.0
use_memory_efficient_attention: bool = False
split_head_dim: bool = False
dtype: jnp.dtype = jnp.float32

def setup(self):
Expand Down Expand Up @@ -177,9 +180,15 @@ def __call__(self, hidden_states, context=None, deterministic=True):
key_proj = self.key(context)
value_proj = self.value(context)

query_states = self.reshape_heads_to_batch_dim(query_proj)
key_states = self.reshape_heads_to_batch_dim(key_proj)
value_states = self.reshape_heads_to_batch_dim(value_proj)
if self.split_head_dim:
b = hidden_states.shape[0]
query_states = jnp.reshape(query_proj, (b, -1, self.heads, self.dim_head))
key_states = jnp.reshape(key_proj, (b, -1, self.heads, self.dim_head))
value_states = jnp.reshape(value_proj, (b, -1, self.heads, self.dim_head))
else:
query_states = self.reshape_heads_to_batch_dim(query_proj)
key_states = self.reshape_heads_to_batch_dim(key_proj)
value_states = self.reshape_heads_to_batch_dim(value_proj)

if self.use_memory_efficient_attention:
query_states = query_states.transpose(1, 0, 2)
Expand All @@ -206,14 +215,23 @@ def __call__(self, hidden_states, context=None, deterministic=True):
hidden_states = hidden_states.transpose(1, 0, 2)
else:
# compute attentions
attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
if self.split_head_dim:
attention_scores = jnp.einsum("b t n h, b f n h -> b n f t", key_states, query_states)
else:
attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)

attention_scores = attention_scores * self.scale
attention_probs = nn.softmax(attention_scores, axis=2)
attention_probs = nn.softmax(attention_scores, axis=-1 if self.split_head_dim else 2)

# attend to values
hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
if self.split_head_dim:
hidden_states = jnp.einsum("b n f t, b t n h -> b f n h", attention_probs, value_states)
b = hidden_states.shape[0]
hidden_states = jnp.reshape(hidden_states, (b, -1, self.heads * self.dim_head))
else:
hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)

hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
hidden_states = self.proj_attn(hidden_states)
return self.dropout_layer(hidden_states, deterministic=deterministic)

Expand Down

1 comment on commit 16d56c4

@daniva
Copy link

@daniva daniva commented on 16d56c4 May 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey there -- I think this change may introduce a bug, as removing self.reshape_batch_dim_to_heads in line 216 means it will not be called when self.use_memory_efficient_attention is True and self.split_head_dim is False.

Please sign in to comment.