diff --git a/nemo/collections/diffusion/data/diffusion_energon_datamodule.py b/nemo/collections/diffusion/data/diffusion_energon_datamodule.py index f18c828d9d45..230b7e582c87 100644 --- a/nemo/collections/diffusion/data/diffusion_energon_datamodule.py +++ b/nemo/collections/diffusion/data/diffusion_energon_datamodule.py @@ -15,7 +15,8 @@ import logging from typing import Any, Dict, Literal -from megatron.energon import DefaultTaskEncoder, get_train_dataset +from megatron.core import parallel_state +from megatron.energon import DefaultTaskEncoder, WorkerConfig, get_savable_loader, get_train_dataset from pytorch_lightning.utilities.types import EVAL_DATALOADERS from nemo.collections.multimodal.data.energon.base import SimpleMultiModalDataModule @@ -56,6 +57,9 @@ def __init__( pin_memory: bool = True, task_encoder: DefaultTaskEncoder = None, use_train_split_for_val: bool = False, + virtual_epoch_length: int = 1_000_000_000, # a hack to avoid energon end of epoch warning + packing_buffer_size: int | None = None, + max_samples_per_sequence: int | None = None, ) -> None: """ Initialize the SimpleMultiModalDataModule. @@ -82,6 +86,10 @@ def __init__( task_encoder=task_encoder, ) self.use_train_split_for_val = use_train_split_for_val + self.virtual_epoch_length = virtual_epoch_length + self.num_workers_val = 1 + self.packing_buffer_size = packing_buffer_size + self.max_samples_per_sequence = max_samples_per_sequence def datasets_provider(self, worker_config, split: Literal['train', 'val'] = 'val'): """ @@ -106,29 +114,52 @@ def datasets_provider(self, worker_config, split: Literal['train', 'val'] = 'val batch_size=self.micro_batch_size, task_encoder=self.task_encoder, worker_config=worker_config, - max_samples_per_sequence=None, - shuffle_buffer_size=100, + max_samples_per_sequence=self.max_samples_per_sequence, + shuffle_buffer_size=None, split_part=split, - batch_drop_last=True, - virtual_epoch_length=1_000_000_000, # a hack to avoid energon end of epoch warning + virtual_epoch_length=self.virtual_epoch_length, + packing_buffer_size=self.packing_buffer_size, ) return _dataset def val_dataloader(self) -> EVAL_DATALOADERS: """ - Configure the validation DataLoader. + Initialize and return the validation DataLoader. - This method configures the DataLoader for validation data. - - Parameters: - worker_config: Configuration for the data loader workers. + This method initializes the DataLoader for the validation dataset. It ensures that the parallel state + is initialized correctly for distributed training and returns a configured DataLoader object. Returns: - DataLoader: The DataLoader for validation data. + EVAL_DATALOADERS: The DataLoader for the validation dataset. """ if self.use_train_split_for_val: return self.train_dataloader() - return super().val_dataloader() + if self.val_dataloader_object: + return self.val_dataloader_object + + if not parallel_state.is_initialized(): + logging.info( + f"Muiltimodal val data loader parallel state is not initialized, using default worker config with no_workers {self.num_workers}" + ) + worker_config = WorkerConfig.default_worker_config(self.num_workers_val) + else: + rank = parallel_state.get_data_parallel_rank() + world_size = parallel_state.get_data_parallel_world_size() + data_parallel_group = parallel_state.get_data_parallel_group() + + logging.info(f"rank {rank} world_size {world_size} data_parallel_group {data_parallel_group}") + worker_config = WorkerConfig( + rank=rank, + world_size=world_size, + num_workers=self.num_workers_val, + data_parallel_group=data_parallel_group, + worker_debug_path=None, + worker_log_level=0, + ) + val_dataset = self.datasets_provider(worker_config, split='val') + energon_loader = get_savable_loader(val_dataset, worker_config=worker_config) + self.val_dataloader_object = energon_loader + return self.val_dataloader_object def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """ diff --git a/nemo/collections/diffusion/data/diffusion_taskencoder.py b/nemo/collections/diffusion/data/diffusion_taskencoder.py index 57e4e4ec8673..b0ab2f2ac234 100644 --- a/nemo/collections/diffusion/data/diffusion_taskencoder.py +++ b/nemo/collections/diffusion/data/diffusion_taskencoder.py @@ -12,15 +12,74 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings +import random +from dataclasses import dataclass +from typing import Any, List, Optional + import torch import torch.nn.functional as F from einops import rearrange -from megatron.core import parallel_state -from megatron.energon import DefaultTaskEncoder, SkipSample +from megatron.energon import DefaultTaskEncoder, Sample, SkipSample +from megatron.energon.task_encoder.base import stateless from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys from nemo.lightning.io.mixin import IOMixin +from nemo.utils.sequence_packing_utils import first_fit_decreasing + + +@dataclass +class DiffusionSample(Sample): + video: torch.Tensor # video latents (C T H W) + t5_text_embeddings: torch.Tensor # (S D) + t5_text_mask: torch.Tensor # 1 + loss_mask: torch.Tensor + image_size: Optional[torch.Tensor] = None + fps: Optional[torch.Tensor] = None + num_frames: Optional[torch.Tensor] = None + padding_mask: Optional[torch.Tensor] = None + seq_len_q: Optional[torch.Tensor] = None + seq_len_kv: Optional[torch.Tensor] = None + pos_ids: Optional[torch.Tensor] = None + latent_shape: Optional[torch.Tensor] = None + + def to_dict(self) -> dict: + return dict( + video=self.video, + t5_text_embeddings=self.t5_text_embeddings, + t5_text_mask=self.t5_text_mask, + loss_mask=self.loss_mask, + image_size=self.image_size, + fps=self.fps, + num_frames=self.num_frames, + padding_mask=self.padding_mask, + seq_len_q=self.seq_len_q, + seq_len_kv=self.seq_len_kv, + pos_ids=self.pos_ids, + latent_shape=self.latent_shape, + ) + + def __add__(self, other: Any) -> int: + if isinstance(other, DiffusionSample): + # Combine the values of the two instances + return self.seq_len_q.item() + other.seq_len_q.item() + elif isinstance(other, int): + # Add an integer to the value + return self.seq_len_q.item() + other + raise NotImplementedError + + def __radd__(self, other: Any) -> int: + # This is called if sum or other operations start with a non-DiffusionSample object. + # e.g., sum([DiffusionSample(1), DiffusionSample(2)]) -> the 0 + DiffusionSample(1) calls __radd__. + if isinstance(other, int): + return self.seq_len_q.item() + other + raise NotImplementedError + + def __lt__(self, other: Any) -> bool: + if isinstance(other, DiffusionSample): + return self.seq_len_q.item() < other.seq_len_q.item() + elif isinstance(other, int): + return self.seq_len_q.item() < other + raise NotImplementedError def cook(sample: dict) -> dict: @@ -75,17 +134,22 @@ def __init__( max_frames: int = None, text_embedding_padding_size: int = 512, seq_length: int = None, + max_seq_length: int = None, patch_spatial: int = 2, patch_temporal: int = 1, + aesthetic_score: float = 0.0, **kwargs, ): super().__init__(*args, **kwargs) self.max_frames = max_frames self.text_embedding_padding_size = text_embedding_padding_size self.seq_length = seq_length + self.max_seq_length = max_seq_length self.patch_spatial = patch_spatial self.patch_temporal = patch_temporal + self.aesthetic_score = aesthetic_score + @stateless(restore_seeds=True) def encode_sample(self, sample: dict) -> dict: video_latent = sample['pth'] @@ -95,6 +159,9 @@ def encode_sample(self, sample: dict) -> dict: raise SkipSample() info = sample['json'] + if info['aesthetic_score'] < self.aesthetic_score: + raise SkipSample() + C, T, H, W = video_latent.shape seq_len = ( video_latent.shape[-1] @@ -105,19 +172,14 @@ def encode_sample(self, sample: dict) -> dict: ) is_image = T == 1 - if seq_len > self.seq_length: + if self.seq_length is not None and seq_len > self.seq_length: + raise SkipSample() + if self.max_seq_length is not None and seq_len > self.max_seq_length: raise SkipSample() if self.max_frames is not None: video_latent = video_latent[:, : self.max_frames, :, :] - tpcp_size = parallel_state.get_tensor_model_parallel_world_size() - if parallel_state.get_context_parallel_world_size() > 1: - tpcp_size *= parallel_state.get_context_parallel_world_size() * 2 - if (T * H * W) % tpcp_size != 0: - warnings.warn(f'skipping {video_latent.shape=} not divisible by {tpcp_size=}') - raise SkipSample() - video_latent = rearrange( video_latent, 'C (T pt) (H ph) (W pw) -> (T H W) (ph pw pt C)', @@ -161,7 +223,7 @@ def encode_sample(self, sample: dict) -> dict: 'T H W d -> (T H W) d', ) - if self.seq_length is not None: + if self.seq_length is not None and self.max_seq_length is None: pos_ids = F.pad(pos_ids, (0, 0, 0, self.seq_length - seq_len)) loss_mask = torch.zeros(self.seq_length, dtype=torch.bfloat16) loss_mask[:seq_len] = 1 @@ -169,7 +231,11 @@ def encode_sample(self, sample: dict) -> dict: else: loss_mask = torch.ones(seq_len, dtype=torch.bfloat16) - return dict( + return DiffusionSample( + __key__=sample['__key__'], + __restore_key__=sample['__restore_key__'], + __subflavor__=None, + __subflavors__=sample['__subflavors__'], video=video_latent, t5_text_embeddings=t5_text_embeddings, t5_text_mask=t5_text_mask, @@ -178,11 +244,71 @@ def encode_sample(self, sample: dict) -> dict: num_frames=num_frames, loss_mask=loss_mask, seq_len_q=torch.tensor(seq_len, dtype=torch.int32), - seq_len_kv=torch.tensor(t5_text_embeddings_seq_length, dtype=torch.int32), + seq_len_kv=torch.tensor(self.text_embedding_padding_size, dtype=torch.int32), pos_ids=pos_ids, latent_shape=torch.tensor([C, T, H, W], dtype=torch.int32), ) + def select_samples_to_pack(self, samples: List[DiffusionSample]) -> List[List[DiffusionSample]]: + results = first_fit_decreasing(samples, self.max_seq_length) + random.shuffle(results) + return results + + @stateless + def pack_selected_samples(self, samples: List[DiffusionSample]) -> DiffusionSample: + # Construct a new CaptioningSample by concatenating the captions + + def stack(attr): + return torch.stack([getattr(sample, attr) for sample in samples], dim=0) + + def cat(attr): + return torch.cat([getattr(sample, attr) for sample in samples], dim=0) + + video = concat_pad([i.video for i in samples], self.max_seq_length) + loss_mask = concat_pad([i.loss_mask for i in samples], self.max_seq_length) + pos_ids = concat_pad([i.pos_ids for i in samples], self.max_seq_length) + + return DiffusionSample( + __key__=",".join([s.__key__ for s in samples]), + __restore_key__=(), # Will be set by energon based on `samples` + __subflavor__=None, + __subflavors__=samples[0].__subflavors__, + video=video, + t5_text_embeddings=cat('t5_text_embeddings'), + t5_text_mask=cat('t5_text_mask'), + # image_size=stack('image_size'), + # fps=stack('fps'), + # num_frames=stack('num_frames'), + loss_mask=loss_mask, + seq_len_q=stack('seq_len_q'), + seq_len_kv=stack('seq_len_kv'), + pos_ids=pos_ids, + latent_shape=stack('latent_shape'), + ) + + @stateless + def batch(self, samples: List[DiffusionSample]) -> dict: + if self.max_seq_length is None: + # no packing + return super().batch(samples).to_dict() + + # packing + sample = samples[0] + return dict( + video=sample.video.unsqueeze_(0), + t5_text_embeddings=sample.t5_text_embeddings.unsqueeze_(0), + t5_text_mask=sample.t5_text_mask.unsqueeze_(0), + loss_mask=sample.loss_mask.unsqueeze_(0), + # image_size=sample.image_size, + # fps=sample.fps, + # num_frames=sample.num_frames, + # padding_mask=sample.padding_mask.unsqueeze_(0), + seq_len_q=sample.seq_len_q, + seq_len_kv=sample.seq_len_kv, + pos_ids=sample.pos_ids.unsqueeze_(0), + latent_shape=sample.latent_shape, + ) + class PosID3D: def __init__(self, *, max_t=32, max_h=128, max_w=128): @@ -210,4 +336,60 @@ def get_pos_id_3d(self, *, t, h, w): return self.grid[:t, :h, :w] +def pad_divisible(x, padding_value=0): + if padding_value == 0: + return x + # Get the size of the first dimension + n = x.size(0) + + # Compute the padding needed to make the first dimension divisible by 16 + padding_needed = (padding_value - n % padding_value) % padding_value + + if padding_needed <= 0: + return x + + # Create a new shape with the padded first dimension + new_shape = list(x.shape) + new_shape[0] += padding_needed + + # Create a new tensor filled with zeros + x_padded = torch.zeros(new_shape, dtype=x.dtype, device=x.device) + + # Assign the original tensor to the beginning of the new tensor + x_padded[:n] = x + return x_padded + + +def concat_pad(tensor_list, max_seq_length): + """ + Efficiently concatenates a list of tensors along the first dimension and pads with zeros + to reach max_seq_length. + + Args: + tensor_list (list of torch.Tensor): List of tensors to concatenate and pad. + max_seq_length (int): The desired size of the first dimension of the output tensor. + + Returns: + torch.Tensor: A tensor of shape [max_seq_length, ...], where ... represents the remaining dimensions. + """ + import torch + + # Get common properties from the first tensor + other_shape = tensor_list[0].shape[1:] + dtype = tensor_list[0].dtype + device = tensor_list[0].device + + # Initialize the result tensor with zeros + result = torch.zeros((max_seq_length, *other_shape), dtype=dtype, device=device) + + current_index = 0 + for tensor in tensor_list: + length = tensor.shape[0] + # Directly assign the tensor to the result tensor without checks + result[current_index : current_index + length] = tensor + current_index += length + + return result + + pos_id_3d = PosID3D() diff --git a/nemo/collections/diffusion/models/model.py b/nemo/collections/diffusion/models/model.py index 05f635a1def7..aa690d5203b7 100644 --- a/nemo/collections/diffusion/models/model.py +++ b/nemo/collections/diffusion/models/model.py @@ -58,12 +58,12 @@ def dit_data_step(module, dataloader_iter): 'self_attention': PackedSeqParams( cu_seqlens_q=cu_seqlens, cu_seqlens_kv=cu_seqlens, - qkv_format='sbhd', + qkv_format=module.qkv_format, ), 'cross_attention': PackedSeqParams( cu_seqlens_q=cu_seqlens, cu_seqlens_kv=cu_seqlens_kv, - qkv_format='sbhd', + qkv_format=module.qkv_format, ), } @@ -77,9 +77,7 @@ def get_batch_on_this_cp_rank(data: Dict): cp_size = mpu.get_context_parallel_world_size() cp_rank = mpu.get_context_parallel_rank() - t = 16 if cp_size > 1: - assert t % cp_size == 0, "t must divisibly by cp_size" num_valid_tokens_in_ub = None if 'loss_mask' in data and data['loss_mask'] is not None: num_valid_tokens_in_ub = data['loss_mask'].sum() @@ -88,9 +86,13 @@ def get_batch_on_this_cp_rank(data: Dict): if (value is not None) and (key in ['video', 'video_latent', 'noise_latent', 'pos_ids']): if len(value.shape) > 5: value = value.squeeze(0) - B, C, T, H, W = value.shape + if len(value.shape) == 5: + B, C, T, H, W = value.shape + data[key] = value.view(B, C, cp_size, T // cp_size, H, W)[:, :, cp_rank, ...].contiguous() + else: + B, S, D = value.shape + data[key] = value.view(B, cp_size, S // cp_size, D)[:, cp_rank, ...].contiguous() # TODO: sequence packing - data[key] = value.view(B, C, cp_size, T // cp_size, H, W)[:, :, cp_rank, ...].contiguous() loss_mask = data["loss_mask"] data["loss_mask"] = loss_mask.view(loss_mask.shape[0], cp_size, loss_mask.shape[1] // cp_size)[ :, cp_rank, ...