Skip to content

Commit

Permalink
Add support for tf32 and set precision to bf16-mixed if available
Browse files Browse the repository at this point in the history
  • Loading branch information
fcogidi committed Sep 24, 2024
1 parent 0651a73 commit ce2a59d
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion mmlearn/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from omegaconf import OmegaConf
from pytorch_lightning.utilities import rank_zero_only
from torch.utils.data import DataLoader
from transformers.utils.import_utils import is_torch_tf32_available

from mmlearn.cli._instantiators import (
instantiate_callbacks,
Expand Down Expand Up @@ -41,7 +42,11 @@ def main(cfg: MMLearnConf) -> None: # noqa: PLR0912
cfg_copy = copy.deepcopy(cfg) # copy of the config for logging

L.seed_everything(cfg.seed, workers=True)
torch.set_float32_matmul_precision("high")

if is_torch_tf32_available():
torch.backends.cuda.matmul.allow_tf32 = True
if "16-mixed" in cfg.trainer.precision:
cfg.trainer.precision = "bf16-mixed"

# setup trainer first so that we can get some variables for distributed training
callbacks = instantiate_callbacks(cfg.trainer.get("callbacks"))
Expand Down

0 comments on commit ce2a59d

Please sign in to comment.