Skip to content

Commit

Permalink
Merge pull request #280 from kAIto47802/pytorch-checkpoint-artifact
Browse files Browse the repository at this point in the history
Introduce `optuna.artifacts` to the PyTorch checkpoint example
  • Loading branch information
not522 authored Oct 25, 2024
2 parents c039f99 + 3fc41e8 commit 8158f46
Showing 1 changed file with 32 additions and 20 deletions.
52 changes: 32 additions & 20 deletions pytorch/pytorch_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
"""

import os
import shutil

import optuna
from optuna.artifacts import download_artifact
from optuna.artifacts import FileSystemArtifactStore
from optuna.artifacts import upload_artifact
from optuna.storages import RetryFailedTrialCallback
import torch
import torch.nn as nn
Expand All @@ -37,6 +39,10 @@
N_VALID_EXAMPLES = BATCHSIZE * 10
CHECKPOINT_DIR = "pytorch_checkpoint"

base_path = "./artifacts"
os.makedirs(base_path, exist_ok=True)
artifact_store = FileSystemArtifactStore(base_path=base_path)


def define_model(trial):
# We optimize the number of layers, hidden units and dropout ratio in each layer.
Expand Down Expand Up @@ -83,36 +89,36 @@ def objective(trial):
lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)

trial_number = RetryFailedTrialCallback.retried_trial_number(trial)
trial_checkpoint_dir = os.path.join(CHECKPOINT_DIR, str(trial_number))
checkpoint_path = os.path.join(trial_checkpoint_dir, "model.pt")
checkpoint_exists = os.path.isfile(checkpoint_path)

if trial_number is not None and checkpoint_exists:
checkpoint = torch.load(checkpoint_path)
artifact_id = None
retry_history = RetryFailedTrialCallback.retry_history(trial)
for trial_number in reversed(retry_history):
artifact_id = trial.study.trials[trial_number].user_attrs.get("artifact_id")
if artifact_id is not None:
retry_trial_number = trial_number
break

if artifact_id is not None:
download_artifact(
artifact_store=artifact_store,
file_path=f"./tmp_model_{trial.number}.pt",
artifact_id=artifact_id,
)
checkpoint = torch.load(f"./tmp_model_{trial.number}.pt")
os.remove(f"./tmp_model_{trial.number}.pt")
epoch = checkpoint["epoch"]
epoch_begin = epoch + 1

print(f"Loading a checkpoint from trial {trial_number} in epoch {epoch}.")
print(f"Loading a checkpoint from trial {retry_trial_number} in epoch {epoch}.")

model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
accuracy = checkpoint["accuracy"]
else:
trial_checkpoint_dir = os.path.join(CHECKPOINT_DIR, str(trial.number))
checkpoint_path = os.path.join(trial_checkpoint_dir, "model.pt")
epoch_begin = 0

# Get the FashionMNIST dataset.
train_loader, valid_loader = get_mnist()

os.makedirs(trial_checkpoint_dir, exist_ok=True)
# A checkpoint may be corrupted when the process is killed during `torch.save`.
# Reduce the risk by first calling `torch.save` to a temporary file, then copy.
tmp_checkpoint_path = os.path.join(trial_checkpoint_dir, "tmp_model.pt")

print(f"Checkpoint path for trial is '{checkpoint_path}'.")

# Training of the model.
for epoch in range(epoch_begin, EPOCHS):
model.train()
Expand Down Expand Up @@ -159,9 +165,15 @@ def objective(trial):
"optimizer_state_dict": optimizer.state_dict(),
"accuracy": accuracy,
},
tmp_checkpoint_path,
f"./tmp_model_{trial.number}.pt",
)
artifact_id = upload_artifact(
artifact_store=artifact_store,
file_path=f"./tmp_model_{trial.number}.pt",
study_or_trial=trial,
)
shutil.move(tmp_checkpoint_path, checkpoint_path)
trial.set_user_attr("artifact_id", artifact_id)
os.remove(f"./tmp_model_{trial.number}.pt")

# Handle pruning based on the intermediate value.
if trial.should_prune():
Expand Down

0 comments on commit 8158f46

Please sign in to comment.