-
Notifications
You must be signed in to change notification settings - Fork 177
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Import examples from github.com/optuna/optuna-dashboard at e0c89e2f34…
…cfb4dc5302b0d910165c5fe0f228f1.
- Loading branch information
1 parent
cc0bc79
commit b47793f
Showing
5 changed files
with
295 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import os | ||
import textwrap | ||
import time | ||
from typing import NoReturn | ||
|
||
import optuna | ||
from optuna.artifacts import FileSystemArtifactStore | ||
from optuna.artifacts import upload_artifact | ||
from optuna.trial import TrialState | ||
from optuna_dashboard import ChoiceWidget | ||
from optuna_dashboard import register_objective_form_widgets | ||
from optuna_dashboard import save_note | ||
from optuna_dashboard.artifact import get_artifact_path | ||
from PIL import Image | ||
|
||
|
||
def suggest_and_generate_image( | ||
study: optuna.Study, artifact_store: FileSystemArtifactStore | ||
) -> None: | ||
# 1. Ask new parameters | ||
trial = study.ask() | ||
r = trial.suggest_int("r", 0, 255) | ||
g = trial.suggest_int("g", 0, 255) | ||
b = trial.suggest_int("b", 0, 255) | ||
|
||
# 2. Generate image | ||
image_path = f"tmp/sample-{trial.number}.png" | ||
image = Image.new("RGB", (320, 240), color=(r, g, b)) | ||
image.save(image_path) | ||
|
||
# 3. Upload Artifact | ||
artifact_id = upload_artifact(trial, image_path, artifact_store) | ||
artifact_path = get_artifact_path(trial, artifact_id) | ||
|
||
# 4. Save Note | ||
note = textwrap.dedent( | ||
f"""\ | ||
## Trial {trial.number} | ||
![generated-image]({artifact_path}) | ||
""" | ||
) | ||
save_note(trial, note) | ||
|
||
|
||
def start_optimization(artifact_store: FileSystemArtifactStore) -> NoReturn: | ||
# 1. Create Study | ||
study = optuna.create_study( | ||
study_name="Human-in-the-loop Optimization", | ||
storage="sqlite:///db.sqlite3", | ||
sampler=optuna.samplers.TPESampler(constant_liar=True, n_startup_trials=5), | ||
load_if_exists=True, | ||
) | ||
|
||
# 2. Set an objective name | ||
study.set_metric_names(["Looks like sunset color?"]) | ||
|
||
# 3. Register ChoiceWidget | ||
register_objective_form_widgets( | ||
study, | ||
widgets=[ | ||
ChoiceWidget( | ||
choices=["Good 👍", "So-so👌", "Bad 👎"], | ||
values=[-1, 0, 1], | ||
description="Please input your score!", | ||
), | ||
], | ||
) | ||
|
||
# 4. Start Human-in-the-loop Optimization | ||
n_batch = 4 | ||
while True: | ||
running_trials = study.get_trials(deepcopy=False, states=(TrialState.RUNNING,)) | ||
if len(running_trials) >= n_batch: | ||
time.sleep(1) # Avoid busy-loop | ||
continue | ||
suggest_and_generate_image(study, artifact_store) | ||
|
||
|
||
def main() -> NoReturn: | ||
tmp_path = os.path.join(os.path.dirname(__file__), "tmp") | ||
|
||
# 1. Create Artifact Store | ||
artifact_path = os.path.join(os.path.dirname(__file__), "artifact") | ||
artifact_store = FileSystemArtifactStore(artifact_path) | ||
|
||
if not os.path.exists(artifact_path): | ||
os.mkdir(artifact_path) | ||
|
||
if not os.path.exists(tmp_path): | ||
os.mkdir(tmp_path) | ||
|
||
# 2. Run optimize loop | ||
start_optimization(artifact_store) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
#!/usr/bin/env sh | ||
optuna-dashboard sqlite:///example.db --artifact-dir ./artifact |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from __future__ import annotations | ||
|
||
import os | ||
import tempfile | ||
import time | ||
from typing import NoReturn | ||
|
||
from optuna.artifacts import FileSystemArtifactStore | ||
from optuna.artifacts import upload_artifact | ||
from optuna_dashboard import register_preference_feedback_component | ||
from optuna_dashboard.preferential import create_study | ||
from optuna_dashboard.preferential.samplers.gp import PreferentialGPSampler | ||
from PIL import Image | ||
|
||
|
||
STORAGE_URL = "sqlite:///example.db" | ||
artifact_path = os.path.join(os.path.dirname(__file__), "artifact") | ||
artifact_store = FileSystemArtifactStore(base_path=artifact_path) | ||
os.makedirs(artifact_path, exist_ok=True) | ||
|
||
|
||
def main() -> NoReturn: | ||
study = create_study( | ||
n_generate=4, | ||
study_name="Preferential Optimization", | ||
storage=STORAGE_URL, | ||
sampler=PreferentialGPSampler(), | ||
load_if_exists=True, | ||
) | ||
# Change the component, displayed on the human feedback pages. | ||
# By default (component_type="note"), the Trial's Markdown note is displayed. | ||
user_attr_key = "rgb_image" | ||
register_preference_feedback_component(study, "artifact", user_attr_key) | ||
|
||
with tempfile.TemporaryDirectory() as tmpdir: | ||
while True: | ||
# If study.should_generate() returns False, | ||
# the generator waits for human evaluation. | ||
if not study.should_generate(): | ||
time.sleep(0.1) # Avoid busy-loop | ||
continue | ||
|
||
trial = study.ask() | ||
# 1. Ask new parameters | ||
r = trial.suggest_int("r", 0, 255) | ||
g = trial.suggest_int("g", 0, 255) | ||
b = trial.suggest_int("b", 0, 255) | ||
|
||
# 2. Generate image | ||
image_path = os.path.join(tmpdir, f"sample-{trial.number}.png") | ||
image = Image.new("RGB", (320, 240), color=(r, g, b)) | ||
image.save(image_path) | ||
|
||
# 3. Upload Artifact and set artifact_id to trial.user_attrs["rgb_image"]. | ||
artifact_id = upload_artifact(trial, image_path, artifact_store) | ||
trial.set_user_attr(user_attr_key, artifact_id) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from __future__ import annotations | ||
|
||
import os | ||
import shutil | ||
import tempfile | ||
import uuid | ||
|
||
import optuna | ||
from optuna.trial import TrialState | ||
from optuna_dashboard.artifact.file_system import FileSystemBackend | ||
from optuna_dashboard.streamlit import render_objective_form_widgets | ||
from optuna_dashboard.streamlit import render_trial_note | ||
|
||
import streamlit as st | ||
|
||
|
||
artifact_path = os.path.join(os.path.dirname(__file__), "artifact") | ||
artifact_backend = FileSystemBackend(base_path=artifact_path) | ||
|
||
|
||
def get_tmp_dir() -> str: | ||
if "tmp_dir" not in st.session_state: | ||
tmp_dir_name = str(uuid.uuid4()) | ||
tmp_dir_path = os.path.join(tempfile.gettempdir(), tmp_dir_name) | ||
os.makedirs(tmp_dir_path, exist_ok=True) | ||
st.session_state.tmp_dir = tmp_dir_path | ||
|
||
return st.session_state.tmp_dir | ||
|
||
|
||
def start_streamlit() -> None: | ||
tmpdir = get_tmp_dir() | ||
study = optuna.load_study( | ||
storage="sqlite:///streamlit-db.sqlite3", study_name="Human-in-the-loop Optimization" | ||
) | ||
selected_trial = st.sidebar.selectbox("Trial", study.trials, format_func=lambda t: t.number) | ||
|
||
if selected_trial is None: | ||
return | ||
render_trial_note(study, selected_trial) | ||
artifact_id = selected_trial.user_attrs.get("artifact_id") | ||
if artifact_id: | ||
with artifact_backend.open(artifact_id) as fsrc: | ||
tmp_img_path = os.path.join(tmpdir, artifact_id + ".png") | ||
with open(tmp_img_path, "wb") as fdst: | ||
shutil.copyfileobj(fsrc, fdst) | ||
st.image(tmp_img_path, caption="Image") | ||
|
||
if selected_trial.state == TrialState.RUNNING: | ||
render_objective_form_widgets(study, selected_trial) | ||
|
||
|
||
if __name__ == "__main__": | ||
start_streamlit() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
from __future__ import annotations | ||
|
||
import os | ||
import tempfile | ||
import time | ||
from typing import NoReturn | ||
|
||
import optuna | ||
from optuna.trial import TrialState | ||
from optuna_dashboard import ChoiceWidget | ||
from optuna_dashboard import register_objective_form_widgets | ||
from optuna_dashboard import save_note | ||
from optuna_dashboard.artifact import upload_artifact | ||
from optuna_dashboard.artifact.file_system import FileSystemBackend | ||
from PIL import Image | ||
|
||
|
||
def suggest_and_generate_image( | ||
study: optuna.Study, artifact_backend: FileSystemBackend, tmpdir: str | ||
) -> None: | ||
# 1. Ask new parameters | ||
trial = study.ask() | ||
r = trial.suggest_int("r", 0, 255) | ||
g = trial.suggest_int("g", 0, 255) | ||
b = trial.suggest_int("b", 0, 255) | ||
|
||
# 2. Generate image | ||
image_path = os.path.join(tmpdir, f"sample-{trial.number}.png") | ||
image = Image.new("RGB", (320, 240), color=(r, g, b)) | ||
image.save(image_path) | ||
|
||
# 3. Upload Artifact | ||
artifact_id = upload_artifact(artifact_backend, trial, image_path) | ||
trial.set_user_attr("artifact_id", artifact_id) | ||
|
||
# 4. Save Note | ||
save_note(trial, f"## Trial {trial.number}") | ||
|
||
|
||
def main() -> NoReturn: | ||
# 1. Create Artifact Store | ||
artifact_path = os.path.join(os.path.dirname(__file__), "artifact") | ||
artifact_backend = FileSystemBackend(base_path=artifact_path) | ||
|
||
if not os.path.exists(artifact_path): | ||
os.mkdir(artifact_path) | ||
|
||
# 2. Create Study | ||
study = optuna.create_study( | ||
study_name="Human-in-the-loop Optimization", | ||
storage="sqlite:///streamlit-db.sqlite3", | ||
sampler=optuna.samplers.TPESampler(constant_liar=True, n_startup_trials=5), | ||
load_if_exists=True, | ||
) | ||
study.set_metric_names(["Looks like sunset color?"]) | ||
|
||
# 4. Register ChoiceWidget | ||
register_objective_form_widgets( | ||
study, | ||
widgets=[ | ||
ChoiceWidget( | ||
choices=["Good 👍", "So-so👌", "Bad 👎"], | ||
values=[-1, 0, 1], | ||
description="Please input your score!", | ||
), | ||
], | ||
) | ||
|
||
# 5. Start Human-in-the-loop Optimization | ||
n_batch = 4 | ||
with tempfile.TemporaryDirectory() as tmpdir: | ||
while True: | ||
running_trials = study.get_trials(deepcopy=False, states=(TrialState.RUNNING,)) | ||
if len(running_trials) >= n_batch: | ||
time.sleep(1) # Avoid busy-loop | ||
continue | ||
suggest_and_generate_image(study, artifact_backend, tmpdir) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |