-
Notifications
You must be signed in to change notification settings - Fork 77
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
zhaoting
committed
Sep 2, 2024
1 parent
2f78b18
commit 24e728f
Showing
16 changed files
with
1,415 additions
and
127 deletions.
There are no files selected for viewing
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Modified from OpenAI's diffusion repos | ||
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py | ||
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion | ||
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py | ||
|
||
from .diffusion_utils import LossType, ModelMeanType, ModelVarType, get_named_beta_schedule | ||
from .respace import SpacedDiffusion, space_timesteps | ||
|
||
|
||
def create_diffusion( | ||
timestep_respacing, | ||
noise_schedule="linear", | ||
use_kl=False, | ||
sigma_small=False, | ||
predict_xstart=False, | ||
learn_sigma=True, | ||
rescale_learned_sigmas=False, | ||
diffusion_steps=1000, | ||
): | ||
betas = get_named_beta_schedule(noise_schedule, diffusion_steps) | ||
if use_kl: | ||
loss_type = LossType.RESCALED_KL | ||
elif rescale_learned_sigmas: | ||
loss_type = LossType.RESCALED_MSE | ||
else: | ||
loss_type = LossType.MSE | ||
if timestep_respacing is None or timestep_respacing == "": | ||
timestep_respacing = [diffusion_steps] | ||
return SpacedDiffusion( | ||
use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), | ||
betas=betas, | ||
model_mean_type=(ModelMeanType.EPSILON if not predict_xstart else ModelMeanType.START_X), | ||
model_var_type=( | ||
(ModelVarType.FIXED_LARGE if not sigma_small else ModelVarType.FIXED_SMALL) | ||
if not learn_sigma | ||
else ModelVarType.LEARNED_RANGE | ||
), | ||
loss_type=loss_type, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,236 @@ | ||
# Modified from OpenAI's diffusion repos | ||
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py | ||
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion | ||
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py | ||
|
||
import enum | ||
import math | ||
from typing import Optional | ||
|
||
import numpy as np | ||
|
||
import mindspore as ms | ||
from mindspore import Tensor, ops | ||
|
||
|
||
def _extract_into_tensor(a, t, x_shape): | ||
""" | ||
Extract values from a 1-D numpy array for a batch of indices. | ||
:param a: the 1-D numpy array. | ||
:param t: a tensor of indices into the array to extract. | ||
:param x_shape: a larger shape of K dimensions with the batch | ||
dimension equal to the length of timesteps. | ||
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. | ||
""" | ||
b = t.shape[0] | ||
out = ops.GatherD()(a, -1, t) | ||
return out.reshape(b, *((1,) * (len(x_shape) - 1))) | ||
|
||
|
||
def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): | ||
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) | ||
warmup_time = int(num_diffusion_timesteps * warmup_frac) | ||
betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) | ||
return betas | ||
|
||
|
||
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): | ||
""" | ||
This is the deprecated API for creating beta schedules. | ||
See get_named_beta_schedule() for the new library of schedules. | ||
""" | ||
if beta_schedule == "quad": | ||
betas = ( | ||
np.linspace( | ||
beta_start**0.5, | ||
beta_end**0.5, | ||
num_diffusion_timesteps, | ||
dtype=np.float64, | ||
) | ||
** 2 | ||
) | ||
elif beta_schedule == "linear": | ||
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) | ||
elif beta_schedule == "warmup10": | ||
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) | ||
elif beta_schedule == "warmup50": | ||
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) | ||
elif beta_schedule == "const": | ||
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) | ||
elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 | ||
betas = 1.0 / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64) | ||
else: | ||
raise NotImplementedError(beta_schedule) | ||
assert betas.shape == (num_diffusion_timesteps,) | ||
return betas | ||
|
||
|
||
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): | ||
""" | ||
Get a pre-defined beta schedule for the given name. | ||
The beta schedule library consists of beta schedules which remain similar | ||
in the limit of num_diffusion_timesteps. | ||
Beta schedules may be added, but should not be removed or changed once | ||
they are committed to maintain backwards compatibility. | ||
""" | ||
if schedule_name == "linear": | ||
# Linear schedule from Ho et al, extended to work for any number of | ||
# diffusion steps. | ||
scale = 1000 / num_diffusion_timesteps | ||
return get_beta_schedule( | ||
"linear", | ||
beta_start=scale * 0.0001, | ||
beta_end=scale * 0.02, | ||
num_diffusion_timesteps=num_diffusion_timesteps, | ||
) | ||
elif schedule_name == "squaredcos_cap_v2": | ||
return betas_for_alpha_bar( | ||
num_diffusion_timesteps, | ||
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, | ||
) | ||
else: | ||
raise NotImplementedError(f"unknown beta schedule: {schedule_name}") | ||
|
||
|
||
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): | ||
""" | ||
Create a beta schedule that discretizes the given alpha_t_bar function, | ||
which defines the cumulative product of (1-beta) over time from t = [0,1]. | ||
:param num_diffusion_timesteps: the number of betas to produce. | ||
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and | ||
produces the cumulative product of (1-beta) up to that | ||
part of the diffusion process. | ||
:param max_beta: the maximum beta to use; use values lower than 1 to | ||
prevent singularities. | ||
""" | ||
betas = [] | ||
for i in range(num_diffusion_timesteps): | ||
t1 = i / num_diffusion_timesteps | ||
t2 = (i + 1) / num_diffusion_timesteps | ||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) | ||
return np.array(betas) | ||
|
||
|
||
def mean_flat(tensor: Tensor, frames_mask: Optional[Tensor] = None, patch_mask: Optional[Tensor] = None) -> Tensor: | ||
""" | ||
Take the mean over all non-batch dimensions. | ||
""" | ||
if frames_mask is None and patch_mask is None: | ||
return tensor.mean(axis=list(range(1, len(tensor.shape)))) | ||
elif patch_mask is None: | ||
assert tensor.dim() == 5 | ||
assert tensor.shape[2] == frames_mask.shape[1] | ||
tensor = tensor.swapaxes(1, 2).reshape(tensor.shape[0], tensor.shape[2], -1) # b c t h w -> b t (c h w) | ||
denom = frames_mask.sum(axis=1) * tensor.shape[-1] | ||
loss = (tensor * frames_mask.unsqueeze(2)).sum(axis=(1, 2)) / denom | ||
return loss | ||
else: | ||
mask = frames_mask[:, None, :, None, None] * patch_mask | ||
tensor = tensor * mask | ||
num = ops.clamp(mask.sum(axis=list(range(1, len(tensor.shape)))), min=1) | ||
return tensor.sum(axis=list(range(1, len(tensor.shape)))) / num | ||
|
||
|
||
class ModelMeanType(enum.Enum): | ||
""" | ||
Which type of output the model predicts. | ||
""" | ||
|
||
PREVIOUS_X = enum.auto() # the model predicts x_{t-1} | ||
START_X = enum.auto() # the model predicts x_0 | ||
EPSILON = enum.auto() # the model predicts epsilon | ||
|
||
|
||
class ModelVarType(enum.Enum): | ||
""" | ||
What is used as the model's output variance. | ||
The LEARNED_RANGE option has been added to allow the model to predict | ||
values between FIXED_SMALL and FIXED_LARGE, making its job easier. | ||
""" | ||
|
||
LEARNED = enum.auto() | ||
FIXED_SMALL = enum.auto() | ||
FIXED_LARGE = enum.auto() | ||
LEARNED_RANGE = enum.auto() | ||
|
||
|
||
class LossType(enum.Enum): | ||
MSE = enum.auto() # use raw MSE loss (and KL when learning variances) | ||
RESCALED_MSE = enum.auto() # use raw MSE loss (with RESCALED_KL when learning variances) | ||
KL = enum.auto() # use the variational lower-bound | ||
RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB | ||
|
||
def is_vb(self): | ||
return self == LossType.KL or self == LossType.RESCALED_KL | ||
|
||
|
||
def normal_kl(mean1, logvar1, mean2, logvar2): | ||
""" | ||
Compute the KL divergence between two gaussians. | ||
Shapes are automatically broadcasted, so batches can be compared to | ||
scalars, among other use cases. | ||
""" | ||
tensor = None | ||
for obj in (mean1, logvar1, mean2, logvar2): | ||
if isinstance(obj, ms.Tensor): | ||
tensor = obj | ||
break | ||
assert tensor is not None, "at least one argument must be a Tensor" | ||
|
||
# Force variances to be Tensors. Broadcasting helps convert scalars to | ||
# Tensors, but it does not work for ops.exp(). | ||
logvar1, logvar2 = [x if isinstance(x, ms.Tensor) else ms.Tensor(x) for x in (logvar1, logvar2)] | ||
|
||
return 0.5 * (-1.0 + logvar2 - logvar1 + ops.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * ops.exp(-logvar2)) | ||
|
||
|
||
def approx_standard_normal_cdf(x): | ||
""" | ||
A fast approximation of the cumulative distribution function of the | ||
standard normal. | ||
""" | ||
return 0.5 * (1.0 + ops.tanh(ms.numpy.sqrt(2.0 / ms.numpy.pi) * (x + 0.044715 * ops.pow(x, 3)))) | ||
|
||
|
||
def continuous_gaussian_log_likelihood(x, *, means, log_scales): | ||
""" | ||
Compute the log-likelihood of a continuous Gaussian distribution. | ||
:param x: the targets | ||
:param means: the Gaussian mean Tensor. | ||
:param log_scales: the Gaussian log stddev Tensor. | ||
:return: a tensor like x of log probabilities (in nats). | ||
""" | ||
centered_x = x - means | ||
inv_stdv = ops.exp(-log_scales) | ||
normalized_x = centered_x * inv_stdv | ||
log_probs = ms.nn.probability.Normal(ops.zeros_like(x), ops.ones_like(x)).log_prob(normalized_x) | ||
return log_probs | ||
|
||
|
||
def discretized_gaussian_log_likelihood(x, *, means, log_scales): | ||
""" | ||
Compute the log-likelihood of a Gaussian distribution discretizing to a | ||
given image. | ||
:param x: the target images. It is assumed that this was uint8 values, | ||
rescaled to the range [-1, 1]. | ||
:param means: the Gaussian mean Tensor. | ||
:param log_scales: the Gaussian log stddev Tensor. | ||
:return: a tensor like x of log probabilities (in nats). | ||
""" | ||
assert x.shape == means.shape and means.shape == log_scales.shape | ||
centered_x = x - means | ||
inv_stdv = ops.exp(-log_scales) | ||
plus_in = inv_stdv * (centered_x + 1.0 / 255.0) | ||
cdf_plus = approx_standard_normal_cdf(plus_in) | ||
min_in = inv_stdv * (centered_x - 1.0 / 255.0) | ||
cdf_min = approx_standard_normal_cdf(min_in) | ||
log_cdf_plus = ops.log(cdf_plus.clamp(min=1e-12)) | ||
log_one_minus_cdf_min = ops.log((1.0 - cdf_min).clamp(min=1e-12)) | ||
cdf_delta = cdf_plus - cdf_min | ||
log_probs = ops.where( | ||
x < -0.999, | ||
log_cdf_plus, | ||
ops.where(x > 0.999, log_one_minus_cdf_min, ops.log(cdf_delta.clamp(min=1e-12))), | ||
) | ||
assert log_probs.shape == x.shape | ||
return log_probs |
Oops, something went wrong.