Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug in SanaPipeline example? #10489

Closed
geronimi73 opened this issue Jan 7, 2025 · 2 comments · Fixed by #10507
Closed

Bug in SanaPipeline example? #10489

geronimi73 opened this issue Jan 7, 2025 · 2 comments · Fixed by #10507
Labels
bug Something isn't working

Comments

@geronimi73
Copy link
Contributor

geronimi73 commented Jan 7, 2025

Describe the bug

I think there might be something wrong with the SanaPipeline example code at https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana#diffusers.SanaPipeline
It results in a shape mismatch (see detailed logs below): mat1 and mat2 shapes cannot be multiplied (600x256000 and 2304x1152)

I've noticed that the text_encoder model looks different depending on the way it is loaded.

  • If I load it with the official example code (=code in Reproduction), pipeline.text_encoder looks like this:
Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 2304, padding_idx=0)
    (layers): ModuleList(
      (0-25): 26 x Gemma2DecoderLayer(
        (self_attn): Gemma2Attention(
          (q_proj): Linear(in_features=2304, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (v_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2304, bias=False)
          (rotary_emb): Gemma2RotaryEmbedding()
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (up_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (down_proj): Linear(in_features=9216, out_features=2304, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (pre_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (post_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (post_attention_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
      )
    )
    (norm): Gemma2RMSNorm((2304,), eps=1e-06)
  )
  (lm_head): Linear(in_features=2304, out_features=256000, bias=False)
)

If however I don't load the components separately but with the code provided by @lawrence-cj here it 1) works and 2) the text_encoder looks different:

Gemma2Model(
  (embed_tokens): Embedding(256000, 2304, padding_idx=0)
  (layers): ModuleList(
    (0-25): 26 x Gemma2DecoderLayer(
      (self_attn): Gemma2Attention(
        (q_proj): Linear(in_features=2304, out_features=2048, bias=False)
        (k_proj): Linear(in_features=2304, out_features=1024, bias=False)
        (v_proj): Linear(in_features=2304, out_features=1024, bias=False)
        (o_proj): Linear(in_features=2048, out_features=2304, bias=False)
        (rotary_emb): Gemma2RotaryEmbedding()
      )
      (mlp): Gemma2MLP(
        (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)
        (up_proj): Linear(in_features=2304, out_features=9216, bias=False)
        (down_proj): Linear(in_features=9216, out_features=2304, bias=False)
        (act_fn): PytorchGELUTanh()
      )
      (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
      (pre_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
      (post_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
      (post_attention_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
    )
  )
  (norm): Gemma2RMSNorm((2304,), eps=1e-06)
)

-> the language modeling head lm_head is gone. Is guess that's all expected (?) but I haven't found any documentation of this behaviour or where in the pipeline code this happens.

Reproduction

import torch
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SanaTransformer2DModel, SanaPipeline
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModelForCausalLM

quant_config = BitsAndBytesConfig(load_in_8bit=True)
text_encoder_8bit = AutoModelForCausalLM.from_pretrained(
    "Efficient-Large-Model/Sana_600M_1024px_diffusers",
    subfolder="text_encoder",
    # quantization_config=quant_config,
    torch_dtype=torch.float16,
)

quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
transformer_8bit = SanaTransformer2DModel.from_pretrained(
    "Efficient-Large-Model/Sana_600M_1024px_diffusers",
    subfolder="transformer",
    # quantization_config=quant_config,
    torch_dtype=torch.float16,
)

pipeline = SanaPipeline.from_pretrained(
    "Efficient-Large-Model/Sana_600M_1024px_diffusers",
    text_encoder=text_encoder_8bit,
    transformer=transformer_8bit,
    torch_dtype=torch.float16,
    device_map="balanced",
)

prompt = "a tiny astronaut hatching from an egg on the moon"
image = pipeline(prompt).images[0]
image.save("sana.png")

Loading without quantization_config because for some reason this does not work on my mac but I tried the same code on a 4090 and it fails there too.

Logs

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[5], line 30
     21 pipeline = SanaPipeline.from_pretrained(
     22     "Efficient-Large-Model/Sana_600M_1024px_diffusers",
     23     text_encoder=text_encoder_8bit,
   (...)
     26     device_map="balanced",
     27 )
     29 prompt = "a tiny astronaut hatching from an egg on the moon"
---> 30 image = pipeline(prompt).images[0]
     31 image.save("sana.png")

File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/diffusers/pipelines/sana/pipeline_sana.py:829, in SanaPipeline.__call__(self, prompt, negative_prompt, num_inference_steps, timesteps, sigmas, guidance_scale, num_images_per_prompt, height, width, eta, generator, latents, prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask, output_type, return_dict, clean_caption, use_resolution_binning, attention_kwargs, callback_on_step_end, callback_on_step_end_tensor_inputs, max_sequence_length, complex_human_instruction)
    826 timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
    828 # predict noise model_output
--> 829 noise_pred = self.transformer(
    830     latent_model_input,
    831     encoder_hidden_states=prompt_embeds,
    832     encoder_attention_mask=prompt_attention_mask,
    833     timestep=timestep,
    834     return_dict=False,
    835     attention_kwargs=self.attention_kwargs,
    836 )[0]
    837 noise_pred = noise_pred.float()
    839 # perform guidance

File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/diffusers/models/transformers/sana_transformer.py:420, in SanaTransformer2DModel.forward(self, hidden_states, encoder_hidden_states, timestep, encoder_attention_mask, attention_mask, attention_kwargs, return_dict)
    414 hidden_states = self.patch_embed(hidden_states)
    416 timestep, embedded_timestep = self.time_embed(
    417     timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
    418 )
--> 420 encoder_hidden_states = self.caption_projection(encoder_hidden_states)
    421 encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
    423 encoder_hidden_states = self.caption_norm(encoder_hidden_states)

File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/diffusers/models/embeddings.py:2221, in PixArtAlphaTextProjection.forward(self, caption)
   2220 def forward(self, caption):
-> 2221     hidden_states = self.linear_1(caption)
   2222     hidden_states = self.act_1(hidden_states)
   2223     hidden_states = self.linear_2(hidden_states)

File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/linear.py:125, in Linear.forward(self, input)
    124 def forward(self, input: Tensor) -> Tensor:
--> 125     return F.linear(input, self.weight, self.bias)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (600x256000 and 2304x1152)

System Info

diffusers 0.32.1
torch 2.5.1
Python 3.13.0

M3 MacBook Air, MacOS Sonoma 14.4.1

Who can help?

@sayakpaul @lawrence-cj

@geronimi73 geronimi73 added the bug Something isn't working label Jan 7, 2025
@lawrence-cj
Copy link
Contributor

@geronimi73
The reason here is that the model_index.json for Pipeline shows that the text_encoder we use is Gemma2Model, so the following code snap would fix your problem:

Gemma2Model.from_pretrained(
    "Sana_600M_1024px_diffusers",
    subfolder="text_encoder",
    # quantization_config=quant_config,
    torch_dtype=torch.bfloat16,
    )

output:

Gemma2Model(
  (embed_tokens): Embedding(256000, 2304, padding_idx=0)
  (layers): ModuleList(
    (0-25): 26 x Gemma2DecoderLayer(
      (self_attn): Gemma2Attention(
        (q_proj): Linear(in_features=2304, out_features=2048, bias=False)
        (k_proj): Linear(in_features=2304, out_features=1024, bias=False)
        (v_proj): Linear(in_features=2304, out_features=1024, bias=False)
        (o_proj): Linear(in_features=2048, out_features=2304, bias=False)
        (rotary_emb): Gemma2RotaryEmbedding()
      )
      (mlp): Gemma2MLP(
        (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)
        (up_proj): Linear(in_features=2304, out_features=9216, bias=False)
        (down_proj): Linear(in_features=9216, out_features=2304, bias=False)
        (act_fn): PytorchGELUTanh()
      )
      (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
      (pre_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
      (post_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
      (post_attention_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
    )
  )
  (norm): Gemma2RMSNorm((2304,), eps=1e-06)
)

BTW, the precision for Gemma is BF16 according to the official repo.

@geronimi73
Copy link
Contributor Author

@lawrence-cj Got it! Thank you!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants