Skip to content

Commit

Permalink
more minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
kvareddy committed Feb 1, 2022
1 parent e1f9c3a commit 3f1a728
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 19 deletions.
13 changes: 2 additions & 11 deletions megatron/data/vit_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,9 @@ def __init__(self, image_size, train=True):
normalize
])
# transformation for the local small crops
self.local_crops_number = args.local_crops_number
self.local_crops_number = args.dino_local_crops_number
self.local_transform = T.Compose([
T.RandomResizedCrop(args.local_img_size,
T.RandomResizedCrop(args.dino_local_img_size,
scale=(0.05, scale_const),
interpolation=Image.BICUBIC),
flip_and_color_jitter,
Expand All @@ -218,12 +218,6 @@ def __init__(self, image_size, train=True):

def __call__(self, image):
crops = []
args = get_args()

if args.street_data:
crop_transform = T.RandomCrop(300)
image = crop_transform(image)

crops.append(self.global_transform1(image))
crops.append(self.global_transform2(image))
for _ in range(self.local_crops_number):
Expand All @@ -247,9 +241,6 @@ def build_train_valid_datasets(data_path, image_size=224):
raise Exception('{} vit pretraining type is not supported.'.format(
args.vit_pretraining_type))

train_transform = ClassificationTransform(image_size)
val_transform = ClassificationTransform(image_size, train=False)

# training dataset
train_data_path = data_path[0] if len(data_path) <= 2 else data_path[2]
train_data = ImageFolder(
Expand Down
2 changes: 0 additions & 2 deletions megatron/model/vision/dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@
from megatron.model.utils import get_linear_layer
from megatron.model.vision.vit_backbone import VitBackbone
from megatron.model.module import MegatronModule
from megatron.utils import print_tensor_min_max_norm as pt
from megatron.model.vision.utils import trunc_normal_
from megatron.model.vision.mit_backbone import mit_b5_avg
from megatron.model.vision.esvit_swin_backbone import get_swin
from megatron.model.vision.av_cam_trunk import get_av_cam_trunk


class DINOLoss(torch.nn.Module):
Expand Down
7 changes: 4 additions & 3 deletions megatron/model/vision/esvit_swin_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
import torch.nn.functional as F
from functools import partial
import torch.distributed as dist
from megatron.model.vision.utils import DropPath, trunc_normal_
from megatron.model.vision.utils import trunc_normal_
from megatron.model.transformer import DropPath
from megatron import get_args
from megatron.model import LayerNorm
import numpy as np
Expand Down Expand Up @@ -809,12 +810,12 @@ def freeze_pretrained_layers(self, frozen_layers=[]):
def get_swin(is_teacher=False):
args = get_args()

if args.swin_type == "tiny":
if args.swin_backbone_type == "tiny":
embed_dim = 96
depths = [2, 2, 6, 2]
num_heads = [3, 6, 12, 24]
drop_path_rate = 0.1
elif args.swin_type == 'h3':
elif args.swin_backbone_type == 'h3':
embed_dim = 384
depths = [2, 2, 18, 2]
num_heads = [6, 12, 24, 48]
Expand Down
5 changes: 4 additions & 1 deletion megatron/model/vision/vit_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ def __init__(self,
pre_process=True,
post_process=True,
class_token=True,
single_token_output=False):
single_token_output=False,
drop_path_rate=0.0):
super(VitBackbone, self).__init__(share_word_embeddings=False)
args = get_args()

Expand All @@ -170,6 +171,7 @@ def __init__(self,
self.img_w = args.img_w
self.micro_batch_size = args.micro_batch_size
self.single_token_output = single_token_output
self.drop_path_rate = drop_path_rate

assert self.img_h % self.patch_dim == 0
assert self.img_w % self.patch_dim == 0
Expand Down Expand Up @@ -216,6 +218,7 @@ def __init__(self,
self.scaled_init_method,
pre_process=self.pre_process,
post_process=self.post_process,
drop_path_rate=self.drop_path_rate
)

def set_input_tensor(self, input_tensor):
Expand Down
4 changes: 3 additions & 1 deletion pretrain_vision_dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@
from functools import partial
from megatron import get_args, get_timers, mpu, print_rank_0
from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model.vision.contrastive import DINOPretrainModel
from megatron.model.vision.dino import DINOPretrainModel
from megatron.model.vision.knn_monitor import knn_predict
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group, unwrap_model
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.model import ModelType

def model_provider(pre_process=True, post_process=True):
"""Build the model."""
Expand Down Expand Up @@ -116,6 +117,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
pretrain(
train_valid_test_datasets_provider,
model_provider,
ModelType.encoder_or_decoder,
forward_step,
args_defaults={'dataloader_type': 'cyclic'}
)
Expand Down
3 changes: 2 additions & 1 deletion pretrain_vision_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
from tasks.vision.metrics import SSIM, PSNR

from megatron.model import ModelType

def model_provider(pre_process=True, post_process=True):
"""Build the model."""
Expand Down Expand Up @@ -143,6 +143,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
pretrain(
train_valid_test_datasets_provider,
model_provider,
ModelType.encoder_or_decoder,
forward_step,
process_non_loss_data,
args_defaults={'dataloader_type': 'cyclic'}
Expand Down

0 comments on commit 3f1a728

Please sign in to comment.