Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix training smoothing bug #145

Merged
merged 2 commits into from
Feb 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 44 additions & 33 deletions src/bmi/estimators/neural/_training_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,22 @@ def __init__(
Args:
max_n_steps: maximum number of training steps allowed
early_stopping: whether early stopping is turned on
train_smooth_factor: TODO(Frederic, Pawel): Add description.
train_smooth_factor: fraction of the training history length used
as the smoothing window for convergence checks. E.g. when
`train_smooth_factor=0.1` and the current training history
has length 1000, averages over 100 steps will be used. Max 0.5.
verbose: whether to print information during the training
enable_tqdm: whether to use tqdm's progress bar during training
history_in_additional_information: whether the generated additional
information should contain training history (evaluated loss on
training and test populations). We recommend keeping this flag
turned on.
"""
assert train_smooth_factor <= 0.5, "train_smooth_factor can be at most 0.5"
pawel-czyz marked this conversation as resolved.
Show resolved Hide resolved

self.max_n_steps = max_n_steps
self.early_stopping = early_stopping
self.train_smooth_window = int(max_n_steps * train_smooth_factor)
self.train_smooth_factor = train_smooth_factor
self.verbose = verbose

self._train_history_in_additional_information = train_history_in_additional_information
Expand Down Expand Up @@ -114,40 +119,46 @@ def detect_warnings(self): # noqa: C901
if self.verbose:
print("WARNING: Early stopping enabled but max_n_steps reached.")

# analyze training
# get train MI history
train_mi = jnp.array([mi for _step, mi in self._mi_train_history])
w = self.train_smooth_window
cs = jnp.cumsum(train_mi)
# TODO(Pawel, Frederic): If training smooth window is too
# long we will have an error that subtraction between (n,)
# and (0,) arrays cannot be performed.
w = int(self.train_smooth_factor * len(train_mi))

# check if training long enough to compute diagnostics
if w < 1:
self._additional_information["training_too_short_for_diagnostics"] = True
if self.verbose:
print("WARNING: Training too short to compute diagnostics.")
return

# compute smoothed mi
cs = jnp.cumsum(jnp.concatenate([jnp.zeros(1), train_mi]))
train_mi_smooth = (cs[w:] - cs[:-w]) / w

if len(train_mi_smooth) > 0:
train_mi_smooth_max = float(train_mi_smooth.max())
train_mi_smooth_fin = float(train_mi_smooth[-1])
if train_mi_smooth_max > 1.05 * train_mi_smooth_fin:
self._additional_information["max_training_mi_decreased"] = True
if self.verbose:
print(
f"WARNING: Smoothed training MI fell compared to highest value: "
f"max={train_mi_smooth_max:.3f} vs "
f"final={train_mi_smooth_fin:.3f}"
)

w = self.train_smooth_window
if len(train_mi_smooth) >= w:
train_mi_smooth_fin = float(train_mi_smooth[-1])
train_mi_smooth_prv = float(train_mi_smooth[-w])
if train_mi_smooth_fin > 1.05 * train_mi_smooth_prv:
self._additional_information["training_mi_still_increasing"] = True
if self.verbose:
print(
f"WARNING: Smoothed raining MI was still "
f"increasing when training stopped: "
f"final={train_mi_smooth_fin:.3f} vs "
f"{w} step(s) ago={train_mi_smooth_prv:.3f}"
)
# n + 1 - w >= w + 1 since w <= int(0.5 * n)
assert len(train_mi_smooth) >= w + 1

train_mi_smooth_max = float(train_mi_smooth.max())
train_mi_smooth_fin = float(train_mi_smooth[-1])
if train_mi_smooth_max > 1.05 * train_mi_smooth_fin:
self._additional_information["max_training_mi_decreased"] = True
if self.verbose:
print(
f"WARNING: Smoothed training MI fell compared to highest value: "
f"max={train_mi_smooth_max:.3f} vs "
f"final={train_mi_smooth_fin:.3f}"
)

train_mi_smooth_fin = float(train_mi_smooth[-1])
train_mi_smooth_prv = float(train_mi_smooth[-w])
if train_mi_smooth_fin > 1.05 * train_mi_smooth_prv:
self._additional_information["training_mi_still_increasing"] = True
if self.verbose:
print(
f"WARNING: Smoothed raining MI was still "
f"increasing when training stopped: "
f"final={train_mi_smooth_fin:.3f} vs "
f"{w} step(s) ago={train_mi_smooth_prv:.3f}"
)

def _tqdm_init(self):
self._tqdm = tqdm.tqdm(
Expand Down
Loading