Skip to content

Commit

Permalink
Update to check the entire retry history
Browse files Browse the repository at this point in the history
  • Loading branch information
kAIto47802 committed Oct 22, 2024
1 parent 262c48d commit 4852e9d
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions pytorch/pytorch_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,20 +89,25 @@ def objective(trial):
lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)

trial_number = RetryFailedTrialCallback.retried_trial_number(trial)

artifact_id = trial_number and trial.study.trials[trial_number].user_attrs.get("artifact_id")
if trial_number is not None and artifact_id is not None:
artifact_id = None
retry_history = RetryFailedTrialCallback.retry_history(trial)
for trial_number in reversed(retry_history):
artifact_id = trial.study.trials[trial_number].user_attrs.get("artifact_id")
if artifact_id is not None:
retry_trial_number = trial_number
break

if artifact_id is not None:
download_artifact(
artifact_store=artifact_store,
file_path=f"./tmp_model_{trial_number}.pt",
file_path=f"./tmp_model_{trial.number}.pt",
artifact_id=artifact_id,
)
checkpoint = torch.load(f"./tmp_model_{trial_number}.pt")
checkpoint = torch.load(f"./tmp_model_{trial.number}.pt")
epoch = checkpoint["epoch"]
epoch_begin = epoch + 1

print(f"Loading a checkpoint from trial {trial_number} in epoch {epoch}.")
print(f"Loading a checkpoint from trial {retry_trial_number} in epoch {epoch}.")

model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
Expand Down Expand Up @@ -159,15 +164,15 @@ def objective(trial):
"optimizer_state_dict": optimizer.state_dict(),
"accuracy": accuracy,
},
f"./tmp_model_{trial_number}.pt",
f"./tmp_model_{trial.number}.pt",
)
artifact_id = upload_artifact(
artifact_store=artifact_store,
file_path=f"./tmp_model_{trial_number}.pt",
file_path=f"./tmp_model_{trial.number}.pt",
study_or_trial=trial,
)
trial.set_user_attr("artifact_id", artifact_id)
os.remove(f"./tmp_model_{trial_number}.pt")
os.remove(f"./tmp_model_{trial.number}.pt")

# Handle pruning based on the intermediate value.
if trial.should_prune():
Expand Down

0 comments on commit 4852e9d

Please sign in to comment.