-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add optmizers to optim init and test (#17)
- Loading branch information
Showing
2 changed files
with
109 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from edugrad.optim.optimizer import SGD, Adam, AdamW |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import numpy as np | ||
import torch | ||
import unittest | ||
from edugrad import Tensor | ||
from edugrad.optim import Adam, SGD, AdamW | ||
|
||
np.random.seed(1337) | ||
x_init = np.random.randn(1,4).astype(np.float32) | ||
W_init = np.random.randn(4,4).astype(np.float32) | ||
m_init = np.random.randn(1,4).astype(np.float32) | ||
|
||
class TeenyNet: | ||
def __init__(self, tensor): | ||
self.x = tensor(x_init.copy(), requires_grad=True) | ||
self.W = tensor(W_init.copy(), requires_grad=True) | ||
def forward(self): | ||
return (self.x * self.W).sum() | ||
|
||
class TinyNet: | ||
def __init__(self, tensor): | ||
self.x = tensor(x_init.copy(), requires_grad=True) | ||
self.W = tensor(W_init.copy(), requires_grad=True) | ||
self.m = tensor(m_init.copy()) | ||
|
||
def forward(self): | ||
out = self.x.matmul(self.W).relu() | ||
# print(out.detach().numpy()) | ||
out = out.log_softmax(1) | ||
out = out.mul(self.m).add(self.m).sum() | ||
return out | ||
|
||
def step(tensor, optim, steps=1, teeny=False, **kwargs): | ||
net = TeenyNet(tensor) if teeny else TinyNet(tensor) | ||
optim = optim([net.x, net.W], **kwargs) | ||
for _ in range(steps): | ||
out = net.forward() | ||
optim.zero_grad() | ||
out.backward() | ||
optim.step() | ||
return net.x.detach().numpy(), net.W.detach().numpy() | ||
|
||
|
||
class TestOptim(unittest.TestCase): | ||
|
||
def _test_optim(self, tinygrad_optim, torch_optim, steps, opts, atol, rtol): | ||
for x,y in zip(step(Tensor, tinygrad_optim, steps, **opts), | ||
step(torch.tensor, torch_optim, steps, **opts)): | ||
np.testing.assert_allclose(x, y, atol=atol, rtol=rtol) | ||
|
||
def _test_sgd(self, steps, opts, atol, rtol): self._test_optim(SGD, torch.optim.SGD, steps, opts, atol, rtol) | ||
def _test_adam(self, steps, opts, atol, rtol): self._test_optim(Adam, torch.optim.Adam, steps, opts, atol, rtol) | ||
def _test_adamw(self, steps, opts, atol, rtol): self._test_optim(AdamW, torch.optim.AdamW, steps, opts, atol, rtol) | ||
|
||
def test_multistep_sgd_high_lr_teeny(self): self._test_sgd(2, {'lr': 1.1, 'teeny': True}, 1e-6, 1e-5) | ||
def test_multistep_adam_high_lr_teeny(self): self._test_adam(2, {'lr': 1.1, 'teeny': True}, 2e-4, 5e-4) | ||
|
||
def test_sgd(self): self._test_sgd(1, {'lr': 0.001}, 1e-6, 0) | ||
def test_sgd_high_lr(self): self._test_sgd(1, {'lr': 10}, 1e-6, 1e-5) | ||
def test_sgd_wd(self): self._test_sgd(1, {'lr': 0.001, 'weight_decay': 0.1}, 1e-6, 0) | ||
def test_sgd_high_lr_wd(self): self._test_sgd(1, {'lr': 10, 'weight_decay': 0.1}, 1e-6, 1e-5) | ||
|
||
def test_multistep_sgd(self): self._test_sgd(10, {'lr': 0.001}, 1e-6, 0) | ||
def test_multistep_sgd_high_lr(self): self._test_sgd(10, {'lr': 10}, 1e-6, 3e-4) | ||
def test_multistep_sgd_wd(self): self._test_sgd(10, {'lr': 0.001, 'weight_decay': 0.1}, 1e-6, 0) | ||
def test_multistep_sgd_high_lr_wd(self): self._test_sgd(10, {'lr': 9, 'weight_decay': 0.1}, 1e-6, 3e-4) | ||
|
||
def test_multistep_sgd_momentum(self): self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9}, 1e-6, 0) | ||
def test_multistep_sgd_high_lr_momentum(self): self._test_sgd(10, {'lr': 10, 'momentum': 0.9}, 1e-5, 3e-4) | ||
def test_multistep_sgd_momentum_wd(self): self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9, 'weight_decay': 0.1}, 1e-6, 0) | ||
def test_multistep_sgd_high_lr_momentum_wd(self): self._test_sgd(10, {'lr': 10, 'momentum': 0.9, 'weight_decay': 0.1}, 1e-5, 3e-4) | ||
|
||
def test_multistep_sgd_nesterov_momentum(self): self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9, 'nesterov': True}, 1e-5, 0) | ||
def test_multistep_sgd_high_lr_nesterov_momentum(self): self._test_sgd(10, {'lr': 10, 'momentum': 0.9, 'nesterov': True}, 1e-5, 3e-4) | ||
def test_multistep_sgd_nesterov_momentum_wd(self): | ||
self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1}, 1e-5, 0) | ||
def test_multistep_sgd_high_lr_nesterov_momentum_wd(self): | ||
self._test_sgd(10, {'lr': 9, 'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1}, 1e-5, 3e-4) | ||
|
||
def test_adam(self): self._test_adam(1, {'lr': 0.001}, 1e-5, 0) | ||
def test_adam_high_lr(self): self._test_adam(1, {'lr': 10}, 1e-4, 1e-4) | ||
def test_adamw(self): self._test_adamw(1, {'lr': 0.001}, 1e-5, 0) | ||
def test_adamw_high_lr(self): self._test_adamw(1, {'lr': 10}, 1e-4, 1e-4) | ||
|
||
def test_multistep_adam(self): self._test_adam(10, {'lr': 0.001}, 1e-5, 0) | ||
def test_multistep_adam_high_lr(self): self._test_adam(10, {'lr': 10}, 2e-4, 5e-4) | ||
|
||
def test_multistep_adamw(self): self._test_adamw(10, {'lr': 0.001}, 1e-5, 0) | ||
def test_multistep_adamw_high_lr(self): self._test_adamw(10, {'lr': 10}, 5e-4, 2e-3) | ||
|
||
def test_duped_weights(self): | ||
for Opt in [Adam, AdamW, SGD]: | ||
losses = [] | ||
for i in range(2): | ||
w = Tensor(x_init.copy()) | ||
opt = Opt([w], lr=0.1) if i == 0 else Opt([w, w], lr=0.1) | ||
|
||
loss = None | ||
for _ in range(3): | ||
loss = w.sum() | ||
opt.zero_grad() | ||
loss.backward() | ||
opt.step() | ||
losses.append(loss.numpy()) | ||
|
||
np.testing.assert_allclose(losses[0], losses[1], atol=1e-4, rtol=0) | ||
|
||
if __name__ == '__main__': | ||
unittest.main() |