Skip to content

Commit

Permalink
Support pass kwargs to sd3 custom attention processor (#9818)
Browse files Browse the repository at this point in the history
* Support pass kwargs to sd3 custom attention processor


---------

Co-authored-by: hlky <[email protected]>
Co-authored-by: YiYi Xu <[email protected]>
  • Loading branch information
3 people authored Dec 18, 2024
1 parent 88b015d commit 8eb73c8
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
13 changes: 10 additions & 3 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,13 @@ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
self._chunk_dim = dim

def forward(
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
):
joint_attention_kwargs = joint_attention_kwargs or {}
if self.use_dual_attention:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
hidden_states, emb=temb
Expand All @@ -206,15 +211,17 @@ def forward(

# Attention.
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
**joint_attention_kwargs,
)

# Process attention outputs for the `hidden_states`.
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output

if self.use_dual_attention:
attn_output2 = self.attn2(hidden_states=norm_hidden_states2)
attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs)
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
hidden_states = hidden_states + attn_output2

Expand Down
6 changes: 5 additions & 1 deletion src/diffusers/models/transformers/transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,11 +411,15 @@ def custom_forward(*inputs):
hidden_states,
encoder_hidden_states,
temb,
joint_attention_kwargs,
**ckpt_kwargs,
)
elif not is_skip:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
joint_attention_kwargs=joint_attention_kwargs,
)

# controlnet residual
Expand Down

0 comments on commit 8eb73c8

Please sign in to comment.