Skip to content

Commit

Permalink
add use_karras_sigmas to KDPM2DiscreteScheduler and `KDPM2Ancestr…
Browse files Browse the repository at this point in the history
…alDiscreteScheduler` (#5111)


---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
  • Loading branch information
yiyixuxu and yiyixuxu authored Sep 21, 2023
1 parent 2badddf commit 80c00e5
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,9 @@ def model_fn(x, t):
noise_sampler = BrownianTreeNoiseSampler(latents, min_sigma, max_sigma, noise_sampler_seed)
sampler_kwargs["noise_sampler"] = noise_sampler

if "generator" in inspect.signature(self.sampler).parameters:
sampler_kwargs["generator"] = generator

latents = self.sampler(model_fn, latents, sigmas, **sampler_kwargs)

if not output_type == "latent":
Expand Down
51 changes: 41 additions & 10 deletions src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
`linear` or `scaled_linear`.
trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
Expand All @@ -113,6 +116,7 @@ def __init__(
beta_end: float = 0.012,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
use_karras_sigmas: Optional[bool] = False,
prediction_type: str = "epsilon",
timestep_spacing: str = "linspace",
steps_offset: int = 0,
Expand Down Expand Up @@ -243,9 +247,15 @@ def set_timesteps(
)

sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device)
log_sigmas = np.log(sigmas)

sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)

if self.config.use_karras_sigmas:
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()

self.log_sigmas = torch.from_numpy(log_sigmas).to(device)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
sigmas = torch.from_numpy(sigmas).to(device=device)

Expand All @@ -269,7 +279,13 @@ def set_timesteps(
self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]])

timesteps = torch.from_numpy(timesteps).to(device)
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype)
sigmas_interpol = sigmas_interpol.cpu()
log_sigmas = self.log_sigmas.cpu()
timesteps_interpol = np.array(
[self._sigma_to_t(sigma_interpol, log_sigmas) for sigma_interpol in sigmas_interpol]
)

timesteps_interpol = torch.from_numpy(timesteps_interpol).to(device, dtype=timesteps.dtype)
interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten()

self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps])
Expand All @@ -282,29 +298,44 @@ def set_timesteps(

self._step_index = None

def sigma_to_t(self, sigma):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
log_sigma = sigma.log()
log_sigma = np.log(sigma)

# get distribution
dists = log_sigma - self.log_sigmas[:, None]
dists = log_sigma - log_sigmas[:, np.newaxis]

# get sigmas range
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1

low = self.log_sigmas[low_idx]
high = self.log_sigmas[high_idx]
low = log_sigmas[low_idx]
high = log_sigmas[high_idx]

# interpolate sigmas
w = (low - log_sigma) / (low - high)
w = w.clamp(0, 1)
w = np.clip(w, 0, 1)

# transform interpolation to time range
t = (1 - w) * low_idx + w * high_idx
t = t.view(sigma.shape)
t = t.reshape(sigma.shape)
return t

# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""

sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()

rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas

@property
def state_in_first_order(self):
return self.sample is None
Expand Down
81 changes: 55 additions & 26 deletions src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
`linear` or `scaled_linear`.
trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
Expand All @@ -112,6 +115,7 @@ def __init__(
beta_end: float = 0.012,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
use_karras_sigmas: Optional[bool] = False,
prediction_type: str = "epsilon",
timestep_spacing: str = "linspace",
steps_offset: int = 0,
Expand Down Expand Up @@ -243,9 +247,14 @@ def set_timesteps(
)

sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device)

log_sigmas = np.log(sigmas)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)

if self.config.use_karras_sigmas:
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()

self.log_sigmas = torch.from_numpy(log_sigmas).to(device=device)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
sigmas = torch.from_numpy(sigmas).to(device=device)

Expand All @@ -260,7 +269,12 @@ def set_timesteps(
timesteps = torch.from_numpy(timesteps).to(device)

# interpolate timesteps
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype)
sigmas_interpol = sigmas_interpol.cpu()
log_sigmas = self.log_sigmas.cpu()
timesteps_interpol = np.array(
[self._sigma_to_t(sigma_interpol, log_sigmas) for sigma_interpol in sigmas_interpol]
)
timesteps_interpol = torch.from_numpy(timesteps_interpol).to(device, dtype=timesteps.dtype)
interleaved_timesteps = torch.stack((timesteps_interpol[1:-1, None], timesteps[1:, None]), dim=-1).flatten()

self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps])
Expand All @@ -273,29 +287,6 @@ def set_timesteps(

self._step_index = None

def sigma_to_t(self, sigma):
# get log sigma
log_sigma = sigma.log()

# get distribution
dists = log_sigma - self.log_sigmas[:, None]

# get sigmas range
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
high_idx = low_idx + 1

low = self.log_sigmas[low_idx]
high = self.log_sigmas[high_idx]

# interpolate sigmas
w = (low - log_sigma) / (low - high)
w = w.clamp(0, 1)

# transform interpolation to time range
t = (1 - w) * low_idx + w * high_idx
t = t.view(sigma.shape)
return t

@property
def state_in_first_order(self):
return self.sample is None
Expand All @@ -318,6 +309,44 @@ def _init_step_index(self, timestep):

self._step_index = step_index.item()

# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
log_sigma = np.log(sigma)

# get distribution
dists = log_sigma - log_sigmas[:, np.newaxis]

# get sigmas range
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1

low = log_sigmas[low_idx]
high = log_sigmas[high_idx]

# interpolate sigmas
w = (low - log_sigma) / (low - high)
w = np.clip(w, 0, 1)

# transform interpolation to time range
t = (1 - w) * low_idx + w * high_idx
t = t.reshape(sigma.shape)
return t

# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""

sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()

rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas

def step(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
Expand Down

0 comments on commit 80c00e5

Please sign in to comment.