Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement CustomDiffusionAttnProcessor2_0 #4588

Closed
eliphatfs opened this issue Aug 13, 2023 · 6 comments · Fixed by #4604
Closed

Implement CustomDiffusionAttnProcessor2_0 #4588

eliphatfs opened this issue Aug 13, 2023 · 6 comments · Fixed by #4604
Labels
stale Issues that haven't received updates

Comments

@eliphatfs
Copy link
Contributor

Default AttnProcessor and LoRA can use torch 2.0 SDP operators, but the custom diffusion processors can not.
I can start a PR if you think it is good.

@sayakpaul
Copy link
Member

Sure. Cc: @nupurkmr9

But have you seen any significant difference when using it in comparison to xformers variant of Custom Diffusion?

@eliphatfs
Copy link
Contributor Author

eliphatfs commented Aug 15, 2023

I am using the very new card H100, and xformers somehow crashes on this card; torch 2 stable also crashes, but in nightly they have fixed the issue and no longer crashes. Since torch 2 have built in support for automatically selecting the heuristically best operator for SDP, I think supporting it is good as it is one fewer dependency to set up and provides more convenience to users.

I think it is easy to implement and I have done that in my recent code. Additional comment: I think 2_0/XFormers/Default for Added/Sliced/Default/CustomDiffusion/LoRA usually share a lot of logic, and had better be designed in a composable manner, the latter ones usually only care about preparing QKV matrices while the former only care about calculating attention from the QKV.

class CustomDiffusionAttnProcessor2_0(nn.Module):
    r"""
    Processor for implementing attention for the Custom Diffusion method.

    Args:
        train_kv (`bool`, defaults to `True`):
            Whether to newly train the key and value matrices corresponding to the text features.
        train_q_out (`bool`, defaults to `True`):
            Whether to newly train query matrices corresponding to the latent image features.
        hidden_size (`int`, *optional*, defaults to `None`):
            The hidden size of the attention layer.
        cross_attention_dim (`int`, *optional*, defaults to `None`):
            The number of channels in the `encoder_hidden_states`.
        out_bias (`bool`, defaults to `True`):
            Whether to include the bias parameter in `train_q_out`.
        dropout (`float`, *optional*, defaults to 0.0):
            The dropout probability to use.
    """

    def __init__(
        self,
        train_kv=True,
        train_q_out=True,
        hidden_size=None,
        cross_attention_dim=None,
        out_bias=True,
        dropout=0.0,
    ):
        super().__init__()
        self.train_kv = train_kv
        self.train_q_out = train_q_out

        self.hidden_size = hidden_size
        self.cross_attention_dim = cross_attention_dim

        # `_custom_diffusion` id for easy serialization and loading.
        if self.train_kv:
            self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
            self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
        if self.train_q_out:
            self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
            self.to_out_custom_diffusion = nn.ModuleList([])
            self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
            self.to_out_custom_diffusion.append(nn.Dropout(dropout))

    def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
        batch_size, sequence_length, _ = hidden_states.shape
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
        if self.train_q_out:
            query = self.to_q_custom_diffusion(hidden_states)
        else:
            query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            crossattn = False
            encoder_hidden_states = hidden_states
        else:
            crossattn = True
            if attn.norm_cross:
                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        if self.train_kv:
            key = self.to_k_custom_diffusion(encoder_hidden_states)
            value = self.to_v_custom_diffusion(encoder_hidden_states)
        else:
            key = attn.to_k(encoder_hidden_states)
            value = attn.to_v(encoder_hidden_states)

        if crossattn:
            detach = torch.ones_like(key)
            detach[:, :1, :] = detach[:, :1, :] * 0.0
            key = detach * key + (1 - detach) * key.detach()
            value = detach * value + (1 - detach) * value.detach()

        # query = attn.head_to_batch_dim(query)
        # key = attn.head_to_batch_dim(key)
        # value = attn.head_to_batch_dim(value)
        inner_dim = hidden_states.shape[-1]

        head_dim = inner_dim // attn.heads
        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        # the output of sdp = (batch, num_heads, seq_len, head_dim)
        # TODO: add support for attn.scale when we move to Torch 2.1
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # attention_probs = attn.get_attention_scores(query, key, attention_mask)
        # hidden_states = torch.bmm(attention_probs, value)
        # hidden_states = attn.batch_to_head_dim(hidden_states)

        if self.train_q_out:
            # linear proj
            hidden_states = self.to_out_custom_diffusion[0](hidden_states)
            # dropout
            hidden_states = self.to_out_custom_diffusion[1](hidden_states)
        else:
            # linear proj
            hidden_states = attn.to_out[0](hidden_states)
            # dropout
            hidden_states = attn.to_out[1](hidden_states)

        return hidden_states

@sayakpaul
Copy link
Member

Awesome. Thanks so much for providing the details and sharing your thoughts. Let's jam on #4604.

@sayakpaul
Copy link
Member

Additional comment: I think 2_0/XFormers/Default for Added/Sliced/Default/CustomDiffusion/LoRA usually share a lot of logic, and had better be designed in a composable manner, the latter ones usually only care about preparing QKV matrices while the former only care about calculating attention from the QKV.

You're right. We disentangle them to promote a better separation of concern and I agree that they can be unified in a composable manner.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Sep 12, 2023
@eliphatfs
Copy link
Contributor Author

The request is still open.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale Issues that haven't received updates
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants