Skip to content

Commit

Permalink
Apply yanase's comment
Browse files Browse the repository at this point in the history
  • Loading branch information
nabenabe0928 committed Aug 7, 2024
1 parent ff1685d commit 36c3209
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions basic_and_faq_usages/simple_artifact_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@

dataset_path = "demo-dataset.csv"
fig_name = "result-trial.png"
# NOTE: The uploaded artifacts can be viewed in Optuna Dashboard with the following command:
# $ optuna-dashboard sqlite:///artifact-demo.db --artifact-dir ./save-artifact-here
base_path = os.path.join("./save-artifact-here")
# Make the directory used for artifact store.
os.makedirs(base_path, exist_ok=True)
# Instantiate an artifact store.
artifact_store = FileSystemArtifactStore(base_path=base_path)


def create_dataset(dataset_path):
Expand All @@ -33,13 +40,14 @@ def plot_predictions(a, b, trial):
x = np.linspace(0, 1, 100)
ax.scatter(dataset["X"], dataset["Y"], label="Dataset", color="blue")
ax.plot(x, a * x + b, label="Prediction", color="darkred")
ax.set_title(f"a={a:.2f}, b={b:.2f}")
ax.grid()
ax.legend()
plt.savefig(fig_name)
plt.close()


def objective(trial, artifact_store):
def objective(trial):
a = trial.suggest_float("a", -5, 5)
b = trial.suggest_float("b", -5, 5)
plot_predictions(a, b, trial)
Expand Down Expand Up @@ -70,18 +78,11 @@ def show_best_result(study, artifact_store):


if __name__ == "__main__":
# NOTE: The uploaded artifacts can be viewed in Optuna Dashboard with the following command:
# $ optuna-dashboard sqlite:///artifact-demo.db --artifact-dir ./save-artifact-here
# Create a study with a SQLite storage.
study = optuna.create_study(
study_name="demo", storage="sqlite:///simple-artifact-store-demo.db", load_if_exists=True
)
base_path = os.path.join("./save-artifact-here")
# Make the directory used for artifact store.
os.makedirs(base_path, exist_ok=True)
# Instantiate an artifact store.
artifact_store = FileSystemArtifactStore(base_path=base_path)
# Upload the dataset to use by artifact store API.
upload_artifact(artifact_store=artifact_store, file_path=dataset_path, study_or_trial=study)
study.optimize(lambda trial: objective(trial, artifact_store), n_trials=30)
study.optimize(objective, n_trials=30)
show_best_result(study, artifact_store)

0 comments on commit 36c3209

Please sign in to comment.