diff --git a/src/unet_spatio_temporal_condition.py b/src/unet_spatio_temporal_condition.py new file mode 100644 index 0000000..f40352a --- /dev/null +++ b/src/unet_spatio_temporal_condition.py @@ -0,0 +1,490 @@ +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import UNet2DConditionLoadersMixin, PeftAdapterMixin +from diffusers.utils import BaseOutput, logging +from diffusers.models.attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.unets.unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNetSpatioTemporalConditionOutput(BaseOutput): + """ + The output of [`UNetSpatioTemporalConditionModel`]. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_frames, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.Tensor = None + + +class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): + r""" + A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and + returns a sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`): + The tuple of downsample blocks to use. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`): + The tuple of upsample blocks to use. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + addition_time_embed_dim: (`int`, defaults to 256): + Dimension to to encode the additional time ids. + projection_class_embeddings_input_dim (`int`, defaults to 768): + The dimension of the projection of encoded `added_time_ids`. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unets.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], + [`~models.unets.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`], + [`~models.unets.unet_3d_blocks.UNetMidBlockSpatioTemporal`]. + num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`): + The number of attention heads. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 8, + out_channels: int = 4, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlockSpatioTemporal", + "CrossAttnDownBlockSpatioTemporal", + "CrossAttnDownBlockSpatioTemporal", + "DownBlockSpatioTemporal", + ), + up_block_types: Tuple[str] = ( + "UpBlockSpatioTemporal", + "CrossAttnUpBlockSpatioTemporal", + "CrossAttnUpBlockSpatioTemporal", + "CrossAttnUpBlockSpatioTemporal", + ), + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + addition_time_embed_dim: int = 256, + projection_class_embeddings_input_dim: int = 768, + layers_per_block: Union[int, Tuple[int]] = 2, + cross_attention_dim: Union[int, Tuple[int]] = 1024, + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, + num_attention_heads: Union[int, Tuple[int]] = (5, 10, 20, 20), + num_frames: int = 25, + ): + super().__init__() + + self.sample_size = sample_size + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + + # input + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[0], + kernel_size=3, + padding=1, + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=1e-5, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + resnet_act_fn="silu", + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlockSpatioTemporal( + block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + transformer_layers_per_block=transformer_layers_per_block[-1], + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=1e-5, + resolution_idx=i, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + resnet_act_fn="silu", + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5) + self.conv_act = nn.SiLU() + + self.conv_out = nn.Conv2d( + block_out_channels[0], + out_channels, + kernel_size=3, + padding=1, + ) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors( + name: str, + module: torch.nn.Module, + processors: Dict[str, AttentionProcessor], + ): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, chunk_size, dim) + + def forward( + self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + added_time_ids: torch.Tensor, + return_dict: bool = True, + ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]: + r""" + The [`UNetSpatioTemporalConditionModel`] forward method. + + Args: + sample (`torch.Tensor`): + The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`. + timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`. + added_time_ids: (`torch.Tensor`): + The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal + embeddings and added to the time embeddings. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead + of a plain tuple. + Returns: + [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is + returned, otherwise a `tuple` is returned where the first element is the sample tensor. + """ + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + batch_size, num_frames = sample.shape[:2] + timesteps = timesteps.expand(batch_size) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb) + + time_embeds = self.add_time_proj(added_time_ids.flatten()) + time_embeds = time_embeds.reshape((batch_size, -1)) + time_embeds = time_embeds.to(emb.dtype) + aug_emb = self.add_embedding(time_embeds) + emb = emb + aug_emb + + # Flatten the batch and frames dimensions + # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width] + sample = sample.flatten(0, 1) + # Repeat the embeddings num_video_frames times + # emb: [batch, channels] -> [batch * frames, channels] + emb = emb.repeat_interleave(num_frames, dim=0) + # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels] + encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0) + + # 2. pre-process + sample = self.conv_in(sample) + + image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device) + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + ) + else: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + image_only_indicator=image_only_indicator, + ) + + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + ) + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + image_only_indicator=image_only_indicator, + ) + + # 6. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + # 7. Reshape back to original shape + sample = sample.reshape(batch_size, num_frames, *sample.shape[1:]) + + if not return_dict: + return (sample,) + + return UNetSpatioTemporalConditionOutput(sample=sample) diff --git a/train_svd.py b/train_svd.py index 0f72f3b..eb399a7 100644 --- a/train_svd.py +++ b/train_svd.py @@ -61,58 +61,11 @@ logger = get_logger(__name__, log_level="INFO") # copy from https://github.com/crowsonkb/k-diffusion.git -def stratified_uniform(shape, group=0, groups=1, dtype=None, device=None): - """Draws stratified samples from a uniform distribution.""" - if groups <= 0: - raise ValueError(f"groups must be positive, got {groups}") - if group < 0 or group >= groups: - raise ValueError(f"group must be in [0, {groups})") - n = shape[-1] * groups - offsets = torch.arange(group, n, groups, dtype=dtype, device=device) - u = torch.rand(shape, dtype=dtype, device=device) - return (offsets + u) / n - - -def rand_cosine_interpolated(shape, image_d, noise_d_low, noise_d_high, sigma_data=1., min_value=1e-3, max_value=1e3, device='cpu', dtype=torch.float32): - """Draws samples from an interpolated cosine timestep distribution (from simple diffusion).""" - - def logsnr_schedule_cosine(t, logsnr_min, logsnr_max): - t_min = math.atan(math.exp(-0.5 * logsnr_max)) - t_max = math.atan(math.exp(-0.5 * logsnr_min)) - return -2 * torch.log(torch.tan(t_min + t * (t_max - t_min))) - - def logsnr_schedule_cosine_shifted(t, image_d, noise_d, logsnr_min, logsnr_max): - shift = 2 * math.log(noise_d / image_d) - return logsnr_schedule_cosine(t, logsnr_min - shift, logsnr_max - shift) + shift - - def logsnr_schedule_cosine_interpolated(t, image_d, noise_d_low, noise_d_high, logsnr_min, logsnr_max): - logsnr_low = logsnr_schedule_cosine_shifted( - t, image_d, noise_d_low, logsnr_min, logsnr_max) - logsnr_high = logsnr_schedule_cosine_shifted( - t, image_d, noise_d_high, logsnr_min, logsnr_max) - return torch.lerp(logsnr_low, logsnr_high, t) - - logsnr_min = -2 * math.log(min_value / sigma_data) - logsnr_max = -2 * math.log(max_value / sigma_data) - u = stratified_uniform( - shape, group=0, groups=1, dtype=dtype, device=device - ) - logsnr = logsnr_schedule_cosine_interpolated( - u, image_d, noise_d_low, noise_d_high, logsnr_min, logsnr_max) - return torch.exp(-logsnr / 2) * sigma_data - def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32): """Draws samples from an lognormal distribution.""" u = torch.rand(shape, dtype=dtype, device=device) * (1 - 2e-7) + 1e-7 return torch.distributions.Normal(loc, scale).icdf(u).exp() -# min_value = 0.002 -# max_value = 700 -# image_d = 64 -# noise_d_low = 32 -# noise_d_high = 64 -# sigma_data = 0.5 - class DummyDataset(Dataset): def __init__(self, base_folder: str, num_samples=100000, width=1024, height=576, sample_frames=25): @@ -340,7 +293,7 @@ def tensor_to_vae_latent(t, vae): def parse_args(): parser = argparse.ArgumentParser( - description="Script to train Stable Diffusion XL for InstructPix2Pix." + description="Script to train Stable Video Diffusion." ) parser.add_argument( "--base_folder", @@ -361,12 +314,6 @@ def parse_args(): required=False, help="Revision of pretrained model identifier from huggingface.co/models.", ) - parser.add_argument( - "--validation_prompt", - type=str, - default=None, - help="A prompt that is sampled during training for inference.", - ) parser.add_argument( "--num_frames", type=int, @@ -609,12 +556,6 @@ def parse_args(): default=None, help="use weight for unet block", ) - parser.add_argument( - "--rank", - type=int, - default=128, - help=("The dimension of the LoRA update matrices."), - ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -698,14 +639,12 @@ def main(): repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token ).repo_id - # Load scheduler, tokenizer and models. - noise_scheduler = EulerDiscreteScheduler.from_pretrained( - args.pretrained_model_name_or_path, subfolder="scheduler") + # Load img encoder, tokenizer and models. feature_extractor = CLIPImageProcessor.from_pretrained( args.pretrained_model_name_or_path, subfolder="feature_extractor", revision=args.revision ) image_encoder = CLIPVisionModelWithProjection.from_pretrained( - args.pretrained_model_name_or_path, subfolder="image_encoder", revision=args.revision, variant="fp16" + args.pretrained_model_name_or_path, subfolder="image_encoder", revision=args.revision ) vae = AutoencoderKLTemporalDecoder.from_pretrained( args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant="fp16") @@ -713,13 +652,9 @@ def main(): args.pretrained_model_name_or_path if args.pretrain_unet is None else args.pretrain_unet, subfolder="unet", low_cpu_mem_usage=True, - variant="fp16", + variant="fp16" ) - # attribute handling for models using DDP - if isinstance(unet, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): - unet = unet.module - # Freeze vae and image_encoder vae.requires_grad_(False) image_encoder.requires_grad_(False) @@ -736,7 +671,7 @@ def main(): # Move image_encoder and vae to gpu and cast to weight_dtype image_encoder.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) - # unet.to(accelerator.device, dtype=weight_dtype) + # Create EMA for the unet. if args.use_ema: @@ -820,17 +755,15 @@ def load_model_hook(models, input_dir): else: optimizer_cls = torch.optim.AdamW - unet.requires_grad_(True) parameters_list = [] # Customize the parameters that need to be trained; if necessary, you can uncomment them yourself. - - for name, para in unet.named_parameters(): + for name, param in unet.named_parameters(): if 'temporal_transformer_block' in name: - parameters_list.append(para) - para.requires_grad = True + parameters_list.append(param) + param.requires_grad = True else: - para.requires_grad = False + param.requires_grad = False optimizer = optimizer_cls( parameters_list, lr=args.learning_rate, @@ -839,18 +772,10 @@ def load_model_hook(models, input_dir): eps=args.adam_epsilon, ) - # optimizer = optimizer_cls( - # unet.parameters(), - # lr=args.learning_rate, - # betas=(args.adam_beta1, args.adam_beta2), - # weight_decay=args.adam_weight_decay, - # eps=args.adam_epsilon, - # ) - # check parameters if accelerator.is_main_process: - rec_txt1 = open('rec_para.txt', 'w') - rec_txt2 = open('rec_para_train.txt', 'w') + rec_txt1 = open('params_freeze.txt', 'w') + rec_txt2 = open('params_train.txt', 'w') for name, para in unet.named_parameters(): if para.requires_grad is False: rec_txt1.write(f'{name}\n') @@ -893,6 +818,10 @@ def load_model_hook(models, input_dir): if args.use_ema: ema_unet.to(accelerator.device) + + # attribute handling for models using DDP + if isinstance(unet, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): + unet = unet.module # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil( diff --git a/train_svd_lora.py b/train_svd_lora.py new file mode 100644 index 0000000..000077d --- /dev/null +++ b/train_svd_lora.py @@ -0,0 +1,1166 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fine-tuning script for Stable Video Diffusion with support for LoRA.""" +import argparse +import random +import logging +import math +import os +import cv2 +import shutil +from pathlib import Path +from urllib.parse import urlparse + +import accelerate +import numpy as np +import PIL +from PIL import Image, ImageDraw +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch.utils.data import RandomSampler +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from huggingface_hub import create_repo, upload_folder +from packaging import version +from tqdm.auto import tqdm +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection +from einops import rearrange + +import diffusers +from diffusers import StableVideoDiffusionPipeline +from diffusers.models.lora import LoRALinearLayer +from diffusers import AutoencoderKLTemporalDecoder +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available, load_image +from diffusers.utils.import_utils import is_xformers_available + +from torch.utils.data import Dataset + +from peft import LoraConfig +from peft.utils import get_peft_model_state_dict +from diffusers.training_utils import cast_training_params + +from src.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.29.1") + +logger = get_logger(__name__, log_level="INFO") + +# copy from https://github.com/crowsonkb/k-diffusion.git +def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32): + """Draws samples from an lognormal distribution.""" + u = torch.rand(shape, dtype=dtype, device=device) * (1 - 2e-7) + 1e-7 + return torch.distributions.Normal(loc, scale).icdf(u).exp() + + +class DummyDataset(Dataset): + def __init__(self, base_folder: str, num_samples=100000, width=1024, height=576, sample_frames=25): + """ + Args: + num_samples (int): Number of samples in the dataset. + channels (int): Number of channels, default is 3 for RGB. + """ + self.num_samples = num_samples + # Define the path to the folder containing video frames + self.base_folder = base_folder + self.folders = os.listdir(self.base_folder) + self.channels = 3 + self.width = width + self.height = height + self.sample_frames = sample_frames + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + """ + Args: + idx (int): Index of the sample to return. + + Returns: + dict: A dictionary containing the 'pixel_values' tensor of shape (16, channels, 320, 512). + """ + # Randomly select a folder (representing a video) from the base folder + chosen_folder = random.choice(self.folders) + folder_path = os.path.join(self.base_folder, chosen_folder) + frames = os.listdir(folder_path) + # Sort the frames by name + frames.sort() + + # Ensure the selected folder has at least `sample_frames`` frames + if len(frames) < self.sample_frames: + raise ValueError( + f"The selected folder '{chosen_folder}' contains fewer than `{self.sample_frames}` frames.") + + # Randomly select a start index for frame sequence + start_idx = random.randint(0, len(frames) - self.sample_frames) + selected_frames = frames[start_idx:start_idx + self.sample_frames] + + # Initialize a tensor to store the pixel values + pixel_values = torch.empty((self.sample_frames, self.channels, self.height, self.width)) + + # Load and process each frame + for i, frame_name in enumerate(selected_frames): + frame_path = os.path.join(folder_path, frame_name) + with Image.open(frame_path) as img: + # Resize the image and convert it to a tensor + img_resized = img.resize((self.width, self.height)) + img_tensor = torch.from_numpy(np.array(img_resized)).float() + + # Normalize the image by scaling pixel values to [-1, 1] + img_normalized = img_tensor / 127.5 - 1 + + # Rearrange channels if necessary + if self.channels == 3: + img_normalized = img_normalized.permute( + 2, 0, 1) # For RGB images + elif self.channels == 1: + img_normalized = img_normalized.mean( + dim=2, keepdim=True) # For grayscale images + + pixel_values[i] = img_normalized + return {'pixel_values': pixel_values} + +# resizing utils +# TODO: clean up later +def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True): + h, w = input.shape[-2:] + factors = (h / size[0], w / size[1]) + + # First, we have to determine sigma + # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171 + sigmas = ( + max((factors[0] - 1.0) / 2.0, 0.001), + max((factors[1] - 1.0) / 2.0, 0.001), + ) + + # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma + # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206 + # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now + ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3)) + + # Make sure it is odd + if (ks[0] % 2) == 0: + ks = ks[0] + 1, ks[1] + + if (ks[1] % 2) == 0: + ks = ks[0], ks[1] + 1 + + input = _gaussian_blur2d(input, ks, sigmas) + + output = torch.nn.functional.interpolate( + input, size=size, mode=interpolation, align_corners=align_corners) + return output + + +def _compute_padding(kernel_size): + """Compute padding tuple.""" + # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom) + # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad + if len(kernel_size) < 2: + raise AssertionError(kernel_size) + computed = [k - 1 for k in kernel_size] + + # for even kernels we need to do asymmetric padding :( + out_padding = 2 * len(kernel_size) * [0] + + for i in range(len(kernel_size)): + computed_tmp = computed[-(i + 1)] + + pad_front = computed_tmp // 2 + pad_rear = computed_tmp - pad_front + + out_padding[2 * i + 0] = pad_front + out_padding[2 * i + 1] = pad_rear + + return out_padding + + +def _filter2d(input, kernel): + # prepare kernel + b, c, h, w = input.shape + tmp_kernel = kernel[:, None, ...].to( + device=input.device, dtype=input.dtype) + + tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) + + height, width = tmp_kernel.shape[-2:] + + padding_shape: list[int] = _compute_padding([height, width]) + input = torch.nn.functional.pad(input, padding_shape, mode="reflect") + + # kernel and input tensor reshape to align element-wise or batch-wise params + tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) + input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) + + # convolve the tensor with the kernel. + output = torch.nn.functional.conv2d( + input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) + + out = output.view(b, c, h, w) + return out + + +def _gaussian(window_size: int, sigma): + if isinstance(sigma, float): + sigma = torch.tensor([[sigma]]) + + batch_size = sigma.shape[0] + + x = (torch.arange(window_size, device=sigma.device, + dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1) + + if window_size % 2 == 0: + x = x + 0.5 + + gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) + + return gauss / gauss.sum(-1, keepdim=True) + + +def _gaussian_blur2d(input, kernel_size, sigma): + if isinstance(sigma, tuple): + sigma = torch.tensor([sigma], dtype=input.dtype) + else: + sigma = sigma.to(dtype=input.dtype) + + ky, kx = int(kernel_size[0]), int(kernel_size[1]) + bs = sigma.shape[0] + kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1)) + kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1)) + out_x = _filter2d(input, kernel_x[..., None, :]) + out = _filter2d(out_x, kernel_y[..., None]) + + return out + + +def export_to_video(video_frames, output_video_path, fps): + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + h, w, _ = video_frames[0].shape + video_writer = cv2.VideoWriter( + output_video_path, fourcc, fps=fps, frameSize=(w, h)) + for i in range(len(video_frames)): + img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR) + video_writer.write(img) + + +def export_to_gif(frames, output_gif_path, fps): + """ + Export a list of frames to a GIF. + + Args: + - frames (list): List of frames (as numpy arrays or PIL Image objects). + - output_gif_path (str): Path to save the output GIF. + - duration_ms (int): Duration of each frame in milliseconds. + + """ + # Convert numpy arrays to PIL Images if needed + pil_frames = [Image.fromarray(frame) if isinstance( + frame, np.ndarray) else frame for frame in frames] + + pil_frames[0].save(output_gif_path.replace('.mp4', '.gif'), + format='GIF', + append_images=pil_frames[1:], + save_all=True, + duration=500, + loop=0) + + +def tensor_to_vae_latent(t, vae): + video_length = t.shape[1] + + t = rearrange(t, "b f c h w -> (b f) c h w") + latents = vae.encode(t).latent_dist.sample() + latents = rearrange(latents, "(b f) c h w -> b f c h w", f=video_length) + latents = latents * vae.config.scaling_factor + + return latents + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Script to train Stable Video Diffusion." + ) + parser.add_argument( + "--base_folder", + required=True, + type=str, + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--num_frames", + type=int, + default=25, + ) + parser.add_argument( + "--width", + type=int, + default=1024, + ) + parser.add_argument( + "--height", + type=int, + default=576, + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=1, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=500, + help=( + "Run fine-tuning validation every X epochs. The validation process consists of running the text/image prompt" + " multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="./outputs", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--seed", type=int, default=None, help="A seed for reproducible training." + ) + parser.add_argument( + "--per_gpu_batch_size", + type=int, + default=1, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", + type=int, + default=500, + help="Number of steps for the warmup in the lr scheduler.", + ) + parser.add_argument( + "--conditioning_dropout_prob", + type=float, + default=0.1, + help="Conditioning dropout probability. Drops out the conditionings (image and edit prompt) used in training InstructPix2Pix. See section 3.2.1 in the paper: https://arxiv.org/abs/2211.09800.", + ) + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes.", + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--num_workers", + type=int, + default=8, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--adam_beta1", + type=float, + default=0.9, + help="The beta1 parameter for the Adam optimizer.", + ) + parser.add_argument( + "--adam_beta2", + type=float, + default=0.999, + help="The beta2 parameter for the Adam optimizer.", + ) + parser.add_argument( + "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use." + ) + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer", + ) + parser.add_argument( + "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." + ) + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether or not to push the model to the Hub.", + ) + parser.add_argument( + "--hub_token", + type=str, + default=None, + help="The token to use to push to the Model Hub.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--local_rank", + type=int, + default=-1, + help="For distributed training: local_rank", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=2, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", + action="store_true", + help="Whether or not to use xformers.", + ) + + parser.add_argument( + "--pretrain_unet", + type=str, + default=None, + help="use weight for unet block", + ) + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + return args + + +def download_image(url): + original_image = ( + lambda image_url_or_path: load_image(image_url_or_path) + if urlparse(image_url_or_path).scheme + else PIL.Image.open(image_url_or_path).convert("RGB") + )(url) + return original_image + + +def main(): + args = parse_args() + + logging_dir = os.path.join(args.output_dir, args.logging_dir) + accelerator_project_config = ProjectConfiguration( + project_dir=args.output_dir, logging_dir=logging_dir) + # ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + # kwargs_handlers=[ddp_kwargs] + ) + + generator = torch.Generator( + device=accelerator.device).manual_seed(args.seed) + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError( + "Make sure to install wandb if you want to use it for logging during training.") + import wandb + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load img encoder, tokenizer and models. + feature_extractor = CLIPImageProcessor.from_pretrained( + args.pretrained_model_name_or_path, subfolder="feature_extractor", revision=args.revision + ) + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + args.pretrained_model_name_or_path, subfolder="image_encoder", revision=args.revision + ) + vae = AutoencoderKLTemporalDecoder.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant="fp16") + unet = UNetSpatioTemporalConditionModel.from_pretrained( + args.pretrained_model_name_or_path if args.pretrain_unet is None else args.pretrain_unet, + subfolder="unet", + low_cpu_mem_usage=True, + variant="fp16", + ) + + # Freeze vae and image_encoder + vae.requires_grad_(False) + image_encoder.requires_grad_(False) + unet.requires_grad_(False) + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Freeze the unet parameters before adding adapters + for param in unet.parameters(): + param.requires_grad_(False) + + unet_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + ) + + # Move image_encoder and vae to gpu and cast to weight_dtype + image_encoder.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + unet.to(accelerator.device, dtype=weight_dtype) + + unet.add_adapter(unet_lora_config) + if args.mixed_precision == "fp16": + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(unet, dtype=torch.float32) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError( + "xformers is not available. Make sure it is installed correctly") + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + + for i, model in enumerate(models): + model.save_pretrained(os.path.join(output_dir, "unet")) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = UNetSpatioTemporalConditionModel.from_pretrained( + input_dir, subfolder="unet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * + args.per_gpu_batch_size * accelerator.num_processes + ) + + # Initialize the optimizer + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + lora_layers = filter(lambda p: p.requires_grad, unet.parameters()) + optimizer = optimizer_cls( + lora_layers, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # check parameters + if accelerator.is_main_process: + rec_txt1 = open('params_freeze.txt', 'w') + rec_txt2 = open('params_train.txt', 'w') + for name, para in unet.named_parameters(): + if para.requires_grad is False: + rec_txt1.write(f'{name}\n') + else: + rec_txt2.write(f'{name}\n') + rec_txt1.close() + rec_txt2.close() + + # DataLoaders creation: + args.global_batch_size = args.per_gpu_batch_size * accelerator.num_processes + + train_dataset = DummyDataset(args.base_folder, width=args.width, height=args.height, sample_frames=args.num_frames) + sampler = RandomSampler(train_dataset) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + sampler=sampler, + batch_size=args.per_gpu_batch_size, + num_workers=args.num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + ) + + # Prepare everything with our `accelerator`. + unet, optimizer, lr_scheduler, train_dataloader = accelerator.prepare( + unet, optimizer, lr_scheduler, train_dataloader + ) + + # attribute handling for models using DDP + if isinstance(unet, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): + unet = unet.module + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil( + args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("SVDXtend", config=vars(args)) + + # Train! + total_batch_size = args.per_gpu_batch_size * \ + accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info( + f" Instantaneous batch size per device = {args.per_gpu_batch_size}") + logger.info( + f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info( + f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + def encode_image(pixel_values): + # pixel: [-1, 1] + pixel_values = _resize_with_antialiasing(pixel_values, (224, 224)) + # We unnormalize it after resizing. + pixel_values = (pixel_values + 1.0) / 2.0 + + # Normalize the image with for CLIP input + pixel_values = feature_extractor( + images=pixel_values, + do_normalize=True, + do_center_crop=False, + do_resize=False, + do_rescale=False, + return_tensors="pt", + ).pixel_values + + pixel_values = pixel_values.to( + device=accelerator.device, dtype=weight_dtype) + image_embeddings = image_encoder(pixel_values).image_embeds + return image_embeddings + + def _get_add_time_ids( + fps, + motion_bucket_id, + noise_aug_strength, + dtype, + batch_size, + ): + add_time_ids = [fps, motion_bucket_id, noise_aug_strength] + + passed_add_embed_dim = unet.config.addition_time_embed_dim * \ + len(add_time_ids) + expected_add_embed_dim = unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_time_ids = add_time_ids.repeat(batch_size, 1) + return add_time_ids + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = global_step // num_update_steps_per_epoch + resume_step = resume_global_step % ( + num_update_steps_per_epoch * args.gradient_accumulation_steps) + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, args.max_train_steps), + disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + + for epoch in range(first_epoch, args.num_train_epochs): + unet.train() + train_loss = 0.0 + for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + continue + + with accelerator.accumulate(unet): + # first, convert images to latent space. + pixel_values = batch["pixel_values"].to(weight_dtype).to( + accelerator.device, non_blocking=True + ) + conditional_pixel_values = pixel_values[:, 0:1, :, :, :] + + latents = tensor_to_vae_latent(pixel_values, vae) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + cond_sigmas = rand_log_normal(shape=[bsz,], loc=-3.0, scale=0.5).to(latents) + noise_aug_strength = cond_sigmas[0] # TODO: support batch > 1 + cond_sigmas = cond_sigmas[:, None, None, None, None] + conditional_pixel_values = \ + torch.randn_like(conditional_pixel_values) * cond_sigmas + conditional_pixel_values + conditional_latents = tensor_to_vae_latent(conditional_pixel_values, vae)[:, 0, :, :, :] + conditional_latents = conditional_latents / vae.config.scaling_factor + + # Sample a random timestep for each image + # P_mean=0.7 P_std=1.6 + sigmas = rand_log_normal(shape=[bsz,], loc=0.7, scale=1.6).to(latents.device) + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + sigmas = sigmas[:, None, None, None, None] + noisy_latents = latents + noise * sigmas + timesteps = torch.Tensor( + [0.25 * sigma.log() for sigma in sigmas]).to(accelerator.device) + + inp_noisy_latents = noisy_latents / ((sigmas**2 + 1) ** 0.5) + + # Get the text embedding for conditioning. + encoder_hidden_states = encode_image( + pixel_values[:, 0, :, :, :].float()) + + # Here I input a fixed numerical value for 'motion_bucket_id', which is not reasonable. + # However, I am unable to fully align with the calculation method of the motion score, + # so I adopted this approach. The same applies to the 'fps' (frames per second). + added_time_ids = _get_add_time_ids( + 7, # fixed + 127, # motion_bucket_id = 127, fixed + noise_aug_strength, # noise_aug_strength == cond_sigmas + encoder_hidden_states.dtype, + bsz, + ) + added_time_ids = added_time_ids.to(latents.device) + + # Conditioning dropout to support classifier-free guidance during inference. For more details + # check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800. + if args.conditioning_dropout_prob is not None: + random_p = torch.rand( + bsz, device=latents.device, generator=generator) + # Sample masks for the edit prompts. + prompt_mask = random_p < 2 * args.conditioning_dropout_prob + prompt_mask = prompt_mask.reshape(bsz, 1, 1) + # Final text conditioning. + null_conditioning = torch.zeros_like(encoder_hidden_states) + encoder_hidden_states = torch.where( + prompt_mask, null_conditioning.unsqueeze(1), encoder_hidden_states.unsqueeze(1)) + # Sample masks for the original images. + image_mask_dtype = conditional_latents.dtype + image_mask = 1 - ( + (random_p >= args.conditioning_dropout_prob).to( + image_mask_dtype) + * (random_p < 3 * args.conditioning_dropout_prob).to(image_mask_dtype) + ) + image_mask = image_mask.reshape(bsz, 1, 1, 1) + # Final image conditioning. + conditional_latents = image_mask * conditional_latents + + # Concatenate the `conditional_latents` with the `noisy_latents`. + conditional_latents = conditional_latents.unsqueeze( + 1).repeat(1, noisy_latents.shape[1], 1, 1, 1) + inp_noisy_latents = torch.cat( + [inp_noisy_latents, conditional_latents], dim=2) + + # check https://arxiv.org/abs/2206.00364(the EDM-framework) for more details. + target = latents + model_pred = unet( + inp_noisy_latents, timesteps, encoder_hidden_states, added_time_ids=added_time_ids).sample + + # Denoise the latents + c_out = -sigmas / ((sigmas**2 + 1)**0.5) + c_skip = 1 / (sigmas**2 + 1) + denoised_latents = model_pred * c_out + c_skip * noisy_latents + weighing = (1 + sigmas ** 2) * (sigmas**-2.0) + + # MSE loss + loss = torch.mean( + (weighing.float() * (denoised_latents.float() - + target.float()) ** 2).reshape(target.shape[0], -1), + dim=1, + ) + loss = loss.mean() + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather( + loss.repeat(args.per_gpu_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + # if accelerator.sync_gradients: + # accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if accelerator.is_main_process: + # save checkpoints! + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [ + d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted( + checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len( + checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info( + f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join( + args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join( + args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + unwrapped_unet = accelerator.unwrap_model(unet) + unet_lora_state_dict = convert_state_dict_to_diffusers( + get_peft_model_state_dict(unwrapped_unet) + ) + + StableVideoDiffusionPipeline.save_lora_weights( + save_directory=save_path, + unet_lora_layers=unet_lora_state_dict, + safe_serialization=True, + ) + + logger.info(f"Saved state to {save_path}") + # sample images! + if ( + (global_step % args.validation_steps == 0) + or (global_step == 1) + ): + logger.info( + f"Running validation... \n Generating {args.num_validation_images} videos." + ) + # The models need unwrapping because for compatibility in distributed training mode. + pipeline = StableVideoDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + image_encoder=accelerator.unwrap_model( + image_encoder), + vae=accelerator.unwrap_model(vae), + revision=args.revision, + torch_dtype=weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + val_save_dir = os.path.join( + args.output_dir, "validation_images") + + if not os.path.exists(val_save_dir): + os.makedirs(val_save_dir) + + with torch.autocast( + str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16" + ): + for val_img_idx in range(args.num_validation_images): + num_frames = args.num_frames + video_frames = pipeline( + load_image('demo.jpg').resize((args.width, args.height)), + height=args.height, + width=args.width, + num_frames=num_frames, + decode_chunk_size=8, + motion_bucket_id=127, + fps=7, + noise_aug_strength=0.02, + # generator=generator, + ).frames[0] + + out_file = os.path.join( + val_save_dir, + f"step_{global_step}_val_img_{val_img_idx}.mp4", + ) + + for i in range(num_frames): + img = video_frames[i] + video_frames[i] = np.array(img) + export_to_gif(video_frames, out_file, 8) + + del pipeline + torch.cuda.empty_cache() + + logs = {"step_loss": loss.detach().item( + ), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + # Create the pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = unet.to(torch.float32) + + unwrapped_unet = accelerator.unwrap_model(unet) + unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_unet)) + StableVideoDiffusionPipeline.save_lora_weights( + save_directory=args.output_dir, + unet_lora_layers=unet_lora_state_dict, + safe_serialization=True, + ) + + if args.push_to_hub: + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + accelerator.end_training() + + +if __name__ == "__main__": + main()