diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 9ea6f979c239..e4d976504c6d 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Union +from typing import Callable, List, Optional, Union import numpy as np import torch @@ -202,6 +202,8 @@ def __call__( latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, ): """ Function invoked when calling the pipeline for generation. @@ -240,6 +242,12 @@ def __call__( (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. Examples: @@ -315,7 +323,7 @@ def __call__( latents = self.prepare_latents(latent_features_shape, dtype, device, generator, latents, self.scheduler) # 6. Run denoising loop - for t in self.progress_bar(timesteps[:-1]): + for i, t in enumerate(self.progress_bar(timesteps[:-1])): ratio = t.expand(latents.size(0)).to(dtype) effnet = ( torch.cat([image_embeddings, torch.zeros_like(image_embeddings)]) @@ -343,6 +351,9 @@ def __call__( generator=generator, ).prev_sample + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + # 10. Scale and decode the image latents with vq-vae latents = self.vqgan.config.scale_factor * latents images = self.vqgan.decode(latents).sample.clamp(0, 1) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py index 2a8614b21e15..bba476b5feab 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py @@ -161,6 +161,10 @@ def __call__( latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, + prior_callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + prior_callback_steps: int = 1, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, ): """ Function invoked when calling the pipeline for generation. @@ -222,6 +226,18 @@ def __call__( (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + prior_callback (`Callable`, *optional*): + A function that will be called every `prior_callback_steps` steps during inference. The function will be + called with the following arguments: `prior_callback(step: int, timestep: int, latents: torch.FloatTensor)`. + prior_callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. Examples: @@ -244,6 +260,8 @@ def __call__( latents=latents, output_type="pt", return_dict=False, + callback=prior_callback, + callback_steps=prior_callback_steps, ) image_embeddings = prior_outputs[0] @@ -257,6 +275,8 @@ def __call__( generator=generator, output_type=output_type, return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, ) return outputs