Skip to content

Commit

Permalink
init commit
Browse files Browse the repository at this point in the history
  • Loading branch information
kazukiosawa committed Oct 10, 2019
1 parent c858530 commit 56b9725
Show file tree
Hide file tree
Showing 84 changed files with 6,436 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,6 @@ venv.bak/

# mypy
.mypy_cache/

# PyCharm
.idea
108 changes: 108 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@


# PyTorch-SSO

Scalable Second-Order methods in PyTorch.

- Open-source library for second-order optimization and Bayesian inference.

- An earlier iteration of this library ([chainerkfac](https://github.com/tyohei/chainerkfac)) holds the world record for large-batch training of ResNet-50 on ImageNet by [Kronecker-Factored Approximate Curvature (K-FAC)](https://arxiv.org/abs/1503.05671), scaling to batch sizes of 131K.
- Kazuki Osawa et al, “Large-Scale Distributed Second-Order Optimization Using Kronecker-Factored Approximate Curvature for Deep Convolutional Neural Networks”, **IEEE/CVF CVPR 2019**.
- [[paper](http://openaccess.thecvf.com/content_CVPR_2019/html/Osawa_Large-Scale_Distributed_Second-Order_Optimization_Using_Kronecker-Factored_Approximate_Curvature_for_Deep_CVPR_2019_paper.html)] [[poster](https://kazukiosawa.github.io/cvpr19_poster.pdf)]
- This library is basis for the Natural Gradient for Bayesian inference (Variational Inference) on ImageNet.
- Kazuki Osawa et al, “Practical Deep Learning with Bayesian Principles”, **NeurIPS 2019**.
- [[paper (preprint)](https://arxiv.org/abs/1906.02506)]

## Scalable Second-Order Optimization

### Optimizers

PyTorch-SSO provides the following optimizers.

- Second-Order Optimization
- `torchsso.optim.SecondOrderOptimizer` [[source](https://github.com/cybertronai/pytorch-sso/blob/master/torchsso/optim/secondorder.py)]
- updates the parameters with the gradients pre-conditioned by the curvature of the loss function (`torch.nn.functional.cross_entropy`) for each `param_group`.
- Variational Inference (VI)
- `torchsso.optim.VIOptimizer` [[source](https://github.com/cybertronai/pytorch-sso/blob/master/torchsso/optim/vi.py)]
- updates the posterior distribution (mean, covariance) of the parameters by using the curvature for each `param_group`.

### Curvatures

You can specify a type of the information matrix to be used as the curvature from the following.

- Hessian [WIP]

- Fisher information matrix

- Covariance matrix (empirical Fisher)



Refer [Information matrices and generalization](https://arxiv.org/abs/1906.07774) by Valentin Thomas et al. (2019) for the definitions and the properties of these information matrices.



Refer Section 6 of [Optimization Methods for Large-Scale Machine Learning](https://arxiv.org/abs/1606.04838) by L´eon Bottou et al. (2018) for a clear explanation of the second-order optimzation using these matrices as curvature.

### Approximation Methods

![](docs/overview.png)

PyTorch-SSO calculates the curvature as a layer-wise block-diagonal matrix.

You can specify the approximation method for the curvatures in each layer from the follwing.

1. Full (No approximation)
2. Diagonal approximation
3. [Kronecker-Factored Approximate Curvature (K-FAC)](https://arxiv.org/abs/1503.05671)

PyTorch-SSO currently supports the following layers (Modules) in PyTorch:

| Layer (Module) | Full | Diagonal | K-FAC |
| ------------------------- | ------------------ | ------------------ | ------------------ |
| `torch.nn.Linear` | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| `torch.nn.Conv2d` | - | :heavy_check_mark: | :heavy_check_mark: |
| `torch.nn.BatchNorm1d/2d` | - | :heavy_check_mark: | - |

To apply PyTorch-SSO,
- Set`requires_grad` to `True` for each Module.
- The network you define cannot contain any other modules.
- E.g., You need to use `torch.nn.functional.relu/max_pool2d` instead of `torch.nn.ReLU/MaxPool2d` to define a ConvNet.

### Distributed Training

PyTorch-SSO supports *data parallelism* and *MC samples parallelism* (for VI)
for distributed training among multiple processes (GPUs).

## Installation
To build PyTorch-SSO run (on a Python 3 environment)
```bash
git clone [email protected]:cybertronai/pytorch-sso.git
cd pytorch-sso
python setup.py install
```

To use the library
```python
import torchsso
```

### Additional requirements

PyTorch-SSO depends on [CuPy](https://cupy.chainer.org/) for fast GPU computation and [ChainerMN](https://github.com/chainer/chainermn) for communication. To use GPUs, you need to install the following requirements **before the installation of PyTorch-SSO**.

| Running environment | Requirements |
| ------------------- | ---------------------- |
| single GPU | CuPy |
| multiple GPUs | Cupy with NCCL, MPI4py |

Refer [CuPy installation guide](https://docs-cupy.chainer.org/en/stable/install.html) and [ChainerMN installation guide](https://docs.chainer.org/en/stable/chainermn/installation/guide.html#chainermn-installation) for details.

## Examples

- [Image classification with a single process](https://github.com/cybertronai/pytorch-sso/tree/master/examples/classification) (MNIST, CIFAR-10)
- [Image classification with multiple processes](https://github.com/cybertronai/pytorch-sso/tree/master/examples/distributed/classification) (CIFAR-10/100, ImageNet)

## Authors

Kazuki Osawa ([@kazukiosawa](https://github.com/kazukiosawa)) and Yaroslav Bulatov ([@yaroslavvb](https://github.com/yaroslavvb))
Binary file added docs/distributed_vi.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/overview.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
10 changes: 10 additions & 0 deletions examples/classification/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
To run training LeNet-5 for CIFAR-10 classification
```bash
python main.py --config <path/to/config> --download
```
| optimizer | dataset | architecture | config file path |
| --- | --- | --- | --- |
| [Adam](https://arxiv.org/abs/1412.6980) | CIFAR-10 | LeNet-5 | [configs/cifar10/lenet_adam.json](https://github.com/cybertronai/pytorch-sso/blob/master/examples/classification/configs/cifar10/lenet_adam.json) |
| [K-FAC](https://arxiv.org/abs/1503.05671)| CIFAR-10 | LeNet-5 | [configs/cifar10/lenet_kfac.json](https://github.com/cybertronai/pytorch-sso/blob/master/examples/classification/configs/cifar10/lenet_kfac.json) |
| [Noisy K-FAC](https://arxiv.org/abs/1712.02390)| CIFAR-10 | LeNet-5 | [configs/cifar10/lenet_noisykfac.json](https://github.com/cybertronai/pytorch-sso/blob/master/examples/classification/configs/cifar10/lenet_noisykfac.json) |
| [VOGN](https://arxiv.org/abs/1806.04854)| CIFAR-10 | LeNet-5 + BatchNorm | [configs/cifar10/lenet_vogn.json](https://github.com/cybertronai/pytorch-sso/blob/master/examples/classification/configs/cifar10/lenet_vogn.json) |
17 changes: 17 additions & 0 deletions examples/classification/configs/cifar10/lenet_adam.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"dataset": "CIFAR-10",
"epochs": 100,
"batch_size": 128,
"val_batch_size": 128,
"random_crop": false,
"random_horizontal_flip": false,
"normalizing_data": true,
"arch_file": "models/lenet.py",
"arch_name": "LeNet5",
"optim_name": "Adam",
"optim_args": {
"lr": 1e-3,
"betas": [0.9, 0.999],
"weight_decay": 0.01
}
}
41 changes: 41 additions & 0 deletions examples/classification/configs/cifar10/lenet_kfac.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{
"dataset": "CIFAR-10",
"epochs": 50,
"batch_size": 128,
"val_batch_size": 5000,
"random_crop": true,
"random_horizontal_flip": true,
"normalizing_data": true,
"arch_file": "models/lenet.py",
"arch_name": "LeNet5",
"optim_name": "SecondOrderOptimizer",
"optim_args": {
"curv_type":"Fisher",
"curv_shapes": {
"Conv2d": "Kron",
"Linear": "Kron",
"BatchNorm1d": "Diag",
"BatchNorm2d": "Diag"
},
"lr": 1e-3,
"momentum": 0.9,
"momentum_type": "raw",
"l2_reg": 1e-3,
"acc_steps": 1
},
"curv_args": {
"damping": 1e-3,
"ema_decay": 0.999,
"pi_type": "tracenorm"
},
"fisher_args": {
"approx_type": "mc",
"num_mc": 1
},
"scheduler_name": "ExponentialLR",
"scheduler_args": {
"gamma": 0.9
},
"log_interval": 64,
"no_cuda": false
}
41 changes: 41 additions & 0 deletions examples/classification/configs/cifar10/lenet_noisykfac.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{
"dataset": "CIFAR-10",
"epochs": 15,
"batch_size": 64,
"val_batch_size": 128,
"random_crop": false,
"random_horizontal_flip": false,
"normalizing_data": false,
"arch_file": "models/lenet.py",
"arch_name": "LeNet5",
"optim_name": "VIOptimizer",
"optim_args": {
"curv_type": "Fisher",
"curv_shapes": {
"Conv2d": "Kron",
"Linear": "Kron"
},
"lr": 4e-3,
"momentum": 0.9,
"momentum_type": "preconditioned",
"weight_decay": 0.1,
"num_mc_samples": 4,
"val_num_mc_samples": 0,
"kl_weighting": 0.2,
"prior_variance": 1
},
"curv_args": {
"damping": 1e-4,
"ema_decay": 0.333,
"pi_type": "tracenorm"
},
"fisher_args": {
"approx_type": "mc",
"num_mc": 1
},
"scheduler_name": "ExponentialLR",
"scheduler_args": {
"gamma": 0.9
},
"no_cuda": false
}
42 changes: 42 additions & 0 deletions examples/classification/configs/cifar10/lenet_vogn.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
{
"dataset": "CIFAR-10",
"epochs": 30,
"batch_size": 128,
"val_batch_size": 128,
"random_crop": false,
"random_horizontal_flip": false,
"normalizing_data": true,
"arch_file": "models/lenet.py",
"arch_name": "LeNet5BatchNorm",
"arch_args": {
"affine": true
},
"optim_name": "VIOptimizer",
"optim_args": {
"curv_type": "Cov",
"curv_shapes": {
"Conv2d": "Diag",
"Linear": "Diag",
"BatchNorm1d": "Diag",
"BatchNorm2d": "Diag"
},
"lr": 0.01,
"grad_ema_decay": 0.1,
"grad_ema_type": "raw",
"num_mc_samples": 10,
"val_num_mc_samples": 0,
"kl_weighting": 1,
"init_precision": 8e-3,
"prior_variance": 1,
"acc_steps": 1
},
"curv_args": {
"damping": 0,
"ema_decay": 0.001
},
"scheduler_name": "ExponentialLR",
"scheduler_args": {
"gamma": 0.9
},
"no_cuda": false
}
Loading

0 comments on commit 56b9725

Please sign in to comment.