Skip to content

Commit

Permalink
fix the add_noise function for dpm-multi et al (#5158)
Browse files Browse the repository at this point in the history
* remove to _device() for sigmas

* update add_noise to use simgas

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
  • Loading branch information
yiyixuxu and yiyixuxu authored Sep 23, 2023
1 parent 310cf32 commit 5b11c5d
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 23 deletions.
11 changes: 6 additions & 5 deletions src/diffusers/schedulers/scheduling_deis_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)

self.sigmas = torch.from_numpy(sigmas).to(device=device)
self.timesteps = torch.from_numpy(timesteps).to(device)
self.sigmas = torch.from_numpy(sigmas)
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)

self.num_inference_steps = len(timesteps)

Expand Down Expand Up @@ -707,12 +707,12 @@ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch
"""
return sample

# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
Expand All @@ -730,7 +730,8 @@ def add_noise(
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)

noisy_samples = original_samples + noise * sigma
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
noisy_samples = alpha_t * original_samples + sigma_t * noise
return noisy_samples

def __len__(self):
Expand Down
10 changes: 5 additions & 5 deletions src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,8 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)

self.sigmas = torch.from_numpy(sigmas).to(device=device)
self.timesteps = torch.from_numpy(timesteps).to(device)
self.sigmas = torch.from_numpy(sigmas)
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)

self.num_inference_steps = len(timesteps)

Expand Down Expand Up @@ -840,12 +840,11 @@ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch
"""
return sample

# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
Expand All @@ -863,7 +862,8 @@ def add_noise(
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)

noisy_samples = original_samples + noise * sigma
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
noisy_samples = alpha_t * original_samples + sigma_t * noise
return noisy_samples

def __len__(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
_, unique_indices = np.unique(timesteps, return_index=True)
timesteps = timesteps[np.sort(unique_indices)]

self.timesteps = torch.from_numpy(timesteps).to(device)
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)

self.num_inference_steps = len(timesteps)

Expand Down Expand Up @@ -858,12 +858,12 @@ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch
"""
return sample

# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
Expand All @@ -881,7 +881,8 @@ def add_noise(
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)

noisy_samples = original_samples + noise * sigma
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
noisy_samples = alpha_t * original_samples + sigma_t * noise
return noisy_samples

def __len__(self):
Expand Down
9 changes: 5 additions & 4 deletions src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic

self.sigmas = torch.from_numpy(sigmas).to(device=device)

self.timesteps = torch.from_numpy(timesteps).to(device)
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
self.model_outputs = [None] * self.config.solver_order
self.sample = None

Expand Down Expand Up @@ -870,12 +870,12 @@ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch
"""
return sample

# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
Expand All @@ -893,7 +893,8 @@ def add_noise(
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)

noisy_samples = original_samples + noise * sigma
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
noisy_samples = alpha_t * original_samples + sigma_t * noise
return noisy_samples

def __len__(self):
Expand Down
11 changes: 6 additions & 5 deletions src/diffusers/schedulers/scheduling_unipc_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,8 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)

self.sigmas = torch.from_numpy(sigmas).to(device=device)
self.timesteps = torch.from_numpy(timesteps).to(device)
self.sigmas = torch.from_numpy(sigmas)
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)

self.num_inference_steps = len(timesteps)

Expand Down Expand Up @@ -801,12 +801,12 @@ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch
"""
return sample

# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
Expand All @@ -824,7 +824,8 @@ def add_noise(
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)

noisy_samples = original_samples + noise * sigma
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
noisy_samples = alpha_t * original_samples + sigma_t * noise
return noisy_samples

def __len__(self):
Expand Down

0 comments on commit 5b11c5d

Please sign in to comment.