Skip to content

Commit

Permalink
Update more non-core code to use config objects.
Browse files Browse the repository at this point in the history
  • Loading branch information
jaredcasper committed Jun 7, 2023
1 parent 51c6f47 commit 305b390
Show file tree
Hide file tree
Showing 14 changed files with 45 additions and 40 deletions.
4 changes: 4 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,10 @@ def core_transformer_config_from_args(args):
kw_args['activation_func'] = F.silu
kw_args['gated_linear_unit'] = True
kw_args['bias_gelu_fusion'] = False
if args.init_method_xavier_uniform:
kw_args['init_method'] = torch.nn.init.xavier_uniform_
kw_args['scaled_init_method'] = torch.nn.init.xavier_uniform_

return TransformerConfig(**kw_args)

def _add_transformer_engine_args(parser):
Expand Down
3 changes: 1 addition & 2 deletions megatron/model/bert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,9 @@ class BertLMHead(MegatronModule):
"""

def __init__(self, mpu_vocab_size, hidden_size, config, parallel_output):
super(BertLMHead, self).__init__()
super().__init__(config=config)

args = get_args()
self.config = config
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
tensor_parallel.set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
self.parallel_output = parallel_output
Expand Down
8 changes: 3 additions & 5 deletions megatron/model/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,23 @@
class Classification(MegatronModule):

def __init__(self,
config,
num_classes,
num_tokentypes=2,
pre_process=True,
post_process=True):
super(Classification, self).__init__(share_embeddings_and_output_weights=False)
super().__init__(config=config, share_embeddings_and_output_weights=False)
args = get_args()

self.num_classes = num_classes
self.pre_process = pre_process
self.post_process = post_process
init_method = init_method_normal(args.init_method_std)

self.language_model, self._language_model_key = get_language_model(
config=config,
num_tokentypes=num_tokentypes,
add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers),
pre_process=self.pre_process,
post_process=self.post_process)

Expand Down
5 changes: 2 additions & 3 deletions megatron/model/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,10 +412,9 @@ def __init__(self,
self.output_layer = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
args.padded_vocab_size,
bias=False, # Setting bias to False always to keep it consistent with embedding tying that also does not have a bias.
config=config,
init_method=self.init_method,
use_cpu_initialization=args.use_cpu_initialization,
perform_initialization=args.perform_initialization)
bias=False) # Setting bias to False always to keep it consistent with embedding tying that also does not have a bias.
self._output_layer_key = 'output_layer'

def set_input_tensor(self, input_tensor):
Expand Down
6 changes: 2 additions & 4 deletions megatron/model/multiple_choice.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,21 @@
class MultipleChoice(MegatronModule):

def __init__(self,
config,
num_tokentypes=2,
pre_process=True,
post_process=True):
super(MultipleChoice, self).__init__(share_embeddings_and_output_weights=False)
args = get_args()

init_method = init_method_normal(args.init_method_std)
self.pre_process = pre_process
self.post_process = post_process

self.language_model, self._language_model_key = get_language_model(
config=config,
num_tokentypes=num_tokentypes,
add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers),
pre_process=self.pre_process,
post_process=self.post_process)

Expand Down
3 changes: 2 additions & 1 deletion megatron/model/vision/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
class VitClassificationModel(MegatronModule):
"""Vision Transformer Model."""

def __init__(self, num_classes, finetune=False,
def __init__(self, config, num_classes, finetune=False,
pre_process=True, post_process=True):
super(VitClassificationModel, self).__init__()
args = get_args()
Expand All @@ -24,6 +24,7 @@ def __init__(self, num_classes, finetune=False,
self.pre_process = pre_process
self.post_process = post_process
self.backbone = VitBackbone(
config=config,
pre_process=self.pre_process,
post_process=self.post_process,
single_token_output=True
Expand Down
16 changes: 9 additions & 7 deletions megatron/model/vision/dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,12 @@ def cosine_scheduler(base_value, final_value, epochs, niter_per_ep,
return schedule


def get_student_backbone_and_num_features(pre_process=True, post_process=True):
def get_student_backbone_and_num_features(config, pre_process=True, post_process=True):
args = get_args()

if args.vision_backbone_type == 'vit':
student = VitBackbone(pre_process=pre_process,
student = VitBackbone(config,
pre_process=pre_process,
post_process=post_process,
drop_path_rate=0.1,
single_token_output=True)
Expand All @@ -194,11 +195,12 @@ def get_student_backbone_and_num_features(pre_process=True, post_process=True):

return student, num_features

def get_teacher_backbone_and_num_features(pre_process=True, post_process=True):
def get_teacher_backbone_and_num_features(config, pre_process=True, post_process=True):
args = get_args()

if args.vision_backbone_type == 'vit':
teacher = VitBackbone(pre_process=pre_process,
teacher = VitBackbone(config,
pre_process=pre_process,
post_process=post_process,
single_token_output=True)
num_features = args.hidden_size
Expand All @@ -215,7 +217,7 @@ def get_teacher_backbone_and_num_features(pre_process=True, post_process=True):


class DINOPretrainModel(MegatronModule):
def __init__(self, pre_process=True, post_process=True):
def __init__(self, config, pre_process=True, post_process=True):
super(DINOPretrainModel, self).__init__()
args = get_args()
self.out_dim = 65536
Expand All @@ -234,7 +236,7 @@ def __init__(self, pre_process=True, post_process=True):
self.momentum_teacher = 0.996

student_backbone, num_features = \
get_student_backbone_and_num_features(pre_process, post_process)
get_student_backbone_and_num_features(config, pre_process, post_process)

self.student = MultiCropWrapper(
student_backbone,
Expand All @@ -249,7 +251,7 @@ def __init__(self, pre_process=True, post_process=True):
)

teacher_backbone, num_features = \
get_teacher_backbone_and_num_features(pre_process, post_process)
get_teacher_backbone_and_num_features(config, pre_process, post_process)
self.teacher = MultiCropWrapper(
teacher_backbone,
DINOHead(num_features, self.out_dim)
Expand Down
3 changes: 2 additions & 1 deletion megatron/model/vision/inpainting.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@

class VitInpaintingModel(MegatronModule):

def __init__(self, pre_process=True, post_process=True):
def __init__(self, config, pre_process=True, post_process=True):
super(VitInpaintingModel, self).__init__()
args = get_args()

self.pre_process = pre_process
self.post_process = post_process
self.hidden_size = args.hidden_size
self.backbone = VitBackbone(
config=config,
pre_process=self.pre_process,
post_process=self.post_process,
class_token=False,
Expand Down
12 changes: 2 additions & 10 deletions megatron/model/vision/vit_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class VitBackbone(MegatronModule):
"""Vision Transformer Model."""

def __init__(self,
config,
pre_process=True,
post_process=True,
class_token=True,
Expand All @@ -140,14 +141,6 @@ def __init__(self,
args = get_args()

self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
if args.init_method_xavier_uniform:
self.init_method = torch.nn.init.xavier_uniform_
self.scaled_init_method = torch.nn.init.xavier_uniform_
else:
self.init_method = init_method_normal(args.init_method_std)
self.scaled_init_method = scaled_init_method_normal(
args.init_method_std, args.num_layers
)

self.pre_process = pre_process
self.post_process = post_process
Expand Down Expand Up @@ -202,8 +195,7 @@ def __init__(self,

# Transformer
self.transformer = ParallelTransformer(
self.init_method,
self.scaled_init_method,
config,
pre_process=self.pre_process,
post_process=self.post_process,
post_layer_norm=self.post_layer_norm,
Expand Down
6 changes: 4 additions & 2 deletions pretrain_vision_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,18 @@
from megatron.model.vision.classification import MitClassificationModel
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
from megatron.arguments import core_transformer_config_from_args


def model_provider(pre_process=True, post_process=True):
"""Build the model."""

args = get_args()

config = core_transformer_config_from_args(args)
if args.vision_backbone_type == 'vit':
print_rank_0("building VIT model ...")
model = VitClassificationModel(num_classes=args.num_classes,
model = VitClassificationModel(config=config,
num_classes=args.num_classes,
pre_process=pre_process,
post_process=post_process)
elif args.vision_backbone_type == 'mit':
Expand Down
4 changes: 3 additions & 1 deletion pretrain_vision_dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.arguments import core_transformer_config_from_args

def model_provider(pre_process=True, post_process=True):
"""Build the model."""
return DINOPretrainModel(pre_process=pre_process, post_process=post_process)
config = core_transformer_config_from_args(get_args())
return DINOPretrainModel(config, pre_process=pre_process, post_process=post_process)

def get_batch(data_iterator):
"""Build the batch."""
Expand Down
5 changes: 4 additions & 1 deletion pretrain_vision_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
from tasks.vision.metrics import SSIM, PSNR
from megatron.arguments import core_transformer_config_from_args

def model_provider(pre_process=True, post_process=True):
"""Build the model."""
args = get_args()
config = core_transformer_config_from_args(args)
if args.vision_backbone_type == 'vit':
model = VitInpaintingModel(pre_process=pre_process,
model = VitInpaintingModel(config,
pre_process=pre_process,
post_process=post_process)
elif args.vision_backbone_type == 'mit':
model = MitInpaintingModel(pre_process=pre_process,
Expand Down
4 changes: 3 additions & 1 deletion tasks/glue/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from megatron.model.classification import Classification
from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune
from megatron.arguments import core_transformer_config_from_args


def glue_classification(num_classes, Dataset,
Expand All @@ -28,10 +29,11 @@ def train_valid_datasets_provider():
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
args = get_args()
config = core_transformer_config_from_args()

print_rank_0('building classification model for {} ...'.format(
args.task))
model = Classification(num_classes=num_classes, num_tokentypes=2,
model = Classification(config=config, num_classes=num_classes, num_tokentypes=2,
pre_process=pre_process, post_process=post_process)

return model
Expand Down
6 changes: 4 additions & 2 deletions tasks/race/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune
from tasks.race.data import RaceDataset
from megatron.arguments import core_transformer_config_from_args


def train_valid_datasets_provider():
Expand All @@ -26,9 +27,10 @@ def train_valid_datasets_provider():

def model_provider(pre_process=True, post_process=True):
"""Build the model."""

config = core_transformer_config_from_args(get_args())
print_rank_0('building multichoice model for RACE ...')
model = MultipleChoice(num_tokentypes=2,
model = MultipleChoice(config=config,
num_tokentypes=2,
pre_process=pre_process,
post_process=post_process)

Expand Down

0 comments on commit 305b390

Please sign in to comment.