Skip to content

Commit

Permalink
fix training smoothing bug (#145)
Browse files Browse the repository at this point in the history
  • Loading branch information
grfrederic authored Feb 28, 2024
1 parent 9fb83f8 commit 007cfcf
Showing 1 changed file with 44 additions and 33 deletions.
77 changes: 44 additions & 33 deletions src/bmi/estimators/neural/_training_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,22 @@ def __init__(
Args:
max_n_steps: maximum number of training steps allowed
early_stopping: whether early stopping is turned on
train_smooth_factor: TODO(Frederic, Pawel): Add description.
train_smooth_factor: fraction of the training history length used
as the smoothing window for convergence checks. E.g. when
`train_smooth_factor=0.1` and the current training history
has length 1000, averages over 100 steps will be used. Max 0.5.
verbose: whether to print information during the training
enable_tqdm: whether to use tqdm's progress bar during training
history_in_additional_information: whether the generated additional
information should contain training history (evaluated loss on
training and test populations). We recommend keeping this flag
turned on.
"""
assert train_smooth_factor <= 0.5, "train_smooth_factor can be at most 0.5"

self.max_n_steps = max_n_steps
self.early_stopping = early_stopping
self.train_smooth_window = int(max_n_steps * train_smooth_factor)
self.train_smooth_factor = train_smooth_factor
self.verbose = verbose

self._train_history_in_additional_information = train_history_in_additional_information
Expand Down Expand Up @@ -114,40 +119,46 @@ def detect_warnings(self): # noqa: C901
if self.verbose:
print("WARNING: Early stopping enabled but max_n_steps reached.")

# analyze training
# get train MI history
train_mi = jnp.array([mi for _step, mi in self._mi_train_history])
w = self.train_smooth_window
cs = jnp.cumsum(train_mi)
# TODO(Pawel, Frederic): If training smooth window is too
# long we will have an error that subtraction between (n,)
# and (0,) arrays cannot be performed.
w = int(self.train_smooth_factor * len(train_mi))

# check if training long enough to compute diagnostics
if w < 1:
self._additional_information["training_too_short_for_diagnostics"] = True
if self.verbose:
print("WARNING: Training too short to compute diagnostics.")
return

# compute smoothed mi
cs = jnp.cumsum(jnp.concatenate([jnp.zeros(1), train_mi]))
train_mi_smooth = (cs[w:] - cs[:-w]) / w

if len(train_mi_smooth) > 0:
train_mi_smooth_max = float(train_mi_smooth.max())
train_mi_smooth_fin = float(train_mi_smooth[-1])
if train_mi_smooth_max > 1.05 * train_mi_smooth_fin:
self._additional_information["max_training_mi_decreased"] = True
if self.verbose:
print(
f"WARNING: Smoothed training MI fell compared to highest value: "
f"max={train_mi_smooth_max:.3f} vs "
f"final={train_mi_smooth_fin:.3f}"
)

w = self.train_smooth_window
if len(train_mi_smooth) >= w:
train_mi_smooth_fin = float(train_mi_smooth[-1])
train_mi_smooth_prv = float(train_mi_smooth[-w])
if train_mi_smooth_fin > 1.05 * train_mi_smooth_prv:
self._additional_information["training_mi_still_increasing"] = True
if self.verbose:
print(
f"WARNING: Smoothed raining MI was still "
f"increasing when training stopped: "
f"final={train_mi_smooth_fin:.3f} vs "
f"{w} step(s) ago={train_mi_smooth_prv:.3f}"
)
# n + 1 - w >= w + 1 since w <= int(0.5 * n)
assert len(train_mi_smooth) >= w + 1

train_mi_smooth_max = float(train_mi_smooth.max())
train_mi_smooth_fin = float(train_mi_smooth[-1])
if train_mi_smooth_max > 1.05 * train_mi_smooth_fin:
self._additional_information["max_training_mi_decreased"] = True
if self.verbose:
print(
f"WARNING: Smoothed training MI fell compared to highest value: "
f"max={train_mi_smooth_max:.3f} vs "
f"final={train_mi_smooth_fin:.3f}"
)

train_mi_smooth_fin = float(train_mi_smooth[-1])
train_mi_smooth_prv = float(train_mi_smooth[-w])
if train_mi_smooth_fin > 1.05 * train_mi_smooth_prv:
self._additional_information["training_mi_still_increasing"] = True
if self.verbose:
print(
f"WARNING: Smoothed raining MI was still "
f"increasing when training stopped: "
f"final={train_mi_smooth_fin:.3f} vs "
f"{w} step(s) ago={train_mi_smooth_prv:.3f}"
)

def _tqdm_init(self):
self._tqdm = tqdm.tqdm(
Expand Down

0 comments on commit 007cfcf

Please sign in to comment.