diff --git a/mindone/diffusers/__init__.py b/mindone/diffusers/__init__.py index 59c21c15d6..772c1f3611 100644 --- a/mindone/diffusers/__init__.py +++ b/mindone/diffusers/__init__.py @@ -59,6 +59,7 @@ "DDIMPipeline", "DDPMPipeline", "DiffusionPipeline", + "DiTPipeline", "I2VGenXLPipeline", "IFImg2ImgPipeline", "IFImg2ImgSuperResolutionPipeline", @@ -87,6 +88,8 @@ "Kandinsky3Pipeline", "LatentConsistencyModelImg2ImgPipeline", "LatentConsistencyModelPipeline", + "LDMSuperResolutionPipeline", + "LDMTextToImagePipeline", "PixArtAlphaPipeline", "ShapEImg2ImgPipeline", "ShapEPipeline", @@ -99,6 +102,7 @@ "StableDiffusionControlNetInpaintPipeline", "StableDiffusionControlNetPipeline", "StableDiffusionDepth2ImgPipeline", + "StableDiffusionDiffEditPipeline", "StableDiffusionGLIGENPipeline", "StableDiffusionGLIGENTextImagePipeline", "StableDiffusionImageVariationPipeline", @@ -206,6 +210,7 @@ DDIMPipeline, DDPMPipeline, DiffusionPipeline, + DiTPipeline, I2VGenXLPipeline, IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, @@ -234,6 +239,8 @@ KandinskyV22PriorPipeline, LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline, + LDMSuperResolutionPipeline, + LDMTextToImagePipeline, PixArtAlphaPipeline, ShapEImg2ImgPipeline, ShapEPipeline, @@ -246,6 +253,7 @@ StableDiffusionControlNetInpaintPipeline, StableDiffusionControlNetPipeline, StableDiffusionDepth2ImgPipeline, + StableDiffusionDiffEditPipeline, StableDiffusionGLIGENPipeline, StableDiffusionGLIGENTextImagePipeline, StableDiffusionImageVariationPipeline, diff --git a/mindone/diffusers/pipelines/__init__.py b/mindone/diffusers/pipelines/__init__.py index 0a6d5fdc33..7be46685f9 100644 --- a/mindone/diffusers/pipelines/__init__.py +++ b/mindone/diffusers/pipelines/__init__.py @@ -30,7 +30,9 @@ "IFPipeline", "IFSuperResolutionPipeline", ], + "dit": ["DiTPipeline"], "i2vgen_xl": ["I2VGenXLPipeline"], + "latent_diffusion": ["LDMSuperResolutionPipeline", "LDMTextToImagePipeline"], "kandinsky": [ "KandinskyCombinedPipeline", "KandinskyImg2ImgCombinedPipeline", @@ -91,6 +93,7 @@ "StableDiffusionXLInstructPix2PixPipeline", "StableDiffusionXLPipeline", ], + "stable_diffusion_diffedit": ["StableDiffusionDiffEditPipeline"], "stable_video_diffusion": ["StableVideoDiffusionPipeline"], "t2i_adapter": [ "StableDiffusionAdapterPipeline", @@ -131,6 +134,7 @@ IFPipeline, IFSuperResolutionPipeline, ) + from .dit import DiTPipeline from .i2vgen_xl import I2VGenXLPipeline from .kandinsky import ( KandinskyCombinedPipeline, @@ -155,6 +159,7 @@ ) from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline + from .latent_diffusion import LDMSuperResolutionPipeline, LDMTextToImagePipeline from .pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .pixart_alpha import PixArtAlphaPipeline from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline @@ -171,6 +176,7 @@ StableDiffusionUpscalePipeline, ) from .stable_diffusion_3 import StableDiffusion3Pipeline + from .stable_diffusion_diffedit import StableDiffusionDiffEditPipeline from .stable_diffusion_gligen import StableDiffusionGLIGENPipeline, StableDiffusionGLIGENTextImagePipeline from .stable_diffusion_xl import ( StableDiffusionXLImg2ImgPipeline, diff --git a/mindone/diffusers/pipelines/dit/__init__.py b/mindone/diffusers/pipelines/dit/__init__.py new file mode 100644 index 0000000000..14a8a546f2 --- /dev/null +++ b/mindone/diffusers/pipelines/dit/__init__.py @@ -0,0 +1,18 @@ +from typing import TYPE_CHECKING + +from ...utils import _LazyModule + +_import_structure = {"pipeline_dit": ["DiTPipeline"]} + +if TYPE_CHECKING: + from .pipeline_dit import DiTPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/mindone/diffusers/pipelines/dit/pipeline_dit.py b/mindone/diffusers/pipelines/dit/pipeline_dit.py new file mode 100644 index 0000000000..7fa5d0f24c --- /dev/null +++ b/mindone/diffusers/pipelines/dit/pipeline_dit.py @@ -0,0 +1,233 @@ +# Attribution-NonCommercial 4.0 International (CC BY-NC 4.0) +# William Peebles and Saining Xie +# +# Copyright (c) 2021 OpenAI +# MIT License +# +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np + +import mindspore as ms +from mindspore import ops + +from ...models import AutoencoderKL, Transformer2DModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils.mindspore_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +class DiTPipeline(DiffusionPipeline): + r""" + Pipeline for image generation based on a Transformer backbone instead of a UNet. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + transformer ([`Transformer2DModel`]): + A class conditioned `Transformer2DModel` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + scheduler ([`DDIMScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "transformer->vae" + + def __init__( + self, + transformer: Transformer2DModel, + vae: AutoencoderKL, + scheduler: KarrasDiffusionSchedulers, + id2label: Optional[Dict[int, str]] = None, + ): + super().__init__() + self.register_modules(transformer=transformer, vae=vae, scheduler=scheduler) + + # create a imagenet -> id dictionary for easier use + self.labels = {} + if id2label is not None: + for key, value in id2label.items(): + for label in value.split(","): + self.labels[label.lstrip().rstrip()] = int(key) + self.labels = dict(sorted(self.labels.items())) + + def get_label_ids(self, label: Union[str, List[str]]) -> List[int]: + r""" + + Map label strings from ImageNet to corresponding class ids. + + Parameters: + label (`str` or `dict` of `str`): + Label strings to be mapped to class ids. + + Returns: + `list` of `int`: + Class ids to be processed by pipeline. + """ + + if not isinstance(label, list): + label = list(label) + + for i in label: + if i not in self.labels: + raise ValueError( + f"{i} does not exist. Please make sure to select one of the following labels: \n {self.labels}." + ) + + return [self.labels[i] for i in label] + + def __call__( + self, + class_labels: List[int], + guidance_scale: float = 4.0, + generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None, + num_inference_steps: int = 50, + output_type: Optional[str] = "pil", + return_dict: bool = False, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + The call function to the pipeline for generation. + + Args: + class_labels (List[int]): + List of ImageNet class labels for the images to be generated. + guidance_scale (`float`, *optional*, defaults to 4.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + generator (`np.random.Generator`, *optional*): + A [`np.random.Generator`](https://numpy.org/doc/stable/reference/random/generator.html) to make + generation deterministic. + num_inference_steps (`int`, *optional*, defaults to 250): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + ```py + >>> from mindone.diffusers import DiTPipeline, DPMSolverMultistepScheduler + >>> import mindspore as ms + + >>> import numpy as np + + >>> pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", mindspore_dtype=ms.float16) + >>> pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + + >>> # pick words from Imagenet class labels + >>> pipe.labels # to print all available words + + >>> # pick words that exist in ImageNet + >>> words = ["white shark", "umbrella"] + + >>> class_ids = pipe.get_label_ids(words) + + >>> generator = np.random.default_rng(33) + >>> output = pipe(class_labels=class_ids, num_inference_steps=25, generator=generator) + + >>> image = output[0][0] # label 'white shark' + ``` + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + + batch_size = len(class_labels) + latent_size = self.transformer.config.sample_size + latent_channels = self.transformer.config.in_channels + + latents = randn_tensor( + shape=(batch_size, latent_channels, latent_size, latent_size), + generator=generator, + dtype=self.transformer.dtype, + ) + latent_model_input = ops.cat([latents] * 2) if guidance_scale > 1 else latents + + class_labels = ms.Tensor(class_labels).reshape(-1) + class_null = ms.Tensor([1000] * batch_size) + class_labels_input = ops.cat([class_labels, class_null], 0) if guidance_scale > 1 else class_labels + + # set step values + self.scheduler.set_timesteps(num_inference_steps) + for t in self.progress_bar(self.scheduler.timesteps): + if guidance_scale > 1: + half = latent_model_input[: len(latent_model_input) // 2] + latent_model_input = ops.cat([half, half], axis=0) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + timesteps = t + if not ops.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = False + if isinstance(timesteps, float): + dtype = ms.float32 if is_mps else ms.float64 + else: + dtype = ms.int32 if is_mps else ms.int64 + timesteps = ms.Tensor([timesteps], dtype=dtype) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None] + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.broadcast_to((latent_model_input.shape[0],)) + # predict noise model_output + noise_pred = self.transformer(latent_model_input, timestep=timesteps, class_labels=class_labels_input)[0] + + # perform guidance + if guidance_scale > 1: + eps, rest = noise_pred[:, :latent_channels], noise_pred[:, latent_channels:] + cond_eps, uncond_eps = ops.split(eps, len(eps) // 2, axis=0) + + half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps) + eps = ops.cat([half_eps, half_eps], axis=0) + + noise_pred = ops.cat([eps, rest], axis=1) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + model_output, _ = ops.split(noise_pred, latent_channels, axis=1) + else: + model_output = noise_pred + + # compute previous image: x_t -> x_t-1 + latent_model_input = self.scheduler.step(model_output, t, latent_model_input)[0] + + if guidance_scale > 1: + latents, _ = latent_model_input.chunk(2, axis=0) + else: + latents = latent_model_input + + latents = 1 / self.vae.config.scaling_factor * latents + samples = self.vae.decode(latents)[0] + + samples = (samples / 2 + 0.5).clamp(0, 1) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + samples = samples.permute(0, 2, 3, 1).float().asnumpy() + + if output_type == "pil": + samples = self.numpy_to_pil(samples) + + if not return_dict: + return (samples,) + + return ImagePipelineOutput(images=samples) diff --git a/mindone/diffusers/pipelines/latent_diffusion/__init__.py b/mindone/diffusers/pipelines/latent_diffusion/__init__.py new file mode 100644 index 0000000000..476ca519d0 --- /dev/null +++ b/mindone/diffusers/pipelines/latent_diffusion/__init__.py @@ -0,0 +1,27 @@ +from typing import TYPE_CHECKING + +from ...utils import _LazyModule + +_dummy_objects = {} +_import_structure = {} + +_import_structure["pipeline_latent_diffusion"] = ["LDMBertModel", "LDMTextToImagePipeline"] +_import_structure["pipeline_latent_diffusion_superresolution"] = ["LDMSuperResolutionPipeline"] + + +if TYPE_CHECKING: + from .pipeline_latent_diffusion import LDMBertModel, LDMTextToImagePipeline + from .pipeline_latent_diffusion_superresolution import LDMSuperResolutionPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/mindone/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/mindone/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py new file mode 100644 index 0000000000..567ed6cec8 --- /dev/null +++ b/mindone/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -0,0 +1,786 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import List, Optional, Tuple, Union + +import numpy as np +from transformers import PretrainedConfig, PreTrainedTokenizer +from transformers.modeling_outputs import BaseModelOutput +from transformers.utils import logging + +import mindspore as ms +from mindspore import nn, ops +from mindspore.common.initializer import Constant, Normal, initializer + +from mindone.transformers import MSPreTrainedModel +from mindone.transformers.activations import ACT2FN + +from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils.mindspore_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + +_MIN_FP16 = ms.tensor(np.finfo(np.float16).min, dtype=ms.float16) +_MIN_FP32 = ms.tensor(np.finfo(np.float32).min, dtype=ms.float32) +_MIN_FP64 = ms.tensor(np.finfo(np.float64).min, dtype=ms.float64) +_MIN_BF16 = ms.tensor(float.fromhex("-0x1.fe00000000000p+127"), dtype=ms.bfloat16) + +_MAX_FP16 = ms.tensor(np.finfo(np.float16).max, dtype=ms.float16) +_MAX_FP32 = ms.tensor(np.finfo(np.float32).max, dtype=ms.float32) +_MAX_FP64 = ms.tensor(np.finfo(np.float64).max, dtype=ms.float64) +_MAX_BF16 = ms.tensor(float.fromhex("0x1.fe00000000000p+127"), dtype=ms.bfloat16) + + +def dtype_to_min(dtype): + if dtype == ms.float16: + return _MIN_FP16 + if dtype == ms.float32: + return _MIN_FP32 + if dtype == ms.float64: + return _MIN_FP64 + if dtype == ms.bfloat16: + return _MIN_BF16 + else: + raise ValueError(f"Only support get minimum value of (float16, ), but got {dtype}") + + +def dtype_to_max(dtype): + if dtype == ms.float16: + return _MAX_FP16 + if dtype == ms.float32: + return _MAX_FP32 + if dtype == ms.float64: + return _MAX_FP64 + if dtype == ms.bfloat16: + return _MAX_BF16 + else: + raise ValueError(f"Only support get maximum value of (float16, ), but got {dtype}") + + +class LDMTextToImagePipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using latent diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) model to encode and decode images to and from latent representations. + bert ([`LDMBertModel`]): + Text-encoder model based on [`~transformers.BERT`]. + tokenizer ([`~transformers.BertTokenizer`]): + A `BertTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "bert->unet->vqvae" + + def __init__( + self, + vqvae: Union[VQModel, AutoencoderKL], + bert: MSPreTrainedModel, + tokenizer: PreTrainedTokenizer, + unet: Union[UNet2DModel, UNet2DConditionModel], + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + ): + super().__init__() + self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler) + self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) + + def __call__( + self, + prompt: Union[str, List[str]], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 1.0, + eta: Optional[float] = 0.0, + generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None, + latents: Optional[ms.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = False, + **kwargs, + ) -> Union[Tuple, ImagePipelineOutput]: + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 1.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + generator (`np.random.Generator`, *optional*): + A [`np.random.Generator`](https://numpy.org/doc/stable/reference/random/generator.html) to make + generation deterministic. + latents (`ms.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple. + + Example: + + ```py + >>> from mindone.diffusers import DiffusionPipeline + + >>> # load model and scheduler + >>> ldm = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") + + >>> # run pipeline in inference (sample random noise and denoise) + >>> prompt = "A painting of a squirrel eating a burger" + >>> images = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6)[0] + + >>> # save images + >>> for idx, image in enumerate(images): + ... image.save(f"squirrel-{idx}.png") + ``` + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # get unconditional embeddings for classifier free guidance + if guidance_scale != 1.0: + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=77, truncation=True, return_tensors="np" + ) + uncond_input_ids = ms.Tensor(uncond_input.input_ids) + negative_prompt_embeds = self.bert(uncond_input_ids)[0] + + # get prompt text embeddings + text_input = self.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="np") + text_input_ids = ms.Tensor(text_input.input_ids) + prompt_embeds = self.bert(text_input_ids)[0] + + # get the initial random noise unless the user supplied it + latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(latents_shape, generator=generator, dtype=prompt_embeds.dtype) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + self.scheduler.set_timesteps(num_inference_steps) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + + extra_kwargs = {} + if accepts_eta: + extra_kwargs["eta"] = eta + + for t in self.progress_bar(self.scheduler.timesteps): + if guidance_scale == 1.0: + # guidance_scale of 1 means no guidance + latents_input = latents + context = prompt_embeds + else: + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + latents_input = ops.cat([latents] * 2) + context = ops.cat([negative_prompt_embeds, prompt_embeds]) + + # predict the noise residual + noise_pred = self.unet(latents_input, t, encoder_hidden_states=context)[0] + # perform guidance + if guidance_scale != 1.0: + noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs)[0] + + # scale and decode the image latents with vae + latents = 1 / self.vqvae.config.scaling_factor * latents + image = self.vqvae.decode(latents)[0] + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.permute(0, 2, 3, 1).asnumpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) + + +################################################################################ +# Code for the text transformer model +################################################################################ +""" MindSpore LDMBERT model.""" + + +logger = logging.get_logger(__name__) + +LDMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "ldm-bert", + # See all LDMBert models at https://huggingface.co/models?filter=ldmbert +] + + +LDMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "ldm-bert": "https://huggingface.co/valhalla/ldm-bert/blob/main/config.json", +} + + +""" LDMBERT model configuration""" + + +class LDMBertConfig(PretrainedConfig): + model_type = "ldmbert" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} + + def __init__( + self, + vocab_size=30522, + max_position_embeddings=77, + encoder_layers=32, + encoder_ffn_dim=5120, + encoder_attention_heads=8, + head_dim=64, + encoder_layerdrop=0.0, + activation_function="gelu", + d_model=1280, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + init_std=0.02, + classifier_dropout=0.0, + scale_embedding=False, + use_cache=True, + pad_token_id=0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.head_dim = head_dim + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.encoder_layerdrop = encoder_layerdrop + self.classifier_dropout = classifier_dropout + self.use_cache = use_cache + self.num_hidden_layers = encoder_layers + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + + super().__init__(pad_token_id=pad_token_id, **kwargs) + + +def _expand_mask(mask: ms.Tensor, dtype: ms.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.shape + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(ms.bool_), dtype_to_min(dtype)) + + +# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->LDMBert +class LDMBertAttention(nn.Cell): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + head_dim: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = False, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = head_dim + self.inner_dim = head_dim * num_heads + + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Dense(embed_dim, self.inner_dim, has_bias=bias) + self.v_proj = nn.Dense(embed_dim, self.inner_dim, has_bias=bias) + self.q_proj = nn.Dense(embed_dim, self.inner_dim, has_bias=bias) + self.out_proj = nn.Dense(self.inner_dim, embed_dim) + + def _shape(self, tensor: ms.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).swapaxes(1, 2) + + def construct( + self, + hidden_states: ms.Tensor, + key_value_states: Optional[ms.Tensor] = None, + past_key_value: Optional[Tuple[ms.Tensor]] = None, + attention_mask: Optional[ms.Tensor] = None, + layer_head_mask: Optional[ms.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[ms.Tensor, Optional[ms.Tensor], Optional[Tuple[ms.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.shape + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = ops.cat([past_key_value[0], key_states], axis=2) + value_states = ops.cat([past_key_value[1], value_states], axis=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(ms.Tensor, ms.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(ms.Tensor, ms.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.shape[1] + attn_weights = ops.bmm(query_states, key_states.swapaxes(1, 2)) + + if attn_weights.shape != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.shape}" + ) + + if attention_mask is not None: + if attention_mask.shape != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.shape}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = ops.softmax(attn_weights, axis=-1) + + if layer_head_mask is not None: + if layer_head_mask.shape != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.shape}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = ops.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = ops.bmm(attn_probs, value_states) + + if attn_output.shape != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.shape}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.swapaxes(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.inner_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class LDMBertEncoderLayer(nn.Cell): + def __init__(self, config: LDMBertConfig): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = LDMBertAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + head_dim=config.head_dim, + dropout=config.attention_dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm((self.embed_dim,)) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Dense(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Dense(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm((self.embed_dim,)) + + def construct( + self, + hidden_states: ms.Tensor, + attention_mask: ms.Tensor, + layer_head_mask: ms.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[ms.Tensor, Optional[ms.Tensor]]: + """ + Args: + hidden_states (`ms.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)` + attention_mask (`ms.Tensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`ms.Tensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = ops.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = ops.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = ops.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + if hidden_states.dtype == ms.float16 and (ops.isinf(hidden_states).any() or ops.isnan(hidden_states).any()): + clamp_value = dtype_to_max(hidden_states.dtype) - 1000 + hidden_states = ops.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Copied from transformers.models.bart.modeling_bart.BartPretrainedModel with Bart->LDMBert +class LDMBertPreTrainedModel(MSPreTrainedModel): + config_class = LDMBertConfig + base_model_prefix = "model" + _supports_gradient_checkpointing = True + _keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"] + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Dense): + module.weight.set_data(initializer(Normal(sigma=std, mean=0.0), module.weight.shape, module.weight.dtype)) + if module.bias is not None: + module.bias.set_data(initializer(Constant(0), module.bias.shape, module.bias.dtype)) + elif isinstance(module, nn.Embedding): + module.embedding_table.set_data( + initializer(Normal(sigma=std, mean=0.0), module.embedding_table.shape, module.embedding_table.dtype) + ) + if module.padding_idx is not None: + module.embedding_table[module.padding_idx].set_data( + initializer( + Constant(0), + module.embedding_table[module.padding_idx].shape, + module.embedding_table[module.padding_idx].dtype, + ) + ) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (LDMBertEncoder,)): + module.gradient_checkpointing = value + + @property + def dummy_inputs(self): + pad_token = self.config.pad_token_id + input_ids = ms.Tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]]) + dummy_inputs = { + "attention_mask": input_ids.ne(pad_token), + "input_ids": input_ids, + } + return dummy_inputs + + +class LDMBertEncoder(LDMBertPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`LDMBertEncoderLayer`]. + + Args: + config: LDMBertConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: LDMBertConfig): + super().__init__(config) + + self.dropout = config.dropout + + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + + self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim) + self.embed_positions = nn.Embedding(config.max_position_embeddings, embed_dim) + self.layers = nn.CellList([LDMBertEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layer_norm = nn.LayerNorm((embed_dim,)) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + self.output_attentions = self.config.output_attentions + self.output_hidden_states = self.config.output_hidden_states + self.use_return_dict = self.config.use_return_dict + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def construct( + self, + input_ids: ms.Tensor = None, + attention_mask: Optional[ms.Tensor] = None, + position_ids: Optional[ms.Tensor] = None, + head_mask: Optional[ms.Tensor] = None, + inputs_embeds: Optional[ms.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = False, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + input_ids (`ms.Tensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`ms.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`ms.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.BaseModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states + return_dict = return_dict if return_dict is not None else self.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.shape[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + seq_len = input_shape[1] + if position_ids is None: + # strict graph mode do not support broadcast_to((1, -1)) + position_ids = ops.arange(seq_len, dtype=ms.int32).broadcast_to((1, seq_len)) + embed_pos = self.embed_positions(position_ids) + + hidden_states = inputs_embeds + embed_pos + hidden_states = ops.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.shape[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.shape[0]}." + ) + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + raise NotImplementedError("Gradient checkpointing is not yet supported.") + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + hidden_states = self.layer_norm(hidden_states) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions) + + +class LDMBertModel(LDMBertPreTrainedModel): + _no_split_modules = [] + + def __init__(self, config: LDMBertConfig): + super().__init__(config) + self.model = LDMBertEncoder(config) + self.to_logits = nn.Dense(config.hidden_size, config.vocab_size) + + def construct( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=False, + ): + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + return outputs diff --git a/mindone/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py b/mindone/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py new file mode 100644 index 0000000000..2451697d35 --- /dev/null +++ b/mindone/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py @@ -0,0 +1,188 @@ +import inspect +from typing import List, Optional, Tuple, Union + +import numpy as np +import PIL.Image + +import mindspore as ms +from mindspore import ops + +from ...models import UNet2DModel, VQModel +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ...utils import PIL_INTERPOLATION +from ...utils.mindspore_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +def preprocess(image): + w, h = image.size + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = ms.Tensor(image) + return 2.0 * image - 1.0 + + +class LDMSuperResolutionPipeline(DiffusionPipeline): + r""" + A pipeline for image super-resolution using latent diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) model to encode and decode images to and from latent representations. + unet ([`UNet2DModel`]): + A `UNet2DModel` to denoise the encoded image. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], + [`EulerAncestralDiscreteScheduler`], [`DPMSolverMultistepScheduler`], or [`PNDMScheduler`]. + """ + + def __init__( + self, + vqvae: VQModel, + unet: UNet2DModel, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + ): + super().__init__() + self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler) + + def __call__( + self, + image: Union[ms.Tensor, PIL.Image.Image] = None, + batch_size: Optional[int] = 1, + num_inference_steps: Optional[int] = 100, + eta: Optional[float] = 0.0, + generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = False, + ) -> Union[Tuple, ImagePipelineOutput]: + r""" + The call function to the pipeline for generation. + + Args: + image (`ms.Tensor` or `PIL.Image.Image`): + `Image` or tensor representing an image batch to be used as the starting point for the process. + batch_size (`int`, *optional*, defaults to 1): + Number of images to generate. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`np.random.Generator` or `List[np.random.Generator]`, *optional*): + A [`np.random.Generator`](https://numpy.org/doc/stable/reference/random/generator.html) to make + generation deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple. + + Example: + + ```py + >>> import requests + >>> from PIL import Image + >>> from io import BytesIO + >>> from mindone.diffusers import LDMSuperResolutionPipeline + >>> import mindspore as ms + + >>> # load model and scheduler + >>> pipeline = LDMSuperResolutionPipeline.from_pretrained("CompVis/ldm-super-resolution-4x-openimages") + + >>> # let's download an image + >>> url = ( + ... "https://user-images.githubusercontent.com/38061659/199705896-b48e17b8-b231-47cd-a270-4ffa5a93fa3e.png" + ... ) + >>> response = requests.get(url) + >>> low_res_img = Image.open(BytesIO(response.content)).convert("RGB") + >>> low_res_img = low_res_img.resize((128, 128)) + + >>> # run pipeline in inference (sample random noise and denoise) + >>> upscaled_image = pipeline(low_res_img, num_inference_steps=100, eta=1)[0][0] + >>> # save image + >>> upscaled_image.save("ldm_generated_image.png") + ``` + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, ms.Tensor): + batch_size = image.shape[0] + else: + raise ValueError(f"`image` has to be of type `PIL.Image.Image` or `ms.Tensor` but is {type(image)}") + + if isinstance(image, PIL.Image.Image): + image = preprocess(image) + + height, width = image.shape[-2:] + + # in_channels should be 6: 3 for latents, 3 for low resolution image + latents_shape = (batch_size, self.unet.config.in_channels // 2, height, width) + latents_dtype = next(self.unet.get_parameters()).dtype + + latents = randn_tensor(latents_shape, generator=generator, dtype=latents_dtype) + + image = image.to(dtype=latents_dtype) + + # set timesteps and move to the correct device + self.scheduler.set_timesteps(num_inference_steps) + timesteps_tensor = self.scheduler.timesteps + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature. + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_kwargs = {} + if accepts_eta: + extra_kwargs["eta"] = eta + + for t in self.progress_bar(timesteps_tensor): + # concat latents and low resolution image in the channel dimension. + latents_input = ops.cat([latents, image], axis=1) + latents_input = self.scheduler.scale_model_input(latents_input, t) + # predict the noise residual + noise_pred = self.unet(latents_input, t)[0] + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs)[0] + + # decode the image latents with the VQVAE + image = self.vqvae.decode(latents)[0] + image = ops.clamp(image, -1.0, 1.0) + image = image / 2 + 0.5 + image = image.permute(0, 2, 3, 1).asnumpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/mindone/diffusers/pipelines/stable_diffusion_diffedit/__init__.py b/mindone/diffusers/pipelines/stable_diffusion_diffedit/__init__.py new file mode 100644 index 0000000000..3b65e4a5ea --- /dev/null +++ b/mindone/diffusers/pipelines/stable_diffusion_diffedit/__init__.py @@ -0,0 +1,25 @@ +from typing import TYPE_CHECKING + +from ...utils import _LazyModule + +_dummy_objects = {} +_import_structure = {} + + +_import_structure["pipeline_stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"] + +if TYPE_CHECKING: + from .pipeline_stable_diffusion_diffedit import StableDiffusionDiffEditPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/mindone/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py b/mindone/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py new file mode 100644 index 0000000000..761f26603c --- /dev/null +++ b/mindone/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py @@ -0,0 +1,1446 @@ +# Copyright 2024 DiffEdit Authors and Pix2Pix Zero Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +import mindspore as ms +from mindspore import ops + +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import DDIMInverseScheduler, KarrasDiffusionSchedulers +from ...utils import PIL_INTERPOLATION, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils.mindspore_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class DiffEditInversionPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + latents (`ms.Tensor`) + inverted latents tensor + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `num_timesteps * batch_size` or numpy array of shape `(num_timesteps, + batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the + diffusion pipeline. + """ + + latents: ms.Tensor + images: Union[List[PIL.Image.Image], np.ndarray] + + +EXAMPLE_DOC_STRING = """ + + ```py + >>> import PIL + >>> import requests + >>> import mindspore as ms + >>> from io import BytesIO + + >>> from mindone.diffusers import StableDiffusionDiffEditPipeline, DDIMScheduler + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" + + >>> init_image = download_image(img_url).resize((768, 768)) + + >>> pipe = StableDiffusionDiffEditPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-2-1", mindspore_dtype=ms.float16 + ... ) + + >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) + + >>> mask_prompt = "A bowl of fruits" + >>> prompt = "A bowl of pears" + + >>> mask_image = pipe.generate_mask(image=init_image, source_prompt=prompt, target_prompt=mask_prompt) + >>> image_latents = pipe.invert(image=init_image, prompt=mask_prompt)[0] + >>> image = pipe(prompt=prompt, mask_image=mask_image, image_latents=image_latents)[0][0] + ``` +""" + +EXAMPLE_INVERT_DOC_STRING = """ + ```py + >>> import PIL + >>> import requests + >>> import mindspore as ms + >>> from io import BytesIO + + >>> from mindone.diffusers import StableDiffusionDiffEditPipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" + + >>> init_image = download_image(img_url).resize((768, 768)) + + >>> pipe = StableDiffusionDiffEditPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-2-1", mindspore_dtype=ms.float16 + ... ) + + >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) + + >>> prompt = "A bowl of fruits" + + >>> inverted_latents = pipe.invert(image=init_image, prompt=prompt)[0] + ``` +""" + + +def auto_corr_loss(hidden_states, generator=None): + reg_loss = 0.0 + for i in range(hidden_states.shape[0]): + for j in range(hidden_states.shape[1]): + noise = hidden_states[i : i + 1, j : j + 1, :, :] + while True: + roll_amount = ops.randint(noise.shape[2] // 2, (1,), generator=generator).item() + reg_loss += (noise * ops.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2 + reg_loss += (noise * ops.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2 + + if noise.shape[2] <= 8: + break + noise = ops.avg_pool2d(noise, kernel_size=2, stride=2) + return reg_loss + + +def kl_divergence(hidden_states): + return hidden_states.var() + hidden_states.mean() ** 2 - 1 - ops.log(hidden_states.var() + 1e-7) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess +def preprocess(image): + deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" + deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) + if isinstance(image, ms.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = ms.Tensor(image) + elif isinstance(image[0], ms.Tensor): + image = ops.cat(image, axis=0) + return image + + +def preprocess_mask(mask, batch_size: int = 1): + if not isinstance(mask, ms.Tensor): + # preprocess mask + if isinstance(mask, PIL.Image.Image) or isinstance(mask, np.ndarray): + mask = [mask] + + if isinstance(mask, list): + if isinstance(mask[0], PIL.Image.Image): + mask = [np.array(m.convert("L")).astype(np.float32) / 255.0 for m in mask] + if isinstance(mask[0], np.ndarray): + mask = np.stack(mask, axis=0) if mask[0].ndim < 3 else np.concatenate(mask, axis=0) + mask = ms.Tensor(mask) + elif isinstance(mask[0], ms.Tensor): + mask = ops.stack(mask, axis=0) if mask[0].ndim < 3 else ops.cat(mask, axis=0) + + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Single batched mask, no channel dim or single mask not batched but channel dim + if mask.shape[0] == 1: + mask = mask.unsqueeze(0) + + # Batched masks no channel dim + else: + mask = mask.unsqueeze(1) + + # Check mask shape + if batch_size > 1: + if mask.shape[0] == 1: + mask = ops.cat([mask] * batch_size) + elif mask.shape[0] > 1 and mask.shape[0] != batch_size: + raise ValueError( + f"`mask_image` with batch size {mask.shape[0]} cannot be broadcasted to batch size {batch_size} " + f"inferred by prompt inputs" + ) + + if mask.shape[1] != 1: + raise ValueError(f"`mask_image` must have 1 channel, but has {mask.shape[1]} channels") + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("`mask_image` should be in [0, 1] range") + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + return mask + + +class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): + r""" + + + This is an experimental feature! + + + + Pipeline for text-guided image inpainting using Stable Diffusion and DiffEdit. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading and saving methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + inverse_scheduler ([`DDIMInverseScheduler`]): + A scheduler to be used in combination with `unet` to fill in the unmasked part of the input latents. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "inverse_scheduler"] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + inverse_scheduler: DDIMInverseScheduler, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration" + " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make" + " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to" + " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face" + " Hub, it would be very nice if you could open a Pull request for the" + " `scheduler/scheduler_config.json` file" + ) + deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["skip_prk_steps"] = True + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + inverse_scheduler=inverse_scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[ms.Tensor] = None, + negative_prompt_embeds: Optional[ms.Tensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = ( + "`_encode_prompt()` is deprecated and it will be removed in a future version. Use" + "`encode_prompt()` instead. Also, be aware that the output format changed from a" + "concatenated tensor to a tuple." + ) + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = ops.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[ms.Tensor] = None, + negative_prompt_embeds: Optional[ms.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_input_ids = ms.Tensor(text_inputs.input_ids) + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="np").input_ids + untruncated_ids = ms.Tensor(untruncated_ids) + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not ops.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids, attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids, attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.tile((1, num_images_per_prompt, 1)) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="np", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = ms.Tensor(uncond_input.attention_mask) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + ms.Tensor(uncond_input.input_ids), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype) + + negative_prompt_embeds = negative_prompt_embeds.tile((1, num_images_per_prompt, 1)) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if isinstance(self, LoraLoaderMixin): + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if ops.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt") + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.permute(0, 2, 3, 1).float().numpy() + return image + + def check_inputs( + self, + prompt, + strength, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (strength is None) or (strength is not None and (strength < 0 or strength > 1)): + raise ValueError( + f"The value of `strength` should in [0.0, 1.0] but is, but is {strength} of type {type(strength)}." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def check_source_inputs( + self, + source_prompt=None, + source_negative_prompt=None, + source_prompt_embeds=None, + source_negative_prompt_embeds=None, + ): + if source_prompt is not None and source_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `source_prompt`: {source_prompt} and `source_prompt_embeds`: {source_prompt_embeds}." + " Please make sure to only forward one of the two." + ) + elif source_prompt is None and source_prompt_embeds is None: + raise ValueError( + "Provide either `source_image` or `source_prompt_embeds`. Cannot leave all both of the arguments undefined." + ) + elif source_prompt is not None and (not isinstance(source_prompt, str) and not isinstance(source_prompt, list)): + raise ValueError(f"`source_prompt` has to be of type `str` or `list` but is {type(source_prompt)}") + + if source_negative_prompt is not None and source_negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `source_negative_prompt`: {source_negative_prompt} and `source_negative_prompt_embeds`:" + f" {source_negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if source_prompt_embeds is not None and source_negative_prompt_embeds is not None: + if source_prompt_embeds.shape != source_negative_prompt_embeds.shape: + raise ValueError( + "`source_prompt_embeds` and `source_negative_prompt_embeds` must have the same shape when passed" + f" directly, but got: `source_prompt_embeds` {source_prompt_embeds.shape} !=" + f" `source_negative_prompt_embeds` {source_negative_prompt_embeds.shape}." + ) + + def get_timesteps(self, num_inference_steps, strength): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + def get_inverse_timesteps(self, num_inference_steps, strength): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + + # safety for t_start overflow to prevent empty timsteps slice + if t_start == 0: + return self.inverse_scheduler.timesteps, num_inference_steps + timesteps = self.inverse_scheduler.timesteps[:-t_start] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_image_latents(self, image, batch_size, dtype, generator=None): + if not isinstance(image, (ms.Tensor, PIL.Image.Image, list)): + raise ValueError(f"`image` has to be of type `ms.Tensor`, `PIL.Image.Image` or list but is {type(image)}") + + image = image.to(dtype=dtype) + + if image.shape[1] == 4: + latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i], generator=generator) + for i in range(batch_size) + ] + latents = ops.cat(latents, axis=0) + else: + latents = self.vae.diag_gauss_dist.sample(self.vae.encode(image)[0], generator=generator) + + latents = self.vae.config.scaling_factor * latents + + if batch_size != latents.shape[0]: + if batch_size % latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_latents_per_image = batch_size // latents.shape[0] + latents = ops.cat([latents] * additional_latents_per_image, axis=0) + else: + raise ValueError( + f"Cannot duplicate `image` of batch size {latents.shape[0]} to {batch_size} text prompts." + ) + else: + latents = ops.cat([latents], axis=0) + + return latents + + def get_epsilon(self, model_output: ms.Tensor, sample: ms.Tensor, timestep: int): + pred_type = self.inverse_scheduler.config.prediction_type + alpha_prod_t = self.inverse_scheduler.alphas_cumprod[timestep] + + beta_prod_t = 1 - alpha_prod_t + + if pred_type == "epsilon": + return model_output + elif pred_type == "sample": + return (sample - alpha_prod_t ** (0.5) * model_output) / beta_prod_t ** (0.5) + elif pred_type == "v_prediction": + return (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {pred_type} must be one of `epsilon`, `sample`, or `v_prediction`" + ) + + def generate_mask( + self, + image: Union[ms.Tensor, PIL.Image.Image] = None, + target_prompt: Optional[Union[str, List[str]]] = None, + target_negative_prompt: Optional[Union[str, List[str]]] = None, + target_prompt_embeds: Optional[ms.Tensor] = None, + target_negative_prompt_embeds: Optional[ms.Tensor] = None, + source_prompt: Optional[Union[str, List[str]]] = None, + source_negative_prompt: Optional[Union[str, List[str]]] = None, + source_prompt_embeds: Optional[ms.Tensor] = None, + source_negative_prompt_embeds: Optional[ms.Tensor] = None, + num_maps_per_mask: Optional[int] = 10, + mask_encode_strength: Optional[float] = 0.5, + mask_thresholding_ratio: Optional[float] = 3.0, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None, + output_type: Optional[str] = "np", + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Generate a latent mask given a mask prompt, a target prompt, and an image. + + Args: + image (`PIL.Image.Image`): + `Image` or tensor representing an image batch to be used for computing the mask. + target_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide semantic mask generation. If not defined, you need to pass + `prompt_embeds`. + target_negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + target_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + target_negative_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + source_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide semantic mask generation using DiffEdit. If not defined, you need to + pass `source_prompt_embeds` or `source_image` instead. + source_negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide semantic mask generation away from using DiffEdit. If not defined, you + need to pass `source_negative_prompt_embeds` or `source_image` instead. + source_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated text embeddings to guide the semantic mask generation. Can be used to easily tweak text + inputs (prompt weighting). If not provided, text embeddings are generated from `source_prompt` input + argument. + source_negative_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated text embeddings to negatively guide the semantic mask generation. Can be used to easily + tweak text inputs (prompt weighting). If not provided, text embeddings are generated from + `source_negative_prompt` input argument. + num_maps_per_mask (`int`, *optional*, defaults to 10): + The number of noise maps sampled to generate the semantic mask using DiffEdit. + mask_encode_strength (`float`, *optional*, defaults to 0.5): + The strength of the noise maps sampled to generate the semantic mask using DiffEdit. Must be between 0 + and 1. + mask_thresholding_ratio (`float`, *optional*, defaults to 3.0): + The maximum multiple of the mean absolute difference used to clamp the semantic guidance map before + mask binarization. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + generator (`np.random.Generator` or `List[np.random.Generator]`, *optional*): + A [`np.random.Generator`](https://numpy.org/doc/stable/reference/random/generator.html) to make + generation deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the + [`~models.attention_processor.AttnProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + Examples: + + Returns: + `List[PIL.Image.Image]` or `np.array`: + When returning a `List[PIL.Image.Image]`, the list consists of a batch of single-channel binary images + with dimensions `(height // self.vae_scale_factor, width // self.vae_scale_factor)`. If it's + `np.array`, the shape is `(batch_size, height // self.vae_scale_factor, width // + self.vae_scale_factor)`. + """ + + # 1. Check inputs (Provide dummy argument for callback_steps) + self.check_inputs( + target_prompt, + mask_encode_strength, + 1, + target_negative_prompt, + target_prompt_embeds, + target_negative_prompt_embeds, + ) + + self.check_source_inputs( + source_prompt, + source_negative_prompt, + source_prompt_embeds, + source_negative_prompt_embeds, + ) + + if (num_maps_per_mask is None) or ( + num_maps_per_mask is not None and (not isinstance(num_maps_per_mask, int) or num_maps_per_mask <= 0) + ): + raise ValueError( + f"`num_maps_per_mask` has to be a positive integer but is {num_maps_per_mask} of type" + f" {type(num_maps_per_mask)}." + ) + + if mask_thresholding_ratio is None or mask_thresholding_ratio <= 0: + raise ValueError( + f"`mask_thresholding_ratio` has to be positive but is {mask_thresholding_ratio} of type" + f" {type(mask_thresholding_ratio)}." + ) + + # 2. Define call parameters + if target_prompt is not None and isinstance(target_prompt, str): + batch_size = 1 + elif target_prompt is not None and isinstance(target_prompt, list): + batch_size = len(target_prompt) + else: + batch_size = target_prompt_embeds.shape[0] + if cross_attention_kwargs is None: + cross_attention_kwargs = {} + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompts + (cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None) + target_negative_prompt_embeds, target_prompt_embeds = self.encode_prompt( + target_prompt, + num_maps_per_mask, + do_classifier_free_guidance, + target_negative_prompt, + prompt_embeds=target_prompt_embeds, + negative_prompt_embeds=target_negative_prompt_embeds, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + target_prompt_embeds = ops.cat([target_negative_prompt_embeds, target_prompt_embeds]) + + source_negative_prompt_embeds, source_prompt_embeds = self.encode_prompt( + source_prompt, + num_maps_per_mask, + do_classifier_free_guidance, + source_negative_prompt, + prompt_embeds=source_prompt_embeds, + negative_prompt_embeds=source_negative_prompt_embeds, + ) + if do_classifier_free_guidance: + source_prompt_embeds = ops.cat([source_negative_prompt_embeds, source_prompt_embeds]) + + # 4. Preprocess image + image = self.image_processor.preprocess(image).repeat_interleave(num_maps_per_mask, dim=0) + + # 5. Set timesteps + self.scheduler.set_timesteps(num_inference_steps) + timesteps, _ = self.get_timesteps(num_inference_steps, mask_encode_strength) + encode_timestep = timesteps[0] + + # 6. Prepare image latents and add noise with specified strength + image_latents = self.prepare_image_latents(image, batch_size * num_maps_per_mask, self.vae.dtype, generator) + noise = randn_tensor(image_latents.shape, generator=generator, dtype=self.vae.dtype) + image_latents = self.scheduler.add_noise(image_latents, noise, encode_timestep) + + latent_model_input = ops.cat([image_latents] * (4 if do_classifier_free_guidance else 2)) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, encode_timestep) + + # 7. Predict the noise residual + prompt_embeds = ops.cat([source_prompt_embeds, target_prompt_embeds]) + noise_pred = self.unet( + latent_model_input, + encode_timestep, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + )[0] + + if do_classifier_free_guidance: + noise_pred_neg_src, noise_pred_source, noise_pred_uncond, noise_pred_target = noise_pred.chunk(4) + noise_pred_source = noise_pred_neg_src + guidance_scale * (noise_pred_source - noise_pred_neg_src) + noise_pred_target = noise_pred_uncond + guidance_scale * (noise_pred_target - noise_pred_uncond) + else: + noise_pred_source, noise_pred_target = noise_pred.chunk(2) + + # 8. Compute the mask from the absolute difference of predicted noise residuals + # TODO: Consider smoothing mask guidance map + mask_guidance_map = ( + ops.abs(noise_pred_target - noise_pred_source) + .reshape(batch_size, num_maps_per_mask, *noise_pred_target.shape[-3:]) + .mean([1, 2]) + ) + clamp_magnitude = mask_guidance_map.mean() * mask_thresholding_ratio + semantic_mask_image = mask_guidance_map.clamp(ms.Tensor(0), clamp_magnitude) / clamp_magnitude + semantic_mask_image = ops.where(semantic_mask_image <= 0.5, ms.Tensor(0), ms.Tensor(1)) + mask_image = semantic_mask_image.asnumpy() + + # 9. Convert to Numpy array or PIL. + if output_type == "pil": + mask_image = self.image_processor.numpy_to_pil(mask_image) + + return mask_image + + def invert( + self, + prompt: Optional[Union[str, List[str]]] = None, + image: Union[ms.Tensor, PIL.Image.Image] = None, + num_inference_steps: int = 50, + inpaint_strength: float = 0.8, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None, + prompt_embeds: Optional[ms.Tensor] = None, + negative_prompt_embeds: Optional[ms.Tensor] = None, + decode_latents: bool = False, + output_type: Optional[str] = "pil", + return_dict: bool = False, + callback: Optional[Callable[[int, int, ms.Tensor], None]] = None, + callback_steps: Optional[int] = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + lambda_auto_corr: float = 20.0, + lambda_kl: float = 20.0, + num_reg_steps: int = 0, + num_auto_corr_rolls: int = 5, + ): + r""" + Generate inverted latents given a prompt and image. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`PIL.Image.Image`): + `Image` or tensor representing an image batch to produce the inverted latents guided by `prompt`. + inpaint_strength (`float`, *optional*, defaults to 0.8): + Indicates extent of the noising process to run latent inversion. Must be between 0 and 1. When + `inpaint_strength` is 1, the inversion process is run for the full number of iterations specified in + `num_inference_steps`. `image` is used as a reference for the inversion process, and adding more noise + increases `inpaint_strength`. If `inpaint_strength` is 0, no inpainting occurs. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + generator (`np.random.Generator`, *optional*): + A [`np.random.Generator`](https://numpy.org/doc/stable/reference/random/generator.html) to make + generation deterministic. + prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + decode_latents (`bool`, *optional*, defaults to `False`): + Whether or not to decode the inverted latents into a generated image. Setting this argument to `True` + decodes all inverted latents for each timestep into a list of generated images. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.DiffEditInversionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: ms.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the + [`~models.attention_processor.AttnProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + lambda_auto_corr (`float`, *optional*, defaults to 20.0): + Lambda parameter to control auto correction. + lambda_kl (`float`, *optional*, defaults to 20.0): + Lambda parameter to control Kullback-Leibler divergence output. + num_reg_steps (`int`, *optional*, defaults to 0): + Number of regularization loss steps. + num_auto_corr_rolls (`int`, *optional*, defaults to 5): + Number of auto correction roll steps. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.pipeline_stable_diffusion_diffedit.DiffEditInversionPipelineOutput`] or + `tuple`: + If `return_dict` is `True`, + [`~pipelines.stable_diffusion.pipeline_stable_diffusion_diffedit.DiffEditInversionPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is the inverted latents tensors + ordered by increasing noise, and the second is the corresponding decoded images if `decode_latents` is + `True`, otherwise `None`. + """ + + # 1. Check inputs + self.check_inputs( + prompt, + inpaint_strength, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + if cross_attention_kwargs is None: + cross_attention_kwargs = {} + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Preprocess image + image = self.image_processor.preprocess(image) + + # 4. Prepare latent variables + num_images_per_prompt = 1 + latents = self.prepare_image_latents(image, batch_size * num_images_per_prompt, self.vae.dtype, generator) + + # 5. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = ops.cat([negative_prompt_embeds, prompt_embeds]) + + # 6. Prepare timesteps + self.inverse_scheduler.set_timesteps(num_inference_steps) + timesteps, num_inference_steps = self.get_inverse_timesteps(num_inference_steps, inpaint_strength) + + # 7. Noising loop where we obtain the intermediate noised latent image for each timestep. + num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order + inverted_latents = [] + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = ops.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # regularization of the noise prediction (not in original code or paper but borrowed from Pix2PixZero) + if num_reg_steps > 0: + raise NotImplementedError("Regularization of the noise prediction is not yet supported.") + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.inverse_scheduler.step(noise_pred, t, latents)[0] + inverted_latents.append(latents) + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.inverse_scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + assert len(inverted_latents) == len(timesteps) + latents = ops.stack(list(reversed(inverted_latents)), 1) + + # 8. Post-processing + image = None + if decode_latents: + image = self.decode_latents(latents.flatten(0, 1)) + + # 9. Convert to PIL. + if decode_latents and output_type == "pil": + image = self.image_processor.numpy_to_pil(image) + + if not return_dict: + return (latents, image) + + return DiffEditInversionPipelineOutput(latents=latents, images=image) + + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + mask_image: Union[ms.Tensor, PIL.Image.Image] = None, + image_latents: Union[ms.Tensor, PIL.Image.Image] = None, + inpaint_strength: Optional[float] = 0.8, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None, + latents: Optional[ms.Tensor] = None, + prompt_embeds: Optional[ms.Tensor] = None, + negative_prompt_embeds: Optional[ms.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = False, + callback: Optional[Callable[[int, int, ms.Tensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_ckip: int = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + mask_image (`PIL.Image.Image`): + `Image` or tensor representing an image batch to mask the generated image. White pixels in the mask are + repainted, while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, 1, H, W)`. + image_latents (`PIL.Image.Image` or `ms.Tensor`): + Partially noised image latents from the inversion process to be used as inputs for image generation. + inpaint_strength (`float`, *optional*, defaults to 0.8): + Indicates extent to inpaint the masked area. Must be between 0 and 1. When `inpaint_strength` is 1, the + denoising process is run on the masked area for the full number of iterations specified in + `num_inference_steps`. `image_latents` is used as a reference for the masked area, and adding more + noise to a region increases `inpaint_strength`. If `inpaint_strength` is 0, no inpainting occurs. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`np.random.Generator`, *optional*): + A [`np.random.Generator`](https://numpy.org/doc/stable/reference/random/generator.html) to make + generation deterministic. + latents (`ms.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: ms.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + # 1. Check inputs + self.check_inputs( + prompt, + inpaint_strength, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + if mask_image is None: + raise ValueError( + "`mask_image` input cannot be undefined. Use `generate_mask()` to compute `mask_image` from text prompts." + ) + if image_latents is None: + raise ValueError( + "`image_latents` input cannot be undefined. Use `invert()` to compute `image_latents` from input images." + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + if cross_attention_kwargs is None: + cross_attention_kwargs = {} + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_ckip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = ops.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Preprocess mask + mask_image = preprocess_mask(mask_image, batch_size) + latent_height, latent_width = mask_image.shape[-2:] + mask_image = ops.cat([mask_image] * num_images_per_prompt) + mask_image = mask_image.to(prompt_embeds.dtype) + + # 5. Set timesteps + self.scheduler.set_timesteps(num_inference_steps) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, inpaint_strength) + + # 6. Preprocess image latents + if isinstance(image_latents, list) and any(isinstance(i, ms.Tensor) and i.ndim == 5 for i in image_latents): + image_latents = ops.cat(image_latents) + elif isinstance(image_latents, ms.Tensor) and image_latents.ndim == 5: + image_latents = image_latents + else: + image_latents = self.image_processor.preprocess(image_latents) + + latent_shape = (self.vae.config.latent_channels, latent_height, latent_width) + if image_latents.shape[-3:] != latent_shape: + raise ValueError( + f"Each latent image in `image_latents` must have shape {latent_shape}, " + f"but has shape {image_latents.shape[-3:]}" + ) + if image_latents.ndim == 4: + image_latents = image_latents.reshape(batch_size, len(timesteps), *latent_shape) + if image_latents.shape[:2] != (batch_size, len(timesteps)): + raise ValueError( + f"`image_latents` must have batch size {batch_size} with latent images from {len(timesteps)}" + f" timesteps, but has batch size {image_latents.shape[0]} with latent images from" + f" {image_latents.shape[1]} timesteps." + ) + image_latents = image_latents.swapaxes(0, 1).repeat_interleave(num_images_per_prompt, dim=1) + image_latents = image_latents.to(prompt_embeds.dtype) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + latents = image_latents[0] + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = ops.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)[0] + + # mask with inverted latents from appropriate timestep - use original image latent for last step + latents = latents * mask_image + image_latents[i] * (1 - mask_image) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/mindone/diffusers/schedulers/scheduling_ddim.py b/mindone/diffusers/schedulers/scheduling_ddim.py index 84077d1b4e..02652bec34 100644 --- a/mindone/diffusers/schedulers/scheduling_ddim.py +++ b/mindone/diffusers/schedulers/scheduling_ddim.py @@ -494,14 +494,16 @@ def add_noise( sqrt_alpha_prod = sqrt_alpha_prod.flatten() # while len(sqrt_alpha_prod.shape) < len(original_samples.shape): # sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - sqrt_alpha_prod = ops.reshape(sqrt_alpha_prod, (timesteps.shape[0],) + (1,) * (len(broadcast_shape) - 1)) + sqrt_alpha_prod = ops.reshape( + sqrt_alpha_prod, (timesteps.reshape((-1,)).shape[0],) + (1,) * (len(broadcast_shape) - 1) + ) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() # while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): # sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = ops.reshape( - sqrt_one_minus_alpha_prod, (timesteps.shape[0],) + (1,) * (len(broadcast_shape) - 1) + sqrt_one_minus_alpha_prod, (timesteps.reshape((-1,)).shape[0],) + (1,) * (len(broadcast_shape) - 1) ) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise