This repository contains the PyTorch self-supervised pretraining, fine-tuning, and evaluation codes for compact self-supervised vision transformer (cSiT) with the main focus on Histopathology image classification. The SiT model uses three pre-text tasks for self-supervised learning:
- Reconstruction
- Rotation Prediction
- Contrastive Learning
This repo is mainly adopted from SiT Repository with some modifications and improvements. Some features were added and the backbone of the model was replaced with CCT-14/7x2(with some changes).
- Gathering a large and diverse dataset of unlabeled histopathology images (consists of ~600k images from famous histopathology datasets)
- Using Compact Convolutional Transformer, which is a compact vision transformer, as a backbone
- Compared to SOTA self-supervised models, getting competitive results on two famous histopathology image classification datasets, NCT-CRC and BreakHis
- Testing the model in a semi-supervised setting on the NCT-CRC dataset
- Adding pieces of code for evaluating the results of the model
First, we pre-trained the SiT-Compact model on a huge diverse unlabeled histopathology dataset consisting of 600k images gathered from well-known histopathology datasets (e.g., PatchCamelyon, ICIAR2018, TUPAC2016 Mitosis, and …). After pre-training phase, we used the SSL pre-trained model in several scenarios and conducted several experiments on two famous datasets (NCT-CRC and BreakHis) to display capability of Self-Supervised Learning on histopathology images.
To pre-train the model in a self-supervised manner, we need a collection of unlabeled images, which should be similar to our downstream task's images. As there is not any proper unlabeled dataset of histopathology images, we created our own. To find related datasets, we used dataset tables of this and this papers.
Fig.1 - sample of unlabeled datasetList of reference datasets are available at this markdown.
As mentioned before, the pre-training consists of three pre-texts.
- The original image corrupts with random drop, random replace, color distortion, blurring, and gray-scale in the reconstruction pre-text. Then the corrupted image is converted to patches and is fed to Transformer. After feedforward, the output of the last encoder in the Transformer is combined and makes a reconstructed image. Then with a loss function difference between the reconstructed image and the original image is measured.
- The second pre-text task is rotating the corrupted image randomly with 4 degrees (0, 90, 180, 270) and adding an extra token (like cls token) to predict rotation. A loss function is used to learn the rotation of images in the pre-training dataset.
- The third task is contrastive learning. The main focus of this method is to learn image embeddings that are invariant to different augmented views of the same image while being discriminative among different images. In each mini-batch, two images are generated from one image, and a special loss function is defined for contrastive learning. For more information, check this paper.
In the main paper (SiT), all three pre-text tasks are explained completely.
We used CCT-14/7×2, with some modifications, as the backbone of the model. By reducing embeddings in the transfomer, we had a compact and small vision transformer with around 6 million parameters. For comparison, ViT-Base has around 86 million parameters. The code of the CCT was adapted from this repository. In CCT, the final representations of encoder are given to a sequence pool modeule and prediction is produced by all of tokens. Also, a convolutional embedder is used to generate patch embeddings.
After pre-training, the model is fine-tuned on some famous histopathology datasets like NCT-CRC-HE-100K and BreakHis and tested on them. Main focus of these tests are investigating the effect of pre-training on histopathology datasets and comparing cSiT model with other self-supervised methods used on histopathology image classification.
Accuracy | Macro Recall | Macro Precision | Macro F1 | Weighted F1 | Kappa Score | Macro AUC | |
---|---|---|---|---|---|---|---|
FT | 0.84 | 0.79 | 0.82 | 0.78 | 0.84 | 0.818 | 0.980 |
SSL+LE | 0.91 | 0.87 | 0.89 | 0.86 | 0.9 | 0.890 | 0.985 |
SSL+FT | 0.94 | 0.92 | 0.92 | 0.91 | 0.94 | 0.928 | 0.993 |
Results of model on the three different mode on BreakHis Dataset (5-fold stratified cross validation)
Accuracy(Mean) | Recall(Mean) | Macro F1(Mean) | Weighted F1(Mean) | Kappa Score(Mean) | Precision(Mean) | |
---|---|---|---|---|---|---|
FT | 0.857 | 0.93 | 0.82 | 0.852 | 0.642 | 0.866 |
SSL+LE | 0.848 | 0.932 | 0.812 | 0.845 | 0.625 | 0.858 |
SSL+FT | 0.937 | 0.952 | 0.928 | 0.938 | 0.8546 | 0.955 |
Accuracy | 40X | 100X | 200X | 400X | Mean |
---|---|---|---|---|---|
cSiT | 93.78 | 93.12 | 94.28 | 94.39 | 93.89 |
Results of cSiT compared to basic SSL methods (ref)
Models | Macro F1 on NCT-CRC | Macro F1 on BreakHis | Backbone |
---|---|---|---|
Autoencoder | 37.0 | 36.0 | ResNet50 |
Colorization | 80.2 | 72.4 | ResNet50 |
CPCv2 | 80.1 | 71.1 | ResNet50 |
SSL Contrastive Learning | 86.2 | 78.2 | ResNet50 |
SSL Contrastive Learning - Best Model | 91.4 | 80.2 | ResNet34 |
cSiT | 93.0 | 92.8 ** | CCT |
pip install --upgrade -r requirements.txt || true
gdown --id 1BqoJ_IJWjOwqueCZstch-XQERcUto3dt # 600k_30epochs
python main.py --batch-size 72 --epochs 10 --min-lr 5e-6 --lr 1e-4 --training-mode 'SSL' --dataset 'UH_main' --output 'output' --validate-every 1
Finetuning with prepared dataset
python main.py --batch-size 120 --epochs 50 --min-lr 5e-6 --training-mode 'finetune' --dataset 'NCT' --finetune '<<path/to/pretrained_model>>' --output 'output' --validate-every 1
Finetuning with custom dataset
python main.py --batch-size 120 --epochs 50 --min-lr 5e-6 --training-mode 'finetune' --dataset 'Custom' --custom_train_dataset_path '<<path/to/train_dataset>>' --custom_val_dataset_path '<<path/to/val_dataset>>' --finetune 'output/checkpoint.pth' --output 'output' --validate-every 1
Linear projection Head
python main.py --batch-size 120 --epochs 50 --min-lr 5e-6 --training-mode 'finetune' --dataset 'NCT' --finetune '<<path/to/pretrained_model>>' --output 'output' --validate-every 1 --SiT_LinearEvaluation 1
2-layer MLP projection Head
python main.py --batch-size 120 --epochs 10 --lr 1e-3 --weight-decay 5e-4 --min-lr 5e-6 --training-mode 'finetune' --dataset 'NCT' --finetune '<<path/to/pretrained_model>>' --output 'output' --validate-every 1 --SiT_LinearEvaluation 1 --representation-size 1024
Note: assign the --dataset_location parameter to the location of the downloaded dataset
python evaluate.py --batch-size 180 --dataset 'Custom' --custom_test_dataset_path '<<path/to/test_dataset>>' --model-path '<<path/to/finetuned_model>>'