Skip to content

Commit

Permalink
reconstruct
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaoting committed Sep 2, 2024
1 parent 2f78b18 commit 24e728f
Show file tree
Hide file tree
Showing 16 changed files with 1,415 additions and 127 deletions.
Empty file added examples/dvt/dvt/__init__.py
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
39 changes: 39 additions & 0 deletions examples/dvt/dvt/schedulers/iddpm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Modified from OpenAI's diffusion repos
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py

from .diffusion_utils import LossType, ModelMeanType, ModelVarType, get_named_beta_schedule
from .respace import SpacedDiffusion, space_timesteps


def create_diffusion(
timestep_respacing,
noise_schedule="linear",
use_kl=False,
sigma_small=False,
predict_xstart=False,
learn_sigma=True,
rescale_learned_sigmas=False,
diffusion_steps=1000,
):
betas = get_named_beta_schedule(noise_schedule, diffusion_steps)
if use_kl:
loss_type = LossType.RESCALED_KL
elif rescale_learned_sigmas:
loss_type = LossType.RESCALED_MSE
else:
loss_type = LossType.MSE
if timestep_respacing is None or timestep_respacing == "":
timestep_respacing = [diffusion_steps]
return SpacedDiffusion(
use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
betas=betas,
model_mean_type=(ModelMeanType.EPSILON if not predict_xstart else ModelMeanType.START_X),
model_var_type=(
(ModelVarType.FIXED_LARGE if not sigma_small else ModelVarType.FIXED_SMALL)
if not learn_sigma
else ModelVarType.LEARNED_RANGE
),
loss_type=loss_type,
)
236 changes: 236 additions & 0 deletions examples/dvt/dvt/schedulers/iddpm/diffusion_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
# Modified from OpenAI's diffusion repos
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py

import enum
import math
from typing import Optional

import numpy as np

import mindspore as ms
from mindspore import Tensor, ops


def _extract_into_tensor(a, t, x_shape):
"""
Extract values from a 1-D numpy array for a batch of indices.
:param a: the 1-D numpy array.
:param t: a tensor of indices into the array to extract.
:param x_shape: a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
b = t.shape[0]
out = ops.GatherD()(a, -1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))


def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
warmup_time = int(num_diffusion_timesteps * warmup_frac)
betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
return betas


def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
"""
This is the deprecated API for creating beta schedules.
See get_named_beta_schedule() for the new library of schedules.
"""
if beta_schedule == "quad":
betas = (
np.linspace(
beta_start**0.5,
beta_end**0.5,
num_diffusion_timesteps,
dtype=np.float64,
)
** 2
)
elif beta_schedule == "linear":
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
elif beta_schedule == "warmup10":
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
elif beta_schedule == "warmup50":
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
elif beta_schedule == "const":
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
betas = 1.0 / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64)
else:
raise NotImplementedError(beta_schedule)
assert betas.shape == (num_diffusion_timesteps,)
return betas


def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
"""
Get a pre-defined beta schedule for the given name.
The beta schedule library consists of beta schedules which remain similar
in the limit of num_diffusion_timesteps.
Beta schedules may be added, but should not be removed or changed once
they are committed to maintain backwards compatibility.
"""
if schedule_name == "linear":
# Linear schedule from Ho et al, extended to work for any number of
# diffusion steps.
scale = 1000 / num_diffusion_timesteps
return get_beta_schedule(
"linear",
beta_start=scale * 0.0001,
beta_end=scale * 0.02,
num_diffusion_timesteps=num_diffusion_timesteps,
)
elif schedule_name == "squaredcos_cap_v2":
return betas_for_alpha_bar(
num_diffusion_timesteps,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
)
else:
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")


def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas)


def mean_flat(tensor: Tensor, frames_mask: Optional[Tensor] = None, patch_mask: Optional[Tensor] = None) -> Tensor:
"""
Take the mean over all non-batch dimensions.
"""
if frames_mask is None and patch_mask is None:
return tensor.mean(axis=list(range(1, len(tensor.shape))))
elif patch_mask is None:
assert tensor.dim() == 5
assert tensor.shape[2] == frames_mask.shape[1]
tensor = tensor.swapaxes(1, 2).reshape(tensor.shape[0], tensor.shape[2], -1) # b c t h w -> b t (c h w)
denom = frames_mask.sum(axis=1) * tensor.shape[-1]
loss = (tensor * frames_mask.unsqueeze(2)).sum(axis=(1, 2)) / denom
return loss
else:
mask = frames_mask[:, None, :, None, None] * patch_mask
tensor = tensor * mask
num = ops.clamp(mask.sum(axis=list(range(1, len(tensor.shape)))), min=1)
return tensor.sum(axis=list(range(1, len(tensor.shape)))) / num


class ModelMeanType(enum.Enum):
"""
Which type of output the model predicts.
"""

PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
START_X = enum.auto() # the model predicts x_0
EPSILON = enum.auto() # the model predicts epsilon


class ModelVarType(enum.Enum):
"""
What is used as the model's output variance.
The LEARNED_RANGE option has been added to allow the model to predict
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
"""

LEARNED = enum.auto()
FIXED_SMALL = enum.auto()
FIXED_LARGE = enum.auto()
LEARNED_RANGE = enum.auto()


class LossType(enum.Enum):
MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
RESCALED_MSE = enum.auto() # use raw MSE loss (with RESCALED_KL when learning variances)
KL = enum.auto() # use the variational lower-bound
RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB

def is_vb(self):
return self == LossType.KL or self == LossType.RESCALED_KL


def normal_kl(mean1, logvar1, mean2, logvar2):
"""
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, ms.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"

# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for ops.exp().
logvar1, logvar2 = [x if isinstance(x, ms.Tensor) else ms.Tensor(x) for x in (logvar1, logvar2)]

return 0.5 * (-1.0 + logvar2 - logvar1 + ops.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * ops.exp(-logvar2))


def approx_standard_normal_cdf(x):
"""
A fast approximation of the cumulative distribution function of the
standard normal.
"""
return 0.5 * (1.0 + ops.tanh(ms.numpy.sqrt(2.0 / ms.numpy.pi) * (x + 0.044715 * ops.pow(x, 3))))


def continuous_gaussian_log_likelihood(x, *, means, log_scales):
"""
Compute the log-likelihood of a continuous Gaussian distribution.
:param x: the targets
:param means: the Gaussian mean Tensor.
:param log_scales: the Gaussian log stddev Tensor.
:return: a tensor like x of log probabilities (in nats).
"""
centered_x = x - means
inv_stdv = ops.exp(-log_scales)
normalized_x = centered_x * inv_stdv
log_probs = ms.nn.probability.Normal(ops.zeros_like(x), ops.ones_like(x)).log_prob(normalized_x)
return log_probs


def discretized_gaussian_log_likelihood(x, *, means, log_scales):
"""
Compute the log-likelihood of a Gaussian distribution discretizing to a
given image.
:param x: the target images. It is assumed that this was uint8 values,
rescaled to the range [-1, 1].
:param means: the Gaussian mean Tensor.
:param log_scales: the Gaussian log stddev Tensor.
:return: a tensor like x of log probabilities (in nats).
"""
assert x.shape == means.shape and means.shape == log_scales.shape
centered_x = x - means
inv_stdv = ops.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
cdf_plus = approx_standard_normal_cdf(plus_in)
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
cdf_min = approx_standard_normal_cdf(min_in)
log_cdf_plus = ops.log(cdf_plus.clamp(min=1e-12))
log_one_minus_cdf_min = ops.log((1.0 - cdf_min).clamp(min=1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = ops.where(
x < -0.999,
log_cdf_plus,
ops.where(x > 0.999, log_one_minus_cdf_min, ops.log(cdf_delta.clamp(min=1e-12))),
)
assert log_probs.shape == x.shape
return log_probs
Loading

0 comments on commit 24e728f

Please sign in to comment.