Skip to content

Latest commit

 

History

History
93 lines (63 loc) · 4.99 KB

README.md

File metadata and controls

93 lines (63 loc) · 4.99 KB

Diffusion Dictionary Learning

In short, we find that feature maps of img2img DM can be decomposed, using the Sparse Autoencoder, into semantic features that align with GT semantic masks in terms of IoU. See our presentation for the results: Link

Examples features

Examples of learnt features' activations. Columns: Original Image, Ground Truth mask, Top-1 feature activation, Top-2 feature activation.

This project is based on Label-Efficient Semantic Segmentation with Diffusion Models repository code and data - here. We apply SAE to img2img pixel-space Diffusion Model and find features/latents that overlap with GT segmentation masks of images for particular classes. We conduct experiments to find best timestep and train SAE on block=6 outputs. Data collection and processing is implemented in collect_features.py script, and SAE training, visualization and metric evaulation is done in train-sae.ipynb notebook. Feel free to explore. The notebook with our best-trained SAE is train-sae-BEST-10scale+big3batch+moreEp+t150.ipynb. You can find visualizations for every class there.

Sections below are left from original codebase to better understand code structure and data.

 

Datasets

The evaluation is performed on 6 collected datasets with a few annotated images in the training set: Bedroom-18, FFHQ-34, Cat-15, Horse-21, CelebA-19 and ADE-Bedroom-30. The number corresponds to the number of semantic classes.

datasets.tar.gz (~47Mb)

 

DDPM

Pretrained DDPMs

The models trained on LSUN are adopted from guided-diffusion. FFHQ-256 is trained by ourselves using the same model parameters as for the LSUN models.

LSUN-Bedroom: lsun_bedroom.pt
FFHQ-256: ffhq.pt (Updated 3/8/2022)
LSUN-Cat: lsun_cat.pt
LSUN-Horse: lsun_horse.pt

Run

  1. Download the datasets:
      bash datasets/download_datasets.sh
  2. Download the DDPM checkpoint:
       bash checkpoints/ddpm/download_checkpoint.sh <checkpoint_name>
  3. Check paths in experiments/<dataset_name>/ddpm.json
  4. Run: bash scripts/ddpm/train_interpreter.sh <dataset_name>

Available checkpoint names: lsun_bedroom, ffhq, lsun_cat, lsun_horse
Available dataset names: bedroom_28, ffhq_34, cat_15, horse_21, celeba_19, ade_bedroom_30

Note: train_interpreter.sh is RAM consuming since it keeps all training pixel representations in memory. For ex, it requires ~210Gb for 50 training images of 256x256. (See issue)

Pretrained pixel classifiers and test predictions are here.

How to improve the performance

  • Tune for a particular task what diffusion steps and UNet blocks to use.

 

DatasetDDPM

Synthetic datasets

To download DDPM-produced synthetic datasets (50000 samples, ~7Gb) (updated 3/8/2022):
bash synthetic-datasets/ddpm/download_synthetic_dataset.sh <dataset_name>

Run | Option #1

  1. Download the synthetic dataset:
       bash synthetic-datasets/ddpm/download_synthetic_dataset.sh <dataset_name>
  2. Check paths in experiments/<dataset_name>/datasetDDPM.json
  3. Run: bash scripts/datasetDDPM/train_deeplab.sh <dataset_name>

Run | Option #2

  1. Download the datasets:
       bash datasets/download_datasets.sh

  2. Download the DDPM checkpoint:
       bash checkpoints/ddpm/download_checkpoint.sh <checkpoint_name>

  3. Check paths in experiments/<dataset_name>/datasetDDPM.json

  4. Train an interpreter on a few DDPM-produced annotated samples:
       bash scripts/datasetDDPM/train_interpreter.sh <dataset_name>

  5. Generate a synthetic dataset:
       bash scripts/datasetDDPM/generate_dataset.sh <dataset_name>
        Please specify the hyperparameters in this script for the available resources.
        On 8xA100 80Gb, it takes about 12 hours to generate 10000 samples.

  6. Run: bash scripts/datasetDDPM/train_deeplab.sh <dataset_name>
       One needs to specify the path to the generated data. See comments in the script.

Available checkpoint names: lsun_bedroom, ffhq, lsun_cat, lsun_horse
Available dataset names: bedroom_28, ffhq_34, cat_15, horse_21