Skip to content

Commit

Permalink
upgrade in install dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
LEFTeyes committed May 1, 2022
1 parent d3fc966 commit 344ce24
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 22 deletions.
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,15 @@

## Description

A Warmup Scheduler for Pytorch to achieve the warmup learning rate at the beginning of training.
A Warmup Scheduler for Pytorch to make the warmup learning rate change at the beginning of training.

## setup

Notice: need to install pytorch>=1.1.0 manually. \
The official website of pytorch is: https://pytorch.org/

Then install as follows:

```
pip install warmup_scheduler_pytorch
```
Expand All @@ -26,7 +31,7 @@ import torch
from torch.optim import SGD # example
from torch.optim.lr_scheduler import CosineAnnealingLR # example

from warmup_scheduler_pytorch.warmup_module import WarmUpScheduler
from warmup_scheduler_pytorch import WarmUpScheduler

model = Model()
optimizer = SGD(model.parameters(), lr=0.1)
Expand Down
2 changes: 1 addition & 1 deletion example.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def run():
epoch_lr[1].append(get_lr(optimizer))

# output = model(...)
# loss = loss_fn(output, ...)
# loss = loss_fn(output, label)
# loss.backward()
optimizer.step()
optimizer.zero_grad()
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[build-system]
requires = ["setuptools>=42"]
requires = ["setuptools>=42.0.0"]
build-backend = "setuptools.build_meta"
9 changes: 6 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ classifiers =
Intended Audience :: Science/Research
License :: OSI Approved :: MIT License
Operating System :: OS Independent
Programming Language :: Python :: 3
Programming Language :: Python :: 3.6
Programming Language :: Python :: 3.7
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Topic :: Scientific/Engineering :: Artificial Intelligence


Expand All @@ -22,8 +25,8 @@ package_dir =
= src
packages = find:
python_requires = >=3.6
install_requires=
torch >= 1.7.1
# install_requires=
# torch >= 1.1.0


[options.packages.find]
Expand Down
3 changes: 1 addition & 2 deletions src/warmup_scheduler_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .warmup_module import WarmUpScheduler, VERSION

__all__ = ['__version__', 'WarmUpScheduler']

__version__ = VERSION

__all__ = ['__version__', 'WarmUpScheduler']
4 changes: 2 additions & 2 deletions src/warmup_scheduler_pytorch/warmup_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
"""

from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler # ignore its error
from torch.optim.lr_scheduler import _LRScheduler

__all__ = ['VERSION', 'WarmUpScheduler']

VERSION = '0.1.0'
VERSION = '0.1.1'


class WarmUpScheduler(object):
Expand Down
11 changes: 2 additions & 9 deletions tests/test_base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,3 @@
from src.warmup_scheduler_pytorch import WarmUpScheduler, __version__
from src.warmup_scheduler_pytorch.warmup_module import VERSION


def test_version():
assert VERSION == __version__


def test_import():
assert isinstance(WarmUpScheduler, object)
from src.warmup_scheduler_pytorch import WarmUpScheduler, __version__
from src.warmup_scheduler_pytorch.warmup_module import VERSION
4 changes: 2 additions & 2 deletions tests/test_warmup.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def test_warmup_init(self):
pass

def test_warmup_state_dict(self):
sd = self.warmup_scheduler.state_dict()
self.warmup_scheduler.load_state_dict(sd)
state_dict = self.warmup_scheduler.state_dict()
self.warmup_scheduler.load_state_dict(state_dict)

def test_warmup_get(self):
self.warmup_scheduler.get_last_lr()
Expand Down

0 comments on commit 344ce24

Please sign in to comment.