From ff1685dc12a59fbfb9f396a717274470717a3908 Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Wed, 7 Aug 2024 06:04:38 +0200 Subject: [PATCH] Remove duplicated management --- basic_and_faq_usages/simple_artifact_store.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/basic_and_faq_usages/simple_artifact_store.py b/basic_and_faq_usages/simple_artifact_store.py index 54c39c7d..44168242 100644 --- a/basic_and_faq_usages/simple_artifact_store.py +++ b/basic_and_faq_usages/simple_artifact_store.py @@ -11,6 +11,7 @@ dataset_path = "demo-dataset.csv" +fig_name = "result-trial.png" def create_dataset(dataset_path): @@ -28,26 +29,23 @@ def create_dataset(dataset_path): def plot_predictions(a, b, trial): # Create an artifact, which is figure in this example, to upload. - os.makedirs("figs/", exist_ok=True) _, ax = plt.subplots() - fig_path = f"figs/result-trial{trial.number}.png" x = np.linspace(0, 1, 100) ax.scatter(dataset["X"], dataset["Y"], label="Dataset", color="blue") ax.plot(x, a * x + b, label="Prediction", color="darkred") ax.grid() ax.legend() - plt.savefig(fig_path) + plt.savefig(fig_name) plt.close() - return fig_path def objective(trial, artifact_store): a = trial.suggest_float("a", -5, 5) b = trial.suggest_float("b", -5, 5) - fig_path = plot_predictions(a, b, trial) + plot_predictions(a, b, trial) # Link the plotted figure with trial using artifact store API. - upload_artifact(artifact_store=artifact_store, file_path=fig_path, study_or_trial=trial) + upload_artifact(artifact_store=artifact_store, file_path=fig_name, study_or_trial=trial) return np.mean((a * dataset["X"] + b - dataset["Y"]) ** 2) @@ -56,7 +54,7 @@ def show_best_result(study, artifact_store): best_trial = study.best_trial # Get all the artifact information linked to best_trial. (Here we have only one.) artifact_meta = get_all_artifact_meta(best_trial, storage=study._storage) - fig_path = "figs/result-best-trial.png" + fig_path = "./result-best-trial.png" # Download the figure from the artifact store to fig_path. download_artifact( artifact_store=artifact_store,