From 3af38527b40e3309dae5c405b0528245d06f9f61 Mon Sep 17 00:00:00 2001 From: Henrique Morimitsu Date: Tue, 3 Dec 2024 17:35:17 +0800 Subject: [PATCH] Add lightning custom files --- .../lightning/ptlflow_checkpoint_connector.py | 196 ++++++++ ptlflow/utils/lightning/ptlflow_cli.py | 172 +++++++ ptlflow/utils/lightning/ptlflow_trainer.py | 436 ++++++++++++++++++ 3 files changed, 804 insertions(+) create mode 100644 ptlflow/utils/lightning/ptlflow_checkpoint_connector.py create mode 100644 ptlflow/utils/lightning/ptlflow_cli.py create mode 100644 ptlflow/utils/lightning/ptlflow_trainer.py diff --git a/ptlflow/utils/lightning/ptlflow_checkpoint_connector.py b/ptlflow/utils/lightning/ptlflow_checkpoint_connector.py new file mode 100644 index 0000000..933ea75 --- /dev/null +++ b/ptlflow/utils/lightning/ptlflow_checkpoint_connector.py @@ -0,0 +1,196 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +from pathlib import Path +import re +from typing import Optional + +import torch +from torch import hub +from fsspec.core import url_to_fs + +import lightning.pytorch as pl +from lightning.fabric.utilities.types import _PATH +from lightning.pytorch.trainer.connectors.checkpoint_connector import ( + _CheckpointConnector, +) +from lightning.pytorch.trainer.states import TrainerFn +from lightning.pytorch.utilities.migration import pl_legacy_patch +from lightning.pytorch.utilities.migration.utils import _pl_migrate_checkpoint +from lightning.pytorch.utilities.rank_zero import rank_zero_info + +log = logging.getLogger(__name__) + + +class _PTLFlowCheckpointConnector(_CheckpointConnector): + def __init__(self, trainer: pl.Trainer) -> None: + super().__init__(trainer) + + def resume_start( + self, + checkpoint_path: Optional[_PATH] = None, + model: Optional[pl.LightningModule] = None, + ) -> None: + """Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority: + + 1. from HPC weights if `checkpoint_path` is ``None`` and on SLURM or passed keyword `"hpc"`. + 2. from fault-tolerant auto-saved checkpoint if found + 3. from `checkpoint_path` file if provided + 4. don't restore + + """ + self._ckpt_path = checkpoint_path + if not checkpoint_path: + log.debug("`checkpoint_path` not specified. Skipping checkpoint loading.") + return + + if not Path(checkpoint_path).exists(): + if model is not None: + model_ref = model.__class__ + if hasattr(model_ref, "pretrained_checkpoints"): + checkpoint_path = model_ref.pretrained_checkpoints.get( + checkpoint_path + ) + if checkpoint_path is None: + raise ValueError( + f"Invalid checkpoint name {checkpoint_path}. " + f'Choose one from {{{",".join(model_ref.pretrained_checkpoints.keys())}}}' + ) + + cache_path = ( + Path(hub.get_dir()) + / "checkpoints" + / checkpoint_path.split("/")[-1] + ) + if cache_path.exists(): + checkpoint_path = cache_path + else: + raise ValueError( + f"Cannot find checkpoint {checkpoint_path} for model {model.__class__.__name__}" + ) + + rank_zero_info( + f"Restoring states from the checkpoint path at {checkpoint_path}" + ) + with pl_legacy_patch(): + loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path) + if not "pytorch-lightning_version" in loaded_checkpoint: + loaded_checkpoint["pytorch-lightning_version"] = "1.9.5" + self._loaded_checkpoint = _pl_migrate_checkpoint( + loaded_checkpoint, checkpoint_path + ) + + def resume_end(self) -> None: + """Signal the connector that all states have resumed and memory for the checkpoint object can be released.""" + assert self.trainer.state.fn is not None + if self._ckpt_path: + message = ( + "Restored all states" + if self.trainer.state.fn == TrainerFn.FITTING + else "Loaded model weights" + ) + rank_zero_info(f"{message} from the checkpoint at {self._ckpt_path}") + + # free memory + self._loaded_checkpoint = {} + torch.cuda.empty_cache() + + # wait for all to catch up + self.trainer.strategy.barrier("_PTLFlowCheckpointConnector.resume_end") + + def restore_training_state(self) -> None: + """Restore the trainer state from the pre-loaded checkpoint. + + This includes the precision settings, loop progress, optimizer states and learning rate scheduler states. + + Modifications by Henrique Morimitsu: + - restore_optimizers_and_schedulers before other states to raise an error earlier in case the ckpt does not have training states + """ + if not self._loaded_checkpoint: + return + + assert self.trainer.state.fn is not None + if self.trainer.state.fn == TrainerFn.FITTING: + # restore optimizers and schedulers state + self.restore_optimizers_and_schedulers() + + # restore precision plugin (scaler etc.) + self.restore_precision_plugin_state() + + # restore loops and their progress + self.restore_loops() + + def _restore_modules_and_callbacks( + self, + checkpoint_path: Optional[_PATH] = None, + model: Optional[pl.LightningModule] = None, + ) -> None: + # restore modules after setup + self.resume_start(checkpoint_path, model) + self.restore_model() + self.restore_datamodule() + if self.trainer.state.fn == TrainerFn.FITTING: + # restore callback states + self.restore_callbacks() + + @staticmethod + def __max_ckpt_version_in_folder( + dir_path: _PATH, name_key: str = "ckpt_" + ) -> Optional[int]: + """List up files in `dir_path` with `name_key`, then yield maximum suffix number. + + Args: + dir_path: path of directory which may contain files whose name include `name_key` + name_key: file name prefix + Returns: + None if no-corresponding-file else maximum suffix number + + """ + # check directory existence + fs, uri = url_to_fs(str(dir_path)) + if not fs.exists(dir_path): + return None + + # check corresponding file existence + files = [os.path.basename(f["name"]) for f in fs.listdir(uri)] + files = [x for x in files if name_key in x] + if len(files) == 0: + return None + + # extract suffix number + ckpt_vs = [] + for name in files: + name = name.split(name_key)[-1] + name = re.sub("[^0-9]", "", name) + ckpt_vs.append(int(name)) + + return max(ckpt_vs) + + @staticmethod + def __get_max_ckpt_path_from_folder(folder_path: _PATH) -> str: + """Get path of maximum-epoch checkpoint in the folder.""" + max_suffix = _PTLFlowCheckpointConnector.__max_ckpt_version_in_folder( + folder_path + ) + ckpt_number = max_suffix if max_suffix is not None else 0 + return f"{folder_path}/hpc_ckpt_{ckpt_number}.ckpt" + + @staticmethod + def hpc_save_path(folderpath: _PATH) -> str: + max_suffix = _PTLFlowCheckpointConnector.__max_ckpt_version_in_folder( + folderpath + ) + ckpt_number = (max_suffix if max_suffix is not None else 0) + 1 + return os.path.join(folderpath, f"hpc_ckpt_{ckpt_number}.ckpt") diff --git a/ptlflow/utils/lightning/ptlflow_cli.py b/ptlflow/utils/lightning/ptlflow_cli.py new file mode 100644 index 0000000..98564f0 --- /dev/null +++ b/ptlflow/utils/lightning/ptlflow_cli.py @@ -0,0 +1,172 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +from typing import Any, Callable, Dict, Optional, Type, Union + +from jsonargparse import Namespace +from lightning import LightningDataModule, LightningModule, Trainer +from lightning.pytorch.cli import ( + ArgsType, + LightningArgumentParser, + LightningCLI, + SaveConfigCallback, +) +from lightning.pytorch.utilities.rank_zero import rank_zero_warn + + +class PTLFlowCLI(LightningCLI): + def __init__( + self, + model_class: Optional[ + Union[Type[LightningModule], Callable[..., LightningModule]] + ] = None, + datamodule_class: Optional[ + Union[Type[LightningDataModule], Callable[..., LightningDataModule]] + ] = None, + save_config_callback: Optional[Type[SaveConfigCallback]] = SaveConfigCallback, + save_config_kwargs: Optional[Dict[str, Any]] = None, + trainer_class: Optional[Union[Type[Trainer], Callable[..., Trainer]]] = None, + trainer_defaults: Optional[Dict[str, Any]] = None, + seed_everything_default: Union[bool, int] = True, + parser_kwargs: Optional[ + Union[Dict[str, Any], Dict[str, Dict[str, Any]]] + ] = None, + subclass_mode_model: bool = False, + subclass_mode_data: bool = False, + args: ArgsType = None, + run: bool = True, + auto_configure_optimizers: bool = True, + parse_only: bool = False, + ignore_sys_argv: bool = False, + ) -> None: + """Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which are + called / instantiated using a parsed configuration file and / or command line args. + + Parsing of configuration from environment variables can be enabled by setting ``parser_kwargs={"default_env": + True}``. A full configuration yaml would be parsed from ``PL_CONFIG`` if set. Individual settings are so parsed + from variables named for example ``PL_TRAINER__MAX_EPOCHS``. + + For more info, read :ref:`the CLI docs `. + + Args: + model_class: An optional :class:`~lightning.pytorch.core.LightningModule` class to train on or a + callable which returns a :class:`~lightning.pytorch.core.LightningModule` instance when + called. If ``None``, you can pass a registered model with ``--model=MyModel``. + datamodule_class: An optional :class:`~lightning.pytorch.core.datamodule.LightningDataModule` class or a + callable which returns a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` instance when + called. If ``None``, you can pass a registered datamodule with ``--data=MyDataModule``. + save_config_callback: A callback class to save the config. + save_config_kwargs: Parameters that will be used to instantiate the save_config_callback. + trainer_class: An optional subclass of the :class:`~lightning.pytorch.trainer.trainer.Trainer` class or a + callable which returns a :class:`~lightning.pytorch.trainer.trainer.Trainer` instance when called. + trainer_defaults: Set to override Trainer defaults or add persistent callbacks. The callbacks added through + this argument will not be configurable from a configuration file and will always be present for + this particular CLI. Alternatively, configurable callbacks can be added as explained in + :ref:`the CLI docs `. + seed_everything_default: Number for the :func:`~lightning.fabric.utilities.seed.seed_everything` + seed value. Set to True to automatically choose a seed value. + Setting it to False will avoid calling ``seed_everything``. + parser_kwargs: Additional arguments to instantiate each ``LightningArgumentParser``. + subclass_mode_model: Whether model can be any `subclass + `_ + of the given class. + subclass_mode_data: Whether datamodule can be any `subclass + `_ + of the given class. + args: Arguments to parse. If ``None`` the arguments are taken from ``sys.argv``. Command line style + arguments can be given in a ``list``. Alternatively, structured config options can be given in a + ``dict`` or ``jsonargparse.Namespace``. + run: Whether subcommands should be added to run a :class:`~lightning.pytorch.trainer.trainer.Trainer` + method. If set to ``False``, the trainer and model classes will be instantiated only. + parse_only: If set to ``True``, the CLI acts as a parser only. The classes are not instantiated. + """ + self.save_config_callback = save_config_callback + self.save_config_kwargs = save_config_kwargs or {} + self.trainer_class = trainer_class + self.trainer_defaults = trainer_defaults or {} + self.seed_everything_default = seed_everything_default + self.parser_kwargs = parser_kwargs or {} + self.auto_configure_optimizers = auto_configure_optimizers + + self.model_class = model_class + # used to differentiate between the original value and the processed value + self._model_class = model_class + self.subclass_mode_model = subclass_mode_model + + self.datamodule_class = datamodule_class + # used to differentiate between the original value and the processed value + self._datamodule_class = datamodule_class + self.subclass_mode_data = (datamodule_class is None) or subclass_mode_data + + main_kwargs, subparser_kwargs = self._setup_parser_kwargs(self.parser_kwargs) + self.setup_parser(run, main_kwargs, subparser_kwargs) + self.parse_arguments(self.parser, args, ignore_sys_argv) + + if not parse_only: + self.subcommand = self.config["subcommand"] if run else None + + self._set_seed() + + self._add_instantiators() + self.before_instantiate_classes() + self.instantiate_classes() + + if self.subcommand is not None: + self._run_subcommand(self.subcommand) + + def add_core_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + """Adds arguments from the core classes to the parser.""" + if self.trainer_class is not None: + parser.add_lightning_class_args(self.trainer_class, "trainer") + trainer_defaults = { + "trainer." + k: v + for k, v in self.trainer_defaults.items() + if k != "callbacks" + } + parser.set_defaults(trainer_defaults) + + if self._model_class is not None: + parser.add_lightning_class_args( + self._model_class, "model", subclass_mode=self.subclass_mode_model + ) + + if self.datamodule_class is not None: + parser.add_lightning_class_args( + self._datamodule_class, "data", subclass_mode=self.subclass_mode_data + ) + + def parse_arguments( + self, parser: LightningArgumentParser, args: ArgsType, ignore_sys_argv: bool + ) -> None: + """Parses command line arguments and stores it in ``self.config``.""" + if args is not None and len(sys.argv) > 1 and not ignore_sys_argv: + rank_zero_warn( + "LightningCLI's args parameter is intended to run from within Python like if it were from the command " + "line. To prevent mistakes it is not recommended to provide both args and command line arguments, got: " + f"sys.argv[1:]={sys.argv[1:]}, args={args}." + ) + if isinstance(args, (dict, Namespace)): + self.config = parser.parse_object(args) + else: + self.config = parser.parse_args(args) + + def instantiate_classes(self) -> None: + """Instantiates the classes and sets their attributes.""" + self.config_init = self.parser.instantiate_classes(self.config) + self.datamodule = self._get(self.config_init, "data") + self.model = self._get(self.config_init, "model") + self._add_configure_optimizers_method_to_model(self.subcommand) + + if self.trainer_class is not None: + self.trainer = self.instantiate_trainer() diff --git a/ptlflow/utils/lightning/ptlflow_trainer.py b/ptlflow/utils/lightning/ptlflow_trainer.py new file mode 100644 index 0000000..66ad80e --- /dev/null +++ b/ptlflow/utils/lightning/ptlflow_trainer.py @@ -0,0 +1,436 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# THIS FILE MUST READ EASILY, FOR UNDERSTANDING AND DEBUGGING PURPOSES. +# DO NOT OBSCURE THE TRAINING LOOP +# THIS IS A HARD REQUIREMENT TO CONTRIBUTING TO LIGHTNING +# WE FAVOR READABILITY OVER ENGINEERING-CONSTRUCTS BY DESIGN +# DO NOT REMOVE THIS NOTICE +# - WILLIAM FALCON +"""Trainer to automate the training.""" + +import logging +from datetime import timedelta +from typing import Dict, Iterable, List, Optional, Union +from weakref import proxy + +import lightning.pytorch as pl +from lightning.fabric.utilities.types import _PATH +from lightning.pytorch.accelerators import Accelerator +from lightning.pytorch.callbacks import Callback +from lightning.pytorch.loggers import Logger +from lightning.pytorch.loggers.utilities import _log_hyperparams +from lightning.pytorch.loops.utilities import _parse_loop_limits +from lightning.pytorch.plugins import _PLUGIN_INPUT +from lightning.pytorch.profilers import Profiler +from lightning.pytorch.strategies import Strategy +from lightning.pytorch.trainer import call +from lightning.pytorch.trainer.configuration_validator import ( + _verify_loop_configurations, +) +from lightning.pytorch.trainer.connectors.accelerator_connector import ( + _LITERAL_WARN, + _PRECISION_INPUT, +) +from lightning.pytorch.trainer.states import ( + TrainerFn, + TrainerStatus, +) +from lightning.pytorch.utilities import parsing +from lightning.pytorch.utilities.argparse import _defaults_from_env_vars +from lightning.pytorch.utilities.rank_zero import rank_zero_info +from lightning.pytorch.utilities.types import ( + _EVALUATE_OUTPUT, + _PREDICT_OUTPUT, +) + +from .ptlflow_checkpoint_connector import ( + _PTLFlowCheckpointConnector, +) + +log = logging.getLogger(__name__) + + +class PTLFlowTrainer(pl.Trainer): + @_defaults_from_env_vars + def __init__( + self, + *, + accelerator: Union[str, Accelerator] = "auto", + strategy: Union[str, Strategy] = "auto", + devices: Union[List[int], str, int] = "auto", + num_nodes: int = 1, + precision: Optional[_PRECISION_INPUT] = None, + logger: Optional[Union[Logger, Iterable[Logger], bool]] = None, + callbacks: Optional[Union[List[Callback], Callback]] = None, + fast_dev_run: Union[int, bool] = False, + max_epochs: Optional[int] = None, + min_epochs: Optional[int] = None, + max_steps: int = -1, + min_steps: Optional[int] = None, + max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None, + limit_train_batches: Optional[Union[int, float]] = None, + limit_val_batches: Optional[Union[int, float]] = None, + limit_test_batches: Optional[Union[int, float]] = None, + limit_predict_batches: Optional[Union[int, float]] = None, + overfit_batches: Union[int, float] = 0.0, + val_check_interval: Optional[Union[int, float]] = None, + check_val_every_n_epoch: Optional[int] = 1, + num_sanity_val_steps: Optional[int] = None, + log_every_n_steps: Optional[int] = None, + enable_checkpointing: Optional[bool] = None, + enable_progress_bar: Optional[bool] = None, + enable_model_summary: Optional[bool] = None, + accumulate_grad_batches: int = 1, + gradient_clip_val: Optional[Union[int, float]] = None, + gradient_clip_algorithm: Optional[str] = None, + deterministic: Optional[Union[bool, _LITERAL_WARN]] = None, + benchmark: Optional[bool] = None, + inference_mode: bool = True, + use_distributed_sampler: bool = True, + profiler: Optional[Union[Profiler, str]] = None, + detect_anomaly: bool = False, + barebones: bool = False, + plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None, + sync_batchnorm: bool = False, + reload_dataloaders_every_n_epochs: int = 0, + default_root_dir: Optional[_PATH] = None, + ) -> None: + r"""Customize every aspect of training via flags. + + Args: + accelerator: Supports passing different accelerator types ("cpu", "gpu", "tpu", "hpu", "mps", "auto") + as well as custom accelerator instances. + + strategy: Supports different training strategies with aliases as well custom strategies. + Default: ``"auto"``. + + devices: The devices to use. Can be set to a positive number (int or str), a sequence of device indices + (list or str), the value ``-1`` to indicate all available devices should be used, or ``"auto"`` for + automatic selection based on the chosen accelerator. Default: ``"auto"``. + + num_nodes: Number of GPU nodes for distributed training. + Default: ``1``. + + precision: Double precision (64, '64' or '64-true'), full precision (32, '32' or '32-true'), + 16bit mixed precision (16, '16', '16-mixed') or bfloat16 mixed precision ('bf16', 'bf16-mixed'). + Can be used on CPU, GPU, TPUs, or HPUs. + Default: ``'32-true'``. + + logger: Logger (or iterable collection of loggers) for experiment tracking. A ``True`` value uses + the default ``TensorBoardLogger`` if it is installed, otherwise ``CSVLogger``. + ``False`` will disable logging. If multiple loggers are provided, local files + (checkpoints, profiler traces, etc.) are saved in the ``log_dir`` of the first logger. + Default: ``True``. + + callbacks: Add a callback or list of callbacks. + Default: ``None``. + + fast_dev_run: Runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es) + of train, val and test to find any bugs (ie: a sort of unit test). + Default: ``False``. + + max_epochs: Stop training once this number of epochs is reached. Disabled by default (None). + If both max_epochs and max_steps are not specified, defaults to ``max_epochs = 1000``. + To enable infinite training, set ``max_epochs = -1``. + + min_epochs: Force training for at least these many epochs. Disabled by default (None). + + max_steps: Stop training after this number of steps. Disabled by default (-1). If ``max_steps = -1`` + and ``max_epochs = None``, will default to ``max_epochs = 1000``. To enable infinite training, set + ``max_epochs`` to ``-1``. + + min_steps: Force training for at least these number of steps. Disabled by default (``None``). + + max_time: Stop training after this amount of time has passed. Disabled by default (``None``). + The time duration can be specified in the format DD:HH:MM:SS (days, hours, minutes seconds), as a + :class:`datetime.timedelta`, or a dictionary with keys that will be passed to + :class:`datetime.timedelta`. + + limit_train_batches: How much of training dataset to check (float = fraction, int = num_batches). + Default: ``1.0``. + + limit_val_batches: How much of validation dataset to check (float = fraction, int = num_batches). + Default: ``1.0``. + + limit_test_batches: How much of test dataset to check (float = fraction, int = num_batches). + Default: ``1.0``. + + limit_predict_batches: How much of prediction dataset to check (float = fraction, int = num_batches). + Default: ``1.0``. + + overfit_batches: Overfit a fraction of training/validation data (float) or a set number of batches (int). + Default: ``0.0``. + + val_check_interval: How often to check the validation set. Pass a ``float`` in the range [0.0, 1.0] to check + after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training + batches. An ``int`` value can only be higher than the number of training batches when + ``check_val_every_n_epoch=None``, which validates after every ``N`` training batches + across epochs or during iteration-based training. + Default: ``1.0``. + + check_val_every_n_epoch: Perform a validation loop after every `N` training epochs. If ``None``, + validation will be done solely based on the number of training batches, requiring ``val_check_interval`` + to be an integer value. + Default: ``1``. + + num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine. + Set it to `-1` to run all batches in all validation dataloaders. + Default: ``2``. + + log_every_n_steps: How often to log within steps. + Default: ``50``. + + enable_checkpointing: If ``True``, enable checkpointing. + It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in + :paramref:`~lightning.pytorch.trainer.trainer.Trainer.callbacks`. + Default: ``True``. + + enable_progress_bar: Whether to enable to progress bar by default. + Default: ``True``. + + enable_model_summary: Whether to enable model summarization by default. + Default: ``True``. + + accumulate_grad_batches: Accumulates gradients over k batches before stepping the optimizer. + Default: 1. + + gradient_clip_val: The value at which to clip gradients. Passing ``gradient_clip_val=None`` disables + gradient clipping. If using Automatic Mixed Precision (AMP), the gradients will be unscaled before. + Default: ``None``. + + gradient_clip_algorithm: The gradient clipping algorithm to use. Pass ``gradient_clip_algorithm="value"`` + to clip by value, and ``gradient_clip_algorithm="norm"`` to clip by norm. By default it will + be set to ``"norm"``. + + deterministic: If ``True``, sets whether PyTorch operations must use deterministic algorithms. + Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations + that don't support deterministic mode. If not set, defaults to ``False``. Default: ``None``. + + benchmark: The value (``True`` or ``False``) to set ``torch.backends.cudnn.benchmark`` to. + The value for ``torch.backends.cudnn.benchmark`` set in the current session will be used + (``False`` if not manually set). If :paramref:`~lightning.pytorch.trainer.trainer.Trainer.deterministic` + is set to ``True``, this will default to ``False``. Override to manually set a different value. + Default: ``None``. + + inference_mode: Whether to use :func:`torch.inference_mode` or :func:`torch.no_grad` during + evaluation (``validate``/``test``/``predict``). + + use_distributed_sampler: Whether to wrap the DataLoader's sampler with + :class:`torch.utils.data.DistributedSampler`. If not specified this is toggled automatically for + strategies that require it. By default, it will add ``shuffle=True`` for the train sampler and + ``shuffle=False`` for validation/test/predict samplers. If you want to disable this logic, you can pass + ``False`` and add your own distributed sampler in the dataloader hooks. If ``True`` and a distributed + sampler was already added, Lightning will not replace the existing one. For iterable-style datasets, + we don't do this automatically. + + profiler: To profile individual steps during training and assist in identifying bottlenecks. + Default: ``None``. + + detect_anomaly: Enable anomaly detection for the autograd engine. + Default: ``False``. + + barebones: Whether to run in "barebones mode", where all features that may impact raw speed are + disabled. This is meant for analyzing the Trainer overhead and is discouraged during regular training + runs. The following features are deactivated: + :paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_checkpointing`, + :paramref:`~lightning.pytorch.trainer.trainer.Trainer.logger`, + :paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_progress_bar`, + :paramref:`~lightning.pytorch.trainer.trainer.Trainer.log_every_n_steps`, + :paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_model_summary`, + :paramref:`~lightning.pytorch.trainer.trainer.Trainer.num_sanity_val_steps`, + :paramref:`~lightning.pytorch.trainer.trainer.Trainer.fast_dev_run`, + :paramref:`~lightning.pytorch.trainer.trainer.Trainer.detect_anomaly`, + :paramref:`~lightning.pytorch.trainer.trainer.Trainer.profiler`, + :meth:`~lightning.pytorch.core.LightningModule.log`, + :meth:`~lightning.pytorch.core.LightningModule.log_dict`. + plugins: Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins. + Default: ``None``. + + sync_batchnorm: Synchronize batch norm layers between process groups/whole world. + Default: ``False``. + + reload_dataloaders_every_n_epochs: Set to a positive integer to reload dataloaders every n epochs. + Default: ``0``. + + default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed. + Default: ``os.getcwd()``. + Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/' + + Raises: + TypeError: + If ``gradient_clip_val`` is not an int or float. + + MisconfigurationException: + If ``gradient_clip_algorithm`` is invalid. + + """ + super().__init__( + accelerator=accelerator, + strategy=strategy, + devices=devices, + num_nodes=num_nodes, + precision=precision, + logger=logger, + callbacks=callbacks, + fast_dev_run=fast_dev_run, + max_epochs=max_epochs, + min_epochs=min_epochs, + max_steps=max_steps, + min_steps=min_steps, + max_time=max_time, + limit_train_batches=limit_train_batches, + limit_val_batches=limit_val_batches, + limit_test_batches=limit_test_batches, + limit_predict_batches=limit_predict_batches, + overfit_batches=overfit_batches, + val_check_interval=val_check_interval, + check_val_every_n_epoch=check_val_every_n_epoch, + num_sanity_val_steps=num_sanity_val_steps, + log_every_n_steps=log_every_n_steps, + enable_checkpointing=enable_checkpointing, + enable_progress_bar=enable_progress_bar, + enable_model_summary=enable_model_summary, + accumulate_grad_batches=accumulate_grad_batches, + gradient_clip_val=gradient_clip_val, + gradient_clip_algorithm=gradient_clip_algorithm, + deterministic=deterministic, + benchmark=benchmark, + inference_mode=inference_mode, + use_distributed_sampler=use_distributed_sampler, + profiler=profiler, + detect_anomaly=detect_anomaly, + barebones=barebones, + plugins=plugins, + sync_batchnorm=sync_batchnorm, + reload_dataloaders_every_n_epochs=reload_dataloaders_every_n_epochs, + default_root_dir=default_root_dir, + ) + self._checkpoint_connector = _PTLFlowCheckpointConnector(self) + + def _run( + self, model: "pl.LightningModule", ckpt_path: Optional[_PATH] = None + ) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: + if self.state.fn == TrainerFn.FITTING: + min_epochs, max_epochs = _parse_loop_limits( + self.min_steps, self.max_steps, self.min_epochs, self.max_epochs, self + ) + self.fit_loop.min_epochs = min_epochs + self.fit_loop.max_epochs = max_epochs + + if self.barebones: + # no progress bar in barebones can make it look like the Trainer hung + rank_zero_info( + "`Trainer(barebones=True)` started running. The progress bar is disabled so you might want to" + " manually print the progress in your model." + ) + + # clean hparams + if hasattr(model, "hparams"): + parsing.clean_namespace(model.hparams) + + # attach model to the strategy + self.strategy.connect(model) + + self._callback_connector._attach_model_callbacks() + self._callback_connector._attach_model_logging_functions() + + _verify_loop_configurations(self) + + # ---------------------------- + # SET UP THE TRAINER + # ---------------------------- + log.debug(f"{self.__class__.__name__}: setting up strategy environment") + self.strategy.setup_environment() + self.__setup_profiler() + + log.debug(f"{self.__class__.__name__}: preparing data") + self._data_connector.prepare_data() + + call._call_setup_hook( + self + ) # allow user to set up LightningModule in accelerator environment + log.debug(f"{self.__class__.__name__}: configuring model") + call._call_configure_model(self) + + # check if we should delay restoring checkpoint till later + if not self.strategy.restore_checkpoint_after_setup: + log.debug( + f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}" + ) + self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path, model) + + # reset logger connector + self._logger_connector.reset_results() + self._logger_connector.reset_metrics() + + # strategy will configure model and move it to the device + self.strategy.setup(self) + + # hook + if self.state.fn == TrainerFn.FITTING: + call._call_callback_hooks(self, "on_fit_start") + call._call_lightning_module_hook(self, "on_fit_start") + + _log_hyperparams(self) + + if self.strategy.restore_checkpoint_after_setup: + log.debug( + f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}" + ) + self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path) + + # restore optimizers, etc. + try: + log.debug(f"{self.__class__.__name__}: restoring training state") + self._checkpoint_connector.restore_training_state() + except KeyError: + log.info( + "The provided checkpoint does not contain the training state. Only the model weights will be loaded." + ) + + self._checkpoint_connector.resume_end() + + self._signal_connector.register_signal_handlers() + + # ---------------------------- + # RUN THE TRAINER + # ---------------------------- + results = self._run_stage() + + # ---------------------------- + # POST-Training CLEAN UP + # ---------------------------- + log.debug(f"{self.__class__.__name__}: trainer tearing down") + self._teardown() + + if self.state.fn == TrainerFn.FITTING: + call._call_callback_hooks(self, "on_fit_end") + call._call_lightning_module_hook(self, "on_fit_end") + + log.debug(f"{self.__class__.__name__}: calling teardown hooks") + call._call_teardown_hook(self) + + self.state.status = TrainerStatus.FINISHED + self.state.stage = None + + return results + + def __setup_profiler(self) -> None: + assert self.state.fn is not None + local_rank = self.local_rank if self.world_size > 1 else None + self.profiler._lightning_module = proxy(self.lightning_module) + self.profiler.setup( + stage=self.state.fn, local_rank=local_rank, log_dir=self.log_dir + )