Extraction and Recovery of Spatio-Temporal Structure in Latent Dynamics Alignment with Diffusion Models [NeurIPS'2023 Spotlight]
A new tag v1.0.1 has been created.
- Initialized linear probing layers with an identity matrix to enhance alignment stability.
- Improved diffusion model stability using data augmentation and
cosine_beta_schedule
. - Resolved NaN issues for better numerical stability.
To install the required dependancies using conda, run:
conda create --name erdiff --file requirements.txt
To install the required dependancies using Python virtual environment, run:
python3 -m venv erdiff
source erdiff/bin/activate
python3 -m pip install --upgrade pip
python3 -m pip install -e .
To train the diffusion model on the source session, run:
cd scripts/ && sbatch run_diffusion_train.sh
To perform the diffusion-guided maximum likelihood alignment, run:
cd scripts/ && sbatch run_mla.sh
The alignment process across epochs can be viewed in scripts/mla_erdiff_398637.out
.
If you find the code useful for your research, please consider citing our work:
@article{wang2024extraction,
title={Extraction and recovery of spatio-temporal structure in latent dynamics alignment with diffusion model},
author={Wang, Yule and Wu, Zijing and Li, Chengrui and Wu, Anqi},
journal={Advances in Neural Information Processing Systems},
volume={36},
year={2024}
}