Skip to content

Commit

Permalink
refactor: refactored interfaces for compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
townwish4git committed Aug 16, 2024
1 parent 89aff8c commit 9b7dad0
Show file tree
Hide file tree
Showing 7 changed files with 322 additions and 227 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from mindspore.dataset import GeneratorDataset, transforms, vision

from mindone.diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionXLPipeline, UNet2DConditionModel
from mindone.diffusers.models.layers_compat import multinomial
from mindone.diffusers.optimization import get_scheduler
from mindone.diffusers.training_utils import (
AttrJitWrapper,
Expand All @@ -46,7 +47,6 @@
init_distributed_device,
is_master,
maybe_compile,
multinomial,
set_seed,
)

Expand Down
292 changes: 292 additions & 0 deletions mindone/diffusers/models/layers_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
# -*- coding: utf-8 -*-
"""Custom MindSpore Operators Suite
This module encapsulates custom implementations for a curated set of operators that are either unsupported or
introduced post specific MindSpore version. Recognizing the evolving nature of the framework, this suite ensures
compatibility across different MindSpore versions, particularly catering to scenarios where native support is
lacking across all versions, and require manual intervention for versions prior to specific one.
Key Features:
- **Conditional Implementations**:
Detects MindSpore's version at runtime to switch between native functions and custom equivalents.
- **Operator Coverage**:
[2024/07/26]
- **conv_transpose1d**: Always custom due to framework limitations.
- **conv_transpose2d**: Native post 2.3.0; custom for earlier versions.
- **group_norm**: Native post 2.3.0; custom for earlier versions.
- **multinomial**: Native post 2.3.0; custom for earlier versions.
- **pad**: Native post 2.3.0; custom for earlier versions.
Example:
Import this module and use the operators as you would with native MindSpore functions, with the assurance of cross-version compatibility.
>>> from mindone.diffusers.models.layers_compat import conv_transpose2d, interpolate
>>> # Depending on the MindSpore version, the correct implementation will be utilized.
Todo:
- Monitor MindSpore updates for potential native support inclusion.
- ...
"""

from packaging.version import parse

import mindspore as ms
from mindspore import ops
from mindspore.common.api import _function_forbid_reuse

__all__ = [
"conv_transpose1d",
"conv_transpose2d",
"group_norm",
"multinomial",
"pad",
]

MINDSPORE_VERSION = parse(ms.__version__)


# ================================================================================
# conv_transpose1d
# ================================================================================
def _conv_transpose1d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
# Equivalence of torch.nn.functional.conv_transpose1d
assert output_padding == 0, "Only support output_padding == 0 so far."

if isinstance(stride, int):
stride = (1, stride)
elif isinstance(stride, tuple):
stride = (1, stride[0])

if isinstance(dilation, int):
dilation = (dilation, dilation)
elif isinstance(dilation, tuple):
dilation = (dilation[0], dilation[0])

if isinstance(padding, int):
padding = (0, 0, padding, padding)
elif isinstance(padding, tuple):
padding = (0, 0, padding[0], padding[0])

# InferShape manually
# Format adapted from https://pytorch.org/docs/stable/generated/torch.nn.functional.conv_transpose1d.html
input = input.unsqueeze(2)
weight = weight.unsqueeze(2)
batch_size, in_channels, iH, iW = input.shape
_, out_channels_divide_groups, kH, kW = weight.shape

out_channels = out_channels_divide_groups * groups
outH = (iH - 1) * stride[0] - (padding[0] + padding[1]) + dilation[0] * (kH - 1) + 1
outW = (iW - 1) * stride[1] - (padding[2] + padding[3]) + dilation[1] * (kW - 1) + 1

op_conv_transpose2d = ops.Conv2DTranspose(
out_channel=out_channels,
kernel_size=(kH, kW),
pad_mode="pad",
pad=padding,
stride=stride,
dilation=dilation,
group=groups,
)
outputs = op_conv_transpose2d(input, weight.to(input.dtype), (batch_size, out_channels, outH, outW)).squeeze(2)

if bias is not None:
assert isinstance(bias, ms.Tensor) and bias.ndim == 1
bias = bias.reshape(1, -1, 1)
outputs += bias

return outputs


conv_transpose1d = _conv_transpose1d


# ================================================================================
# conv_transpose2d
# ================================================================================
def _conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
# Equivalence of torch.nn.functional.conv_transpose2d
assert output_padding == 0, "Only support output_padding == 0 so far."

if isinstance(stride, int):
stride = (stride, stride)
if isinstance(dilation, int):
dilation = (dilation, dilation)
if isinstance(padding, int):
padding = (padding, padding, padding, padding)
elif len(padding) == 2:
padding = (
padding[0],
padding[0],
padding[1],
padding[1],
)

# InferShape manually
# Format adapted from https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html#torch.nn.ConvTranspose2d
batch_size, in_channels, iH, iW = input.shape
_, out_channels_divide_groups, kH, kW = weight.shape

out_channels = out_channels_divide_groups * groups
outH = (iH - 1) * stride[0] - (padding[0] + padding[1]) + dilation[0] * (kH - 1) + 1
outW = (iW - 1) * stride[1] - (padding[2] + padding[3]) + dilation[1] * (kW - 1) + 1

op_conv_transpose2d = ops.Conv2DTranspose(
out_channel=out_channels,
kernel_size=(kH, kW),
pad_mode="pad",
pad=padding,
stride=stride,
dilation=dilation,
group=groups,
)
outputs = op_conv_transpose2d(input, weight.to(input.dtype), (batch_size, out_channels, outH, outW))

if bias is not None:
assert isinstance(bias, ms.Tensor) and bias.ndim == 1
bias = bias.reshape(1, -1, 1, 1)
outputs += bias

return outputs


if MINDSPORE_VERSION >= parse("2.3.0"):
conv_transpose2d = ms.mint.nn.functional.conv_transpose2d
else:
conv_transpose2d = _conv_transpose2d


# ================================================================================
# group_norm
# ================================================================================
def _group_norm(x, num_groups, weight, bias, eps):
x_shape = x.shape
x = x.reshape(x_shape[0], num_groups, -1)
var, mean = ops.var_mean(x, axis=-1, keepdims=True)
x = (x - mean) / ops.sqrt(var + eps)
x = x.reshape(x_shape)

if weight is not None and bias is not None:
expanded_shape = (1, -1) + (1,) * len(x_shape[2:])
x = x * weight.reshape(expanded_shape) + bias.reshape(expanded_shape)

return x


if MINDSPORE_VERSION >= parse("2.3.0"):
group_norm = ms.mint.nn.functional.group_norm
else:
group_norm = _group_norm


# ================================================================================
# multinomial
# ================================================================================
@_function_forbid_reuse
def _multinomial(input, num_samples, replacement=True, **kwargs):
assert isinstance(input, ms.Tensor) and input.ndim in (
1,
2,
), "argument input should be a MindSpore Tensor with 1 or 2 dim."
assert (
replacement or num_samples <= input.shape[-1]
), "cannot sample n_sample > prob_dist.size(-1) samples without replacement."

input = input.float()
input /= input.sum(-1, keepdims=True)

if num_samples == 1 or not replacement:
# The algorithm is from gumbel softmax.
# s = argmax( logp - log(-log(eps)) ) where eps ~ U(0, 1)
# Here we can apply exp to the formula which will not affect result of
# argmax or topk. Then we have
# s = argmax( p / (-log(eps)) ) where eps ~ U(0, 1).
# We can also simplify the formula above by
# s = argmax( p / q ) where q ~ Exp(1)
# No proper Exp generator op in MindSpore,
# so we still generate it by -log(eps)
q = -ops.log(ops.rand_like(input))
if num_samples == 1:
result = (input / q).argmax(-1, keepdim=True)
else:
_, result = ops.topk(input / q, k=num_samples, dim=-1)
else:
# To generate scalar random variable X with cumulative distribution ms.mint.nn.functional(x)
# just let X = ms.mint.nn.functional^(-1)(U) where U ~ U(0, 1)
input = input.cumsum(-1).expand_dims(-1)
rshape = (1, num_samples) if input.ndim == 2 else (input.shape[0], 1, num_samples)
rand = ops.rand(*rshape, dtype=input.dtype)
result = ops.ge(rand, input).long().sum(-2)

return result.long()


if MINDSPORE_VERSION >= parse("2.3.0"):
multinomial = ops.multinomial
else:
multinomial = _multinomial


# ================================================================================
# pad
# ================================================================================
def _pad(input, pad, mode="constant", value=0):
assert mode in ["constant", "replicate", "reflect"], "Unsupported padding mode"

padding = [0, 0, 0, 0]
if isinstance(pad, tuple):
assert len(pad) <= 4, "Only support padding for the lastest 2 dimensions."
pad = list(pad)
padding[: len(pad)] = pad

left, right, top, bottom = padding

height, width = input.shape[-2:]
other_dimensions = input.shape[:-2]
input = input.reshape(-1, height, width)
batch_size = input.shape[0]

padded_height = height + top + bottom
padded_width = width + left + right

output = ops.full((batch_size, padded_height, padded_width), value, dtype=input.dtype)
output[:, top : top + height, left : left + width] = input

if mode == "replicate":
if top > 0:
output[:, :top, left : left + width] = input[:, 0:1, :].broadcast_to((batch_size, top, width))
if bottom > 0:
output[:, top + height :, left : left + width] = input[:, -1:, :].broadcast_to((batch_size, bottom, width))
if left > 0:
output[:, :, :left] = output[:, :, left : left + 1].broadcast_to((batch_size, padded_height, left))
if right > 0:
output[:, :, left + width :] = output[:, :, left + width - 1 : left + width].broadcast_to(
(batch_size, padded_height, right)
)
elif mode == "reflect":
if top > 0:
output[:, :top, left : left + width] = (
input[:, 1 : top + 1, :].flip(dims=[1]).broadcast_to((batch_size, top, width))
)
if bottom > 0:
output[:, top + height :, left : left + width] = (
input[:, -bottom - 1 : -1, :].flip(dims=[1]).broadcast_to((batch_size, bottom, width))
)
if left > 0:
output[:, :, :left] = (
output[:, :, left + 1 : 2 * left + 1].flip(dims=[2]).broadcast_to((batch_size, padded_height, left))
)
if right > 0:
right_edge = max(0, left + width - right - 2)
output[:, :, left + width :] = output[:, :, left + width - 2 : right_edge : -1].broadcast_to(
(batch_size, padded_height, right)
)

target_shape = tuple(other_dimensions) + (padded_height, padded_width)
output = output.reshape(*target_shape)
return output


if MINDSPORE_VERSION >= parse("2.3.0"):
pad = ms.mint.nn.functional.pad
else:
pad = _pad
19 changes: 3 additions & 16 deletions mindone/diffusers/models/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from .activations import SiLU, get_activation
from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
from .layers_compat import group_norm


class AdaLayerNorm(nn.Cell):
Expand Down Expand Up @@ -152,7 +153,7 @@ def construct(self, x: ms.Tensor, emb: ms.Tensor) -> ms.Tensor:
emb = emb[:, :, None, None]
scale, shift = emb.chunk(2, axis=1)

x = _group_norm(x, self.num_groups, None, None, self.eps)
x = group_norm(x, self.num_groups, None, None, self.eps)
x = x * (1 + scale) + shift
return x

Expand Down Expand Up @@ -362,7 +363,7 @@ def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5, affine
self.bias = None

def construct(self, x: Tensor):
x = _group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
x = group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
return x


Expand Down Expand Up @@ -410,17 +411,3 @@ def construct(self, x):
nx = gx / (gx.mean(axis=-1, keep_dims=True) + 1e-6)
out = (self.gamma * (x * nx) + self.beta + x).to(x.dtype)
return out


def _group_norm(x, num_groups, weight, bias, eps):
x_shape = x.shape
x = x.reshape(x_shape[0], num_groups, -1)
var, mean = ops.var_mean(x, axis=-1, keepdims=True)
x = (x - mean) / ops.sqrt(var + eps)
x = x.reshape(x_shape)

if weight is not None and bias is not None:
expanded_shape = (1, -1) + (1,) * len(x_shape[2:])
x = x * weight.reshape(expanded_shape) + bias.reshape(expanded_shape)

return x
28 changes: 19 additions & 9 deletions mindone/diffusers/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,7 @@
from .attention_processor import SpatialNorm
from .downsampling import Downsample1D, Downsample2D, FirDownsample2D, KDownsample2D, downsample_2d # noqa
from .normalization import AdaGroupNorm, GroupNorm
from .upsampling import ( # noqa
FirUpsample2D,
KUpsample2D,
SdeVpUpsample2D,
Upsample1D,
Upsample2D,
upfirdn2d_native,
upsample_2d,
)
from .upsampling import FirUpsample2D, KUpsample2D, Upsample1D, Upsample2D, upfirdn2d_native, upsample_2d # noqa


class ResnetBlockCondNorm2D(nn.Cell):
Expand Down Expand Up @@ -176,6 +168,24 @@ def construct(self, input_tensor: ms.Tensor, temb: ms.Tensor) -> ms.Tensor:
return output_tensor


class SdeVpUpsample2D(nn.Cell):
"""
Equivalence of partial(F.interpolate, scale_factor=2.0, mode="nearest") used in ResnetBlock2D.__init__()
when self.up and kernel == "sde_vp". We wrap ops.interpolate in our implement because the `scale_factor`
argument cannot be directly utilized in certain modes and partial is not fully supported in GRAPH MODE.
"""

def __init__(self, scale_factor=2.0, mode="nearest"):
super().__init__()
self.scale_factor = scale_factor
self.mode = mode

def construct(self, x):
_, _, h, w = x.shape
x = ops.interpolate(x, size=(int(self.scale_factor * h), int(self.scale_factor * w)), mode=self.mode)
return x


class ResnetBlock2D(nn.Cell):
r"""
A Resnet block.
Expand Down
Loading

0 comments on commit 9b7dad0

Please sign in to comment.