Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add stdit model #11528

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions nemo/collections/diffusion/data/diffusion_fake_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,52 @@ def collate_fn(self, batch):
return self._collate_fn(batch)


class STDiTVideoLatentFakeDataset(DiTVideoLatentFakeDataset):
def __init__(
self,
n_frames,
max_h,
max_w,
patch_size,
in_channels,
crossattn_emb_size,
max_text_seqlen=512,
seq_length=8192,
):
super().__init__(
n_frames=n_frames,
max_h=max_h,
max_w=max_w,
patch_size=patch_size,
in_channels=in_channels,
crossattn_emb_size=crossattn_emb_size,
max_text_seqlen=max_text_seqlen,
seq_length=seq_length,
)

def __getitem__(self, idx):
t = self.max_t
h = self.max_height
w = self.max_width
p = self.patch_size
c = self.in_channels

video_latent = torch.ones((c, t, h, w), dtype=torch.bfloat16) * 0.5
text_embedding = torch.randn(self.text_seqlen, self.text_dim, dtype=torch.bfloat16)

# calculate seq_length
seq_length = t * (h // p) * (w // p)

return {
'video': video_latent,
't5_text_embeddings': text_embedding,
'seq_len_q': torch.tensor([seq_length], dtype=torch.int32).squeeze(),
'seq_len_kv': torch.tensor([self.text_seqlen], dtype=torch.int32).squeeze(),
'pos_ids': torch.zeros((seq_length, 3), dtype=torch.int32),
'loss_mask': torch.ones(seq_length, dtype=torch.bfloat16),
}


class VideoLatentFakeDataModule(pl.LightningDataModule):
"""A LightningDataModule for generating fake video latent data for training."""

Expand Down Expand Up @@ -216,3 +262,40 @@ def _create_dataloader(self, dataset, **kwargs) -> DataLoader:
collate_fn=dataset.collate_fn,
**kwargs,
)


class STDiTLatentFakeDataModule(VideoLatentFakeDataModule):
def __init__(
self,
model_config: DiTConfig,
seq_length: int = 2048,
micro_batch_size: int = 1,
global_batch_size: int = 8,
num_workers: int = 1,
pin_memory: bool = True,
task_encoder=None,
use_train_split_for_val: bool = False,
):
super().__init__()
self.seq_length = seq_length
self.micro_batch_size = micro_batch_size
self.global_batch_size = global_batch_size
self.num_workers = num_workers
self.model_config = model_config

self.data_sampler = MegatronDataSampler(
seq_len=self.seq_length,
micro_batch_size=micro_batch_size,
global_batch_size=global_batch_size,
)

def setup(self, stage: str = "") -> None:
self._train_ds = STDiTVideoLatentFakeDataset(
n_frames=self.model_config.max_frames,
max_h=self.model_config.max_img_h,
max_w=self.model_config.max_img_w,
patch_size=self.model_config.patch_spatial,
in_channels=self.model_config.in_channels,
crossattn_emb_size=self.model_config.crossattn_emb_size,
seq_length=self.seq_length,
)
117 changes: 117 additions & 0 deletions nemo/collections/diffusion/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from nemo.lightning.pytorch.optim import OptimizerModule

from .dit.dit_model import DiTCrossAttentionModel
from .stdit.stdit_model import STDiTModel


def dit_forward_step(model, batch) -> torch.Tensor:
Expand Down Expand Up @@ -73,6 +74,47 @@ def dit_data_step(module, dataloader_iter):
return batch


def stdit_data_step(module, dataloader_iter):
batch = next(dataloader_iter)[0]
batch = stdit_get_batch_on_this_cp_rank(batch)
batch = {k: v.to(device='cuda', non_blocking=True) if torch.is_tensor(v) else v for k, v in batch.items()}

return batch


def stdit_get_batch_on_this_cp_rank(data: Dict):
"""Split the data for context parallelism."""
from megatron.core import mpu

cp_size = mpu.get_context_parallel_world_size()
cp_rank = mpu.get_context_parallel_rank()

if cp_size > 1:
num_valid_tokens_in_ub = None
if 'loss_mask' in data and data['loss_mask'] is not None:
num_valid_tokens_in_ub = data['loss_mask'].sum()

for key, value in data.items():
if (value is not None) and (key in ['video', 'video_latent', 'noise_latent', 'pos_ids']):
if len(value.shape) > 5:
value = value.squeeze(0)
if len(value.shape) == 5:
# split in temporal dimension
B, C, T, H, W = value.shape
data[key] = torch.chunk(value, cp_size, dim=3)[cp_rank].contiguous()
else:
# Todo: need to know T, H, W dimension size
B, S, D = value.shape
data[key] = value.view(B, cp_size, S // cp_size, D)[:, cp_rank, ...].contiguous()
# TODO: sequence packing
loss_mask = data["loss_mask"]
data["loss_mask"] = loss_mask.view(loss_mask.shape[0], cp_size, loss_mask.shape[1] // cp_size)[
:, cp_rank, ...
].contiguous()
data['num_valid_tokens_in_ub'] = num_valid_tokens_in_ub
return data


def get_batch_on_this_cp_rank(data: Dict):
"""Split the data for context parallelism."""
from megatron.core import mpu
Expand Down Expand Up @@ -283,6 +325,81 @@ class ECDiTLlama1BConfig(DiTLlama1BConfig):
ffn_hidden_size: int = 1024


@dataclass
class STDiTConfig(DiTConfig):

# model set
num_layers: int = 28
hidden_size: int = 1152
num_attention_heads: int = 16
crossattn_emb_size: int = 1024
ffn_hidden_size: int = 4608

add_bias_linear: bool = True

# video set
max_img_h: int = 128
max_img_w: int = 128
max_frames: int = 24
patch_spatial: int = 2

dynamic_sequence_parallel: bool = False

@override
def configure_model(self, tokenizer=None) -> STDiTModel:
vp_size = self.virtual_pipeline_model_parallel_size
if vp_size:
p_size = self.pipeline_model_parallel_size
assert (
self.num_layers // p_size
) % vp_size == 0, "Make sure the number of model chunks is the same across all pipeline stages."

model = STDiTModel

return model(
self,
fp16_lm_cross_entropy=self.fp16_lm_cross_entropy,
parallel_output=self.parallel_output,
pre_process=parallel_state.is_pipeline_first_stage(),
post_process=parallel_state.is_pipeline_last_stage(),
max_img_h=self.max_img_h,
max_img_w=self.max_img_w,
max_frames=self.max_frames,
patch_spatial=self.patch_spatial,
dynamic_sequence_parallel=self.dynamic_sequence_parallel,
)


@dataclass
class STDiTV3_XLConfig(STDiTConfig):

num_layers: int = 28
hidden_size: int = 1152
num_attention_heads: int = 16
crossattn_emb_size: int = 1024
ffn_hidden_size: int = 4608


@dataclass
class STDiTXLConfig(STDiTConfig):

num_layers: int = 24
hidden_size: int = 1536
num_attention_heads: int = 12
crossattn_emb_size: int = 1024
ffn_hidden_size: int = 6144


@dataclass
class STDiT3BConfig(STDiTConfig):

num_layers: int = 24
hidden_size: int = 2048
num_attention_heads: int = 16
crossattn_emb_size: int = 1024
ffn_hidden_size: int = 8192


class DiTModel(GPTModel):
"""
Diffusion Transformer Model
Expand Down
13 changes: 13 additions & 0 deletions nemo/collections/diffusion/models/stdit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading
Loading