From 69cce2990cea330bca3b68740cd7de8442a784c7 Mon Sep 17 00:00:00 2001 From: zhaoting Date: Mon, 2 Sep 2024 15:01:04 +0800 Subject: [PATCH] ema update --- examples/dvt/dvt/__init__.py | 0 examples/dvt/dvt/dataset/__init__.py | 0 examples/dvt/dvt/models/__init__.py | 0 examples/dvt/dvt/pipelines/__init__.py | 0 examples/dvt/dvt/schedulers/__init__.py | 0 examples/dvt/dvt/schedulers/iddpm/__init__.py | 39 -- .../dvt/schedulers/iddpm/diffusion_utils.py | 236 ------- .../schedulers/iddpm/gaussian_diffusion.py | 613 ------------------ examples/dvt/dvt/schedulers/iddpm/respace.py | 120 ---- .../dvt/schedulers/iddpm/timestep_sampler.py | 144 ---- mindone/models/modules/parallel/__init__.py | 13 +- mindone/models/modules/parallel/conv.py | 8 +- mindone/models/modules/parallel/dense.py | 2 +- .../models/modules/parallel/param_wrapper.py | 2 +- mindone/trainers/zero.py | 73 ++- tests/st/test_zero.py | 38 +- 16 files changed, 82 insertions(+), 1206 deletions(-) delete mode 100644 examples/dvt/dvt/__init__.py delete mode 100644 examples/dvt/dvt/dataset/__init__.py delete mode 100644 examples/dvt/dvt/models/__init__.py delete mode 100644 examples/dvt/dvt/pipelines/__init__.py delete mode 100644 examples/dvt/dvt/schedulers/__init__.py delete mode 100644 examples/dvt/dvt/schedulers/iddpm/__init__.py delete mode 100644 examples/dvt/dvt/schedulers/iddpm/diffusion_utils.py delete mode 100644 examples/dvt/dvt/schedulers/iddpm/gaussian_diffusion.py delete mode 100644 examples/dvt/dvt/schedulers/iddpm/respace.py delete mode 100644 examples/dvt/dvt/schedulers/iddpm/timestep_sampler.py diff --git a/examples/dvt/dvt/__init__.py b/examples/dvt/dvt/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/examples/dvt/dvt/dataset/__init__.py b/examples/dvt/dvt/dataset/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/examples/dvt/dvt/models/__init__.py b/examples/dvt/dvt/models/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/examples/dvt/dvt/pipelines/__init__.py b/examples/dvt/dvt/pipelines/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/examples/dvt/dvt/schedulers/__init__.py b/examples/dvt/dvt/schedulers/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/examples/dvt/dvt/schedulers/iddpm/__init__.py b/examples/dvt/dvt/schedulers/iddpm/__init__.py deleted file mode 100644 index 78b9d7ae16..0000000000 --- a/examples/dvt/dvt/schedulers/iddpm/__init__.py +++ /dev/null @@ -1,39 +0,0 @@ -# Modified from OpenAI's diffusion repos -# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py -# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion -# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py - -from .diffusion_utils import LossType, ModelMeanType, ModelVarType, get_named_beta_schedule -from .respace import SpacedDiffusion, space_timesteps - - -def create_diffusion( - timestep_respacing, - noise_schedule="linear", - use_kl=False, - sigma_small=False, - predict_xstart=False, - learn_sigma=True, - rescale_learned_sigmas=False, - diffusion_steps=1000, -): - betas = get_named_beta_schedule(noise_schedule, diffusion_steps) - if use_kl: - loss_type = LossType.RESCALED_KL - elif rescale_learned_sigmas: - loss_type = LossType.RESCALED_MSE - else: - loss_type = LossType.MSE - if timestep_respacing is None or timestep_respacing == "": - timestep_respacing = [diffusion_steps] - return SpacedDiffusion( - use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), - betas=betas, - model_mean_type=(ModelMeanType.EPSILON if not predict_xstart else ModelMeanType.START_X), - model_var_type=( - (ModelVarType.FIXED_LARGE if not sigma_small else ModelVarType.FIXED_SMALL) - if not learn_sigma - else ModelVarType.LEARNED_RANGE - ), - loss_type=loss_type, - ) diff --git a/examples/dvt/dvt/schedulers/iddpm/diffusion_utils.py b/examples/dvt/dvt/schedulers/iddpm/diffusion_utils.py deleted file mode 100644 index 6d64f2490d..0000000000 --- a/examples/dvt/dvt/schedulers/iddpm/diffusion_utils.py +++ /dev/null @@ -1,236 +0,0 @@ -# Modified from OpenAI's diffusion repos -# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py -# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion -# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py - -import enum -import math -from typing import Optional - -import numpy as np - -import mindspore as ms -from mindspore import Tensor, ops - - -def _extract_into_tensor(a, t, x_shape): - """ - Extract values from a 1-D numpy array for a batch of indices. - :param a: the 1-D numpy array. - :param t: a tensor of indices into the array to extract. - :param x_shape: a larger shape of K dimensions with the batch - dimension equal to the length of timesteps. - :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. - """ - b = t.shape[0] - out = ops.GatherD()(a, -1, t) - return out.reshape(b, *((1,) * (len(x_shape) - 1))) - - -def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): - betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) - warmup_time = int(num_diffusion_timesteps * warmup_frac) - betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) - return betas - - -def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): - """ - This is the deprecated API for creating beta schedules. - See get_named_beta_schedule() for the new library of schedules. - """ - if beta_schedule == "quad": - betas = ( - np.linspace( - beta_start**0.5, - beta_end**0.5, - num_diffusion_timesteps, - dtype=np.float64, - ) - ** 2 - ) - elif beta_schedule == "linear": - betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) - elif beta_schedule == "warmup10": - betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) - elif beta_schedule == "warmup50": - betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) - elif beta_schedule == "const": - betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) - elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 - betas = 1.0 / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64) - else: - raise NotImplementedError(beta_schedule) - assert betas.shape == (num_diffusion_timesteps,) - return betas - - -def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): - """ - Get a pre-defined beta schedule for the given name. - The beta schedule library consists of beta schedules which remain similar - in the limit of num_diffusion_timesteps. - Beta schedules may be added, but should not be removed or changed once - they are committed to maintain backwards compatibility. - """ - if schedule_name == "linear": - # Linear schedule from Ho et al, extended to work for any number of - # diffusion steps. - scale = 1000 / num_diffusion_timesteps - return get_beta_schedule( - "linear", - beta_start=scale * 0.0001, - beta_end=scale * 0.02, - num_diffusion_timesteps=num_diffusion_timesteps, - ) - elif schedule_name == "squaredcos_cap_v2": - return betas_for_alpha_bar( - num_diffusion_timesteps, - lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, - ) - else: - raise NotImplementedError(f"unknown beta schedule: {schedule_name}") - - -def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): - """ - Create a beta schedule that discretizes the given alpha_t_bar function, - which defines the cumulative product of (1-beta) over time from t = [0,1]. - :param num_diffusion_timesteps: the number of betas to produce. - :param alpha_bar: a lambda that takes an argument t from 0 to 1 and - produces the cumulative product of (1-beta) up to that - part of the diffusion process. - :param max_beta: the maximum beta to use; use values lower than 1 to - prevent singularities. - """ - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return np.array(betas) - - -def mean_flat(tensor: Tensor, frames_mask: Optional[Tensor] = None, patch_mask: Optional[Tensor] = None) -> Tensor: - """ - Take the mean over all non-batch dimensions. - """ - if frames_mask is None and patch_mask is None: - return tensor.mean(axis=list(range(1, len(tensor.shape)))) - elif patch_mask is None: - assert tensor.dim() == 5 - assert tensor.shape[2] == frames_mask.shape[1] - tensor = tensor.swapaxes(1, 2).reshape(tensor.shape[0], tensor.shape[2], -1) # b c t h w -> b t (c h w) - denom = frames_mask.sum(axis=1) * tensor.shape[-1] - loss = (tensor * frames_mask.unsqueeze(2)).sum(axis=(1, 2)) / denom - return loss - else: - mask = frames_mask[:, None, :, None, None] * patch_mask - tensor = tensor * mask - num = ops.clamp(mask.sum(axis=list(range(1, len(tensor.shape)))), min=1) - return tensor.sum(axis=list(range(1, len(tensor.shape)))) / num - - -class ModelMeanType(enum.Enum): - """ - Which type of output the model predicts. - """ - - PREVIOUS_X = enum.auto() # the model predicts x_{t-1} - START_X = enum.auto() # the model predicts x_0 - EPSILON = enum.auto() # the model predicts epsilon - - -class ModelVarType(enum.Enum): - """ - What is used as the model's output variance. - The LEARNED_RANGE option has been added to allow the model to predict - values between FIXED_SMALL and FIXED_LARGE, making its job easier. - """ - - LEARNED = enum.auto() - FIXED_SMALL = enum.auto() - FIXED_LARGE = enum.auto() - LEARNED_RANGE = enum.auto() - - -class LossType(enum.Enum): - MSE = enum.auto() # use raw MSE loss (and KL when learning variances) - RESCALED_MSE = enum.auto() # use raw MSE loss (with RESCALED_KL when learning variances) - KL = enum.auto() # use the variational lower-bound - RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB - - def is_vb(self): - return self == LossType.KL or self == LossType.RESCALED_KL - - -def normal_kl(mean1, logvar1, mean2, logvar2): - """ - Compute the KL divergence between two gaussians. - Shapes are automatically broadcasted, so batches can be compared to - scalars, among other use cases. - """ - tensor = None - for obj in (mean1, logvar1, mean2, logvar2): - if isinstance(obj, ms.Tensor): - tensor = obj - break - assert tensor is not None, "at least one argument must be a Tensor" - - # Force variances to be Tensors. Broadcasting helps convert scalars to - # Tensors, but it does not work for ops.exp(). - logvar1, logvar2 = [x if isinstance(x, ms.Tensor) else ms.Tensor(x) for x in (logvar1, logvar2)] - - return 0.5 * (-1.0 + logvar2 - logvar1 + ops.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * ops.exp(-logvar2)) - - -def approx_standard_normal_cdf(x): - """ - A fast approximation of the cumulative distribution function of the - standard normal. - """ - return 0.5 * (1.0 + ops.tanh(ms.numpy.sqrt(2.0 / ms.numpy.pi) * (x + 0.044715 * ops.pow(x, 3)))) - - -def continuous_gaussian_log_likelihood(x, *, means, log_scales): - """ - Compute the log-likelihood of a continuous Gaussian distribution. - :param x: the targets - :param means: the Gaussian mean Tensor. - :param log_scales: the Gaussian log stddev Tensor. - :return: a tensor like x of log probabilities (in nats). - """ - centered_x = x - means - inv_stdv = ops.exp(-log_scales) - normalized_x = centered_x * inv_stdv - log_probs = ms.nn.probability.Normal(ops.zeros_like(x), ops.ones_like(x)).log_prob(normalized_x) - return log_probs - - -def discretized_gaussian_log_likelihood(x, *, means, log_scales): - """ - Compute the log-likelihood of a Gaussian distribution discretizing to a - given image. - :param x: the target images. It is assumed that this was uint8 values, - rescaled to the range [-1, 1]. - :param means: the Gaussian mean Tensor. - :param log_scales: the Gaussian log stddev Tensor. - :return: a tensor like x of log probabilities (in nats). - """ - assert x.shape == means.shape and means.shape == log_scales.shape - centered_x = x - means - inv_stdv = ops.exp(-log_scales) - plus_in = inv_stdv * (centered_x + 1.0 / 255.0) - cdf_plus = approx_standard_normal_cdf(plus_in) - min_in = inv_stdv * (centered_x - 1.0 / 255.0) - cdf_min = approx_standard_normal_cdf(min_in) - log_cdf_plus = ops.log(cdf_plus.clamp(min=1e-12)) - log_one_minus_cdf_min = ops.log((1.0 - cdf_min).clamp(min=1e-12)) - cdf_delta = cdf_plus - cdf_min - log_probs = ops.where( - x < -0.999, - log_cdf_plus, - ops.where(x > 0.999, log_one_minus_cdf_min, ops.log(cdf_delta.clamp(min=1e-12))), - ) - assert log_probs.shape == x.shape - return log_probs diff --git a/examples/dvt/dvt/schedulers/iddpm/gaussian_diffusion.py b/examples/dvt/dvt/schedulers/iddpm/gaussian_diffusion.py deleted file mode 100644 index ce83f1c97f..0000000000 --- a/examples/dvt/dvt/schedulers/iddpm/gaussian_diffusion.py +++ /dev/null @@ -1,613 +0,0 @@ -# Modified from OpenAI's diffusion repos -# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py -# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion -# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py - - -from functools import partial -from typing import Optional - -import numpy as np - -import mindspore as ms -from mindspore import Tensor, ops - -from .diffusion_utils import ( - ModelMeanType, - ModelVarType, - _extract_into_tensor, - discretized_gaussian_log_likelihood, - mean_flat, - normal_kl, -) - - -@ms.jit_class -class GaussianDiffusion: - """ - Utilities for training and sampling diffusion models. - Original ported from this codebase: - https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 - :param betas: a 1-D numpy array of betas for each diffusion timestep, - starting at T and going to 1. - """ - - def __init__( - self, - *, - betas, - model_mean_type, - model_var_type, - loss_type, - ): - super().__init__() - self.model_mean_type = model_mean_type - self.model_var_type = model_var_type - self.loss_type = loss_type - - # 1. pre-compute scheduler vars in numpy using float64 for accuracy. - betas = np.array(betas, dtype=np.float64) - assert len(betas.shape) == 1, "betas must be 1-D" - assert (betas > 0).all() and (betas <= 1).all() - - self.num_timesteps = int(betas.shape[0]) - - alphas = 1.0 - betas - self.alphas_cumprod = np.cumprod(alphas, axis=0) - self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) - self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) - assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) - self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) - self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) - self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) - self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) - - # calculations for posterior q(x_{t-1} | x_t, x_0) - self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) - # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - self.posterior_log_variance_clipped = ( - np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:])) - if len(self.posterior_variance) > 1 - else np.array([]) - ) - - self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) - self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) - - # new - self.log_betas = np.log(betas) - - # 2. convert to ms tensors in float32 - to_mindspore = partial(Tensor, dtype=ms.float32) - - self.betas = betas - - self.alphas_cumprod = to_mindspore(self.alphas_cumprod) - self.alphas_cumprod_prev = to_mindspore(self.alphas_cumprod_prev) - self.alphas_cumprod_next = to_mindspore(self.alphas_cumprod_next) - - self.sqrt_alphas_cumprod = to_mindspore(self.sqrt_alphas_cumprod) - self.sqrt_one_minus_alphas_cumprod = to_mindspore(self.sqrt_one_minus_alphas_cumprod) - self.log_one_minus_alphas_cumprod = to_mindspore(self.log_one_minus_alphas_cumprod) - self.sqrt_recip_alphas_cumprod = to_mindspore(self.sqrt_recip_alphas_cumprod) - self.sqrt_recipm1_alphas_cumprod = to_mindspore(self.sqrt_recipm1_alphas_cumprod) - - self.posterior_variance = to_mindspore(self.posterior_variance) - self.posterior_log_variance_clipped = to_mindspore(self.posterior_log_variance_clipped) - self.posterior_mean_coef1 = to_mindspore(self.posterior_mean_coef1) - self.posterior_mean_coef2 = to_mindspore(self.posterior_mean_coef2) - - # new - self.log_betas = to_mindspore(self.log_betas) - - def q_mean_variance(self, x_start, t): - """ - Get the distribution q(x_t | x_0). - :param x_start: the [N x C x ...] tensor of noiseless inputs. - :param t: the number of diffusion steps (minus 1). Here, 0 means one step. - :return: A tuple (mean, variance, log_variance), all of x_start's shape. - """ - mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start - variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) - log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) - return mean, variance, log_variance - - def q_sample(self, x_start, t, noise=None): - """ - Diffuse the data for a given number of diffusion steps. - In other words, sample from q(x_t | x_0). - :param x_start: the initial data batch. - :param t: the number of diffusion steps (minus 1). Here, 0 means one step. - :param noise: if specified, the split-out normal noise. - :return: A noisy version of x_start. - """ - if noise is None: - noise = ops.randn_like(x_start) - assert noise.shape == x_start.shape - return ( - _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start - + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise - ) - - def q_posterior_mean_variance(self, x_start, x_t, t): - """ - Compute the mean and variance of the diffusion posterior: - q(x_{t-1} | x_t, x_0) - """ - assert x_start.shape == x_t.shape - posterior_mean = ( - _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start - + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t - ) - posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) - posterior_log_variance_clipped = _extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) - return posterior_mean, posterior_variance, posterior_log_variance_clipped - - def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None): - """ - Apply the model to get p(x_{t-1} | x_t), as well as a prediction of - the initial x, x_0. - :param model: the model, which takes a signal and a batch of timesteps - as input. - :param x: the [N x C x ...] tensor at time t. - :param t: a 1-D Tensor of timesteps. - :param clip_denoised: if True, clip the denoised signal into [-1, 1]. - :param denoised_fn: if not None, a function which applies to the - x_start prediction before it is used to sample. Applies before - clip_denoised. - :param model_kwargs: if not None, a dict of extra keyword arguments to - pass to the model. This can be used for conditioning. - :return: a dict with the following keys: - - 'mean': the model mean output. - - 'variance': the model variance output. - - 'log_variance': the log of 'variance'. - - 'pred_xstart': the prediction for x_0. - """ - if model_kwargs is None: - model_kwargs = {} - - B, C, F = x.shape[:3] - - assert t.shape == (B,) - model_output = model(x, t, **model_kwargs) - if isinstance(model_output, tuple): - model_output, extra = model_output - else: - extra = None - if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: - assert model_output.shape == (B, C * 2, F, *x.shape[3:]) - model_output, model_var_values = ops.split(model_output, C, axis=1) - - min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) - max_log = _extract_into_tensor(self.log_betas, t, x.shape) - # The model_var_values is [-1, 1] for [min_var, max_var]. - frac = (model_var_values + 1) / 2 - model_log_variance = frac * max_log + (1 - frac) * min_log - model_variance = ops.exp(model_log_variance) - else: - model_variance, model_log_variance = { - # for fixedlarge, we set the initial (log-)variance like so - # to get a better decoder log likelihood. - ModelVarType.FIXED_LARGE: ( - ops.cat([self.posterior_variance[1].unsqueeze(0), self.betas[1:]]), - ops.log(ops.cat([self.posterior_variance[1].unsqueeze(0), self.betas[1:]])), - ), - ModelVarType.FIXED_SMALL: ( - self.posterior_variance, - self.posterior_log_variance_clipped, - ), - }[self.model_var_type] - model_variance = _extract_into_tensor(model_variance, t, x.shape) - model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) - - def process_xstart(x): - if denoised_fn is not None: - x = denoised_fn(x) - if clip_denoised: - return x.clamp(-1, 1) - return x - - if self.model_mean_type == ModelMeanType.START_X: - pred_xstart = process_xstart(model_output) - else: - pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)) - model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) - - # assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape - return { - "mean": model_mean, - "variance": model_variance, - "log_variance": model_log_variance, - "pred_xstart": pred_xstart, - "extra": extra, - } - - def _predict_xstart_from_eps(self, x_t, t, eps): - assert x_t.shape == eps.shape - return ( - _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps - ) - - def predict_xstart_from_eps(self, x_t, t, eps): - return self._predict_xstart_from_eps(x_t, t, eps) - - def _predict_eps_from_xstart(self, x_t, t, pred_xstart): - return ( - _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart - ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) - - def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): - """ - Compute the mean for the previous step, given a function cond_fn that - computes the gradient of a conditional log probability with respect to - x. In particular, cond_fn computes grad(log(p(y|x))), and we want to - condition on y. - This uses the conditioning strategy from Sohl-Dickstein et al. (2015). - """ - gradient = cond_fn(x, t, **model_kwargs) - new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() - return new_mean - - def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): - """ - Compute what the p_mean_variance output would have been, should the - model's score function be conditioned by cond_fn. - See condition_mean() for details on cond_fn. - Unlike condition_mean(), this instead uses the conditioning strategy - from Song et al (2020). - """ - alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) - - eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) - eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs) - - out = p_mean_var.copy() - out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) - out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) - return out - - def p_sample( - self, - model, - x: Tensor, - t, - clip_denoised=True, - denoised_fn=None, - cond_fn=None, - model_kwargs=None, - frames_mask: Optional[np.ndarray] = None, - ): - """ - Sample x_{t-1} from the model at the given timestep. - :param model: the model to sample from. - :param x: the current tensor at x_{t-1}. - :param t: the value of t, starting at 0 for the first diffusion step. - :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. - :param denoised_fn: if not None, a function which applies to the - x_start prediction before it is used to sample. - :param cond_fn: if not None, this is a gradient function that acts - similarly to the model. - :param model_kwargs: if not None, a dict of extra keyword arguments to - pass to the model. This can be used for conditioning. - :return: a dict containing the following keys: - - 'sample': a random sample from the model. - - 'pred_xstart': a prediction of x_0. - """ - if frames_mask is not None: - if frames_mask.shape[0] != x.shape[0]: - frames_mask = frames_mask.reshape(1, -1).repeat(2, axis=0) # HACK - mask_t = (frames_mask * len(self.betas)).astype(np.int32) - - # x0: copy unchanged x values - # x_noise: add noise to x values - x0 = x.copy() - x_noise = x0 * _extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) + ops.randn_like( - x - ) * _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) - - # active noise addition - mask_t_equall = (mask_t == t.unsqueeze(1))[:, None, :, None, None] - x = ops.where(Tensor(mask_t_equall), x_noise, x0) # FIXME: numpy - - # create frames_mask - mask_t_upper = (mask_t > t.unsqueeze(1))[:, None, :, None, None] - batch_size = x.shape[0] - model_kwargs["frames_mask"] = Tensor(mask_t_upper.reshape(batch_size, -1), dtype=ms.bool_) # FIXME: numpy - - out = self.p_mean_variance( - model, - x, - t, - clip_denoised=clip_denoised, - denoised_fn=denoised_fn, - model_kwargs=model_kwargs, - ) - noise = ops.randn_like(x) - nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0 - if cond_fn is not None: - out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) - sample = out["mean"] + nonzero_mask * ops.exp(0.5 * out["log_variance"]) * noise - - if frames_mask is not None: - mask_t_lower = (mask_t < t.unsqueeze(1))[:, None, :, None, None] - sample = ops.where(Tensor(mask_t_lower), x0, sample) # FIXME: numpy - - return {"sample": sample, "pred_xstart": out["pred_xstart"]} - - def p_sample_loop( - self, - model, - shape, - noise=None, - clip_denoised=True, - denoised_fn=None, - cond_fn=None, - model_kwargs=None, - progress=False, - frames_mask: Optional[Tensor] = None, - ): - """ - Generate samples from the model. - :param model: the model module. - :param shape: the shape of the samples, (N, C, H, W). - :param noise: if specified, the noise from the encoder to sample. - Should be of the same shape as `shape`. - :param clip_denoised: if True, clip x_start predictions to [-1, 1]. - :param denoised_fn: if not None, a function which applies to the - x_start prediction before it is used to sample. - :param cond_fn: if not None, this is a gradient function that acts - similarly to the model. - :param model_kwargs: if not None, a dict of extra keyword arguments to - pass to the model. This can be used for conditioning. - :param progress: if True, show a tqdm progress bar. - :return: a non-differentiable batch of samples. - """ - final = None - for sample in self.p_sample_loop_progressive( - model, - shape, - noise=noise, - clip_denoised=clip_denoised, - denoised_fn=denoised_fn, - cond_fn=cond_fn, - model_kwargs=model_kwargs, - progress=progress, - frames_mask=frames_mask, - ): - final = sample - return final["sample"] - - def p_sample_loop_progressive( - self, - model, - shape, - noise=None, - clip_denoised=True, - denoised_fn=None, - cond_fn=None, - model_kwargs=None, - progress=False, - frames_mask: Optional[Tensor] = None, - ): - """ - Generate samples from the model and yield intermediate samples from - each timestep of diffusion. - Arguments are the same as p_sample_loop(). - Returns a generator over dicts, where each dict is the return value of - p_sample(). - """ - assert isinstance(shape, (tuple, list)) - if noise is not None: - img = noise - else: - img = ops.randn(*shape) - indices = list(range(self.num_timesteps))[::-1] - - if progress: - # Lazy import so that we don't depend on tqdm. - from tqdm.auto import tqdm - - indices = tqdm(indices) - - for i in indices: - t = ms.Tensor([i] * shape[0]) - # no_grad - out = self.p_sample( - model, - img, - t, - clip_denoised=clip_denoised, - denoised_fn=denoised_fn, - cond_fn=cond_fn, - model_kwargs=model_kwargs, - frames_mask=frames_mask, - ) - yield out - img = out["sample"] - - def ddim_sample( - self, - model, - x, - t, - clip_denoised=True, - denoised_fn=None, - cond_fn=None, - model_kwargs=None, - eta=0.0, - ): - """ - Sample x_{t-1} from the model using DDIM. - Same usage as p_sample(). - """ - out = self.p_mean_variance( - model, - x, - t, - clip_denoised=clip_denoised, - denoised_fn=denoised_fn, - model_kwargs=model_kwargs, - ) - if cond_fn is not None: - out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) - - # Usually our model outputs epsilon, but we re-derive it - # in case we used x_start or x_prev prediction. - eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) - - alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) - alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) - sigma = eta * ops.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * ops.sqrt(1 - alpha_bar / alpha_bar_prev) - # Equation 12. - noise = ops.randn_like(x) - mean_pred = out["pred_xstart"] * ops.sqrt(alpha_bar_prev) + ops.sqrt(1 - alpha_bar_prev - sigma**2) * eps - nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0 - sample = mean_pred + nonzero_mask * sigma * noise - return {"sample": sample, "pred_xstart": out["pred_xstart"]} - - def ddim_reverse_sample( - self, - model, - x, - t, - clip_denoised=True, - denoised_fn=None, - cond_fn=None, - model_kwargs=None, - eta=0.0, - ): - """ - Sample x_{t+1} from the model using DDIM reverse ODE. - """ - assert eta == 0.0, "Reverse ODE only for deterministic path" - out = self.p_mean_variance( - model, - x, - t, - clip_denoised=clip_denoised, - denoised_fn=denoised_fn, - model_kwargs=model_kwargs, - ) - if cond_fn is not None: - out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) - # Usually our model outputs epsilon, but we re-derive it - # in case we used x_start or x_prev prediction. - eps = ( - _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"] - ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) - alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) - - # Equation 12. reversed - mean_pred = out["pred_xstart"] * ops.sqrt(alpha_bar_next) + ops.sqrt(1 - alpha_bar_next) * eps - - return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} - - def ddim_sample_loop( - self, - model, - shape, - noise=None, - clip_denoised=True, - denoised_fn=None, - cond_fn=None, - model_kwargs=None, - progress=False, - eta=0.0, - frames_mask: Optional[Tensor] = None, # TODO: integrate support - ): - """ - Generate samples from the model using DDIM. - Same usage as p_sample_loop(). - """ - final = None - for sample in self.ddim_sample_loop_progressive( - model, - shape, - noise=noise, - clip_denoised=clip_denoised, - denoised_fn=denoised_fn, - cond_fn=cond_fn, - model_kwargs=model_kwargs, - progress=progress, - eta=eta, - ): - final = sample - return final["sample"] - - def ddim_sample_loop_progressive( - self, - model, - shape, - noise=None, - clip_denoised=True, - denoised_fn=None, - cond_fn=None, - model_kwargs=None, - progress=False, - eta=0.0, - ): - """ - Use DDIM to sample from the model and yield intermediate samples from - each timestep of DDIM. - Same usage as p_sample_loop_progressive(). - """ - assert isinstance(shape, (tuple, list)) - if noise is not None: - img = noise - else: - img = ops.randn( - *shape, - ) - indices = list(range(self.num_timesteps))[::-1] - - if progress: - # Lazy import so that we don't depend on tqdm. - from tqdm.auto import tqdm - - indices = tqdm(indices) - - for i in indices: - t = ms.Tensor([i] * shape[0]) - # with th.no_grad(): - out = self.ddim_sample( - model, - img, - t, - clip_denoised=clip_denoised, - denoised_fn=denoised_fn, - cond_fn=cond_fn, - model_kwargs=model_kwargs, - eta=eta, - ) - yield out - img = out["sample"] - - def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None): - """ - Get a term for the variational lower-bound. - The resulting units are bits (rather than nats, as one might expect). - This allows for comparison to other papers. - :return: a dict with the following keys: - - 'output': a shape [N] tensor of NLLs or KLs. - - 'pred_xstart': the x_0 predictions. - """ - true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t) - out = self.p_mean_variance(model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs) - kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]) - kl = mean_flat(kl) / np.log(2.0) - - decoder_nll = -discretized_gaussian_log_likelihood( - x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] - ) - assert decoder_nll.shape == x_start.shape - decoder_nll = mean_flat(decoder_nll) / Tensor(np.log(2.0)) - decoder_nll = decoder_nll.to(kl.dtype) - - # At the first timestep return the decoder NLL, - # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) - output = ops.where((t == 0), decoder_nll, kl) - return {"output": output, "pred_xstart": out["pred_xstart"]} diff --git a/examples/dvt/dvt/schedulers/iddpm/respace.py b/examples/dvt/dvt/schedulers/iddpm/respace.py deleted file mode 100644 index 5a9f00e59c..0000000000 --- a/examples/dvt/dvt/schedulers/iddpm/respace.py +++ /dev/null @@ -1,120 +0,0 @@ -# Modified from OpenAI's diffusion repos -# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py -# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion -# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py - -import numpy as np - -import mindspore as ms - -from .gaussian_diffusion import GaussianDiffusion - - -def space_timesteps(num_timesteps, section_counts): - """ - Create a list of timesteps to use from an original diffusion process, - given the number of timesteps we want to take from equally-sized portions - of the original process. - For example, if there's 300 timesteps and the section counts are [10,15,20] - then the first 100 timesteps are strided to be 10 timesteps, the second 100 - are strided to be 15 timesteps, and the final 100 are strided to be 20. - If the stride is a string starting with "ddim", then the fixed striding - from the DDIM paper is used, and only one section is allowed. - :param num_timesteps: the number of diffusion steps in the original - process to divide up. - :param section_counts: either a list of numbers, or a string containing - comma-separated numbers, indicating the step count - per section. As a special case, use "ddimN" where N - is a number of steps to use the striding from the - DDIM paper. - :return: a set of diffusion steps from the original process to use. - """ - if isinstance(section_counts, str): - if section_counts.startswith("ddim"): - desired_count = int(section_counts[len("ddim") :]) - for i in range(1, num_timesteps): - if len(range(0, num_timesteps, i)) == desired_count: - return set(range(0, num_timesteps, i)) - raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride") - section_counts = [int(x) for x in section_counts.split(",")] - size_per = num_timesteps // len(section_counts) - extra = num_timesteps % len(section_counts) - start_idx = 0 - all_steps = [] - for i, section_count in enumerate(section_counts): - size = size_per + (1 if i < extra else 0) - if size < section_count: - raise ValueError(f"cannot divide section of {size} steps into {section_count}") - if section_count <= 1: - frac_stride = 1 - else: - frac_stride = (size - 1) / (section_count - 1) - cur_idx = 0.0 - taken_steps = [] - for _ in range(section_count): - taken_steps.append(start_idx + round(cur_idx)) - cur_idx += frac_stride - all_steps += taken_steps - start_idx += size - return set(all_steps) - - -class SpacedDiffusion(GaussianDiffusion): - """ - A diffusion process which can skip steps in a base diffusion process. - :param use_timesteps: a collection (sequence or set) of timesteps from the - original diffusion process to retain. - :param kwargs: the kwargs to create the base diffusion process. - """ - - def __init__(self, use_timesteps, **kwargs): - self.use_timesteps = set(use_timesteps) - self.timestep_map = [] - self.original_num_steps = len(kwargs["betas"]) - - base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa - last_alpha_cumprod = 1.0 - new_betas = [] - for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): - if i in self.use_timesteps: - new_betas.append(1 - float(alpha_cumprod) / last_alpha_cumprod) - last_alpha_cumprod = alpha_cumprod - self.timestep_map.append(i) - kwargs["betas"] = np.array(new_betas) - super().__init__(**kwargs) - - def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs - return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) - - def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs - return super().training_losses(self._wrap_model(model), *args, **kwargs) - - def condition_mean(self, cond_fn, *args, **kwargs): - return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) - - def condition_score(self, cond_fn, *args, **kwargs): - return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) - - def _wrap_model(self, model): - if isinstance(model, _WrappedModel): - return model - return _WrappedModel(model, self.timestep_map, self.original_num_steps) - - def _scale_timesteps(self, t): - # Scaling is done by the wrapped model. - return t - - -class _WrappedModel: - def __init__(self, model, timestep_map, original_num_steps): - self.model = model - self.timestep_map = timestep_map - # self.rescale_timesteps = rescale_timesteps - self.original_num_steps = original_num_steps - - def __call__(self, x, ts, **kwargs): - map_tensor = ms.Tensor(self.timestep_map, dtype=ts.dtype) - new_ts = map_tensor[ts] - # if self.rescale_timesteps: - # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) - return self.model(x, new_ts, **kwargs) diff --git a/examples/dvt/dvt/schedulers/iddpm/timestep_sampler.py b/examples/dvt/dvt/schedulers/iddpm/timestep_sampler.py deleted file mode 100644 index 993939dac4..0000000000 --- a/examples/dvt/dvt/schedulers/iddpm/timestep_sampler.py +++ /dev/null @@ -1,144 +0,0 @@ -# Modified from OpenAI's diffusion repos -# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py -# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion -# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py - -from abc import ABC, abstractmethod - -import numpy as np - -import mindspore as ms - -# from mindspore import ops - - -def create_named_schedule_sampler(name, diffusion): - """ - Create a ScheduleSampler from a library of pre-defined samplers. - :param name: the name of the sampler. - :param diffusion: the diffusion object to sample for. - """ - if name == "uniform": - return UniformSampler(diffusion) - elif name == "loss-second-moment": - return LossSecondMomentResampler(diffusion) - else: - raise NotImplementedError(f"unknown schedule sampler: {name}") - - -class ScheduleSampler(ABC): - """ - A distribution over timesteps in the diffusion process, intended to reduce - variance of the objective. - By default, samplers perform unbiased importance sampling, in which the - objective's mean is unchanged. - However, subclasses may override sample() to change how the resampled - terms are reweighted, allowing for actual changes in the objective. - """ - - @abstractmethod - def weights(self): - """ - Get a numpy array of weights, one per diffusion step. - The weights needn't be normalized, but must be positive. - """ - - def sample(self, batch_size): - """ - Importance-sample timesteps for a batch. - :param batch_size: the number of timesteps. - :return: a tuple (timesteps, weights): - - timesteps: a tensor of timestep indices. - - weights: a tensor of weights to scale the resulting losses. - """ - w = self.weights() - p = w / np.sum(w) - indices_np = np.random.choice(len(p), size=(batch_size,), p=p) - indices = ms.Tensor(indices_np).long() - weights_np = 1 / (len(p) * p[indices_np]) - weights = ms.Tensor(weights_np).float() - return indices, weights - - -class UniformSampler(ScheduleSampler): - def __init__(self, diffusion): - self.diffusion = diffusion - self._weights = np.ones([diffusion.num_timesteps]) - - def weights(self): - return self._weights - - -class LossAwareSampler(ScheduleSampler): - # def update_with_local_losses(self, local_ts, local_losses): - # """ - # Update the reweighting using losses from a model. - # Call this method from each rank with a batch of timesteps and the - # corresponding losses for each of those timesteps. - # This method will perform synchronization to make sure all of the ranks - # maintain the exact same reweighting. - # :param local_ts: an integer Tensor of timesteps. - # :param local_losses: a 1D Tensor of losses. - # """ - # batch_sizes = [ms.Tensor([0], dtype=ms.int32) for _ in range(dist.get_world_size())] - # dist.all_gather( - # batch_sizes, - # ms.Tensor([len(local_ts)], dtype=ms.int32), - # ) - - # # Pad all_gather batches to be the maximum batch size. - # batch_sizes = [x.item() for x in batch_sizes] - # max_bs = max(batch_sizes) - - # timestep_batches = [ops.zeros(max_bs) for bs in batch_sizes] - # loss_batches = [ops.zeros(max_bs) for bs in batch_sizes] - # dist.all_gather(timestep_batches, local_ts) - # dist.all_gather(loss_batches, local_losses) - # timesteps = [x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]] - # losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] - # self.update_with_all_losses(timesteps, losses) - - @abstractmethod - def update_with_all_losses(self, ts, losses): - """ - Update the reweighting using losses from a model. - Sub-classes should override this method to update the reweighting - using losses from the model. - This method directly updates the reweighting without synchronizing - between workers. It is called by update_with_local_losses from all - ranks with identical arguments. Thus, it should have deterministic - behavior to maintain state across workers. - :param ts: a list of int timesteps. - :param losses: a list of float losses, one per timestep. - """ - - -class LossSecondMomentResampler(LossAwareSampler): - def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): - self.diffusion = diffusion - self.history_per_term = history_per_term - self.uniform_prob = uniform_prob - self._loss_history = np.zeros([diffusion.num_timesteps, history_per_term], dtype=np.float64) - self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) - - def weights(self): - if not self._warmed_up(): - return np.ones([self.diffusion.num_timesteps], dtype=np.float64) - weights = np.sqrt(np.mean(self._loss_history**2, axis=-1)) - weights /= np.sum(weights) - weights *= 1 - self.uniform_prob - weights += self.uniform_prob / len(weights) - return weights - - def update_with_all_losses(self, ts, losses): - for t, loss in zip(ts, losses): - if self._loss_counts[t] == self.history_per_term: - # Shift out the oldest loss term. - self._loss_history[t, :-1] = self._loss_history[t, 1:] - self._loss_history[t, -1] = loss - else: - self._loss_history[t, self._loss_counts[t]] = loss - self._loss_counts[t] += 1 - - def _warmed_up(self): - return (self._loss_counts == self.history_per_term).all() diff --git a/mindone/models/modules/parallel/__init__.py b/mindone/models/modules/parallel/__init__.py index dc6ca85b90..912f825ba5 100644 --- a/mindone/models/modules/parallel/__init__.py +++ b/mindone/models/modules/parallel/__init__.py @@ -1,9 +1,12 @@ +from mindspore import nn + from .conv import Conv1d, Conv2d, Conv3d from .dense import Dense -from mindspore import nn -PARALLEL_MODULE = {nn.Conv1d: Conv1d, - nn.Conv2d: Conv2d, - nn.Conv3d: Conv3d, - nn.Dense: Dense,} +PARALLEL_MODULE = { + nn.Conv1d: Conv1d, + nn.Conv2d: Conv2d, + nn.Conv3d: Conv3d, + nn.Dense: Dense, +} __all__ = ["Conv1d", "Conv2d", "Conv3d", "Dense"] diff --git a/mindone/models/modules/parallel/conv.py b/mindone/models/modules/parallel/conv.py index 79639ed096..5cd7b763b7 100644 --- a/mindone/models/modules/parallel/conv.py +++ b/mindone/models/modules/parallel/conv.py @@ -1,10 +1,10 @@ -from mindspore import ops, nn -from mindspore.nn import Conv1d as MSConv1d -from mindspore.nn import Conv2d as MSConv2d -from mindspore.nn import Conv3d as MSConv3d +from mindspore import nn, ops from mindspore.communication import get_group_size, get_rank from mindspore.communication.management import GlobalComm from mindspore.context import ParallelMode +from mindspore.nn import Conv1d as MSConv1d +from mindspore.nn import Conv2d as MSConv2d +from mindspore.nn import Conv3d as MSConv3d from mindspore.parallel._utils import _get_parallel_mode from .param_wrapper import ZeroParamWrapper diff --git a/mindone/models/modules/parallel/dense.py b/mindone/models/modules/parallel/dense.py index 9eb5669f7f..19ab0d04e7 100644 --- a/mindone/models/modules/parallel/dense.py +++ b/mindone/models/modules/parallel/dense.py @@ -1,5 +1,5 @@ import mindspore as ms -from mindspore import ops, nn +from mindspore import nn, ops from mindspore.communication import get_group_size, get_rank from mindspore.communication.management import GlobalComm from mindspore.context import ParallelMode diff --git a/mindone/models/modules/parallel/param_wrapper.py b/mindone/models/modules/parallel/param_wrapper.py index 62b1143e2a..1ca8d753b7 100644 --- a/mindone/models/modules/parallel/param_wrapper.py +++ b/mindone/models/modules/parallel/param_wrapper.py @@ -1,5 +1,5 @@ import mindspore as ms -from mindspore import ops, nn +from mindspore import nn, ops from mindspore.communication import get_group_size from mindspore.communication.management import GlobalComm from mindspore.context import ParallelMode diff --git a/mindone/trainers/zero.py b/mindone/trainers/zero.py index 7183bfdbe1..a7719287c7 100644 --- a/mindone/trainers/zero.py +++ b/mindone/trainers/zero.py @@ -1,5 +1,5 @@ import logging -from typing import List, Optional + import mindspore as ms from mindspore import nn, ops from mindspore.communication import get_group_size, get_rank @@ -7,9 +7,10 @@ from mindspore.context import ParallelMode from mindspore.parallel._utils import _get_parallel_mode -from .train_step import TrainOneStepWrapper from mindone.models.modules.parallel import PARALLEL_MODULE +from .train_step import TrainOneStepWrapper + _logger = logging.getLogger(__name__) @@ -161,7 +162,7 @@ def set_comm_ops(self,): self.dp_group_size = ms.Tensor(get_group_size(group=self.dp_group), ms.float32) def update_comm_op_info(self, comm_op_info, bucket_size, param_size, param_name): - if comm_op_info[-1]["size"] + param_size <= bucket_size or len(comm_op_info) == 1: + if comm_op_info[-1]["size"] + param_size <= bucket_size or len(comm_op_info[-1]["params"]) == 0: comm_op_info[-1]["size"] += param_size comm_op_info[-1]["params"].append(param_name) else: @@ -174,15 +175,13 @@ def set_zero1_allreduce_fusion_comm_list(self, comm_fusion): self.max_fusion_id += 1 self.zero1_allreduce_list = [] for i, param in enumerate(self.ori_parameters): - param_size = param.itemsize + param_size = param.itemsize * param.size param_name = param.name - self.update_comm_op_info(allreduce_info, - comm_fusion["allreduce"]["bucket_size"], - param_size, - param_name) + self.update_comm_op_info(allreduce_info, comm_fusion["allreduce"]["bucket_size"], param_size, param_name) comm_op = ops.AllReduce(op=ops.ReduceOp.SUM, group=self.op_group) comm_op.add_prim_attr("fusion", allreduce_info[-1]["fusion_id"]) self.zero1_allreduce_list.append(comm_op) + _logger.info(f"zero1_allreduce_fusion: {allreduce_info}") def set_zero2_reduce_scatter_fusion_comm_list(self, comm_fusion): reduce_scatter_info = [{"size": 0, "fusion_id": self.max_fusion_id + 1, "params": []}] @@ -192,18 +191,16 @@ def set_zero2_reduce_scatter_fusion_comm_list(self, comm_fusion): self.zero2_reduce_scatter_list = [] self.zero2_allreduce_list = [] for i, param in enumerate(self.ori_parameters): - param_size = param.itemsize + param_size = param.itemsize * param.size param_name = param.name if self.need_parameter_split[i]: - self.update_comm_op_info(reduce_scatter_info, - comm_fusion["reduce_scatter"]["bucket_size"], - param_size, - param_name) + self.update_comm_op_info( + reduce_scatter_info, comm_fusion["reduce_scatter"]["bucket_size"], param_size, param_name + ) else: - self.update_comm_op_info(allreduce_info, - comm_fusion["allreduce"]["bucket_size"], - param_size, - param_name) + self.update_comm_op_info( + allreduce_info, comm_fusion["allreduce"]["bucket_size"], param_size, param_name + ) comm_op = ops.ReduceScatter(op=ops.ReduceOp.SUM, group=self.op_group) comm_op.add_prim_attr("fusion", reduce_scatter_info[-1]["fusion_id"]) self.zero2_reduce_scatter_list.append(comm_op) @@ -211,38 +208,40 @@ def set_zero2_reduce_scatter_fusion_comm_list(self, comm_fusion): comm_op = ops.AllReduce(op=ops.ReduceOp.SUM, group=self.op_group) comm_op.add_prim_attr("fusion", allreduce_info[-1]["fusion_id"]) self.zero2_allreduce_list.append(comm_op) - + _logger.info(f"zero2_reduce_scatter_fusion: {reduce_scatter_info}") + _logger.info(f"zero2_reduce_scatter_fusion: {allreduce_info}") + def set_optimizer_allgather_fusion_comm_list(self, comm_fusion): allgather_info = [{"size": 0, "fusion_id": self.max_fusion_id + 1, "params": []}] self.max_fusion_id += 1 self.optimizer_allgather_list = [] for i, param in enumerate(self.ori_parameters): - param_size = param.itemsize + param_size = param.itemsize * param.size param_name = param.name if self.need_parameter_split[i]: - self.update_comm_op_info(allgather_info, - comm_fusion["allgather"]["bucket_size"], - param_size, - param_name) + self.update_comm_op_info( + allgather_info, comm_fusion["allgather"]["bucket_size"], param_size, param_name + ) comm_op = ops.AllGather(group=self.op_group) comm_op.add_prim_attr("fusion", allgather_info[-1]["fusion_id"]) self.optimizer_allgather_list.append(comm_op) + _logger.info(f"optimizer_allgather_fusion: {allgather_info}") def set_dp_allreduce_comm_list(self, comm_fusion): dp_allreduce_info = [{"size": 0, "fusion_id": self.max_fusion_id + 1, "params": []}] self.max_fusion_id += 1 self.dp_allreduce_list = [] for i, param in enumerate(self.ori_parameters): - param_size = param.itemsize + param_size = param.itemsize * param.size param_name = param.name if self.need_parameter_split[i]: - self.update_comm_op_info(dp_allreduce_info, - comm_fusion["allreduce"]["bucket_size"], - param_size, - param_name) + self.update_comm_op_info( + dp_allreduce_info, comm_fusion["allreduce"]["bucket_size"], param_size, param_name + ) comm_op = ops.AllGather(group=self.op_group) comm_op.add_prim_attr("fusion", dp_allreduce_info[-1]["fusion_id"]) self.dp_allreduce_list.append(comm_op) + _logger.info(f"dp_allreduce_fusion: {dp_allreduce_info}") def split_param(self, param): return self.split_op(param)[self.op_rank_id] @@ -474,6 +473,22 @@ def prepare_network(network: nn.Cell, zero_stage: int = 0, op_group: str = None) return network +def prepare_ema(ema, zero_stage: int = 0, op_group: str = None): + is_parallel = _get_parallel_mode() == ParallelMode.DATA_PARALLEL + if not is_parallel or zero_stage != 3: + return ema + op_group_size = get_group_size(op_group) + op_rank_id = get_rank(op_group) + split_op = ops.Split(0, op_group_size) + _logger.info(f"Split EMA params: rank_id {op_rank_id}, rank_size {op_group_size}.") + for net_weight, ema_weight, swap_cache in zip(ema.net_weight, ema.ema_weight, ema.swap_cache): + if net_weight.shape == ema_weight.shape: + continue + ema_weight.set_data(split_op(ema_weight)[op_rank_id], slice_shape=True) + swap_cache.set_data(split_op(swap_cache)[op_rank_id], slice_shape=True) + return ema + + def prepare_train_network( network: nn.Cell, optimizer: nn.Optimizer, @@ -527,6 +542,8 @@ def prepare_train_network( new_network = prepare_network(network, zero_stage, op_group) zero_helper = ZeroHelper(optimizer, zero_stage, op_group, dp_group, optimizer_offload, comm_fusion) + if ema is not None: + ema = prepare_ema(ema, zero_stage, op_group) if isinstance(scale_sense, float): scale_sense = ms.Tensor(scale_sense, ms.float32) train_network = TrainOneStepWrapper( diff --git a/tests/st/test_zero.py b/tests/st/test_zero.py index 10e5a20895..c9c99742c0 100644 --- a/tests/st/test_zero.py +++ b/tests/st/test_zero.py @@ -7,13 +7,14 @@ from mindspore.communication import get_group_size, get_rank, init from mindspore.communication.management import GlobalComm +from mindone.trainers.ema import EMA from mindone.trainers.zero import prepare_train_network from mindone.utils.logger import set_logger _logger = logging.getLogger(__name__) -def init_env(mode, distribute, save_graph=True, comm_fusio=False): +def init_env(mode, distribute, save_graph=True, comm_fusion=False): ms.set_seed(1) ms.set_context(mode=mode) if save_graph: @@ -28,10 +29,12 @@ def init_env(mode, distribute, save_graph=True, comm_fusio=False): parallel_mode=ms.ParallelMode.DATA_PARALLEL, gradients_mean=True, ) - if comm_fusio: - comm_fusion_dict = {"allreduce": {"mode": "auto", "config": None}, - "reducescatter": {"mode": "auto", "config": None}, - "allgather": {"mode": "auto", "config": None},} + if comm_fusion: + comm_fusion_dict = { + "allreduce": {"mode": "auto", "config": None}, + "reducescatter": {"mode": "auto", "config": None}, + "allgather": {"mode": "auto", "config": None}, + } ms.set_auto_parallel_context(comm_fusion=comm_fusion_dict) return group_size, rank_id return 1, 0 @@ -66,13 +69,17 @@ def test_zero(x, y, zero_stage=0, comm_fusion=False): ms.set_seed(1) net = nn.WithLossCell(TestNet(), nn.MSELoss()) opt = nn.AdamWeightDecay(net.trainable_params(), learning_rate=1e-3) + ema = EMA(net) comm_fusion_dict = None if comm_fusion: - comm_fusion_dict = {"allreduce": {"openstate": True, "bucket_size": 5e8}, - "reduce_scatter": {"openstate": True, "bucket_size": 5e8}, - "allgather": {"openstate": False, "bucket_size": 5e8},} - train_net = prepare_train_network(net, opt, zero_stage=zero_stage, op_group=GlobalComm.WORLD_COMM_GROUP, - comm_fusion=comm_fusion_dict) + comm_fusion_dict = { + "allreduce": {"bucket_size": 64}, + "reduce_scatter": {"bucket_size": 64}, + "allgather": {"bucket_size": 64}, + } + train_net = prepare_train_network( + net, opt, ema=ema, zero_stage=zero_stage, op_group=GlobalComm.WORLD_COMM_GROUP, comm_fusion=comm_fusion_dict + ) for i in range(10): loss = train_net(x, y) @@ -80,11 +87,12 @@ def test_zero(x, y, zero_stage=0, comm_fusion=False): if __name__ == "__main__": - group_size, rank_id = init_env(mode=0, distribute=True, save_graph=True) + comm_fusion = False + group_size, rank_id = init_env(mode=0, distribute=True, save_graph=False, comm_fusion=comm_fusion) set_logger(name="", output_dir="logs", rank=rank_id, log_level="DEBUG") x = ms.Tensor(np.random.uniform(-1, 1, (1, 2, 5, 5)).astype(np.float32) * (get_rank() + 1)) y = ms.Tensor(np.random.uniform(-1, 1, (1, 2, 5, 5)).astype(np.float32) * (get_rank() + 1)) - test_zero(x, y, zero_stage=0) - test_zero(x, y, zero_stage=1) - test_zero(x, y, zero_stage=2) - test_zero(x, y, zero_stage=3) + test_zero(x, y, zero_stage=0, comm_fusion=comm_fusion) + test_zero(x, y, zero_stage=1, comm_fusion=comm_fusion) + test_zero(x, y, zero_stage=2, comm_fusion=comm_fusion) + test_zero(x, y, zero_stage=3, comm_fusion=comm_fusion)