From 9d63c5c4d02ffa5d512dc4f3aff1bac75fd68131 Mon Sep 17 00:00:00 2001 From: kAIto47802 <115693559+kAIto47802@users.noreply.github.com> Date: Tue, 24 Sep 2024 17:32:21 +0900 Subject: [PATCH 1/9] Introduce artifacts --- pytorch/pytorch_checkpoint.py | 42 ++++++++++++++++++++++------------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/pytorch/pytorch_checkpoint.py b/pytorch/pytorch_checkpoint.py index bc2e9808..eebb1db5 100644 --- a/pytorch/pytorch_checkpoint.py +++ b/pytorch/pytorch_checkpoint.py @@ -25,6 +25,11 @@ import torch.utils.data from torchvision import datasets from torchvision import transforms +from optuna.artifacts import FileSystemArtifactStore +from optuna.artifacts import upload_artifact +from optuna.artifacts import download_artifact + +import pandas as pd DEVICE = torch.device("cpu") @@ -37,6 +42,10 @@ N_VALID_EXAMPLES = BATCHSIZE * 10 CHECKPOINT_DIR = "pytorch_checkpoint" +base_path = "./artifacts" +os.makedirs(base_path, exist_ok=True) +artifact_store = FileSystemArtifactStore(base_path=base_path) + def define_model(trial): # We optimize the number of layers, hidden units and dropout ratio in each layer. @@ -84,12 +93,14 @@ def objective(trial): optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr) trial_number = RetryFailedTrialCallback.retried_trial_number(trial) - trial_checkpoint_dir = os.path.join(CHECKPOINT_DIR, str(trial_number)) - checkpoint_path = os.path.join(trial_checkpoint_dir, "model.pt") - checkpoint_exists = os.path.isfile(checkpoint_path) - if trial_number is not None and checkpoint_exists: - checkpoint = torch.load(checkpoint_path) + print(f"Retrieved trial number: {trial_number}") + + if trial_number is not None: + study = optuna.load_study(study_name="pytorch_checkpoint", storage="sqlite:///example.db") + artifact_id = study.trials[trial_number].user_attrs["artifact_id"] + download_artifact(artifact_store=artifact_store, file_path="./tmp_model.pt", artifact_id=artifact_id) + checkpoint = torch.load("./tmp_model.pt") epoch = checkpoint["epoch"] epoch_begin = epoch + 1 @@ -99,20 +110,11 @@ def objective(trial): optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) accuracy = checkpoint["accuracy"] else: - trial_checkpoint_dir = os.path.join(CHECKPOINT_DIR, str(trial.number)) - checkpoint_path = os.path.join(trial_checkpoint_dir, "model.pt") epoch_begin = 0 # Get the FashionMNIST dataset. train_loader, valid_loader = get_mnist() - os.makedirs(trial_checkpoint_dir, exist_ok=True) - # A checkpoint may be corrupted when the process is killed during `torch.save`. - # Reduce the risk by first calling `torch.save` to a temporary file, then copy. - tmp_checkpoint_path = os.path.join(trial_checkpoint_dir, "tmp_model.pt") - - print(f"Checkpoint path for trial is '{checkpoint_path}'.") - # Training of the model. for epoch in range(epoch_begin, EPOCHS): model.train() @@ -159,9 +161,17 @@ def objective(trial): "optimizer_state_dict": optimizer.state_dict(), "accuracy": accuracy, }, - tmp_checkpoint_path, + "./tmp_model.pt" + ) + artifact_id = upload_artifact( + artifact_store=artifact_store, + file_path="./tmp_model.pt", + study_or_trial=trial, + ) + trial.set_user_attr( + "artifact_id", artifact_id ) - shutil.move(tmp_checkpoint_path, checkpoint_path) + os.remove("./tmp_model.pt") # Handle pruning based on the intermediate value. if trial.should_prune(): From 79130db8e797bd15ed03117123c9bfebd3eeab53 Mon Sep 17 00:00:00 2001 From: kAIto47802 <115693559+kAIto47802@users.noreply.github.com> Date: Tue, 24 Sep 2024 17:46:48 +0900 Subject: [PATCH 2/9] Apply formatter --- pytorch/pytorch_checkpoint.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/pytorch/pytorch_checkpoint.py b/pytorch/pytorch_checkpoint.py index eebb1db5..8c92f4bb 100644 --- a/pytorch/pytorch_checkpoint.py +++ b/pytorch/pytorch_checkpoint.py @@ -14,9 +14,11 @@ """ import os -import shutil import optuna +from optuna.artifacts import download_artifact +from optuna.artifacts import FileSystemArtifactStore +from optuna.artifacts import upload_artifact from optuna.storages import RetryFailedTrialCallback import torch import torch.nn as nn @@ -25,11 +27,6 @@ import torch.utils.data from torchvision import datasets from torchvision import transforms -from optuna.artifacts import FileSystemArtifactStore -from optuna.artifacts import upload_artifact -from optuna.artifacts import download_artifact - -import pandas as pd DEVICE = torch.device("cpu") @@ -95,11 +92,13 @@ def objective(trial): trial_number = RetryFailedTrialCallback.retried_trial_number(trial) print(f"Retrieved trial number: {trial_number}") - + if trial_number is not None: study = optuna.load_study(study_name="pytorch_checkpoint", storage="sqlite:///example.db") artifact_id = study.trials[trial_number].user_attrs["artifact_id"] - download_artifact(artifact_store=artifact_store, file_path="./tmp_model.pt", artifact_id=artifact_id) + download_artifact( + artifact_store=artifact_store, file_path="./tmp_model.pt", artifact_id=artifact_id + ) checkpoint = torch.load("./tmp_model.pt") epoch = checkpoint["epoch"] epoch_begin = epoch + 1 @@ -161,16 +160,14 @@ def objective(trial): "optimizer_state_dict": optimizer.state_dict(), "accuracy": accuracy, }, - "./tmp_model.pt" + "./tmp_model.pt", ) artifact_id = upload_artifact( artifact_store=artifact_store, file_path="./tmp_model.pt", study_or_trial=trial, ) - trial.set_user_attr( - "artifact_id", artifact_id - ) + trial.set_user_attr("artifact_id", artifact_id) os.remove("./tmp_model.pt") # Handle pruning based on the intermediate value. From e7b117c3b9e36641bab9349b53daa6ed49ce7078 Mon Sep 17 00:00:00 2001 From: kAIto47802 <115693559+kAIto47802@users.noreply.github.com> Date: Tue, 24 Sep 2024 17:49:00 +0900 Subject: [PATCH 3/9] Remove debug print --- pytorch/pytorch_checkpoint.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch/pytorch_checkpoint.py b/pytorch/pytorch_checkpoint.py index 8c92f4bb..eb3c0045 100644 --- a/pytorch/pytorch_checkpoint.py +++ b/pytorch/pytorch_checkpoint.py @@ -91,8 +91,6 @@ def objective(trial): trial_number = RetryFailedTrialCallback.retried_trial_number(trial) - print(f"Retrieved trial number: {trial_number}") - if trial_number is not None: study = optuna.load_study(study_name="pytorch_checkpoint", storage="sqlite:///example.db") artifact_id = study.trials[trial_number].user_attrs["artifact_id"] From 5f329a60be168d22d8509ed9cd8effc4eff1c6d8 Mon Sep 17 00:00:00 2001 From: kAIto47802 <115693559+kAIto47802@users.noreply.github.com> Date: Fri, 18 Oct 2024 18:00:10 +0900 Subject: [PATCH 4/9] Update pytorch/pytorch_checkpoint.py Co-authored-by: Naoto Mizuno --- pytorch/pytorch_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch/pytorch_checkpoint.py b/pytorch/pytorch_checkpoint.py index eb3c0045..f93c9eec 100644 --- a/pytorch/pytorch_checkpoint.py +++ b/pytorch/pytorch_checkpoint.py @@ -92,7 +92,7 @@ def objective(trial): trial_number = RetryFailedTrialCallback.retried_trial_number(trial) if trial_number is not None: - study = optuna.load_study(study_name="pytorch_checkpoint", storage="sqlite:///example.db") + study = trial.study artifact_id = study.trials[trial_number].user_attrs["artifact_id"] download_artifact( artifact_store=artifact_store, file_path="./tmp_model.pt", artifact_id=artifact_id From 0443f6b0a06c44fe133a7751ed52c5a1c993091b Mon Sep 17 00:00:00 2001 From: kAIto47802 <115693559+kAIto47802@users.noreply.github.com> Date: Fri, 18 Oct 2024 18:59:03 +0900 Subject: [PATCH 5/9] Update based on review --- pytorch/pytorch_checkpoint.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/pytorch/pytorch_checkpoint.py b/pytorch/pytorch_checkpoint.py index f93c9eec..b35d8d34 100644 --- a/pytorch/pytorch_checkpoint.py +++ b/pytorch/pytorch_checkpoint.py @@ -91,13 +91,12 @@ def objective(trial): trial_number = RetryFailedTrialCallback.retried_trial_number(trial) - if trial_number is not None: - study = trial.study - artifact_id = study.trials[trial_number].user_attrs["artifact_id"] + artifact_id = trial.study.trials[trial_number].user_attrs.get("artifact_id") + if trial_number is not None and artifact_id is not None: download_artifact( - artifact_store=artifact_store, file_path="./tmp_model.pt", artifact_id=artifact_id + artifact_store=artifact_store, file_path=f"./tmp_model_{trial_number}.pt", artifact_id=artifact_id ) - checkpoint = torch.load("./tmp_model.pt") + checkpoint = torch.load(f"./tmp_model_{trial_number}.pt") epoch = checkpoint["epoch"] epoch_begin = epoch + 1 @@ -158,15 +157,15 @@ def objective(trial): "optimizer_state_dict": optimizer.state_dict(), "accuracy": accuracy, }, - "./tmp_model.pt", + f"./tmp_model_{trial_number}.pt", ) artifact_id = upload_artifact( artifact_store=artifact_store, - file_path="./tmp_model.pt", + file_path=f"./tmp_model_{trial_number}.pt", study_or_trial=trial, ) trial.set_user_attr("artifact_id", artifact_id) - os.remove("./tmp_model.pt") + os.remove(f"./tmp_model_{trial_number}.pt") # Handle pruning based on the intermediate value. if trial.should_prune(): From a1a595315fa0145cc581f854a00266f644061e70 Mon Sep 17 00:00:00 2001 From: kAIto47802 <115693559+kAIto47802@users.noreply.github.com> Date: Fri, 18 Oct 2024 19:01:17 +0900 Subject: [PATCH 6/9] Apply formatter --- pytorch/pytorch_checkpoint.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch/pytorch_checkpoint.py b/pytorch/pytorch_checkpoint.py index b35d8d34..f500ff1c 100644 --- a/pytorch/pytorch_checkpoint.py +++ b/pytorch/pytorch_checkpoint.py @@ -94,7 +94,9 @@ def objective(trial): artifact_id = trial.study.trials[trial_number].user_attrs.get("artifact_id") if trial_number is not None and artifact_id is not None: download_artifact( - artifact_store=artifact_store, file_path=f"./tmp_model_{trial_number}.pt", artifact_id=artifact_id + artifact_store=artifact_store, + file_path=f"./tmp_model_{trial_number}.pt", + artifact_id=artifact_id, ) checkpoint = torch.load(f"./tmp_model_{trial_number}.pt") epoch = checkpoint["epoch"] From 262c48d2fcfe0326929f2a01ce40f1eca4f40082 Mon Sep 17 00:00:00 2001 From: kAIto47802 <115693559+kAIto47802@users.noreply.github.com> Date: Fri, 18 Oct 2024 19:12:10 +0900 Subject: [PATCH 7/9] Fix --- pytorch/pytorch_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch/pytorch_checkpoint.py b/pytorch/pytorch_checkpoint.py index f500ff1c..35b697d3 100644 --- a/pytorch/pytorch_checkpoint.py +++ b/pytorch/pytorch_checkpoint.py @@ -91,7 +91,7 @@ def objective(trial): trial_number = RetryFailedTrialCallback.retried_trial_number(trial) - artifact_id = trial.study.trials[trial_number].user_attrs.get("artifact_id") + artifact_id = trial_number and trial.study.trials[trial_number].user_attrs.get("artifact_id") if trial_number is not None and artifact_id is not None: download_artifact( artifact_store=artifact_store, From 4852e9d0de2fdaf301da818f8fc2d9e71eea3ebf Mon Sep 17 00:00:00 2001 From: kAIto47802 <115693559+kAIto47802@users.noreply.github.com> Date: Tue, 22 Oct 2024 16:09:21 +0900 Subject: [PATCH 8/9] Update to check the entire retry history --- pytorch/pytorch_checkpoint.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/pytorch/pytorch_checkpoint.py b/pytorch/pytorch_checkpoint.py index 35b697d3..e7f2b5d6 100644 --- a/pytorch/pytorch_checkpoint.py +++ b/pytorch/pytorch_checkpoint.py @@ -89,20 +89,25 @@ def objective(trial): lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True) optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr) - trial_number = RetryFailedTrialCallback.retried_trial_number(trial) - - artifact_id = trial_number and trial.study.trials[trial_number].user_attrs.get("artifact_id") - if trial_number is not None and artifact_id is not None: + artifact_id = None + retry_history = RetryFailedTrialCallback.retry_history(trial) + for trial_number in reversed(retry_history): + artifact_id = trial.study.trials[trial_number].user_attrs.get("artifact_id") + if artifact_id is not None: + retry_trial_number = trial_number + break + + if artifact_id is not None: download_artifact( artifact_store=artifact_store, - file_path=f"./tmp_model_{trial_number}.pt", + file_path=f"./tmp_model_{trial.number}.pt", artifact_id=artifact_id, ) - checkpoint = torch.load(f"./tmp_model_{trial_number}.pt") + checkpoint = torch.load(f"./tmp_model_{trial.number}.pt") epoch = checkpoint["epoch"] epoch_begin = epoch + 1 - print(f"Loading a checkpoint from trial {trial_number} in epoch {epoch}.") + print(f"Loading a checkpoint from trial {retry_trial_number} in epoch {epoch}.") model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) @@ -159,15 +164,15 @@ def objective(trial): "optimizer_state_dict": optimizer.state_dict(), "accuracy": accuracy, }, - f"./tmp_model_{trial_number}.pt", + f"./tmp_model_{trial.number}.pt", ) artifact_id = upload_artifact( artifact_store=artifact_store, - file_path=f"./tmp_model_{trial_number}.pt", + file_path=f"./tmp_model_{trial.number}.pt", study_or_trial=trial, ) trial.set_user_attr("artifact_id", artifact_id) - os.remove(f"./tmp_model_{trial_number}.pt") + os.remove(f"./tmp_model_{trial.number}.pt") # Handle pruning based on the intermediate value. if trial.should_prune(): From 3fc41e827cd61813596d85b2985a9668a0e9998e Mon Sep 17 00:00:00 2001 From: kAIto47802 <115693559+kAIto47802@users.noreply.github.com> Date: Fri, 25 Oct 2024 15:51:43 +0900 Subject: [PATCH 9/9] Add removal of temporal file --- pytorch/pytorch_checkpoint.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch/pytorch_checkpoint.py b/pytorch/pytorch_checkpoint.py index e7f2b5d6..70b458ce 100644 --- a/pytorch/pytorch_checkpoint.py +++ b/pytorch/pytorch_checkpoint.py @@ -104,6 +104,7 @@ def objective(trial): artifact_id=artifact_id, ) checkpoint = torch.load(f"./tmp_model_{trial.number}.pt") + os.remove(f"./tmp_model_{trial.number}.pt") epoch = checkpoint["epoch"] epoch_begin = epoch + 1