Skip to content

Commit

Permalink
timestep_spacing for FlaxDPMSolverMultistepScheduler (#5189)
Browse files Browse the repository at this point in the history
* timestep_spacing for FlaxDPMSolverMultistepScheduler

* Style
  • Loading branch information
pcuenca authored Sep 26, 2023
1 parent fd1c54a commit d9e7857
Showing 1 changed file with 27 additions and 6 deletions.
33 changes: 27 additions & 6 deletions src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
lower_order_final (`bool`, default `True`):
whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
the `dtype` used for params and computation.
"""
Expand Down Expand Up @@ -163,6 +166,7 @@ def __init__(
algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint",
lower_order_final: bool = True,
timestep_spacing: str = "linspace",
dtype: jnp.dtype = jnp.float32,
):
self.dtype = dtype
Expand Down Expand Up @@ -210,12 +214,29 @@ def set_timesteps(
shape (`Tuple`):
the shape of the samples to be generated.
"""

timesteps = (
jnp.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
.round()[::-1][:-1]
.astype(jnp.int32)
)
last_timestep = self.config.num_train_timesteps
if self.config.timestep_spacing == "linspace":
timesteps = (
jnp.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].astype(jnp.int32)
)
elif self.config.timestep_spacing == "leading":
step_ratio = last_timestep // (num_inference_steps + 1)
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (
(jnp.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(jnp.int32)
)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = jnp.arange(last_timestep, 0, -step_ratio).round().copy().astype(jnp.int32)
timesteps -= 1
else:
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)

# initial running values

Expand Down

0 comments on commit d9e7857

Please sign in to comment.