Skip to content
This repository has been archived by the owner on Nov 25, 2020. It is now read-only.

Commit

Permalink
deprecate exection time in name (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
tokusumi authored Oct 10, 2020
1 parent d3604d6 commit a595f73
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 8 deletions.
12 changes: 5 additions & 7 deletions kerastuner_tensorboard_logger/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def timedelta_to_hms(timedelta: timedelta) -> str:
"""convert datetime.timedelta to string like '01h:15m:30s'"""
"""(Deprecated) convert datetime.timedelta to string like '01h:15m:30s'"""
tot_seconds = int(timedelta.total_seconds())
hours = tot_seconds // 3600
minutes = (tot_seconds % 3600) // 60
Expand All @@ -30,7 +30,7 @@ class TensorBoardLogger(Logger):
def __init__(
self,
metrics: Union[str, List[str]] = ["acc"],
logdir: str = "logs/hparam_tuning",
logdir: str = "logs/",
overwrite: bool = False,
):
self.metrics = [metrics] if isinstance(metrics, str) else metrics
Expand All @@ -46,18 +46,16 @@ def register_tuner(self, tuner_state):

def register_trial(self, trial_id: str, trial_state: Dict[str, Any]):
"""Informs the logger that a new Trial is starting."""
self.times[trial_id] = datetime.now()
pass

def report_trial_state(self, trial_id: str, trial_state: Dict[str, Any]):
"""Gives the logger information about trial status."""
execution_time = timedelta_to_hms(datetime.now() - self.times.pop(trial_id))
name = f"{execution_time}-{trial_id}"
logdir = os.path.join(self.logdir, name)
logdir = os.path.join(self.logdir, trial_id, "hparams")

with tf.summary.create_file_writer(logdir).as_default():
hparams = self.parse_hparams(trial_state)
hp_board.hparams(
hparams, trial_id=name
hparams, trial_id=trial_id
) # record the values used in this trial

for target_metric, metric in self.parse_metrics(trial_state):
Expand Down
2 changes: 1 addition & 1 deletion scripts/local_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
set -e

pytest --disable-warnings tests/
tensorboard --logdir tests/logs/hparams
tensorboard --logdir tests/logs
30 changes: 30 additions & 0 deletions tests/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@


def test_timedelta_to_hms():
"""(Deprecated)"""
td = timedelta(minutes=10, hours=2, seconds=30, microseconds=111)
out = timedelta_to_hms(td)
assert out == "2h10m30s"
Expand Down Expand Up @@ -198,6 +199,35 @@ def test_initialize_manual():
tuner.search(train_data, epochs=3, validation_data=test_data)


def test_search_with_callbacks_manual():
"""test logging with TensorBoardCallbacks
manual test is required. log files for tensorboard,
then, run tensorboard server as:
```bash
tensorboard --logdir tests/logs/with-callbacks
```
"""
tuner = Hyperband(
build_model,
objective="val_acc",
max_epochs=3,
directory="tests/logs/with-callbacks/search",
project_name="initialize_manual",
overwrite=True,
logger=TensorBoardLogger(
metrics="val_acc",
logdir="tests/logs/with-callbacks",
overwrite=True,
),
)
setup_tb(tuner)
train_data, test_data = make_dataset()
callbacks = [tf.keras.callbacks.TensorBoard(log_dir="tests/logs/with-callbacks")]
tuner.search(train_data, epochs=3, validation_data=test_data, callbacks=callbacks)


def test_parse():
trained_trial_state = {
"trial_id": "bb0649bfdb92155d308f12dca83152e1",
Expand Down

0 comments on commit a595f73

Please sign in to comment.