Skip to content

Commit

Permalink
Cleanup-1 [termination condition and init file] (deepchem#4109)
Browse files Browse the repository at this point in the history
* minimizer and init file

* added test

* added test

* fixed formatting
  • Loading branch information
sudo-rsingh authored Sep 6, 2024
1 parent 3c010b6 commit 5cd9e6f
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 4 deletions.
Empty file.
16 changes: 12 additions & 4 deletions deepchem/utils/differentiation_utils/optimize/minimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def gd(
x_rtol: float = 1e-8,
# misc parameters
verbose=False,
terminate=False,
**unused):
r"""
Vanilla gradient descent with momentum. The stopping conditions use OR criteria.
Expand Down Expand Up @@ -61,6 +62,8 @@ def gd(
The absolute tolerance of the norm of the input ``x``.
x_rtol: float or None
The relative tolerance of the norm of the input ``x``.
terminate: bool (default False)
Whether to use termination condition or, keep on running the minimizer.
"""

Expand All @@ -79,11 +82,12 @@ def gd(
# check the stopping conditions
to_stop = stop_cond.to_stop(i, x, xprev, f, fprev)

if to_stop:
if to_stop and terminate:
break

fprev = f
x = stop_cond.get_best_x(x)
if terminate:
x = stop_cond.get_best_x(x)
return x


Expand All @@ -104,6 +108,7 @@ def adam(
x_rtol: float = 1e-8,
# misc parameters
verbose=False,
terminate=False,
**unused):
r"""
Adam optimizer by Kingma & Ba (2015). The stopping conditions use OR criteria.
Expand Down Expand Up @@ -148,6 +153,8 @@ def adam(
The absolute tolerance of the norm of the input ``x``.
x_rtol: float or None
The relative tolerance of the norm of the input ``x``.
terminate: bool (default False)
Whether to use the termination condition, or keep running the minimizer.
"""

x = x0.clone()
Expand Down Expand Up @@ -175,11 +182,12 @@ def adam(
# check the stopping conditions
to_stop = stop_cond.to_stop(i, x, xprev, f, fprev)

if to_stop:
if to_stop and terminate:
break

fprev = f
x = stop_cond.get_best_x(x)
if terminate:
x = stop_cond.get_best_x(x)
return x


Expand Down
15 changes: 15 additions & 0 deletions deepchem/utils/test/test_differentiation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,3 +1228,18 @@ def lotka_volterra(t, z, *params):
assert torch.allclose(sol[-1][0],
torch.tensor(sol_scipy.y[0][-1], dtype=torch.float),
0.01, 0.001)


@pytest.mark.torch
def test_terminate_param():
from deepchem.utils.differentiation_utils import gd
import torch

def fun(x):
return torch.tan(x), (1 / torch.cos(x))**2

x0 = torch.tensor(0.0, requires_grad=True)
x0.grad = torch.tensor(1.0)
x1 = gd(fun, x0, [], terminate=True)
x2 = gd(fun, x0, [], terminate=False)
assert not torch.allclose(x1, x2)

0 comments on commit 5cd9e6f

Please sign in to comment.