Skip to content

Commit

Permalink
Merge pull request #16149 from AndreyRGW/devpatch1
Browse files Browse the repository at this point in the history
Add Normal and DDIM Schedulers
  • Loading branch information
AUTOMATIC1111 committed Jul 6, 2024
2 parents c02e3a5 + f864066 commit b282b47
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions modules/sd_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,33 @@ def kl_optimal(n, sigma_min, sigma_max, device):
sigmas = torch.tan(step_indices / n * alpha_min + (1.0 - step_indices / n) * alpha_max)
return sigmas

def normal_scheduler(n, sigma_min, sigma_max, inner_model, device, sgm=False, floor=False):
start = inner_model.sigma_to_t(torch.tensor(sigma_max))
end = inner_model.sigma_to_t(torch.tensor(sigma_min))

if sgm:
timesteps = torch.linspace(start, end, n + 1)[:-1]
else:
timesteps = torch.linspace(start, end, n)

sigs = []
for x in range(len(timesteps)):
ts = timesteps[x]
sigs.append(inner_model.t_to_sigma(ts))
sigs += [0.0]
return torch.FloatTensor(sigs).to(device)

def ddim_scheduler(n, sigma_min, sigma_max, inner_model, device):
sigs = []
ss = max(len(inner_model.sigmas) // n, 1)
x = 1
while x < len(inner_model.sigmas):
sigs += [float(inner_model.sigmas[x])]
x += ss
sigs = sigs[::-1]
sigs += [0.0]
return torch.FloatTensor(sigs).to(device)


schedulers = [
Scheduler('automatic', 'Automatic', None),
Expand All @@ -86,6 +113,8 @@ def kl_optimal(n, sigma_min, sigma_max, device):
Scheduler('sgm_uniform', 'SGM Uniform', sgm_uniform, need_inner_model=True, aliases=["SGMUniform"]),
Scheduler('kl_optimal', 'KL Optimal', kl_optimal),
Scheduler('align_your_steps', 'Align Your Steps', get_align_your_steps_sigmas),
Scheduler('normal', 'Normal', normal_scheduler, need_inner_model=True),
Scheduler('ddim', 'DDIM', ddim_scheduler, need_inner_model=True),
]

schedulers_map = {**{x.name: x for x in schedulers}, **{x.label: x for x in schedulers}}

0 comments on commit b282b47

Please sign in to comment.