Skip to content

Commit

Permalink
Instantiate RDB
Browse files Browse the repository at this point in the history
  • Loading branch information
nabenabe0928 committed Aug 8, 2024
1 parent 19af58f commit 7737c46
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions basic_and_faq_usages/simple_artifact_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
os.makedirs(base_path, exist_ok=True)
# Instantiate an artifact store.
artifact_store = FileSystemArtifactStore(base_path=base_path)
# Instantiate an RDB.
storage = optuna.storages.RDBStorage("sqlite:///simple-artifact-store-demo.db")


def create_dataset(dataset_path):
# The coefficients we would like to find by Optuna.
a_true, b_true = 2, -3
X = np.random.random(20)
Y = a_true * X + b_true
X = np.random.random(20) * 10 - 5
Y = a_true * X**2 + b_true
dataset = pd.DataFrame({"X": X, "Y": Y})
dataset.to_csv(dataset_path)
return dataset
Expand All @@ -37,9 +39,9 @@ def create_dataset(dataset_path):
def plot_predictions(a, b, trial):
# Create an artifact, which is figure in this example, to upload.
_, ax = plt.subplots()
x = np.linspace(0, 1, 100)
x = np.linspace(-5, 5, 100)
ax.scatter(dataset["X"], dataset["Y"], label="Dataset", color="blue")
ax.plot(x, a * x + b, label="Prediction", color="darkred")
ax.plot(x, a * x**2 + b, label="Prediction", color="darkred")
ax.set_title(f"a={a:.2f}, b={b:.2f}")
ax.grid()
ax.legend()
Expand All @@ -55,13 +57,13 @@ def objective(trial):
# Link the plotted figure with trial using artifact store API.
upload_artifact(artifact_store=artifact_store, file_path=fig_name, study_or_trial=trial)

return np.mean((a * dataset["X"] + b - dataset["Y"]) ** 2)
return np.mean((a * dataset["X"]**2 + b - dataset["Y"]) ** 2)


def show_best_result(study, artifact_store):
best_trial = study.best_trial
# Get all the artifact information linked to best_trial. (Here we have only one.)
artifact_meta = get_all_artifact_meta(best_trial, storage=study._storage)
artifact_meta = get_all_artifact_meta(best_trial, storage=storage)
fig_path = "./result-best-trial.png"
# Download the figure from the artifact store to fig_path.
download_artifact(
Expand All @@ -79,9 +81,7 @@ def show_best_result(study, artifact_store):

if __name__ == "__main__":
# Create a study with a SQLite storage.
study = optuna.create_study(
study_name="demo", storage="sqlite:///simple-artifact-store-demo.db", load_if_exists=True
)
study = optuna.create_study(study_name="demo", storage=storage, load_if_exists=True)
# Upload the dataset to use by artifact store API.
upload_artifact(artifact_store=artifact_store, file_path=dataset_path, study_or_trial=study)
study.optimize(objective, n_trials=30)
Expand Down

0 comments on commit 7737c46

Please sign in to comment.