This is the PyTorch library for training Submanifold Sparse Convolutional Networks.
This library brings Spatially-sparse convolutional networks to PyTorch. Moreover, it introduces Submanifold Sparse Convolutions, that can be used to build computationally efficient sparse VGG/ResNet/DenseNet-style networks.
With regular 3x3 convolutions, the set of active (non-zero) sites grows rapidly:
With Submanifold Sparse Convolutions, the set of active sites is unchanged. Active sites look at their active neighbors (green); non-active sites (red) have no computational overhead:
Stacking Submanifold Sparse Convolutions to build VGG and ResNet type ConvNets, information can flow along lines or surfaces of active points.
Disconnected components don't communicate at first, although they will merge due to the effect of strided operations, either pooling or convolutions. Additionally, adding ConvolutionWithStride2-SubmanifoldConvolution-DeconvolutionWithStride2 paths to the network allows disjoint active sites to communicate; see the 'VGG+' networks in the paper.
From left: (i) an active point is highlighted; a convolution with stride 2 sees the green active sites (ii) and produces output (iii), 'children' of hightlighted active point from (i) are highlighted; a submanifold sparse convolution sees the green active sites (iv) and produces output (v); a deconvolution operation sees the green active sites (vi) and produces output (vii).
SparseConvNet supports input with different numbers of spatial/temporal dimensions.
Higher dimensional input is more likely to be sparse because of the 'curse of dimensionality'.
Dimension | Name in 'torch.nn' | Use cases |
---|---|---|
1 | Conv1d | Text, audio |
2 | Conv2d | Lines in 2D space, e.g. handwriting |
3 | Conv3d | Lines and surfaces in 3D space or (2+1)D space-time |
4 | - | Lines, etc, in (3+1)D space-time |
We use the term 'submanifold' to refer to input data that is sparse because it has a lower effective dimension than the space in which it lives, for example a one-dimensional curve in 2+ dimensional space, or a two-dimensional surface in 3+ dimensional space.
In theory, the library supports up to 10 dimensions. In practice, ConvNets with size-3 SVC convolutions in dimension 5+ may be impractical as the number of parameters per convolution is growing exponentially. Possible solutions include factorizing the convolutions (e.g. 3x1x1x..., 1x3x1x..., etc), or switching to a hyper-tetrahedral lattice (see Sparse 3D convolutional neural networks).
SparseConvNets can be built either by defining a function that inherits from torch.nn.Module or by stacking modules in a sparseconvnet.Sequential:
import torch
import sparseconvnet as scn
# Use the GPU if there is one, otherwise CPU
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model = scn.Sequential().add(
scn.SparseVggNet(2, 1,
[['C', 8], ['C', 8], ['MP', 3, 2],
['C', 16], ['C', 16], ['MP', 3, 2],
['C', 24], ['C', 24], ['MP', 3, 2]])
).add(
scn.SubmanifoldConvolution(2, 24, 32, 3, False)
).add(
scn.BatchNormReLU(32)
).add(
scn.SparseToDense(2, 32)
).to(device)
# output will be 10x10
inputSpatialSize = model.input_spatial_size(torch.LongTensor([10, 10]))
input_layer = scn.InputLayer(2, inputSpatialSize)
msgs = [[" X X XXX X X XX X X XX XXX X XXX ",
" X X X X X X X X X X X X X X X X ",
" XXXXX XX X X X X X X X X X XXX X X X ",
" X X X X X X X X X X X X X X X X X X ",
" X X XXX XXX XXX XX X X XX X X XXX XXX "],
[" XXX XXXXX x x x xxxxx xxx ",
" X X X XXX X x x x x x x x ",
" XXX X x xxxx x xxxx xxx ",
" X X XXX X x x x x x ",
" X X XXXX x x x x xxxx x ",]]
# Create Nx3 and Nx1 vectors to encode the messages above:
locations = []
features = []
for batchIdx, msg in enumerate(msgs):
for y, line in enumerate(msg):
for x, c in enumerate(line):
if c == 'X':
locations.append([y, x, batchIdx])
features.append([1])
locations = torch.LongTensor(locations)
features = torch.FloatTensor(features).to(device)
input = input_layer([locations,features])
print('Input SparseConvNetTensor:', input)
output = model(input)
# Output is 2x32x10x10: our minibatch has 2 samples, the network has 32 output
# feature planes, and 10x10 is the spatial size of the output.
print('Output SparseConvNetTensor:', output)
Examples in the examples folder include
- Assamese handwriting recognition
- Chinese handwriting for recognition
- 3D Segmentation using ShapeNet Core-55
- ScanNet 3D Semantic label benchmark
For example:
cd examples/Assamese_handwriting
python VGGplus.py
Tested with PyTorch 1.3, CUDA 10.0, and Python 3.3 with Conda.
conda install pytorch torchvision cudatoolkit=10.0 -c pytorch # See https://pytorch.org/get-started/locally/
git clone [email protected]:facebookresearch/SparseConvNet.git
cd SparseConvNet/
bash develop.sh
To run the examples you may also need to install unrar:
apt-get install unrar
SparseConvNet is BSD licensed, as found in the LICENSE file. Terms of use. Privacy
Copyright © Meta Platforms, Inc
- ICDAR 2013 Chinese Handwriting Recognition Competition 2013 First place in task 3, with test error of 2.61%. Human performance on the test set was 4.81%. Report
- Spatially-sparse convolutional neural networks, 2014 SparseConvNets for Chinese handwriting recognition
- Fractional max-pooling, 2014 A SparseConvNet with fractional max-pooling achieves an error rate of 3.47% for CIFAR-10.
- Sparse 3D convolutional neural networks, BMVC 2015 SparseConvNets for 3D object recognition and (2+1)D video action recognition.
- Kaggle plankton recognition competition, 2015 Third place. The competition solution is being adapted for research purposes in EcoTaxa.
- Kaggle Diabetic Retinopathy Detection, 2015 First place in the Kaggle Diabetic Retinopathy Detection competition.
- SparseConvNet 'classic' version
- Submanifold Sparse Convolutional Networks, 2017 Introduces deep 'submanifold' SparseConvNets.
- Workshop on Learning to See from 3D Data, 2017 First place in the semantic segmentation competition. Report
- 3D Semantic Segmentation with Submanifold Sparse Convolutional Networks, 2017 Semantic segmentation for the ShapeNet Core55 and NYU-DepthV2 datasets, CVPR 2018
- Unsupervised learning with sparse space-and-time autoencoders (3+1)D space-time autoencoders
- ScanNet 3D semantic label benchmark 2018 0.726 average IOU for 3D semantic segmentation.
- MinkowskiEngine is an alternative implementation of SparseConvNet; 0.736 average IOU for ScanNet.
- SpConv: PyTorch Spatially Sparse Convolution Library is an alternative implementation of SparseConvNet.
- Live Semantic 3D Perception for Immersive Augmented Reality describes a way to optimize memory access for SparseConvNet.
- OccuSeg real-time object detection using SparseConvNets.
- TorchSparse implements 3D submanifold convolutions.
- TensorFlow 3D implements submanifold convolutions.
- VoTr implements submanifold voxel transformers using SpConv.
- Mix3D brings MixUp to the sparse setting— 0.781 average IOU for ScanNet 3D semantic segmentation.
- Point Transformer V3 uses sparse convolutions as an enhanced conditional positional encoding (xCPE); 0.794 average IOU for ScanNet 3D semantic segmentation.
If you find this code useful in your research then please cite:
3D Semantic Segmentation with Submanifold Sparse Convolutional Networks, CVPR 2018
Benjamin Graham,
Martin Engelcke,
Laurens van der Maaten,
@article{3DSemanticSegmentationWithSubmanifoldSparseConvNet,
title={3D Semantic Segmentation with Submanifold Sparse Convolutional Networks},
author={Graham, Benjamin and Engelcke, Martin and van der Maaten, Laurens},
journal={CVPR},
year={2018}
}
and/or
Submanifold Sparse Convolutional Networks, https://arxiv.org/abs/1706.01307
Benjamin Graham,
Laurens van der Maaten,
@article{SubmanifoldSparseConvNet,
title={Submanifold Sparse Convolutional Networks},
author={Graham, Benjamin and van der Maaten, Laurens},
journal={arXiv preprint arXiv:1706.01307},
year={2017}
}