This is the official implementation of the paper What's in a Prior? Learned Proximal Networks for Inverse Problems @ ICLR 2024
by Zhenghan Fang, Sam Buchanan, and Jeremias Sulam
[Paper
] [Project
] [Talk
] [Slides
] [Poster
] [BibTeX
]
We propose learned proximal networks (LPN), a new class of deep neural networks that exactly implement the proximal operator of a general learned function. Such an LPN implicitly learns a regularization function for inverse problems that can be characterized and evaluated, shedding light onto what has been learned from data and improving the interpretability of learning-based solutions. In turn, we present a new training problem, dubbed proximal matching, that provably promotes the recovery of the correct regularization term (i.e., the log of the data distribution). Moreover, we show convergence for PnP reconstruction algorithms using LPN with minimal and verifiable assumptions.
- Laplacian example: the proximal operator
$f_\theta$ and log-prior$R_\theta$ learned by LPN for the Laplacian distribution, trained via the$\ell_2$ ,$\ell_1$ , or proximal matching ($\mathcal{L}_{PM}$ ) loss.
- Deblurring on CelebA,
$\sigma_{blur}=1.0$ ,$\sigma_{noise}=0.02$
- Deblurring on CelebA,
$\sigma_{blur}=1.0$ ,$\sigma_{noise}=0.04$
- Sparse-view tomographic reconstruction on MayoCT
- Compressed sensing on MayoCT (compression rate = 1/16)
The code is implemented with Python 3.9.16 and PyTorch 1.12.0. Install the conda environment by
conda env create -f environment.yml
Install the lpn
package
pip install -e .
The datasets are placed in data/
folder.
The dataset is already in data/mnist
with the following structure:
data/
└── mnist
├── labels.npy
└── mnist.npy
Download files of the CelebA dataset, as defined in the filelist in torchvision's CelebA class:
img_align_celeba.zip, list_attr_celeba.txt, identity_CelebA.txt, list_bbox_celeba.txt, list_landmarks_align_celeba.txt, list_eval_partition.txt
directly from the authors' google drive link, and place them in data/celeba/celeba
. Unzip img_align_celeba.zip
. You may use the following commands to download the files automatically:
pip install gdown
mkdir data/celeba/celeba
cd data/celeba/celeba
gdown --id 0B7EVK8r0v71pZjFTYXZWM3FlRnM
gdown --id 0B7EVK8r0v71pY0NSMzRuSXJEVkk
gdown --id 0B7EVK8r0v71pTzJIdlJWdHczRlU
gdown --id 0B7EVK8r0v71pd0FJY3Blby1HUTQ
gdown --id 0B7EVK8r0v71pbThiMVRxWXZ4dU0
gdown --id 0B7EVK8r0v71pblRyaVFSWGxPY0U
gdown --id 1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS
unzip img_align_celeba.zip
The resulting directory should have the following structure:
data/
└── celeba
└── celeba
├── img_align_celeba (unzipped from img_align_celeba.zip)
├── identity_CelebA.txt
├── list_attr_celeba.txt
├── list_bbox_celeba.txt
├── list_eval_partition.txt
└── list_landmarks_align_celeba.txt
Download the dataset from the authors' google drive link, and place it in data/mayoct
. See the authors' github repo and paper for more details.
The resulting directory should have the following structure:
data/
└── mayoct
└── mayo_data_arranged_patientwise
├── test
│ ├── FBP
│ ├── Phantom
│ └── Sinogram
└── train
├── FBP
├── Phantom
└── Sinogram
Code of the main functionalities of LPN is placed in the lpn
folder.
Code for repoducing the experiments in the paper is placed in the exps
folder.
To reproduce the Laplacian experiment, use code in exps/laplacian/
.
-
Train:
laplacian_train.ipynb
-
Test:
laplacian_test.ipynb
-
Visualize results
- Plot Fig. 1 in the paper:
viz_compact.ipynb
- Plot Fig. 6 in the supplementary of paper:
viz_supp.ipynb
Outputs (figures, models, and results) will be saved in exps/laplacian/experiments/
.
Code for reproducing the MNIST experiment is in exps/mnist/
.
- Train:
bash exps/mnist/train_mnist.sh
-
Model will be saved at
exps/mnist/experiments/mnist/model.pt
. -
We also provide the pretrained model.
- Compute prior:
bash exps/mnist/prior_mnist.sh
- Results will be saved in
exps/mnist/experiments/mnist/prior
.
- Visualize results (Figures 3 and 7 in paper)
-
Learned prior at example images:
notebooks/viz_img_and_prior.ipynb
-
Violin plot over all images:
notebooks/viz_violin.ipynb
-
Set
perturb_mode
in the notebooks togaussian
,convex
, orblur
for different perturbation modes. -
Figures will be saved in
exps/mnist/experiments/mnist/figures
.
Code for reproducing the CelebA experiment is in exps/celeba/
.
- Train:
bash exps/celeba/train.sh
-
Two models will be trained with different noise levels (0.05 and 0.1), saved in
exps/celeba/models/lpn/s={0.05, 0.1}/model.pt
. -
We also provide the pretrained models.
- Deblurring using trained LPN:
python exps/celeba/test.py --sigma_blur [BLUR LEVEL] --sigma_noise [NOISE LEVEL]
- E.g.,
python exps/celeba/test.py --sigma_blur 1.0 --sigma_noise 0.02
will run deblurring using LPN for Gaussian kernel with standard deviation of$\sigma_{blur}=1.0$ and noise std of$\sigma_{noise}=0.02$ . -
sigma_blur
can be chosen from {1.0, 2.0},sigma_noise
can be chosen from {0.02, 0.04}. - Results will be saved in
exps/celeba/results/inverse/deblur/blur=[BLUR LEVEL]_noise=[NOISE LEVEL]/admm/lpn/{x,y,xhat}
.x
andy
contain the clean images and blurred observation, respectively.xhat
contains the deblurred images.
Code for reproducing the MayoCT experiment is in exps/mayoct/
.
- Train:
bash exps/mayoct/train.sh
-
Model will be saved in
exps/mayoct/models/lpn/s=0.1
. -
We also provide the pretrained model.
- Tomography reconstruction using LPN:
python exps/mayoct/inverse_mayoct_tomo.py
- Results will be saved in
exps/mayoct/results/inverse/mayoct/tomo/num_angles=200_det_shape=400_noise=2.0/lpn
.x
: ground truth,y
: measurements,xhat
: LPN reconstruction,fbp
: FBP reconstruction. Average PSNR and SSIM of LPN reconstructions are saved inxhat/recon_log.txt
.
- Compressed sensing (CS) using LPN:
bash exps/mayoct/test_cs.sh
- Results will be saved in
exps/mayoct/results/inverse/mayoct/cs/M=[NUM OF MEASUREMENTS]_noise=0.001/lpn
.x
: ground truth,y
: measurements,xhat
: LPN reconstruction,ls
: least squares reconstruction. Average PSNR and SSIM of LPN reconstructions are saved inxhat/recon_log.txt
.
All checkpoints are provided in this Google drive.
If you find the code useful, please consider citing
@inproceedings{
fang2024whats,
title={What's in a Prior? Learned Proximal Networks for Inverse Problems},
author={Zhenghan Fang and Sam Buchanan and Jeremias Sulam},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024}
}