Skip to content

Commit

Permalink
SwinMM/Initialize the SwinMM project (#296)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zi-hao-Wei authored Aug 11, 2023
1 parent d02865b commit 0cd69f2
Show file tree
Hide file tree
Showing 50 changed files with 3,477 additions and 0 deletions.
101 changes: 101 additions & 0 deletions SwinMM/INSTALL.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Installation

We provide installation instructions here.

## Setup

### Using Docker

The simplest way to use SwinMM is to use our docker image [`swinmm`](https://drive.google.com/file/d/1EGSoqN-HphyMV_gKUq-g7_BSwTTg35oA/view?usp=sharing), which has contained all the needed dependencies. Download the `swinmm.tar` into the `SwinMM` directory and try the following scripts:

```bash
cd SwinMM
docker import - swinmm < swinmm.tar
docker run --runtime=nvidia --gpus=all -m="800g" --shm-size="32g" -itd -v ./:/volume swinmm /bin/bash
docker exec -it swinmm /bin/bash
conda activate SwinMM
```

To use docker, make sure you have installed `docker` and `nvidia-docker`.

### Manual

For fast dataset loading, we required the users to install the Redis database, for example, on Ubuntu: `sudo apt-get install redis`

We also recommend the users install the PyTorch-based version from the official website.

Two packages are recommended to install manually according to their complicated dependencies: [bagua==0.9.2](https://github.com/BaguaSys/bagua), [monai==0.9.0](https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies)

The others can be installed through `pip install -r requirements.txt`

## Datasets

Our pre-training dataset includes 5833 volumes from 8 public datasets:

- [AbdomenCT-1K](https://github.com/JunMa11/AbdomenCT-1K)
- [BTCV](https://www.synapse.org/#!Synapse:syn3193805/wiki/217789)
- [MSD](http://medicaldecathlon.com/)
- [TCIACovid19](https://wiki.cancerimagingarchive.net/display/Public/CT+Images+in+COVID-19/)
- [WORD](https://github.com/HiLab-git/WORD)
- [TCIA-Colon](https://wiki.cancerimagingarchive.net/display/Public/CT+COLONOGRAPHY/)
- [LiDC](https://wiki.cancerimagingarchive.net/display/Public/LIDC-IDRI/)
- [HNSCC](https://wiki.cancerimagingarchive.net/display/Public/HNSCC)

We choose two popular datasets to test the downstream segmentation performance:

- [WORD](https://github.com/HiLab-git/WORD) (The Whole abdominal Organ Dataset)
- [ACDC](https://www.creatis.insa-lyon.fr/Challenge/acdc/#challenge/584e75606a3c77492fe91bba) (Automated Cardiac Diagnosis Challenge)

The json files can be downloaded from [pretrain_jsons](https://drive.google.com/file/d/1gJThxBvnJnc2_N1nFX7xywjFWFw7DSEY/view?usp=sharing) and [word_jsons](https://drive.google.com/file/d/1Td4T_k2QlEcTETz9TERGsVdOyebD5ULv/view?usp=sharing);

The dataset is organized as below:

```text
SwinMM
├── WORD
│ └── dataset
│ └── dataset12_WORD
│ ├── imagesTr
│ ├── imagesTs
│ ├── imagesVal
│ ├── labelsTr
│ ├── labelsTs
│ ├── labelsVal
│ └── dataset12_WORD.json
└── Pretrain
├── dataset
│ ├── dataset00_BTCV
│ ├── dataset02_Heart
│ ├── dataset03_Liver
│ ├── dataset04_Hippocampus
│ ├── dataset06_Lung
│ ├── dataset07_Pancreas
│ ├── dataset08_HepaticVessel
│ ├── dataset09_Spleen
│ ├── dataset10_Colon
│ ├── dataset11_TCIAcovid19
│ ├── dataset12_WORD
│ ├── dataset13_AbdomenCT-1K
│ ├── dataset_HNSCC
│ ├── dataset_TCIAcolon
│ └── dataset_LIDC
└── jsons
├── dataset00_BTCV.json
├── dataset01_BrainTumour.json
├── dataset02_Heart.json
├── dataset03_Liver.json
├── dataset04_Hippocampus.json
├── dataset05_Prostate.json
├── dataset06_Lung.json
├── dataset07_Pancreas.json
├── dataset08_HepaticVessel.json
├── dataset09_Spleen.json
├── dataset10_Colon.json
├── dataset11_TCIAcovid19.json
├── dataset12_WORD.json
├── dataset13_AbdomenCT-1K.json
├── dataset_HNSCC.json
├── dataset_TCIAcolon.json
└── dataset_LIDC.json
```
Empty file.
Empty file.
95 changes: 95 additions & 0 deletions SwinMM/Pretrain/losses/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright 2020 - 2022 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torch.nn import functional as F


class ContrastLoss(torch.nn.Module):
def __init__(self, args, batch_size, temperature=0.5):
super().__init__()
device = torch.device(f"cuda:{args.local_rank}")
self.batch_size = batch_size
self.register_buffer("temp", torch.tensor(temperature).to(torch.device(f"cuda:{args.local_rank}")))
self.register_buffer("neg_mask", (~torch.eye(batch_size * 2, batch_size * 2, dtype=bool).to(device)).float())

def forward(self, x_i, x_j):
z_i = F.normalize(x_i, dim=1)
z_j = F.normalize(x_j, dim=1)
z = torch.cat([z_i, z_j], dim=0)
sim = F.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=2)
sim_ij = torch.diag(sim, self.batch_size)
sim_ji = torch.diag(sim, -self.batch_size)
pos = torch.cat([sim_ij, sim_ji], dim=0)
nom = torch.exp(pos / self.temp)
denom = self.neg_mask * torch.exp(sim / self.temp)
return torch.sum(-torch.log(nom / torch.sum(denom, dim=1))) / (2 * self.batch_size)


class MutualLoss(torch.nn.Module):
def __init__(self, args):
super().__init__()
self.alpha = 1.0
self.mask_ratio = args.mask_ratio
self.recon_loss_2 = torch.nn.MSELoss().cuda()

def __call__(self, rec1, rec2, mask):
mask = mask.to(dtype=rec1.dtype)
rec1, rec2 = [val * mask for val in [rec1, rec2]]

recon_loss = self.recon_loss_2(rec1, rec2) / self.mask_ratio
return self.alpha * recon_loss


class Loss(torch.nn.Module):
def __init__(self, batch_size, args):
super().__init__()
self.rot_loss = torch.nn.CrossEntropyLoss().cuda()
self.recon_loss = torch.nn.L1Loss().cuda()
self.recon_loss_2 = torch.nn.MSELoss().cuda()
self.contrast_loss = ContrastLoss(args, batch_size).cuda()
self.alpha1 = 1.0
self.alpha2 = 1.0
self.alpha3 = 1.0
self.norm_pix_loss = args.norm_pix_loss
self.mask_ratio = args.mask_ratio

def __call__(
self,
output_rot,
target_rot,
output_contrastive,
target_contrastive,
output_recons,
target_recons,
mask,
only_mae=False,
):
B, C, H, W, D = output_recons.shape
target_recons = target_recons.reshape(B, C, -1)

if self.norm_pix_loss:
mean = target_recons.mean(dim=-1, keepdim=True)
var = target_recons.var(dim=-1, keepdim=True)
target_recons = (target_recons - mean) / (var + 1.0e-6) ** 0.5
target_recons = target_recons.reshape(B, C, H, W, D)
# masked voxels.
mask = mask.to(dtype=target_recons.dtype)[None, ...]
target_recons, output_recons = [val * mask for val in [target_recons, output_recons]]
recon_loss = self.recon_loss_2(output_recons, target_recons) / self.mask_ratio
recon_loss = self.alpha3 * recon_loss
if only_mae:
return recon_loss
contrast_loss = self.alpha2 * self.contrast_loss(output_contrastive, target_contrastive)
rot_loss = self.alpha1 * self.rot_loss(output_rot, target_rot)
total_loss = rot_loss + contrast_loss + recon_loss

return total_loss, (rot_loss, contrast_loss, recon_loss)
Loading

0 comments on commit 0cd69f2

Please sign in to comment.