diff --git a/examples/community/rerender_a_video.py b/examples/community/rerender_a_video.py index 706b22bbb88d..a2830d8b0e12 100644 --- a/examples/community/rerender_a_video.py +++ b/examples/community/rerender_a_video.py @@ -632,7 +632,7 @@ def __call__( The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. frames (`List[np.ndarray]` or `torch.Tensor`): The input images to be used as the starting point for the image generation process. - control_frames (`List[np.ndarray]` or `torch.Tensor`): The ControlNet input images condition to provide guidance to the `unet` for generation. + control_frames (`List[np.ndarray]` or `torch.Tensor` or `Callable`): The ControlNet input images condition to provide guidance to the `unet` for generation or any callable object to convert frame to control_frame. strength ('float'): SDEdit strength. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -789,7 +789,7 @@ def __call__( # Currently we only support single control if isinstance(controlnet, ControlNetModel): control_image = self.prepare_control_image( - image=control_frames[0], + image=control_frames(frames[0]) if callable(control_frames) else control_frames[0], width=width, height=height, batch_size=batch_size, @@ -924,7 +924,7 @@ def __call__( for idx in range(1, len(frames)): image = frames[idx] prev_image = frames[idx - 1] - control_image = control_frames[idx] + control_image = control_frames(image) if callable(control_frames) else control_frames[idx] # 5.1 prepare frames image = self.image_processor.preprocess(image).to(dtype=self.dtype) prev_image = self.image_processor.preprocess(prev_image).to(dtype=self.dtype)