Skip to content

Commit

Permalink
packed sequence support
Browse files Browse the repository at this point in the history
Signed-off-by: Zeeshan Patel <[email protected]>
  • Loading branch information
Zeeshan Patel committed Nov 8, 2024
1 parent 628ecc6 commit 86d59ad
Show file tree
Hide file tree
Showing 3 changed files with 247 additions and 32 deletions.
55 changes: 43 additions & 12 deletions nemo/collections/diffusion/data/diffusion_energon_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import logging
from typing import Any, Dict, Literal

from megatron.energon import DefaultTaskEncoder, get_train_dataset
from megatron.core import parallel_state
from megatron.energon import DefaultTaskEncoder, WorkerConfig, get_savable_loader, get_train_dataset
from pytorch_lightning.utilities.types import EVAL_DATALOADERS

from nemo.collections.multimodal.data.energon.base import SimpleMultiModalDataModule
Expand Down Expand Up @@ -56,6 +57,9 @@ def __init__(
pin_memory: bool = True,
task_encoder: DefaultTaskEncoder = None,
use_train_split_for_val: bool = False,
virtual_epoch_length: int = 1_000_000_000, # a hack to avoid energon end of epoch warning
packing_buffer_size: int | None = None,
max_samples_per_sequence: int | None = None,
) -> None:
"""
Initialize the SimpleMultiModalDataModule.
Expand All @@ -82,6 +86,10 @@ def __init__(
task_encoder=task_encoder,
)
self.use_train_split_for_val = use_train_split_for_val
self.virtual_epoch_length = virtual_epoch_length
self.num_workers_val = 1
self.packing_buffer_size = packing_buffer_size
self.max_samples_per_sequence = max_samples_per_sequence

def datasets_provider(self, worker_config, split: Literal['train', 'val'] = 'val'):
"""
Expand All @@ -106,29 +114,52 @@ def datasets_provider(self, worker_config, split: Literal['train', 'val'] = 'val
batch_size=self.micro_batch_size,
task_encoder=self.task_encoder,
worker_config=worker_config,
max_samples_per_sequence=None,
shuffle_buffer_size=100,
max_samples_per_sequence=self.max_samples_per_sequence,
shuffle_buffer_size=None,
split_part=split,
batch_drop_last=True,
virtual_epoch_length=1_000_000_000, # a hack to avoid energon end of epoch warning
virtual_epoch_length=self.virtual_epoch_length,
packing_buffer_size=self.packing_buffer_size,
)
return _dataset

def val_dataloader(self) -> EVAL_DATALOADERS:
"""
Configure the validation DataLoader.
Initialize and return the validation DataLoader.
This method configures the DataLoader for validation data.
Parameters:
worker_config: Configuration for the data loader workers.
This method initializes the DataLoader for the validation dataset. It ensures that the parallel state
is initialized correctly for distributed training and returns a configured DataLoader object.
Returns:
DataLoader: The DataLoader for validation data.
EVAL_DATALOADERS: The DataLoader for the validation dataset.
"""
if self.use_train_split_for_val:
return self.train_dataloader()
return super().val_dataloader()
if self.val_dataloader_object:
return self.val_dataloader_object

if not parallel_state.is_initialized():
logging.info(
f"Muiltimodal val data loader parallel state is not initialized, using default worker config with no_workers {self.num_workers}"
)
worker_config = WorkerConfig.default_worker_config(self.num_workers_val)
else:
rank = parallel_state.get_data_parallel_rank()
world_size = parallel_state.get_data_parallel_world_size()
data_parallel_group = parallel_state.get_data_parallel_group()

logging.info(f"rank {rank} world_size {world_size} data_parallel_group {data_parallel_group}")
worker_config = WorkerConfig(
rank=rank,
world_size=world_size,
num_workers=self.num_workers_val,
data_parallel_group=data_parallel_group,
worker_debug_path=None,
worker_log_level=0,
)
val_dataset = self.datasets_provider(worker_config, split='val')
energon_loader = get_savable_loader(val_dataset, worker_config=worker_config)
self.val_dataloader_object = energon_loader
return self.val_dataloader_object

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""
Expand Down
210 changes: 196 additions & 14 deletions nemo/collections/diffusion/data/diffusion_taskencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,74 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
import random
from dataclasses import dataclass
from typing import Any, List, Optional

import torch
import torch.nn.functional as F
from einops import rearrange
from megatron.core import parallel_state
from megatron.energon import DefaultTaskEncoder, SkipSample
from megatron.energon import DefaultTaskEncoder, Sample, SkipSample
from megatron.energon.task_encoder.base import stateless
from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys

from nemo.lightning.io.mixin import IOMixin
from nemo.utils.sequence_packing_utils import first_fit_decreasing


@dataclass
class DiffusionSample(Sample):
video: torch.Tensor # video latents (C T H W)
t5_text_embeddings: torch.Tensor # (S D)
t5_text_mask: torch.Tensor # 1
loss_mask: torch.Tensor
image_size: Optional[torch.Tensor] = None
fps: Optional[torch.Tensor] = None
num_frames: Optional[torch.Tensor] = None
padding_mask: Optional[torch.Tensor] = None
seq_len_q: Optional[torch.Tensor] = None
seq_len_kv: Optional[torch.Tensor] = None
pos_ids: Optional[torch.Tensor] = None
latent_shape: Optional[torch.Tensor] = None

def to_dict(self) -> dict:
return dict(
video=self.video,
t5_text_embeddings=self.t5_text_embeddings,
t5_text_mask=self.t5_text_mask,
loss_mask=self.loss_mask,
image_size=self.image_size,
fps=self.fps,
num_frames=self.num_frames,
padding_mask=self.padding_mask,
seq_len_q=self.seq_len_q,
seq_len_kv=self.seq_len_kv,
pos_ids=self.pos_ids,
latent_shape=self.latent_shape,
)

def __add__(self, other: Any) -> int:
if isinstance(other, DiffusionSample):
# Combine the values of the two instances
return self.seq_len_q.item() + other.seq_len_q.item()
elif isinstance(other, int):
# Add an integer to the value
return self.seq_len_q.item() + other
raise NotImplementedError

def __radd__(self, other: Any) -> int:
# This is called if sum or other operations start with a non-DiffusionSample object.
# e.g., sum([DiffusionSample(1), DiffusionSample(2)]) -> the 0 + DiffusionSample(1) calls __radd__.
if isinstance(other, int):
return self.seq_len_q.item() + other
raise NotImplementedError

def __lt__(self, other: Any) -> bool:
if isinstance(other, DiffusionSample):
return self.seq_len_q.item() < other.seq_len_q.item()
elif isinstance(other, int):
return self.seq_len_q.item() < other
raise NotImplementedError


def cook(sample: dict) -> dict:
Expand Down Expand Up @@ -75,17 +134,22 @@ def __init__(
max_frames: int = None,
text_embedding_padding_size: int = 512,
seq_length: int = None,
max_seq_length: int = None,
patch_spatial: int = 2,
patch_temporal: int = 1,
aesthetic_score: float = 0.0,
**kwargs,
):
super().__init__(*args, **kwargs)
self.max_frames = max_frames
self.text_embedding_padding_size = text_embedding_padding_size
self.seq_length = seq_length
self.max_seq_length = max_seq_length
self.patch_spatial = patch_spatial
self.patch_temporal = patch_temporal
self.aesthetic_score = aesthetic_score

@stateless(restore_seeds=True)
def encode_sample(self, sample: dict) -> dict:
video_latent = sample['pth']

Expand All @@ -95,6 +159,9 @@ def encode_sample(self, sample: dict) -> dict:
raise SkipSample()

info = sample['json']
if info['aesthetic_score'] < self.aesthetic_score:
raise SkipSample()

C, T, H, W = video_latent.shape
seq_len = (
video_latent.shape[-1]
Expand All @@ -105,19 +172,14 @@ def encode_sample(self, sample: dict) -> dict:
)
is_image = T == 1

if seq_len > self.seq_length:
if self.seq_length is not None and seq_len > self.seq_length:
raise SkipSample()
if self.max_seq_length is not None and seq_len > self.max_seq_length:
raise SkipSample()

if self.max_frames is not None:
video_latent = video_latent[:, : self.max_frames, :, :]

tpcp_size = parallel_state.get_tensor_model_parallel_world_size()
if parallel_state.get_context_parallel_world_size() > 1:
tpcp_size *= parallel_state.get_context_parallel_world_size() * 2
if (T * H * W) % tpcp_size != 0:
warnings.warn(f'skipping {video_latent.shape=} not divisible by {tpcp_size=}')
raise SkipSample()

video_latent = rearrange(
video_latent,
'C (T pt) (H ph) (W pw) -> (T H W) (ph pw pt C)',
Expand Down Expand Up @@ -161,15 +223,19 @@ def encode_sample(self, sample: dict) -> dict:
'T H W d -> (T H W) d',
)

if self.seq_length is not None:
if self.seq_length is not None and self.max_seq_length is None:
pos_ids = F.pad(pos_ids, (0, 0, 0, self.seq_length - seq_len))
loss_mask = torch.zeros(self.seq_length, dtype=torch.bfloat16)
loss_mask[:seq_len] = 1
video_latent = F.pad(video_latent, (0, 0, 0, self.seq_length - seq_len))
else:
loss_mask = torch.ones(seq_len, dtype=torch.bfloat16)

return dict(
return DiffusionSample(
__key__=sample['__key__'],
__restore_key__=sample['__restore_key__'],
__subflavor__=None,
__subflavors__=sample['__subflavors__'],
video=video_latent,
t5_text_embeddings=t5_text_embeddings,
t5_text_mask=t5_text_mask,
Expand All @@ -178,11 +244,71 @@ def encode_sample(self, sample: dict) -> dict:
num_frames=num_frames,
loss_mask=loss_mask,
seq_len_q=torch.tensor(seq_len, dtype=torch.int32),
seq_len_kv=torch.tensor(t5_text_embeddings_seq_length, dtype=torch.int32),
seq_len_kv=torch.tensor(self.text_embedding_padding_size, dtype=torch.int32),
pos_ids=pos_ids,
latent_shape=torch.tensor([C, T, H, W], dtype=torch.int32),
)

def select_samples_to_pack(self, samples: List[DiffusionSample]) -> List[List[DiffusionSample]]:
results = first_fit_decreasing(samples, self.max_seq_length)
random.shuffle(results)
return results

@stateless
def pack_selected_samples(self, samples: List[DiffusionSample]) -> DiffusionSample:
# Construct a new CaptioningSample by concatenating the captions

def stack(attr):
return torch.stack([getattr(sample, attr) for sample in samples], dim=0)

def cat(attr):
return torch.cat([getattr(sample, attr) for sample in samples], dim=0)

video = concat_pad([i.video for i in samples], self.max_seq_length)
loss_mask = concat_pad([i.loss_mask for i in samples], self.max_seq_length)
pos_ids = concat_pad([i.pos_ids for i in samples], self.max_seq_length)

return DiffusionSample(
__key__=",".join([s.__key__ for s in samples]),
__restore_key__=(), # Will be set by energon based on `samples`
__subflavor__=None,
__subflavors__=samples[0].__subflavors__,
video=video,
t5_text_embeddings=cat('t5_text_embeddings'),
t5_text_mask=cat('t5_text_mask'),
# image_size=stack('image_size'),
# fps=stack('fps'),
# num_frames=stack('num_frames'),
loss_mask=loss_mask,
seq_len_q=stack('seq_len_q'),
seq_len_kv=stack('seq_len_kv'),
pos_ids=pos_ids,
latent_shape=stack('latent_shape'),
)

@stateless
def batch(self, samples: List[DiffusionSample]) -> dict:
if self.max_seq_length is None:
# no packing
return super().batch(samples).to_dict()

# packing
sample = samples[0]
return dict(
video=sample.video.unsqueeze_(0),
t5_text_embeddings=sample.t5_text_embeddings.unsqueeze_(0),
t5_text_mask=sample.t5_text_mask.unsqueeze_(0),
loss_mask=sample.loss_mask.unsqueeze_(0),
# image_size=sample.image_size,
# fps=sample.fps,
# num_frames=sample.num_frames,
# padding_mask=sample.padding_mask.unsqueeze_(0),
seq_len_q=sample.seq_len_q,
seq_len_kv=sample.seq_len_kv,
pos_ids=sample.pos_ids.unsqueeze_(0),
latent_shape=sample.latent_shape,
)


class PosID3D:
def __init__(self, *, max_t=32, max_h=128, max_w=128):
Expand Down Expand Up @@ -210,4 +336,60 @@ def get_pos_id_3d(self, *, t, h, w):
return self.grid[:t, :h, :w]


def pad_divisible(x, padding_value=0):
if padding_value == 0:
return x
# Get the size of the first dimension
n = x.size(0)

# Compute the padding needed to make the first dimension divisible by 16
padding_needed = (padding_value - n % padding_value) % padding_value

if padding_needed <= 0:
return x

# Create a new shape with the padded first dimension
new_shape = list(x.shape)
new_shape[0] += padding_needed

# Create a new tensor filled with zeros
x_padded = torch.zeros(new_shape, dtype=x.dtype, device=x.device)

# Assign the original tensor to the beginning of the new tensor
x_padded[:n] = x
return x_padded


def concat_pad(tensor_list, max_seq_length):
"""
Efficiently concatenates a list of tensors along the first dimension and pads with zeros
to reach max_seq_length.
Args:
tensor_list (list of torch.Tensor): List of tensors to concatenate and pad.
max_seq_length (int): The desired size of the first dimension of the output tensor.
Returns:
torch.Tensor: A tensor of shape [max_seq_length, ...], where ... represents the remaining dimensions.
"""
import torch

# Get common properties from the first tensor
other_shape = tensor_list[0].shape[1:]
dtype = tensor_list[0].dtype
device = tensor_list[0].device

# Initialize the result tensor with zeros
result = torch.zeros((max_seq_length, *other_shape), dtype=dtype, device=device)

current_index = 0
for tensor in tensor_list:
length = tensor.shape[0]
# Directly assign the tensor to the result tensor without checks
result[current_index : current_index + length] = tensor
current_index += length

return result


pos_id_3d = PosID3D()
Loading

0 comments on commit 86d59ad

Please sign in to comment.