From b49249803e3e89abc5da2860e906e6c6d17fb3c1 Mon Sep 17 00:00:00 2001 From: Jon Barker Date: Mon, 11 Sep 2023 21:07:33 -0700 Subject: [PATCH] Fixes errors in vision model pipelines --- .gitignore | 1 + examples/pretrain_vision_classify.sh | 64 +++++++++++++++++++++++ examples/pretrain_vision_dino.sh | 67 +++++++++++++++++++++++++ examples/pretrain_vision_inpaint.sh | 65 ++++++++++++++++++++++++ megatron/data/autoaugment.py | 2 +- megatron/model/vision/classification.py | 5 +- megatron/model/vision/dino.py | 3 +- megatron/model/vision/inpainting.py | 15 +++--- megatron/model/vision/vit_backbone.py | 7 ++- megatron/tokenizer/tokenizer.py | 2 +- pretrain_vision_dino.py | 3 +- pretrain_vision_inpaint.py | 11 ++-- 12 files changed, 225 insertions(+), 20 deletions(-) create mode 100755 examples/pretrain_vision_classify.sh create mode 100755 examples/pretrain_vision_dino.sh create mode 100755 examples/pretrain_vision_inpaint.sh diff --git a/.gitignore b/.gitignore index cac3499524..5955b349f1 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ build *~ slurm* logs +.vscode diff --git a/examples/pretrain_vision_classify.sh b/examples/pretrain_vision_classify.sh new file mode 100755 index 0000000000..5fcdd6e6ef --- /dev/null +++ b/examples/pretrain_vision_classify.sh @@ -0,0 +1,64 @@ +#! /bin/bash + +# Pre-trains ViT based image classificaation model + +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NCCL_IB_SL=1 + +# Training and validation paths should each point to a folder where each +# sub-folder contains a collection of images in jpg or png format +# e.g. If using imagenet, one train image might be, train_data/n01688243/n01688243_11301.JPEG +DATA_PATH_TRAIN= +DATA_PATH_VAL= + +CHECKPOINT_PATH= + +CLASSIFIER_ARGS=" + --tensor-model-parallel-size 1 \ + --num-layers 12 \ + --hidden-size 768 \ + --num-attention-heads 12 \ + --patch-dim 4 \ + --seq-length 3136 \ + --max-position-embeddings 3136 \ + --img-h 224 \ + --img-w 224 \ + --mask-factor 1.0 \ + --fp16 \ + --train-iters 750000 \ + --lr-decay-style cosine \ + --micro-batch-size 4 \ + --global-batch-size 1024 \ + --lr 0.0005 \ + --min-lr 0.00001 \ + --attention-dropout 0.0 \ + --weight-decay 0.05 \ + --lr-warmup-iters 12500 \ + --clip-grad 1.0 \ + --no-gradient-accumulation-fusion \ + --num-workers 4 \ + --DDP-impl torch " + +DATA_ARGS=" + --tokenizer-type NullTokenizer \ + --vocab-size 0 \ + --data-path $DATA_PATH_TRAIN $DATA_PATH_VAL \ + --no-data-sharding \ + --split 949,50,1 \ +" + +OUTPUT_ARG=" + --log-interval 32 \ + --save-interval 10000 \ + --eval-interval 2500 \ + --eval-iters 100 \ + --tensorboard-dir ${CHECKPOINT_PATH} \ +" + +torchrun pretrain_vision_classification.py \ + $CLASSIFIER_ARGS \ + $DATA_ARGS \ + $OUTPUT_ARGS \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH + diff --git a/examples/pretrain_vision_dino.sh b/examples/pretrain_vision_dino.sh new file mode 100755 index 0000000000..b047e4e340 --- /dev/null +++ b/examples/pretrain_vision_dino.sh @@ -0,0 +1,67 @@ +#! /bin/bash + +# Pre-trains Dino V1 model +# For model details: https://arxiv.org/abs/2104.14294 +# For original author implementation: https://github.com/facebookresearch/dino/tree/main + +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NCCL_IB_SL=1 + +# Training and validation paths should each point to a folder where each +# sub-folder contains a collection of images in jpg or png format +# e.g. If using imagenet, one train image might be, train_data/n01688243/n01688243_11301.JPEG +DATA_PATH_TRAIN= +DATA_PATH_VAL= + +CHECKPOINT_PATH= + +DINO_ARGS=" + --vision-pretraining-type dino \ + --tensor-model-parallel-size 1 \ + --num-layers 12 \ + --hidden-size 768 \ + --num-attention-heads 12 \ + --patch-dim 4 \ + --seq-length 3136 \ + --max-position-embeddings 3136 \ + --img-h 224 \ + --img-w 224 \ + --mask-factor 1.0 \ + --fp16 \ + --train-iters 750000 \ + --lr-decay-style cosine \ + --micro-batch-size 4 \ + --global-batch-size 1024 \ + --lr 0.0005 \ + --min-lr 0.00001 \ + --attention-dropout 0.0 \ + --weight-decay 0.05 \ + --lr-warmup-iters 12500 \ + --clip-grad 1.0 \ + --no-gradient-accumulation-fusion \ + --num-workers 4 \ + --DDP-impl torch " + +DATA_ARGS=" + --tokenizer-type NullTokenizer \ + --vocab-size 0 \ + --data-path $DATA_PATH_TRAIN $DATA_PATH_VAL \ + --no-data-sharding \ + --split 949,50,1 \ +" + +OUTPUT_ARG=" + --log-interval 32 \ + --save-interval 10000 \ + --eval-interval 2500 \ + --eval-iters 100 \ + --tensorboard-dir ${CHECKPOINT_PATH} \ +" + +torchrun pretrain_vision_dino.py \ + $DINO_ARGS \ + $DATA_ARGS \ + $OUTPUT_ARGS \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH + diff --git a/examples/pretrain_vision_inpaint.sh b/examples/pretrain_vision_inpaint.sh new file mode 100755 index 0000000000..01c7e71a9e --- /dev/null +++ b/examples/pretrain_vision_inpaint.sh @@ -0,0 +1,65 @@ +#! /bin/bash + +# Pre-trains ViT based image inpainting model + +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NCCL_IB_SL=1 + +# Training and validation paths should each point to a folder where each +# sub-folder contains a collection of images in jpg or png format +# e.g. If using imagenet, one train image might be, train_data/n01688243/n01688243_11301.JPEG +DATA_PATH_TRAIN= +DATA_PATH_VAL= + +CHECKPOINT_PATH= + +INPAINT_ARGS=" + --vision-pretraining-type inpaint \ + --tensor-model-parallel-size 1 \ + --num-layers 12 \ + --hidden-size 768 \ + --num-attention-heads 12 \ + --patch-dim 4 \ + --seq-length 3136 \ + --max-position-embeddings 3136 \ + --img-h 224 \ + --img-w 224 \ + --mask-factor 1.0 \ + --fp16 \ + --train-iters 750000 \ + --lr-decay-style cosine \ + --micro-batch-size 4 \ + --global-batch-size 1024 \ + --lr 0.0005 \ + --min-lr 0.00001 \ + --attention-dropout 0.0 \ + --weight-decay 0.05 \ + --lr-warmup-iters 12500 \ + --clip-grad 1.0 \ + --no-gradient-accumulation-fusion \ + --num-workers 4 \ + --DDP-impl torch " + +DATA_ARGS=" + --tokenizer-type NullTokenizer \ + --vocab-size 0 \ + --data-path $DATA_PATH_TRAIN $DATA_PATH_VAL \ + --no-data-sharding \ + --split 949,50,1 \ +" + +OUTPUT_ARG=" + --log-interval 32 \ + --save-interval 10000 \ + --eval-interval 2500 \ + --eval-iters 100 \ + --tensorboard-dir ${CHECKPOINT_PATH} \ +" + +torchrun pretrain_vision_inpaint.py \ + $INPAINT_ARGS \ + $DATA_ARGS \ + $OUTPUT_ARGS \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH + diff --git a/megatron/data/autoaugment.py b/megatron/data/autoaugment.py index 585a4fa6a5..7f988c5f04 100644 --- a/megatron/data/autoaugment.py +++ b/megatron/data/autoaugment.py @@ -193,7 +193,7 @@ def __init__( "rotate": np.linspace(0, 30, num_levels), "color": np.linspace(0.0, 0.9, num_levels), "posterize": np.round(np.linspace(8, 4, num_levels), 0).astype( - np.int + np.int32 ), "solarize": np.linspace(256, 0, num_levels), # range [0, 256] "contrast": np.linspace(0.0, 0.9, num_levels), diff --git a/megatron/model/vision/classification.py b/megatron/model/vision/classification.py index 4d1a4e9021..3d5c823df4 100644 --- a/megatron/model/vision/classification.py +++ b/megatron/model/vision/classification.py @@ -17,6 +17,7 @@ def __init__(self, config, num_classes, finetune=False, pre_process=True, post_process=True): super(VitClassificationModel, self).__init__() args = get_args() + self.config = config self.hidden_size = args.hidden_size self.num_classes = num_classes @@ -29,10 +30,10 @@ def __init__(self, config, num_classes, finetune=False, post_process=self.post_process, single_token_output=True ) - + if self.post_process: if not self.finetune: - self.head = VitMlpHead(self.hidden_size, self.num_classes) + self.head = VitMlpHead(config, self.hidden_size, self.num_classes) else: self.head = get_linear_layer( self.hidden_size, diff --git a/megatron/model/vision/dino.py b/megatron/model/vision/dino.py index 1c577d2e19..151ec26647 100644 --- a/megatron/model/vision/dino.py +++ b/megatron/model/vision/dino.py @@ -192,7 +192,7 @@ def get_student_backbone_and_num_features(config, pre_process=True, post_process else: raise Exception('{} vision backbone is not supported.'.format( args.vision_backbone_type)) - + return student, num_features def get_teacher_backbone_and_num_features(config, pre_process=True, post_process=True): @@ -220,6 +220,7 @@ class DINOPretrainModel(MegatronModule): def __init__(self, config, pre_process=True, post_process=True): super(DINOPretrainModel, self).__init__() args = get_args() + self.config = config self.out_dim = 65536 self.dino_loss = DINOLoss( diff --git a/megatron/model/vision/inpainting.py b/megatron/model/vision/inpainting.py index cda03315be..6aae9658bc 100644 --- a/megatron/model/vision/inpainting.py +++ b/megatron/model/vision/inpainting.py @@ -1,8 +1,8 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. -i + import math import apex import einops @@ -13,7 +13,7 @@ from megatron.model.vision.vit_backbone import VitBackbone from megatron.model.module import MegatronModule from megatron.model.vision.mit_backbone import mit_b3 -from megatron.model.vision.utils import resize_ +from megatron.model.vision.utils import resize class VitInpaintingModel(MegatronModule): @@ -22,6 +22,7 @@ def __init__(self, config, pre_process=True, post_process=True): super(VitInpaintingModel, self).__init__() args = get_args() + self.config = config self.pre_process = pre_process self.post_process = post_process self.hidden_size = config.hidden_size @@ -108,9 +109,9 @@ def __init__(self, pre_process=True, post_process=True): self.conv_fuse = torch.nn.Conv2d(self.embedding_dim*4, self.embedding_dim, 1, 1, bias=False) self.norm = apex.parallel.SyncBatchNorm(self.embedding_dim) self.dropout = torch.nn.Dropout2d(0.1) - + self.linear_pred = torch.nn.Conv2d(self.embedding_dim, self.flatten_dim, kernel_size=1) - + def set_input_tensor(self, input_tensor): """See megatron.model.transformer.set_input_tensor()""" pass @@ -121,7 +122,7 @@ def forward(self, input): n, _, h, w = c4.shape _c4 = self.linear_c4(c4).permute(0, 2, 1).reshape(n, -1, c4.shape[2], c4.shape[3]) _c4 = resize(_c4, size=c1.size()[2:], mode='bilinear', align_corners=False) - + _c3 = self.linear_c3(c3).permute(0, 2, 1).reshape(n, -1, c3.shape[2], c3.shape[3]) _c3 = resize(_c3, size=c1.size()[2:], mode='bilinear', align_corners=False) @@ -132,7 +133,7 @@ def forward(self, input): _c = torch.cat([_c4, _c3, _c2, _c1], dim=1) _c = self.conv_fuse(_c) - + x = self.norm(_c) x = F.relu(x, inplace=True) x = self.dropout(x) diff --git a/megatron/model/vision/vit_backbone.py b/megatron/model/vision/vit_backbone.py index 1efef9c17a..15cf75affc 100644 --- a/megatron/model/vision/vit_backbone.py +++ b/megatron/model/vision/vit_backbone.py @@ -30,8 +30,9 @@ class VitMlpHead(MegatronModule): bias is set to zero. """ - def __init__(self, hidden_size, num_classes): + def __init__(self, config, hidden_size, num_classes): super(VitMlpHead, self).__init__() + self.config = config self.dense_in = torch.nn.Linear(hidden_size, hidden_size) self.relu = torch.nn.ReLU() self.dense_out = torch.nn.Linear(hidden_size, num_classes) @@ -139,6 +140,7 @@ def __init__(self, drop_path_rate=0.0): super(VitBackbone, self).__init__(share_embeddings_and_output_weights=False) args = get_args() + self.config = config self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy @@ -172,7 +174,7 @@ def __init__(self, ) torch.nn.init.zeros_(self.cls_token) self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda() - + # Linear encoder self.linear_encoder = torch.nn.Linear( self.flatten_dim, self.hidden_size @@ -196,6 +198,7 @@ def __init__(self, # Transformer self.transformer = ParallelTransformer( config, + model_type=args.model_type, pre_process=self.pre_process, post_process=self.post_process, post_layer_norm=self.post_layer_norm, diff --git a/megatron/tokenizer/tokenizer.py b/megatron/tokenizer/tokenizer.py index 39a9e33215..98643343c5 100644 --- a/megatron/tokenizer/tokenizer.py +++ b/megatron/tokenizer/tokenizer.py @@ -44,7 +44,7 @@ def build_tokenizer(args): else: raise NotImplementedError('{} tokenizer is not ' 'implemented.'.format(args.tokenizer_type)) - + # Add vocab size (if not already set from a checkpoint). if getattr(args, "padded_vocab_size", None) is None: args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size, diff --git a/pretrain_vision_dino.py b/pretrain_vision_dino.py index 3c75b6160a..01efeab2b1 100644 --- a/pretrain_vision_dino.py +++ b/pretrain_vision_dino.py @@ -36,7 +36,7 @@ def get_batch(data_iterator): def loss_func(model, labels, output_tensor, collect_data=False): args = get_args() - + model = unwrap_model(model) if model.training: student_output, teacher_output = output_tensor @@ -94,6 +94,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): if __name__ == "__main__": + pretrain( train_valid_test_datasets_provider, model_provider, diff --git a/pretrain_vision_inpaint.py b/pretrain_vision_inpaint.py index 509a38d2af..1947a47faf 100644 --- a/pretrain_vision_inpaint.py +++ b/pretrain_vision_inpaint.py @@ -12,7 +12,7 @@ from megatron.model.vision.inpainting import MitInpaintingModel from megatron.training import pretrain from megatron.utils import average_losses_across_data_parallel_group -from tasks.vision.metrics import SSIM, PSNR +from tasks.vision.segmentation.metrics import SSIM, PSNR from megatron.arguments import core_transformer_config_from_args def model_provider(pre_process=True, post_process=True): @@ -20,11 +20,12 @@ def model_provider(pre_process=True, post_process=True): args = get_args() config = core_transformer_config_from_args(args) if args.vision_backbone_type == 'vit': - model = VitInpaintingModel(config, + model = VitInpaintingModel(config=config, pre_process=pre_process, post_process=post_process) elif args.vision_backbone_type == 'mit': - model = MitInpaintingModel(pre_process=pre_process, + model = MitInpaintingModel(config=config, + pre_process=pre_process, post_process=post_process) else: raise Exception('{} vision backbone is not supported.'.format( @@ -42,7 +43,7 @@ def get_batch(data_iterator): return images, masks -def loss_func(images, masks, masked_images, outputs, collect_data=False): +def loss_func(images, masks, masked_images, outputs, non_loss_data=False): outputs = outputs.contiguous().float() masks_flip = 1-masks flip_masked_outputs = outputs.masked_fill(masks_flip.bool(), 0) @@ -51,7 +52,7 @@ def loss_func(images, masks, masked_images, outputs, collect_data=False): ssim_fun = SSIM() psnr_fun = PSNR() - if not collect_data: + if not non_loss_data: mask_count = torch.count_nonzero(masks) loss = F.mse_loss( flip_masked_outputs,