Skip to content

Commit

Permalink
Adding checkpoint_path for resume training (#182)
Browse files Browse the repository at this point in the history
* adding ckpt_path to fit to resume training

* option to resume or fine tunning

* small changes

* add checkpoint option in the guide

* cleaned up guide

* cleaned up

* tring rename the ckpt

* cleaned up after rebase

* some changes in the guide

* run pre-commit

* fixed test

* parsing the config to model instance during fine-tunning

* small changes on guide

* changes based on the review

* small changes

* Update crabs/detection_tracking/train_model.py

Co-authored-by: sfmig <[email protected]>
Signed-off-by: nikk-nikaznan <[email protected]>

* Update crabs/detection_tracking/train_model.py

Co-authored-by: sfmig <[email protected]>
Signed-off-by: nikk-nikaznan <[email protected]>

* Update crabs/detection_tracking/train_model.py

Co-authored-by: sfmig <[email protected]>
Signed-off-by: nikk-nikaznan <[email protected]>

* cleaned up pre-commit

---------

Signed-off-by: nikk-nikaznan <[email protected]>
Co-authored-by: sfmig <[email protected]>
nikk-nikaznan and sfmig authored Jun 26, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent d7a5cef commit 4e70c2c
Showing 3 changed files with 87 additions and 12 deletions.
73 changes: 64 additions & 9 deletions crabs/detection_tracking/train_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import logging
import os
import sys
from pathlib import Path
@@ -56,6 +57,8 @@ def __init__(self, args):
self.fast_dev_run = args.fast_dev_run
self.limit_train_batches = args.limit_train_batches

self.checkpoint_path = args.checkpoint_path

def load_config_yaml(self):
with open(self.config_file, "r") as f:
self.config = yaml.safe_load(f)
@@ -77,16 +80,16 @@ def setup_trainer(self):
)

# Define checkpointing callback for trainer
config = self.config.get("checkpoint_saving")
if config:
config_ckpt = self.config.get("checkpoint_saving")
if config_ckpt:
checkpoint_callback = ModelCheckpoint(
filename="checkpoint-{epoch}",
every_n_epochs=config["every_n_epochs"],
save_top_k=config["keep_last_n_ckpts"],
every_n_epochs=config_ckpt["every_n_epochs"],
save_top_k=config_ckpt["keep_last_n_ckpts"],
monitor="epoch", # monitor the metric "epoch" for selecting which checkpoints to save
mode="max", # get the max of the monitored metric
save_last=config["save_last"],
save_weights_only=config["save_weights_only"],
save_last=config_ckpt["save_last"],
save_weights_only=config_ckpt["save_weights_only"],
)
enable_checkpointing = True
else:
@@ -161,12 +164,58 @@ def core_training(self) -> lightning.Trainer:
self.seed_n,
)

# Get checkpoint type
if self.checkpoint_path and os.path.exists(self.checkpoint_path):
checkpoint = torch.load(self.checkpoint_path)
if all(
[
param in checkpoint
for param in ["optimizer_states", "lr_schedulers"]
]
):
checkpoint_type = "full" # for resuming training
logging.info(
f"Resuming training from checkpoint at: {self.checkpoint_path}"
)
else:
checkpoint_type = "weights" # for fine tuning
logging.info(
f"Fine-tuning training from checkpoint at: {self.checkpoint_path}"
)
else:
checkpoint_type = None

# Get model
lightning_model = FasterRCNN(self.config)
if checkpoint_type == "weights":
# Note: weights-only checkpoint contains hyperparameters
# see https://lightning.ai/docs/pytorch/stable/common/checkpointing_basic.html#save-hyperparameters
lightning_model = FasterRCNN.load_from_checkpoint(
self.checkpoint_path,
config=self.config,
# overwrite checkpoint hyperparameters with config ones
# otherwise ckpt hyperparameters are logged to MLflow, but yaml hyperparameters are used
)
else:
lightning_model = FasterRCNN(self.config)

# Run training
# Get trainer
trainer = self.setup_trainer()
trainer.fit(lightning_model, data_module)

# Run training
# Resume from full checkpoint if available
# (automatically restores model, epoch, step, LR schedulers, etc...)
# https://lightning.ai/docs/pytorch/stable/common/checkpointing_basic.html#save-hyperparameters
if checkpoint_type == "full":
trainer.fit(
lightning_model,
data_module,
ckpt_path=self.checkpoint_path, # needs to having been saved with `save_weights_only=False`
)
else: # for "weights" or no checkpoint
trainer.fit(
lightning_model,
data_module,
)

return trainer

@@ -281,6 +330,12 @@ def train_parse_args(args):
default="./ml-runs",
help=("Path to MLflow directory. Default: ./ml-runs"),
)
parser.add_argument(
"--checkpoint_path",
type=str,
default=None,
help=("Path to checkpoint for resume training"),
)
parser.add_argument(
"--optuna",
action="store_true",
25 changes: 22 additions & 3 deletions guides/TrainingModelsHPC.md
Original file line number Diff line number Diff line change
@@ -92,7 +92,26 @@
>
> If we launch a job and then modify the config file _before_ the job has been able to read it, we may be using an undesired version of the config in our job! To avoid this, it is best to wait until you can verify in MLflow that the job has the expected config parameters (and then edit the file to launch a new job if needed).
6. **Optional argument - Optuna**
6. **Restarting training from a checkpoint**
The `checkpoint_path` argument can be useful. There are two primary options related to checkpoints:
- Resume training
- This option is useful for interrupted training sessions or extending training duration.
- If training is disrupted and stops mid-way, you can resume it by adding `--checkpoint_path $CKPT_PATH \` to your bash script.
- The training will pick up from the last saved epoch and continue until the specified n_epoch.
- Similarly, if training completes but you want to extend it based on metric evaluations, you can increase the n_epoch value (e.g., from `n` to `n + y`). If n_epoch is the same, no new training will be continued as the max_epoch has been reached.
Again, use `--checkpoint_path $CKPT_PATH \` in your bash script, and training will resume from epoch `n` to `n + y`.
- Ensure the `save_weights_only` parameter under `checkpoint_saving` in the config file is set to `False` to resume training, as this option requires loading both weights and the training state.
- Fine-tunning
- This option is useful for fine-tuning a pre-trained model on a different dataset.
- It loads the weights from a checkpoint, allowing you to leverage pre-trained weights from another dataset.
- Add `--checkpoint_path $CKPT_PATH \` to your bash script to use this option.
- Set the `save_weights_only` parameter under `checkpoint_saving` in the config file to `True`, as only the weights are needed for fine-tuning.
7. **Optional argument - Optuna**
We have the option to run [Optuna](https://optuna.org) which is a hyperparameter optimization framework that allows us the find the best hyperparameters for our model.
@@ -117,15 +136,15 @@
--optuna
```
7. **Run the training job using the SLURM scheduler**
8. **Run the training job using the SLURM scheduler**
To launch a job, use the `sbatch` command with the relevant training script:
```
sbatch <path-to-training-bash-script>
```
8. **Check the status of the training job**
9. **Check the status of the training job**
To do this, we can:
1 change: 1 addition & 0 deletions tests/test_unit/test_optuna.py
Original file line number Diff line number Diff line change
@@ -34,6 +34,7 @@ def args():
mlflow_folder="/tmp/mlflow",
fast_dev_run=True,
limit_train_batches=False,
checkpoint_path=None,
)


0 comments on commit 4e70c2c

Please sign in to comment.