-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Inference Strategy] Will inconsistent sampling steps cause suboptimal performance? #14
Comments
Thanks for the reminder. It might be caused by the num_ddim_timesteps in the training configs. Using num_ddim_timesteps=50 should make it consistent. |
@G-U-N num_ddim_timesteps=50 indeed cause problem. You can check this with following code: import numpy as np
import torch
ddim_timesteps = 50
multiphase = 4
step_ratio = 1000 // ddim_timesteps
ddim_timesteps = (
np.arange(1, ddim_timesteps + 1) * step_ratio
).round().astype(np.int64) - 1
ddim_timesteps = torch.from_numpy(ddim_timesteps).long()
inference_indices = np.linspace(
0, len(ddim_timesteps), num=multiphase, endpoint=False
)
inference_indices = np.floor(inference_indices).astype(np.int64)
inference_indices = (
torch.from_numpy(inference_indices).long().to(ddim_timesteps.device)
)
print(ddim_timesteps) # tensor([..., 719, 739, 759, 779, ...])
print(inference_indices) # tensor([ 0, 12, 25, 37])
print(ddim_timesteps[inference_indices]) # tensor([ 19, 259, 519, 759])
step_ratio = 1000 / 4
timesteps = np.round(np.arange(1000, 0, -step_ratio)).astype(np.int64)
timesteps -= 1
print(timesteps) # [999 749 499 249] In training, the previous timestep in DDIM before step 759 is 739, which means we learn to jump to step 739 for every timestep after 739. However, in inference, we jump from 999 to 749, which cause inconsistency. |
Did you print the end_timesteps when training? I remember it did print [0 249 499 749]. |
You can make a PR if you find it is indeed wrong. |
end_timesteps is [0, 239, 499, 739] according to bellow minimal reproducible code. import numpy as np
import torch
from diffusers import DDPMScheduler
def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
class DDIMSolver:
def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):
self.step_ratio = timesteps // ddim_timesteps
self.ddim_timesteps = (
np.arange(1, ddim_timesteps + 1) * self.step_ratio
).round().astype(np.int64) - 1
self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
self.ddim_timesteps_prev = np.asarray([0] + self.ddim_timesteps[:-1].tolist())
self.ddim_alpha_cumprods_prev = np.asarray(
[alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
)
self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
self.ddim_timesteps_prev = torch.from_numpy(self.ddim_timesteps_prev).long()
self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)
def to(self, device):
self.ddim_timesteps = self.ddim_timesteps.to(device)
self.ddim_timesteps_prev = self.ddim_timesteps_prev.to(device)
self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)
return self
def ddim_step(self, pred_x0, pred_noise, timestep_index):
alpha_cumprod_prev = extract_into_tensor(
self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape
)
dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
return x_prev
def ddim_style_multiphase_pred(self, timestep_index, multiphase):
inference_indices = np.linspace(
0, len(self.ddim_timesteps), num=multiphase, endpoint=False
)
inference_indices = np.floor(inference_indices).astype(np.int64)
inference_indices = (
torch.from_numpy(inference_indices).long().to(self.ddim_timesteps.device)
)
expanded_timestep_index = timestep_index.unsqueeze(1).expand(
-1, inference_indices.size(0)
)
valid_indices_mask = expanded_timestep_index >= inference_indices
last_valid_index = valid_indices_mask.flip(dims=[1]).long().argmax(dim=1)
last_valid_index = inference_indices.size(0) - 1 - last_valid_index
timestep_index = inference_indices[last_valid_index]
return self.ddim_timesteps_prev[timestep_index]
if __name__ == '__main__':
ddim_timesteps = 50
multiphase = 4
# 1. Create the noise scheduler and the desired noise schedule.
noise_scheduler = DDPMScheduler.from_pretrained(
"runwayml/stable-diffusion-v1-5",
subfolder="scheduler"
)
# The scheduler calculates the alpha and sigma schedule for us
alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
solver = DDIMSolver(
noise_scheduler.alphas_cumprod.numpy(),
timesteps=noise_scheduler.config.num_train_timesteps,
ddim_timesteps=ddim_timesteps,
)
index = torch.arange(ddim_timesteps)
end_timesteps = solver.ddim_style_multiphase_pred(index, multiphase)
print(end_timesteps)
# tensor([ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 239, 239,
# 239, 239, 239, 239, 239, 239, 239, 239, 239, 239, 239, 499, 499, 499,
# 499, 499, 499, 499, 499, 499, 499, 499, 499, 739, 739, 739, 739, 739,
# 739, 739, 739, 739, 739, 739, 739, 739]) |
I'm happy to make a PR after we agree on a proper way for it. I think a suitable division on time interval would be: endpoints = np.linspace(
-1, noise_scheduler.config.num_train_timesteps - 1, num=num_phase + 1, endpoint=True
)# [-1, 249. 499. 749. 999.]
# endpoints[0] = 0 # we can also let start point be 0, it is fine.
endpoints = np.floor(endpoints).astype(np.int64)
endpoints = (
torch.from_numpy(endpoints).long().to(start_timesteps.device)
) In training, for every timestep in DDIMScheduler(
timestep_spacing="trailing",
set_alpha_to_one = True,
) Furthermore, a proper condition is that endpoints is subset of solver.ddim_timesteps (so that we learn from endpoints directly). To satisfy the condition |
I think you're right. I did the same test as yours, and the results were the same. Therefore, the inconsistency between the timesteps of the training and the test phase, may lead to the sub-optimal results, and I'm wondering whether you've tried your revised version, how is the result? Thx. |
@JaySimple I have tested my revised version, and it functions as intended. However, I'm currently working on a different configuration, so a direct comparison with the original implementation is not available. I'm confident that my implementation at least maintains the existing performance level, if not improves it. |
Great, I will check it in time. |
Hi, I notice that you utilize builtin DDIM scheduler in inference, however, the discretization of it is different with that used in training stage.
Specifically, the endpoints that split time interval into sub-trajectories are [0, 239, 499, 739, 999], however, DDIM scheduler use [0, 249, 499, 749, 999] in inference.
The text was updated successfully, but these errors were encountered: