Skip to content

Commit

Permalink
nemo-ux MixedPrecision fix (NVIDIA#10080)
Browse files Browse the repository at this point in the history
* NeMo-UX: Mcore mixed precision fix

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Restore parts of MegatronMixedPrecision

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Add code to catch Opt/Model config mismatches

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* remove amp_O2

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Apply isort and black reformatting

Signed-off-by: akoumpa <[email protected]>

* review

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* review

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* remvoe unused param

Signed-off-by: Alexandros Koumparoulis <[email protected]>

---------

Signed-off-by: Alexandros Koumparoulis <[email protected]>
Signed-off-by: akoumpa <[email protected]>
Co-authored-by: akoumpa <[email protected]>
  • Loading branch information
akoumpa and akoumpa committed Aug 15, 2024
1 parent 9fda424 commit 7916269
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 32 deletions.
2 changes: 1 addition & 1 deletion examples/llm/megatron_gpt_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def get_args():
callbacks=callbacks,
log_every_n_steps=1,
limit_val_batches=2,
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed", amp_O2=False),
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
)

nemo_logger = NeMoLogger(
Expand Down
47 changes: 16 additions & 31 deletions nemo/lightning/pytorch/plugins/mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,17 @@
AnyT = TypeVar("AnyT")


def get_optim_config(optimizer: Optimizer):
try:
return optimizer.mcore_optimizer.config
except:
raise ValueError("Failed to extract optimizer config from module.")


class MegatronMixedPrecision(MixedPrecision):
def __init__(
self,
precision: Literal["16-mixed", "bf16-mixed"],
amp_O2: bool = False,
device="cuda",
) -> None:
if precision == "bf16-mixed":
Expand All @@ -39,21 +45,6 @@ def __init__(
scaler = GradScaler(init_scale=2**32, growth_interval=1000, hysteresis=2)

super().__init__(precision, device, scaler)
self.amp_O2 = amp_O2

def connect(
self, model: Module, optimizers: List[Optimizer], lr_schedulers: List[Any]
) -> Tuple[Module, List[Optimizer], List[Any]]:
"""Connects this plugin to the accelerator and the training process."""
from nemo.core.optim import MainParamsOptimizerWrapper

if not optimizers or not self.amp_O2 or isinstance(optimizers[0], MainParamsOptimizerWrapper):
return model, optimizers, lr_schedulers

_optimizers = [*optimizers]
_optimizers[0] = self.convert_optimizer(_optimizers[0])

return model, _optimizers, lr_schedulers

def convert_module(self, module: Module) -> Module:
"""Convert the module parameters to the precision type this plugin handles.
Expand All @@ -68,11 +59,11 @@ def convert_module(self, module: Module) -> Module:
config = get_model_config(module.module)
config.fp16 = self.precision == "16-mixed"
config.bf16 = self.precision == "bf16-mixed"
if isinstance(module.module, Float16Module):
new_float16_module = Float16Module(config, module.module.module)
module.module = new_float16_module
else:
config.autocast = False
if hasattr(module, 'module'):
module.module = Float16Module(config, module.module)
else:
module = Float16Module(config, module)

return module

Expand All @@ -82,16 +73,10 @@ def convert_optimizer(self, optimizer: Optimizer) -> Optimizer:
This is optional and depends on the precision limitations during optimization.
"""
from nemo.core.optim import MainParamsOptimizerWrapper

if isinstance(optimizer, MainParamsOptimizerWrapper) or not self.amp_O2:
return optimizer

return MainParamsOptimizerWrapper(
optimizer,
fp32_grad_accum=True,
contiguous_grad_bucket=True,
)
optim_config = get_optim_config(optimizer)
assert optim_config.bf16 == (self.precision == "bf16-mixed"), "BF16 enabled on model but not on optimizer"
assert optim_config.fp16 == (self.precision == "fp16-mixed"), "BF16 enabled on model but not on optimizer"
return optimizer

def convert_input(self, data: AnyT) -> AnyT:
"""Convert model inputs (forward) to the floating point precision type of this plugin.
Expand Down Expand Up @@ -120,7 +105,7 @@ def optimizer_step(
) -> None:
from nemo.core.optim import MainParamsOptimizerWrapper

if not self.amp_O2 and not isinstance(optimizer, MainParamsOptimizerWrapper):
if not isinstance(optimizer, MainParamsOptimizerWrapper):
return super().optimizer_step(optimizer, model, closure, **kwargs)

if self.scaler is None:
Expand Down
2 changes: 2 additions & 0 deletions nemo/lightning/pytorch/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,8 @@ def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = Tr
assert self.megatron_parallel is not None

_strategy_lib.load_model_state_dict(self.megatron_parallel, checkpoint, strict=strict)
for opt in self.optimizers:
opt.reload_model_params()

@property
@override
Expand Down

0 comments on commit 7916269

Please sign in to comment.