Skip to content

Commit

Permalink
Add callbacks to WuerstchenDecoderPipeline and `WuerstchenCombinedP…
Browse files Browse the repository at this point in the history
…ipeline` (#5154)
  • Loading branch information
carson-katri authored Sep 25, 2023
1 parent 28254c7 commit 6281d20
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
15 changes: 13 additions & 2 deletions src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional, Union
from typing import Callable, List, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -202,6 +202,8 @@ def __call__(
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
):
"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -240,6 +242,12 @@ def __call__(
(`np.array`) or `"pt"` (`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Examples:
Expand Down Expand Up @@ -315,7 +323,7 @@ def __call__(
latents = self.prepare_latents(latent_features_shape, dtype, device, generator, latents, self.scheduler)

# 6. Run denoising loop
for t in self.progress_bar(timesteps[:-1]):
for i, t in enumerate(self.progress_bar(timesteps[:-1])):
ratio = t.expand(latents.size(0)).to(dtype)
effnet = (
torch.cat([image_embeddings, torch.zeros_like(image_embeddings)])
Expand Down Expand Up @@ -343,6 +351,9 @@ def __call__(
generator=generator,
).prev_sample

if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

# 10. Scale and decode the image latents with vq-vae
latents = self.vqgan.config.scale_factor * latents
images = self.vqgan.decode(latents).sample.clamp(0, 1)
Expand Down
20 changes: 20 additions & 0 deletions src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ def __call__(
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
prior_callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
prior_callback_steps: int = 1,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
):
"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -222,6 +226,18 @@ def __call__(
(`np.array`) or `"pt"` (`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
prior_callback (`Callable`, *optional*):
A function that will be called every `prior_callback_steps` steps during inference. The function will be
called with the following arguments: `prior_callback(step: int, timestep: int, latents: torch.FloatTensor)`.
prior_callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Examples:
Expand All @@ -244,6 +260,8 @@ def __call__(
latents=latents,
output_type="pt",
return_dict=False,
callback=prior_callback,
callback_steps=prior_callback_steps,
)
image_embeddings = prior_outputs[0]

Expand All @@ -257,6 +275,8 @@ def __call__(
generator=generator,
output_type=output_type,
return_dict=return_dict,
callback=callback,
callback_steps=callback_steps,
)

return outputs

0 comments on commit 6281d20

Please sign in to comment.