LCA-PyTorch (lcapt) provides the ability to flexibly build single- or multi-layer convolutional sparse coding networks in PyTorch with the Locally Competitive Algorithm (LCA). LCA-Pytorch currently supports 1D, 2D, and 3D convolutional LCA layers, which maintain all the functionality and behavior of PyTorch convolutional layers. We currently do not support Linear (a.k.a. fully-connected) layers, but it is possible to implement the equivalent of a Linear layer with convolutions.
Required:
- Python (>= 3.8)
Recommended:
- GPU(s) with NVIDIA CUDA (>= 11.0) and NVIDIA cuDNN (>= v7)
pip install git+https://github.com/lanl/lca-pytorch.git
git clone [email protected]:lanl/lca-pytorch.git
cd lca-pytorch
pip install .
LCA-PyTorch layers inherit all functionality of standard PyTorch layers.
import torch
import torch.nn as nn
from lcapt.lca import LCAConv2D
# create a dummy input
inputs = torch.zeros(1, 3, 32, 32)
# 2D conv layer in PyTorch
pt_conv = nn.Conv2d(
in_channels=3,
out_channels=64,
kernel_size=7,
stride=2,
padding=3
)
pt_out = pt_conv(inputs)
# 2D conv layer in LCA-PyTorch
lcapt_conv = LCAConv2D(
out_neurons=64,
in_neurons=3,
kernel_size=7,
stride=2,
pad='same'
)
lcapt_out = lcapt_conv(inputs)
LCA solves the
where
in which each neuron's membrane potential,
Below is a mapping between the variable names used in this implementation and those used in Rozell et al.'s formulation of LCA.
LCA-PyTorch Variable | Rozell Variable | Description |
---|---|---|
input_drive | Drive from the inputs/stimulus | |
states | Internal state/membrane potential | |
acts | Code/Representation/External Communication | |
lambda_ | Transfer function threshold value | |
weights | Dictionary/Features | |
inputs | Input data | |
recons | Reconstruction of the input | |
tau | LCA time constant |
-
Dictionary Learning Using Built-In Update Method
-
Dictionary Learning Using PyTorch Optimizer
LCA-PyTorch is provided under a BSD license with a "modifications must be indicated" clause. See the LICENSE file for the full text. Internally, the LCA-PyTorch package is known as LA-CC-23-064.