From 9b7dad01c1d4afa96d9c39347916a5d176f2ecb4 Mon Sep 17 00:00:00 2001 From: townwish4git Date: Sat, 27 Jul 2024 11:56:04 +0800 Subject: [PATCH] refactor: refactored interfaces for compatibility --- .../text_to_image/train_text_to_image_sdxl.py | 2 +- mindone/diffusers/models/layers_compat.py | 292 ++++++++++++++++++ mindone/diffusers/models/normalization.py | 19 +- mindone/diffusers/models/resnet.py | 28 +- .../diffusers/models/unets/unet_1d_blocks.py | 105 +------ mindone/diffusers/models/upsampling.py | 63 +--- mindone/diffusers/training_utils.py | 40 --- 7 files changed, 322 insertions(+), 227 deletions(-) create mode 100644 mindone/diffusers/models/layers_compat.py diff --git a/examples/diffusers/text_to_image/train_text_to_image_sdxl.py b/examples/diffusers/text_to_image/train_text_to_image_sdxl.py index 53fb82dcfb..5a3ddc9b2d 100644 --- a/examples/diffusers/text_to_image/train_text_to_image_sdxl.py +++ b/examples/diffusers/text_to_image/train_text_to_image_sdxl.py @@ -38,6 +38,7 @@ from mindspore.dataset import GeneratorDataset, transforms, vision from mindone.diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionXLPipeline, UNet2DConditionModel +from mindone.diffusers.models.layers_compat import multinomial from mindone.diffusers.optimization import get_scheduler from mindone.diffusers.training_utils import ( AttrJitWrapper, @@ -46,7 +47,6 @@ init_distributed_device, is_master, maybe_compile, - multinomial, set_seed, ) diff --git a/mindone/diffusers/models/layers_compat.py b/mindone/diffusers/models/layers_compat.py new file mode 100644 index 0000000000..412de4bc82 --- /dev/null +++ b/mindone/diffusers/models/layers_compat.py @@ -0,0 +1,292 @@ +# -*- coding: utf-8 -*- +"""Custom MindSpore Operators Suite + +This module encapsulates custom implementations for a curated set of operators that are either unsupported or +introduced post specific MindSpore version. Recognizing the evolving nature of the framework, this suite ensures +compatibility across different MindSpore versions, particularly catering to scenarios where native support is +lacking across all versions, and require manual intervention for versions prior to specific one. + +Key Features: + - **Conditional Implementations**: + Detects MindSpore's version at runtime to switch between native functions and custom equivalents. + - **Operator Coverage**: + [2024/07/26] + - **conv_transpose1d**: Always custom due to framework limitations. + - **conv_transpose2d**: Native post 2.3.0; custom for earlier versions. + - **group_norm**: Native post 2.3.0; custom for earlier versions. + - **multinomial**: Native post 2.3.0; custom for earlier versions. + - **pad**: Native post 2.3.0; custom for earlier versions. + +Example: + Import this module and use the operators as you would with native MindSpore functions, with the assurance of cross-version compatibility. + + >>> from mindone.diffusers.models.layers_compat import conv_transpose2d, interpolate + >>> # Depending on the MindSpore version, the correct implementation will be utilized. + +Todo: + - Monitor MindSpore updates for potential native support inclusion. + - ... +""" + +from packaging.version import parse + +import mindspore as ms +from mindspore import ops +from mindspore.common.api import _function_forbid_reuse + +__all__ = [ + "conv_transpose1d", + "conv_transpose2d", + "group_norm", + "multinomial", + "pad", +] + +MINDSPORE_VERSION = parse(ms.__version__) + + +# ================================================================================ +# conv_transpose1d +# ================================================================================ +def _conv_transpose1d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): + # Equivalence of torch.nn.functional.conv_transpose1d + assert output_padding == 0, "Only support output_padding == 0 so far." + + if isinstance(stride, int): + stride = (1, stride) + elif isinstance(stride, tuple): + stride = (1, stride[0]) + + if isinstance(dilation, int): + dilation = (dilation, dilation) + elif isinstance(dilation, tuple): + dilation = (dilation[0], dilation[0]) + + if isinstance(padding, int): + padding = (0, 0, padding, padding) + elif isinstance(padding, tuple): + padding = (0, 0, padding[0], padding[0]) + + # InferShape manually + # Format adapted from https://pytorch.org/docs/stable/generated/torch.nn.functional.conv_transpose1d.html + input = input.unsqueeze(2) + weight = weight.unsqueeze(2) + batch_size, in_channels, iH, iW = input.shape + _, out_channels_divide_groups, kH, kW = weight.shape + + out_channels = out_channels_divide_groups * groups + outH = (iH - 1) * stride[0] - (padding[0] + padding[1]) + dilation[0] * (kH - 1) + 1 + outW = (iW - 1) * stride[1] - (padding[2] + padding[3]) + dilation[1] * (kW - 1) + 1 + + op_conv_transpose2d = ops.Conv2DTranspose( + out_channel=out_channels, + kernel_size=(kH, kW), + pad_mode="pad", + pad=padding, + stride=stride, + dilation=dilation, + group=groups, + ) + outputs = op_conv_transpose2d(input, weight.to(input.dtype), (batch_size, out_channels, outH, outW)).squeeze(2) + + if bias is not None: + assert isinstance(bias, ms.Tensor) and bias.ndim == 1 + bias = bias.reshape(1, -1, 1) + outputs += bias + + return outputs + + +conv_transpose1d = _conv_transpose1d + + +# ================================================================================ +# conv_transpose2d +# ================================================================================ +def _conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): + # Equivalence of torch.nn.functional.conv_transpose2d + assert output_padding == 0, "Only support output_padding == 0 so far." + + if isinstance(stride, int): + stride = (stride, stride) + if isinstance(dilation, int): + dilation = (dilation, dilation) + if isinstance(padding, int): + padding = (padding, padding, padding, padding) + elif len(padding) == 2: + padding = ( + padding[0], + padding[0], + padding[1], + padding[1], + ) + + # InferShape manually + # Format adapted from https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html#torch.nn.ConvTranspose2d + batch_size, in_channels, iH, iW = input.shape + _, out_channels_divide_groups, kH, kW = weight.shape + + out_channels = out_channels_divide_groups * groups + outH = (iH - 1) * stride[0] - (padding[0] + padding[1]) + dilation[0] * (kH - 1) + 1 + outW = (iW - 1) * stride[1] - (padding[2] + padding[3]) + dilation[1] * (kW - 1) + 1 + + op_conv_transpose2d = ops.Conv2DTranspose( + out_channel=out_channels, + kernel_size=(kH, kW), + pad_mode="pad", + pad=padding, + stride=stride, + dilation=dilation, + group=groups, + ) + outputs = op_conv_transpose2d(input, weight.to(input.dtype), (batch_size, out_channels, outH, outW)) + + if bias is not None: + assert isinstance(bias, ms.Tensor) and bias.ndim == 1 + bias = bias.reshape(1, -1, 1, 1) + outputs += bias + + return outputs + + +if MINDSPORE_VERSION >= parse("2.3.0"): + conv_transpose2d = ms.mint.nn.functional.conv_transpose2d +else: + conv_transpose2d = _conv_transpose2d + + +# ================================================================================ +# group_norm +# ================================================================================ +def _group_norm(x, num_groups, weight, bias, eps): + x_shape = x.shape + x = x.reshape(x_shape[0], num_groups, -1) + var, mean = ops.var_mean(x, axis=-1, keepdims=True) + x = (x - mean) / ops.sqrt(var + eps) + x = x.reshape(x_shape) + + if weight is not None and bias is not None: + expanded_shape = (1, -1) + (1,) * len(x_shape[2:]) + x = x * weight.reshape(expanded_shape) + bias.reshape(expanded_shape) + + return x + + +if MINDSPORE_VERSION >= parse("2.3.0"): + group_norm = ms.mint.nn.functional.group_norm +else: + group_norm = _group_norm + + +# ================================================================================ +# multinomial +# ================================================================================ +@_function_forbid_reuse +def _multinomial(input, num_samples, replacement=True, **kwargs): + assert isinstance(input, ms.Tensor) and input.ndim in ( + 1, + 2, + ), "argument input should be a MindSpore Tensor with 1 or 2 dim." + assert ( + replacement or num_samples <= input.shape[-1] + ), "cannot sample n_sample > prob_dist.size(-1) samples without replacement." + + input = input.float() + input /= input.sum(-1, keepdims=True) + + if num_samples == 1 or not replacement: + # The algorithm is from gumbel softmax. + # s = argmax( logp - log(-log(eps)) ) where eps ~ U(0, 1) + # Here we can apply exp to the formula which will not affect result of + # argmax or topk. Then we have + # s = argmax( p / (-log(eps)) ) where eps ~ U(0, 1). + # We can also simplify the formula above by + # s = argmax( p / q ) where q ~ Exp(1) + # No proper Exp generator op in MindSpore, + # so we still generate it by -log(eps) + q = -ops.log(ops.rand_like(input)) + if num_samples == 1: + result = (input / q).argmax(-1, keepdim=True) + else: + _, result = ops.topk(input / q, k=num_samples, dim=-1) + else: + # To generate scalar random variable X with cumulative distribution ms.mint.nn.functional(x) + # just let X = ms.mint.nn.functional^(-1)(U) where U ~ U(0, 1) + input = input.cumsum(-1).expand_dims(-1) + rshape = (1, num_samples) if input.ndim == 2 else (input.shape[0], 1, num_samples) + rand = ops.rand(*rshape, dtype=input.dtype) + result = ops.ge(rand, input).long().sum(-2) + + return result.long() + + +if MINDSPORE_VERSION >= parse("2.3.0"): + multinomial = ops.multinomial +else: + multinomial = _multinomial + + +# ================================================================================ +# pad +# ================================================================================ +def _pad(input, pad, mode="constant", value=0): + assert mode in ["constant", "replicate", "reflect"], "Unsupported padding mode" + + padding = [0, 0, 0, 0] + if isinstance(pad, tuple): + assert len(pad) <= 4, "Only support padding for the lastest 2 dimensions." + pad = list(pad) + padding[: len(pad)] = pad + + left, right, top, bottom = padding + + height, width = input.shape[-2:] + other_dimensions = input.shape[:-2] + input = input.reshape(-1, height, width) + batch_size = input.shape[0] + + padded_height = height + top + bottom + padded_width = width + left + right + + output = ops.full((batch_size, padded_height, padded_width), value, dtype=input.dtype) + output[:, top : top + height, left : left + width] = input + + if mode == "replicate": + if top > 0: + output[:, :top, left : left + width] = input[:, 0:1, :].broadcast_to((batch_size, top, width)) + if bottom > 0: + output[:, top + height :, left : left + width] = input[:, -1:, :].broadcast_to((batch_size, bottom, width)) + if left > 0: + output[:, :, :left] = output[:, :, left : left + 1].broadcast_to((batch_size, padded_height, left)) + if right > 0: + output[:, :, left + width :] = output[:, :, left + width - 1 : left + width].broadcast_to( + (batch_size, padded_height, right) + ) + elif mode == "reflect": + if top > 0: + output[:, :top, left : left + width] = ( + input[:, 1 : top + 1, :].flip(dims=[1]).broadcast_to((batch_size, top, width)) + ) + if bottom > 0: + output[:, top + height :, left : left + width] = ( + input[:, -bottom - 1 : -1, :].flip(dims=[1]).broadcast_to((batch_size, bottom, width)) + ) + if left > 0: + output[:, :, :left] = ( + output[:, :, left + 1 : 2 * left + 1].flip(dims=[2]).broadcast_to((batch_size, padded_height, left)) + ) + if right > 0: + right_edge = max(0, left + width - right - 2) + output[:, :, left + width :] = output[:, :, left + width - 2 : right_edge : -1].broadcast_to( + (batch_size, padded_height, right) + ) + + target_shape = tuple(other_dimensions) + (padded_height, padded_width) + output = output.reshape(*target_shape) + return output + + +if MINDSPORE_VERSION >= parse("2.3.0"): + pad = ms.mint.nn.functional.pad +else: + pad = _pad diff --git a/mindone/diffusers/models/normalization.py b/mindone/diffusers/models/normalization.py index db5f768acc..e30b07c41f 100644 --- a/mindone/diffusers/models/normalization.py +++ b/mindone/diffusers/models/normalization.py @@ -23,6 +23,7 @@ from .activations import SiLU, get_activation from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings +from .layers_compat import group_norm class AdaLayerNorm(nn.Cell): @@ -152,7 +153,7 @@ def construct(self, x: ms.Tensor, emb: ms.Tensor) -> ms.Tensor: emb = emb[:, :, None, None] scale, shift = emb.chunk(2, axis=1) - x = _group_norm(x, self.num_groups, None, None, self.eps) + x = group_norm(x, self.num_groups, None, None, self.eps) x = x * (1 + scale) + shift return x @@ -362,7 +363,7 @@ def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5, affine self.bias = None def construct(self, x: Tensor): - x = _group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + x = group_norm(x, self.num_groups, self.weight, self.bias, self.eps) return x @@ -410,17 +411,3 @@ def construct(self, x): nx = gx / (gx.mean(axis=-1, keep_dims=True) + 1e-6) out = (self.gamma * (x * nx) + self.beta + x).to(x.dtype) return out - - -def _group_norm(x, num_groups, weight, bias, eps): - x_shape = x.shape - x = x.reshape(x_shape[0], num_groups, -1) - var, mean = ops.var_mean(x, axis=-1, keepdims=True) - x = (x - mean) / ops.sqrt(var + eps) - x = x.reshape(x_shape) - - if weight is not None and bias is not None: - expanded_shape = (1, -1) + (1,) * len(x_shape[2:]) - x = x * weight.reshape(expanded_shape) + bias.reshape(expanded_shape) - - return x diff --git a/mindone/diffusers/models/resnet.py b/mindone/diffusers/models/resnet.py index f3a6efd0ce..5cca25b67f 100644 --- a/mindone/diffusers/models/resnet.py +++ b/mindone/diffusers/models/resnet.py @@ -22,15 +22,7 @@ from .attention_processor import SpatialNorm from .downsampling import Downsample1D, Downsample2D, FirDownsample2D, KDownsample2D, downsample_2d # noqa from .normalization import AdaGroupNorm, GroupNorm -from .upsampling import ( # noqa - FirUpsample2D, - KUpsample2D, - SdeVpUpsample2D, - Upsample1D, - Upsample2D, - upfirdn2d_native, - upsample_2d, -) +from .upsampling import FirUpsample2D, KUpsample2D, Upsample1D, Upsample2D, upfirdn2d_native, upsample_2d # noqa class ResnetBlockCondNorm2D(nn.Cell): @@ -176,6 +168,24 @@ def construct(self, input_tensor: ms.Tensor, temb: ms.Tensor) -> ms.Tensor: return output_tensor +class SdeVpUpsample2D(nn.Cell): + """ + Equivalence of partial(F.interpolate, scale_factor=2.0, mode="nearest") used in ResnetBlock2D.__init__() + when self.up and kernel == "sde_vp". We wrap ops.interpolate in our implement because the `scale_factor` + argument cannot be directly utilized in certain modes and partial is not fully supported in GRAPH MODE. + """ + + def __init__(self, scale_factor=2.0, mode="nearest"): + super().__init__() + self.scale_factor = scale_factor + self.mode = mode + + def construct(self, x): + _, _, h, w = x.shape + x = ops.interpolate(x, size=(int(self.scale_factor * h), int(self.scale_factor * w)), mode=self.mode) + return x + + class ResnetBlock2D(nn.Cell): r""" A Resnet block. diff --git a/mindone/diffusers/models/unets/unet_1d_blocks.py b/mindone/diffusers/models/unets/unet_1d_blocks.py index 08045f2c06..8916bbc695 100644 --- a/mindone/diffusers/models/unets/unet_1d_blocks.py +++ b/mindone/diffusers/models/unets/unet_1d_blocks.py @@ -17,6 +17,7 @@ from mindspore import nn, ops from ..activations import get_activation +from ..layers_compat import conv_transpose1d, pad from ..normalization import GroupNorm from ..resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims @@ -283,7 +284,7 @@ def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"): def construct(self, hidden_states: ms.Tensor) -> ms.Tensor: dtype = hidden_states.dtype - hidden_states = _pad(hidden_states, (self.pad,) * 2, self.pad_mode) + hidden_states = pad(hidden_states, (self.pad,) * 2, self.pad_mode) weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) indices = ops.arange(hidden_states.shape[1]) kernel = self.kernel.to(weight.dtype)[None, :].broadcast_to((hidden_states.shape[1], -1)) @@ -300,12 +301,12 @@ def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"): self.kernel = ms.Parameter(kernel_1d, name="kernel") def construct(self, hidden_states: ms.Tensor, temb: Optional[ms.Tensor] = None) -> ms.Tensor: - hidden_states = _pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode) + hidden_states = pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode) weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) indices = ops.arange(hidden_states.shape[1]) kernel = self.kernel.to(weight.dtype)[None, :].broadcast_to((hidden_states.shape[1], -1)) weight[indices, indices] = kernel - return _conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1) + return conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1) class SelfAttention1d(nn.Cell): @@ -694,101 +695,3 @@ def get_out_block( elif out_block_type == "ValueFunction": return OutValueFunctionBlock(fc_dim, embed_dim, act_fn) return None - - -def _conv_transpose1d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): - # Equivalence of torch.nn.functional.conv_transpose1d - assert output_padding == 0, "Only support output_padding == 0 so far." - - if isinstance(stride, int): - stride = (1, stride) - elif isinstance(stride, tuple): - stride = (1, stride[0]) - - if isinstance(dilation, int): - dilation = (dilation, dilation) - elif isinstance(dilation, tuple): - dilation = (dilation[0], dilation[0]) - - if isinstance(padding, int): - padding = (0, 0, padding, padding) - elif isinstance(padding, tuple): - padding = (0, 0, padding[0], padding[0]) - - # InferShape manually - # Format adapted from https://pytorch.org/docs/stable/generated/torch.nn.functional.conv_transpose1d.html - input = input.unsqueeze(2) - weight = weight.unsqueeze(2) - batch_size, in_channels, iH, iW = input.shape - _, out_channels_divide_groups, kH, kW = weight.shape - - out_channels = out_channels_divide_groups * groups - outH = (iH - 1) * stride[0] - (padding[0] + padding[1]) + dilation[0] * (kH - 1) + 1 - outW = (iW - 1) * stride[1] - (padding[2] + padding[3]) + dilation[1] * (kW - 1) + 1 - - op_conv_transpose2d = ops.Conv2DTranspose( - out_channel=out_channels, - kernel_size=(kH, kW), - pad_mode="pad", - pad=padding, - stride=stride, - dilation=dilation, - group=groups, - ) - outputs = op_conv_transpose2d(input, weight.to(input.dtype), (batch_size, out_channels, outH, outW)).squeeze(2) - - if bias is not None: - assert isinstance(bias, ms.Tensor) and bias.ndim == 1 - bias = bias.reshape(1, -1, 1) - outputs += bias - - return outputs - - -def _pad(input, pad, mode="constant", value=0): - assert mode in ["constant", "replicate", "reflect"], "Unsupported padding mode" - - padding = [0, 0, 0, 0] - if isinstance(pad, tuple): - pad = list(pad) - padding[: len(pad)] = pad - - left, right, top, bottom = padding - batch_size, height, width = input.shape - - padded_height = height + top + bottom - padded_width = width + left + right - - output = ops.full((batch_size, padded_height, padded_width), value, dtype=input.dtype) - output[:, top : top + height, left : left + width] = input - - if mode == "replicate": - if top > 0: - output[:, :top, left : left + width] = input[:, 0:1, :].broadcast_to((batch_size, top, width)) - if bottom > 0: - output[:, top + height :, left : left + width] = input[:, -1:, :].broadcast_to((batch_size, bottom, width)) - if left > 0: - output[:, :, :left] = output[:, :, left : left + 1].broadcast_to((batch_size, padded_height, left)) - if right > 0: - output[:, :, left + width :] = output[:, :, left + width - 1 : left + width].broadcast_to( - (batch_size, padded_height, right) - ) - elif mode == "reflect": - if top > 0: - output[:, :top, left : left + width] = ( - input[:, 1 : top + 1, :].flip(dims=[1]).broadcast_to((batch_size, top, width)) - ) - if bottom > 0: - output[:, top + height :, left : left + width] = ( - input[:, -bottom - 1 : -1, :].flip(dims=[1]).broadcast_to((batch_size, bottom, width)) - ) - if left > 0: - output[:, :, :left] = ( - output[:, :, left + 1 : 2 * left + 1].flip(dims=[2]).broadcast_to((batch_size, padded_height, left)) - ) - if right > 0: - right_edge = max(0, left + width - right - 2) - output[:, :, left + width :] = output[:, :, left + width - 2 : right_edge : -1].broadcast_to( - (batch_size, padded_height, right) - ) - return output diff --git a/mindone/diffusers/models/upsampling.py b/mindone/diffusers/models/upsampling.py index e29026dd95..271ed5cef8 100644 --- a/mindone/diffusers/models/upsampling.py +++ b/mindone/diffusers/models/upsampling.py @@ -16,6 +16,7 @@ import mindspore as ms from mindspore import nn, ops +from .layers_compat import conv_transpose2d from .normalization import LayerNorm, RMSNorm @@ -292,7 +293,7 @@ def _upsample_2d( weight = ops.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4) weight = ops.reshape(weight, (num_groups * inC, -1, convH, convW)) - inverse_conv = _conv_transpose2d( + inverse_conv = conv_transpose2d( hidden_states, weight, stride=stride, @@ -353,19 +354,7 @@ def construct(self, inputs: ms.Tensor) -> ms.Tensor: indices = ops.arange(inputs.shape[1]) kernel = self.kernel.to(weight.dtype)[None, :].broadcast_to((inputs.shape[1], -1, -1)) weight[indices, indices] = kernel - return _conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1) - - -class SdeVpUpsample2D(nn.Cell): - def __init__(self, scale_factor=2.0, mode="nearest"): - super().__init__() - self.scale_factor = scale_factor - self.mode = mode - - def construct(self, x): - _, _, h, w = x.shape - x = ops.interpolate(x, size=(int(self.scale_factor * h), int(self.scale_factor * w)), mode=self.mode) - return x + return conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1) def upfirdn2d_native( @@ -462,49 +451,3 @@ def upsample_2d( pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), ) return output - - -def _conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): - # Equivalence of torch.nn.functional.conv_transpose2d - assert output_padding == 0, "Only support output_padding == 0 so far." - - if isinstance(stride, int): - stride = (stride, stride) - if isinstance(dilation, int): - dilation = (dilation, dilation) - if isinstance(padding, int): - padding = (padding, padding, padding, padding) - elif len(padding) == 2: - padding = ( - padding[0], - padding[0], - padding[1], - padding[1], - ) - - # InferShape manually - # Format adapted from https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html#torch.nn.ConvTranspose2d - batch_size, in_channels, iH, iW = input.shape - _, out_channels_divide_groups, kH, kW = weight.shape - - out_channels = out_channels_divide_groups * groups - outH = (iH - 1) * stride[0] - (padding[0] + padding[1]) + dilation[0] * (kH - 1) + 1 - outW = (iW - 1) * stride[1] - (padding[2] + padding[3]) + dilation[1] * (kW - 1) + 1 - - op_conv_transpose2d = ops.Conv2DTranspose( - out_channel=out_channels, - kernel_size=(kH, kW), - pad_mode="pad", - pad=padding, - stride=stride, - dilation=dilation, - group=groups, - ) - outputs = op_conv_transpose2d(input, weight.to(input.dtype), (batch_size, out_channels, outH, outW)) - - if bias is not None: - assert isinstance(bias, ms.Tensor) and bias.ndim == 1 - bias = bias.reshape(1, -1, 1, 1) - outputs += bias - - return outputs diff --git a/mindone/diffusers/training_utils.py b/mindone/diffusers/training_utils.py index 90ded4c20f..f6b792c097 100644 --- a/mindone/diffusers/training_utils.py +++ b/mindone/diffusers/training_utils.py @@ -13,7 +13,6 @@ import mindspore as ms from mindspore import context, nn, ops from mindspore.amp import DynamicLossScaler, StaticLossScaler, all_finite -from mindspore.common.api import _function_forbid_reuse from mindspore.communication import get_group_size, get_local_rank, get_rank, init from mindone.diffusers._peft import set_peft_model_state_dict @@ -267,45 +266,6 @@ def load_state_dict(self, state_dict: dict) -> None: raise ValueError("shadow_params must all be Tensors") -@_function_forbid_reuse -def multinomial(input, num_samples, replacement=True, **kwargs): - assert isinstance(input, ms.Tensor) and input.ndim in ( - 1, - 2, - ), "argument input should be a MindSpore Tensor with 1 or 2 dim." - assert ( - replacement or num_samples <= input.shape[-1] - ), "cannot sample n_sample > prob_dist.size(-1) samples without replacement." - - input = input.float() - input /= input.sum(-1, keepdims=True) - - if num_samples == 1 or not replacement: - # The algorithm is from gumbel softmax. - # s = argmax( logp - log(-log(eps)) ) where eps ~ U(0, 1) - # Here we can apply exp to the formula which will not affect result of - # argmax or topk. Then we have - # s = argmax( p / (-log(eps)) ) where eps ~ U(0, 1). - # We can also simplify the formula above by - # s = argmax( p / q ) where q ~ Exp(1) - # No proper Exp generator op in MindSpore, - # so we still generate it by -log(eps) - q = -ops.log(ops.rand_like(input)) - if num_samples == 1: - result = (input / q).argmax(-1, keepdim=True) - else: - _, result = ops.topk(input / q, k=num_samples, dim=-1) - else: - # To generate scalar random variable X with cumulative distribution F(x) - # just let X = F^(-1)(U) where U ~ U(0, 1) - input = input.cumsum(-1).expand_dims(-1) - rshape = (1, num_samples) if input.ndim == 2 else (input.shape[0], 1, num_samples) - rand = ops.rand(*rshape, dtype=input.dtype) - result = ops.ge(rand, input).long().sum(-2) - - return result.long() - - def is_master(args): return args.rank == 0