Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

how to use torch.float16 in diffusers pipeline with pytorch xla #8223

Open
fancy45daddy opened this issue Oct 6, 2024 · 0 comments
Open

Comments

@fancy45daddy
Copy link

❓ Questions and Help

import diffusers, torch, os
import torch_xla.core.xla_model as xm

pipeline = diffusers.DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None, use_safetensors=True, torch_dtype=torch.float16)
# Move the model to the first TPU core
pipeline = pipeline.to(xm.xla_device())
image = pipeline("a cloud tpu winning a kaggle competition", num_inference_steps=20).images[0]
image

I run the above code in kaggle
and get

RuntimeError                              Traceback (most recent call last)
Cell In[2], line 8
      6 # Move the model to the first TPU core
      7 pipeline = pipeline.to(xm.xla_device())
----> 8 image = pipeline("a cloud tpu winning a kaggle competition", num_inference_steps=20).images[0]
      9 image

File /usr/local/lib/python3.8/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File /usr/local/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py:1000, in StableDiffusionPipeline.__call__(self, prompt, height, width, num_inference_steps, timesteps, sigmas, guidance_scale, negative_prompt, num_images_per_prompt, eta, generator, latents, prompt_embeds, negative_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, output_type, return_dict, cross_attention_kwargs, guidance_rescale, clip_skip, callback_on_step_end, callback_on_step_end_tensor_inputs, **kwargs)
    997 latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
    999 # predict the noise residual
-> 1000 noise_pred = self.unet(
   1001     latent_model_input,
   1002     t,
   1003     encoder_hidden_states=prompt_embeds,
   1004     timestep_cond=timestep_cond,
   1005     cross_attention_kwargs=self.cross_attention_kwargs,
   1006     added_cond_kwargs=added_cond_kwargs,
   1007     return_dict=False,
   1008 )[0]
   1010 # perform guidance
   1011 if self.do_classifier_free_guidance:

File /usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /usr/local/lib/python3.8/site-packages/diffusers/models/unets/unet_2d_condition.py:1169, in UNet2DConditionModel.forward(self, sample, timestep, encoder_hidden_states, class_labels, timestep_cond, attention_mask, cross_attention_kwargs, added_cond_kwargs, down_block_additional_residuals, mid_block_additional_residual, down_intrablock_additional_residuals, encoder_attention_mask, return_dict)
   1164 encoder_hidden_states = self.process_encoder_hidden_states(
   1165     encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
   1166 )
   1168 # 2. pre-process
-> 1169 sample = self.conv_in(sample)
   1171 # 2.5 GLIGEN position net
   1172 if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:

File /usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /usr/local/lib/python3.8/site-packages/torch/nn/modules/conv.py:463, in Conv2d.forward(self, input)
    462 def forward(self, input: Tensor) -> Tensor:
--> 463     return self._conv_forward(input, self.weight, self.bias)

File /usr/local/lib/python3.8/site-packages/torch/nn/modules/conv.py:459, in Conv2d._conv_forward(self, input, weight, bias)
    455 if self.padding_mode != 'zeros':
    456     return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
    457                     weight, bias, self.stride,
    458                     _pair(0), self.dilation, self.groups)
--> 459 return F.conv2d(input, weight, bias, self.stride,
    460                 self.padding, self.dilation, self.groups)

RuntimeError: Input type (c10::BFloat16) and bias type (c10::Half) should be the same
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant