From 80c00e5451e0ced32043fbb0ed06eb6f3c427f82 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Thu, 21 Sep 2023 13:50:41 -1000 Subject: [PATCH] add `use_karras_sigmas` to `KDPM2DiscreteScheduler` and `KDPM2AncestralDiscreteScheduler` (#5111) --------- Co-authored-by: yiyixuxu --- .../pipeline_stable_diffusion_k_diffusion.py | 3 + .../scheduling_k_dpm_2_ancestral_discrete.py | 51 +++++++++--- .../schedulers/scheduling_k_dpm_2_discrete.py | 81 +++++++++++++------ 3 files changed, 99 insertions(+), 36 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py index ce074b1e6d1d..2cddf819781d 100755 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -609,6 +609,9 @@ def model_fn(x, t): noise_sampler = BrownianTreeNoiseSampler(latents, min_sigma, max_sigma, noise_sampler_seed) sampler_kwargs["noise_sampler"] = noise_sampler + if "generator" in inspect.signature(self.sampler).parameters: + sampler_kwargs["generator"] = generator + latents = self.sampler(model_fn, latents, sigmas, **sampler_kwargs) if not output_type == "latent": diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py index b44ff31379ad..a0137b83fda1 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py @@ -89,6 +89,9 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): `linear` or `scaled_linear`. trained_betas (`np.ndarray`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. prediction_type (`str`, defaults to `epsilon`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen @@ -113,6 +116,7 @@ def __init__( beta_end: float = 0.012, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + use_karras_sigmas: Optional[bool] = False, prediction_type: str = "epsilon", timestep_spacing: str = "linspace", steps_offset: int = 0, @@ -243,9 +247,15 @@ def set_timesteps( ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device) + log_sigmas = np.log(sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + + if self.config.use_karras_sigmas: + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + + self.log_sigmas = torch.from_numpy(log_sigmas).to(device) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = torch.from_numpy(sigmas).to(device=device) @@ -269,7 +279,13 @@ def set_timesteps( self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]]) timesteps = torch.from_numpy(timesteps).to(device) - timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype) + sigmas_interpol = sigmas_interpol.cpu() + log_sigmas = self.log_sigmas.cpu() + timesteps_interpol = np.array( + [self._sigma_to_t(sigma_interpol, log_sigmas) for sigma_interpol in sigmas_interpol] + ) + + timesteps_interpol = torch.from_numpy(timesteps_interpol).to(device, dtype=timesteps.dtype) interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten() self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps]) @@ -282,29 +298,44 @@ def set_timesteps( self._step_index = None - def sigma_to_t(self, sigma): + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): # get log sigma - log_sigma = sigma.log() + log_sigma = np.log(sigma) # get distribution - dists = log_sigma - self.log_sigmas[:, None] + dists = log_sigma - log_sigmas[:, np.newaxis] # get sigmas range - low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2) + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) high_idx = low_idx + 1 - low = self.log_sigmas[low_idx] - high = self.log_sigmas[high_idx] + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] # interpolate sigmas w = (low - log_sigma) / (low - high) - w = w.clamp(0, 1) + w = np.clip(w, 0, 1) # transform interpolation to time range t = (1 - w) * low_idx + w * high_idx - t = t.view(sigma.shape) + t = t.reshape(sigma.shape) return t + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min: float = in_sigmas[-1].item() + sigma_max: float = in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + @property def state_in_first_order(self): return self.sample is None diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py index d8f54a21e246..ddea57e8c167 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py @@ -88,6 +88,9 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): `linear` or `scaled_linear`. trained_betas (`np.ndarray`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. prediction_type (`str`, defaults to `epsilon`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen @@ -112,6 +115,7 @@ def __init__( beta_end: float = 0.012, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + use_karras_sigmas: Optional[bool] = False, prediction_type: str = "epsilon", timestep_spacing: str = "linspace", steps_offset: int = 0, @@ -243,9 +247,14 @@ def set_timesteps( ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device) - + log_sigmas = np.log(sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + + if self.config.use_karras_sigmas: + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + + self.log_sigmas = torch.from_numpy(log_sigmas).to(device=device) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = torch.from_numpy(sigmas).to(device=device) @@ -260,7 +269,12 @@ def set_timesteps( timesteps = torch.from_numpy(timesteps).to(device) # interpolate timesteps - timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype) + sigmas_interpol = sigmas_interpol.cpu() + log_sigmas = self.log_sigmas.cpu() + timesteps_interpol = np.array( + [self._sigma_to_t(sigma_interpol, log_sigmas) for sigma_interpol in sigmas_interpol] + ) + timesteps_interpol = torch.from_numpy(timesteps_interpol).to(device, dtype=timesteps.dtype) interleaved_timesteps = torch.stack((timesteps_interpol[1:-1, None], timesteps[1:, None]), dim=-1).flatten() self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps]) @@ -273,29 +287,6 @@ def set_timesteps( self._step_index = None - def sigma_to_t(self, sigma): - # get log sigma - log_sigma = sigma.log() - - # get distribution - dists = log_sigma - self.log_sigmas[:, None] - - # get sigmas range - low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2) - high_idx = low_idx + 1 - - low = self.log_sigmas[low_idx] - high = self.log_sigmas[high_idx] - - # interpolate sigmas - w = (low - log_sigma) / (low - high) - w = w.clamp(0, 1) - - # transform interpolation to time range - t = (1 - w) * low_idx + w * high_idx - t = t.view(sigma.shape) - return t - @property def state_in_first_order(self): return self.sample is None @@ -318,6 +309,44 @@ def _init_step_index(self, timestep): self._step_index = step_index.item() + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(sigma) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min: float = in_sigmas[-1].item() + sigma_max: float = in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + def step( self, model_output: Union[torch.FloatTensor, np.ndarray],