This is the Pytorch implementation for the multiple instance learning model described in the paper Dual-stream Multiple Instance Learning Network for Whole Slide Image Classification with Self-supervised Contrastive Learning (CVPR 2021, accepted for oral presentation).
Install anaconda/miniconda
Required packages
$ conda env create --name dsmil --file env.yml
$ conda activate dsmil
The MIL benchmark dataset can be downloaded via:
$ python download.py --dataset=mil
Precomputed features for TCGA Lung Cancer dataset can be downloaded via:
$ python download.py --dataset=tcga
This dataset requires 20GB of free disk space.
Train DSMIL on standard MIL benchmark dataset:
$ python train_mil.py
Switch between MIL benchmark dataset, use option:
[--datasets] # musk1, musk2, elephant, fox, tiger
Other options are available for learning rate (0.0002), cross validation fold (5), weight-decay (5e-3), and number of epochs (40).
Train DSMIL on TCGA Lung Cancer dataset (precomputed features):
$ python train_tcga.py --new_features=0
If you are processing WSI from raw images, you will need to download the WSIs first.
- Download WSIs.
Navigate to './tcga-download/' and download WSIs from TCGA data portal using the manifest file and configuration file.
The example shows the case of Windows operating system. The WSIs will be saved in './WSI/TCGA-lung/LUAD' and './WSI/TCGA-lung/LUSC'.
The raw WSIs take about 1TB disc space and may take several days to download. Open command line tool (Command Prompt for the case of Windows), navigate to './tcga-download', and use commands:
$ cd tcga-download
$ gdc-client -m gdc_manifest.2020-09-06-TCGA-LUAD.txt --config config-LUAD.dtt
$ gdc-client -m gdc_manifest.2020-09-06-TCGA-LUSC.txt --config config-LUSC.dtt
- Prepare the patches.
We will be using OpenSlide, a C library with a Python API that provides a simple interface to read WSI data. We refer the users to OpenSlide Python API document for the details of using this tool.
The patches could be saved in './WSI/TCGA-lung/pyramid' in a pyramidal structure for the magnifications of 20x and 5x. Navigate to './tcga-download/OpenSlide/bin' and run the script 'TCGA-pre-crop.py':
cd tcga-download/OpenSlide/bin
$ python TCGA-pre-crop.py --multiscale=1
Or, the patches could be cropped at a single magnification of 10x and saved in './WSI/TCGA-lung/single' via:
$ python TCGA-pre-crop.py --multiscale=0
- Train the embedder.
We provided a modified script from this repository Pytorch implementation of SimCLR For training the embedder.
Navigate to './simclr' and edit the attributes in the configuration file 'config.yaml'. You will need to determine a batch size that fits your gpu(s). We recommand to use a batch size of at least 512 to get good simclr features. The trained model weights and loss log are saved in folder './simclr/runs'.
cd simclr
$ python run.py
- Compute the features.
Compute the features for 20x magnification:
$ python compute_feats.py --dataset=wsi-tcga-lung
Compute the features for 10x magnification:
$ python compute_feats.py --dataset=wsi-tcga-lung-single --magnification=10x
- Start training.
$ python train_tcga.py --new_features=1
- Testing.
We provided a testing pipeline for several sample slides. The slides can be downloaded via:
$ python download.py --dataset=tcga-test
To crop the WSIs into patches, navigate to './tcga-download/OpenSlide/bin' and run the script 'TCGA-pre-crop.py':
$ cd tcga-download/OpenSlide/bin
$ python TCGA-test-10x.py
A folder containing all patches for each WSI will be created at './test/patches'.
After the WSIs are cropped, run the testing script:
$ python testing.py
The thumbnails of the WSIs will be saved in './test/thumbnails'.
The detection color maps will be saved in './test/output'.
The testing pipeline will process every WSI placed inside the './test/input' folder. The slide will be detected as a LUAD, LUSC or benign sample.
You could modify train_tcga.py to easily let it work with your datasets. After you have trained your embedder, you will need to compute the features and organize them as:
- For each bag, generate a .csv file where each row contains the feature of an instance. The .csv file should be named as "bagID.csv" and put into a folder named "dataset-name".
- Generate a "dataset-name.csv" file with two columns where the first column contains the paths to all bagID.csv files, and the second column contains the bag labels.
- Replace the corresponding file path in the script with the file path of "dataset.csv".
bags_path = pd.read_csv(PATH_TO_[dataset-name.csv])
- Configure the corresponding number of classes argument for creating the DSMIL model.
- Start training.
$ python train_tcga.py --new_features=1
If you use the code or results in your research, please use the following BibTeX entry.
@article{li2020dualstream,
author = {Bin Li and Yin Li and Kevin W. Eliceiri},
title = {Dual-stream Multiple Instance Learning Network for Whole Slide Image Classification with Self-supervised Contrastive Learning},
journal = {arXiv preprint arXiv:2011.08939},
year = {2020}
}