Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Neva ETP EPP support #12154

Open
wants to merge 103 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
103 commits
Select commit Hold shift + click to select a range
15439bf
api updates and fixes
yaoyu-33 Dec 12, 2024
6bfd873
Apply isort and black reformatting
yaoyu-33 Dec 12, 2024
773b4c9
fix
yaoyu-33 Dec 12, 2024
3a1a017
fix arg
yaoyu-33 Dec 12, 2024
e3e87b7
update seq packing in mock ds
yaoyu-33 Dec 16, 2024
4ee633c
Merge branch 'main' into yuya/neva2_seq_packing
yaoyu-33 Dec 16, 2024
ecc813d
Merge branch 'main' into yuya/neva2_seq_packing
yaoyu-33 Dec 17, 2024
c10157c
save
yaoyu-33 Dec 17, 2024
84eb7cc
update preprocess_data
yaoyu-33 Dec 17, 2024
3bf6442
update seq packing
yaoyu-33 Dec 17, 2024
c8a26af
Apply isort and black reformatting
yaoyu-33 Dec 17, 2024
48b5261
fix sp
yaoyu-33 Dec 17, 2024
365c051
Merge remote-tracking branch 'origin/yuya/neva2_seq_packing' into yuy…
yaoyu-33 Dec 17, 2024
7da82ed
save
yaoyu-33 Dec 18, 2024
4127c40
Merge branch 'main' into yuya/neva2_seq_packing
yaoyu-33 Dec 18, 2024
c5d26c3
fix seq packing
yaoyu-33 Dec 18, 2024
ecd461f
add truncation and padding
yaoyu-33 Dec 19, 2024
5e0a168
Apply isort and black reformatting
yaoyu-33 Dec 19, 2024
9240a79
Fix issues
yaoyu-33 Dec 19, 2024
4808999
change LLaVATemplateConfig variables to class variables
yashaswikarnati Dec 19, 2024
c4d92f9
change to use field with default attributes
yashaswikarnati Dec 19, 2024
ad44132
Apply isort and black reformatting
yashaswikarnati Dec 19, 2024
7db8e52
Merge remote-tracking branch 'origin/yuya/neva2_seq_packing' into yuy…
yaoyu-33 Dec 19, 2024
e705afe
Apply isort and black reformatting
yaoyu-33 Dec 19, 2024
f0a9cb5
Merge remote-tracking branch 'origin/yash/fix_template_dataclass' int…
yaoyu-33 Dec 19, 2024
f508f8b
Initial support for CP
parthmannan Dec 31, 2024
7415036
Add seq packing option in energon
yaoyu-33 Dec 31, 2024
af1f32a
Fix energon conversation
yaoyu-33 Dec 31, 2024
568f9aa
add energon option in neva training script
yaoyu-33 Dec 31, 2024
01fd6cf
Apply isort and black reformatting
yaoyu-33 Dec 31, 2024
2a38eb6
Apply isort and black reformatting
parthmannan Dec 31, 2024
bd0179c
Improvements
parthmannan Jan 3, 2025
e491d3e
Merge branch 'pmannan/neva_cp_seq_packing' of https://github.com/NVID…
parthmannan Jan 3, 2025
c69b9e8
Merge branch 'yuya/neva2_seq_packing' of https://github.com/NVIDIA/Ne…
parthmannan Jan 3, 2025
094ef9a
add ci test for packed seq
yaoyu-33 Jan 3, 2025
325a9a1
Fix for PP+CP
parthmannan Jan 6, 2025
8b67987
Max seq len fix
parthmannan Jan 6, 2025
626bbc3
fix mock dataset seq packing
yaoyu-33 Jan 7, 2025
18aa644
Apply isort and black reformatting
yaoyu-33 Jan 7, 2025
b4f7e8b
fix mock dataset seq packing
yaoyu-33 Jan 7, 2025
2ccea79
Merge remote-tracking branch 'origin/yuya/neva2_seq_packing' into yuy…
yaoyu-33 Jan 7, 2025
a2a4000
Apply isort and black reformatting
yaoyu-33 Jan 7, 2025
0599b5a
Merge branch 'main' into yuya/neva2_seq_packing
yaoyu-33 Jan 7, 2025
90778e1
fix lint and update seq pack func
yaoyu-33 Jan 7, 2025
38a6c49
Merge remote-tracking branch 'origin/yuya/neva2_seq_packing' into yuy…
yaoyu-33 Jan 7, 2025
f0ec5f1
fix energon module
yaoyu-33 Jan 7, 2025
ff45f7e
Apply isort and black reformatting
yaoyu-33 Jan 7, 2025
b1c6af9
Merge branch 'refs/heads/yuya/neva2_seq_packing' into pmannan/neva_cp
yaoyu-33 Jan 8, 2025
38e42a2
fix comments
yaoyu-33 Jan 8, 2025
eadc665
Apply isort and black reformatting
yaoyu-33 Jan 8, 2025
846252f
address lightning issues
yaoyu-33 Jan 8, 2025
d70e432
Merge remote-tracking branch 'origin/yuya/neva2_seq_packing' into yuy…
yaoyu-33 Jan 8, 2025
a2290de
Apply isort and black reformatting
yaoyu-33 Jan 8, 2025
d9b7520
reformat
yaoyu-33 Jan 8, 2025
12023b9
fix import
yaoyu-33 Jan 8, 2025
6db7636
rename to base.py
yaoyu-33 Jan 8, 2025
fc8e6da
fix few issues from importing
yaoyu-33 Jan 9, 2025
2385eaa
temp save for intern vit
yaoyu-33 Jan 9, 2025
3fdfe3e
Update sequence_packing.py
yaoyu-33 Jan 9, 2025
96fd7be
save for partially working internvit
yaoyu-33 Jan 9, 2025
04c6796
added support for importing clip vit
yashaswikarnati Jan 10, 2025
5bd3723
Move projector
yaoyu-33 Jan 10, 2025
c8f11fb
fix intern_vit
yaoyu-33 Jan 10, 2025
9249af5
fix intern_vit conversion
yaoyu-33 Jan 10, 2025
d24cd3b
Merge branch 'main' into yuya/neva2_seq_packing
yaoyu-33 Jan 13, 2025
105e455
update intern vit
yaoyu-33 Jan 13, 2025
a68e41f
update energon requirements
yaoyu-33 Jan 13, 2025
c04f1ed
Fix for energon update
yaoyu-33 Jan 13, 2025
3a855bc
update tp logic
yaoyu-33 Jan 14, 2025
41ab130
fix for test
yaoyu-33 Jan 14, 2025
826a9ab
Apply isort and black reformatting
yaoyu-33 Jan 14, 2025
12b3521
Apply isort and black reformatting
yaoyu-33 Jan 14, 2025
421c789
Merge branch 'yuya/neva2_seq_packing' into pmannan/neva_cp
yaoyu-33 Jan 14, 2025
2e674c0
Merge remote-tracking branch 'origin/pmannan/neva_cp' into pmannan/ne…
yaoyu-33 Jan 14, 2025
4a4a91b
Merge branch 'main' into yuya/refactor_vlm_vision_module
yaoyu-33 Jan 15, 2025
ae6f5f7
Merge remote-tracking branch 'origin/main' into pmannan/neva_cp
yaoyu-33 Jan 15, 2025
65233cb
revert overlap config change
yaoyu-33 Jan 15, 2025
efbbdbf
fix neva generate
yaoyu-33 Jan 21, 2025
5517510
remove not used module
yaoyu-33 Jan 21, 2025
49854a0
update encoder / decoder seq len settings
yaoyu-33 Jan 21, 2025
5d094a6
temp save
yaoyu-33 Jan 21, 2025
bc52537
Merge branch 'main' into yuya/refactor_vlm_vision_module
yaoyu-33 Jan 21, 2025
309461f
update logging
yaoyu-33 Jan 21, 2025
a610a88
Apply isort and black reformatting
yaoyu-33 Jan 21, 2025
e4fe166
update init / fix unused layer bug
yaoyu-33 Jan 22, 2025
c33bf50
Apply isort and black reformatting
yaoyu-33 Jan 22, 2025
edd6aab
remove not used import
yaoyu-33 Jan 22, 2025
4fad475
Merge remote-tracking branch 'origin/yuya/refactor_vlm_vision_module'…
yaoyu-33 Jan 22, 2025
1c47a2d
Update for Siglip
yaoyu-33 Jan 22, 2025
25f1247
Apply isort and black reformatting
yaoyu-33 Jan 22, 2025
e7600e3
update init
yaoyu-33 Jan 22, 2025
01e4cd7
Merge remote-tracking branch 'origin/yuya/refactor_vlm_vision_module'…
yaoyu-33 Jan 22, 2025
46b2332
Apply isort and black reformatting
yaoyu-33 Jan 22, 2025
c364196
Fix logging
yaoyu-33 Jan 23, 2025
cfff3e0
Apply isort and black reformatting
yaoyu-33 Jan 23, 2025
8e9d8ed
Merge branch 'yuya/refactor_vlm_vision_module' into yuya/neva_epp_etp
yaoyu-33 Jan 29, 2025
789dfe8
fix hang
yaoyu-33 Feb 4, 2025
e5d85f2
fix
yaoyu-33 Feb 4, 2025
4932205
fix hang again
yaoyu-33 Feb 5, 2025
cfe337f
fix few issues
yaoyu-33 Feb 7, 2025
3296be1
turn off overlap
yaoyu-33 Feb 7, 2025
462fd23
Merge branch 'refs/heads/main' into yuya/neva_epp_etp
yaoyu-33 Feb 12, 2025
8698719
Apply isort and black reformatting
yaoyu-33 Feb 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions nemo/collections/llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,10 +908,17 @@ def _validate_config(
assert trainer.strategy.pipeline_model_parallel_size > 0
assert trainer.strategy.context_parallel_size > 0

encoder_tensor_model_parallel_size = trainer.strategy.encoder_tensor_model_parallel_size
# By default, encoder has the same TP size as decoder
if encoder_tensor_model_parallel_size == 0:
encoder_tensor_model_parallel_size = trainer.strategy.tensor_model_parallel_size

# DP validation
assert (trainer.num_devices * trainer.num_nodes) % (
trainer.strategy.tensor_model_parallel_size
* trainer.strategy.pipeline_model_parallel_size
(
trainer.strategy.tensor_model_parallel_size * trainer.strategy.pipeline_model_parallel_size
+ encoder_tensor_model_parallel_size * trainer.strategy.encoder_pipeline_model_parallel_size
)
* trainer.strategy.context_parallel_size
) == 0, "Number of GPUs must be divisible by the product of all parallelism sizes for data parallel."

Expand All @@ -922,8 +929,11 @@ def _validate_config(
* (
(trainer.num_devices * trainer.num_nodes)
/ (
trainer.strategy.tensor_model_parallel_size
* trainer.strategy.pipeline_model_parallel_size
(
trainer.strategy.tensor_model_parallel_size * trainer.strategy.pipeline_model_parallel_size
+ encoder_tensor_model_parallel_size
* trainer.strategy.encoder_pipeline_model_parallel_size
)
* trainer.strategy.context_parallel_size
)
)
Expand Down
17 changes: 0 additions & 17 deletions nemo/collections/multimodal/data/energon/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,20 +413,3 @@ def transform_dataloader(self, dataloader: DataLoader) -> DataLoader:
DataLoader: The transformed DataLoader.
"""
return dataloader

@property
def megatron_data_kwargs(self) -> Dict[str, Any]:
"""
Return the keyword arguments required for Megatron data handling.

This property provides the necessary arguments that Megatron uses to handle data, including sequence length,
micro-batch size, and the number of micro-batches.

Returns:
Dict[str, Any]: A dictionary containing the Megatron data handling arguments.
"""
return {
"seq_length": self.seq_len,
"micro_batch_size": self.micro_batch_size,
"num_microbatches": self.num_microbatches,
}
Original file line number Diff line number Diff line change
Expand Up @@ -943,7 +943,6 @@ def fwd_bwd_step(self, dataloader_iter, forward_only, first_val_step=None):
# run forward and backwards passes for an entire global batch
# we do this inside training_step to support pipeline parallelism
fwd_bwd_function = get_forward_backward_func()
# print(f"{torch.distributed.get_rank()}: {parallel_state.is_pipeline_last_stage()} {fwd_bwd_function}")

# TODO @akhattar: add num_micro_batches_with_partial_activation_checkpoints when ready
losses_reduced_per_micro_batch = fwd_bwd_function(
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/vlm/neva/data/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(
)

def setup(self, stage: str = "") -> None:
seq_length = self.seq_length
seq_length = self.decoder_seq_len or self.seq_length
if self.packed_sequence and self.micro_batch_size > 1:
seq_length = seq_length // self.micro_batch_size
logging.warning(
Expand Down
25 changes: 22 additions & 3 deletions nemo/collections/vlm/neva/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,8 @@ def configure_model(self, tokenizer) -> "MCoreNevaModel":
model = MCoreNevaModel(
config=self,
tokenizer=tokenizer,
pre_process=ps.is_pipeline_first_stage(),
pre_process=ps.is_pipeline_first_stage()
or ps.get_pipeline_model_parallel_rank() == self.encoder_pipeline_model_parallel_size,
post_process=ps.is_pipeline_last_stage(),
add_encoder=ps.is_pipeline_first_stage(),
add_decoder=ps.is_pipeline_last_stage()
Expand Down Expand Up @@ -361,14 +362,15 @@ def __init__(
freeze_vision_projection=config.freeze_vision_projection,
)

self.model_type = ModelType.encoder_or_decoder
self.model_type = ModelType.encoder_and_decoder
# This attribute is needed to check if an all-reduce is required
# on the word embeddings inside `finalize_model_grads._allreduce_word_embedding_grads`.

self.vision_model_from_hf = hasattr(vision_transformer_config, "image_size")
self._img_seq_len = vision_transformer_config.num_image_embeddings_per_tile
if drop_vision_class_token and vision_transformer_config.add_class_token:
self._img_seq_len -= vision_transformer_config.class_token_len
self._language_hidden_size = language_transformer_config.hidden_size

def forward(
self,
Expand Down Expand Up @@ -426,7 +428,9 @@ def forward(
elif self.add_encoder and not has_images:
vision_param = next(self.vision_model.parameters())
# If no images provided, use an empty image embeddings tensor.
image_embeddings = torch.tensor([], dtype=vision_param.dtype, device=vision_param.device).reshape(0, 0, 0)
image_embeddings = torch.tensor([], dtype=vision_param.dtype, device=vision_param.device).reshape(
self._img_seq_len, 0, self._language_hidden_size
)
elif self.add_encoder and has_images:
# images is in shape of (num_images_in_mbs, c, h, w)
# note num_images_in_mbs is not mbs but total images in this mbs.
Expand Down Expand Up @@ -459,8 +463,23 @@ def forward(
)
else:
image_embeddings = self.encoder_hidden_state
if self.config.encoder_pipeline_model_parallel_size > 0:
num_images = len(images) if images is not None else 0
image_embeddings = image_embeddings[:, :num_images]

if not self.add_decoder:
if self.config.encoder_pipeline_model_parallel_size > 0:
_, num_images, _ = image_embeddings.shape
pad_amount = max(input_ids.size(0) - num_images, 0)
if pad_amount > 0:
pad_tensor = torch.zeros(
self._img_seq_len,
pad_amount,
self._language_hidden_size,
dtype=image_embeddings.dtype,
device=image_embeddings.device,
)
image_embeddings = torch.cat([image_embeddings, pad_tensor], dim=1)
return image_embeddings

language_embeddings = None
Expand Down
1 change: 1 addition & 0 deletions nemo/collections/vlm/vision/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def configure_model(self) -> "MCoreMultimodalProjector":
if self.projector_type.startswith("mcore") and self.layer_spec is None:
if self.projector_type == "mcore_mlp":
self.projector_type = "mlp" # strip "mcore_" for mcore init
self.add_bias_linear = self.bias
self.layer_spec = ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
Expand Down
3 changes: 3 additions & 0 deletions nemo/collections/vlm/vision/intern_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def get_mlp_module_spec(use_te: bool = True) -> ModuleSpec:
)


# Handle InternViT's layer scaling.
def _bias_dropout_add_func_internvit(ls, x_with_bias, residual, prob, training):
"""Handle InternViT's layer scaling."""
x, bias = x_with_bias # unpack
Expand Down Expand Up @@ -199,6 +200,7 @@ def get_bias_dropout_add_internvit(ls, training, fused):
return bias_dropout_add_unfused_internvit(ls, training)


# Add InternViT specialties to our default TransformerLayer.
class InternViTTransformerLayer(TransformerLayer):
"""Add InternViT specialties to our default TransformerLayer."""

Expand All @@ -212,6 +214,7 @@ def __init__(self, *args, **kwargs):
self.mlp_bda = partial(self.mlp_bda, self.ls2)


# Override a few things that are special in InternViT and not supported by the SelfAttention class.
class InternViTSelfAttention(SelfAttention):
"""Override a few things that are special in InternViT and not supported by the SelfAttention class."""

Expand Down
22 changes: 17 additions & 5 deletions nemo/lightning/_strategy_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,12 +620,24 @@ def _sync_from_last_pipeline_stage(value: torch.Tensor, broadcast: bool = False)
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
src_rank = parallel_state.get_pipeline_model_parallel_last_rank()

if not isinstance(src_rank, list):
src_rank = [src_rank]

if not broadcast:
pp_ranks = torch.distributed.get_process_group_ranks(parallel_state.get_pipeline_model_parallel_group())
if torch.distributed.get_rank() == src_rank and 0 in pp_ranks:
torch.distributed.send(value, 0)
elif torch.distributed.get_rank() == 0:
torch.distributed.recv(value, src_rank)
group = parallel_state.get_pipeline_model_parallel_group()
if isinstance(group, list):
pp_ranks = []
for g in group:
pp_ranks.append(torch.distributed.get_process_group_ranks(g))
else:
pp_ranks = torch.distributed.get_process_group_ranks(group)

for src_rank_idx in src_rank:
if torch.distributed.get_rank() == 0:
torch.distributed.recv(value, src_rank_idx)
elif torch.distributed.get_rank() == src_rank_idx and 0 in pp_ranks:
torch.distributed.send(value, 0)

else:
torch.distributed.broadcast(
value,
Expand Down
4 changes: 2 additions & 2 deletions nemo/lightning/pytorch/callbacks/megatron_comm_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ def _get_optimizer_overlap_cfgs(self, parallelism_cfg: ParallelismConfig) -> _Co

if data_parallel_size > 1:
comm_overlap_cfg.bucket_size = 128 * 1024 * 1024
comm_overlap_cfg.overlap_grad_reduce = True
comm_overlap_cfg.overlap_param_gather = True
comm_overlap_cfg.overlap_grad_reduce = False
comm_overlap_cfg.overlap_param_gather = False
if parallelism_cfg.pipeline_model_parallel_size > 1 and vp_size > 1:
# Currently disabled due to an issue with checkpointing
# comm_overlap_cfg.overlap_param_gather_with_optimizer_step = True
Expand Down
25 changes: 18 additions & 7 deletions scripts/vlm/neva_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from nemo.collections import llm, vlm
from nemo.collections.multimodal.data.energon.task_encoder import MultiModalTaskEncoder
from nemo.collections.vlm import ImageDataConfig
from nemo.lightning.pytorch.callbacks import NsysCallback

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'NsysCallback' is not used.

Copilot Autofix AI 6 days ago

To fix the problem, we need to remove the unused import statement. This will clean up the code and eliminate the unnecessary dependency. Specifically, we will delete the line that imports NsysCallback from nemo.lightning.pytorch.callbacks.

Suggested changeset 1
scripts/vlm/neva_finetune.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/scripts/vlm/neva_finetune.py b/scripts/vlm/neva_finetune.py
--- a/scripts/vlm/neva_finetune.py
+++ b/scripts/vlm/neva_finetune.py
@@ -45,3 +45,3 @@
 from nemo.collections.vlm import ImageDataConfig
-from nemo.lightning.pytorch.callbacks import NsysCallback
+
 from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
EOF
@@ -45,3 +45,3 @@
from nemo.collections.vlm import ImageDataConfig
from nemo.lightning.pytorch.callbacks import NsysCallback

from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
Copilot is powered by AI and may make mistakes. Always verify output.
Positive Feedback
Negative Feedback

Provide additional feedback

Please help us improve GitHub Copilot by sharing more details about this comment.

Please select one or more of the options
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
from nemo.lightning.pytorch.optim import CosineAnnealingScheduler
from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule
Expand Down Expand Up @@ -73,6 +74,8 @@
input_size=vision_transformer_config.hidden_size,
hidden_size=language_transformer_config.hidden_size,
ffn_hidden_size=language_transformer_config.hidden_size,
bias=False,
bias_activation_fusion=False,
)

# NEVA model configuration
Expand All @@ -84,7 +87,11 @@
freeze_language_model=False,
freeze_vision_model=True,
)
num_image_embeddings_per_tile = vision_transformer_config.num_image_embeddings_per_tile
num_image_embeddings_per_tile = vision_transformer_config.num_image_embeddings_per_tile - int(
neva_config.drop_vision_class_token and vision_transformer_config.add_class_token
)

seq_length = num_image_embeddings_per_tile

if args.data_type == "llava":
# Data configuration
Expand All @@ -97,8 +104,8 @@
data = vlm.NevaPreloadedDataModule(
paths=args.data_path,
data_config=data_config,
seq_length=decoder_seq_length,
decoder_seq_length=None,
seq_length=seq_length,
decoder_seq_length=decoder_seq_length,
global_batch_size=gbs,
micro_batch_size=mbs,
tokenizer=None,
Expand Down Expand Up @@ -133,7 +140,8 @@
path=args.data_path,
tokenizer=tokenizer,
image_processor=image_processor,
seq_length=decoder_seq_length,
seq_length=seq_length,
decoder_seq_length=decoder_seq_length,
micro_batch_size=mbs,
global_batch_size=gbs,
num_workers=0,
Expand All @@ -143,14 +151,15 @@
image_processor=image_processor,
multimodal_sample_config=config,
packed_sequence=args.use_packed_sequence,
packed_sequence_size=decoder_seq_length,
packed_sequence_size=seq_length,
num_image_embeddings_per_tile=num_image_embeddings_per_tile,
),
packing_buffer_size=200 if args.use_packed_sequence else None,
)
elif args.data_type == "mock":
data = vlm.NevaMockDataModule(
seq_length=decoder_seq_length,
seq_length=seq_length,
decoder_seq_length=decoder_seq_length,
global_batch_size=gbs,
micro_batch_size=mbs,
tokenizer=None,
Expand All @@ -168,9 +177,10 @@
tensor_model_parallel_size=args.tp_size,
pipeline_model_parallel_size=args.pp_size,
encoder_pipeline_model_parallel_size=args.encoder_pp_size,
encoder_tensor_model_parallel_size=args.encoder_tp_size,
context_parallel_size=args.cp_size,
pipeline_dtype=torch.bfloat16,
sequence_parallel=True,
sequence_parallel=False,
ddp=DistributedDataParallelConfig(
check_for_nan_in_grad=True,
grad_reduce_in_fp32=True,
Expand Down Expand Up @@ -290,6 +300,7 @@
parser.add_argument("--pp_size", type=int, required=False, default=1)
parser.add_argument("--cp_size", type=int, required=False, default=1)
parser.add_argument("--encoder_pp_size", type=int, required=False, default=0)
parser.add_argument("--encoder_tp_size", type=int, required=False, default=0)
parser.add_argument("--projector_type", type=str, required=False, default="mcore_mlp")
parser.add_argument("--name", type=str, required=False, default="neva_pretrain")
parser.add_argument("--peft", type=str, default='none', help="none | lora")
Expand Down
22 changes: 11 additions & 11 deletions scripts/vlm/neva_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,6 @@ def main(args) -> None:
if raw_image is None:
return # Exit if the image can't be loaded

inputs = processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16)
input_ids = hf_tokenizer(prompt, return_tensors='pt')['input_ids'].cuda()
input_ids[input_ids == 32000] = -200
media = inputs['pixel_values'].cuda()
media = media.reshape(media.size(0), 3, 336, 336)

position_ids = (
torch.arange(input_ids.size(1), dtype=torch.long, device=input_ids.device).unsqueeze(0).expand_as(input_ids)
)

fabric = trainer.to_fabric()

# Decide whether to import or load the model based on the input arguments
Expand All @@ -99,13 +89,23 @@ def main(args) -> None:

model = model.module.cuda()
model.eval()

inputs = processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16)
input_ids = hf_tokenizer(prompt, return_tensors='pt')['input_ids'].to(model.device)
input_ids[input_ids == 32000] = -200
images = inputs['pixel_values'].to(model.device)
images = images.reshape(images.size(0), 3, 336, 336)

position_ids = (
torch.arange(input_ids.size(1), dtype=torch.long, device=input_ids.device).unsqueeze(0).expand_as(input_ids)
)
generated_ids = input_ids.clone()

# Greedy generation loop
for _ in range(20):
with torch.no_grad():
output = model(
media=media,
images=images,
input_ids=input_ids,
position_ids=position_ids,
attention_mask=None,
Expand Down
Loading