-
Notifications
You must be signed in to change notification settings - Fork 38
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
04bbd2f
commit 3af3852
Showing
3 changed files
with
804 additions
and
0 deletions.
There are no files selected for viewing
196 changes: 196 additions & 0 deletions
196
ptlflow/utils/lightning/ptlflow_checkpoint_connector.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
# Copyright The Lightning AI team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import logging | ||
import os | ||
from pathlib import Path | ||
import re | ||
from typing import Optional | ||
|
||
import torch | ||
from torch import hub | ||
from fsspec.core import url_to_fs | ||
|
||
import lightning.pytorch as pl | ||
from lightning.fabric.utilities.types import _PATH | ||
from lightning.pytorch.trainer.connectors.checkpoint_connector import ( | ||
_CheckpointConnector, | ||
) | ||
from lightning.pytorch.trainer.states import TrainerFn | ||
from lightning.pytorch.utilities.migration import pl_legacy_patch | ||
from lightning.pytorch.utilities.migration.utils import _pl_migrate_checkpoint | ||
from lightning.pytorch.utilities.rank_zero import rank_zero_info | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
class _PTLFlowCheckpointConnector(_CheckpointConnector): | ||
def __init__(self, trainer: pl.Trainer) -> None: | ||
super().__init__(trainer) | ||
|
||
def resume_start( | ||
self, | ||
checkpoint_path: Optional[_PATH] = None, | ||
model: Optional[pl.LightningModule] = None, | ||
) -> None: | ||
"""Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority: | ||
1. from HPC weights if `checkpoint_path` is ``None`` and on SLURM or passed keyword `"hpc"`. | ||
2. from fault-tolerant auto-saved checkpoint if found | ||
3. from `checkpoint_path` file if provided | ||
4. don't restore | ||
""" | ||
self._ckpt_path = checkpoint_path | ||
if not checkpoint_path: | ||
log.debug("`checkpoint_path` not specified. Skipping checkpoint loading.") | ||
return | ||
|
||
if not Path(checkpoint_path).exists(): | ||
if model is not None: | ||
model_ref = model.__class__ | ||
if hasattr(model_ref, "pretrained_checkpoints"): | ||
checkpoint_path = model_ref.pretrained_checkpoints.get( | ||
checkpoint_path | ||
) | ||
if checkpoint_path is None: | ||
raise ValueError( | ||
f"Invalid checkpoint name {checkpoint_path}. " | ||
f'Choose one from {{{",".join(model_ref.pretrained_checkpoints.keys())}}}' | ||
) | ||
|
||
cache_path = ( | ||
Path(hub.get_dir()) | ||
/ "checkpoints" | ||
/ checkpoint_path.split("/")[-1] | ||
) | ||
if cache_path.exists(): | ||
checkpoint_path = cache_path | ||
else: | ||
raise ValueError( | ||
f"Cannot find checkpoint {checkpoint_path} for model {model.__class__.__name__}" | ||
) | ||
|
||
rank_zero_info( | ||
f"Restoring states from the checkpoint path at {checkpoint_path}" | ||
) | ||
with pl_legacy_patch(): | ||
loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path) | ||
if not "pytorch-lightning_version" in loaded_checkpoint: | ||
loaded_checkpoint["pytorch-lightning_version"] = "1.9.5" | ||
self._loaded_checkpoint = _pl_migrate_checkpoint( | ||
loaded_checkpoint, checkpoint_path | ||
) | ||
|
||
def resume_end(self) -> None: | ||
"""Signal the connector that all states have resumed and memory for the checkpoint object can be released.""" | ||
assert self.trainer.state.fn is not None | ||
if self._ckpt_path: | ||
message = ( | ||
"Restored all states" | ||
if self.trainer.state.fn == TrainerFn.FITTING | ||
else "Loaded model weights" | ||
) | ||
rank_zero_info(f"{message} from the checkpoint at {self._ckpt_path}") | ||
|
||
# free memory | ||
self._loaded_checkpoint = {} | ||
torch.cuda.empty_cache() | ||
|
||
# wait for all to catch up | ||
self.trainer.strategy.barrier("_PTLFlowCheckpointConnector.resume_end") | ||
|
||
def restore_training_state(self) -> None: | ||
"""Restore the trainer state from the pre-loaded checkpoint. | ||
This includes the precision settings, loop progress, optimizer states and learning rate scheduler states. | ||
Modifications by Henrique Morimitsu: | ||
- restore_optimizers_and_schedulers before other states to raise an error earlier in case the ckpt does not have training states | ||
""" | ||
if not self._loaded_checkpoint: | ||
return | ||
|
||
assert self.trainer.state.fn is not None | ||
if self.trainer.state.fn == TrainerFn.FITTING: | ||
# restore optimizers and schedulers state | ||
self.restore_optimizers_and_schedulers() | ||
|
||
# restore precision plugin (scaler etc.) | ||
self.restore_precision_plugin_state() | ||
|
||
# restore loops and their progress | ||
self.restore_loops() | ||
|
||
def _restore_modules_and_callbacks( | ||
self, | ||
checkpoint_path: Optional[_PATH] = None, | ||
model: Optional[pl.LightningModule] = None, | ||
) -> None: | ||
# restore modules after setup | ||
self.resume_start(checkpoint_path, model) | ||
self.restore_model() | ||
self.restore_datamodule() | ||
if self.trainer.state.fn == TrainerFn.FITTING: | ||
# restore callback states | ||
self.restore_callbacks() | ||
|
||
@staticmethod | ||
def __max_ckpt_version_in_folder( | ||
dir_path: _PATH, name_key: str = "ckpt_" | ||
) -> Optional[int]: | ||
"""List up files in `dir_path` with `name_key`, then yield maximum suffix number. | ||
Args: | ||
dir_path: path of directory which may contain files whose name include `name_key` | ||
name_key: file name prefix | ||
Returns: | ||
None if no-corresponding-file else maximum suffix number | ||
""" | ||
# check directory existence | ||
fs, uri = url_to_fs(str(dir_path)) | ||
if not fs.exists(dir_path): | ||
return None | ||
|
||
# check corresponding file existence | ||
files = [os.path.basename(f["name"]) for f in fs.listdir(uri)] | ||
files = [x for x in files if name_key in x] | ||
if len(files) == 0: | ||
return None | ||
|
||
# extract suffix number | ||
ckpt_vs = [] | ||
for name in files: | ||
name = name.split(name_key)[-1] | ||
name = re.sub("[^0-9]", "", name) | ||
ckpt_vs.append(int(name)) | ||
|
||
return max(ckpt_vs) | ||
|
||
@staticmethod | ||
def __get_max_ckpt_path_from_folder(folder_path: _PATH) -> str: | ||
"""Get path of maximum-epoch checkpoint in the folder.""" | ||
max_suffix = _PTLFlowCheckpointConnector.__max_ckpt_version_in_folder( | ||
folder_path | ||
) | ||
ckpt_number = max_suffix if max_suffix is not None else 0 | ||
return f"{folder_path}/hpc_ckpt_{ckpt_number}.ckpt" | ||
|
||
@staticmethod | ||
def hpc_save_path(folderpath: _PATH) -> str: | ||
max_suffix = _PTLFlowCheckpointConnector.__max_ckpt_version_in_folder( | ||
folderpath | ||
) | ||
ckpt_number = (max_suffix if max_suffix is not None else 0) + 1 | ||
return os.path.join(folderpath, f"hpc_ckpt_{ckpt_number}.ckpt") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
# Copyright The Lightning AI team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import sys | ||
from typing import Any, Callable, Dict, Optional, Type, Union | ||
|
||
from jsonargparse import Namespace | ||
from lightning import LightningDataModule, LightningModule, Trainer | ||
from lightning.pytorch.cli import ( | ||
ArgsType, | ||
LightningArgumentParser, | ||
LightningCLI, | ||
SaveConfigCallback, | ||
) | ||
from lightning.pytorch.utilities.rank_zero import rank_zero_warn | ||
|
||
|
||
class PTLFlowCLI(LightningCLI): | ||
def __init__( | ||
self, | ||
model_class: Optional[ | ||
Union[Type[LightningModule], Callable[..., LightningModule]] | ||
] = None, | ||
datamodule_class: Optional[ | ||
Union[Type[LightningDataModule], Callable[..., LightningDataModule]] | ||
] = None, | ||
save_config_callback: Optional[Type[SaveConfigCallback]] = SaveConfigCallback, | ||
save_config_kwargs: Optional[Dict[str, Any]] = None, | ||
trainer_class: Optional[Union[Type[Trainer], Callable[..., Trainer]]] = None, | ||
trainer_defaults: Optional[Dict[str, Any]] = None, | ||
seed_everything_default: Union[bool, int] = True, | ||
parser_kwargs: Optional[ | ||
Union[Dict[str, Any], Dict[str, Dict[str, Any]]] | ||
] = None, | ||
subclass_mode_model: bool = False, | ||
subclass_mode_data: bool = False, | ||
args: ArgsType = None, | ||
run: bool = True, | ||
auto_configure_optimizers: bool = True, | ||
parse_only: bool = False, | ||
ignore_sys_argv: bool = False, | ||
) -> None: | ||
"""Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which are | ||
called / instantiated using a parsed configuration file and / or command line args. | ||
Parsing of configuration from environment variables can be enabled by setting ``parser_kwargs={"default_env": | ||
True}``. A full configuration yaml would be parsed from ``PL_CONFIG`` if set. Individual settings are so parsed | ||
from variables named for example ``PL_TRAINER__MAX_EPOCHS``. | ||
For more info, read :ref:`the CLI docs <lightning-cli>`. | ||
Args: | ||
model_class: An optional :class:`~lightning.pytorch.core.LightningModule` class to train on or a | ||
callable which returns a :class:`~lightning.pytorch.core.LightningModule` instance when | ||
called. If ``None``, you can pass a registered model with ``--model=MyModel``. | ||
datamodule_class: An optional :class:`~lightning.pytorch.core.datamodule.LightningDataModule` class or a | ||
callable which returns a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` instance when | ||
called. If ``None``, you can pass a registered datamodule with ``--data=MyDataModule``. | ||
save_config_callback: A callback class to save the config. | ||
save_config_kwargs: Parameters that will be used to instantiate the save_config_callback. | ||
trainer_class: An optional subclass of the :class:`~lightning.pytorch.trainer.trainer.Trainer` class or a | ||
callable which returns a :class:`~lightning.pytorch.trainer.trainer.Trainer` instance when called. | ||
trainer_defaults: Set to override Trainer defaults or add persistent callbacks. The callbacks added through | ||
this argument will not be configurable from a configuration file and will always be present for | ||
this particular CLI. Alternatively, configurable callbacks can be added as explained in | ||
:ref:`the CLI docs <lightning-cli>`. | ||
seed_everything_default: Number for the :func:`~lightning.fabric.utilities.seed.seed_everything` | ||
seed value. Set to True to automatically choose a seed value. | ||
Setting it to False will avoid calling ``seed_everything``. | ||
parser_kwargs: Additional arguments to instantiate each ``LightningArgumentParser``. | ||
subclass_mode_model: Whether model can be any `subclass | ||
<https://jsonargparse.readthedocs.io/en/stable/#class-type-and-sub-classes>`_ | ||
of the given class. | ||
subclass_mode_data: Whether datamodule can be any `subclass | ||
<https://jsonargparse.readthedocs.io/en/stable/#class-type-and-sub-classes>`_ | ||
of the given class. | ||
args: Arguments to parse. If ``None`` the arguments are taken from ``sys.argv``. Command line style | ||
arguments can be given in a ``list``. Alternatively, structured config options can be given in a | ||
``dict`` or ``jsonargparse.Namespace``. | ||
run: Whether subcommands should be added to run a :class:`~lightning.pytorch.trainer.trainer.Trainer` | ||
method. If set to ``False``, the trainer and model classes will be instantiated only. | ||
parse_only: If set to ``True``, the CLI acts as a parser only. The classes are not instantiated. | ||
""" | ||
self.save_config_callback = save_config_callback | ||
self.save_config_kwargs = save_config_kwargs or {} | ||
self.trainer_class = trainer_class | ||
self.trainer_defaults = trainer_defaults or {} | ||
self.seed_everything_default = seed_everything_default | ||
self.parser_kwargs = parser_kwargs or {} | ||
self.auto_configure_optimizers = auto_configure_optimizers | ||
|
||
self.model_class = model_class | ||
# used to differentiate between the original value and the processed value | ||
self._model_class = model_class | ||
self.subclass_mode_model = subclass_mode_model | ||
|
||
self.datamodule_class = datamodule_class | ||
# used to differentiate between the original value and the processed value | ||
self._datamodule_class = datamodule_class | ||
self.subclass_mode_data = (datamodule_class is None) or subclass_mode_data | ||
|
||
main_kwargs, subparser_kwargs = self._setup_parser_kwargs(self.parser_kwargs) | ||
self.setup_parser(run, main_kwargs, subparser_kwargs) | ||
self.parse_arguments(self.parser, args, ignore_sys_argv) | ||
|
||
if not parse_only: | ||
self.subcommand = self.config["subcommand"] if run else None | ||
|
||
self._set_seed() | ||
|
||
self._add_instantiators() | ||
self.before_instantiate_classes() | ||
self.instantiate_classes() | ||
|
||
if self.subcommand is not None: | ||
self._run_subcommand(self.subcommand) | ||
|
||
def add_core_arguments_to_parser(self, parser: LightningArgumentParser) -> None: | ||
"""Adds arguments from the core classes to the parser.""" | ||
if self.trainer_class is not None: | ||
parser.add_lightning_class_args(self.trainer_class, "trainer") | ||
trainer_defaults = { | ||
"trainer." + k: v | ||
for k, v in self.trainer_defaults.items() | ||
if k != "callbacks" | ||
} | ||
parser.set_defaults(trainer_defaults) | ||
|
||
if self._model_class is not None: | ||
parser.add_lightning_class_args( | ||
self._model_class, "model", subclass_mode=self.subclass_mode_model | ||
) | ||
|
||
if self.datamodule_class is not None: | ||
parser.add_lightning_class_args( | ||
self._datamodule_class, "data", subclass_mode=self.subclass_mode_data | ||
) | ||
|
||
def parse_arguments( | ||
self, parser: LightningArgumentParser, args: ArgsType, ignore_sys_argv: bool | ||
) -> None: | ||
"""Parses command line arguments and stores it in ``self.config``.""" | ||
if args is not None and len(sys.argv) > 1 and not ignore_sys_argv: | ||
rank_zero_warn( | ||
"LightningCLI's args parameter is intended to run from within Python like if it were from the command " | ||
"line. To prevent mistakes it is not recommended to provide both args and command line arguments, got: " | ||
f"sys.argv[1:]={sys.argv[1:]}, args={args}." | ||
) | ||
if isinstance(args, (dict, Namespace)): | ||
self.config = parser.parse_object(args) | ||
else: | ||
self.config = parser.parse_args(args) | ||
|
||
def instantiate_classes(self) -> None: | ||
"""Instantiates the classes and sets their attributes.""" | ||
self.config_init = self.parser.instantiate_classes(self.config) | ||
self.datamodule = self._get(self.config_init, "data") | ||
self.model = self._get(self.config_init, "model") | ||
self._add_configure_optimizers_method_to_model(self.subcommand) | ||
|
||
if self.trainer_class is not None: | ||
self.trainer = self.instantiate_trainer() |
Oops, something went wrong.