Skip to content

Commit

Permalink
Add lightning custom files
Browse files Browse the repository at this point in the history
  • Loading branch information
hmorimitsu committed Dec 3, 2024
1 parent 04bbd2f commit 3af3852
Show file tree
Hide file tree
Showing 3 changed files with 804 additions and 0 deletions.
196 changes: 196 additions & 0 deletions ptlflow/utils/lightning/ptlflow_checkpoint_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from pathlib import Path
import re
from typing import Optional

import torch
from torch import hub
from fsspec.core import url_to_fs

import lightning.pytorch as pl
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.trainer.connectors.checkpoint_connector import (
_CheckpointConnector,
)
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities.migration import pl_legacy_patch
from lightning.pytorch.utilities.migration.utils import _pl_migrate_checkpoint
from lightning.pytorch.utilities.rank_zero import rank_zero_info

log = logging.getLogger(__name__)


class _PTLFlowCheckpointConnector(_CheckpointConnector):
def __init__(self, trainer: pl.Trainer) -> None:
super().__init__(trainer)

def resume_start(
self,
checkpoint_path: Optional[_PATH] = None,
model: Optional[pl.LightningModule] = None,
) -> None:
"""Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority:
1. from HPC weights if `checkpoint_path` is ``None`` and on SLURM or passed keyword `"hpc"`.
2. from fault-tolerant auto-saved checkpoint if found
3. from `checkpoint_path` file if provided
4. don't restore
"""
self._ckpt_path = checkpoint_path
if not checkpoint_path:
log.debug("`checkpoint_path` not specified. Skipping checkpoint loading.")
return

if not Path(checkpoint_path).exists():
if model is not None:
model_ref = model.__class__
if hasattr(model_ref, "pretrained_checkpoints"):
checkpoint_path = model_ref.pretrained_checkpoints.get(
checkpoint_path
)
if checkpoint_path is None:
raise ValueError(
f"Invalid checkpoint name {checkpoint_path}. "
f'Choose one from {{{",".join(model_ref.pretrained_checkpoints.keys())}}}'
)

cache_path = (
Path(hub.get_dir())
/ "checkpoints"
/ checkpoint_path.split("/")[-1]
)
if cache_path.exists():
checkpoint_path = cache_path
else:
raise ValueError(
f"Cannot find checkpoint {checkpoint_path} for model {model.__class__.__name__}"
)

rank_zero_info(
f"Restoring states from the checkpoint path at {checkpoint_path}"
)
with pl_legacy_patch():
loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path)
if not "pytorch-lightning_version" in loaded_checkpoint:
loaded_checkpoint["pytorch-lightning_version"] = "1.9.5"
self._loaded_checkpoint = _pl_migrate_checkpoint(
loaded_checkpoint, checkpoint_path
)

def resume_end(self) -> None:
"""Signal the connector that all states have resumed and memory for the checkpoint object can be released."""
assert self.trainer.state.fn is not None
if self._ckpt_path:
message = (
"Restored all states"
if self.trainer.state.fn == TrainerFn.FITTING
else "Loaded model weights"
)
rank_zero_info(f"{message} from the checkpoint at {self._ckpt_path}")

# free memory
self._loaded_checkpoint = {}
torch.cuda.empty_cache()

# wait for all to catch up
self.trainer.strategy.barrier("_PTLFlowCheckpointConnector.resume_end")

def restore_training_state(self) -> None:
"""Restore the trainer state from the pre-loaded checkpoint.
This includes the precision settings, loop progress, optimizer states and learning rate scheduler states.
Modifications by Henrique Morimitsu:
- restore_optimizers_and_schedulers before other states to raise an error earlier in case the ckpt does not have training states
"""
if not self._loaded_checkpoint:
return

assert self.trainer.state.fn is not None
if self.trainer.state.fn == TrainerFn.FITTING:
# restore optimizers and schedulers state
self.restore_optimizers_and_schedulers()

# restore precision plugin (scaler etc.)
self.restore_precision_plugin_state()

# restore loops and their progress
self.restore_loops()

def _restore_modules_and_callbacks(
self,
checkpoint_path: Optional[_PATH] = None,
model: Optional[pl.LightningModule] = None,
) -> None:
# restore modules after setup
self.resume_start(checkpoint_path, model)
self.restore_model()
self.restore_datamodule()
if self.trainer.state.fn == TrainerFn.FITTING:
# restore callback states
self.restore_callbacks()

@staticmethod
def __max_ckpt_version_in_folder(
dir_path: _PATH, name_key: str = "ckpt_"
) -> Optional[int]:
"""List up files in `dir_path` with `name_key`, then yield maximum suffix number.
Args:
dir_path: path of directory which may contain files whose name include `name_key`
name_key: file name prefix
Returns:
None if no-corresponding-file else maximum suffix number
"""
# check directory existence
fs, uri = url_to_fs(str(dir_path))
if not fs.exists(dir_path):
return None

# check corresponding file existence
files = [os.path.basename(f["name"]) for f in fs.listdir(uri)]
files = [x for x in files if name_key in x]
if len(files) == 0:
return None

# extract suffix number
ckpt_vs = []
for name in files:
name = name.split(name_key)[-1]
name = re.sub("[^0-9]", "", name)
ckpt_vs.append(int(name))

return max(ckpt_vs)

@staticmethod
def __get_max_ckpt_path_from_folder(folder_path: _PATH) -> str:
"""Get path of maximum-epoch checkpoint in the folder."""
max_suffix = _PTLFlowCheckpointConnector.__max_ckpt_version_in_folder(
folder_path
)
ckpt_number = max_suffix if max_suffix is not None else 0
return f"{folder_path}/hpc_ckpt_{ckpt_number}.ckpt"

@staticmethod
def hpc_save_path(folderpath: _PATH) -> str:
max_suffix = _PTLFlowCheckpointConnector.__max_ckpt_version_in_folder(
folderpath
)
ckpt_number = (max_suffix if max_suffix is not None else 0) + 1
return os.path.join(folderpath, f"hpc_ckpt_{ckpt_number}.ckpt")
172 changes: 172 additions & 0 deletions ptlflow/utils/lightning/ptlflow_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from typing import Any, Callable, Dict, Optional, Type, Union

from jsonargparse import Namespace
from lightning import LightningDataModule, LightningModule, Trainer
from lightning.pytorch.cli import (
ArgsType,
LightningArgumentParser,
LightningCLI,
SaveConfigCallback,
)
from lightning.pytorch.utilities.rank_zero import rank_zero_warn


class PTLFlowCLI(LightningCLI):
def __init__(
self,
model_class: Optional[
Union[Type[LightningModule], Callable[..., LightningModule]]
] = None,
datamodule_class: Optional[
Union[Type[LightningDataModule], Callable[..., LightningDataModule]]
] = None,
save_config_callback: Optional[Type[SaveConfigCallback]] = SaveConfigCallback,
save_config_kwargs: Optional[Dict[str, Any]] = None,
trainer_class: Optional[Union[Type[Trainer], Callable[..., Trainer]]] = None,
trainer_defaults: Optional[Dict[str, Any]] = None,
seed_everything_default: Union[bool, int] = True,
parser_kwargs: Optional[
Union[Dict[str, Any], Dict[str, Dict[str, Any]]]
] = None,
subclass_mode_model: bool = False,
subclass_mode_data: bool = False,
args: ArgsType = None,
run: bool = True,
auto_configure_optimizers: bool = True,
parse_only: bool = False,
ignore_sys_argv: bool = False,
) -> None:
"""Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which are
called / instantiated using a parsed configuration file and / or command line args.
Parsing of configuration from environment variables can be enabled by setting ``parser_kwargs={"default_env":
True}``. A full configuration yaml would be parsed from ``PL_CONFIG`` if set. Individual settings are so parsed
from variables named for example ``PL_TRAINER__MAX_EPOCHS``.
For more info, read :ref:`the CLI docs <lightning-cli>`.
Args:
model_class: An optional :class:`~lightning.pytorch.core.LightningModule` class to train on or a
callable which returns a :class:`~lightning.pytorch.core.LightningModule` instance when
called. If ``None``, you can pass a registered model with ``--model=MyModel``.
datamodule_class: An optional :class:`~lightning.pytorch.core.datamodule.LightningDataModule` class or a
callable which returns a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` instance when
called. If ``None``, you can pass a registered datamodule with ``--data=MyDataModule``.
save_config_callback: A callback class to save the config.
save_config_kwargs: Parameters that will be used to instantiate the save_config_callback.
trainer_class: An optional subclass of the :class:`~lightning.pytorch.trainer.trainer.Trainer` class or a
callable which returns a :class:`~lightning.pytorch.trainer.trainer.Trainer` instance when called.
trainer_defaults: Set to override Trainer defaults or add persistent callbacks. The callbacks added through
this argument will not be configurable from a configuration file and will always be present for
this particular CLI. Alternatively, configurable callbacks can be added as explained in
:ref:`the CLI docs <lightning-cli>`.
seed_everything_default: Number for the :func:`~lightning.fabric.utilities.seed.seed_everything`
seed value. Set to True to automatically choose a seed value.
Setting it to False will avoid calling ``seed_everything``.
parser_kwargs: Additional arguments to instantiate each ``LightningArgumentParser``.
subclass_mode_model: Whether model can be any `subclass
<https://jsonargparse.readthedocs.io/en/stable/#class-type-and-sub-classes>`_
of the given class.
subclass_mode_data: Whether datamodule can be any `subclass
<https://jsonargparse.readthedocs.io/en/stable/#class-type-and-sub-classes>`_
of the given class.
args: Arguments to parse. If ``None`` the arguments are taken from ``sys.argv``. Command line style
arguments can be given in a ``list``. Alternatively, structured config options can be given in a
``dict`` or ``jsonargparse.Namespace``.
run: Whether subcommands should be added to run a :class:`~lightning.pytorch.trainer.trainer.Trainer`
method. If set to ``False``, the trainer and model classes will be instantiated only.
parse_only: If set to ``True``, the CLI acts as a parser only. The classes are not instantiated.
"""
self.save_config_callback = save_config_callback
self.save_config_kwargs = save_config_kwargs or {}
self.trainer_class = trainer_class
self.trainer_defaults = trainer_defaults or {}
self.seed_everything_default = seed_everything_default
self.parser_kwargs = parser_kwargs or {}
self.auto_configure_optimizers = auto_configure_optimizers

self.model_class = model_class
# used to differentiate between the original value and the processed value
self._model_class = model_class
self.subclass_mode_model = subclass_mode_model

self.datamodule_class = datamodule_class
# used to differentiate between the original value and the processed value
self._datamodule_class = datamodule_class
self.subclass_mode_data = (datamodule_class is None) or subclass_mode_data

main_kwargs, subparser_kwargs = self._setup_parser_kwargs(self.parser_kwargs)
self.setup_parser(run, main_kwargs, subparser_kwargs)
self.parse_arguments(self.parser, args, ignore_sys_argv)

if not parse_only:
self.subcommand = self.config["subcommand"] if run else None

self._set_seed()

self._add_instantiators()
self.before_instantiate_classes()
self.instantiate_classes()

if self.subcommand is not None:
self._run_subcommand(self.subcommand)

def add_core_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
"""Adds arguments from the core classes to the parser."""
if self.trainer_class is not None:
parser.add_lightning_class_args(self.trainer_class, "trainer")
trainer_defaults = {
"trainer." + k: v
for k, v in self.trainer_defaults.items()
if k != "callbacks"
}
parser.set_defaults(trainer_defaults)

if self._model_class is not None:
parser.add_lightning_class_args(
self._model_class, "model", subclass_mode=self.subclass_mode_model
)

if self.datamodule_class is not None:
parser.add_lightning_class_args(
self._datamodule_class, "data", subclass_mode=self.subclass_mode_data
)

def parse_arguments(
self, parser: LightningArgumentParser, args: ArgsType, ignore_sys_argv: bool
) -> None:
"""Parses command line arguments and stores it in ``self.config``."""
if args is not None and len(sys.argv) > 1 and not ignore_sys_argv:
rank_zero_warn(
"LightningCLI's args parameter is intended to run from within Python like if it were from the command "
"line. To prevent mistakes it is not recommended to provide both args and command line arguments, got: "
f"sys.argv[1:]={sys.argv[1:]}, args={args}."
)
if isinstance(args, (dict, Namespace)):
self.config = parser.parse_object(args)
else:
self.config = parser.parse_args(args)

def instantiate_classes(self) -> None:
"""Instantiates the classes and sets their attributes."""
self.config_init = self.parser.instantiate_classes(self.config)
self.datamodule = self._get(self.config_init, "data")
self.model = self._get(self.config_init, "model")
self._add_configure_optimizers_method_to_model(self.subcommand)

if self.trainer_class is not None:
self.trainer = self.instantiate_trainer()
Loading

0 comments on commit 3af3852

Please sign in to comment.