diff --git a/.gitignore b/.gitignore index 894a44c..376f22a 100644 --- a/.gitignore +++ b/.gitignore @@ -102,3 +102,6 @@ venv.bak/ # mypy .mypy_cache/ + +# PyCharm +.idea diff --git a/README.md b/README.md new file mode 100644 index 0000000..e395a26 --- /dev/null +++ b/README.md @@ -0,0 +1,108 @@ + + +# PyTorch-SSO + +Scalable Second-Order methods in PyTorch. + +- Open-source library for second-order optimization and Bayesian inference. + +- An earlier iteration of this library ([chainerkfac](https://github.com/tyohei/chainerkfac)) holds the world record for large-batch training of ResNet-50 on ImageNet by [Kronecker-Factored Approximate Curvature (K-FAC)](https://arxiv.org/abs/1503.05671), scaling to batch sizes of 131K. + - Kazuki Osawa et al, “Large-Scale Distributed Second-Order Optimization Using Kronecker-Factored Approximate Curvature for Deep Convolutional Neural Networks”, **IEEE/CVF CVPR 2019**. + - [[paper](http://openaccess.thecvf.com/content_CVPR_2019/html/Osawa_Large-Scale_Distributed_Second-Order_Optimization_Using_Kronecker-Factored_Approximate_Curvature_for_Deep_CVPR_2019_paper.html)] [[poster](https://kazukiosawa.github.io/cvpr19_poster.pdf)] +- This library is basis for the Natural Gradient for Bayesian inference (Variational Inference) on ImageNet. + - Kazuki Osawa et al, “Practical Deep Learning with Bayesian Principles”, **NeurIPS 2019**. + - [[paper (preprint)](https://arxiv.org/abs/1906.02506)] + +## Scalable Second-Order Optimization + +### Optimizers + +PyTorch-SSO provides the following optimizers. + +- Second-Order Optimization + - `torchsso.optim.SecondOrderOptimizer` [[source](https://github.com/cybertronai/pytorch-sso/blob/master/torchsso/optim/secondorder.py)] + - updates the parameters with the gradients pre-conditioned by the curvature of the loss function (`torch.nn.functional.cross_entropy`) for each `param_group`. +- Variational Inference (VI) + - `torchsso.optim.VIOptimizer` [[source](https://github.com/cybertronai/pytorch-sso/blob/master/torchsso/optim/vi.py)] + - updates the posterior distribution (mean, covariance) of the parameters by using the curvature for each `param_group`. + +### Curvatures + +You can specify a type of the information matrix to be used as the curvature from the following. + +- Hessian [WIP] + +- Fisher information matrix + +- Covariance matrix (empirical Fisher) + + + +Refer [Information matrices and generalization](https://arxiv.org/abs/1906.07774) by Valentin Thomas et al. (2019) for the definitions and the properties of these information matrices. + + + +Refer Section 6 of [Optimization Methods for Large-Scale Machine Learning](https://arxiv.org/abs/1606.04838) by L´eon Bottou et al. (2018) for a clear explanation of the second-order optimzation using these matrices as curvature. + +### Approximation Methods + +![](docs/overview.png) + +PyTorch-SSO calculates the curvature as a layer-wise block-diagonal matrix. + +You can specify the approximation method for the curvatures in each layer from the follwing. + +1. Full (No approximation) +2. Diagonal approximation +3. [Kronecker-Factored Approximate Curvature (K-FAC)](https://arxiv.org/abs/1503.05671) + +PyTorch-SSO currently supports the following layers (Modules) in PyTorch: + +| Layer (Module) | Full | Diagonal | K-FAC | +| ------------------------- | ------------------ | ------------------ | ------------------ | +| `torch.nn.Linear` | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | +| `torch.nn.Conv2d` | - | :heavy_check_mark: | :heavy_check_mark: | +| `torch.nn.BatchNorm1d/2d` | - | :heavy_check_mark: | - | + +To apply PyTorch-SSO, +- Set`requires_grad` to `True` for each Module. +- The network you define cannot contain any other modules. +- E.g., You need to use `torch.nn.functional.relu/max_pool2d` instead of `torch.nn.ReLU/MaxPool2d` to define a ConvNet. + +### Distributed Training + +PyTorch-SSO supports *data parallelism* and *MC samples parallelism* (for VI) +for distributed training among multiple processes (GPUs). + +## Installation +To build PyTorch-SSO run (on a Python 3 environment) +```bash +git clone git@github.com:cybertronai/pytorch-sso.git +cd pytorch-sso +python setup.py install +``` + +To use the library +```python +import torchsso +``` + +### Additional requirements + +PyTorch-SSO depends on [CuPy](https://cupy.chainer.org/) for fast GPU computation and [ChainerMN](https://github.com/chainer/chainermn) for communication. To use GPUs, you need to install the following requirements **before the installation of PyTorch-SSO**. + +| Running environment | Requirements | +| ------------------- | ---------------------- | +| single GPU | CuPy | +| multiple GPUs | Cupy with NCCL, MPI4py | + +Refer [CuPy installation guide](https://docs-cupy.chainer.org/en/stable/install.html) and [ChainerMN installation guide](https://docs.chainer.org/en/stable/chainermn/installation/guide.html#chainermn-installation) for details. + +## Examples + +- [Image classification with a single process](https://github.com/cybertronai/pytorch-sso/tree/master/examples/classification) (MNIST, CIFAR-10) +- [Image classification with multiple processes](https://github.com/cybertronai/pytorch-sso/tree/master/examples/distributed/classification) (CIFAR-10/100, ImageNet) + +## Authors + +Kazuki Osawa ([@kazukiosawa](https://github.com/kazukiosawa)) and Yaroslav Bulatov ([@yaroslavvb](https://github.com/yaroslavvb)) diff --git a/docs/distributed_vi.png b/docs/distributed_vi.png new file mode 100644 index 0000000..2d9fb89 Binary files /dev/null and b/docs/distributed_vi.png differ diff --git a/docs/overview.png b/docs/overview.png new file mode 100644 index 0000000..0601e3c Binary files /dev/null and b/docs/overview.png differ diff --git a/examples/classification/README.md b/examples/classification/README.md new file mode 100644 index 0000000..08c86ef --- /dev/null +++ b/examples/classification/README.md @@ -0,0 +1,10 @@ +To run training LeNet-5 for CIFAR-10 classification +```bash +python main.py --config --download +``` +| optimizer | dataset | architecture | config file path | +| --- | --- | --- | --- | +| [Adam](https://arxiv.org/abs/1412.6980) | CIFAR-10 | LeNet-5 | [configs/cifar10/lenet_adam.json](https://github.com/cybertronai/pytorch-sso/blob/master/examples/classification/configs/cifar10/lenet_adam.json) | +| [K-FAC](https://arxiv.org/abs/1503.05671)| CIFAR-10 | LeNet-5 | [configs/cifar10/lenet_kfac.json](https://github.com/cybertronai/pytorch-sso/blob/master/examples/classification/configs/cifar10/lenet_kfac.json) | +| [Noisy K-FAC](https://arxiv.org/abs/1712.02390)| CIFAR-10 | LeNet-5 | [configs/cifar10/lenet_noisykfac.json](https://github.com/cybertronai/pytorch-sso/blob/master/examples/classification/configs/cifar10/lenet_noisykfac.json) | +| [VOGN](https://arxiv.org/abs/1806.04854)| CIFAR-10 | LeNet-5 + BatchNorm | [configs/cifar10/lenet_vogn.json](https://github.com/cybertronai/pytorch-sso/blob/master/examples/classification/configs/cifar10/lenet_vogn.json) | diff --git a/examples/classification/configs/cifar10/lenet_adam.json b/examples/classification/configs/cifar10/lenet_adam.json new file mode 100644 index 0000000..553dff3 --- /dev/null +++ b/examples/classification/configs/cifar10/lenet_adam.json @@ -0,0 +1,17 @@ +{ + "dataset": "CIFAR-10", + "epochs": 100, + "batch_size": 128, + "val_batch_size": 128, + "random_crop": false, + "random_horizontal_flip": false, + "normalizing_data": true, + "arch_file": "models/lenet.py", + "arch_name": "LeNet5", + "optim_name": "Adam", + "optim_args": { + "lr": 1e-3, + "betas": [0.9, 0.999], + "weight_decay": 0.01 + } +} \ No newline at end of file diff --git a/examples/classification/configs/cifar10/lenet_kfac.json b/examples/classification/configs/cifar10/lenet_kfac.json new file mode 100644 index 0000000..f4f08a0 --- /dev/null +++ b/examples/classification/configs/cifar10/lenet_kfac.json @@ -0,0 +1,41 @@ +{ + "dataset": "CIFAR-10", + "epochs": 50, + "batch_size": 128, + "val_batch_size": 5000, + "random_crop": true, + "random_horizontal_flip": true, + "normalizing_data": true, + "arch_file": "models/lenet.py", + "arch_name": "LeNet5", + "optim_name": "SecondOrderOptimizer", + "optim_args": { + "curv_type":"Fisher", + "curv_shapes": { + "Conv2d": "Kron", + "Linear": "Kron", + "BatchNorm1d": "Diag", + "BatchNorm2d": "Diag" + }, + "lr": 1e-3, + "momentum": 0.9, + "momentum_type": "raw", + "l2_reg": 1e-3, + "acc_steps": 1 + }, + "curv_args": { + "damping": 1e-3, + "ema_decay": 0.999, + "pi_type": "tracenorm" + }, + "fisher_args": { + "approx_type": "mc", + "num_mc": 1 + }, + "scheduler_name": "ExponentialLR", + "scheduler_args": { + "gamma": 0.9 + }, + "log_interval": 64, + "no_cuda": false +} \ No newline at end of file diff --git a/examples/classification/configs/cifar10/lenet_noisykfac.json b/examples/classification/configs/cifar10/lenet_noisykfac.json new file mode 100644 index 0000000..7429be9 --- /dev/null +++ b/examples/classification/configs/cifar10/lenet_noisykfac.json @@ -0,0 +1,41 @@ +{ + "dataset": "CIFAR-10", + "epochs": 15, + "batch_size": 64, + "val_batch_size": 128, + "random_crop": false, + "random_horizontal_flip": false, + "normalizing_data": false, + "arch_file": "models/lenet.py", + "arch_name": "LeNet5", + "optim_name": "VIOptimizer", + "optim_args": { + "curv_type": "Fisher", + "curv_shapes": { + "Conv2d": "Kron", + "Linear": "Kron" + }, + "lr": 4e-3, + "momentum": 0.9, + "momentum_type": "preconditioned", + "weight_decay": 0.1, + "num_mc_samples": 4, + "val_num_mc_samples": 0, + "kl_weighting": 0.2, + "prior_variance": 1 + }, + "curv_args": { + "damping": 1e-4, + "ema_decay": 0.333, + "pi_type": "tracenorm" + }, + "fisher_args": { + "approx_type": "mc", + "num_mc": 1 + }, + "scheduler_name": "ExponentialLR", + "scheduler_args": { + "gamma": 0.9 + }, + "no_cuda": false +} diff --git a/examples/classification/configs/cifar10/lenet_vogn.json b/examples/classification/configs/cifar10/lenet_vogn.json new file mode 100644 index 0000000..f78ef16 --- /dev/null +++ b/examples/classification/configs/cifar10/lenet_vogn.json @@ -0,0 +1,42 @@ +{ + "dataset": "CIFAR-10", + "epochs": 30, + "batch_size": 128, + "val_batch_size": 128, + "random_crop": false, + "random_horizontal_flip": false, + "normalizing_data": true, + "arch_file": "models/lenet.py", + "arch_name": "LeNet5BatchNorm", + "arch_args": { + "affine": true + }, + "optim_name": "VIOptimizer", + "optim_args": { + "curv_type": "Cov", + "curv_shapes": { + "Conv2d": "Diag", + "Linear": "Diag", + "BatchNorm1d": "Diag", + "BatchNorm2d": "Diag" + }, + "lr": 0.01, + "grad_ema_decay": 0.1, + "grad_ema_type": "raw", + "num_mc_samples": 10, + "val_num_mc_samples": 0, + "kl_weighting": 1, + "init_precision": 8e-3, + "prior_variance": 1, + "acc_steps": 1 + }, + "curv_args": { + "damping": 0, + "ema_decay": 0.001 + }, + "scheduler_name": "ExponentialLR", + "scheduler_args": { + "gamma": 0.9 + }, + "no_cuda": false +} diff --git a/examples/classification/main.py b/examples/classification/main.py new file mode 100644 index 0000000..b01e6fc --- /dev/null +++ b/examples/classification/main.py @@ -0,0 +1,385 @@ +import os +import argparse +from importlib import import_module +import shutil +import json + +import torch +import torch.nn.functional as F +from torchvision import datasets, transforms, models +import torchsso +from torchsso.optim import SecondOrderOptimizer, VIOptimizer +from torchsso.utils import Logger + +DATASET_CIFAR10 = 'CIFAR-10' +DATASET_CIFAR100 = 'CIFAR-100' +DATASET_MNIST = 'MNIST' + + +def main(): + parser = argparse.ArgumentParser() + # Data + parser.add_argument('--dataset', type=str, + choices=[DATASET_CIFAR10, DATASET_CIFAR100, DATASET_MNIST], default=DATASET_CIFAR10, + help='name of dataset') + parser.add_argument('--root', type=str, default='./data', + help='root of dataset') + parser.add_argument('--epochs', type=int, default=10, + help='number of epochs to train') + parser.add_argument('--batch_size', type=int, default=128, + help='input batch size for training') + parser.add_argument('--val_batch_size', type=int, default=128, + help='input batch size for valing') + parser.add_argument('--normalizing_data', action='store_true', + help='[data pre processing] normalizing data') + parser.add_argument('--random_crop', action='store_true', + help='[data augmentation] random crop') + parser.add_argument('--random_horizontal_flip', action='store_true', + help='[data augmentation] random horizontal flip') + # Training Settings + parser.add_argument('--arch_file', type=str, default=None, + help='name of file which defines the architecture') + parser.add_argument('--arch_name', type=str, default='LeNet5', + help='name of the architecture') + parser.add_argument('--arch_args', type=json.loads, default=None, + help='[JSON] arguments for the architecture') + parser.add_argument('--optim_name', type=str, default=SecondOrderOptimizer.__name__, + help='name of the optimizer') + parser.add_argument('--optim_args', type=json.loads, default=None, + help='[JSON] arguments for the optimizer') + parser.add_argument('--curv_args', type=json.loads, default=dict(), + help='[JSON] arguments for the curvature') + parser.add_argument('--fisher_args', type=json.loads, default=dict(), + help='[JSON] arguments for the fisher') + parser.add_argument('--scheduler_name', type=str, default=None, + help='name of the learning rate scheduler') + parser.add_argument('--scheduler_args', type=json.loads, default=None, + help='[JSON] arguments for the scheduler') + # Options + parser.add_argument('--download', action='store_true', default=False, + help='if True, downloads the dataset (CIFAR-10 or 100) from the internet') + parser.add_argument('--create_graph', action='store_true', default=False, + help='create graph of the derivative') + parser.add_argument('--no_cuda', action='store_true', default=False, + help='disables CUDA training') + parser.add_argument('--seed', type=int, default=1, + help='random seed') + parser.add_argument('--num_workers', type=int, default=0, + help='number of sub processes for data loading') + parser.add_argument('--log_interval', type=int, default=50, + help='how many batches to wait before logging training status') + parser.add_argument('--log_file_name', type=str, default='log', + help='log file name') + parser.add_argument('--checkpoint_interval', type=int, default=50, + help='how many epochs to wait before logging training status') + parser.add_argument('--resume', type=str, default=None, + help='checkpoint path for resume training') + parser.add_argument('--out', type=str, default='result', + help='dir to save output files') + parser.add_argument('--config', default='configs/cifar10/lenet_kfac.json', + help='config file path') + + args = parser.parse_args() + dict_args = vars(args) + + # Load config file + if args.config is not None: + with open(args.config) as f: + config = json.load(f) + dict_args.update(config) + + # Set device + use_cuda = not args.no_cuda and torch.cuda.is_available() + device = torch.device('cuda' if use_cuda else 'cpu') + + # Set random seed + torch.manual_seed(args.seed) + + # Setup data augmentation & data pre processing + train_transforms, val_transforms = [], [] + if args.random_crop: + train_transforms.append(transforms.RandomCrop(32, padding=4)) + + if args.random_horizontal_flip: + train_transforms.append(transforms.RandomHorizontalFlip()) + + train_transforms.append(transforms.ToTensor()) + val_transforms.append(transforms.ToTensor()) + + if args.normalizing_data: + normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) + train_transforms.append(normalize) + val_transforms.append(normalize) + + train_transform = transforms.Compose(train_transforms) + val_transform = transforms.Compose(val_transforms) + + # Setup data loader + if args.dataset == DATASET_CIFAR10: + # CIFAR-10 + num_classes = 10 + dataset_class = datasets.CIFAR10 + elif args.dataset == DATASET_CIFAR100: + # CIFAR-100 + num_classes = 100 + dataset_class = datasets.CIFAR100 + elif args.dataset == DATASET_MNIST: + num_classes = 10 + dataset_class = datasets.MNIST + else: + assert False, f'unknown dataset {args.dataset}' + + train_dataset = dataset_class( + root=args.root, train=True, download=args.download, transform=train_transform) + val_dataset = dataset_class( + root=args.root, train=False, download=args.download, transform=val_transform) + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) + val_loader = torch.utils.data.DataLoader( + val_dataset, batch_size=args.val_batch_size, shuffle=False, num_workers=args.num_workers) + + # Setup model + if args.arch_file is None: + arch_class = getattr(models, args.arch_name) + else: + _, ext = os.path.splitext(args.arch_file) + dirname = os.path.dirname(args.arch_file) + + if dirname == '': + module_path = args.arch_file.replace(ext, '') + elif dirname == '.': + module_path = os.path.basename(args.arch_file).replace(ext, '') + else: + module_path = '.'.join(os.path.split(args.arch_file)).replace(ext, '') + + module = import_module(module_path) + arch_class = getattr(module, args.arch_name) + + arch_kwargs = {} if args.arch_args is None else args.arch_args + arch_kwargs['num_classes'] = num_classes + + model = arch_class(**arch_kwargs) + setattr(model, 'num_classes', num_classes) + model = model.to(device) + + optim_kwargs = {} if args.optim_args is None else args.optim_args + + # Setup optimizer + if args.optim_name == SecondOrderOptimizer.__name__: + optimizer = SecondOrderOptimizer(model, **optim_kwargs, curv_kwargs=args.curv_args) + elif args.optim_name == VIOptimizer.__name__: + optimizer = VIOptimizer(model, dataset_size=len(train_loader.dataset), seed=args.seed, + **optim_kwargs, curv_kwargs=args.curv_args) + else: + optim_class = getattr(torch.optim, args.optim_name) + optimizer = optim_class(model.parameters(), **optim_kwargs) + + # Setup lr scheduler + if args.scheduler_name is None: + scheduler = None + else: + scheduler_class = getattr(torchsso.optim.lr_scheduler, args.scheduler_name, None) + if scheduler_class is None: + scheduler_class = getattr(torch.optim.lr_scheduler, args.scheduler_name) + scheduler_kwargs = {} if args.scheduler_args is None else args.scheduler_args + scheduler = scheduler_class(optimizer, **scheduler_kwargs) + + start_epoch = 1 + + # Load checkpoint + if args.resume is not None: + print('==> Resuming from checkpoint..') + assert os.path.exists(args.resume), 'Error: no checkpoint file found' + checkpoint = torch.load(args.resume) + model.load_state_dict(checkpoint['model']) + start_epoch = checkpoint['epoch'] + + # All config + print('===========================') + for key, val in vars(args).items(): + if key == 'dataset': + print('{}: {}'.format(key, val)) + print('train data size: {}'.format(len(train_loader.dataset))) + print('val data size: {}'.format(len(val_loader.dataset))) + else: + print('{}: {}'.format(key, val)) + print('===========================') + + # Copy this file & config to args.out + if not os.path.isdir(args.out): + os.makedirs(args.out) + shutil.copy(os.path.realpath(__file__), args.out) + + if args.config is not None: + shutil.copy(args.config, args.out) + if args.arch_file is not None: + shutil.copy(args.arch_file, args.out) + + # Setup logger + logger = Logger(args.out, args.log_file_name) + logger.start() + + # Run training + for epoch in range(start_epoch, args.epochs + 1): + + # train + accuracy, loss, confidence = train(model, device, train_loader, optimizer, scheduler, epoch, args, logger) + + # val + val_accuracy, val_loss = validate(model, device, val_loader, optimizer) + + # save log + iteration = epoch * len(train_loader) + log = {'epoch': epoch, 'iteration': iteration, + 'accuracy': accuracy, 'loss': loss, 'confidence': confidence, + 'val_accuracy': val_accuracy, 'val_loss': val_loss, + 'lr': optimizer.param_groups[0]['lr'], + 'momentum': optimizer.param_groups[0].get('momentum', 0)} + logger.write(log) + + # save checkpoint + if epoch % args.checkpoint_interval == 0 or epoch == args.epochs: + path = os.path.join(args.out, 'epoch{}.ckpt'.format(epoch)) + data = { + 'model': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'epoch': epoch + } + torch.save(data, path) + + +def train(model, device, train_loader, optimizer, scheduler, epoch, args, logger): + + def scheduler_type(_scheduler): + if _scheduler is None: + return 'none' + return getattr(_scheduler, 'scheduler_type', 'epoch') + + if scheduler_type(scheduler) == 'epoch': + scheduler.step(epoch - 1) + + model.train() + + total_correct = 0 + loss = None + confidence = {'top1': 0, 'top1_true': 0, 'top1_false': 0, 'true': 0, 'false': 0} + total_data_size = 0 + epoch_size = len(train_loader.dataset) + num_iters_in_epoch = len(train_loader) + base_num_iter = (epoch - 1) * num_iters_in_epoch + + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + + if scheduler_type(scheduler) == 'iter': + scheduler.step() + + for name, param in model.named_parameters(): + attr = 'p_pre_{}'.format(name) + setattr(model, attr, param.detach().clone()) + + # update params + def closure(): + optimizer.zero_grad() + output = model(data) + loss = F.cross_entropy(output, target) + loss.backward(create_graph=args.create_graph) + + return loss, output + + if isinstance(optimizer, SecondOrderOptimizer) and optimizer.curv_type == 'Fisher': + closure = torchsso.get_closure_for_fisher(optimizer, model, data, target, **args.fisher_args) + + loss, output = optimizer.step(closure=closure) + + pred = output.argmax(dim=1, keepdim=True) + correct = pred.eq(target.view_as(pred)).sum().item() + + loss = loss.item() + total_correct += correct + + prob = F.softmax(output, dim=1) + for p, idx in zip(prob, target): + confidence['top1'] += torch.max(p).item() + top1 = torch.argmax(p).item() + if top1 == idx: + confidence['top1_true'] += p[top1].item() + else: + confidence['top1_false'] += p[top1].item() + confidence['true'] += p[idx].item() + confidence['false'] += (1 - p[idx].item()) + + iteration = base_num_iter + batch_idx + 1 + total_data_size += len(data) + + if batch_idx % args.log_interval == 0: + accuracy = 100. * total_correct / total_data_size + elapsed_time = logger.elapsed_time + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}, ' + 'Accuracy: {:.0f}/{} ({:.2f}%), ' + 'Elapsed Time: {:.1f}s'.format( + epoch, total_data_size, epoch_size, 100. * (batch_idx + 1) / num_iters_in_epoch, + loss, total_correct, total_data_size, accuracy, elapsed_time)) + + # save log + lr = optimizer.param_groups[0]['lr'] + log = {'epoch': epoch, 'iteration': iteration, 'elapsed_time': elapsed_time, + 'accuracy': accuracy, 'loss': loss, 'lr': lr} + + for name, param in model.named_parameters(): + attr = 'p_pre_{}'.format(name) + p_pre = getattr(model, attr) + p_norm = param.norm().item() + p_shape = list(param.size()) + p_pre_norm = p_pre.norm().item() + g_norm = param.grad.norm().item() + upd_norm = param.sub(p_pre).norm().item() + noise_scale = getattr(param, 'noise_scale', 0) + + p_log = {'p_shape': p_shape, 'p_norm': p_norm, 'p_pre_norm': p_pre_norm, + 'g_norm': g_norm, 'upd_norm': upd_norm, 'noise_scale': noise_scale} + log[name] = p_log + + logger.write(log) + + accuracy = 100. * total_correct / epoch_size + confidence['top1'] /= epoch_size + confidence['top1_true'] /= total_correct + confidence['top1_false'] /= (epoch_size - total_correct) + confidence['true'] /= epoch_size + confidence['false'] /= (epoch_size * (model.num_classes - 1)) + + return accuracy, loss, confidence + + +def validate(model, device, val_loader, optimizer): + model.eval() + val_loss = 0 + correct = 0 + + with torch.no_grad(): + for data, target in val_loader: + + data, target = data.to(device), target.to(device) + + if isinstance(optimizer, VIOptimizer): + output = optimizer.prediction(data) + else: + output = model(data) + + val_loss += F.cross_entropy(output, target, reduction='sum').item() # sum up batch loss + pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability + correct += pred.eq(target.view_as(pred)).sum().item() + + val_loss /= len(val_loader.dataset) + val_accuracy = 100. * correct / len(val_loader.dataset) + + print('\nEval: Average loss: {:.4f}, Accuracy: {:.0f}/{} ({:.2f}%)\n'.format( + val_loss, correct, len(val_loader.dataset), val_accuracy)) + + return val_accuracy, val_loss + + +if __name__ == '__main__': + main() diff --git a/examples/classification/models/__init__.py b/examples/classification/models/__init__.py new file mode 100644 index 0000000..4f1b169 --- /dev/null +++ b/examples/classification/models/__init__.py @@ -0,0 +1,5 @@ +from .vgg import * +from .lenet import * +from .resnet import * +from .alexnet import * +from .mlp import * diff --git a/examples/classification/models/alexnet.py b/examples/classification/models/alexnet.py new file mode 100644 index 0000000..f3b1daf --- /dev/null +++ b/examples/classification/models/alexnet.py @@ -0,0 +1,87 @@ +'''AlexNet for CIFAR10. FC layers are removed. Paddings are adjusted. +Without BN, the start learning rate should be 0.01 +(c) YANG, Wei +''' +import torch.nn as nn +import torch.nn.functional as F +from torchsso.utils.accumulator import TensorAccumulator + + +__all__ = ['alexnet', 'alexnet_mcdropout'] + + +class AlexNet(nn.Module): + + def __init__(self, num_classes=10): + super().__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5) + self.conv2 = nn.Conv2d(64, 192, kernel_size=5, padding=2) + self.conv3 = nn.Conv2d(192, 384, kernel_size=3, padding=1) + self.conv4 = nn.Conv2d(384, 256, kernel_size=3, padding=1) + self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1) + self.fc = nn.Linear(256, num_classes) + + def forward(self, x): + x = F.relu(self.conv1(x)) + x = F.max_pool2d(x, kernel_size=2, stride=2) + x = F.relu(self.conv2(x)) + x = F.max_pool2d(x, kernel_size=2, stride=2) + x = F.relu(self.conv3(x)) + x = F.relu(self.conv4(x)) + x = F.relu(self.conv5(x)) + x = F.max_pool2d(x, kernel_size=2, stride=2) + x = x.view(x.size(0), -1) + x = self.fc(x) + return x + + +class AlexNetMCDropout(AlexNet): + + mc_dropout = True + + def __init__(self, num_classes=10, dropout_ratio=0.5, val_mc=10): + super(AlexNetMCDropout, self).__init__(num_classes) + self.dropout_ratio = dropout_ratio + self.val_mc = val_mc + + def forward(self, x): + dropout_ratio = self.dropout_ratio + x = F.relu(F.dropout(self.conv1(x), p=dropout_ratio)) + x = F.max_pool2d(x, kernel_size=2, stride=2) + x = F.relu(F.dropout(self.conv2(x), p=dropout_ratio)) + x = F.max_pool2d(x, kernel_size=2, stride=2) + x = F.relu(F.dropout(self.conv3(x), p=dropout_ratio)) + x = F.relu(F.dropout(self.conv4(x), p=dropout_ratio)) + x = F.relu(F.dropout(self.conv5(x), p=dropout_ratio)) + x = F.max_pool2d(x, kernel_size=2, stride=2) + x = x.view(x.size(0), -1) + x = self.fc(x) + return x + + def prediction(self, x): + + acc_prob = TensorAccumulator() + m = self.val_mc + + for _ in range(m): + output = self.forward(x) + prob = F.softmax(output, dim=1) + acc_prob.update(prob, scale=1/m) + + prob = acc_prob.get() + + return prob + + +def alexnet(**kwargs): + r"""AlexNet model architecture from the + `"One weird trick..." `_ paper. + """ + model = AlexNet(**kwargs) + return model + + +def alexnet_mcdropout(**kwargs): + model = AlexNetMCDropout(**kwargs) + return model + diff --git a/examples/classification/models/lenet.py b/examples/classification/models/lenet.py new file mode 100644 index 0000000..e0fd600 --- /dev/null +++ b/examples/classification/models/lenet.py @@ -0,0 +1,84 @@ +import torch.nn as nn +import torch.nn.functional as F +from torchsso.utils.accumulator import TensorAccumulator + + +class LeNet5(nn.Module): + + def __init__(self, num_classes=10): + super().__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, num_classes) + + def forward(self, x): + out = F.relu(self.conv1(x)) + out = F.max_pool2d(out, 2) + out = F.relu(self.conv2(out)) + out = F.max_pool2d(out, 2) + out = out.view(out.size(0), -1) + out = F.relu(self.fc1(out)) + out = F.relu(self.fc2(out)) + out = self.fc3(out) + return out + + +class LeNet5MCDropout(LeNet5): + + def __init__(self, num_classes=10, dropout_ratio=0.1, val_mc=10): + super(LeNet5MCDropout, self).__init__(num_classes=num_classes) + self.dropout_ratio = dropout_ratio + self.val_mc = val_mc + + def forward(self, x): + p = self.dropout_ratio + out = F.relu(F.dropout(self.conv1(x), p)) + out = F.max_pool2d(out, 2) + out = F.relu(F.dropout(self.conv2(out), p)) + out = F.max_pool2d(out, 2) + out = out.view(out.size(0), -1) + out = F.relu(F.dropout(self.fc1(out), p)) + out = F.relu(F.dropout(self.fc2(out), p)) + out = F.dropout(self.fc2(out), p) + return out + + def mc_prediction(self, x): + + acc_prob = TensorAccumulator() + m = self.val_mc + + for _ in range(m): + output = self.forward(x) + prob = F.softmax(output, dim=1) + acc_prob.update(prob, scale=1/m) + + prob = acc_prob.get() + + return prob + + +class LeNet5BatchNorm(nn.Module): + def __init__(self, num_classes=10, affine=False): + super().__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.bn1 = nn.BatchNorm2d(6, affine=affine) + self.conv2 = nn.Conv2d(6, 16, 5) + self.bn2 = nn.BatchNorm2d(16, affine=affine) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.bn3 = nn.BatchNorm1d(120, affine=affine) + self.fc2 = nn.Linear(120, 84) + self.bn4 = nn.BatchNorm1d(84, affine=affine) + self.fc3 = nn.Linear(84, num_classes) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.max_pool2d(out, 2) + out = F.relu(self.bn2(self.conv2(out))) + out = F.max_pool2d(out, 2) + out = out.view(out.size(0), -1) + out = F.relu(self.bn3(self.fc1(out))) + out = F.relu(self.bn4(self.fc2(out))) + out = self.fc3(out) + return out diff --git a/examples/classification/models/mlp.py b/examples/classification/models/mlp.py new file mode 100644 index 0000000..1ca64da --- /dev/null +++ b/examples/classification/models/mlp.py @@ -0,0 +1,29 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +__all__ = ['mlp'] + + +class MLP(nn.Module): + def __init__(self, num_classes=10): + super().__init__() + n_hid = 1000 + n_out = 10 + self.l1 = nn.Linear(28*28, n_hid) + self.l2 = nn.Linear(n_hid, n_hid) + self.l3 = nn.Linear(n_hid, n_out) + + def forward(self, x: torch.Tensor): + x = x.view([-1, 28*28]) + x = F.relu(self.l1(x)) + x = F.relu(self.l2(x)) + x = self.l3(x) + return x + + +def mlp(**kwargs): + model = MLP(**kwargs) + return model + diff --git a/examples/classification/models/resnet.py b/examples/classification/models/resnet.py new file mode 100644 index 0000000..8262eae --- /dev/null +++ b/examples/classification/models/resnet.py @@ -0,0 +1,121 @@ +'''ResNet in PyTorch. + +For Pre-activation ResNet, see 'preact_resnet.py'. + +Reference: +[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun + Deep Residual Learning for Image Recognition. arXiv:1512.03385 +''' +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes,track_running_stats=False) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes,track_running_stats=False) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion*planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion*planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes,track_running_stats=False) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes,track_running_stats=False) + self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(self.expansion*planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion*planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion*planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class ResNet(nn.Module): + def __init__(self, block, num_blocks, num_classes=10): + super(ResNet, self).__init__() + self.in_planes = 64 + + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) + self.linear = nn.Linear(512*block.expansion, num_classes) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1]*(num_blocks-1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = F.avg_pool2d(out, 4) + out = out.view(out.size(0), -1) + out = self.linear(out) + return out + + +def ResNet18(num_classes=10): + return ResNet(BasicBlock, [2,2,2,2], num_classes) + +def ResNet34(num_classes=10): + return ResNet(BasicBlock, [3,4,6,3], num_classes) + +def ResNet50(num_classes=10): + return ResNet(Bottleneck, [3,4,6,3], num_classes) + +def ResNet101(num_classes=10): + return ResNet(Bottleneck, [3,4,23,3], num_classes) + +def ResNet152(num_classes=10): + return ResNet(Bottleneck, [3,8,36,3], num_classes) + + +def test(): + net = ResNet18() + y = net(torch.randn(1,3,32,32)) + print(y.size()) + +# test() diff --git a/examples/classification/models/vgg.py b/examples/classification/models/vgg.py new file mode 100644 index 0000000..15271d4 --- /dev/null +++ b/examples/classification/models/vgg.py @@ -0,0 +1,47 @@ +'''VGG11/13/16/19 in Pytorch.''' +import torch +import torch.nn as nn + + +cfg = { + 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], + 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], +} + + +class VGG(nn.Module): + def __init__(self, vgg_name='VGG19'): + super(VGG, self).__init__() + self.features = self._make_layers(cfg[vgg_name]) + self.classifier = nn.Linear(512, 10) + + def forward(self, x): + out = self.features(x) + out = out.view(out.size(0), -1) + out = self.classifier(out) + return out + + def _make_layers(self, cfg): + layers = [] + in_channels = 3 + for x in cfg: + if x == 'M': + layers += [nn.MaxPool2d(kernel_size=2, stride=2)] + else: + layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), + nn.BatchNorm2d(x), + nn.ReLU(inplace=True)] + in_channels = x + layers += [nn.AvgPool2d(kernel_size=1, stride=1)] + return nn.Sequential(*layers) + + +def test(): + net = VGG('VGG11') + x = torch.randn(2,3,32,32) + y = net(x) + print(y.size()) + +# test() diff --git a/examples/distributed/README.md b/examples/distributed/README.md new file mode 100644 index 0000000..6b8fdde --- /dev/null +++ b/examples/distributed/README.md @@ -0,0 +1,7 @@ +# Distributed training +PyTorch-SSO supports data parallelism and MC samples parallelism (for VI) for distributed training among multiple processes (GPUs). + +![](../../docs/distributed_vi.png) + +## Applications +- [Image classification](https://github.com/cybertronai/pytorch-sso/blob/master/examples/distributed/classification) diff --git a/examples/distributed/classification/README.md b/examples/distributed/classification/README.md new file mode 100644 index 0000000..77a7fc1 --- /dev/null +++ b/examples/distributed/classification/README.md @@ -0,0 +1,37 @@ +To run training on CIFAR-10/100 with multiple GPUs + +```bash +mpirun -np python main.py \ +--dist_init_method \ +--download \ +--config +``` + +To run training on ImageNet with multiple GPUs + +```bash +mpirun -np python main.py \ +--train_root \ +--val_root \ +--dist_init_method \ +--config +``` +For `init_method`, refer the [PyTorch tutorial for distirubted applications](https://pytorch.org/tutorials/intermediate/dist_tuto.html). + +| optimizer | dataset | architecture | GPUs | config file path | +| --- | --- | --- | --- | --- | +| [Adam](https://arxiv.org/abs/1412.6980) | ImageNet | ResNet-18 | 4 | [configs/imagenet/resnet18_adam_bs4k_4gpu.json](https://github.com/cybertronai/pytorch-sso/blob/master/examples/distributed/classification/configs/imagenet/resnet18_adam_bs4k_4gpu.json) | +| [Adam](https://arxiv.org/abs/1412.6980) | ImageNet | ResNet-18 | 128 | [configs/imagenet/resnet18_adam_bs4k_128gpu.json](https://github.com/cybertronai/pytorch-sso/blob/master/examples/distributed/classification/configs/imagenet/resnet18_adam_bs4k_128gpu.json) | +| [K-FAC](https://arxiv.org/abs/1503.05671) | ImageNet | ResNet-18 | 4 | [configs/imagenet/resnet18_kfac_bs4k_4gpu.json](https://github.com/cybertronai/pytorch-sso/blob/master/examples/distributed/classification/configs/imagenet/resnet18_kfac_bs4k_4gpu.json) | +| [K-FAC](https://arxiv.org/abs/1503.05671)| ImageNet | ResNet-18 | 128 | [configs/imagenet/resnet18_kfac_bs4k_128gpu.json](https://github.com/cybertronai/pytorch-sso/blob/master/examples/distributed/classification/configs/imagenet/resnet18_kfac_bs4k_128gpu.json) | +| [Noisy K-FAC](https://arxiv.org/abs/1712.02390)| ImageNet | ResNet-18 | 128 | [configs/imagenet/resnet18_noisykfac_bs4k_128gpu.json](https://github.com/cybertronai/pytorch-sso/blob/master/examples/distributed/classification/configs/imagenet/resnet18_noisykfac_bs4k_128gpu.json) | +| [VOGN](https://arxiv.org/abs/1806.04854)| ImageNet | ResNet-18 | 128 | [configs/imagenet/resnet18_vogn_bs4k_128gpu.json](https://github.com/cybertronai/pytorch-sso/blob/master/examples/distributed/classification/configs/imagenet/resnet18_vogn_bs4k_128gpu.json) | + +- NOTE: + - You need to run with `N` GPUs when you use `*{N}gpu.json` config file. + - You need to set `--acc_steps` (or `"acc_steps"` in json config) to run with limited number of GPUs as below: + - Mini-batch size (bs) = {examples per GPU} x {# GPUs} x {acc_steps} + - Ex) 4096 (4k) = 32 x 8 x 16 + - The gradients of loss and the curvature are accumulated for `acc_steps` to build pseudo mini-batch size. + +Visit [configs](https://github.com/cybertronai/pytorch-sso/blob/master/examples/distributed/classification/configs) for other architecture, dataset, optimizer, number of GPUs. diff --git a/examples/distributed/classification/configs/cifar10/alexnet_adam_bs256_8gpu.json b/examples/distributed/classification/configs/cifar10/alexnet_adam_bs256_8gpu.json new file mode 100644 index 0000000..afc3ece --- /dev/null +++ b/examples/distributed/classification/configs/cifar10/alexnet_adam_bs256_8gpu.json @@ -0,0 +1,22 @@ +{ + "dataset": "CIFAR-10", + "epochs": 161, + "batch_size": 32, + "val_batch_size": 128, + "random_crop": true, + "random_horizontal_flip": true, + "normalizing_data": true, + "arch_file": "models/alexnet.py", + "arch_name": "AlexNet", + "optim_name": "Adam", + "optim_args": { + "lr": 0.001, + "betas": [0.9, 0.999], + "weight_decay": 1e-4 + }, + "scheduler_name": "MultiStepLR", + "scheduler_args": { + "milestones": [80, 120], + "gamma": 0.1 + } +} \ No newline at end of file diff --git a/examples/distributed/classification/configs/cifar10/alexnet_kfac_bs256_8gpu.json b/examples/distributed/classification/configs/cifar10/alexnet_kfac_bs256_8gpu.json new file mode 100644 index 0000000..8dce839 --- /dev/null +++ b/examples/distributed/classification/configs/cifar10/alexnet_kfac_bs256_8gpu.json @@ -0,0 +1,39 @@ +{ + "dataset": "CIFAR-10", + "epochs": 61, + "batch_size": 32, + "val_batch_size": 1250, + "random_crop": true, + "random_horizontal_flip": true, + "normalizing_data": true, + "arch_file": "models/alexnet.py", + "arch_name": "AlexNet", + "optim_name": "DistributedSecondOrderOptimizer", + "optim_args": { + "curv_type": "Fisher", + "curv_shapes": { + "Conv2d": "Kron", + "Linear": "Kron", + "BatchNorm1d": "Diag", + "BatchNorm2d": "Diag" + }, + "lr": 1e-2, + "l2_reg": 1e-3, + "momentum": 0.9, + "momentum_type": "raw" + }, + "curv_args": { + "damping": 1e-3, + "ema_decay": 0.999, + "pi_type": "tracenorm" + }, + "fisher_args": { + "approx_type": "mc", + "num_mc": 1 + }, + "scheduler_name": "MultiStepLR", + "scheduler_args": { + "milestones": [30, 50], + "gamma": 0.1 + } +} \ No newline at end of file diff --git a/examples/distributed/classification/configs/cifar10/alexnet_noisykfac_bs256_8gpu.json b/examples/distributed/classification/configs/cifar10/alexnet_noisykfac_bs256_8gpu.json new file mode 100644 index 0000000..14b9253 --- /dev/null +++ b/examples/distributed/classification/configs/cifar10/alexnet_noisykfac_bs256_8gpu.json @@ -0,0 +1,45 @@ +{ + "dataset": "CIFAR-10", + "epochs": 81, + "batch_size": 32, + "val_batch_size": 128, + "random_crop": true, + "random_horizontal_flip": true, + "normalizing_data": true, + "arch_file": "models/alexnet.py", + "arch_name": "AlexNet", + "optim_name": "DistributedVIOptimizer", + "optim_args": { + "curv_type": "Fisher", + "curv_shapes": { + "Conv2d": "Kron", + "Linear": "Kron", + "BatchNorm1d": "Diag", + "BatchNorm2d": "Diag" + }, + "lr": 1e-3, + "momentum": 0.9, + "momentum_type": "raw", + "num_mc_samples": 3, + "val_num_mc_samples": 10, + "kl_weighting": 1, + "warmup_kl_weighting_init": 0.01, + "warmup_kl_weighting_steps": 15821, + "prior_variance": 2 + }, + "curv_args": { + "damping": 1e-3, + "ema_decay": 0.999, + "pi_type": "tracenorm" + }, + "fisher_args": { + "approx_type": "mc", + "num_mc": 1 + }, + "scheduler_name": "MultiStepLR", + "scheduler_args": { + "milestones": [30, 50], + "gamma": 0.1 + }, + "num_mc_groups": 8 +} \ No newline at end of file diff --git a/examples/distributed/classification/configs/cifar10/alexnet_sgd_bs256_8gpu.json b/examples/distributed/classification/configs/cifar10/alexnet_sgd_bs256_8gpu.json new file mode 100644 index 0000000..6d025ad --- /dev/null +++ b/examples/distributed/classification/configs/cifar10/alexnet_sgd_bs256_8gpu.json @@ -0,0 +1,22 @@ +{ + "dataset": "CIFAR-10", + "epochs": 161, + "batch_size": 32, + "val_batch_size": 128, + "random_crop": true, + "random_horizontal_flip": true, + "normalizing_data": true, + "arch_file": "models/alexnet.py", + "arch_name": "AlexNet", + "optim_name": "SGD", + "optim_args": { + "lr": 0.1, + "momentum": 0.9, + "weight_decay": 1e-4 + }, + "scheduler_name": "MultiStepLR", + "scheduler_args": { + "milestones": [80, 120], + "gamma": 0.1 + } +} \ No newline at end of file diff --git a/examples/distributed/classification/configs/cifar10/alexnet_vogn_bs256_8gpu.json b/examples/distributed/classification/configs/cifar10/alexnet_vogn_bs256_8gpu.json new file mode 100644 index 0000000..ad412f5 --- /dev/null +++ b/examples/distributed/classification/configs/cifar10/alexnet_vogn_bs256_8gpu.json @@ -0,0 +1,41 @@ +{ + "dataset": "CIFAR-10", + "epochs": 161, + "batch_size": 32, + "val_batch_size": 128, + "random_crop": false, + "random_horizontal_flip": false, + "normalizing_data": true, + "dataset_size_scale": 10, + "arch_file": "models/alexnet.py", + "arch_name": "AlexNet", + "optim_name": "DistributedVIOptimizer", + "optim_args": { + "curv_type": "Cov", + "curv_shapes": { + "Conv2d": "Diag", + "Linear": "Kron", + "BatchNorm1d": "Diag", + "BatchNorm2d": "Diag" + }, + "lr": 1e-4, + "momentum": 0.9, + "momentum_type": "raw", + "num_mc_samples": 3, + "val_num_mc_samples": 10, + "kl_weighting": 1, + "warmup_kl_weighting_init": 0.5, + "warmup_kl_weighting_steps": 1954, + "prior_variance": 2 + }, + "curv_args": { + "damping": 1e-3, + "ema_decay": 0.999 + }, + "scheduler_name": "MultiStepLR", + "scheduler_args": { + "milestones": [80, 120], + "gamma": 0.1 + }, + "num_mc_groups": 8 +} \ No newline at end of file diff --git a/examples/distributed/classification/configs/cifar10/lenet_adam_bs128_4gpu.json b/examples/distributed/classification/configs/cifar10/lenet_adam_bs128_4gpu.json new file mode 100644 index 0000000..6b30d79 --- /dev/null +++ b/examples/distributed/classification/configs/cifar10/lenet_adam_bs128_4gpu.json @@ -0,0 +1,17 @@ +{ + "dataset": "CIFAR-10", + "epochs": 150, + "batch_size": 32, + "val_batch_size": 1250, + "random_crop": true, + "random_horizontal_flip": true, + "normalizing_data": true, + "arch_file": "models/lenet.py", + "arch_name": "LeNet5", + "optim_name": "Adam", + "optim_args": { + "lr": 1e-3, + "betas": [0.9, 0.999], + "weight_decay": 1e-2 + } +} \ No newline at end of file diff --git a/examples/distributed/classification/configs/cifar10/lenet_kfac_bs128_4gpu.json b/examples/distributed/classification/configs/cifar10/lenet_kfac_bs128_4gpu.json new file mode 100644 index 0000000..1f11827 --- /dev/null +++ b/examples/distributed/classification/configs/cifar10/lenet_kfac_bs128_4gpu.json @@ -0,0 +1,41 @@ +{ + "dataset": "CIFAR-10", + "epochs": 50, + "batch_size": 32, + "val_batch_size": 1250, + "random_crop": true, + "random_horizontal_flip": true, + "normalizing_data": true, + "arch_file": "models/lenet.py", + "arch_name": "LeNet5", + "optim_name": "DistributedSecondOrderOptimizer", + "optim_args": { + "lr": 1e-3, + "curv_type": "Fisher", + "curv_shapes": { + "Conv2d": "Kron", + "Linear": "Kron", + "BatchNorm1d": "Diag", + "BatchNorm2d": "Diag" + }, + "momentum": 0.9, + "momentum_type": "raw", + "l2_reg": 1e-3, + "acc_steps": 1 + }, + "curv_args": { + "damping": 1e-3, + "ema_decay": 0.999, + "pi_type": "tracenorm" + }, + "fisher_args": { + "approx_type": "mc", + "num_mc": 1 + }, + "scheduler_name": "ExponentialLR", + "scheduler_args": { + "gamma": 0.9 + }, + "log_interval": 64, + "no_cuda": false +} \ No newline at end of file diff --git a/examples/distributed/classification/configs/cifar10/lenet_sgd_bs128_4gpu.json b/examples/distributed/classification/configs/cifar10/lenet_sgd_bs128_4gpu.json new file mode 100644 index 0000000..58fed2d --- /dev/null +++ b/examples/distributed/classification/configs/cifar10/lenet_sgd_bs128_4gpu.json @@ -0,0 +1,17 @@ +{ + "dataset": "CIFAR-10", + "epochs": 150, + "batch_size": 32, + "val_batch_size": 1250, + "random_crop": true, + "random_horizontal_flip": true, + "normalizing_data": true, + "arch_file": "models/lenet.py", + "arch_name": "LeNet5", + "optim_name": "SGD", + "optim_args": { + "lr": 1e-3, + "momentum": 0.9, + "weight_decay": 1e-2 + } +} \ No newline at end of file diff --git a/examples/distributed/classification/configs/cifar10/lenet_vogn_bs128_4gpu.json b/examples/distributed/classification/configs/cifar10/lenet_vogn_bs128_4gpu.json new file mode 100644 index 0000000..f69d7ba --- /dev/null +++ b/examples/distributed/classification/configs/cifar10/lenet_vogn_bs128_4gpu.json @@ -0,0 +1,36 @@ +{ + "dataset": "CIFAR-10", + "epochs": 211, + "batch_size": 32, + "val_batch_size": 128, + "random_crop": false, + "random_horizontal_flip": false, + "normalizing_data": true, + "dataset_size_scale": 1, + "arch_file": "models/lenet.py", + "arch_name": "LeNet5", + "optim_name": "DistributedVIOptimizer", + "optim_args": { + "curv_type": "Cov", + "curv_shapes": { + "Conv2d": "Diag", + "Linear": "Diag", + "BatchNorm1d": "Diag", + "BatchNorm2d": "Diag" + }, + "lr": 1e-4, + "momentum": 0.9, + "momentum_type": "raw", + "num_mc_samples": 6, + "val_num_mc_samples": 10, + "kl_weighting": 1, + "warmup_kl_weighting_init": 0.1, + "warmup_kl_weighting_steps": 11719, + "prior_variance": 1e-2 + }, + "curv_args": { + "damping": 1e-3, + "ema_decay": 0.999 + }, + "num_mc_groups": 4 +} diff --git a/examples/distributed/classification/configs/cifar10/resnet18_adam_bs256_8gpu.json b/examples/distributed/classification/configs/cifar10/resnet18_adam_bs256_8gpu.json new file mode 100644 index 0000000..76f883c --- /dev/null +++ b/examples/distributed/classification/configs/cifar10/resnet18_adam_bs256_8gpu.json @@ -0,0 +1,27 @@ +{ + "dataset": "CIFAR-10", + "epochs": 161, + "batch_size": 32, + "val_batch_size": 128, + "random_crop": true, + "random_horizontal_flip": true, + "normalizing_data": true, + "arch_file": "models/resnet_b.py", + "arch_name": "resnet18", + "arch_args": { + "zero_init_residual": true, + "norm_stat_momentum": 0.1 + }, + "optim_name": "Adam", + "optim_args": { + "lr": 1e-3, + "betas": [0.9, 0.999], + "weight_decay": 5e-4 + }, + "non_wd_for_bn": true, + "scheduler_name": "MultiStepLR", + "scheduler_args": { + "milestones": [80, 120], + "gamma": 0.1 + } +} \ No newline at end of file diff --git a/examples/distributed/classification/configs/cifar10/resnet18_sgd_bs256_8gpu.json b/examples/distributed/classification/configs/cifar10/resnet18_sgd_bs256_8gpu.json new file mode 100644 index 0000000..7719ad1 --- /dev/null +++ b/examples/distributed/classification/configs/cifar10/resnet18_sgd_bs256_8gpu.json @@ -0,0 +1,28 @@ +{ + "dataset": "CIFAR-10", + "epochs": 161, + "batch_size": 32, + "val_batch_size": 128, + "random_crop": true, + "random_horizontal_flip": true, + "normalizing_data": true, + "arch_file": "models/resnet_b.py", + "arch_name": "resnet18", + "arch_args": { + "zero_init_residual": true, + "norm_stat_momentum": 0.1 + }, + "optim_name": "SGD", + "optim_args": { + "lr": 0.1, + "momentum": 0.9, + "weight_decay": 1e-4 + }, + "momentum_correction": true, + "non_wd_for_bn": true, + "scheduler_name": "MultiStepLR", + "scheduler_args": { + "milestones": [80, 120], + "gamma": 0.1 + } +} \ No newline at end of file diff --git a/examples/distributed/classification/configs/cifar10/resnet18_vogn_bs256_8gpu.json b/examples/distributed/classification/configs/cifar10/resnet18_vogn_bs256_8gpu.json new file mode 100644 index 0000000..a7ffde6 --- /dev/null +++ b/examples/distributed/classification/configs/cifar10/resnet18_vogn_bs256_8gpu.json @@ -0,0 +1,45 @@ +{ + "dataset": "CIFAR-10", + "epochs": 161, + "batch_size": 32, + "val_batch_size": 128, + "random_crop": true, + "random_horizontal_flip": true, + "normalizing_data": true, + "dataset_size_scale": 10, + "arch_file": "models/resnet_b.py", + "arch_name": "resnet18", + "arch_args": { + "zero_init_residual": false, + "norm_stat_momentum": 0.1 + }, + "optim_name": "DistributedVIOptimizer", + "optim_args": { + "curv_type": "Cov", + "curv_shapes": { + "Conv2d": "Diag", + "Linear": "Diag", + "BatchNorm1d": "Diag", + "BatchNorm2d": "Diag" + }, + "lr": 1e-4, + "momentum": 0.9, + "momentum_type": "raw", + "num_mc_samples": 5, + "val_num_mc_samples": 20, + "kl_weighting": 1, + "prior_variance": 0.02, + "non_reg_for_bn": true + }, + "curv_args": { + "damping": 1e-3, + "ema_decay": 0.999 + }, + "momentum_correction": true, + "scheduler_name": "MultiStepLR", + "scheduler_args": { + "milestones": [80, 120], + "gamma": 0.1 + }, + "num_mc_groups": 8 +} \ No newline at end of file diff --git a/examples/distributed/classification/configs/cifar10/vgg19_adam_bs256_8gpu.json b/examples/distributed/classification/configs/cifar10/vgg19_adam_bs256_8gpu.json new file mode 100644 index 0000000..56ea28f --- /dev/null +++ b/examples/distributed/classification/configs/cifar10/vgg19_adam_bs256_8gpu.json @@ -0,0 +1,22 @@ +{ + "dataset": "CIFAR-10", + "epochs": 161, + "batch_size": 32, + "val_batch_size": 128, + "random_crop": true, + "random_horizontal_flip": true, + "normalizing_data": true, + "arch_file": "models/vgg.py", + "arch_name": "VGG19", + "optim_name": "Adam", + "optim_args": { + "lr": 1e-3, + "betas": [0.9, 0.999], + "weight_decay": 1e-4 + }, + "scheduler_name": "MultiStepLR", + "scheduler_args": { + "milestones": [80, 120], + "gamma": 0.1 + } +} \ No newline at end of file diff --git a/examples/distributed/classification/configs/cifar10/vgg19_vogn_bs256_8gpu.json b/examples/distributed/classification/configs/cifar10/vgg19_vogn_bs256_8gpu.json new file mode 100644 index 0000000..43c9881 --- /dev/null +++ b/examples/distributed/classification/configs/cifar10/vgg19_vogn_bs256_8gpu.json @@ -0,0 +1,39 @@ +{ + "dataset": "CIFAR-10", + "epochs": 161, + "batch_size": 32, + "val_batch_size": 128, + "random_crop": true, + "random_horizontal_flip": true, + "normalizing_data": true, + "dataset_size_scale": 10, + "arch_file": "models/vgg.py", + "arch_name": "VGG19", + "optim_name": "DistributedVIOptimizer", + "optim_args": { + "curv_type": "Cov", + "curv_shapes": { + "Conv2d": "Diag", + "Linear": "Diag", + "BatchNorm1d": "Diag", + "BatchNorm2d": "Diag" + }, + "lr": 1e-4, + "momentum": 0.9, + "momentum_type": "raw", + "num_mc_samples": 5, + "val_num_mc_samples": 10, + "kl_weighting": 1, + "prior_variance": 2 + }, + "curv_args": { + "damping": 1e-3, + "ema_decay": 0.999 + }, + "scheduler_name": "MultiStepLR", + "scheduler_args": { + "milestones": [80, 120], + "gamma": 0.1 + }, + "num_mc_groups": 8 +} \ No newline at end of file diff --git a/examples/distributed/classification/configs/cifar100/alexnet_adam_bs256_8gpu.json b/examples/distributed/classification/configs/cifar100/alexnet_adam_bs256_8gpu.json new file mode 100644 index 0000000..d726b57 --- /dev/null +++ b/examples/distributed/classification/configs/cifar100/alexnet_adam_bs256_8gpu.json @@ -0,0 +1,22 @@ +{ + "dataset": "CIFAR-100", + "epochs": 161, + "batch_size": 32, + "val_batch_size": 128, + "random_crop": true, + "random_horizontal_flip": true, + "normalizing_data": true, + "arch_file": "models/alexnet.py", + "arch_name": "AlexNet", + "optim_name": "Adam", + "optim_args": { + "lr": 0.001, + "betas": [0.9, 0.999], + "weight_decay": 1e-2 + }, + "scheduler_name": "MultiStepLR", + "scheduler_args": { + "milestones": [80, 120], + "gamma": 0.1 + } +} \ No newline at end of file diff --git a/examples/distributed/classification/configs/cifar100/alexnet_sgd_bs256_8gpu.json b/examples/distributed/classification/configs/cifar100/alexnet_sgd_bs256_8gpu.json new file mode 100644 index 0000000..ed60718 --- /dev/null +++ b/examples/distributed/classification/configs/cifar100/alexnet_sgd_bs256_8gpu.json @@ -0,0 +1,22 @@ +{ + "dataset": "CIFAR-100", + "epochs": 161, + "batch_size": 32, + "val_batch_size": 128, + "random_crop": true, + "random_horizontal_flip": true, + "normalizing_data": true, + "arch_file": "models/alexnet.py", + "arch_name": "AlexNet", + "optim_name": "SGD", + "optim_args": { + "lr": 0.1, + "momentum": 0.9, + "weight_decay": 1e-4 + }, + "scheduler_name": "MultiStepLR", + "scheduler_args": { + "milestones": [80, 120], + "gamma": 0.1 + } +} \ No newline at end of file diff --git a/examples/distributed/classification/configs/cifar100/alexnet_vogn_bs256_8gpu.json b/examples/distributed/classification/configs/cifar100/alexnet_vogn_bs256_8gpu.json new file mode 100644 index 0000000..4aff422 --- /dev/null +++ b/examples/distributed/classification/configs/cifar100/alexnet_vogn_bs256_8gpu.json @@ -0,0 +1,41 @@ +{ + "dataset": "CIFAR-100", + "epochs": 161, + "batch_size": 32, + "val_batch_size": 128, + "random_crop": true, + "random_horizontal_flip": true, + "normalizing_data": true, + "dataset_size_scale": 10, + "arch_file": "models/alexnet.py", + "arch_name": "AlexNet", + "optim_name": "DistributedVIOptimizer", + "optim_args": { + "curv_type": "Cov", + "curv_shapes": { + "Conv2d": "Diag", + "Linear": "Diag", + "BatchNorm1d": "Diag", + "BatchNorm2d": "Diag" + }, + "lr": 1e-4, + "momentum": 0.9, + "momentum_type": "raw", + "num_mc_samples": 10, + "val_num_mc_samples": 100, + "kl_weighting": 1, + "warmup_kl_weighting_init": 0.5, + "warmup_kl_weighting_steps": 1954, + "prior_variance": 0.02 + }, + "curv_args": { + "damping": 1e-3, + "ema_decay": 0.999 + }, + "scheduler_name": "MultiStepLR", + "scheduler_args": { + "milestones": [80, 120], + "gamma": 0.1 + }, + "num_mc_groups": 8 +} \ No newline at end of file diff --git a/examples/distributed/classification/configs/cifar100/resnet18_adam_bs256_8gpu.json b/examples/distributed/classification/configs/cifar100/resnet18_adam_bs256_8gpu.json new file mode 100644 index 0000000..f38c64d --- /dev/null +++ b/examples/distributed/classification/configs/cifar100/resnet18_adam_bs256_8gpu.json @@ -0,0 +1,27 @@ +{ + "dataset": "CIFAR-100", + "epochs": 161, + "batch_size": 32, + "val_batch_size": 128, + "random_crop": true, + "random_horizontal_flip": true, + "normalizing_data": true, + "arch_file": "models/resnet_b.py", + "arch_name": "resnet18", + "arch_args": { + "zero_init_residual": true, + "norm_stat_momentum": 0.1 + }, + "optim_name": "Adam", + "optim_args": { + "lr": 1e-3, + "betas": [0.9, 0.999], + "weight_decay": 1e-2 + }, + "non_wd_for_bn": true, + "scheduler_name": "MultiStepLR", + "scheduler_args": { + "milestones": [80, 120], + "gamma": 0.1 + } +} \ No newline at end of file diff --git a/examples/distributed/classification/configs/cifar100/resnet18_vogn_bs256_8gpu.json b/examples/distributed/classification/configs/cifar100/resnet18_vogn_bs256_8gpu.json new file mode 100644 index 0000000..553950d --- /dev/null +++ b/examples/distributed/classification/configs/cifar100/resnet18_vogn_bs256_8gpu.json @@ -0,0 +1,47 @@ +{ + "dataset": "CIFAR-100", + "epochs": 161, + "batch_size": 32, + "val_batch_size": 128, + "random_crop": true, + "random_horizontal_flip": true, + "normalizing_data": true, + "dataset_size_scale": 10, + "arch_file": "models/resnet_b.py", + "arch_name": "resnet18", + "arch_args": { + "zero_init_residual": false, + "norm_stat_momentum": 0.1 + }, + "optim_name": "DistributedVIOptimizer", + "optim_args": { + "curv_type": "Cov", + "curv_shapes": { + "Conv2d": "Diag", + "Linear": "Diag", + "BatchNorm1d": "Diag", + "BatchNorm2d": "Diag" + }, + "lr": 1e-4, + "momentum": 0.9, + "momentum_type": "raw", + "num_mc_samples": 5, + "val_num_mc_samples": 20, + "kl_weighting": 1, + "warmup_kl_weighting_init": 0.5, + "warmup_kl_weighting_steps": 1954, + "prior_variance": 0.02, + "non_reg_for_bn": true + }, + "curv_args": { + "damping": 1e-3, + "ema_decay": 0.999 + }, + "momentum_correction": true, + "scheduler_name": "MultiStepLR", + "scheduler_args": { + "milestones": [80, 120], + "gamma": 0.1 + }, + "num_mc_groups": 8 +} \ No newline at end of file diff --git a/examples/distributed/classification/configs/imagenet/resnet18_adam_bs4k_128gpu.json b/examples/distributed/classification/configs/imagenet/resnet18_adam_bs4k_128gpu.json new file mode 100644 index 0000000..af29943 --- /dev/null +++ b/examples/distributed/classification/configs/imagenet/resnet18_adam_bs4k_128gpu.json @@ -0,0 +1,34 @@ +{ + "dataset": "ImageNet", + "epochs": 91, + "batch_size": 32, + "val_batch_size": 391, + "random_resized_crop": false, + "random_crop": true, + "random_horizontal_flip": true, + "normalizing_data": true, + "arch_file": "models/resnet_b.py", + "arch_name": "resnet18", + "arch_args": { + "zero_init_residual": true, + "norm_stat_momentum": 0.1 + }, + "optim_name": "Adam", + "optim_args": { + "lr": 1.6e-3, + "weight_decay": 1e-4, + "betas": [0.9, 0.999] + }, + "non_wd_for_bn": true, + "scheduler_name": "MultiStepLR", + "scheduler_args": { + "milestones": [30, 60, 80], + "gamma": 0.1 + }, + "warmup_epochs": 5, + "warmup_scheduler_name": "GradualWarmupIterLR", + "warmup_scheduler_args": { + "initial_lr": 1.25e-5, + "max_count": 1565 + } +} diff --git a/examples/distributed/classification/configs/imagenet/resnet18_kfac_bs4k_128gpu.json b/examples/distributed/classification/configs/imagenet/resnet18_kfac_bs4k_128gpu.json new file mode 100644 index 0000000..2b5778f --- /dev/null +++ b/examples/distributed/classification/configs/imagenet/resnet18_kfac_bs4k_128gpu.json @@ -0,0 +1,53 @@ +{ + "dataset": "ImageNet", + "epochs": 61, + "batch_size": 32, + "val_batch_size": 32, + "random_resized_crop": false, + "random_crop": true, + "random_horizontal_flip": true, + "normalizing_data": true, + "arch_file": "models/resnet_b.py", + "arch_name": "resnet18", + "arch_args": { + "zero_init_residual": false, + "norm_stat_momentum": 0.1 + }, + "optim_name": "DistributedSecondOrderOptimizer", + "optim_args": { + "curv_type": "Fisher", + "curv_shapes": { + "Conv2d": "Kron", + "Linear": "Kron", + "BatchNorm1d": "Diag", + "BatchNorm2d": "Diag" + }, + "lr": 1.6e-3, + "l2_reg": 1e-4, + "momentum": 0.9, + "momentum_type": "raw", + "non_reg_for_bn": true, + "acc_steps": 1 + }, + "curv_args": { + "damping": 1e-4, + "ema_decay": 1, + "pi_type": "tracenorm" + }, + "fisher_args": { + "approx_type": "mc", + "num_mc": 1 + }, + "momentum_correction": true, + "scheduler_name": "MultiStepLR", + "scheduler_args": { + "milestones": [15, 30, 45], + "gamma": 0.1 + }, + "warmup_epochs": 5, + "warmup_scheduler_name": "GradualWarmupIterLR", + "warmup_scheduler_args": { + "initial_lr": 1.25e-5, + "max_count": 1565 + } +} diff --git a/examples/distributed/classification/configs/imagenet/resnet18_kfac_bs4k_4gpu.json b/examples/distributed/classification/configs/imagenet/resnet18_kfac_bs4k_4gpu.json new file mode 100644 index 0000000..e0685f3 --- /dev/null +++ b/examples/distributed/classification/configs/imagenet/resnet18_kfac_bs4k_4gpu.json @@ -0,0 +1,41 @@ +{ + "dataset": "ImageNet10", + "epochs": 30, + "batch_size": 32, + "val_batch_size": 32, + "random_resized_crop": false, + "random_crop": true, + "random_horizontal_flip": true, + "normalizing_data": true, + "arch_file": "models/resnet_b.py", + "arch_name": "resnet18", + "arch_args": { + "zero_init_residual": false, + "norm_stat_momentum": 0.1 + }, + "optim_name": "DistributedSecondOrderOptimizer", + "optim_args": { + "curv_type": "Fisher", + "curv_shapes": { + "Conv2d": "Kron", + "Linear": "Kron", + "BatchNorm1d": "Diag", + "BatchNorm2d": "Diag" + }, + "lr": 1.6e-3, + "momentum": 0.9, + "momentum_type": "raw", + "l2_reg": 1e-4, + "non_reg_for_bn": true, + "acc_steps": 32 + }, + "curv_args": { + "damping": 1e-4, + "ema_decay": 1, + "pi_type": "tracenorm" + }, + "fisher_args": { + "approx_type": "mc", + "num_mc": 1 + } +} diff --git a/examples/distributed/classification/configs/imagenet/resnet18_noisykfac_bs4k_128gpu.json b/examples/distributed/classification/configs/imagenet/resnet18_noisykfac_bs4k_128gpu.json new file mode 100644 index 0000000..464ba1d --- /dev/null +++ b/examples/distributed/classification/configs/imagenet/resnet18_noisykfac_bs4k_128gpu.json @@ -0,0 +1,58 @@ +{ + "dataset": "ImageNet", + "epochs": 61, + "batch_size": 32, + "val_batch_size": 32, + "random_resized_crop": false, + "random_crop": true, + "random_horizontal_flip": true, + "dataset_size_scale": 5, + "normalizing_data": true, + "arch_file": "models/resnet_b.py", + "arch_name": "resnet18", + "arch_args": { + "zero_init_residual": false, + "norm_stat_momentum": 0.1 + }, + "optim_name": "DistributedVIOptimizer", + "optim_args": { + "curv_type": "Fisher", + "curv_shapes": { + "Conv2d": "Kron", + "Linear": "Kron", + "BatchNorm1d": "Diag", + "BatchNorm2d": "Diag" + }, + "lr": 1.6e-3, + "momentum": 0.9, + "momentum_type": "raw", + "num_mc_samples": 1, + "val_num_mc_samples": 10, + "kl_weighting": 1, + "prior_variance": 7.5e-3, + "non_reg_for_bn": true, + "acc_steps": 1 + }, + "curv_args": { + "damping": 1e-4, + "ema_decay": 0.9, + "pi_type": "tracenorm" + }, + "fisher_args": { + "approx_type": "mc", + "num_mc": 1 + }, + "momentum_correction": true, + "scheduler_name": "MultiStepLR", + "scheduler_args": { + "milestones": [15, 30, 45], + "gamma": 0.1 + }, + "warmup_epochs": 5, + "warmup_scheduler_name": "GradualWarmupIterLR", + "warmup_scheduler_args": { + "initial_lr": 1.25e-5, + "max_count": 1565 + }, + "num_mc_groups": 128 +} diff --git a/examples/distributed/classification/configs/imagenet/resnet18_noisykfac_bs4k_4gpu.json b/examples/distributed/classification/configs/imagenet/resnet18_noisykfac_bs4k_4gpu.json new file mode 100644 index 0000000..5505bd3 --- /dev/null +++ b/examples/distributed/classification/configs/imagenet/resnet18_noisykfac_bs4k_4gpu.json @@ -0,0 +1,45 @@ +{ + "dataset": "ImageNet10", + "epochs": 30, + "batch_size": 32, + "val_batch_size": 32, + "random_resized_crop": false, + "random_crop": true, + "random_horizontal_flip": true, + "normalizing_data": true, + "arch_file": "models/resnet_b.py", + "arch_name": "resnet18", + "arch_args": { + "zero_init_residual": false, + "norm_stat_momentum": 0.1 + }, + "optim_name": "DistributedVIOptimizer", + "optim_args": { + "curv_type": "Fisher", + "curv_shapes": { + "Conv2d": "Kron", + "Linear": "Kron", + "BatchNorm1d": "Diag", + "BatchNorm2d": "Diag" + }, + "lr": 1.6e-3, + "momentum": 0.9, + "momentum_type": "raw", + "num_mc_samples": 1, + "val_num_mc_samples": 0, + "kl_weighting": 1, + "prior_variance": 0.75, + "non_reg_for_bn": true, + "acc_steps": 32 + }, + "curv_args": { + "damping": 1e-4, + "ema_decay": 1, + "pi_type": "tracenorm" + }, + "fisher_args": { + "approx_type": "mc", + "num_mc": 1 + }, + "num_mc_groups": 1 +} diff --git a/examples/distributed/classification/configs/imagenet/resnet18_sgd_bs4k_128gpu.json b/examples/distributed/classification/configs/imagenet/resnet18_sgd_bs4k_128gpu.json new file mode 100644 index 0000000..d6ef624 --- /dev/null +++ b/examples/distributed/classification/configs/imagenet/resnet18_sgd_bs4k_128gpu.json @@ -0,0 +1,35 @@ +{ + "dataset": "ImageNet", + "epochs": 91, + "batch_size": 32, + "val_batch_size": 391, + "random_resized_crop": false, + "random_crop": true, + "random_horizontal_flip": true, + "normalizing_data": true, + "arch_file": "models/resnet_b.py", + "arch_name": "resnet18", + "arch_args": { + "zero_init_residual": true, + "norm_stat_momentum": 0.1 + }, + "optim_name": "SGD", + "optim_args": { + "lr": 1.6e-1, + "momentum": 0.9, + "weight_decay": 1e-4 + }, + "momentum_correction": true, + "non_wd_for_bn": true, + "scheduler_name": "MultiStepLR", + "scheduler_args": { + "milestones": [30, 60, 80], + "gamma": 0.1 + }, + "warmup_epochs": 5, + "warmup_scheduler_name": "GradualWarmupIterLR", + "warmup_scheduler_args": { + "initial_lr": 1.25e-3, + "max_count": 1565 + } +} diff --git a/examples/distributed/classification/configs/imagenet/resnet18_vogn_bs4k_128gpu.json b/examples/distributed/classification/configs/imagenet/resnet18_vogn_bs4k_128gpu.json new file mode 100644 index 0000000..2c3fb2f --- /dev/null +++ b/examples/distributed/classification/configs/imagenet/resnet18_vogn_bs4k_128gpu.json @@ -0,0 +1,52 @@ +{ + "dataset": "ImageNet", + "epochs": 91, + "batch_size": 32, + "val_batch_size": 32, + "random_resized_crop": false, + "random_crop": false, + "random_horizontal_flip": false, + "dataset_size_scale": 5, + "normalizing_data": true, + "arch_file": "models/resnet_b.py", + "arch_name": "resnet18", + "arch_args": { + "zero_init_residual": false, + "norm_stat_momentum": 0.1 + }, + "optim_name": "DistributedVIOptimizer", + "optim_args": { + "curv_type": "Cov", + "curv_shapes": { + "Conv2d": "Diag", + "Linear": "Diag", + "BatchNorm1d": "Diag", + "BatchNorm2d": "Diag" + }, + "lr": 1.6e-3, + "momentum": 0.9, + "momentum_type": "raw", + "num_mc_samples": 1, + "val_num_mc_samples": 10, + "kl_weighting": 1, + "prior_variance": 7.5e-3, + "non_reg_for_bn": true + }, + "curv_args": { + "damping": 1e-4, + "ema_decay": 0.9 + }, + "momentum_correction": true, + "scheduler_name": "MultiStepLR", + "scheduler_args": { + "milestones": [30, 60, 80], + "gamma": 0.1 + }, + "warmup_epochs": 5, + "warmup_scheduler_name": "GradualWarmupIterLR", + "warmup_scheduler_args": { + "initial_lr": 1.25e-5, + "max_count": 1565 + }, + "num_mc_groups": 128 +} diff --git a/examples/distributed/classification/configs/imagenet/resnet18_vogn_bs4k_4gpu.json b/examples/distributed/classification/configs/imagenet/resnet18_vogn_bs4k_4gpu.json new file mode 100644 index 0000000..bb1a8a6 --- /dev/null +++ b/examples/distributed/classification/configs/imagenet/resnet18_vogn_bs4k_4gpu.json @@ -0,0 +1,54 @@ +{ + "dataset": "ImageNet", + "epochs": 90, + "batch_size": 32, + "val_batch_size": 32, + "random_resized_crop": false, + "random_crop": true, + "random_horizontal_flip": true, + "normalizing_data": true, + "arch_file": "models/resnet_b.py", + "arch_name": "resnet18", + "arch_args": { + "zero_init_residual": true, + "norm_stat_momentum": 0.1 + }, + "optim_name": "DistributedVIOptimizer", + "optim_args": { + "curv_type": "Cov", + "curv_shapes": { + "Conv2d": "Diag", + "Linear": "Diag", + "BatchNorm1d": "Diag", + "BatchNorm2d": "Diag" + }, + "lr": 1.6e-2, + "grad_ema_decay": 0.1, + "grad_ema_type": "raw", + "bias_correction": true, + "non_reg_for_bn": true, + "num_mc_samples": 1, + "val_num_mc_samples": 10, + "kl_weighting": 1, + "prior_variance": 7.5e-4, + "weight_decay": 0, + "lars": false + }, + "curv_args": { + "damping": 1e-4, + "ema_decay": 1e-3 + }, + "momentum_correction": true, + "scheduler_name": "MultiStepLR", + "scheduler_args": { + "milestones": [30, 60, 80], + "gamma": 0.1 + }, + "warmup_epochs": 5, + "warmup_scheduler_name": "GradualWarmupIterLR", + "warmup_scheduler_args": { + "initial_lr": 1.25e-4, + "max_count": 1565 + }, + "num_mc_groups": 4 +} diff --git a/examples/distributed/classification/main.py b/examples/distributed/classification/main.py new file mode 100644 index 0000000..cdf87e7 --- /dev/null +++ b/examples/distributed/classification/main.py @@ -0,0 +1,573 @@ +import os +import argparse +from importlib import import_module +import shutil +import json +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import datasets, transforms, models +import torchsso +from torchsso.optim import DistributedFirstOrderOptimizer, DistributedSecondOrderOptimizer, DistributedVIOptimizer +from torchsso.optim.lr_scheduler import MomentumCorrectionLR +from torchsso.utils import Logger + +from mpi4py import MPI +import torch.distributed as dist + +DATASET_CIFAR10 = 'CIFAR-10' +DATASET_CIFAR100 = 'CIFAR-100' +DATASET_IMAGENET = 'ImageNet' + + +def main(): + parser = argparse.ArgumentParser() + # Data + parser.add_argument('--dataset', type=str, + choices=[DATASET_CIFAR10, DATASET_CIFAR100, DATASET_IMAGENET], + default=DATASET_CIFAR10, + help='name of dataset') + parser.add_argument('--root', type=str, default='./data', + help='root of dataset') + parser.add_argument('--train_root', type=str, default=None, + help='root of train dataset') + parser.add_argument('--val_root', type=str, default=None, + help='root of validate dataset') + parser.add_argument('--epochs', type=int, default=10, + help='number of epochs to train') + parser.add_argument('--batch_size', type=int, default=128, + help='input batch size for training') + parser.add_argument('--val_batch_size', type=int, default=128, + help='input batch size for valing') + parser.add_argument('--normalizing_data', action='store_true', + help='[data pre processing] normalizing data') + parser.add_argument('--random_crop', action='store_true', + help='[data augmentation] random crop') + parser.add_argument('--random_resized_crop', action='store_true', + help='[data augmentation] random resised crop') + parser.add_argument('--random_horizontal_flip', action='store_true', + help='[data augmentation] random horizontal flip') + parser.add_argument('--dataset_size_scale', type=float, default=1., + help='ratio multiplied to the actual dataset size') + # Training Settings + parser.add_argument('--arch_file', type=str, default=None, + help='name of file which defines the architecture') + parser.add_argument('--arch_name', type=str, default='LeNet5', + help='name of the architecture') + parser.add_argument('--arch_args', type=json.loads, default=None, + help='[JSON] arguments for the architecture') + parser.add_argument('--optim_name', type=str, default=DistributedSecondOrderOptimizer.__name__, + help='name of the optimizer') + parser.add_argument('--optim_args', type=json.loads, default=None, + help='[JSON] arguments for the optimizer') + parser.add_argument('--curv_args', type=json.loads, default=dict(), + help='[JSON] arguments for the curvature') + parser.add_argument('--fisher_args', type=json.loads, default=dict(), + help='[JSON] arguments for the fisher') + parser.add_argument('--scheduler_name', type=str, default=None, + help='name of the learning rate scheduler') + parser.add_argument('--scheduler_args', type=json.loads, default=None, + help='[JSON] arguments for the scheduler') + parser.add_argument('--warmup_epochs', type=int, default=0, + help='number of epochs for warmup lr') + parser.add_argument('--warmup_scheduler_name', type=str, default=None, + help='name of the learning rate scheduler') + parser.add_argument('--warmup_scheduler_args', type=json.loads, default=None, + help='[JSON] arguments for the wamup scheduler') + parser.add_argument('--momentum_correction', action='store_true', + help='if True, momentum/LR ratio is kept to be constant') + parser.add_argument('--non_wd_for_bn', action='store_true', + help='(FirstOrderOptimizer only) if True, weight decay is not applied for BatchNorm') + parser.add_argument('--lars', action='store_true', + help='if True, LARS is applied for first-order optimizer') + # Options + parser.add_argument('--download', action='store_true', default=False, + help='if True, downloads the dataset (CIFAR-10 or 100) from the internet') + parser.add_argument('--seed', type=int, default=1, + help='random seed') + parser.add_argument('--num_workers', type=int, default=0, + help='number of sub processes for data loading') + parser.add_argument('--log_interval', type=int, default=50, + help='how many batches to wait before logging training status') + parser.add_argument('--log_file_name', type=str, default='log', + help='log file name') + parser.add_argument('--checkpoint_interval', type=int, default=50, + help='how many epochs to wait before logging training status') + parser.add_argument('--resume', type=str, default=None, + help='checkpoint path for resume training') + parser.add_argument('--out', type=str, default='result', + help='dir to save output files') + parser.add_argument('--config', default=None, + help='config file path') + # [COMM] + parser.add_argument('--dist_init_method', type=str, + help='torch.distributed init_method') + parser.add_argument('--size_data_group', type=int, default=1, + help='size of the process groups in which input data are shared') + parser.add_argument('--num_mc_groups', type=int, default=1, + help='number of the process groups in which mc sampled params are shared') + + args = parser.parse_args() + dict_args = vars(args) + + # Load config file + if args.config is not None: + with open(args.config) as f: + config = json.load(f) + dict_args.update(config) + + # Set random seed + torch.manual_seed(args.seed) + + # [COMM] Initialize process group + comm = MPI.COMM_WORLD + size = comm.Get_size() + ranks = list(range(size)) + rank = comm.Get_rank() + n_per_node = torch.cuda.device_count() + device = rank % n_per_node + torch.cuda.set_device(device) + init_method = 'tcp://{}:23456'.format(args.dist_init_method) + dist.init_process_group('nccl', init_method=init_method, world_size=size, rank=rank) + + # [COMM] Setup process group for MC sample parallel + size_data_group = args.size_data_group + assert size % size_data_group == 0 + num_mc_groups = args.num_mc_groups + assert size % num_mc_groups == 0 + + if size_data_group > 1: + num_data_group = size / size_data_group + data_group_id = rank % num_data_group + data_group_ranks = ranks[data_group_id:size:num_data_group] + data_group = dist.new_group(data_group_ranks) + + master_ranks = ranks[0:num_data_group] + master_group = dist.new_group(master_ranks) + else: + num_data_group = size + data_group_id = rank + data_group = None + master_group = dist.new_group(ranks) + + if num_mc_groups > 1: + size_mc_group = int(size / num_mc_groups) + mc_group_id = int(rank/size_mc_group) + else: + size_mc_group = size + mc_group_id = 0 + + # Setup data augmentation & data pre processing + train_transforms, val_transforms = [], [] + + if args.dataset in [DATASET_CIFAR10, DATASET_CIFAR100]: + # CIFAR-10/100 + if args.random_crop: + train_transforms.append(transforms.RandomCrop(32, padding=4)) + + normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) + else: + # ImageNet + if args.random_resized_crop: + train_transforms.append(transforms.RandomResizedCrop(224)) + else: + train_transforms.append(transforms.Resize(256)) + if args.random_crop: + train_transforms.append(transforms.RandomCrop(224)) + else: + train_transforms.append(transforms.CenterCrop(224)) + + normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) + + val_transforms.append(transforms.Resize(256)) + val_transforms.append(transforms.CenterCrop(224)) + + if args.random_horizontal_flip: + train_transforms.append(transforms.RandomHorizontalFlip()) + + train_transforms.append(transforms.ToTensor()) + val_transforms.append(transforms.ToTensor()) + + if args.normalizing_data: + train_transforms.append(normalize) + val_transforms.append(normalize) + + train_transform = transforms.Compose(train_transforms) + val_transform = transforms.Compose(val_transforms) + + # Setup data loader + if args.dataset == DATASET_IMAGENET: + # ImageNet + num_classes = 1000 + + train_root = args.root if args.train_root is None else args.train_root + val_root = args.root if args.val_root is None else args.val_root + train_dataset = datasets.ImageFolder(root=train_root, transform=train_transform) + val_dataset = datasets.ImageFolder(root=val_root, transform=val_transform) + else: + if args.dataset == DATASET_CIFAR10: + # CIFAR-10 + num_classes = 10 + dataset_class = datasets.CIFAR10 + else: + # CIFAR-100 + num_classes = 100 + dataset_class = datasets.CIFAR100 + + train_dataset = dataset_class( + root=args.root, train=True, download=args.download, transform=train_transform) + val_dataset = dataset_class( + root=args.root, train=False, download=args.download, transform=val_transform) + + # [COMM] Setup distributed sampler for data parallel & MC sample parallel + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, num_replicas=num_data_group, rank=data_group_id) + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), + pin_memory=True, sampler=train_sampler, num_workers=args.num_workers) + + # [COMM] Setup distributed sampler for data parallel + val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) + val_loader = torch.utils.data.DataLoader( + val_dataset, batch_size=args.val_batch_size, shuffle=False, + sampler=val_sampler, num_workers=args.num_workers) + + # Setup model + if args.arch_file is None: + arch_class = getattr(models, args.arch_name) + else: + _, ext = os.path.splitext(args.arch_file) + dirname = os.path.dirname(args.arch_file) + + if dirname == '': + module_path = args.arch_file.replace(ext, '') + elif dirname == '.': + module_path = os.path.basename(args.arch_file).replace(ext, '') + else: + module_path = '.'.join(os.path.split(args.arch_file)).replace(ext, '') + + module = import_module(module_path) + arch_class = getattr(module, args.arch_name) + + arch_kwargs = {} if args.arch_args is None else args.arch_args + arch_kwargs['num_classes'] = num_classes + + model = arch_class(**arch_kwargs) + setattr(model, 'num_classes', num_classes) + model = model.to(device) + + # [COMM] Broadcast model parameters + for param in list(model.parameters()): + dist.broadcast(param.data, src=0) + + # Setup optimizer + optim_kwargs = {} if args.optim_args is None else args.optim_args + acc_steps = optim_kwargs.get('acc_steps', 1) + global_batch_size = num_data_group * args.batch_size * acc_steps + total_steps = math.ceil(args.epochs * len(train_loader.dataset) / global_batch_size) + + # Setup optimizer + if args.optim_name == DistributedVIOptimizer.__name__: + optimizer = DistributedVIOptimizer(model, + mc_group_id=mc_group_id, + dataset_size=len(train_loader.dataset) * args.dataset_size_scale, + total_steps=total_steps, + seed=args.seed, + **optim_kwargs, curv_kwargs=args.curv_args) + else: + assert args.num_mc_groups == 1, 'You cannot use MC sample groups with non-VI optimizers.' + if args.optim_name == DistributedSecondOrderOptimizer.__name__: + optimizer = DistributedSecondOrderOptimizer(model, **optim_kwargs, curv_kwargs=args.curv_args) + else: + if args.non_wd_for_bn: + group, group_non_wd = {'params': []}, {'params': [], 'non_wd': True} + for m in model.children(): + if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): + group_non_wd['params'].extend(m.parameters()) + else: + group['params'].extend(m.parameters()) + + params = [group, group_non_wd] + else: + params = model.parameters() + + optim_class = getattr(torch.optim, args.optim_name) + optimizer = optim_class(params, **optim_kwargs) + + for group in optimizer.param_groups: + if group.get('non_wd', False): + group['weight_decay'] = 0 + + optimizer = DistributedFirstOrderOptimizer(optimizer, model, dist, lars=args.lars) + + # Setup lr scheduler + def get_scheduler(name, kwargs): + scheduler_class = getattr(torchsso.optim.lr_scheduler, name, None) + if scheduler_class is None: + scheduler_class = getattr(torch.optim.lr_scheduler, name) + scheduler_kwargs = {} if kwargs is None else kwargs + _scheduler = scheduler_class(optimizer, **scheduler_kwargs) + if args.momentum_correction: + _scheduler = MomentumCorrectionLR(_scheduler) + return _scheduler + + if args.scheduler_name is None: + main_scheduler = None + else: + main_scheduler = get_scheduler(args.scheduler_name, args.scheduler_args) + + if args.warmup_scheduler_name is None: + warmup_scheduler = main_scheduler + else: + warmup_scheduler = get_scheduler(args.warmup_scheduler_name, args.warmup_scheduler_args) + + logger = None + start_epoch = 1 + + # Load checkpoint + if args.resume is not None: + print('==> Resuming from checkpoint..') + assert os.path.exists(args.resume), 'Error: no checkpoint file found' + checkpoint = torch.load(args.resume) + model.load_state_dict(checkpoint['model']) + optimizer.load_state_dict(checkpoint['optimizer']) + start_epoch = checkpoint['epoch'] + + if rank == 0: + + # All config + print('===========================') + print('dataset: {}'.format(vars(args)['dataset'])) + print('train data size: {}'.format(len(train_loader.dataset))) + print('val data size: {}'.format(len(val_loader.dataset))) + + print('MPI.COMM_WORLD size: {}'.format(size)) + print('global mini-batch size: {}'.format(global_batch_size)) + print('steps/epoch: {}'.format(math.ceil(len(train_loader.dataset) / global_batch_size))) + + num_mc_samples = optim_kwargs.get('num_mc_samples', None) + if num_mc_samples is not None: + print('global num MC samples: {}'.format(num_mc_groups * num_mc_samples)) + print('MC sample group: {} processes/group x {} group'.format(size_mc_group, num_mc_groups)) + print('data group: {} processes/group x {} group'.format(size_data_group, num_data_group)) + + if hasattr(optimizer, 'indices'): + print('layer assignment: {}'.format(optimizer.indices)) + + print('---------------------------') + + for key, val in vars(args).items(): + if key == 'dataset': + continue + else: + print('{}: {}'.format(key, val)) + print('===========================') + + # Copy this file & config to args.out + if not os.path.isdir(args.out): + os.makedirs(args.out) + try: + shutil.copy(os.path.realpath(__file__), args.out) + except shutil.SameFileError: + pass + if args.config is not None: + try: + shutil.copy(args.config, args.out) + except shutil.SameFileError: + pass + if args.arch_file is not None: + try: + shutil.copy(args.arch_file, args.out) + except shutil.SameFileError: + pass + + # Setup logger + logger = Logger(args.out, args.log_file_name) + logger.start() + + # Run training + for epoch in range(start_epoch, args.epochs + 1): + + scheduler = main_scheduler if epoch > args.warmup_epochs else warmup_scheduler + + # train + accuracy, loss = train(rank, epoch, model, device, train_loader, optimizer, scheduler, + args, master_group, data_group_id, data_group, logger) + # val + val_accuracy, val_loss = validate(rank, model, val_loader, device, optimizer) + + if rank == 0: + # write to log + iteration = epoch * len(train_loader) + elapsed_time = logger.elapsed_time + log = {'epoch': epoch, 'iteration': iteration, 'elapsed_time': elapsed_time, + 'accuracy': accuracy, 'loss': loss, + 'val_accuracy': val_accuracy, 'val_loss': val_loss, + 'lr': optimizer.param_groups[0]['lr'], + 'momentum': optimizer.param_groups[0].get('momentum', 0), + } + logger.write(log) + + # save checkpoint + if epoch % args.checkpoint_interval == 0 or epoch > args.epochs - 3: + path = os.path.join(args.out, 'epoch{}.ckpt'.format(epoch)) + data = { + 'model': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'epoch': epoch + } + torch.save(data, path) + + +def train(rank, epoch, model, device, train_loader, optimizer, scheduler, + args, master_group, data_group_id=0, data_group=None, logger=None): + + def scheduler_type(_scheduler): + if _scheduler is None: + return 'none' + return getattr(_scheduler, 'scheduler_type', 'epoch') + + if scheduler_type(scheduler) == 'epoch': + scheduler.step(epoch - 1) + + model.train() + + total_correct = 0 + loss = None + total_data_size = 0 + epoch_size = len(train_loader.dataset) + num_iters_in_epoch = len(train_loader) + base_num_iter = (epoch - 1) * num_iters_in_epoch + + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + + if scheduler_type(scheduler) == 'iter': + scheduler.step() + + for name, param in model.named_parameters(): + attr = 'p_pre_{}'.format(name) + setattr(model, attr, param.detach().clone()) + + # update params + def closure(): + optimizer.zero_grad() + output = model(data) + loss = F.cross_entropy(output, target) + loss.backward() + + return loss, output + + if isinstance(optimizer, DistributedSecondOrderOptimizer) \ + and optimizer.curv_type == 'Fisher': + closure = torchsso.get_closure_for_fisher(optimizer, model, data, target, **args.fisher_args) + + loss, output = optimizer.step(closure=closure) + data_size = torch.tensor(len(data)).to(device) + + # [COMM] reduce across the all processes + dist.reduce(loss, dst=0) + + # [COMM] reduce across the processes in a data group + if data_group is not None: + dist.reduce(output, dst=data_group_id, group=data_group) + + pred = output.argmax(dim=1, keepdim=True) + correct = pred.eq(target.view_as(pred)).sum().data + + # [COMM] reduce across the processes in the master MC sample group + if dist.get_world_size(master_group) > 1: + dist.reduce(correct, dst=0, group=master_group) + dist.reduce(data_size, dst=0, group=master_group) + + # refresh results + if rank == 0: + loss = loss.item() / dist.get_world_size() + + correct = correct.item() + data_size = data_size.item() + + total_correct += correct + + iteration = base_num_iter + batch_idx + 1 + total_data_size += data_size + + is_log_timing = (epoch == 1 and batch_idx == 0) or \ + (batch_idx + 1) % args.log_interval == 0 + + # save log + if logger is not None and is_log_timing: + accuracy = 100. * total_correct / total_data_size + elapsed_time = logger.elapsed_time + print('epoch: {} [{}/{} ({:.0f}%)]\tloss: {:.6f}, ' + 'accuracy: {:.0f}/{} ({:.2f}%), ' + 'elapsed: {:.1f}s, iters/sec: {:.2f}'.format( + epoch, total_data_size, epoch_size, 100. * (batch_idx + 1) / num_iters_in_epoch, + loss, total_correct, total_data_size, accuracy, elapsed_time, iteration/elapsed_time)) + + lr = optimizer.param_groups[0]['lr'] + m = optimizer.param_groups[0].get('momentum', 0) + log = {'epoch': epoch, 'iteration': iteration, 'elapsed_time': elapsed_time, + 'accuracy': accuracy, 'loss': loss, 'lr': lr, 'momentum': m} + + for name, param in model.named_parameters(): + attr = 'p_pre_{}'.format(name) + p_pre = getattr(model, attr) + p_norm = param.norm().item() + p_shape = list(param.size()) + p_pre_norm = p_pre.norm().item() + g_norm = param.grad.norm().item() + upd_norm = param.sub(p_pre).norm().item() + noise_scale = getattr(param, 'noise_scale', 0) + + p_log = {'p_shape': p_shape, 'p_norm': p_norm, 'p_pre_norm': p_pre_norm, + 'g_norm': g_norm, 'upd_norm': upd_norm, 'noise_scale': noise_scale} + log[name] = p_log + + logger.write(log) + + accuracy = 100. * total_correct / epoch_size + + return accuracy, loss + + +def validate(rank, model, val_loader, device, optimizer): + model.eval() + val_loss = 0 + correct = 0 + + with torch.no_grad(): + for data, target in val_loader: + data, target = data.to(device), target.to(device) + if isinstance(optimizer, DistributedVIOptimizer): + prob = optimizer.prediction(data) + val_loss += F.nll_loss(torch.log(prob), target, reduction='sum') + pred = prob.argmax(dim=1, keepdim=True) # get the index of the max log-probability + elif hasattr(model, 'mc_prediction'): + prob = model.mc_prediction(data) + val_loss += F.nll_loss(torch.log(prob), target, reduction='sum') + pred = prob.argmax(dim=1, keepdim=True) # get the index of the max log-probability + else: + output = model(data) + val_loss += F.cross_entropy(output, target, reduction='sum') # sum up batch loss + pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability + + correct += pred.eq(target.view_as(pred)).sum() + + dist.reduce(val_loss, dst=0) + dist.reduce(correct, dst=0) + + val_loss = val_loss.item() / len(val_loader.dataset) + val_accuracy = 100. * correct.item() / len(val_loader.dataset) + + if rank == 0: + print('\nEval: average loss: {:.4f}, accuracy: {:.0f}/{} ({:.2f}%)\n'.format( + val_loss, correct, len(val_loader.dataset), val_accuracy)) + + return val_accuracy, val_loss + + +if __name__ == '__main__': + main() diff --git a/examples/distributed/classification/models/__init__.py b/examples/distributed/classification/models/__init__.py new file mode 100644 index 0000000..158135e --- /dev/null +++ b/examples/distributed/classification/models/__init__.py @@ -0,0 +1,5 @@ +from .vgg import * +from .lenet import * +from .resnet import * +from .resnext import * +from .alexnet import * diff --git a/examples/distributed/classification/models/alexnet.py b/examples/distributed/classification/models/alexnet.py new file mode 100644 index 0000000..32fbd72 --- /dev/null +++ b/examples/distributed/classification/models/alexnet.py @@ -0,0 +1,113 @@ +'''AlexNet for CIFAR10. FC layers are removed. Paddings are adjusted. +Without BN, the start learning rate should be 0.01 +(c) YANG, Wei +''' +import torch.nn as nn +import torch.nn.functional as F +from torchsso.utils.accumulator import TensorAccumulator + + +__all__ = ['alexnet', 'alexnet_mcdropout'] + + +class AlexNet(nn.Module): + + def __init__(self, num_classes=10): + super().__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5) + self.conv2 = nn.Conv2d(64, 192, kernel_size=5, padding=2) + self.conv3 = nn.Conv2d(192, 384, kernel_size=3, padding=1) + self.conv4 = nn.Conv2d(384, 256, kernel_size=3, padding=1) + self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1) + self.fc = nn.Linear(256, num_classes) + + def forward(self, x): + x = F.relu(self.conv1(x), inplace=True) + x = F.max_pool2d(x, kernel_size=2, stride=2) + x = F.relu(self.conv2(x), inplace=True) + x = F.max_pool2d(x, kernel_size=2, stride=2) + x = F.relu(self.conv3(x), inplace=True) + x = F.relu(self.conv4(x), inplace=True) + x = F.relu(self.conv5(x), inplace=True) + x = F.max_pool2d(x, kernel_size=2, stride=2) + x = x.view(x.size(0), -1) + x = self.fc(x) + return x + + +class AlexNet2(nn.Module): + + def __init__(self, num_classes=10): + super(AlexNet2, self).__init__() + self.features = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(64, 192, kernel_size=5, padding=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(192, 384, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(384, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=2, stride=2), + ) + self.classifier = nn.Linear(256, num_classes) + + def forward(self, x): + x = self.features(x) + x = x.view(x.size(0), -1) + x = self.classifier(x) + return x + + +class AlexNetMCDropout(AlexNet): + + def __init__(self, num_classes=10, dropout_ratio=0.5, val_mc=10): + super(AlexNetMCDropout, self).__init__(num_classes) + self.dropout_ratio = dropout_ratio + self.val_mc = val_mc + + def forward(self, x): + dropout_ratio = self.dropout_ratio + x = F.relu(F.dropout(self.conv1(x), p=dropout_ratio), inplace=True) + x = F.max_pool2d(x, kernel_size=2, stride=2) + x = F.relu(F.dropout(self.conv2(x), p=dropout_ratio), inplace=True) + x = F.max_pool2d(x, kernel_size=2, stride=2) + x = F.relu(F.dropout(self.conv3(x), p=dropout_ratio), inplace=True) + x = F.relu(F.dropout(self.conv4(x), p=dropout_ratio), inplace=True) + x = F.relu(F.dropout(self.conv5(x), p=dropout_ratio), inplace=True) + x = F.max_pool2d(x, kernel_size=2, stride=2) + x = x.view(x.size(0), -1) + x = self.fc(x) + return x + + def mc_prediction(self, x): + + acc_prob = TensorAccumulator() + m = self.val_mc + + for _ in range(m): + output = self.forward(x) + prob = F.softmax(output, dim=1) + acc_prob.update(prob, scale=1/m) + + prob = acc_prob.get() + + return prob + + +def alexnet(**kwargs): + r"""AlexNet model architecture from the + `"One weird trick..." `_ paper. + """ + model = AlexNet(**kwargs) + return model + + +def alexnet_mcdropout(**kwargs): + model = AlexNetMCDropout(**kwargs) + return model + diff --git a/examples/distributed/classification/models/lenet.py b/examples/distributed/classification/models/lenet.py new file mode 100644 index 0000000..27fd02d --- /dev/null +++ b/examples/distributed/classification/models/lenet.py @@ -0,0 +1,84 @@ +import torch.nn as nn +import torch.nn.functional as F +from torchsso.utils.accumulator import TensorAccumulator + + +class LeNet5(nn.Module): + + def __init__(self, num_classes=10): + super().__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, num_classes) + + def forward(self, x): + out = F.relu(self.conv1(x)) + out = F.max_pool2d(out, 2) + out = F.relu(self.conv2(out)) + out = F.max_pool2d(out, 2) + out = out.view(out.size(0), -1) + out = F.relu(self.fc1(out)) + out = F.relu(self.fc2(out)) + out = self.fc3(out) + return out + + +class LeNet5MCDropout(LeNet5): + + def __init__(self, num_classes=10, dropout_ratio=0.1, val_mc=10): + super(LeNet5MCDropout, self).__init__(num_classes=num_classes) + self.dropout_ratio = dropout_ratio + self.val_mc = val_mc + + def forward(self, x): + p = self.dropout_ratio + out = F.relu(F.dropout(self.conv1(x), p)) + out = F.max_pool2d(out, 2) + out = F.relu(F.dropout(self.conv2(out), p)) + out = F.max_pool2d(out, 2) + out = out.view(out.size(0), -1) + out = F.relu(F.dropout(self.fc1(out), p)) + out = F.relu(F.dropout(self.fc2(out), p)) + out = F.dropout(self.fc3(out), p) + return out + + def mc_prediction(self, x): + + acc_prob = TensorAccumulator() + m = self.val_mc + + for _ in range(m): + output = self.forward(x) + prob = F.softmax(output, dim=1) + acc_prob.update(prob, scale=1/m) + + prob = acc_prob.get() + + return prob + + +class LeNet5BatchNorm(nn.Module): + def __init__(self, num_classes=10, affine=False): + super().__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.bn1 = nn.BatchNorm2d(6, affine=affine) + self.conv2 = nn.Conv2d(6, 16, 5) + self.bn2 = nn.BatchNorm2d(16, affine=affine) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.bn3 = nn.BatchNorm1d(120, affine=affine) + self.fc2 = nn.Linear(120, 84) + self.bn4 = nn.BatchNorm1d(84, affine=affine) + self.fc3 = nn.Linear(84, num_classes) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.max_pool2d(out, 2) + out = F.relu(self.bn2(self.conv2(out))) + out = F.max_pool2d(out, 2) + out = out.view(out.size(0), -1) + out = F.relu(self.bn3(self.fc1(out))) + out = F.relu(self.bn4(self.fc2(out))) + out = self.fc3(out) + return out diff --git a/examples/distributed/classification/models/resnet.py b/examples/distributed/classification/models/resnet.py new file mode 100644 index 0000000..6a6629b --- /dev/null +++ b/examples/distributed/classification/models/resnet.py @@ -0,0 +1,121 @@ +'''ResNet in PyTorch. + +For Pre-activation ResNet, see 'preact_resnet.py'. + +Reference: +[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun + Deep Residual Learning for Image Recognition. arXiv:1512.03385 +''' +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes,track_running_stats=False) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes,track_running_stats=False) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion*planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion*planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes,track_running_stats=False) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes,track_running_stats=False) + self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(self.expansion*planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion*planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion*planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class ResNet(nn.Module): + def __init__(self, block, num_blocks, num_classes=10): + super(ResNet, self).__init__() + self.in_planes = 64 + + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) + self.linear = nn.Linear(512*block.expansion, num_classes) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1]*(num_blocks-1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = F.avg_pool2d(out, 4) + out = out.view(out.size(0), -1) + out = self.linear(out) + return out + + +def ResNet18(num_classes=10): + return ResNet(BasicBlock, [2,2,2,2], num_classes) + +def ResNet34(num_classes=10): + return ResNet(BasicBlock, [3,4,6,3], num_classes) + +def ResNet50(num_classes=10): + return ResNet(Bottleneck, [3,4,6,3], num_classes) + +def ResNet101(num_classes=10): + return ResNet(Bottleneck, [3,4,23,3], num_classes) + +def ResNet152(num_classe=10): + return ResNet(Bottleneck, [3,8,36,3], num_classes) + + +def test(): + net = ResNet18() + y = net(torch.randn(1,3,32,32)) + print(y.size()) + +# test() diff --git a/examples/distributed/classification/models/resnet_b.py b/examples/distributed/classification/models/resnet_b.py new file mode 100644 index 0000000..267850c --- /dev/null +++ b/examples/distributed/classification/models/resnet_b.py @@ -0,0 +1,264 @@ +import torch.nn as nn +import torch.utils.model_zoo as model_zoo + + +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d'] + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1, groups=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, groups=groups, bias=False) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, norm_layer=None, norm_stat_momentum=0.1): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes, momentum=norm_stat_momentum) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes, momentum=norm_stat_momentum) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, norm_layer=None, norm_stat_momentum=0.1): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width, momentum=norm_stat_momentum) + self.conv2 = conv3x3(width, width, stride, groups) + self.bn2 = norm_layer(width, momentum=norm_stat_momentum) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion, momentum=norm_stat_momentum) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, + groups=1, width_per_group=64, norm_layer=None, norm_stat_momentum=0.1): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + + self.inplanes = 64 + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes, momentum=norm_stat_momentum) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0], + norm_layer=norm_layer, norm_stat_momentum=norm_stat_momentum) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + norm_layer=norm_layer, norm_stat_momentum=norm_stat_momentum) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + norm_layer=norm_layer, norm_stat_momentum=norm_stat_momentum) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + norm_layer=norm_layer, norm_stat_momentum=norm_stat_momentum) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, norm_layer=None, norm_stat_momentum=0.1): + if norm_layer is None: + norm_layer = nn.BatchNorm2d + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion, momentum=norm_stat_momentum), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, norm_layer, norm_stat_momentum)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, + norm_layer=norm_layer, norm_stat_momentum=norm_stat_momentum)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + +def resnet18(pretrained=False, **kwargs): + """Constructs a ResNet-18 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) + return model + + +def resnet34(pretrained=False, **kwargs): + """Constructs a ResNet-34 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) + return model + + +def resnet50(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) + return model + + +def resnet101(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) + return model + + +def resnet152(pretrained=False, **kwargs): + """Constructs a ResNet-152 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) + return model + + +def resnext50_32x4d(pretrained=False, **kwargs): + model = ResNet(Bottleneck, [3, 4, 6, 3], groups=32, width_per_group=4, **kwargs) + # if pretrained: + # model.load_state_dict(model_zoo.load_url(model_urls['resnext50_32x4d'])) + return model + + +def resnext101_32x8d(pretrained=False, **kwargs): + model = ResNet(Bottleneck, [3, 4, 23, 3], groups=32, width_per_group=8, **kwargs) + # if pretrained: + # model.load_state_dict(model_zoo.load_url(model_urls['resnext101_32x8d'])) + return model diff --git a/examples/distributed/classification/models/resnext.py b/examples/distributed/classification/models/resnext.py new file mode 100644 index 0000000..e4f7233 --- /dev/null +++ b/examples/distributed/classification/models/resnext.py @@ -0,0 +1,95 @@ +'''ResNeXt in PyTorch. + +See the paper "Aggregated Residual Transformations for Deep Neural Networks" for more details. +''' +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Block(nn.Module): + '''Grouped convolution block.''' + expansion = 2 + + def __init__(self, in_planes, cardinality=32, bottleneck_width=4, stride=1): + super(Block, self).__init__() + group_width = cardinality * bottleneck_width + self.conv1 = nn.Conv2d(in_planes, group_width, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(group_width) + self.conv2 = nn.Conv2d(group_width, group_width, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False) + self.bn2 = nn.BatchNorm2d(group_width) + self.conv3 = nn.Conv2d(group_width, self.expansion*group_width, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(self.expansion*group_width) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion*group_width: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion*group_width, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion*group_width) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class ResNeXt(nn.Module): + def __init__(self, num_blocks, cardinality, bottleneck_width, num_classes=10): + super(ResNeXt, self).__init__() + self.cardinality = cardinality + self.bottleneck_width = bottleneck_width + self.in_planes = 64 + + self.conv1 = nn.Conv2d(3, 64, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.layer1 = self._make_layer(num_blocks[0], 1) + self.layer2 = self._make_layer(num_blocks[1], 2) + self.layer3 = self._make_layer(num_blocks[2], 2) + # self.layer4 = self._make_layer(num_blocks[3], 2) + self.linear = nn.Linear(cardinality*bottleneck_width*8, num_classes) + + def _make_layer(self, num_blocks, stride): + strides = [stride] + [1]*(num_blocks-1) + layers = [] + for stride in strides: + layers.append(Block(self.in_planes, self.cardinality, self.bottleneck_width, stride)) + self.in_planes = Block.expansion * self.cardinality * self.bottleneck_width + # Increase bottleneck_width by 2 after each stage. + self.bottleneck_width *= 2 + return nn.Sequential(*layers) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + # out = self.layer4(out) + out = F.avg_pool2d(out, 8) + out = out.view(out.size(0), -1) + out = self.linear(out) + return out + + +def ResNeXt29_2x64d(num_classes=10): + return ResNeXt(num_blocks=[3,3,3], cardinality=2, bottleneck_width=64, num_classes=num_classes) + +def ResNeXt29_4x64d(num_classes=10): + return ResNeXt(num_blocks=[3,3,3], cardinality=4, bottleneck_width=64, num_classes=num_classes) + +def ResNeXt29_8x64d(num_classes=10): + return ResNeXt(num_blocks=[3,3,3], cardinality=8, bottleneck_width=64, num_classes=num_classes) + +def ResNeXt29_32x4d(num_classes=10): + return ResNeXt(num_blocks=[3,3,3], cardinality=32, bottleneck_width=4, num_classes=num_classes) + +def test_resnext(): + net = ResNeXt29_2x64d() + x = torch.randn(1,3,32,32) + y = net(x) + print(y.size()) + +# test_resnext() diff --git a/examples/distributed/classification/models/vgg.py b/examples/distributed/classification/models/vgg.py new file mode 100644 index 0000000..280285f --- /dev/null +++ b/examples/distributed/classification/models/vgg.py @@ -0,0 +1,153 @@ +'''VGG11/13/16/19 in Pytorch.''' +import torch.nn as nn +import torch.nn.functional as F +from torchsso.utils.accumulator import TensorAccumulator + + +cfg = { + 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], + 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], +} + + +class VGG(nn.Module): + def __init__(self, num_classes=10, vgg_name='VGG19'): + super(VGG, self).__init__() + self.features = self._make_layers(cfg[vgg_name]) + self.classifier = nn.Linear(512, num_classes) + + def forward(self, x): + out = self.features(x) + out = out.view(out.size(0), -1) + out = self.classifier(out) + return out + + def _make_layers(self, cfg): + layers = [] + in_channels = 3 + for x in cfg: + if x == 'M': + layers += [nn.MaxPool2d(kernel_size=2, stride=2)] + else: + layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), + nn.BatchNorm2d(x), + nn.ReLU(inplace=True)] + in_channels = x + layers += [nn.AvgPool2d(kernel_size=1, stride=1)] + return nn.Sequential(*layers) + + +class VGG19(nn.Module): + + def __init__(self, num_classes=10): + super(VGG19, self).__init__() + self.conv1_1 = nn.Conv2d(3, 64, 3, padding=1) + self.bn1_1 = nn.BatchNorm2d(64) + self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) + self.bn1_2 = nn.BatchNorm2d(64) + self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) + self.bn2_1 = nn.BatchNorm2d(128) + self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) + self.bn2_2 = nn.BatchNorm2d(128) + self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) + self.bn3_1 = nn.BatchNorm2d(256) + self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) + self.bn3_2 = nn.BatchNorm2d(256) + self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) + self.bn3_3 = nn.BatchNorm2d(256) + self.conv3_4 = nn.Conv2d(256, 256, 3, padding=1) + self.bn3_4 = nn.BatchNorm2d(256) + self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) + self.bn4_1 = nn.BatchNorm2d(512) + self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) + self.bn4_2 = nn.BatchNorm2d(512) + self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) + self.bn4_3 = nn.BatchNorm2d(512) + self.conv4_4 = nn.Conv2d(512, 512, 3, padding=1) + self.bn4_4 = nn.BatchNorm2d(512) + self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1) + self.bn5_1 = nn.BatchNorm2d(512) + self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1) + self.bn5_2 = nn.BatchNorm2d(512) + self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1) + self.bn5_3 = nn.BatchNorm2d(512) + self.conv5_4 = nn.Conv2d(512, 512, 3, padding=1) + self.bn5_4 = nn.BatchNorm2d(512) + self.fc = nn.Linear(512, num_classes) + + def forward(self, x): + h = F.relu(self.bn1_1(self.conv1_1(x)), inplace=True) + h = F.relu(self.bn1_2(self.conv1_2(h)), inplace=True) + h = F.max_pool2d(h, kernel_size=2, stride=2) + h = F.relu(self.bn2_1(self.conv2_1(h)), inplace=True) + h = F.relu(self.bn2_2(self.conv2_2(h)), inplace=True) + h = F.max_pool2d(h, kernel_size=2, stride=2) + h = F.relu(self.bn3_1(self.conv3_1(h)), inplace=True) + h = F.relu(self.bn3_2(self.conv3_2(h)), inplace=True) + h = F.relu(self.bn3_3(self.conv3_3(h)), inplace=True) + h = F.relu(self.bn3_4(self.conv3_4(h)), inplace=True) + h = F.max_pool2d(h, kernel_size=2, stride=2) + h = F.relu(self.bn4_1(self.conv4_1(h)), inplace=True) + h = F.relu(self.bn4_2(self.conv4_2(h)), inplace=True) + h = F.relu(self.bn4_3(self.conv4_3(h)), inplace=True) + h = F.relu(self.bn4_4(self.conv4_4(h)), inplace=True) + h = F.max_pool2d(h, kernel_size=2, stride=2) + h = F.relu(self.bn5_1(self.conv5_1(h)), inplace=True) + h = F.relu(self.bn5_2(self.conv5_2(h)), inplace=True) + h = F.relu(self.bn5_3(self.conv5_3(h)), inplace=True) + h = F.relu(self.bn5_4(self.conv5_4(h)), inplace=True) + h = F.max_pool2d(h, kernel_size=2, stride=2) + h = h.view(h.size(0), -1) + out = self.fc(h) + return out + + +class VGG19MCDropout(VGG19): + + def __init__(self, num_classes=10, dropout_ratio=0.1, val_mc=10): + super(VGG19MCDropout, self).__init__(num_classes) + self.dropout_ratio = dropout_ratio + self.val_mc = val_mc + + def forward(self, x): + p = self.dropout_ratio + h = F.relu(self.bn1_1(F.dropout(self.conv1_1(x), p)), inplace=True) + h = F.relu(self.bn1_2(F.dropout(self.conv1_2(h), p)), inplace=True) + h = F.max_pool2d(h, kernel_size=2, stride=2) + h = F.relu(self.bn2_1(F.dropout(self.conv2_1(h), p)), inplace=True) + h = F.relu(self.bn2_2(F.dropout(self.conv2_2(h), p)), inplace=True) + h = F.max_pool2d(h, kernel_size=2, stride=2) + h = F.relu(self.bn3_1(F.dropout(self.conv3_1(h), p)), inplace=True) + h = F.relu(self.bn3_2(F.dropout(self.conv3_2(h), p)), inplace=True) + h = F.relu(self.bn3_3(F.dropout(self.conv3_3(h), p)), inplace=True) + h = F.relu(self.bn3_4(F.dropout(self.conv3_4(h), p)), inplace=True) + h = F.max_pool2d(h, kernel_size=2, stride=2) + h = F.relu(self.bn4_1(F.dropout(self.conv4_1(h), p)), inplace=True) + h = F.relu(self.bn4_2(F.dropout(self.conv4_2(h), p)), inplace=True) + h = F.relu(self.bn4_3(F.dropout(self.conv4_3(h), p)), inplace=True) + h = F.relu(self.bn4_4(F.dropout(self.conv4_4(h), p)), inplace=True) + h = F.max_pool2d(h, kernel_size=2, stride=2) + h = F.relu(self.bn5_1(F.dropout(self.conv5_1(h), p)), inplace=True) + h = F.relu(self.bn5_2(F.dropout(self.conv5_2(h), p)), inplace=True) + h = F.relu(self.bn5_3(F.dropout(self.conv5_3(h), p)), inplace=True) + h = F.relu(self.bn5_4(F.dropout(self.conv5_4(h), p)), inplace=True) + h = F.max_pool2d(h, kernel_size=2, stride=2) + h = h.view(h.size(0), -1) + out = F.dropout(self.fc(h), p) + return out + + def mc_prediction(self, x): + + acc_prob = TensorAccumulator() + m = self.val_mc + + for _ in range(m): + output = self.forward(x) + prob = F.softmax(output, dim=1) + acc_prob.update(prob, scale=1/m) + + prob = acc_prob.get() + + return prob diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..db52820 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,22 @@ +[metadata] +name = torchsso +version = 0.1.1 +url = https://github.com/cybertronai/pytorch-sso +author = Kazuki Osawa +author_email = osawa1021@gmail.com +license_file = LICENSE +description = PyTorch-SSO: Scalable Second-Order Optimization Methods in PyTorch. +long_description = file: README.md +classifiers = + Programming Language :: Python :: 3 + +[options] +zip_safe = False +packages = find: +install_requires = + torch + torchvision + chainer + Pillow + numpy + scipy diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..8bf1ba9 --- /dev/null +++ b/setup.py @@ -0,0 +1,2 @@ +from setuptools import setup +setup() diff --git a/torchsso/__init__.py b/torchsso/__init__.py new file mode 100644 index 0000000..ea313da --- /dev/null +++ b/torchsso/__init__.py @@ -0,0 +1,17 @@ +from torchsso import optim # NOQA +from torchsso import utils # NOQA + +from torchsso.curv.curvature import Curvature, DiagCurvature, KronCurvature # NOQA +from torchsso.curv.cov.linear import CovLinear, DiagCovLinear, KronCovLinear # NOQA +from torchsso.curv.cov.conv import CovConv2d, DiagCovConv2d, KronCovConv2d # NOQA +from torchsso.curv.cov.batchnorm import CovBatchNorm1d, DiagCovBatchNorm1d, CovBatchNorm2d, DiagCovBatchNorm2d # NOQA + +from torchsso.curv.hessian import KronHessian # NOQA +from torchsso.curv.hessian.linear import KronHessianLinear # NOQA +from torchsso.curv.hessian.conv import KronHessianConv2d # NOQA + +from torchsso.curv.fisher import get_closure_for_fisher # NOQA +from torchsso.curv.fisher import Fisher # NOQA +from torchsso.curv.fisher.linear import DiagFisherLinear, KronFisherLinear # NOQA +from torchsso.curv.fisher.conv import DiagFisherConv2d, KronFisherConv2d # NOQA +from torchsso.curv.fisher.batchnorm import DiagFisherBatchNorm2d # NOQA diff --git a/torchsso/curv/__init__.py b/torchsso/curv/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/torchsso/curv/cov/__init__.py b/torchsso/curv/cov/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/torchsso/curv/cov/batchnorm.py b/torchsso/curv/cov/batchnorm.py new file mode 100644 index 0000000..38eb41a --- /dev/null +++ b/torchsso/curv/cov/batchnorm.py @@ -0,0 +1,49 @@ +from torchsso import Curvature, DiagCurvature + + +class CovBatchNorm1d(Curvature): + + def update_in_backward(self, grad_output_data): + pass + + +class DiagCovBatchNorm1d(DiagCurvature): + + def update_in_backward(self, grad_output): + data_input = getattr(self._module, 'data_input', None) # n x f + assert data_input is not None + + in_in = data_input.mul(data_input) # n x f + grad_grad = grad_output.mul(grad_output) # n x f + + data_w = in_in.mul(grad_grad).mean(dim=0) # f x 1 + + self._data = [data_w] + + if self.bias: + data_b = grad_grad.mean(dim=0) # f x 1 + self._data.append(data_b) + + +class CovBatchNorm2d(Curvature): + + def update_in_backward(self, grad_output): + pass + + +class DiagCovBatchNorm2d(DiagCurvature): + + def update_in_backward(self, grad_out): + data_input = getattr(self._module, 'data_input', None) # n x c x h x w + assert data_input is not None + + in_in = data_input.mul(data_input).sum(dim=(2, 3)) # n x c + grad_grad = grad_out.mul(grad_out).sum(dim=(2, 3)) # n x c + + data_w = in_in.mul(grad_grad).mean(dim=0) # c x 1 + + self._data = [data_w] + + if self.bias: + data_b = grad_grad.mean(dim=0) # c x 1 + self._data.append(data_b) diff --git a/torchsso/curv/cov/conv.py b/torchsso/curv/cov/conv.py new file mode 100644 index 0000000..42621e0 --- /dev/null +++ b/torchsso/curv/cov/conv.py @@ -0,0 +1,124 @@ +from torchsso import Curvature, DiagCurvature, KronCurvature +import torch +import torch.nn.functional as F + + +class CovConv2d(Curvature): + + def update_in_backward(self, grad_output): + pass + + def precgrad(self, params): + pass + + +class DiagCovConv2d(DiagCurvature): + + def update_in_backward(self, grad_output): + conv2d = self._module + data_input = getattr(conv2d, 'data_input', None) # n x c_in x h_in x w_in + assert data_input is not None + + # n x (c_in)(k_h)(k_w) x (h_out)(w_out) + input2d = F.unfold(data_input, + kernel_size=conv2d.kernel_size, stride=conv2d.stride, + padding=conv2d.padding, dilation=conv2d.dilation) + + # n x c_out x h_out x w_out + n, c_out, h, w = grad_output.shape + # n x c_out x (h_out)(w_out) + grad_output2d = grad_output.reshape(n, c_out, -1) + + grad_in = torch.einsum('bik,bjk->bij', + grad_output2d, input2d) # n x c_out x (c_in)(k_h)(k_w) + + data_w = grad_in.mul(grad_in).mean(dim=0) # c_out x (c_in)(k_h)(k_w) + data_w = data_w.reshape((c_out, -1, *conv2d.kernel_size)) # c_out x c_in x k_h x k_w + self._data = [data_w] + + if self.bias: + grad_grad = grad_output2d.mul(grad_output2d) # n x c_out x (h_out)(w_out) + data_b = grad_grad.sum(dim=2).mean(dim=0) # c_out + self._data.append(data_b) + + +class KronCovConv2d(KronCurvature): + + def update_in_forward(self, data_input): + conv2d = self._module + + # n x (c_in)(k_h)(k_w) x (h_out)(w_out) + input2d = F.unfold(data_input, + kernel_size=conv2d.kernel_size, stride=conv2d.stride, + padding=conv2d.padding, dilation=conv2d.dilation) + + n, a, _ = input2d.shape + + # (c_in)(k_h)(k_w) x n(h_out)(w_out) + m = input2d.transpose(0, 1).reshape(a, -1) + a, b = m.shape + if self.bias: + # {(c_in)(k_h)(k_w) + 1} x n(h_out)(w_out) + m = torch.cat((m, m.new_ones((1, b))), 0) + + # (c_in)(k_h)(k_w) x (c_in)(k_h)(k_w) or + # {(c_in)(k_h)(k_w) + 1} x {(c_in)(k_h)(k_w) + 1} + A = torch.einsum('ik,jk->ij', m, m).div(n) + self._A = A + + def update_in_backward(self, grad_output): + n, c, h, w = grad_output.shape # n x c_out x h_out x w_out + m = grad_output.transpose(0, 1).reshape(c, -1) # c_out x n(h_out)(w_out) + + G = torch.einsum('ik,jk->ij', m, m).div(n*h*w) # c_out x c_out + self._G = G + + def precondition_grad(self, params): + A_inv, G_inv = self.inv + + # todo check params == list? + oc, _, _, _ = params[0].shape + if self.bias: + grad2d = torch.cat( + (params[0].grad.reshape(oc, -1), params[1].grad.view(-1, 1)), 1) + preconditioned_grad2d = G_inv.mm(grad2d).mm(A_inv) + + params[0].grad.copy_(preconditioned_grad2d[:, 0:-1].reshape_as(params[0])) + params[1].grad.copy_(preconditioned_grad2d[:, -1]) + else: + grad2d = params[0].grad.reshape(oc, -1) + preconditioned_grad2d = G_inv.mm(grad2d).mm(A_inv) + + params[0].grad.copy_(preconditioned_grad2d.reshape_as(params[0])) + + def sample_params(self, params, mean, std_scale): + A_ic, G_ic = self.std + oc, ic, h, w = mean[0].shape + if self.bias: + m = torch.cat( + (mean[0].reshape(oc, -1), mean[1].view(-1, 1)), 1) + param = m.add(std_scale, G_ic.mm( + torch.randn_like(m)).mm(A_ic)) + params[0].data.copy_(param[:, 0:-1].reshape(oc, ic, h, w)) + params[1].data.copy_(param[:, -1]) + else: + m = mean[0].reshape(oc, -1) + param = m.add(std_scale, G_ic.mm( + torch.randn_like(m)).mm(A_ic)) + params[0].data = param.reshape(oc, ic, h, w) + + def _get_shape(self): + linear = self._module + w = getattr(linear, 'weight') + c_out, c_in, k_h, k_w = w.shape + + G_shape = (c_out, c_out) + + dim = c_in * k_h * k_w + if self.bias: + A_shape = (dim + 1, dim + 1) + else: + A_shape = (dim, dim) + + return A_shape, G_shape + diff --git a/torchsso/curv/cov/linear.py b/torchsso/curv/cov/linear.py new file mode 100644 index 0000000..7c0fdcc --- /dev/null +++ b/torchsso/curv/cov/linear.py @@ -0,0 +1,115 @@ +import torch +from torchsso import Curvature, DiagCurvature, KronCurvature + + +class CovLinear(Curvature): + + def update_in_backward(self, grad_output): + data_input = getattr(self._module, 'data_input', None) # n x f_in + assert data_input is not None + + n = data_input.shape[0] + + if self.bias: + ones = torch.ones((n, 1), device=data_input.device, dtype=data_input.dtype) + data_input = torch.cat((data_input, ones), 1) # n x (f_in+1) + + grad = torch.einsum('bi,bj->bij', grad_output, data_input) # n x f_out x f_in + grad = grad.reshape((n, -1)) # n x (f_out)(f_in) + + data = torch.einsum('bi,bj->ij', grad, grad) + + self._data = [data] + + def precondition_grad(self, params): + pass + + +class DiagCovLinear(DiagCurvature): + + def update_in_backward(self, grad_output): + data_input = getattr(self._module, 'data_input', None) # n x f_in + assert data_input is not None + + n = data_input.shape[0] + + in_in = data_input.mul(data_input) # n x f_in + grad_grad = grad_output.mul(grad_output) # n x f_out + + data_w = torch.einsum('ki,kj->ij', grad_grad, + in_in).div(n) # f_out x f_in + self._data = [data_w] + + if self.bias: + data_b = grad_grad.mean(dim=0) # f_out x 1 + self._data.append(data_b) + + +class KronCovLinear(KronCurvature): + + def update_in_forward(self, input_data): + n = input_data.shape[0] # n x f_in + if self.bias: + ones = input_data.new_ones((n, 1)) + # shape: n x (f_in+1) + input_data = torch.cat((input_data, ones), 1) + + # f_in x f_in or (f_in+1) x (f_in+1) + A = torch.einsum('ki,kj->ij', input_data, input_data).div(n) + self._A = A + + def update_in_backward(self, grad_output): + n = grad_output.shape[0] # n x f_out + + # f_out x f_out + G = torch.einsum( + 'ki,kj->ij', grad_output, grad_output).div(n) + self._G = G + + def precondition_grad(self, params): + A_inv, G_inv = self.inv + + # todo check params == list? + if self.bias: + grad = torch.cat( + (params[0].grad, params[1].grad.view(-1, 1)), 1) + preconditioned_grad = G_inv.mm(grad).mm(A_inv) + + params[0].grad.copy_(preconditioned_grad[:, :-1]) + params[1].grad.copy_(preconditioned_grad[:, -1]) + else: + grad = params[0].grad + preconditioned_grad = G_inv.mm(grad).mm(A_inv) + + params[0].grad.copy_(preconditioned_grad) + + def sample_params(self, params, mean, std_scale): + A_ic, G_ic = self.std + + if self.bias: + m = torch.cat( + (mean[0], mean[1].view(-1, 1)), 1) + param = m.add(std_scale, G_ic.mm( + torch.randn_like(m)).mm(A_ic)) + params[0].data.copy_(param[:, 0:-1]) + params[1].data.copy_(param[:, -1]) + else: + m = mean[0] + param = mean.add(std_scale, G_ic.mm( + torch.randn_like(m)).mm(A_ic)) + params[0].data = param + + def _get_shape(self): + linear = self._module + w = getattr(linear, 'weight') + f_out, f_in = w.shape + + G_shape = (f_out, f_out) + + if self.bias: + A_shape = (f_in + 1, f_in + 1) + else: + A_shape = (f_in, f_in) + + return A_shape, G_shape + diff --git a/torchsso/curv/curvature.py b/torchsso/curv/curvature.py new file mode 100644 index 0000000..ddfcbc2 --- /dev/null +++ b/torchsso/curv/curvature.py @@ -0,0 +1,327 @@ +import math + +import torch +import torch.nn as nn +import torchsso + +PI_TYPE_TRACENORM = 'tracenorm' + + +class Curvature(object): + r"""Base implementation of the curvatures for each layer. + + This class computes/maintains curvature (data) and EMA/inverse of it for a given layer (module) + which are used for torchsso.optim.SecondOrderOptimizer. + Standard deviation (std) is calculated for torchsso.optim.VIOptimizer based on the inverse. + IE, data -> ema -> inv (-> std) + + Args: + module (torch.nn.Module): a layer with trainable params for which the curvature is computed + ema_decay (float, optional): decay rate for EMA of curvature + damping (float, optional): value to be added to the diagonal of EMA before inverting it + use_max_ema (bool, optional): whether to use the maximum value as EMA + use_sqrt_ema (bool, optional): whether to take the squre root of EMA + """ + + def __init__(self, module: nn.Module, ema_decay=1., damping=1e-7, + use_max_ema=False, use_sqrt_ema=False, + pi_type=PI_TYPE_TRACENORM): + + if ema_decay < 0 or 1 < ema_decay: + raise ValueError("Invalid ema_decay: {}".format(ema_decay)) + if damping < 0: + raise ValueError("Invalid damping: {}".format(damping)) + if pi_type not in [PI_TYPE_TRACENORM]: + raise ValueError("Invalid pi_type: {}".format(pi_type)) + + self._module = module + self.ema_decay = ema_decay + self._damping = damping + self._l2_reg = 0 + self._l2_reg_ema = 0 + + self._data = None + self._acc_data = None + self.ema = None + self.ema_max = None + self.inv = None + self.std = None + + self.use_sqrt_ema = use_sqrt_ema + self.use_max_ema = use_max_ema + + self.pi_type = pi_type + + module.register_forward_hook(self.forward_postprocess) + module.register_backward_hook(self.backward_postprocess) + + @property + def data(self): + return self._data + + @data.setter + def data(self, value): + self._data = value + + @property + def shape(self): + if self._data is None: + return self._get_shape() + + return tuple([d.shape for d in self._data]) + + @property + def device(self): + return next(self._module.parameters()).device + + def _get_shape(self): + size = 0 + for p in self._module.parameters(): + size += p.view(-1).shape[0] + + return tuple((size, size)) + + def element_wise_init(self, value): + init_data = [] + for s in self.shape: + diag = torch.ones(s[0], device=self.device).mul(value) # 1d + diag = torch.diag(diag) # 1d -> 2d + init_data.append(diag) + + self._data = init_data + + @property + def module(self): + return self._module + + @property + def bias(self): + bias = getattr(self._module, 'bias', None) + return False if bias is None else True + + @property + def damping(self): + return self._damping + self._l2_reg_ema + + @property + def l2_reg(self): + return self._l2_reg + + @l2_reg.setter + def l2_reg(self, value): + self._l2_reg = value + + @property + def l2_reg_ema(self): + return self._l2_reg_ema + + def forward_postprocess(self, module, input, output): + assert self._module == module + + data_input = input[0].detach() + + setattr(module, 'data_input', data_input) + setattr(module, 'data_output', output) + + self.update_in_forward(data_input) + + def backward_postprocess(self, module, grad_input, grad_output): + assert self._module == module + + index = 1 if self.bias else 0 + grad_input = None if grad_input[index] is None else grad_input[index].detach() + grad_output = grad_output[0] + + setattr(module, 'grad_input', grad_input) + setattr(module, 'grad_output', grad_output) + + self.update_in_backward(grad_output) + + # adjust grad scale along with 'reduction' in loss function + batch_size = grad_output.shape[0] + self.adjust_data_scale(batch_size**2) + + def adjust_data_scale(self, scale): + self._data = [d.mul(scale) for d in self._data] + + def update_in_forward(self, data_input): + pass + + def update_in_backward(self, grad_output): + raise NotImplementedError + + def step(self, update_std=False, update_inv=True): + # TODO(oosawak): Add check for ema/inv timing + self.update_ema() + if update_inv: + self.update_inv() + if update_std: + self.update_std() + + def update_ema(self): + data = self.data + ema = self.ema + ema_max = self.ema_max + beta = self.ema_decay + if ema is None or beta == 1: + self.ema = [d.clone() for d in data] + if self.use_max_ema and ema_max is None: + self.ema_max = [e.clone() for e in self.ema] + self._l2_reg_ema = self._l2_reg + else: + self.ema = [d.mul(beta).add(1 - beta, e) + for d, e in zip(data, ema)] + self._l2_reg_ema = self._l2_reg * beta + self._l2_reg_ema * (1 - beta) + + if self.use_max_ema: + for e, e_max in zip(self.ema, self.ema_max): + torch.max(e, e_max, out=e_max) + + def update_inv(self): + ema = self.ema if not self.use_max_ema else self.ema_max + self.inv = [self._inv(e) for e in ema] + + def _inv(self, X): + X_damp = add_value_to_diagonal(X, self.damping) + + return torchsso.utils.inv(X_damp) + + def precondition_grad(self, params): + raise NotImplementedError + + def update_std(self): + raise NotImplementedError + + def sample_params(self, params, mean, std_scale): + raise NotImplementedError + + def std_norm(self): + raise NotImplementedError + + +class DiagCurvature(Curvature): + + def _get_shape(self): + return tuple(p.shape for p in self.module.parameters()) + + def element_wise_init(self, value): + self._data = [torch.ones(s, device=self.device).mul(value) for s in self.shape] + + def update_in_backward(self, grad_output_data): + raise NotImplementedError + + def _inv(self, X): + if self.use_sqrt_ema: + X = X.sqrt() + + X_damp = X.add(X.new_ones(X.shape).mul(self.damping)) + + return 1 / X_damp + + def precondition_grad(self, params): + for p, inv in zip(params, self.inv): + preconditioned_grad = inv.mul(p.grad) + + p.grad.copy_(preconditioned_grad) + + def update_std(self): + self.std = [inv.sqrt() for inv in self.inv] + + def sample_params(self, params, mean, std_scale): + for p, m, std in zip(params, mean, self.std): + noise = torch.randn_like(m) + p.data.copy_(torch.addcmul(m, std_scale, noise, std)) + + def std_norm(self): + if self.std is None: + return 0 + + return sum(std.norm().item() for std in self.std) + + +class KronCurvature(Curvature): + + def __init__(self, *args, **kwargs): + super(KronCurvature, self).__init__(*args, **kwargs) + + self._A = None + self._G = None + + @property + def data(self): + return [self._A, self._G] + + @data.setter + def data(self, value): + self._A, self._G = value + + @property + def shape(self): + if self._A is None or self._G is None: + return self._get_shape() + + return self._A.shape, self._G.shape + + def _get_shape(self): + raise NotImplementedError + + def element_wise_init(self, value): + super(KronCurvature, self).element_wise_init(math.sqrt(value)) + self._A, self._G = self._data + + @property + def A(self): + return self._A + + @property + def G(self): + return self._G + + def update_in_forward(self, input_data): + raise NotImplementedError + + def update_in_backward(self, grad_output_data): + raise NotImplementedError + + def adjust_data_scale(self, scale): + self._G.mul_(scale) + + def update_inv(self): + A, G = self.ema + + if self.pi_type == PI_TYPE_TRACENORM: + pi = torch.sqrt((A.trace()/A.shape[0])/(G.trace()/G.shape[0])) + else: + pi = 1. + + r = self.damping**0.5 + self.inv = [torchsso.utils.inv(add_value_to_diagonal(X, value)) + for X, value in zip([A, G], [r*pi, r/pi])] + + def precondition_grad(self, params): + raise NotImplementedError + + def update_std(self): + A_inv, G_inv = self.inv + + self.std = [torchsso.utils.cholesky(X) + for X in [A_inv, G_inv]] + + def sample_params(self, params, mean, std_scale): + raise NotImplementedError + + def std_norm(self): + if self.std is None: + return 0 + + A_ic, G_ic = self.std + return A_ic.norm().item() * G_ic.norm().item() + + +def add_value_to_diagonal(X, value): + if torch.cuda.is_available(): + indices = torch.cuda.LongTensor([[i, i] for i in range(X.shape[0])]) + else: + indices = torch.LongTensor([[i, i] for i in range(X.shape[0])]) + values = X.new_ones(X.shape[0]).mul(value) + return X.index_put(tuple(indices.t()), values, accumulate=True) diff --git a/torchsso/curv/fisher/__init__.py b/torchsso/curv/fisher/__init__.py new file mode 100644 index 0000000..77277ac --- /dev/null +++ b/torchsso/curv/fisher/__init__.py @@ -0,0 +1,90 @@ +import torch +import torch.nn.functional as F + +from torchsso.utils import TensorAccumulator + + +class Fisher(object): + + def __init__(self): + self.prob = None + self._do_backward = True + self._acc_cov = TensorAccumulator() + + @property + def do_backward(self): + return self._do_backward + + def turn_on_backward(self): + self._do_backward = True + + def turn_off_backward(self): + self._do_backward = False + + def accumulate_cov(self, cov): + self._acc_cov.update(cov) + + def finalize(self): + return self._acc_cov.get() + + def update_as_presoftmax(self, prob): + raise NotImplementedError('This method supports only torchsso.KronFisherLinear.') + + +def get_closure_for_fisher(optimizer, model, data, target, approx_type=None, num_mc=1): + + _APPROX_TYPE_MC = 'mc' + + def turn_off_param_grad(): + for group in optimizer.param_groups: + group['curv'].turn_on_backward() + for param in group['params']: + param.requires_grad = False + + def turn_on_param_grad(): + for group in optimizer.param_groups: + group['curv'].turn_off_backward() + for param in group['params']: + param.requires_grad = True + + def closure(): + + for group in optimizer.param_groups: + assert isinstance(group['curv'], Fisher), f"Invalid Curvature type: {type(group['curv'])}." + + optimizer.zero_grad() + output = model(data) + prob = F.softmax(output, dim=1) + + is_sampling = approx_type is None or approx_type == _APPROX_TYPE_MC + + if is_sampling: + turn_off_param_grad() + + if approx_type == _APPROX_TYPE_MC: + dist = torch.distributions.Categorical(prob) + _target = dist.sample((num_mc,)) + for group in optimizer.param_groups: + group['curv'].prob = torch.ones_like(prob[:, 0]).div(num_mc) + + for i in range(num_mc): + loss = F.cross_entropy(output, _target[i]) + loss.backward(retain_graph=True) + else: + for i in range(model.num_classes): + for group in optimizer.param_groups: + group['curv'].prob = prob[:, i] + loss = F.cross_entropy(output, torch.ones_like(target).mul(i)) + loss.backward(retain_graph=True) + + turn_on_param_grad() + + else: + raise ValueError('Invalid approx type: {}'.format(approx_type)) + + loss = F.cross_entropy(output, target) + loss.backward() + + return loss, output + + return closure diff --git a/torchsso/curv/fisher/batchnorm.py b/torchsso/curv/fisher/batchnorm.py new file mode 100644 index 0000000..b0ae2fb --- /dev/null +++ b/torchsso/curv/fisher/batchnorm.py @@ -0,0 +1,32 @@ +import torch +from torchsso import DiagCovBatchNorm2d, Fisher + + +class DiagFisherBatchNorm2d(DiagCovBatchNorm2d, Fisher): + + def __init__(self, *args, **kwargs): + DiagCovBatchNorm2d.__init__(self, *args, **kwargs) + Fisher.__init__(self) + + def update_in_backward(self, grad_out): + if self.do_backward: + assert self.prob is not None + data_input = getattr(self._module, 'data_input', None) # n x c x h x w + assert data_input is not None + + n = grad_out.shape[0] # n x c x h x w + pg = torch.mul(grad_out, self.prob.reshape(n, 1, 1, 1)) + + grad_grad = pg.mul(grad_out).sum(dim=(2, 3)) # n x c + in_in = data_input.mul(data_input).sum(dim=(2, 3)) # n x c + + data_w = in_in.mul(grad_grad).mean(dim=0) # c x 1 + + self._data = [data_w] + + if self.bias: + data_b = grad_grad.mean(dim=0) # c x 1 + self._data.append(data_b) + self.accumulate_cov(self._data) + else: + self._data = self.finalize() diff --git a/torchsso/curv/fisher/conv.py b/torchsso/curv/fisher/conv.py new file mode 100644 index 0000000..3f0fa86 --- /dev/null +++ b/torchsso/curv/fisher/conv.py @@ -0,0 +1,69 @@ +from torchsso import DiagCovConv2d, KronCovConv2d, Fisher +import torch +import torch.nn.functional as F + + +class DiagFisherConv2d(DiagCovConv2d, Fisher): + + def __init__(self, *args, **kwargs): + DiagCovConv2d.__init__(self, *args, **kwargs) + Fisher.__init__(self) + + def update_in_backward(self, grad_output): + + if self.do_backward: + assert self.prob is not None + + conv2d = self._module + data_input = getattr(conv2d, 'data_input', None) # n x c_in x h_in x w_in + assert data_input is not None + + # n x (c_in)(k_h)(k_w) x (h_out)(w_out) + input2d = F.unfold(data_input, + kernel_size=conv2d.kernel_size, stride=conv2d.stride, + padding=conv2d.padding, dilation=conv2d.dilation) + + # n x c_out x h_out x w_out + n, c_out, h, w = grad_output.shape + # n x c_out x (h_out)(w_out) + grad_output2d = grad_output.reshape(n, c_out, -1) + + grad_in = torch.einsum('bik,bjk->bij', + grad_output2d, input2d) # n x c_out x (c_in)(k_h)(k_w) + + pgi = torch.mul(grad_in, self.prob.reshape(n, 1, 1)) + data_w = pgi.mul(grad_in).mean(dim=0) # c_out x (c_in)(k_h)(k_w) + data_w = data_w.reshape((c_out, -1, *conv2d.kernel_size)) # c_out x c_in x k_h x k_w + self._data = [data_w] + + if self.bias: + pg = torch.mul(grad_output2d, self.prob.reshape(n, 1, 1)) + grad_grad = pg.mul(grad_output2d) # n x c_out x (h_out)(w_out) + data_b = grad_grad.sum(dim=2).mean(dim=0) # c_out + self._data.append(data_b) + + self.accumulate_cov(self._data) + else: + self._data = self.finalize() + + +class KronFisherConv2d(KronCovConv2d, Fisher): + + def __init__(self, *args, **kwargs): + KronCovConv2d.__init__(self, *args, **kwargs) + Fisher.__init__(self) + + def update_in_backward(self, grad_output): + if self.do_backward: + assert self.prob is not None + n, c, h, w = grad_output.shape # n x c_out x h_out x w_out + + pg = torch.mul(grad_output, self.prob.reshape(n, 1, 1, 1)) + pm = pg.transpose(0, 1).reshape(c, -1) # c_out x n(h_out)(w_out) + m = grad_output.transpose(0, 1).reshape(c, -1) # c_out x n(h_out)(w_out) + + G = torch.einsum('ik,jk->ij', pm, m).div(n*h*w) # c_out x c_out + self._G = G + self.accumulate_cov(G) + else: + self._G = self.finalize() diff --git a/torchsso/curv/fisher/linear.py b/torchsso/curv/fisher/linear.py new file mode 100644 index 0000000..1e81a52 --- /dev/null +++ b/torchsso/curv/fisher/linear.py @@ -0,0 +1,64 @@ +import torch +from torchsso import DiagCovLinear, KronCovLinear, Fisher + + +class DiagFisherLinear(DiagCovLinear, Fisher): + + def __init__(self, *args, **kwargs): + DiagCovLinear.__init__(self, *args, **kwargs) + Fisher.__init__(self) + + def update_in_backward(self, grad_output): + if self.do_backward: + assert self.prob is not None + + data_input = getattr(self._module, 'data_input', None) # n x f_in + assert data_input is not None + + n = data_input.shape[0] + + in_in = data_input.mul(data_input) # n x f_in + + pg = torch.mul(grad_output, self.prob.reshape(n, 1)) + grad_grad = pg.mul(grad_output) # n x f_out + + data_w = torch.einsum('ki,kj->ij', grad_grad, + in_in).div(n) # f_out x f_in + self._data = [data_w] + + if self.bias: + data_b = grad_grad.mean(dim=0) # f_out x 1 + self._data.append(data_b) + + self.accumulate_cov(self._data) + else: + self._data = self.finalize() + + +class KronFisherLinear(KronCovLinear, Fisher): + + def __init__(self, *args, **kwargs): + KronCovLinear.__init__(self, *args, **kwargs) + Fisher.__init__(self) + + def update_in_backward(self, grad_output): + if self.do_backward: + assert self.prob is not None + n = grad_output.shape[0] # n x f_out + + pg = torch.mul(grad_output, self.prob.reshape(n, 1)) + + # f_out x f_out + G = torch.einsum( + 'ki,kj->ij', pg, grad_output).div(n) + self._G = G + self.accumulate_cov(G) + else: + self._G = self.finalize() + + def update_as_presoftmax(self, prob): + n, dim = prob.shape + cov = torch.einsum('ki,kj->ij', prob, prob).div(n) + fisher_presoftmax = (torch.diag(prob.sum(dim=0)) - cov).div(n) + self._G = fisher_presoftmax + diff --git a/torchsso/curv/hessian/__init__.py b/torchsso/curv/hessian/__init__.py new file mode 100644 index 0000000..decf154 --- /dev/null +++ b/torchsso/curv/hessian/__init__.py @@ -0,0 +1,135 @@ +import torch +from torchsso import KronCurvature + + +class KronHessian(KronCurvature): + + def update_in_forward(self, input_data): + raise NotImplementedError + + def update_in_backward(self, grad_output): + output = getattr(self._module, 'data_output') + + device = grad_output.device + n = grad_output.shape[0] + dim = grad_output.shape[1] + + post_curv = self.post_curv + + if post_curv is not None: + post_module = post_curv.module + + import time + + print('-----------------') + start = time.time() + + print(self.module) + + if post_curv is not None: + post_module = post_curv.module + print(post_module) + + post_output = getattr(post_module, 'data_output') + post_dim = post_output.shape[1] + + post_out_grad_out = torch.zeros((n, post_dim, dim)) # n x post_dim x dim + if post_dim <= dim: + post_output = reshape_4d_to_2d(post_output) + print('n: {}, dim: {}'.format(len(post_output), post_dim)) + for i in range(post_dim): + outputs = tuple(po[i] for po in post_output) + grad = torch.autograd.grad(outputs, output, create_graph=True) + post_out_grad_out[:, i, :] = reshape_4d_to_2d(grad[0], reduce=True) # n x dim + else: + post_grad_output = getattr(post_module, 'grad_output') + grad_output = reshape_4d_to_2d(grad_output) + print('n: {}, dim: {}'.format(len(grad_output), dim)) + for i in range(dim): + outputs = tuple(g[i] for g in grad_output) + grad = torch.autograd.grad(outputs, post_grad_output, create_graph=True) + post_out_grad_out[:, :, i] = reshape_4d_to_2d(grad[0], reduce=True) # n x post_dim + + post_out_grad_out = post_out_grad_out.to(device) + + recursive_approx = getattr(post_curv, 'recursive_approx', False) + if recursive_approx: + equation = 'bij,ik,bkl->bjl' + post_hessian_output = post_curv.G # post_dim x post_dim + else: + equation = 'bij,bik,bkl->bjl' + post_hessian_output = getattr(post_module, 'hessian_output', None) # n x post_dim x post_dim + + msg = 'hessian of loss w.r.t. outputs of post layer' \ + ' have to be computed beforehand.' + assert post_hessian_output is not None, msg + + # compute sample hessian_output based on hessian_output of post module + hessian_output = torch.einsum(equation, + post_out_grad_out, # n x post_dim x dim + post_hessian_output, # n x post_dim x post_dim + post_out_grad_out) # n x post_dim x dim + + del post_module.hessian_output + del post_out_grad_out + + else: + # compute sample hessian_output from scratch + hessian_output = torch.zeros((n, dim, dim)) + print('n: {}, dim: {}'.format(len(grad_output), dim)) + for i in range(dim): + outputs = tuple(g[i] for g in reshape_4d_to_2d(grad_output)) + grad = torch.autograd.grad(outputs, output, create_graph=True) + hessian_output[:, i, :] = reshape_4d_to_2d(grad[0], reduce=True) + + hessian_output = hessian_output.to(device) + setattr(self._module, 'hessian_output', hessian_output) + + # refresh hessian_output + self._G = hessian_output.sum((0,)) # dim x dim + + elapsed = time.time() - start + print('{}s'.format(elapsed)) + + def precondition_grad(self, params): + raise NotImplementedError + + def sample_params(self, params, mean, std_scale): + raise NotImplementedError + + def backward_postprocess(self, module, grad_input, grad_output): + # skip hook for higher order derivative + order = getattr(module, 'derivative_order', 1) + if order > 1: + return + + super(KronHessian, self).backward_postprocess(module, grad_input, grad_output) + + # skip hook for higher order derivative + setattr(module, 'derivative_order', 2) + + def reset_derivative_order(self): + module = self._module + setattr(module, 'derivative_order', 1) + + def step(self, update_std=False): + super(KronHessian, self).step(update_std) + self.reset_derivative_order() + + +def reshape_4d_to_2d(data, reduce=False): + ndim = len(data.shape) + if ndim == 2: + return data + + assert ndim == 4, 'number of dimension of data is expected to be 4, got {}.'.format(ndim) + + if reduce: + # n x c x h x w -> n x c + return data.sum((2, 3)) + else: + n, c, h, w = data.shape + # n x c x h x w -> n x h x w x c -> n*h*w x c + data = data.transpose(1, 2).transpose(2, 3).contiguous().view(n*h*w, c) + return data + diff --git a/torchsso/curv/hessian/conv.py b/torchsso/curv/hessian/conv.py new file mode 100644 index 0000000..2611ec3 --- /dev/null +++ b/torchsso/curv/hessian/conv.py @@ -0,0 +1,10 @@ +from torchsso import KronCovConv2d, KronHessian + + +class KronHessianConv2d(KronCovConv2d, KronHessian): + + def __init__(self, module, ema_decay=1., damping=0, post_curv=None, recursive_approx=False): + KronHessian.__init__(self, module, ema_decay, damping, post_curv, recursive_approx) + + def update_in_backward(self, grad_output): + KronHessian.update_in_backward(self, grad_output) diff --git a/torchsso/curv/hessian/linear.py b/torchsso/curv/hessian/linear.py new file mode 100644 index 0000000..0b78e61 --- /dev/null +++ b/torchsso/curv/hessian/linear.py @@ -0,0 +1,10 @@ +from torchsso import KronCovLinear, KronHessian + + +class KronHessianLinear(KronCovLinear, KronHessian): + + def __init__(self, module, ema_decay=1., damping=0, post_curv=None, recursive_approx=False): + KronHessian.__init__(self, module, ema_decay, damping, post_curv, recursive_approx) + + def update_in_backward(self, grad_output): + KronHessian.update_in_backward(self, grad_output) diff --git a/torchsso/optim/__init__.py b/torchsso/optim/__init__.py new file mode 100644 index 0000000..2f8db32 --- /dev/null +++ b/torchsso/optim/__init__.py @@ -0,0 +1,4 @@ +from torchsso.optim.firstorder import DistributedFirstOrderOptimizer # NOQA +from torchsso.optim.secondorder import SecondOrderOptimizer, DistributedSecondOrderOptimizer # NOQA +from torchsso.optim.vi import VIOptimizer, DistributedVIOptimizer # NOQA +from torchsso.optim import lr_scheduler # NOQA diff --git a/torchsso/optim/firstorder.py b/torchsso/optim/firstorder.py new file mode 100644 index 0000000..fb627ee --- /dev/null +++ b/torchsso/optim/firstorder.py @@ -0,0 +1,58 @@ +from torch.optim import Optimizer +from torch.nn.utils import parameters_to_vector, vector_to_parameters + + +class DistributedFirstOrderOptimizer(Optimizer): + + def __init__(self, optimizer, model, dist, lars=False): + super(DistributedFirstOrderOptimizer, self).__setattr__( + 'actual_optimizer', optimizer + ) + super(DistributedFirstOrderOptimizer, self).__setattr__( + 'model', model + ) + super(DistributedFirstOrderOptimizer, self).__setattr__( + 'dist', dist + ) + super(DistributedFirstOrderOptimizer, self).__setattr__( + 'lars', lars + ) + + def step(self, closure=None, thr=1e-2, eps=1e-9): + loss = None + if closure is not None: + loss = closure() + world_size = self.dist.get_world_size() + grads = [p.grad for p in self.model.parameters()] + # pack + packed_tensor = parameters_to_vector(grads) + # all reduce + self.dist.all_reduce(packed_tensor) + # unpack + vector_to_parameters(packed_tensor.div_(world_size), grads) + + if self.lars: + for group in self.param_groups: + for p in group['params']: + setattr(p, 'data_pre', p.data.detach().clone()) + + self.actual_optimizer.step(closure=None) + + if self.lars: + for group in self.param_groups: + for p in group['params']: + d_norm_pre = p.data_pre.norm() + if d_norm_pre > thr: + upd = p.data - p.data_pre + upd_norm = upd.norm() + rate = group['lr'] * d_norm_pre / (upd_norm + eps) + p.data = p.data_pre.add(rate, upd) + + return loss + + def __getattr__(self, item): + return getattr(self.actual_optimizer, item) + + def __setattr__(self, key, value): + setattr(self.actual_optimizer, key, value) + diff --git a/torchsso/optim/lr_scheduler.py b/torchsso/optim/lr_scheduler.py new file mode 100644 index 0000000..ce51087 --- /dev/null +++ b/torchsso/optim/lr_scheduler.py @@ -0,0 +1,133 @@ +from torch.optim import Optimizer + + +class _IterLRScheduler(object): + def __init__(self, optimizer, last_iter=-1): + if not isinstance(optimizer, Optimizer): + raise TypeError('{} is not an Optimizer'.format( + type(optimizer).__name__)) + self.optimizer = optimizer + if last_iter == -1: + for group in optimizer.param_groups: + group.setdefault('initial_lr', group['lr']) + else: + for i, group in enumerate(optimizer.param_groups): + if 'initial_lr' not in group: + raise KeyError("param 'initial_lr' is not specified " + "in param_groups[{}] when resuming an optimizer".format(i)) + self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) + self.step(last_iter + 1) + self.last_iter = last_iter + self.scheduler_type = 'iter' + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + """ + return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + Arguments: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_lr(self): + raise NotImplementedError + + def step(self, iter=None): + if iter is None: + iter = self.last_iter + 1 + self.last_iter = iter + for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): + param_group['lr'] = lr + + +class PolynomialDecayIterLR(_IterLRScheduler): + """Set the learning rate of each parameter group to the initial lr decayed + by gamma every iter. When last_iter=-1, sets initial lr as lr. + Args: + optimizer (Optimizer): Wrapped optimizer. + gamma (float): Multiplicative factor of learning rate decay. + last_iter (int): The index of last iter. Default: -1. + """ + + def __init__(self, optimizer, rate, max_count, target=None, start_iter=0, last_iter=-1): + self.rate = rate + self.max_count = max_count + self.target = target + self.start_iter = start_iter + super(PolynomialDecayIterLR, self).__init__(optimizer, last_iter) + + def get_lr(self): + if self.last_iter < self.start_iter: + return [param_group['lr'] + for param_group in self.optimizer.param_groups] + decay = max(1-(self.last_iter-self.start_iter) / (self.max_count-self.start_iter), 0) + if self.target is not None: + if self.rate > 0: + return [self.target if self.target / (base_lr * decay ** self.rate) > 1 + else base_lr * decay ** self.rate + for base_lr in self.base_lrs] + else: + return [self.target if self.target / (base_lr * decay ** self.rate) < 1 + else base_lr * decay ** self.rate + for base_lr in self.base_lrs] + return [base_lr * decay ** self.rate + for base_lr in self.base_lrs] + + +class GradualWarmupIterLR(_IterLRScheduler): + """Set the learning rate of each parameter group to the initial lr decayed + by gamma every iter. When last_iter=-1, sets initial lr as lr. + Args: + optimizer (Optimizer): Wrapped optimizer. + gamma (float): Multiplicative factor of learning rate decay. + last_iter (int): The index of last iter. Default: -1. + """ + + def __init__(self, optimizer, initial_lr, max_count, last_iter=-1): + self.initial_lr = initial_lr + self.max_count = max_count + super(GradualWarmupIterLR, self).__init__(optimizer, last_iter) + + def get_lr(self): + if self.last_iter > self.max_count: + return [param_group['lr'] + for param_group in self.optimizer.param_groups] + else: + alpha = self.last_iter / self.max_count + return [self.initial_lr*(1-alpha) + base_lr*alpha + for base_lr in self.base_lrs] + + +class MomentumCorrectionLR(object): + + def __init__(self, scheduler): + super(MomentumCorrectionLR, self).__setattr__( + 'scheduler', scheduler) + + for group in self.optimizer.param_groups: + group['init_momentum'] = group['momentum'] + + def step(self, count=None): + self.scheduler.step(count) + + for group in self.optimizer.param_groups: + lr = group['lr'] + lr_pre = group.get('lr_pre', None) + + if lr_pre is not None: + m = group.get('init_momentum', 0) + group['momentum'] = m * lr / lr_pre + + group['lr_pre'] = group['lr'] + + def __getattr__(self, item): + return getattr(self.scheduler, item) + + def __setattr__(self, key, value): + setattr(self.scheduler, key, value) diff --git a/torchsso/optim/secondorder.py b/torchsso/optim/secondorder.py new file mode 100644 index 0000000..fe97a7f --- /dev/null +++ b/torchsso/optim/secondorder.py @@ -0,0 +1,391 @@ +from collections import defaultdict +import math + +import numpy as np + +import torch +import torch.nn as nn +from torch.optim import Optimizer +import torchsso +from torchsso.utils import TensorAccumulator +from torchsso.utils.chainer_communicators import create_communicator +from torchsso.utils.chainer_communicators import _utility + + +class SecondOrderOptimizer(Optimizer): + r"""An optimizer for Second-Order Optimization. + + This optimizer manages the curvatures for each layer as a collection + of torchsso.Curvature instance. + This optimizer updates the params with the gradients pre-conditioned + by the inverse of the curvature for each layer. + + Args: + model (torch.nn.Module): model with parameters to be trained + curv_type (str): type of the curvature ('Hessian', 'Fisher', or 'Cov') + curv_shapes (dict): shape the curvatures for each type of layer + curv_kwargs (dict): arguments (with keys) to be passed to torchsso.Curvature.__init__() + lr (float, optional): learning rate + momentum (float, optional): momentum factor + momentum_type (str, optional): type of gradients of which momentum + is calculated ('raw' or 'preconditioned') + grad_ema_decay (float, optional): decay rate for EMA of gradients + grad_ema_type (str, optional): type of gradients of which EMA + is calculated ('raw' or 'preconditioned') + l2_reg (float, optional): L2 penalty + weight_decay (float, optional): weight decay + normalizing_weights (bool, optional): whether the scale of the params + are normalized after each step + weight_scale (float, optional): the scale of the params for normalizing weights + acc_steps (int, optional): number of steps for which gradients and curvatures + are accumulated before each step + non_reg_for_bn (bool, optional): whether the regularization is applied to BatchNorm params + bias_correction (bool, optional): whether the bias correction (refer torch.optim.Adam) is applied + lars (bool, optional): whether LARS (https://arxiv.org/abs/1708.03888) is applied + lars_type (str, optional): type of gradients of which LARS + is applied ('raw' or 'preconditioned') + update_inv (bool, optional): whether to update curvature inverses at each step + precondition_grad (bool, optional): whether to apply preconditioning + (if False, this optimizer works as SGD) + + Example: + >>> curv_shapes = {"Conv2d": "Kron", "Linear": "Diag"} + >>> curv_kwargs = {"damping": 1e-3, "ema_decay": 0.999} + >>> optimizer = torchsso.optim.SecondOrderOptimizer(model, "Cov", curv_shapes, curv_kwargs) + >>> + >>> def closure(): + >>> optimizer.zero_grad() + >>> output = model(data) + >>> loss = F.cross_entropy(output, target) + >>> loss.backward(create_graph=args.create_graph) + >>> return loss, output + >>> + >>> optimizer.step(closure=closure) + """ + + def __init__(self, model: nn.Module, curv_type: str, curv_shapes: dict, curv_kwargs: dict, + lr=0.01, momentum=0., momentum_type='preconditioned', + grad_ema_decay=1., grad_ema_type='raw', l2_reg=0., weight_decay=0., + normalizing_weights=False, weight_scale=None, + acc_steps=1, non_reg_for_bn=False, bias_correction=False, + lars=False, lars_type='preconditioned', update_inv=True, precondition_grad=True): + + if lr < 0: + raise ValueError("Invalid learning rate: {}".format(lr)) + if momentum < 0: + raise ValueError("Invalid momentum: {}".format(momentum)) + if momentum > 0 and momentum_type not in ['raw', 'preconditioned']: + raise ValueError("Invalid momentum type: {}".format(momentum_type)) + if grad_ema_decay < 0 or 1 < grad_ema_decay: + raise ValueError("Invalid grad_ema value: {}".format(grad_ema_decay)) + if grad_ema_decay > 0 and grad_ema_type not in ['raw', 'preconditioned']: + raise ValueError("Invalid grad_ema type: {}".format(grad_ema_type)) + if l2_reg < 0: + raise ValueError("Invalid l2_reg value: {}".format(l2_reg)) + if weight_decay < 0: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if acc_steps < 1: + raise ValueError("Invalid acc_steps: {}".format(acc_steps)) + if lars and lars_type not in ['raw', 'preconditioned']: + raise ValueError("Invalid LARS type: {}".format(lars_type)) + if normalizing_weights and weight_scale is not None and weight_scale <= 0: + raise ValueError("Invalid weight scale for LARS: {}".format(weight_scale)) + + self.model = model + defaults = {'lr': lr, 'momentum': momentum, 'momentum_type': momentum_type, + 'grad_ema_decay': grad_ema_decay, 'grad_ema_type': grad_ema_type, + 'l2_reg': l2_reg, 'weight_decay': weight_decay, + 'normalizing_weights': normalizing_weights, 'weight_scale': weight_scale, + 'acc_steps': acc_steps, 'bias_correction': bias_correction, + 'lars': lars, 'lars_type': lars_type} + defaults.update(curv_kwargs) + self.defaults = defaults + self.state = defaultdict(dict) + self.optim_state = {'step': 0, 'acc_step': 0} + + self.param_groups = [] + self.curv_type = curv_type + self.curv_shapes = {} if curv_shapes is None else curv_shapes + self.update_inv = update_inv + self.precondition_grad = precondition_grad + + for module in model.modules(): + if len(list(module.children())) > 0: + continue + params = list(module.parameters()) + + curv_class = self.get_curv_class(module) + curvature = curv_class(module, **curv_kwargs) + + group = { + 'params': params, + 'curv': curvature, + 'acc_curv': TensorAccumulator(), + 'acc_grads': TensorAccumulator() + } + + self.add_param_group(group) + self.init_buffer(params) + + if non_reg_for_bn and \ + isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): + group['l2_reg'] = 0 + group['weight_decay'] = 0 + group['normalizing_weights'] = False + + def init_buffer(self, params): + for p in params: + state = self.state[p] + state['momentum_buffer'] = torch.zeros_like(p.data) + state['grad_ema_buffer'] = torch.zeros_like(p.data) + + @property + def local_param_groups(self): + return self.param_groups + + def get_curv_class(self, module): + module_name = module.__class__.__name__ + curv_shape = self.curv_shapes.get(module_name, '') + curv_name = curv_shape + self.curv_type + module_name + curv_class = getattr(torchsso, curv_name, None) + + assert curv_class is not None, f"Failed to lookup Curvature class {curv_name} for {module}." + + return curv_class + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + + n = self.defaults['acc_steps'] + loss = None + + if closure is not None: + # forward and backward + loss = closure() + + # accumulate + for group in self.param_groups: + params = group['params'] + + grads = [p.grad.data for p in params] + group['acc_grads'].update(grads, scale=1/n) + + curv = group['curv'] + if curv is not None: + group['acc_curv'].update(curv.data, scale=1/n) + + # update acc step + self.optim_state['acc_step'] += 1 + if self.optim_state['acc_step'] < n: + return loss + else: + self.optim_state['acc_step'] = 0 + + self.backward_postprocess() + + self.optim_state['step'] += 1 + + for group in self.local_param_groups: + + self.update_preprocess(group, grad_type='raw') + + # update curvature + params, curv = group['params'], group['curv'] + if curv is not None: + curv.step(update_inv=self.update_inv) + if self.precondition_grad: + curv.precondition_grad(params) + + # update params + self.update_preprocess(group, grad_type='preconditioned') + self.update(group) + self.update_postprocess(group) + + return loss + + def backward_postprocess(self, target='params'): + for group in self.param_groups: + params = group[target] + + acc_grads = group['acc_grads'].get() + for p, acc_grad in zip(params, acc_grads): + p.grad = acc_grad.clone() + + curv = group['curv'] + if curv is not None: + curv.data = group['acc_curv'].get() + + def update(self, group, target='params'): + params = group[target] + for p in params: + grad = p.grad + if grad is None: + continue + p.data.add_(-group['lr'], grad) + + def update_preprocess(self, group, target='params', grad_type='raw'): + assert grad_type in ['raw', 'preconditioned'], 'Invalid grad type: {}.'.format(grad_type) + params = group[target] + state = self.state + + def apply_l2_reg(p, grad): + if group['l2_reg'] != 0: + if grad.is_sparse: + raise RuntimeError( + "l2 regularization option is not compatible with sparse gradients") + grad.add_(group['l2_reg'], p.data) + curv = group['curv'] + if curv is not None: + curv.l2_reg = group['l2_reg'] + + def apply_weight_decay(p, grad): + if group['weight_decay'] != 0: + if hasattr(grad, 'is_sparse') and grad.is_sparse: + raise RuntimeError( + "weight_decay option is not compatible with sparse gradients") + grad.add_(group['weight_decay'], p.data) + + def apply_momentum(p, grad): + momentum = group['momentum'] + + if momentum != 0: + buf = state[p]['momentum_buffer'] + buf.mul_(momentum).add_(grad) + grad.copy_(buf) + + def apply_grad_ema_decay(p, grad): + grad_ema_decay = group['grad_ema_decay'] + if grad_ema_decay != 1: + buf = state[p]['grad_ema_buffer'] + buf.mul_(1 - grad_ema_decay).add_(grad.mul(grad_ema_decay)) + grad.copy_(buf) + + def apply_bias_correction(grad): + curv = group['curv'] + beta1 = 1 - group['grad_ema_decay'] + beta2 = 1 - curv.ema_decay + + bias_correction1 = 1 - beta1 ** self.optim_state['step'] + bias_correction2 = 1 - beta2 ** self.optim_state['step'] + if getattr(curv, 'use_sqrt_ema', False): + bias_correction2 = math.sqrt(bias_correction2) + + grad.mul_(bias_correction2 / bias_correction1) + + def apply_lars(p, grad, thr=1e-2, eps=1e-9): + d_norm = p.data.norm() + if d_norm > thr: + g_norm = grad.norm() + rate = d_norm / (g_norm + eps) + grad.mul_(rate) + + for p in params: + + grad = p.grad + + if grad is None: + continue + + if grad_type == 'raw': + apply_l2_reg(p, grad) + + if grad_type == 'preconditioned': + apply_weight_decay(p, grad) + + if group['momentum_type'] == grad_type: + apply_momentum(p, grad) + + if group['grad_ema_type'] == grad_type: + apply_grad_ema_decay(p, grad) + + if grad_type == 'preconditioned' and group['bias_correction']: + apply_bias_correction(grad) + + if group['lars_type'] == grad_type and group['lars']: + apply_lars(p, grad) + + def update_postprocess(self, group, target='params'): + params = group[target] + curv = group['curv'] + + def apply_normalizing_weights(p, thr=1e-2, eps=1e-9): + d_norm = p.data.norm() + if d_norm > thr: + scale = group['weight_scale'] + if scale is None: + scale = np.sqrt(2.0 * w.data.shape[0]) + p.data.div_(d_norm + eps).mul_(scale) + + if group['normalizing_weights']: + for p, _p in zip(params, group['params']): + w = getattr(curv.module, 'weight', None) + if w is not None and w is _p: + apply_normalizing_weights(p) + + +class DistributedSecondOrderOptimizer(SecondOrderOptimizer): + + def __init__(self, *args, **kwargs): + + self.actual_optimizer.__init__(self, *args, **kwargs) + + self.comm = create_communicator() + + local_size = self.comm.size + local_rank = self.comm.rank + indices = np.array_split(np.arange(len(self.param_groups)), local_size) + indices = [local_indices.tolist() for local_indices in indices] + local_indices = indices[local_rank] + local_param_groups = [self.param_groups[i] for i in local_indices] + + self.indices = indices + self.local_indices = local_indices + self._local_param_groups = local_param_groups + setattr(self.comm, 'indices', indices) + + @property + def actual_optimizer(self): + return SecondOrderOptimizer + + @property + def local_param_groups(self): + return self._local_param_groups + + def extractors_for_rsv(self): + extractors = [_utility.extract_attr_from_params('grad'), + _utility.extract_attr_from_curv('data', True)] + return extractors + + def extractors_for_agv(self): + extractors = [_utility.extract_attr_from_params('data')] + return extractors + + def backward_postprocess(self, target='params'): + self.actual_optimizer.backward_postprocess(self, target) + # reduce_scatter_v + self.comm.reduce_scatterv_data(self.param_groups, self.extractors_for_rsv()) + + def is_updated(self): + return self.optim_state['acc_step'] == 0 + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + + ret = self.actual_optimizer.step(self, closure) + + if self.is_updated(): + # all_gather_v + self.comm.allgatherv_data(self.param_groups, self.extractors_for_agv()) + + return ret + diff --git a/torchsso/optim/vi.py b/torchsso/optim/vi.py new file mode 100644 index 0000000..eacc19a --- /dev/null +++ b/torchsso/optim/vi.py @@ -0,0 +1,314 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchsso.optim import SecondOrderOptimizer, DistributedSecondOrderOptimizer +from torchsso.utils import TensorAccumulator +from torchsso.utils.chainer_communicators import _utility + + +class VIOptimizer(SecondOrderOptimizer): + r"""An optimizer for Variational Inference (VI) based on torch.optim.SecondOrderOptimizer. + + This optimizer manages the posterior distribution (mean and covariance of multivariate Gaussian) + of params for each layer. + + Args: + model (torch.nn.Module): model with parameters to be trained + model (float): dataset size + curv_type (str): type of the curvature ('Hessian', 'Fisher', or 'Cov') + curv_shapes (dict): shape the curvatures for each type of layer + curv_kwargs (dict): arguments (with keys) to be passed to torchsso.Curvature.__init__() + lr (float, optional): learning rate + momentum (float, optional): momentum factor + momentum_type (str, optional): type of gradients of which momentum + is calculated ('raw' or 'preconditioned') + grad_ema_decay (float, optional): decay rate for EMA of gradients + grad_ema_type (str, optional): type of gradients of which EMA + is calculated ('raw' or 'preconditioned') + weight_decay (float, optional): weight decay + normalizing_weights (bool, optional): whether the scale of the params + are normalized after each step + weight_scale (float, optional): the scale of the params for normalizing weights + acc_steps (int, optional): number of steps for which gradients and curvatures + are accumulated before each step + non_reg_for_bn (bool, optional): whether the regularization is applied to BatchNorm params + bias_correction (bool, optional): whether the bias correction (refer torch.optim.Adam) is applied + lars (bool, optional): whether LARS (https://arxiv.org/abs/1708.03888) is applied + lars_type (str, optional): type of gradients of which LARS + is applied ('raw' or 'preconditioned') + num_mc_samples (int, optional): number of MC samples taken from the posterior in each step + val_num_mc_samples (int, optional): number of MC samples taken from the posterior for evaluation + kl_weighting (float, optional): KL weighting (https://arxiv.org/abs/1712.02390) + warmup_kl_weighting_init (float, optional): initial KL weighting for warming up the value + warmup_kl_weighting_steps (float, optional): number of steps until the value reaches the kl_weighting + prior_variance (float, optional): variance of the prior distribution (Gaussian) of each param + init_precision (float, optional): initial (diagonal) precision of the posterior of params + """ + + def __init__(self, model: nn.Module, dataset_size: float, curv_type: str, curv_shapes: dict, curv_kwargs: dict, + lr=0.01, momentum=0., momentum_type='preconditioned', + grad_ema_decay=1., grad_ema_type='raw', weight_decay=0., + normalizing_weights=False, weight_scale=None, + acc_steps=1, non_reg_for_bn=False, bias_correction=False, + lars=False, lars_type='preconditioned', + num_mc_samples=10, val_num_mc_samples=10, + kl_weighting=1, warmup_kl_weighting_init=0.01, warmup_kl_weighting_steps=None, + prior_variance=1, init_precision=None, + seed=1, total_steps=1000): + + if dataset_size < 0: + raise ValueError("Invalid dataset size: {}".format(dataset_size)) + if num_mc_samples < 1: + raise ValueError("Invalid number of MC samples: {}".format(num_mc_samples)) + if val_num_mc_samples < 0: + raise ValueError("Invalid number of MC samples for validation: {}".format(val_num_mc_samples)) + if kl_weighting < 0: + raise ValueError("Invalid KL weighting: {}".format(kl_weighting)) + if warmup_kl_weighting_steps is not None and warmup_kl_weighting_init < 0: + raise ValueError("Invalid initial KL weighting: {}".format(warmup_kl_weighting_init)) + if prior_variance < 0: + raise ValueError("Invalid prior variance: {}".format(prior_variance)) + if init_precision is not None and init_precision < 0: + raise ValueError("Invalid initial precision: {}".format(init_precision)) + + init_kl_weighting = kl_weighting if warmup_kl_weighting_steps is None else warmup_kl_weighting_init + l2_reg = init_kl_weighting / dataset_size / prior_variance if prior_variance != 0 else 0 + std_scale = math.sqrt(init_kl_weighting / dataset_size) + + super(VIOptimizer, self).__init__(model, curv_type, curv_shapes, curv_kwargs, + lr=lr, momentum=momentum, momentum_type=momentum_type, + grad_ema_decay=grad_ema_decay, grad_ema_type=grad_ema_type, + l2_reg=l2_reg, weight_decay=weight_decay, + normalizing_weights=normalizing_weights, weight_scale=weight_scale, + acc_steps=acc_steps, non_reg_for_bn=non_reg_for_bn, + bias_correction=bias_correction, + lars=lars, lars_type=lars_type) + + self.defaults['std_scale'] = std_scale + self.defaults['kl_weighting'] = kl_weighting + self.defaults['warmup_kl_weighting_init'] = warmup_kl_weighting_init + self.defaults['warmup_kl_weighting_steps'] = warmup_kl_weighting_steps + self.defaults['num_mc_samples'] = num_mc_samples + self.defaults['val_num_mc_samples'] = val_num_mc_samples + self.defaults['total_steps'] = total_steps + self.defaults['seed_base'] = seed + + for group in self.param_groups: + group['std_scale'] = 0 if group['l2_reg'] == 0 else std_scale + group['mean'] = [p.data.detach().clone() for p in group['params']] + self.init_buffer(group['mean']) + + if init_precision is not None: + curv = group['curv'] + curv.element_wise_init(init_precision) + curv.step(update_std=(group['std_scale'] > 0)) + + def zero_grad(self): + r"""Clears the gradients of all optimized :class:`torch.Tensor` s.""" + for group in self.param_groups: + for m in group['mean']: + if m.grad is not None: + m.grad.detach_() + m.grad.zero_() + + super(VIOptimizer, self).zero_grad() + + @property + def seed(self): + return self.optim_state['step'] + self.defaults['seed_base'] + + def set_random_seed(self, seed=None): + if seed is None: + seed = self.seed + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + def sample_params(self): + + for group in self.param_groups: + params, mean = group['params'], group['mean'] + curv = group['curv'] + if curv is not None and curv.std is not None: + # sample from posterior + curv.sample_params(params, mean, group['std_scale']) + else: + for p, m in zip(params, mean): + p.data.copy_(m.data) + + def copy_mean_to_params(self): + for group in self.param_groups: + params, mean = group['params'], group['mean'] + for p, m in zip(params, mean): + p.data.copy_(m.data) + if getattr(p, 'grad', None) is not None \ + and getattr(m, 'grad', None) is not None: + p.grad.copy_(m.grad) + + def adjust_kl_weighting(self): + warmup_steps = self.defaults['warmup_kl_weighting_steps'] + if warmup_steps is None: + return + + current_step = self.optim_state['step'] + if warmup_steps < current_step: + return + + target_kl = self.defaults['kl_weighting'] + init_kl = self.defaults['warmup_kl_weighting_init'] + + rate = current_step / warmup_steps + kl_weighting = init_kl + rate * (target_kl - init_kl) + + rate = kl_weighting / init_kl + l2_reg = rate * self.defaults['l2_reg'] + std_scale = math.sqrt(rate) * self.defaults['std_scale'] + for group in self.param_groups: + if group['l2_reg'] > 0: + group['l2_reg'] = l2_reg + if group['std_scale'] > 0: + group['std_scale'] = std_scale + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + + def closure(): + # forward/backward + return loss, output + """ + + m = self.defaults['num_mc_samples'] + n = self.defaults['acc_steps'] + + acc_loss = TensorAccumulator() + acc_prob = TensorAccumulator() + + self.set_random_seed() + + for _ in range(m): + + # sampling + self.sample_params() + + # forward and backward + loss, output = closure() + + acc_loss.update(loss, scale=1/m) + prob = F.softmax(output, dim=1) + acc_prob.update(prob, scale=1/n) + + # accumulate + for group in self.param_groups: + params = group['params'] + + grads = [p.grad.data for p in params] + group['acc_grads'].update(grads, scale=1/m/n) + + curv = group['curv'] + if curv is not None: + group['acc_curv'].update(curv.data, scale=1/m/n) + + loss, prob = acc_loss.get(), acc_prob.get() + + # update acc step + self.optim_state['acc_step'] += 1 + if self.optim_state['acc_step'] < n: + return loss, prob + else: + self.optim_state['acc_step'] = 0 + + self.backward_postprocess(target='mean') + self.optim_state['step'] += 1 + + # update distribution + for group in self.local_param_groups: + + self.update_preprocess(group, target='mean', grad_type='raw') + + # update covariance + mean, curv = group['mean'], group['curv'] + if curv is not None: + curv.step(update_std=(group['std_scale'] > 0)) + curv.precondition_grad(mean) + + # update mean + self.update_preprocess(group, target='mean', grad_type='preconditioned') + self.update(group, target='mean') + self.update_postprocess(group, target='mean') + + # copy mean to param + params = group['params'] + for p, m in zip(params, mean): + p.data.copy_(m.data) + p.grad.copy_(m.grad) + + self.adjust_kl_weighting() + + return loss, prob + + def prediction(self, data): + + self.set_random_seed(self.optim_state['step']) + + acc_prob = TensorAccumulator() + mc_samples = self.defaults['val_num_mc_samples'] + + use_mean = mc_samples == 0 + n = 1 if use_mean else mc_samples + + for _ in range(n): + + if use_mean: + self.copy_mean_to_params() + else: + # sampling + self.sample_params() + + output = self.model(data) + prob = F.softmax(output, dim=1) + acc_prob.update(prob, scale=1/n) + + self.copy_mean_to_params() + + prob = acc_prob.get() + + return prob + + +class DistributedVIOptimizer(DistributedSecondOrderOptimizer, VIOptimizer): + + def __init__(self, *args, mc_group_id=0, **kwargs): + super(DistributedVIOptimizer, self).__init__(*args, **kwargs) + self.defaults['seed_base'] += mc_group_id * self.defaults['total_steps'] + + @property + def actual_optimizer(self): + return VIOptimizer + + def zero_grad(self): + self.actual_optimizer.zero_grad(self) + + def extractors_for_rsv(self): + extractors = [_utility.extract_attr_from_params('grad', target='mean'), + _utility.extract_attr_from_curv('data', True)] + return extractors + + def extractors_for_agv(self): + extractors = [_utility.extract_attr_from_params('data', target='mean'), + _utility.extract_attr_from_curv('std', True)] + return extractors + + def step(self, closure=None): + ret = super(DistributedVIOptimizer, self).step(closure) + + if self.is_updated(): + self.copy_mean_to_params() + + return ret + diff --git a/torchsso/utils/__init__.py b/torchsso/utils/__init__.py new file mode 100644 index 0000000..94ae857 --- /dev/null +++ b/torchsso/utils/__init__.py @@ -0,0 +1,4 @@ +from torchsso.utils.logger import Logger # NOQA +from torchsso.utils.inv_cupy import inv # NOQA +from torchsso.utils.cholesky_cupy import cholesky # NOQA +from torchsso.utils.accumulator import TensorAccumulator # NOQA diff --git a/torchsso/utils/accumulator.py b/torchsso/utils/accumulator.py new file mode 100644 index 0000000..3b6aa5f --- /dev/null +++ b/torchsso/utils/accumulator.py @@ -0,0 +1,57 @@ +from torch import Tensor + + +class TensorAccumulator(object): + + def __init__(self): + self._accumulation = None + + def check_type(self, data): + accumulation = self._accumulation + + if isinstance(data, list): + assert type(data[0]) == Tensor, 'the type of data has to be list of torch.Tensor or torch.Tensor' + else: + assert type(data) == Tensor, 'the type of data has to be list of torch.Tensor or torch.Tensor' + + if accumulation is not None: + assert type(data) == type(accumulation), \ + 'the type of data ({}) is different from ' \ + 'the type of the accumulation ({})'.format( + type(data), type(accumulation)) + + def update(self, data, scale=1.): + self.check_type(data) + + accumulation = self._accumulation + + if isinstance(data, list): + if accumulation is None: + self._accumulation = [d.mul(scale) for d in data] + else: + self._accumulation = [acc.add(scale, d) + for acc, d in zip(accumulation, data)] + else: + if accumulation is None: + self._accumulation = data.mul(scale) + else: + self._accumulation = accumulation.add(scale, data) + + def get(self, clear=True): + accumulation = self._accumulation + if accumulation is None: + return + + if isinstance(accumulation, list): + data = [d.clone() for d in self._accumulation] + else: + data = accumulation.clone() + + if clear: + self.clear() + + return data + + def clear(self): + self._accumulation = None + diff --git a/torchsso/utils/chainer_communicators/__init__.py b/torchsso/utils/chainer_communicators/__init__.py new file mode 100644 index 0000000..becc9b3 --- /dev/null +++ b/torchsso/utils/chainer_communicators/__init__.py @@ -0,0 +1,38 @@ +import numpy as np + + +def create_communicator(communicator_name='pure_nccl', + mpi_comm=None, + rsv_comm_dtype=np.float32, + agv_comm_dtype=np.float32, + use_hiercoll=False, + dims=None, + ): + if mpi_comm is None: + import mpi4py.MPI + mpi_comm = mpi4py.MPI.COMM_WORLD + + if communicator_name != 'pure_nccl' and rsv_comm_dtype != np.float32: + raise ValueError( + 'rsv_comm_dtype is only available at \'pure_nccl\' communicator') + + if communicator_name != 'pure_nccl' and agv_comm_dtype != np.float32: + raise ValueError( + 'agv_comm_dtype is only available at \'pure_nccl\' communicator') + + if communicator_name != 'pure_nccl' and dims is not None: + raise ValueError( + 'dims is only available at \'pure_nccl\' communicator') + + if communicator_name == 'pure_nccl': + from torchsso.utils.chainer_communicators.pure_nccl_communicator \ + import PureNCCLCommunicator + return PureNCCLCommunicator(mpi_comm, + rsv_comm_dtype=rsv_comm_dtype, + agv_comm_dtype=agv_comm_dtype, + use_hiercoll=use_hiercoll, + dims=dims + ) + else: + raise ValueError( + 'Unrecognized communicator_name: {}'.format(communicator_name)) diff --git a/torchsso/utils/chainer_communicators/_utility.py b/torchsso/utils/chainer_communicators/_utility.py new file mode 100644 index 0000000..0e2d1c1 --- /dev/null +++ b/torchsso/utils/chainer_communicators/_utility.py @@ -0,0 +1,224 @@ +import warnings + +import numpy +try: + import cupy + from torchsso.utils.cupy import to_cupy +except: + pass + # print("No cupy detected") + +from chainer.backends import cuda + +import torch + + +class Packer(object): + + def __init__(self): + self.unpack_kernel = cupy.ElementwiseKernel( + 'raw T vec, int32 matrix_size', + 'raw T mat', + """ + int x = i % matrix_size; + int y = i / matrix_size; + if( x < y ) { + int tmp = y; + y = x; + x = tmp; + } + mat[i] = vec[matrix_size * y - y * (y + 1) / 2 + x]; + """, + 'unpack' + ) + + def pack(self, arrays, gpu_buf, sizeof_dtype, stream, offset=0): + buf_offset = offset * sizeof_dtype + for local_arrays in arrays: + for array, triangular in local_arrays: + if triangular: + nbytes = self._put_triangular_matrix_to_device_memory( + array, gpu_buf, buf_offset, stream) + else: + nbytes = array.size * sizeof_dtype + gpu_buf.from_device(array, nbytes, buf_offset, stream) + buf_offset += nbytes + + def unpack(self, arrays, gpu_buf, sizeof_dtype, stream, offset=0): + buf_offset = offset * sizeof_dtype + for local_arrays in arrays: + for array, triangular in local_arrays: + if triangular: + nbytes = self._get_triangular_matrix_from_device_memory( + array, gpu_buf, buf_offset, stream) + else: + nbytes = array.size * sizeof_dtype + gpu_buf.to_device(array, nbytes, buf_offset, stream) + buf_offset += nbytes + + def _put_triangular_matrix_to_device_memory( + self, array, memory, offset, stream): + """Puts a triangular matrix to ``DeviceMemory`` + """ + if array.dtype.char == 'f' or array.dtype.char == 'd': + dtype = array.dtype.char + else: + dtype = numpy.find_common_type((array.dtype.char, 'f'), ()).char + + cublas_handle = cupy.cuda.device.get_cublas_handle() + + if array.shape[0] != array.shape[1]: + raise RuntimeError('non square matrix') + + n = array.shape[0] + nelems = n * (n + 1) // 2 + nbytes = nelems * array.dtype.itemsize + + if dtype == 'f': + trttp = cupy.cuda.cublas.strttp + else: + trttp = cupy.cuda.cublas.dtrttp + + with stream: + trttp(cublas_handle, cupy.cuda.cublas.CUBLAS_FILL_MODE_LOWER, n, + array.data.ptr, n, memory.ptr() + offset) + + return nbytes + + def _get_triangular_matrix_from_device_memory( + self, array, memory, offset, stream): + """Gets a triangular matrix from ``DeviceMemory`` + """ + if array.shape[0] != array.shape[1]: + raise RuntimeError('non square matrix') + + n = array.shape[0] + nelems = n * (n + 1) // 2 + nbytes = nelems * array.dtype.itemsize + + with stream: + self.unpack_kernel( + memory.array(nelems, offset=offset, dtype=array.dtype), + n, array, size=n * n) + + return nbytes + + +def _check_array(array, name): + xp = cuda.get_array_module(array) + with cuda.get_device_from_array(array): + if not array.dtype == xp.float32: + warnings.warn('non FP32 dtype detected in {}'.format(name)) + array = array.astype(xp.float32) + if not (array.flags.c_contiguous or array.flags.f_contiguous): + warnings.warn('non contiguous array detected in {}'.format(name)) + array = xp.ascontiguousarray(array) + return array + + +def extract(param_groups, indices, extractors): + """Extracts arrays from given fisher blocks using indices and extractors + + Args: + fblocks: List of ``FisherBlock`` instances + indices: List of ``int``s + extractors: Callable that extract arrays from a given ``FisherBlock`` + + Return: + List of tuple(array, bool). Second item indicates triangular flag. + """ + arrays = [] + for local_indices in indices: + local_arrays = [] + for index in local_indices: + for extractor in extractors: + for array in extractor(param_groups[index]): + local_arrays.append(array) + arrays.append(local_arrays) + return arrays + + +def extract_attr_from_params(attr, target='params', triangular=False): + """Extracts arrays from all ``Parameter``s in a given ``FisherBlock`` + """ + + def _extract_attr_from_params(group): + arrays = [] + for param in group[target]: + x = getattr(param, attr, None) + if x is not None: + #x = _check_array(x, fblock.linkname) + #setattr(param, attr, x) + x_ten = x.data + x_cp = to_cupy(x_ten) + arrays.append((x_cp, triangular)) + return arrays + + return _extract_attr_from_params + + +def extract_attr_from_curv(attr, triangular=False): + """Extracts arrays from all ``Parameter``s in a given ``FisherBlock`` + """ + + def _extract_attr_from_curv(group): + arrays = [] + + curv = group['curv'] + if curv is None: + return arrays + + target = getattr(curv, attr, None) + if target is None: + if curv.data is not None: + zeros = [] + for x in curv.data: + zeros.append(torch.zeros_like(x)) + setattr(curv, attr, zeros) + target = getattr(curv, attr) + else: + return arrays + + for x in target: + #x = _check_array(x, fblock.linkname) + #setattr(param, attr, x) + x_ten = x.data + x_cp = to_cupy(x_ten) + _triangular = triangular and x_cp.ndim == 2 and x_cp.shape[0] == x_cp.shape[1] + arrays.append((x_cp, _triangular)) + + return arrays + + return _extract_attr_from_curv + + +def get_nelems(arrays): + """Computes number of elements from given arrays using the triangular flag. + """ + nelems = 0 + for local_arrays in arrays: + for array, triangular in local_arrays: + if triangular: + if array.shape[0] != array.shape[1]: + raise RuntimeError('get_nelems: not a square matrix') + nelems += array.shape[0] * (array.shape[0] + 1) // 2 + else: + nelems += array.size + return nelems + + +def assign(gpu_buf, nbytes): + if nbytes > gpu_buf.size: + gpu_buf.assign(nbytes) + return True + return False + + +def allocate_asgrad(fblocks, attr): + for fblock in fblocks: + for _, param in sorted(fblock.link.namedparams()): + if not hasattr(param, attr): + # We need to allocate memory space for recieving data + _grad = param.grad.copy() + _grad.fill(0.) + setattr(param, attr, _grad) diff --git a/torchsso/utils/chainer_communicators/base.py b/torchsso/utils/chainer_communicators/base.py new file mode 100644 index 0000000..f6dffcc --- /dev/null +++ b/torchsso/utils/chainer_communicators/base.py @@ -0,0 +1,22 @@ +from chainermn.communicators import mpi_communicator_base +import warnings + +from torchsso.utils.chainer_communicators import _utility + + +class KFACCommunicatorBase(mpi_communicator_base.MpiCommunicatorBase): + + def __init__(self, mpi_comm): + super(KFACCommunicatorBase, self).__init__(mpi_comm) + self.indices = None + self.packer = _utility.Packer() + + def allreduce_grad(self): + # We don't use AllReduce for training K-FAC + warnings.warn('AllReduce called, skipping...') + + def reduce_scatterv_data(self, fblocks, extractors): + raise NotImplementedError + + def allgatherv_data(self, fblocks, extractors): + raise NotImplementedError diff --git a/torchsso/utils/chainer_communicators/pure_nccl_communicator.py b/torchsso/utils/chainer_communicators/pure_nccl_communicator.py new file mode 100644 index 0000000..1a9e592 --- /dev/null +++ b/torchsso/utils/chainer_communicators/pure_nccl_communicator.py @@ -0,0 +1,475 @@ +import itertools +import math + +from mpi4py import MPI + +import numpy as np +import cupy + +import chainer + +from chainermn.communicators import _memory_utility +from chainermn.communicators import _communication_utility +from chainermn import nccl + +try: + from hiercoll.hiernccl import HierNcclCommunicator + + _hiercoll_available = True +except ImportError as e: + _hiercoll_available = False + _hiercoll_available_exception = e + +from torchsso.utils.chainer_communicators import base +from torchsso.utils.chainer_communicators import _utility + +NUM_STREAMS = 12 + + +class PureNCCLCommunicator(base.KFACCommunicatorBase): + + def __init__(self, + mpi_comm, + rsv_comm_dtype=np.float32, + agv_comm_dtype=np.float32, + use_hiercoll=False, + dims=None + ): + super(PureNCCLCommunicator, self).__init__(mpi_comm) + + if use_hiercoll: + if not _hiercoll_available: + raise ValueError('use_hiercoll is True,' + 'but hiercoll.hiernccl is not available.') + + if dims is None: + dims = [] + + if dims is not None and not use_hiercoll: + raise ValueError('dim is not None,' + 'but use_hiercoll is False.') + + if use_hiercoll and mpi_comm.size != MPI.COMM_WORLD.size: + raise ValueError( + 'HierColl with non-WORLD MPI Comm is not supported.') + + # None -> Non-hierarchical / pure NCCL + # [] -> auto hierarchical selection (envv or optimizer) + # [int]-> manual hierarchical selection + self.dims = dims + + # We have to delay the initialization of communicators. This is because + # NCCL's communicators use the current CUDA devices at the time of + # initialization. Therefore, we have to initialize NCCL communicators + # after users set the devices to use. + self.nccl_comm = None + + # GPU buffers + self.gpu_buf_a = _memory_utility.DeviceMemory() + self.gpu_buf_b = _memory_utility.DeviceMemory() + self.gpu_buf_c = _memory_utility.DeviceMemory() + + # Assume FP32 for data type + self._arrs_dtype = np.dtype(np.float32) + + # Data type used in communications + self._rsv_comm_dtype = np.dtype(rsv_comm_dtype) + if self._rsv_comm_dtype.kind != 'f': + raise ValueError('rsv_comm_dtype must be numpy.float16,' + 'numpy.float32 or numpy.float64.') + + self._agv_comm_dtype = np.dtype(agv_comm_dtype) + if self._agv_comm_dtype.kind != 'f': + raise ValueError('agv_comm_dtype must be numpy.float16,' + 'numpy.float32 or numpy.float64.') + + # GPU kernels. We don't generate here due to the same reason above + self._cast_rsv_kernels = None + self._cast_agv_kernels = None + self._mean_kernel = None + self._max_kernel = None + self._memset_kernel = None + + # Packer to pack/unpack arrays + self._packer = _utility.Packer() + + # For scaling in FP16 + self._scaled = False + self._scaling_factors = None + self._streams = None + + def _init_comms(self): + if self.nccl_comm is not None: + return + + if self.dims is None: + self.nccl_comm = _communication_utility.init_nccl_comm( + self.mpi_comm) + else: + if len(self.dims) == 0: + self.nccl_comm = HierNcclCommunicator() + else: + self.nccl_comm = HierNcclCommunicator(self.dims) + + def reduce_scatterv_data(self, param_groups, extractors): + """Executes Reduce+ScatterV. + + Flow(no cast): pack(A) -> send(A) -> recv(B) -> mean(B->A) + -> unpack(A) + + Flow(casting): pack(C) -> cast(C->A) -> send(A) -> recv(B) + -> mean(B->A) -> cast(A->C) + -> unpack(C) + """ + + # CUDA default stream + stream = chainer.cuda.Stream.null + + # Initialize NCCL communicator if not + self._init_comms() + # Target NCCL communicator + nccl_comm = self.nccl_comm + + # This processes's assigned array index in arrays + local_rank = self.rank + + # Extract arrays from param_groups + arrays = _utility.extract(param_groups, self.indices, extractors) + + # Get total number of elements, local number of elements, and local + # number of elements' offset + nelems = _get_divideable_nelems(nccl_comm, _utility.get_nelems(arrays)) + nelems_local = _utility.get_nelems([arrays[local_rank]]) + nelems_offset = _utility.get_nelems(arrays[:local_rank]) + + # Allocate memory if not + needs_sync_a = _utility.assign(self.gpu_buf_a, + nelems * self._rsv_comm_dtype.itemsize) + needs_sync_b = _utility.assign(self.gpu_buf_b, + nelems * self._rsv_comm_dtype.itemsize) + needs_sync_c = _utility.assign(self.gpu_buf_c, + nelems * self._arrs_dtype.itemsize) \ + if self._arrs_dtype != self._rsv_comm_dtype else False + + # Pack elements in a buffer + # Data type casting will occur here if necessesary + if self._arrs_dtype != self._rsv_comm_dtype: + # Casting required + if self._cast_rsv_kernels is None or \ + self._cast_rsv_kernels.src_dtype != self._arrs_dtype or \ + self._cast_rsv_kernels.dst_dtype != self._rsv_comm_dtype: + self._cast_rsv_kernels = _CastingKernels(self._arrs_dtype, + self._rsv_comm_dtype) + self._packcast(arrays, arrays, nelems, self.gpu_buf_c, + self.gpu_buf_a, self._cast_rsv_kernels, stream) + else: + # Casting unnecessesary + self.packer.pack(arrays, self.gpu_buf_a, self._arrs_dtype.itemsize, + stream) + + # Buffers for AllReduce + sendbuf = self.gpu_buf_a.ptr() + recvbuf = self.gpu_buf_b.ptr() + + # Synchronize if necessesary + if needs_sync_a or needs_sync_b or needs_sync_c: + chainer.cuda.Stream.null.synchronize() + + # Communication + nccl_dtype = _get_nccl_dtype(self._rsv_comm_dtype) + nccl_comm.allReduce(sendbuf, recvbuf, nelems, nccl_dtype, + nccl.NCCL_SUM, stream.ptr) + + # Generate mean computing kernel if necessesary + if self._mean_kernel is None: + self._mean_kernel = _get_mean_kernel(self._rsv_comm_dtype, + self.size) + + # Compute the mean (divide by the number of processes) + # TODO: compute mean and cast dtype simultaneously. + self._mean_kernel( + self.gpu_buf_b.array( + nelems_local, + offset=nelems_offset * self._rsv_comm_dtype.itemsize, + dtype=self._rsv_comm_dtype), + self.gpu_buf_a.array( + nelems_local, + offset=nelems_offset * self._rsv_comm_dtype.itemsize, + dtype=self._rsv_comm_dtype), + stream=stream) + + # Unpack elements from a buffer + # Data type casting will occur here if necessesary + if self._arrs_dtype != self._rsv_comm_dtype: + # Casting required + self._castunpack(arrays, [arrays[local_rank]], nelems_local, + self.gpu_buf_c, self.gpu_buf_a, + self._cast_rsv_kernels, stream, + offset=nelems_offset) + else: + # Casting unnecessesary + self._packer.unpack([arrays[local_rank]], self.gpu_buf_a, + self._arrs_dtype.itemsize, stream, + offset=nelems_offset) + + def allgatherv_data(self, param_groups, extractors): + """Executes AllGatherV. + + Flow(no cast): pack(A) -> send(A) -> recv(B) -> unpack(B) + + Flow(casting): pack(C) -> cast(C->A) -> send(A) -> recv(B) + -> cast(B->C) -> unpack(C) + """ + + # CUDA default stream + stream = chainer.cuda.Stream.null + + # This processes's assigned array index in arrays + local_rank = self.rank + + # Allocate memory space for recieving + # TODO + #_utility.allocate_asgrad(param_groups, 'kfgrad') + + # Initialize NCCL communicator if not + self._init_comms() + + # Target NCCL communicator + nccl_comm = self.nccl_comm + + # Extract arrays from param_groups + arrays = _utility.extract(param_groups, self.indices, extractors) + + # Get total number of elements, local number of elements, and local + # number of elements' offset + nelems = _get_divideable_nelems(nccl_comm, _utility.get_nelems(arrays)) + nelems_local = _utility.get_nelems([arrays[local_rank]]) + nelems_offset = _utility.get_nelems(arrays[:local_rank]) + + # Allocate memory if not + needs_sync_a = _utility.assign(self.gpu_buf_a, + nelems * self._agv_comm_dtype.itemsize) + needs_sync_b = _utility.assign(self.gpu_buf_b, + nelems * self._agv_comm_dtype.itemsize) + needs_sync_c = _utility.assign(self.gpu_buf_c, + nelems * self._arrs_dtype.itemsize) \ + if self._arrs_dtype != self._agv_comm_dtype else False + + # Generate memset kernel if necessesary + if self._memset_kernel is None: + self._memset_kernel = _get_memset_kernel(self._agv_comm_dtype) + + # Memset 0 + self._memset_kernel( + self.gpu_buf_a.array(nelems, dtype=self._agv_comm_dtype), + stream=stream) + + # Pack elements in a buffer + # Data type casting will occur here if necessesary + if self._arrs_dtype != self._agv_comm_dtype: + # Casting required + if self._cast_agv_kernels is None or \ + self._cast_agv_kernels.src_dtype != self._arrs_dtype or \ + self._cast_agv_kernels.dst_dtype != self._agv_comm_dtype: + self._cast_agv_kernels = _CastingKernels(self._arrs_dtype, + self._agv_comm_dtype) + self._packcast(arrays, [arrays[local_rank]], nelems_local, + self.gpu_buf_c, self.gpu_buf_a, + self._cast_agv_kernels, stream, + offset=nelems_offset) + else: + # Casting unnecessesary + self._packer.pack([arrays[local_rank]], self.gpu_buf_a, + self._arrs_dtype.itemsize, + stream, offset=nelems_offset) + + # Buffers for AllReduce + sendbuf = self.gpu_buf_a.ptr() + recvbuf = self.gpu_buf_b.ptr() + + # Synchronize if necessesary + if needs_sync_a or needs_sync_b or needs_sync_c: + chainer.cuda.Stream.null.synchronize() + + # Communication + nccl_dtype = _get_nccl_dtype(self._agv_comm_dtype) + nccl_comm.allReduce(sendbuf, recvbuf, nelems, nccl_dtype, + nccl.NCCL_SUM, stream.ptr) + + # Unpack elements from a buffer + # Data type casting will occur here if necessesary + if self._arrs_dtype != self._agv_comm_dtype: + # Casting required + self._castunpack(arrays, arrays, nelems, self.gpu_buf_c, + self.gpu_buf_b, self._cast_agv_kernels, stream, + offset=0) + else: + # Casting unnecessesary + self._packer.unpack(arrays, self.gpu_buf_b, + self._arrs_dtype.itemsize, stream) + + def _packcast(self, global_arrays, arrays, nelems, src_gpu_buf, + dst_gpu_buf, casting_kernels, stream, offset=0): + """Scale, pack, and cast using the given array and GPU buffers + """ + + # Scaling + if casting_kernels.dst_dtype == np.dtype(np.float16): + self._communication_scale(global_arrays, stream) + + # Pack elements to the buffer + self.packer.pack(arrays, src_gpu_buf, + casting_kernels.src_dtype.itemsize, stream, + offset=offset) + + # Cast data type: SRC -> DST + casting_kernels.src_to_dst_kernel( + src_gpu_buf.array( + nelems, dtype=casting_kernels.src_dtype, + offset=offset * casting_kernels.src_dtype.itemsize), + dst_gpu_buf.array( + nelems, dtype=casting_kernels.dst_dtype, + offset=offset * casting_kernels.dst_dtype.itemsize), + stream=stream) + + def _castunpack(self, global_arrays, arrays, nelems, src_gpu_buf, + dst_gpu_buf, casting_kernels, stream, offset=0): + """Cast, unpack, and scale using the given array and GPU buffers + """ + + # Cast data type: DST -> SRC + casting_kernels.dst_to_src_kernel( + dst_gpu_buf.array( + nelems, dtype=casting_kernels.dst_dtype, + offset=offset * casting_kernels.dst_dtype.itemsize), + src_gpu_buf.array( + nelems, dtype=casting_kernels.src_dtype, + offset=offset * casting_kernels.src_dtype.itemsize), + stream=stream) + + # Unpack elements from the buffer + self.packer.unpack(arrays, src_gpu_buf, + casting_kernels.src_dtype.itemsize, stream, + offset=offset) + + # Scaling + if self._scaled: + self._rescale(global_arrays, stream) + + def _communication_scale(self, arrays, default_stream): + if self._streams is None: + self._streams = [cupy.cuda.Stream() for _ in range(NUM_STREAMS)] + + if self._max_kernel is None: + self._max_kernel = _get_max_kernel() + + arrays = list(itertools.chain.from_iterable(arrays)) + arrays = [array for array, _ in arrays] + arrays = sorted(arrays, key=lambda x: x.size) + nelems = _get_divideable_nelems(self.nccl_comm, len(arrays)) + + send_arr = cupy.empty(nelems, dtype=cupy.float32) + recv_arr = cupy.empty(nelems, dtype=cupy.float32) + + used_stream_indices = np.zeros(NUM_STREAMS) + for i, array in enumerate(arrays): + stream_index = used_stream_indices.argmin() + stream = self._streams[stream_index] + + self._max_kernel(array, send_arr[i], stream=stream) + + # NOTE(y1r): order of computation time is O(n). + # So we count n (bin-packing problem heuristics) + used_stream_indices[stream_index] += np.prod(array.shape) + + # NOTE(y1r): Assume that stream is default stream. + # Therefore, stream.synchronize() is not necessary. + self.nccl_comm.allReduce(send_arr.data.ptr, + recv_arr.data.ptr, + nelems, + _get_nccl_dtype(cupy.dtype(cupy.float32)), + nccl.NCCL_SUM, + default_stream.ptr) + + with default_stream: + scaling_factors = 65000 / recv_arr + + for scaling_factor, array in zip(scaling_factors, arrays): + array *= scaling_factor + + self._scaled = True + self._scaling_factors = scaling_factors + + def _rescale(self, arrays, default_stream): + arrays = list(itertools.chain.from_iterable(arrays)) + arrays = [array for array, _ in arrays] + arrays = sorted(arrays, key=lambda x: x.size) + + with default_stream: + for i, array in enumerate(arrays): + array *= (1 / self._scaling_factors[i]) + + self._scaled = False + + +class _CastingKernels(object): + + def __init__(self, src_dtype, dst_dtype): + self.src_dtype = src_dtype + self.dst_dtype = dst_dtype + self.src_to_dst_kernel = chainer.cuda.cupy.ElementwiseKernel( + '{} x'.format(src_dtype.name), + '{} y'.format(dst_dtype.name), + 'y = x', "{}_to_{}".format(src_dtype.name, dst_dtype.name)) + self.dst_to_src_kernel = chainer.cuda.cupy.ElementwiseKernel( + '{} x'.format(dst_dtype.name), + '{} y'.format(src_dtype.name), + 'y = x', "{}_to_{}".format(dst_dtype.name, src_dtype.name)) + + +def _get_mean_kernel(dtype, size): + return chainer.cuda.cupy.ElementwiseKernel( + '{} x'.format(dtype.name), + '{} y'.format(dtype.name), + 'y = x * (1.0 / {})'.format(size), + 'my_mean') + + +def _get_max_kernel(): + return chainer.cuda.cupy.ReductionKernel( + 'float32 x', + 'float32 y', + 'fabsf(x)', + 'fmaxf(a, b)', + 'y = a', + '0', + 'my_max') + + +def _get_memset_kernel(dtype): + return chainer.cuda.cupy.ElementwiseKernel( + '', + '{} x'.format(dtype.name), + 'x = 0.0', + 'my_memset') + + +def _get_divideable_nelems(nccl_comm, nelems): + if hasattr(nccl_comm, 'getCountRequirement'): + requirement = nccl_comm.getCountRequirement() + return int(math.ceil(nelems / requirement)) * requirement + else: + return nelems + + +def _get_nccl_dtype(dtype): + if dtype == np.float16: + return nccl.NCCL_FLOAT16 + elif dtype == np.float32: + return nccl.NCCL_FLOAT32 + elif dtype == np.float64: + return nccl.NCCL_FLOAT64 + else: + raise ValueError( + 'dtype must be numpy.float16, numpy.float32 or numpy.float64,' + 'not {}'.format(dtype)) diff --git a/torchsso/utils/cholesky_cupy.py b/torchsso/utils/cholesky_cupy.py new file mode 100644 index 0000000..1ea8112 --- /dev/null +++ b/torchsso/utils/cholesky_cupy.py @@ -0,0 +1,14 @@ +try: + import cupy + from torchsso.utils.cupy import to_cupy, from_cupy +except: + # print("No cupy detected") + pass + + +def cholesky(m, upper=True): + m_cp = to_cupy(m) + m_chl_cp = cupy.linalg.decomposition.cholesky(m_cp) + if upper: + m_chl_cp = m_chl_cp.transpose() + return from_cupy(m_chl_cp) diff --git a/torchsso/utils/cupy.py b/torchsso/utils/cupy.py new file mode 100644 index 0000000..0e05efc --- /dev/null +++ b/torchsso/utils/cupy.py @@ -0,0 +1,15 @@ +try: + import cupy +except: + # print("No cupy detected") + pass + +from torch.utils.dlpack import to_dlpack, from_dlpack + + +def to_cupy(m_tensor): + return cupy.fromDlpack(to_dlpack(m_tensor)) + + +def from_cupy(m_cp): + return from_dlpack(m_cp.toDlpack()) diff --git a/torchsso/utils/inv_cupy.py b/torchsso/utils/inv_cupy.py new file mode 100644 index 0000000..d453b23 --- /dev/null +++ b/torchsso/utils/inv_cupy.py @@ -0,0 +1,132 @@ +import numpy +import scipy +import torch + +try: + import cupy + from cupy import cuda + from cupy.cuda import cublas + from cupy.cuda import device + from cupy.linalg import util + if cuda.cusolver_enabled: + from cupy.cuda import cusolver + from torchsso.utils.cupy import to_cupy, from_cupy +except: + pass + # print("No cupy detected") + + +import warnings + + +use_cholesky = True + +# Based cupy (cupy/cupy/linalg/solve.py) @ 067f830 + + +def inv(m): + if torch.cuda.is_available(): + m_cp = to_cupy(m) + m_inv_cp = inv_core(m_cp, use_cholesky) + return from_cupy(m_inv_cp) + else: + result = torch.from_numpy(scipy.linalg.inv(m.cpu().numpy())) + return result + + +def inv_core(a, cholesky=False): + """Computes the inverse of a matrix. + This function computes matrix ``a_inv`` from n-dimensional regular matrix + ``a`` such that ``dot(a, a_inv) == eye(n)``. + Args: + a (cupy.ndarray): The regular matrix + b (Boolean): Use cholesky decomposition + Returns: + cupy.ndarray: The inverse of a matrix. + .. seealso:: :func:`numpy.linalg.inv` + """ + + xp = cupy.get_array_module(a) + if xp == numpy: + if cholesky: + warnings.warn( + "Current fast-inv using cholesky doesn't support numpy.ndarray.") + return numpy.linalg.inv(a) + + if not cuda.cusolver_enabled: + raise RuntimeError('Current cupy only supports cusolver in CUDA 8.0') + + # to prevent `a` to be overwritten + a = a.copy() + + util._assert_cupy_array(a) + util._assert_rank2(a) + util._assert_nd_squareness(a) + + if a.dtype.char == 'f' or a.dtype.char == 'd': + dtype = a.dtype.char + else: + dtype = numpy.find_common_type((a.dtype.char, 'f'), ()).char + + cusolver_handle = device.get_cusolver_handle() + dev_info = cupy.empty(1, dtype=cupy.int) + m = a.shape[0] + + b = cupy.eye(m, dtype=dtype) + + if not cholesky: + if dtype == 'f': + getrf = cusolver.sgetrf + getrf_bufferSize = cusolver.sgetrf_bufferSize + getrs = cusolver.sgetrs + else: # dtype == 'd' + getrf = cusolver.dgetrf + getrf_bufferSize = cusolver.dgetrf_bufferSize + getrs = cusolver.dgetrs + + buffersize = getrf_bufferSize(cusolver_handle, m, m, a.data.ptr, m) + + # TODO(y1r): cache buffer to avoid malloc + workspace = cupy.empty(buffersize, dtype=dtype) + ipiv = cupy.empty((a.shape[0], 1), dtype=dtype) + + # LU Decomposition + getrf(cusolver_handle, m, m, a.data.ptr, m, + workspace.data.ptr, ipiv.data.ptr, dev_info.data.ptr) + + # TODO(y1r): check dev_info status + + # solve for the inverse + getrs(cusolver_handle, 0, m, m, a.data.ptr, m, + ipiv.data.ptr, b.data.ptr, m, dev_info.data.ptr) + + # TODO(y1r): check dev_info status + else: + if dtype == 'f': + potrf = cusolver.spotrf + potrf_bufferSize = cusolver.spotrf_bufferSize + potrs = cusolver.spotrs + else: # dtype == 'd' + potrf = cusolver.dpotrf + potrf_bufferSize = cusolver.dpotrf_bufferSize + potrs = cusolver.dpotrs + + buffersize = potrf_bufferSize( + cusolver_handle, cublas.CUBLAS_FILL_MODE_UPPER, m, a.data.ptr, m) + + # TODO(y1r): cache buffer to avoid malloc + workspace = cupy.empty(buffersize, dtype=dtype) + + # Cholesky Decomposition + potrf(cusolver_handle, cublas.CUBLAS_FILL_MODE_UPPER, m, + a.data.ptr, m, workspace.data.ptr, buffersize, dev_info.data.ptr) + + # TODO(y1r): check dev_info status + + # solve for the inverse + potrs(cusolver_handle, cublas.CUBLAS_FILL_MODE_UPPER, m, + m, a.data.ptr, m, b.data.ptr, m, dev_info.data.ptr) + + # TODO(y1r): check dev_info status + + return b diff --git a/torchsso/utils/logger.py b/torchsso/utils/logger.py new file mode 100644 index 0000000..82a1354 --- /dev/null +++ b/torchsso/utils/logger.py @@ -0,0 +1,45 @@ +import os +import time +import json +import shutil + + +# Select the best-resolution timer function +try: + _get_time = time.perf_counter +except AttributeError: + if os.name == 'nt': + _get_time = time.clock + else: + _get_time = time.time + + +class Logger(object): + + def __init__(self, out, logname): + self.out = out + self.logname = logname + self._log = [] + self._start_at = None + + if not os.path.isdir(self.out): + os.makedirs(self.out) + + def start(self): + self._start_at = _get_time() + + @property + def elapsed_time(self): + if self._start_at is None: + raise RuntimeError('training has not been started yet') + return _get_time() - self._start_at + + def write(self, log): + self._log.append(log) + tmp_path = os.path.join(self.out, 'tmp') + with open(tmp_path, 'w') as f: + json.dump(self._log, f, indent=4) + + path = os.path.join(self.out, self.logname) + shutil.move(tmp_path, path) +