We develop open Foundation Models from numerous public datasets using their heterogeneous expert annotations. Our Ark models outperform SOTA fully/self-supervised methods on various thoracic disease classification tasks and organ/bones segmentation tasks. Ark offers embeddings with superior quality over Google's CXR Foundation Model.
Foundation Ark: Accruing and Reusing Knowledge for Superior and Robust Performance
DongAo Ma1, Jiaxuan Pang1, Michael B. Gotway2, Jianming Liang1
1 Arizona State University, 2 Mayo Clinic
International Conference on Medical Image Computing and Computer Assisted Intervention (MICCAI 2023) (Oral + Poster)
Paper (PDF, Arxiv) | Code | Poster | Oral Presentation (YouTube, BiliBili)
- Python
- PyTorch (pytorch.org)
Create and activate a Python 3 conda environment:
$ conda create -n ark python=3
$ conda activate ark
Install PyTorch according to the CUDA version (e.g., CUDA 11.6)
$ conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia
Clone the repository:
$ git clone https://github.com/Mda233/Ark.git
$ cd Ark
$ pip install -r requirements
Modify <PATH_TO_DATASET> in datasets_config.yaml for each dataset.
(To incorporate a new dataset, refer to the examples provided in datasets_config.yaml. Afterwards, create a corresponding dataloader for the dataset in dataloader.py.)
# Train Ark-6 with six public datasets
python main_ark.py --data_set MIMIC --data_set CheXpert
--data_set ChestXray14 --data_set RSNAPneumonia
--data_set VinDrCXR --data_set Shenzhen
--opt sgd --warmup-epochs 20 --lr 0.3
--batch_size 200 --model swin_base --init imagenet
--pretrain_epochs 200 --test_epoch 10
--pretrained_weights https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth
--momentum_teacher 0.9 --projector_features 1376
You can request the pretrained Ark-5 and Ark-6 models (the teacher model) in our paper throught this Google Form or wjx.cn.
- Create Swin Transformer Base model from the official model or from timm (v0.5.4):
model = timm.create_model('swin_base_patch4_window7_224', num_classes=args.num_class, pretrained=False)
- Load the weight:
state_dict = torch.load('<PATH_TO_MODEL>/ark6_teacher_ep200_swinb_projector1376_mlp.pth.tar', map_location="cpu")
for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']:
if k in state_dict:
print(f"Removing key {k} from pretrained checkpoint")
del state_dict[k]
model.load_state_dict(state_dict, strict=False)
We have integrated Ark pretrained models in our Benchmark Tansformers GitHub Repository
$ git clone https://github.com/jlianglab/BenchmarkTransformers.git
$ cd BenchmarkTransformers
python main_classification.py --data_set ChestXray14
--model swin_base
--init ark
--pretrained_weights [PATH_TO_MODEL]/ark6_teacher_ep200_swinb_projector1376_mlp.pth.tar
--data_dir [PATH_TO_DATASET]
--train_list dataset/Xray14_train_official.txt
--val_list dataset/Xray14_val_official.txt
--test_list dataset/Xray14_test_official.txt
--lr 0.01 --opt sgd --epochs 200 --warmup-epochs 0 --batch_size 64
If you use this code or use our pre-trained weights for your research, please cite our paper:
@InProceedings{ma2023foundation,
author="Ma, DongAo and Pang, Jiaxuan and Gotway, Michael B. and Liang, Jianming",
title="Foundation Ark: Accruing and Reusing Knowledge for Superior and Robust Performance",
booktitle="Medical Image Computing and Computer Assisted Intervention -- MICCAI 2023",
year="2023",
publisher="Springer Nature Switzerland",
address="Cham",
pages="651--662",
isbn="978-3-031-43907-0"
}
This research has been supported in part by ASU and Mayo Clinic through a Seed Grant and an Innovation Grant, and in part by the NIH under Award Number R01HL128785. The content is solely the responsibility of the authors and does not necessarily represent the official views of the NIH. This work has utilized the GPUs provided in part by the ASU Research Computing and in part by the Bridges-2 at Pittsburgh Supercomputing Center through allocation BCS190015 and the Anvil at Purdue University through allocation MED220025 from the Advanced Cyberinfrastructure Coordination Ecosystem: Services & Support (ACCESS) program, which is supported by National Science Foundation grants #2138259, #2138286, #2138307, #2137603, and #2138296. We also acknowledge Google for granting us access to CXR Foundation API, which enabled us to generate the embeddings for the target datasets. The content of this paper is covered by patents pending.
Released under the ASU GitHub Project License.