From 589cd8100bd89b6932980a29958ce376eb27fe00 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 25 Sep 2023 19:27:20 +0200 Subject: [PATCH] make style --- .../pipelines/wuerstchen/pipeline_wuerstchen_combined.py | 5 +++-- src/diffusers/schedulers/scheduling_utils.py | 2 +- src/diffusers/schedulers/scheduling_utils_flax.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py index bba476b5feab..6b5ce9530d4c 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py @@ -227,8 +227,9 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. prior_callback (`Callable`, *optional*): - A function that will be called every `prior_callback_steps` steps during inference. The function will be - called with the following arguments: `prior_callback(step: int, timestep: int, latents: torch.FloatTensor)`. + A function that will be called every `prior_callback_steps` steps during inference. The function will + be called with the following arguments: `prior_callback(step: int, timestep: int, latents: + torch.FloatTensor)`. prior_callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index 2e3147b80e60..9d9472a9063f 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -15,7 +15,7 @@ import os from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, Optional, Union +from typing import Optional, Union import torch diff --git a/src/diffusers/schedulers/scheduling_utils_flax.py b/src/diffusers/schedulers/scheduling_utils_flax.py index d0bed6ef5f91..ccec121d3094 100644 --- a/src/diffusers/schedulers/scheduling_utils_flax.py +++ b/src/diffusers/schedulers/scheduling_utils_flax.py @@ -16,7 +16,7 @@ import os from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, Optional, Tuple, Union +from typing import Optional, Tuple, Union import flax import jax.numpy as jnp