diff --git a/README.md b/README.md index 85f2c54..6c04d98 100644 --- a/README.md +++ b/README.md @@ -8,10 +8,15 @@ ## Description -A Warmup Scheduler for Pytorch to achieve the warmup learning rate at the beginning of training. +A Warmup Scheduler for Pytorch to make the warmup learning rate change at the beginning of training. ## setup +Notice: need to install pytorch>=1.1.0 manually. \ +The official website of pytorch is: https://pytorch.org/ + +Then install as follows: + ``` pip install warmup_scheduler_pytorch ``` @@ -26,7 +31,7 @@ import torch from torch.optim import SGD # example from torch.optim.lr_scheduler import CosineAnnealingLR # example -from warmup_scheduler_pytorch.warmup_module import WarmUpScheduler +from warmup_scheduler_pytorch import WarmUpScheduler model = Model() optimizer = SGD(model.parameters(), lr=0.1) diff --git a/example.py b/example.py index ca44c63..141193b 100644 --- a/example.py +++ b/example.py @@ -45,7 +45,7 @@ def run(): epoch_lr[1].append(get_lr(optimizer)) # output = model(...) - # loss = loss_fn(output, ...) + # loss = loss_fn(output, label) # loss.backward() optimizer.step() optimizer.zero_grad() diff --git a/pyproject.toml b/pyproject.toml index fa7093a..04d1b04 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,3 @@ [build-system] -requires = ["setuptools>=42"] +requires = ["setuptools>=42.0.0"] build-backend = "setuptools.build_meta" \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index e577714..5a748ad 100644 --- a/setup.cfg +++ b/setup.cfg @@ -13,7 +13,10 @@ classifiers = Intended Audience :: Science/Research License :: OSI Approved :: MIT License Operating System :: OS Independent - Programming Language :: Python :: 3 + Programming Language :: Python :: 3.6 + Programming Language :: Python :: 3.7 + Programming Language :: Python :: 3.8 + Programming Language :: Python :: 3.9 Topic :: Scientific/Engineering :: Artificial Intelligence @@ -22,8 +25,8 @@ package_dir = = src packages = find: python_requires = >=3.6 -install_requires= - torch >= 1.7.1 +# install_requires= +# torch >= 1.1.0 [options.packages.find] diff --git a/src/warmup_scheduler_pytorch/__init__.py b/src/warmup_scheduler_pytorch/__init__.py index 99aef39..266584d 100644 --- a/src/warmup_scheduler_pytorch/__init__.py +++ b/src/warmup_scheduler_pytorch/__init__.py @@ -1,6 +1,5 @@ from .warmup_module import WarmUpScheduler, VERSION -__all__ = ['__version__', 'WarmUpScheduler'] - __version__ = VERSION +__all__ = ['__version__', 'WarmUpScheduler'] diff --git a/src/warmup_scheduler_pytorch/warmup_module.py b/src/warmup_scheduler_pytorch/warmup_module.py index ec917db..b8e4b72 100644 --- a/src/warmup_scheduler_pytorch/warmup_module.py +++ b/src/warmup_scheduler_pytorch/warmup_module.py @@ -4,11 +4,11 @@ """ from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler # ignore its error +from torch.optim.lr_scheduler import _LRScheduler __all__ = ['VERSION', 'WarmUpScheduler'] -VERSION = '0.1.0' +VERSION = '0.1.1' class WarmUpScheduler(object): diff --git a/tests/test_base.py b/tests/test_base.py index cac8b6f..0fe325b 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -1,10 +1,3 @@ -from src.warmup_scheduler_pytorch import WarmUpScheduler, __version__ -from src.warmup_scheduler_pytorch.warmup_module import VERSION - - -def test_version(): - assert VERSION == __version__ - - def test_import(): - assert isinstance(WarmUpScheduler, object) + from src.warmup_scheduler_pytorch import WarmUpScheduler, __version__ + from src.warmup_scheduler_pytorch.warmup_module import VERSION diff --git a/tests/test_warmup.py b/tests/test_warmup.py index 0b78ae6..e6e022d 100644 --- a/tests/test_warmup.py +++ b/tests/test_warmup.py @@ -77,8 +77,8 @@ def test_warmup_init(self): pass def test_warmup_state_dict(self): - sd = self.warmup_scheduler.state_dict() - self.warmup_scheduler.load_state_dict(sd) + state_dict = self.warmup_scheduler.state_dict() + self.warmup_scheduler.load_state_dict(state_dict) def test_warmup_get(self): self.warmup_scheduler.get_last_lr()