diff --git a/examples/llm/megatron_gpt_pretraining.py b/examples/llm/megatron_gpt_pretraining.py index d3d049e4296e..73e96a23bf81 100644 --- a/examples/llm/megatron_gpt_pretraining.py +++ b/examples/llm/megatron_gpt_pretraining.py @@ -92,7 +92,7 @@ def get_args(): callbacks=callbacks, log_every_n_steps=1, limit_val_batches=2, - plugins=nl.MegatronMixedPrecision(precision="bf16-mixed", amp_O2=False), + plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), ) nemo_logger = NeMoLogger( diff --git a/nemo/lightning/pytorch/plugins/mixed_precision.py b/nemo/lightning/pytorch/plugins/mixed_precision.py index 5e43e09c0420..65b7c6292249 100644 --- a/nemo/lightning/pytorch/plugins/mixed_precision.py +++ b/nemo/lightning/pytorch/plugins/mixed_precision.py @@ -26,11 +26,17 @@ AnyT = TypeVar("AnyT") +def get_optim_config(optimizer: Optimizer): + try: + return optimizer.mcore_optimizer.config + except: + raise ValueError("Failed to extract optimizer config from module.") + + class MegatronMixedPrecision(MixedPrecision): def __init__( self, precision: Literal["16-mixed", "bf16-mixed"], - amp_O2: bool = False, device="cuda", ) -> None: if precision == "bf16-mixed": @@ -39,21 +45,6 @@ def __init__( scaler = GradScaler(init_scale=2**32, growth_interval=1000, hysteresis=2) super().__init__(precision, device, scaler) - self.amp_O2 = amp_O2 - - def connect( - self, model: Module, optimizers: List[Optimizer], lr_schedulers: List[Any] - ) -> Tuple[Module, List[Optimizer], List[Any]]: - """Connects this plugin to the accelerator and the training process.""" - from nemo.core.optim import MainParamsOptimizerWrapper - - if not optimizers or not self.amp_O2 or isinstance(optimizers[0], MainParamsOptimizerWrapper): - return model, optimizers, lr_schedulers - - _optimizers = [*optimizers] - _optimizers[0] = self.convert_optimizer(_optimizers[0]) - - return model, _optimizers, lr_schedulers def convert_module(self, module: Module) -> Module: """Convert the module parameters to the precision type this plugin handles. @@ -68,11 +59,11 @@ def convert_module(self, module: Module) -> Module: config = get_model_config(module.module) config.fp16 = self.precision == "16-mixed" config.bf16 = self.precision == "bf16-mixed" - if isinstance(module.module, Float16Module): - new_float16_module = Float16Module(config, module.module.module) - module.module = new_float16_module - else: + config.autocast = False + if hasattr(module, 'module'): module.module = Float16Module(config, module.module) + else: + module = Float16Module(config, module) return module @@ -82,16 +73,10 @@ def convert_optimizer(self, optimizer: Optimizer) -> Optimizer: This is optional and depends on the precision limitations during optimization. """ - from nemo.core.optim import MainParamsOptimizerWrapper - - if isinstance(optimizer, MainParamsOptimizerWrapper) or not self.amp_O2: - return optimizer - - return MainParamsOptimizerWrapper( - optimizer, - fp32_grad_accum=True, - contiguous_grad_bucket=True, - ) + optim_config = get_optim_config(optimizer) + assert optim_config.bf16 == (self.precision == "bf16-mixed"), "BF16 enabled on model but not on optimizer" + assert optim_config.fp16 == (self.precision == "fp16-mixed"), "BF16 enabled on model but not on optimizer" + return optimizer def convert_input(self, data: AnyT) -> AnyT: """Convert model inputs (forward) to the floating point precision type of this plugin. @@ -120,7 +105,7 @@ def optimizer_step( ) -> None: from nemo.core.optim import MainParamsOptimizerWrapper - if not self.amp_O2 and not isinstance(optimizer, MainParamsOptimizerWrapper): + if not isinstance(optimizer, MainParamsOptimizerWrapper): return super().optimizer_step(optimizer, model, closure, **kwargs) if self.scaler is None: diff --git a/nemo/lightning/pytorch/strategies.py b/nemo/lightning/pytorch/strategies.py index a208fac4017f..f22ed5b40a20 100644 --- a/nemo/lightning/pytorch/strategies.py +++ b/nemo/lightning/pytorch/strategies.py @@ -621,6 +621,8 @@ def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = Tr assert self.megatron_parallel is not None _strategy_lib.load_model_state_dict(self.megatron_parallel, checkpoint, strict=strict) + for opt in self.optimizers: + opt.reload_model_params() @property @override