Skip to content

Commit

Permalink
Prevent unexpected unpacking error when calling lr_finder.plot() wi…
Browse files Browse the repository at this point in the history
…th `suggest_lr=True` (#98)

* MAINT: always return 2 values when `suggest_lr` is True

As it's mentioned in #88, suggested lr would not be returned along
with `ax` (`matplotlib.Axes`) if there is no sufficient data points
to calculate gradient of lr-loss curve.

Though it would warn user about this problem [1], but users might be
confused by another error caused by unpacking returned value. This is
because users would usually expect it works as below:
```python
ax, lr = lr_finder.plot(..., suggest_lr=True)
```

But the second returned value `lr` might not exist when it failed
to find a suggested lr, then the returned value would be a single
value instead. Therefore, the unpacking syntax `ax, lr = ...` would
fail and result in the error reported in #88.

So we fix it by always returning both `ax` and `suggested_lr` when
the flag `suggest_lr` is True to meet the expectation, and leave
the responsibility of "check whether `lr` is null" back to user.

[1]: https://github.com/davidtvs/pytorch-lr-finder/blob/fd9e949/torch_lr_finder/lr_finder.py#L539-L542

* MAINT: raise error earlier if there is no sufficient data points to suggest LR

Now LR finder will raise a RuntimeError if there is no sufficient data
points to calculate gradient for suggested LR when
`lr_finder.plot(..., suggest_lr=True)` is called.

The error message will clarify the details of failure, so users can fix
the issue earlier as well.
  • Loading branch information
NaleRaphael authored Aug 26, 2024
1 parent fd9e949 commit 5b9d92c
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 34 deletions.
74 changes: 61 additions & 13 deletions tests/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import task as mod_task
import dataset as mod_dataset

import numpy as np
import matplotlib.pyplot as plt

# Check available backends for mixed precision training
Expand Down Expand Up @@ -400,21 +401,68 @@ def test_plot_with_skip_and_suggest_lr(suggest_lr, skip_start, skip_end):
)

fig, ax = plt.subplots()
results = lr_finder.plot(
skip_start=skip_start, skip_end=skip_end, suggest_lr=suggest_lr, ax=ax
)

if num_iter - skip_start - skip_end <= 1:
# handle data with one or zero lr
assert len(ax.lines) == 1
assert results is ax
results = None
if suggest_lr and num_iter < (skip_start + skip_end + 2):
# No sufficient data points to calculate gradient, so this call should fail
with pytest.raises(RuntimeError, match="Need at least"):
results = lr_finder.plot(
skip_start=skip_start, skip_end=skip_end, suggest_lr=suggest_lr, ax=ax
)

# No need to proceed then
return
else:
# handle different suggest_lr
# for 'steepest': the point with steepest gradient (minimal gradient)
assert len(ax.lines) == 1
assert len(ax.collections) == int(suggest_lr)
if results is not ax:
assert len(results) == 2
results = lr_finder.plot(
skip_start=skip_start, skip_end=skip_end, suggest_lr=suggest_lr, ax=ax
)

# NOTE:
# - ax.lines[0]: the lr-loss curve. It should be always available once
# `ax.plot(lrs, losses)` is called. But when there is no sufficent data
# point (num_iter <= skip_start + skip_end), the coordinates will be
# 2 empty arrays.
# - ax.collections[0]: the point of suggested lr (type: <PathCollection>).
# It's available only when there are sufficient data points to calculate
# gradient of lr-loss curve.
assert len(ax.lines) == 1

if suggest_lr:
assert isinstance(results, tuple) and len(results) == 2

ret_ax, ret_lr = results
assert ret_ax is ax

# XXX: Currently suggested lr is selected according to gradient of
# lr-loss curve, so there should be at least 2 valid data points (after
# filtered by `skip_start` and `skip_end`). If not, the returned lr
# will be None.
# But we would need to rework on this if there are more suggestion
# methods is supported in the future.
if num_iter - skip_start - skip_end <= 1:
assert ret_lr is None
assert len(ax.collections) == 0
else:
assert len(ax.collections) == 1
else:
# Not suggesting lr, so it just plots a lr-loss curve.
assert results is ax
assert len(ax.collections) == 0

# Check whether the data of plotted line is the same as the one filtered
# according to `skip_start` and `skip_end`.
lrs = np.array(lr_finder.history["lr"])
losses = np.array(lr_finder.history["loss"])
x, y = ax.lines[0].get_data()

# If skip_end is 0, we should replace it with None. Otherwise, it
# will create a slice as `x[0:-0]` which is an empty list.
_slice = slice(skip_start, -skip_end if skip_end != 0 else None, None)
assert np.allclose(x, lrs[_slice])
assert np.allclose(y, losses[_slice])

# Close figure to release memory
plt.close()


def test_suggest_lr():
Expand Down
45 changes: 24 additions & 21 deletions torch_lr_finder/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,13 @@ def plot(
if show_lr is not None and not isinstance(show_lr, float):
raise ValueError("show_lr must be float")

# Make sure there are enough data points to suggest a learning rate
if suggest_lr and len(self.history["lr"]) < (skip_start + skip_end + 2):
raise RuntimeError(
f"Need at least {skip_start + skip_end + 2} iterations to suggest a "
f"learning rate. Got {len(self.history['lr'])}"
)

# Get the data to plot from the history dictionary. Also, handle skip_end=0
# properly so the behaviour is the expected
lrs = self.history["lr"]
Expand All @@ -533,25 +540,19 @@ def plot(
if suggest_lr:
# 'steepest': the point with steepest gradient (minimal gradient)
print("LR suggestion: steepest gradient")
min_grad_idx = None
try:
min_grad_idx = (np.gradient(np.array(losses))).argmin()
except ValueError:
print(
"Failed to compute the gradients, there might not be enough points."
)
if min_grad_idx is not None:
print("Suggested LR: {:.2E}".format(lrs[min_grad_idx]))
ax.scatter(
lrs[min_grad_idx],
losses[min_grad_idx],
s=75,
marker="o",
color="red",
zorder=3,
label="steepest gradient",
)
ax.legend()
min_grad_idx = (np.gradient(np.array(losses))).argmin()

print("Suggested LR: {:.2E}".format(lrs[min_grad_idx]))
ax.scatter(
lrs[min_grad_idx],
losses[min_grad_idx],
s=75,
marker="o",
color="red",
zorder=3,
label="steepest gradient",
)
ax.legend()

if log_lr:
ax.set_xscale("log")
Expand All @@ -565,8 +566,10 @@ def plot(
if fig is not None:
plt.show()

if suggest_lr and min_grad_idx is not None:
return ax, lrs[min_grad_idx]
if suggest_lr:
# If suggest_lr is set, then we should always return 2 values.
suggest_lr = lrs[min_grad_idx]
return ax, suggest_lr
else:
return ax

Expand Down

0 comments on commit 5b9d92c

Please sign in to comment.