-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c858530
commit 56b9725
Showing
84 changed files
with
6,436 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -102,3 +102,6 @@ venv.bak/ | |
|
||
# mypy | ||
.mypy_cache/ | ||
|
||
# PyCharm | ||
.idea |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
|
||
|
||
# PyTorch-SSO | ||
|
||
Scalable Second-Order methods in PyTorch. | ||
|
||
- Open-source library for second-order optimization and Bayesian inference. | ||
|
||
- An earlier iteration of this library ([chainerkfac](https://github.com/tyohei/chainerkfac)) holds the world record for large-batch training of ResNet-50 on ImageNet by [Kronecker-Factored Approximate Curvature (K-FAC)](https://arxiv.org/abs/1503.05671), scaling to batch sizes of 131K. | ||
- Kazuki Osawa et al, “Large-Scale Distributed Second-Order Optimization Using Kronecker-Factored Approximate Curvature for Deep Convolutional Neural Networks”, **IEEE/CVF CVPR 2019**. | ||
- [[paper](http://openaccess.thecvf.com/content_CVPR_2019/html/Osawa_Large-Scale_Distributed_Second-Order_Optimization_Using_Kronecker-Factored_Approximate_Curvature_for_Deep_CVPR_2019_paper.html)] [[poster](https://kazukiosawa.github.io/cvpr19_poster.pdf)] | ||
- This library is basis for the Natural Gradient for Bayesian inference (Variational Inference) on ImageNet. | ||
- Kazuki Osawa et al, “Practical Deep Learning with Bayesian Principles”, **NeurIPS 2019**. | ||
- [[paper (preprint)](https://arxiv.org/abs/1906.02506)] | ||
|
||
## Scalable Second-Order Optimization | ||
|
||
### Optimizers | ||
|
||
PyTorch-SSO provides the following optimizers. | ||
|
||
- Second-Order Optimization | ||
- `torchsso.optim.SecondOrderOptimizer` [[source](https://github.com/cybertronai/pytorch-sso/blob/master/torchsso/optim/secondorder.py)] | ||
- updates the parameters with the gradients pre-conditioned by the curvature of the loss function (`torch.nn.functional.cross_entropy`) for each `param_group`. | ||
- Variational Inference (VI) | ||
- `torchsso.optim.VIOptimizer` [[source](https://github.com/cybertronai/pytorch-sso/blob/master/torchsso/optim/vi.py)] | ||
- updates the posterior distribution (mean, covariance) of the parameters by using the curvature for each `param_group`. | ||
|
||
### Curvatures | ||
|
||
You can specify a type of the information matrix to be used as the curvature from the following. | ||
|
||
- Hessian [WIP] | ||
|
||
- Fisher information matrix | ||
|
||
- Covariance matrix (empirical Fisher) | ||
|
||
|
||
|
||
Refer [Information matrices and generalization](https://arxiv.org/abs/1906.07774) by Valentin Thomas et al. (2019) for the definitions and the properties of these information matrices. | ||
|
||
|
||
|
||
Refer Section 6 of [Optimization Methods for Large-Scale Machine Learning](https://arxiv.org/abs/1606.04838) by L´eon Bottou et al. (2018) for a clear explanation of the second-order optimzation using these matrices as curvature. | ||
|
||
### Approximation Methods | ||
|
||
![](docs/overview.png) | ||
|
||
PyTorch-SSO calculates the curvature as a layer-wise block-diagonal matrix. | ||
|
||
You can specify the approximation method for the curvatures in each layer from the follwing. | ||
|
||
1. Full (No approximation) | ||
2. Diagonal approximation | ||
3. [Kronecker-Factored Approximate Curvature (K-FAC)](https://arxiv.org/abs/1503.05671) | ||
|
||
PyTorch-SSO currently supports the following layers (Modules) in PyTorch: | ||
|
||
| Layer (Module) | Full | Diagonal | K-FAC | | ||
| ------------------------- | ------------------ | ------------------ | ------------------ | | ||
| `torch.nn.Linear` | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | ||
| `torch.nn.Conv2d` | - | :heavy_check_mark: | :heavy_check_mark: | | ||
| `torch.nn.BatchNorm1d/2d` | - | :heavy_check_mark: | - | | ||
|
||
To apply PyTorch-SSO, | ||
- Set`requires_grad` to `True` for each Module. | ||
- The network you define cannot contain any other modules. | ||
- E.g., You need to use `torch.nn.functional.relu/max_pool2d` instead of `torch.nn.ReLU/MaxPool2d` to define a ConvNet. | ||
|
||
### Distributed Training | ||
|
||
PyTorch-SSO supports *data parallelism* and *MC samples parallelism* (for VI) | ||
for distributed training among multiple processes (GPUs). | ||
|
||
## Installation | ||
To build PyTorch-SSO run (on a Python 3 environment) | ||
```bash | ||
git clone [email protected]:cybertronai/pytorch-sso.git | ||
cd pytorch-sso | ||
python setup.py install | ||
``` | ||
|
||
To use the library | ||
```python | ||
import torchsso | ||
``` | ||
|
||
### Additional requirements | ||
|
||
PyTorch-SSO depends on [CuPy](https://cupy.chainer.org/) for fast GPU computation and [ChainerMN](https://github.com/chainer/chainermn) for communication. To use GPUs, you need to install the following requirements **before the installation of PyTorch-SSO**. | ||
|
||
| Running environment | Requirements | | ||
| ------------------- | ---------------------- | | ||
| single GPU | CuPy | | ||
| multiple GPUs | Cupy with NCCL, MPI4py | | ||
|
||
Refer [CuPy installation guide](https://docs-cupy.chainer.org/en/stable/install.html) and [ChainerMN installation guide](https://docs.chainer.org/en/stable/chainermn/installation/guide.html#chainermn-installation) for details. | ||
|
||
## Examples | ||
|
||
- [Image classification with a single process](https://github.com/cybertronai/pytorch-sso/tree/master/examples/classification) (MNIST, CIFAR-10) | ||
- [Image classification with multiple processes](https://github.com/cybertronai/pytorch-sso/tree/master/examples/distributed/classification) (CIFAR-10/100, ImageNet) | ||
|
||
## Authors | ||
|
||
Kazuki Osawa ([@kazukiosawa](https://github.com/kazukiosawa)) and Yaroslav Bulatov ([@yaroslavvb](https://github.com/yaroslavvb)) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
To run training LeNet-5 for CIFAR-10 classification | ||
```bash | ||
python main.py --config <path/to/config> --download | ||
``` | ||
| optimizer | dataset | architecture | config file path | | ||
| --- | --- | --- | --- | | ||
| [Adam](https://arxiv.org/abs/1412.6980) | CIFAR-10 | LeNet-5 | [configs/cifar10/lenet_adam.json](https://github.com/cybertronai/pytorch-sso/blob/master/examples/classification/configs/cifar10/lenet_adam.json) | | ||
| [K-FAC](https://arxiv.org/abs/1503.05671)| CIFAR-10 | LeNet-5 | [configs/cifar10/lenet_kfac.json](https://github.com/cybertronai/pytorch-sso/blob/master/examples/classification/configs/cifar10/lenet_kfac.json) | | ||
| [Noisy K-FAC](https://arxiv.org/abs/1712.02390)| CIFAR-10 | LeNet-5 | [configs/cifar10/lenet_noisykfac.json](https://github.com/cybertronai/pytorch-sso/blob/master/examples/classification/configs/cifar10/lenet_noisykfac.json) | | ||
| [VOGN](https://arxiv.org/abs/1806.04854)| CIFAR-10 | LeNet-5 + BatchNorm | [configs/cifar10/lenet_vogn.json](https://github.com/cybertronai/pytorch-sso/blob/master/examples/classification/configs/cifar10/lenet_vogn.json) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
{ | ||
"dataset": "CIFAR-10", | ||
"epochs": 100, | ||
"batch_size": 128, | ||
"val_batch_size": 128, | ||
"random_crop": false, | ||
"random_horizontal_flip": false, | ||
"normalizing_data": true, | ||
"arch_file": "models/lenet.py", | ||
"arch_name": "LeNet5", | ||
"optim_name": "Adam", | ||
"optim_args": { | ||
"lr": 1e-3, | ||
"betas": [0.9, 0.999], | ||
"weight_decay": 0.01 | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
{ | ||
"dataset": "CIFAR-10", | ||
"epochs": 50, | ||
"batch_size": 128, | ||
"val_batch_size": 5000, | ||
"random_crop": true, | ||
"random_horizontal_flip": true, | ||
"normalizing_data": true, | ||
"arch_file": "models/lenet.py", | ||
"arch_name": "LeNet5", | ||
"optim_name": "SecondOrderOptimizer", | ||
"optim_args": { | ||
"curv_type":"Fisher", | ||
"curv_shapes": { | ||
"Conv2d": "Kron", | ||
"Linear": "Kron", | ||
"BatchNorm1d": "Diag", | ||
"BatchNorm2d": "Diag" | ||
}, | ||
"lr": 1e-3, | ||
"momentum": 0.9, | ||
"momentum_type": "raw", | ||
"l2_reg": 1e-3, | ||
"acc_steps": 1 | ||
}, | ||
"curv_args": { | ||
"damping": 1e-3, | ||
"ema_decay": 0.999, | ||
"pi_type": "tracenorm" | ||
}, | ||
"fisher_args": { | ||
"approx_type": "mc", | ||
"num_mc": 1 | ||
}, | ||
"scheduler_name": "ExponentialLR", | ||
"scheduler_args": { | ||
"gamma": 0.9 | ||
}, | ||
"log_interval": 64, | ||
"no_cuda": false | ||
} |
41 changes: 41 additions & 0 deletions
41
examples/classification/configs/cifar10/lenet_noisykfac.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
{ | ||
"dataset": "CIFAR-10", | ||
"epochs": 15, | ||
"batch_size": 64, | ||
"val_batch_size": 128, | ||
"random_crop": false, | ||
"random_horizontal_flip": false, | ||
"normalizing_data": false, | ||
"arch_file": "models/lenet.py", | ||
"arch_name": "LeNet5", | ||
"optim_name": "VIOptimizer", | ||
"optim_args": { | ||
"curv_type": "Fisher", | ||
"curv_shapes": { | ||
"Conv2d": "Kron", | ||
"Linear": "Kron" | ||
}, | ||
"lr": 4e-3, | ||
"momentum": 0.9, | ||
"momentum_type": "preconditioned", | ||
"weight_decay": 0.1, | ||
"num_mc_samples": 4, | ||
"val_num_mc_samples": 0, | ||
"kl_weighting": 0.2, | ||
"prior_variance": 1 | ||
}, | ||
"curv_args": { | ||
"damping": 1e-4, | ||
"ema_decay": 0.333, | ||
"pi_type": "tracenorm" | ||
}, | ||
"fisher_args": { | ||
"approx_type": "mc", | ||
"num_mc": 1 | ||
}, | ||
"scheduler_name": "ExponentialLR", | ||
"scheduler_args": { | ||
"gamma": 0.9 | ||
}, | ||
"no_cuda": false | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
{ | ||
"dataset": "CIFAR-10", | ||
"epochs": 30, | ||
"batch_size": 128, | ||
"val_batch_size": 128, | ||
"random_crop": false, | ||
"random_horizontal_flip": false, | ||
"normalizing_data": true, | ||
"arch_file": "models/lenet.py", | ||
"arch_name": "LeNet5BatchNorm", | ||
"arch_args": { | ||
"affine": true | ||
}, | ||
"optim_name": "VIOptimizer", | ||
"optim_args": { | ||
"curv_type": "Cov", | ||
"curv_shapes": { | ||
"Conv2d": "Diag", | ||
"Linear": "Diag", | ||
"BatchNorm1d": "Diag", | ||
"BatchNorm2d": "Diag" | ||
}, | ||
"lr": 0.01, | ||
"grad_ema_decay": 0.1, | ||
"grad_ema_type": "raw", | ||
"num_mc_samples": 10, | ||
"val_num_mc_samples": 0, | ||
"kl_weighting": 1, | ||
"init_precision": 8e-3, | ||
"prior_variance": 1, | ||
"acc_steps": 1 | ||
}, | ||
"curv_args": { | ||
"damping": 0, | ||
"ema_decay": 0.001 | ||
}, | ||
"scheduler_name": "ExponentialLR", | ||
"scheduler_args": { | ||
"gamma": 0.9 | ||
}, | ||
"no_cuda": false | ||
} |
Oops, something went wrong.