[ACM Transactions on Computing for Heathcare, 2023] ScouT: Synthetic Counterfactuals via Spatiotemporal Transformers for Actionable Healthcare
This is the official repository for the paper SCouT: Synthetic Counterfactuals via Spatiotemporal Transformers for Actionable Healthcare.
- Environment Setup
- Dataset Setup
- Experiment Config Setup
- SCouT Library
- Demo
- Third Party Code
- Citations
The following shell script creates an anaconda environment called "scout" and installs all the required packages.
source env_setup.sh
Data is setup as a N x T x D numpy matrix data.npy
where:
-- N is units
-- T is total time interval
-- D is the total covariates
Along with the data matrix provide a binary matrix mask.npy
of shape N x T indicating missing measurement by 1
See /synthetic_data/synethtic_data_noise_1
for an example
Experiment configurations in SCouT are defined using YAML files. These config files specify the model architecture, modeling parameters, and training settings. An example configuration is provided in experiment_config/synthetic.yaml
.
feature_dim
: Number of covariates in the datasetcont_dim
: Number of continuous covariatesdiscrete_dim
: Number of discrete covariateshidden_size
: Hidden size of the transformer modeln_layers
: Number of layers in the SCouT modeln_heads
: Number of attention heads in the transformer
K
: Number of donor units to useseq_range
: Total units (donor + target)pre_int_len
: Length of pre-intervention interval to modelpost_int_len
: Length of post-intervention interval to modeltime_range
: Total time interval periodinterv_time
: Time point of interventiontarget_id
: ID of the target unitlowrank
: Whether to use low-rank approximationrank
: Rank of the low-rank approximation
batch_size
: Batch size for traininglr
: Learning rateweight_decay
: Weight decay for regularizationwarmup_steps
: Number of warmup steps for learning rate scheduler
The core of the SCouT library is the SCOUT
class which handles model initialization, training, and prediction. Here's how to use the key functionalities:
scout = SCOUT(config_path="./experiment_config/synthetic.yaml",
op_dir="./logs/",
random_seed=42,
datapath="./synthetic_data/synthetic_data_noise_1/",
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
The fit()
method handles both pretraining on donor units and finetuning on the target unit:
scout = scout.fit(pretrain=True, pretrain_iters=10000, finetune_iters=1000)
After training, .predict()
generates counterfactual predictions for post-intervention period:
counterfactual_predictions = scout.predict()
Extract donor attention weights to analyze which donor units influenced the predictions:
attention_weights = scout.return_attention(interv_time=1600)
Load a previously trained model:
scout.load_model_from_checkpoint("./logs/finetune/checkpoint.pt")
The repository includes demo.ipynb
, a Jupyter notebook that demonstrates SCouT's core functionality using synthetic data. The notebook shows:
- Setting up model configuration and paths
- Initializing and training the SCouT model
- Generating counterfactual predictions
- Visualizing results comparing observed data, ground truth, and counterfactual predictions
We are grateful to huggingface for their transformers library, which our SCouT model builds upon. Our architecture leverages their BERT implementation as the foundation for our spatiotemporal transformer models.
Please cite the paper and star this repo if you find it useful, thanks! Feel free to contact [email protected] or open an issue if you have any questions. Cite our work using the following bitex entry:
@misc{dedhia2022scoutsyntheticcounterfactualsspatiotemporal,
title={SCouT: Synthetic Counterfactuals via Spatiotemporal Transformers for Actionable Healthcare},
author={Bhishma Dedhia and Roshini Balasubramanian and Niraj K. Jha},
year={2022},
eprint={2207.04208},
archivePrefix={arXiv},
primaryClass={cs.AI},
url={https://arxiv.org/abs/2207.04208},
}