In this work, we propose a generative adversarial network (GAN) entitled EHR-M-GAN which simultaneously synthesizes mixed-type timeseries EHR data (e.g., continuous-valued timeseries and discrete-valued timeseries). EHR-M-GAN is capable of capturing the multidimensional, heterogeneous, and correlated temporal dynamics in patient trajectories.
This repository contains a Tensorflow implementation of EHR-M-GAN. For details, please see Generating Synthetic Mixed-type Longitudinal Electronic Health Records for Artificial Intelligent Applications. [Arxiv paper link]
The code requires
- Python 3.6 or higher
- Tensorflow 1.14.0 or higher
- Numpy
- Sklearn
- Pickle
- Matplotlib
- Seaborn
- Pandas
All datasets are publicly available from PhysioNet, and can be downloaded from the following links:
In order to preprocess the datasets for running EHR-M-GAN, please refer to the following repository.
main_train.py
: Use mixed-type timeseries EHR data as training set to generate synthetic datanetworks.py
: Components (generators and discriminators in the sequentially coupled GANs, encoders and decoders in the dual-VAE) in the modelm3gan.py
: Pretrain the latent representations and optimize the adversarial learning networksConstrastivelosslayer.py
: The contrastive loss function in learning the shared VAE representationsBilateral_lstm_cell.py
: The proposed Bilateral LSTM cell (single-layer)Bilateral_lstm_class.py
: The proposed Bilateral LSTM network with multiple layersinit_state.py
: Initial state function for recurrent neural networksutils.py
: other utility fucntions for adversarial training
To train the model(s) in the paper, simply run this command:
python main_train.py --dataset mimic --num_pre_epochs 500 --num_epochs 800 --epoch_ckpt_freq 100
For training the conditional extension of EHR-M-GAN in the paper, run this command:
python main_train.py --dataset mimic --conditional True --num_labels 1 --num_pre_epochs 500 --num_epochs 800 --epoch_ckpt_freq 100