Skip to content

Latest commit

 

History

History
executable file
·
164 lines (129 loc) · 6.47 KB

README.md

File metadata and controls

executable file
·
164 lines (129 loc) · 6.47 KB
Table of contents
  1. Overview
  2. Installation
  3. Data
  4. Pretrained weights
  5. Train
  6. Evaluation
  7. Implementation details
  8. Acknowledgments
  9. Contacts

Official PyTorch implementation of "DiMSUM: Diffusion Mamba - A Scalable and Unified Spatial-Frequency Method for Image Generation" (NeurIPS'24)

Hao Phung*13†·Quan Dao*12†·Trung Dao1

Hoang Phan4· Dimitris N. Metaxas2·Anh Tran1

1VinAI Research   2Rutgers University   3Cornell University   4New York University

[Page]    [Paper]   

*Equal contribution   Work done while at VinAI Research

Overview

We propose DiMSUM, a hybrid Mamba-Transformer diffusion model that synergistically leverages both spatial and frequency information for high-quality image synthesis. Through extensive experiments on standard benchmarks, our method achieves state-of-the-art results, with a FID of 4.62 on CelebHQ 256, 3.76 on LSUN Church, and 2.11 on ImageNet1k 256. Additionally, our approach attains faster training convergence compared to Zigma and other diffusion methods. In detail, our method outperforms both DiT and SiT while requiring less than a third of the training iterations, achieving the best FID score of 2.11.

Details of the model architecture and experimental results can be found in our following paper:

@inproceedings{phung2024dimsum,
   title={DiMSUM: Diffusion Mamba - A Scalable and Unified Spatial-Frequency Method for Image Generation},
   author={Phung, Hao and Dao, Quan and Dao, Trung and Phan, Hoang and Metaxas, Dimitris and Tran, Anh},
   booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
   year= {2024},
}

Please CITE our paper and give us a ⭐ whenever this repository is used to help produce published results or incorporated into other software.

Installation

  • Python 3.10.13

    • conda create -n dimsum python=3.10.13
  • torch 2.1.1 + cu118

    • pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
  • Requirements:

    • pip install -r requirements.txt
  • Install causal_conv1d and mamba

    • conda install conda-forge::cudatoolkit-dev
    • cd causal_conv1d && pip install -e . && cd ..
    • cd mamba && pip install -e . && cd ..
  • Add python path for DiMSUM: export PYTHONPATH=$PYTHONPATH:$(pwd)

Data

Training

For CelebA HQ (256) and LSUN, please follow this repo for dataset preparation.

Evaluation

For evaluation, please resize and extract "jpeg" images from dataset first.

For LMDB data (like celeba_256 and lsun_church), run this command:

python eval_toolbox/resize_lmdb.py --dataset celeba_256 --datadir ./data/celeba_256/celeba-lmdb/ --image_size 256 --save_dir real_samples/

For image folder of jpeg/png images, run this command instead:

python eval_toolbox/resize.py main input_data_dir real_samples/dataname

Pretrained Weights

Exp #Params FID Checkpoints
Celeba 256 460M 4.62 celeb256_225ep.pt
Church 256 460M 3.76 church_395ep.pt
ImageNet-1K 256 (CFG) 460M 2.11 imnet256_510ep.pt

Train

Comment/Uncomment command lines for desired dataset, then run: bash scripts/train.sh

Evaluation

To sampe images from pretrained checkpoints, run:

bash scripts/sample.sh

To evaluate, select a relevant command and run:

bash scripts/eval.sh

Implementation details

Acknowledgments

This project is based on Vim, LFM, SiT, DiT, ZigMa. Thanks for publishing their wonderful works with codes.

Contacts

If you have any problems, please open an issue in this repository or ping an email to [email protected] and [email protected].