Skip to content

Commit

Permalink
add callable object to convert frame into control_frame to reduce cpu…
Browse files Browse the repository at this point in the history
… memory usage. (#10501)

* Update rerender_a_video.py

* Update rerender_a_video.py

* Update examples/community/rerender_a_video.py

Co-authored-by: hlky <[email protected]>

---------

Co-authored-by: hlky <[email protected]>
Co-authored-by: YiYi Xu <[email protected]>
  • Loading branch information
3 people authored Jan 9, 2025
1 parent f0c6d97 commit 7bc8b92
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions examples/community/rerender_a_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ def __call__(
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
frames (`List[np.ndarray]` or `torch.Tensor`): The input images to be used as the starting point for the image generation process.
control_frames (`List[np.ndarray]` or `torch.Tensor`): The ControlNet input images condition to provide guidance to the `unet` for generation.
control_frames (`List[np.ndarray]` or `torch.Tensor` or `Callable`): The ControlNet input images condition to provide guidance to the `unet` for generation or any callable object to convert frame to control_frame.
strength ('float'): SDEdit strength.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
Expand Down Expand Up @@ -789,7 +789,7 @@ def __call__(
# Currently we only support single control
if isinstance(controlnet, ControlNetModel):
control_image = self.prepare_control_image(
image=control_frames[0],
image=control_frames(frames[0]) if callable(control_frames) else control_frames[0],
width=width,
height=height,
batch_size=batch_size,
Expand Down Expand Up @@ -924,7 +924,7 @@ def __call__(
for idx in range(1, len(frames)):
image = frames[idx]
prev_image = frames[idx - 1]
control_image = control_frames[idx]
control_image = control_frames(image) if callable(control_frames) else control_frames[idx]
# 5.1 prepare frames
image = self.image_processor.preprocess(image).to(dtype=self.dtype)
prev_image = self.image_processor.preprocess(prev_image).to(dtype=self.dtype)
Expand Down

0 comments on commit 7bc8b92

Please sign in to comment.