Description
Describe the bug
I think there might be something wrong with the SanaPipeline
example code at https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana#diffusers.SanaPipeline
It results in a shape mismatch (see detailed logs below): mat1 and mat2 shapes cannot be multiplied (600x256000 and 2304x1152)
I've noticed that the text_encoder
model looks different depending on the way it is loaded.
- If I load it with the official example code (=code in
Reproduction
),pipeline.text_encoder
looks like this:
Gemma2ForCausalLM(
(model): Gemma2Model(
(embed_tokens): Embedding(256000, 2304, padding_idx=0)
(layers): ModuleList(
(0-25): 26 x Gemma2DecoderLayer(
(self_attn): Gemma2Attention(
(q_proj): Linear(in_features=2304, out_features=2048, bias=False)
(k_proj): Linear(in_features=2304, out_features=1024, bias=False)
(v_proj): Linear(in_features=2304, out_features=1024, bias=False)
(o_proj): Linear(in_features=2048, out_features=2304, bias=False)
(rotary_emb): Gemma2RotaryEmbedding()
)
(mlp): Gemma2MLP(
(gate_proj): Linear(in_features=2304, out_features=9216, bias=False)
(up_proj): Linear(in_features=2304, out_features=9216, bias=False)
(down_proj): Linear(in_features=9216, out_features=2304, bias=False)
(act_fn): PytorchGELUTanh()
)
(input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
(pre_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
(post_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
(post_attention_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
)
)
(norm): Gemma2RMSNorm((2304,), eps=1e-06)
)
(lm_head): Linear(in_features=2304, out_features=256000, bias=False)
)
If however I don't load the components separately but with the code provided by @lawrence-cj here it 1) works and 2) the text_encoder
looks different:
Gemma2Model(
(embed_tokens): Embedding(256000, 2304, padding_idx=0)
(layers): ModuleList(
(0-25): 26 x Gemma2DecoderLayer(
(self_attn): Gemma2Attention(
(q_proj): Linear(in_features=2304, out_features=2048, bias=False)
(k_proj): Linear(in_features=2304, out_features=1024, bias=False)
(v_proj): Linear(in_features=2304, out_features=1024, bias=False)
(o_proj): Linear(in_features=2048, out_features=2304, bias=False)
(rotary_emb): Gemma2RotaryEmbedding()
)
(mlp): Gemma2MLP(
(gate_proj): Linear(in_features=2304, out_features=9216, bias=False)
(up_proj): Linear(in_features=2304, out_features=9216, bias=False)
(down_proj): Linear(in_features=9216, out_features=2304, bias=False)
(act_fn): PytorchGELUTanh()
)
(input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
(pre_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
(post_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
(post_attention_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
)
)
(norm): Gemma2RMSNorm((2304,), eps=1e-06)
)
-> the language modeling head lm_head
is gone. Is guess that's all expected (?) but I haven't found any documentation of this behaviour or where in the pipeline code this happens.
Reproduction
import torch
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SanaTransformer2DModel, SanaPipeline
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModelForCausalLM
quant_config = BitsAndBytesConfig(load_in_8bit=True)
text_encoder_8bit = AutoModelForCausalLM.from_pretrained(
"Efficient-Large-Model/Sana_600M_1024px_diffusers",
subfolder="text_encoder",
# quantization_config=quant_config,
torch_dtype=torch.float16,
)
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
transformer_8bit = SanaTransformer2DModel.from_pretrained(
"Efficient-Large-Model/Sana_600M_1024px_diffusers",
subfolder="transformer",
# quantization_config=quant_config,
torch_dtype=torch.float16,
)
pipeline = SanaPipeline.from_pretrained(
"Efficient-Large-Model/Sana_600M_1024px_diffusers",
text_encoder=text_encoder_8bit,
transformer=transformer_8bit,
torch_dtype=torch.float16,
device_map="balanced",
)
prompt = "a tiny astronaut hatching from an egg on the moon"
image = pipeline(prompt).images[0]
image.save("sana.png")
Loading without quantization_config
because for some reason this does not work on my mac but I tried the same code on a 4090 and it fails there too.
Logs
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[5], line 30
21 pipeline = SanaPipeline.from_pretrained(
22 "Efficient-Large-Model/Sana_600M_1024px_diffusers",
23 text_encoder=text_encoder_8bit,
(...)
26 device_map="balanced",
27 )
29 prompt = "a tiny astronaut hatching from an egg on the moon"
---> 30 image = pipeline(prompt).images[0]
31 image.save("sana.png")
File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
113 @functools.wraps(func)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)
File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/diffusers/pipelines/sana/pipeline_sana.py:829, in SanaPipeline.__call__(self, prompt, negative_prompt, num_inference_steps, timesteps, sigmas, guidance_scale, num_images_per_prompt, height, width, eta, generator, latents, prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask, output_type, return_dict, clean_caption, use_resolution_binning, attention_kwargs, callback_on_step_end, callback_on_step_end_tensor_inputs, max_sequence_length, complex_human_instruction)
826 timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
828 # predict noise model_output
--> 829 noise_pred = self.transformer(
830 latent_model_input,
831 encoder_hidden_states=prompt_embeds,
832 encoder_attention_mask=prompt_attention_mask,
833 timestep=timestep,
834 return_dict=False,
835 attention_kwargs=self.attention_kwargs,
836 )[0]
837 noise_pred = noise_pred.float()
839 # perform guidance
File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/diffusers/models/transformers/sana_transformer.py:420, in SanaTransformer2DModel.forward(self, hidden_states, encoder_hidden_states, timestep, encoder_attention_mask, attention_mask, attention_kwargs, return_dict)
414 hidden_states = self.patch_embed(hidden_states)
416 timestep, embedded_timestep = self.time_embed(
417 timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
418 )
--> 420 encoder_hidden_states = self.caption_projection(encoder_hidden_states)
421 encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
423 encoder_hidden_states = self.caption_norm(encoder_hidden_states)
File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/diffusers/models/embeddings.py:2221, in PixArtAlphaTextProjection.forward(self, caption)
2220 def forward(self, caption):
-> 2221 hidden_states = self.linear_1(caption)
2222 hidden_states = self.act_1(hidden_states)
2223 hidden_states = self.linear_2(hidden_states)
File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/linear.py:125, in Linear.forward(self, input)
124 def forward(self, input: Tensor) -> Tensor:
--> 125 return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (600x256000 and 2304x1152)
System Info
diffusers 0.32.1
torch 2.5.1
Python 3.13.0
M3 MacBook Air, MacOS Sonoma 14.4.1