From 16d56c4b4ff49790fcaa3b98ce5c9c5a00fd1c41 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 26 Sep 2023 11:08:30 -0700 Subject: [PATCH] F/flax split head dim (#5181) * split_head_dim flax attn * Make split_head_dim non default * make style and make quality * add description for split_head_dim flag * Update src/diffusers/models/attention_flax.py Co-authored-by: Patrick von Platen --------- Co-authored-by: Juan Acevedo Co-authored-by: Patrick von Platen --- src/diffusers/models/attention_flax.py | 32 ++++++++++++++++++++------ 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 0b160d238431..588b99ec240f 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -131,6 +131,8 @@ class FlaxAttention(nn.Module): Dropout rate use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): enable memory efficient attention https://arxiv.org/abs/2112.05682 + split_head_dim (`bool`, *optional*, defaults to `False`): + Whether to split the head dimension into a new axis for the self-attention computation. In most cases, enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` @@ -140,6 +142,7 @@ class FlaxAttention(nn.Module): dim_head: int = 64 dropout: float = 0.0 use_memory_efficient_attention: bool = False + split_head_dim: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): @@ -177,9 +180,15 @@ def __call__(self, hidden_states, context=None, deterministic=True): key_proj = self.key(context) value_proj = self.value(context) - query_states = self.reshape_heads_to_batch_dim(query_proj) - key_states = self.reshape_heads_to_batch_dim(key_proj) - value_states = self.reshape_heads_to_batch_dim(value_proj) + if self.split_head_dim: + b = hidden_states.shape[0] + query_states = jnp.reshape(query_proj, (b, -1, self.heads, self.dim_head)) + key_states = jnp.reshape(key_proj, (b, -1, self.heads, self.dim_head)) + value_states = jnp.reshape(value_proj, (b, -1, self.heads, self.dim_head)) + else: + query_states = self.reshape_heads_to_batch_dim(query_proj) + key_states = self.reshape_heads_to_batch_dim(key_proj) + value_states = self.reshape_heads_to_batch_dim(value_proj) if self.use_memory_efficient_attention: query_states = query_states.transpose(1, 0, 2) @@ -206,14 +215,23 @@ def __call__(self, hidden_states, context=None, deterministic=True): hidden_states = hidden_states.transpose(1, 0, 2) else: # compute attentions - attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states) + if self.split_head_dim: + attention_scores = jnp.einsum("b t n h, b f n h -> b n f t", key_states, query_states) + else: + attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states) + attention_scores = attention_scores * self.scale - attention_probs = nn.softmax(attention_scores, axis=2) + attention_probs = nn.softmax(attention_scores, axis=-1 if self.split_head_dim else 2) # attend to values - hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states) + if self.split_head_dim: + hidden_states = jnp.einsum("b n f t, b t n h -> b f n h", attention_probs, value_states) + b = hidden_states.shape[0] + hidden_states = jnp.reshape(hidden_states, (b, -1, self.heads * self.dim_head)) + else: + hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) hidden_states = self.proj_attn(hidden_states) return self.dropout_layer(hidden_states, deterministic=deterministic)