Skip to content

Bug in SanaPipeline example? #10489

Closed
@geronimi73

Description

@geronimi73

Describe the bug

I think there might be something wrong with the SanaPipeline example code at https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana#diffusers.SanaPipeline
It results in a shape mismatch (see detailed logs below): mat1 and mat2 shapes cannot be multiplied (600x256000 and 2304x1152)

I've noticed that the text_encoder model looks different depending on the way it is loaded.

  • If I load it with the official example code (=code in Reproduction), pipeline.text_encoder looks like this:
Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 2304, padding_idx=0)
    (layers): ModuleList(
      (0-25): 26 x Gemma2DecoderLayer(
        (self_attn): Gemma2Attention(
          (q_proj): Linear(in_features=2304, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (v_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2304, bias=False)
          (rotary_emb): Gemma2RotaryEmbedding()
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (up_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (down_proj): Linear(in_features=9216, out_features=2304, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (pre_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (post_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (post_attention_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
      )
    )
    (norm): Gemma2RMSNorm((2304,), eps=1e-06)
  )
  (lm_head): Linear(in_features=2304, out_features=256000, bias=False)
)

If however I don't load the components separately but with the code provided by @lawrence-cj here it 1) works and 2) the text_encoder looks different:

Gemma2Model(
  (embed_tokens): Embedding(256000, 2304, padding_idx=0)
  (layers): ModuleList(
    (0-25): 26 x Gemma2DecoderLayer(
      (self_attn): Gemma2Attention(
        (q_proj): Linear(in_features=2304, out_features=2048, bias=False)
        (k_proj): Linear(in_features=2304, out_features=1024, bias=False)
        (v_proj): Linear(in_features=2304, out_features=1024, bias=False)
        (o_proj): Linear(in_features=2048, out_features=2304, bias=False)
        (rotary_emb): Gemma2RotaryEmbedding()
      )
      (mlp): Gemma2MLP(
        (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)
        (up_proj): Linear(in_features=2304, out_features=9216, bias=False)
        (down_proj): Linear(in_features=9216, out_features=2304, bias=False)
        (act_fn): PytorchGELUTanh()
      )
      (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
      (pre_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
      (post_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
      (post_attention_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
    )
  )
  (norm): Gemma2RMSNorm((2304,), eps=1e-06)
)

-> the language modeling head lm_head is gone. Is guess that's all expected (?) but I haven't found any documentation of this behaviour or where in the pipeline code this happens.

Reproduction

import torch
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SanaTransformer2DModel, SanaPipeline
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModelForCausalLM

quant_config = BitsAndBytesConfig(load_in_8bit=True)
text_encoder_8bit = AutoModelForCausalLM.from_pretrained(
    "Efficient-Large-Model/Sana_600M_1024px_diffusers",
    subfolder="text_encoder",
    # quantization_config=quant_config,
    torch_dtype=torch.float16,
)

quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
transformer_8bit = SanaTransformer2DModel.from_pretrained(
    "Efficient-Large-Model/Sana_600M_1024px_diffusers",
    subfolder="transformer",
    # quantization_config=quant_config,
    torch_dtype=torch.float16,
)

pipeline = SanaPipeline.from_pretrained(
    "Efficient-Large-Model/Sana_600M_1024px_diffusers",
    text_encoder=text_encoder_8bit,
    transformer=transformer_8bit,
    torch_dtype=torch.float16,
    device_map="balanced",
)

prompt = "a tiny astronaut hatching from an egg on the moon"
image = pipeline(prompt).images[0]
image.save("sana.png")

Loading without quantization_config because for some reason this does not work on my mac but I tried the same code on a 4090 and it fails there too.

Logs

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[5], line 30
     21 pipeline = SanaPipeline.from_pretrained(
     22     "Efficient-Large-Model/Sana_600M_1024px_diffusers",
     23     text_encoder=text_encoder_8bit,
   (...)
     26     device_map="balanced",
     27 )
     29 prompt = "a tiny astronaut hatching from an egg on the moon"
---> 30 image = pipeline(prompt).images[0]
     31 image.save("sana.png")

File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/diffusers/pipelines/sana/pipeline_sana.py:829, in SanaPipeline.__call__(self, prompt, negative_prompt, num_inference_steps, timesteps, sigmas, guidance_scale, num_images_per_prompt, height, width, eta, generator, latents, prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask, output_type, return_dict, clean_caption, use_resolution_binning, attention_kwargs, callback_on_step_end, callback_on_step_end_tensor_inputs, max_sequence_length, complex_human_instruction)
    826 timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
    828 # predict noise model_output
--> 829 noise_pred = self.transformer(
    830     latent_model_input,
    831     encoder_hidden_states=prompt_embeds,
    832     encoder_attention_mask=prompt_attention_mask,
    833     timestep=timestep,
    834     return_dict=False,
    835     attention_kwargs=self.attention_kwargs,
    836 )[0]
    837 noise_pred = noise_pred.float()
    839 # perform guidance

File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/diffusers/models/transformers/sana_transformer.py:420, in SanaTransformer2DModel.forward(self, hidden_states, encoder_hidden_states, timestep, encoder_attention_mask, attention_mask, attention_kwargs, return_dict)
    414 hidden_states = self.patch_embed(hidden_states)
    416 timestep, embedded_timestep = self.time_embed(
    417     timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
    418 )
--> 420 encoder_hidden_states = self.caption_projection(encoder_hidden_states)
    421 encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
    423 encoder_hidden_states = self.caption_norm(encoder_hidden_states)

File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/diffusers/models/embeddings.py:2221, in PixArtAlphaTextProjection.forward(self, caption)
   2220 def forward(self, caption):
-> 2221     hidden_states = self.linear_1(caption)
   2222     hidden_states = self.act_1(hidden_states)
   2223     hidden_states = self.linear_2(hidden_states)

File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/linear.py:125, in Linear.forward(self, input)
    124 def forward(self, input: Tensor) -> Tensor:
--> 125     return F.linear(input, self.weight, self.bias)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (600x256000 and 2304x1152)

System Info

diffusers 0.32.1
torch 2.5.1
Python 3.13.0

M3 MacBook Air, MacOS Sonoma 14.4.1

Who can help?

@sayakpaul @lawrence-cj

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions