diff --git a/pytorch/pytorch_checkpoint.py b/pytorch/pytorch_checkpoint.py index e7f2b5d6..70b458ce 100644 --- a/pytorch/pytorch_checkpoint.py +++ b/pytorch/pytorch_checkpoint.py @@ -104,6 +104,7 @@ def objective(trial): artifact_id=artifact_id, ) checkpoint = torch.load(f"./tmp_model_{trial.number}.pt") + os.remove(f"./tmp_model_{trial.number}.pt") epoch = checkpoint["epoch"] epoch_begin = epoch + 1