diff --git a/basic_and_faq_usages/simple_artifact_store.py b/basic_and_faq_usages/simple_artifact_store.py index 85af235d..5d9bc13d 100644 --- a/basic_and_faq_usages/simple_artifact_store.py +++ b/basic_and_faq_usages/simple_artifact_store.py @@ -1,4 +1,5 @@ import os +import tempfile import matplotlib.pyplot as plt import numpy as np @@ -36,7 +37,7 @@ def create_dataset(dataset_path): dataset = create_dataset(dataset_path) -def plot_predictions(a, b, trial): +def plot_predictions(a, b, trial, tmp_dir): # Create an artifact, which is figure in this example, to upload. _, ax = plt.subplots() x = np.linspace(-5, 5, 100) @@ -45,17 +46,20 @@ def plot_predictions(a, b, trial): ax.set_title(f"a={a:.2f}, b={b:.2f}") ax.grid() ax.legend() - plt.savefig(fig_name) + plt.savefig(os.path.join(tmp_dir, fig_name)) plt.close() def objective(trial): a = trial.suggest_float("a", -5, 5) b = trial.suggest_float("b", -5, 5) - plot_predictions(a, b, trial) - # Link the plotted figure with trial using artifact store API. - upload_artifact(artifact_store=artifact_store, file_path=fig_name, study_or_trial=trial) + with tempfile.TemporaryDirectory() as tmp_dir: + print(tmp_dir) + plot_predictions(a, b, trial, tmp_dir) + fig_path = os.path.join(tmp_dir, fig_name) + # Link the plotted figure with trial using artifact store API. + upload_artifact(artifact_store=artifact_store, file_path=fig_path, study_or_trial=trial) return np.mean((a * dataset["X"]**2 + b - dataset["Y"]) ** 2)