This is the code train a CNN to transform a slice or a slab of T1/T2/PD mri image(s) to PDD, FA or MRA.
Run using this docker images:
docker pull
- MRA: /data/mri/data/multi-pix2pix-pytorch/t123_mra
- FA/ColorFA: /data/mri/data/color_fa_sliced
- PDD: /data/mri/data/pdd_sliced
Run to train the model. For example (
CUDA_VISIBLE_DEVICES=4 python --dataroot /data/mri/data/color_fa_sliced --name t1_fa_L1_resnet9_T3_3d_tmp --which_model_netG resnet_9blocks_3d --content_only --T 3 --predict_idx_type middle --output_nc 1 --norm batch_3d --conv_type 3d --fineSize 128 --valid_folder val --use_L1 --input_nc 1 --input_channels 0 --validate_freq 1000 --niter 10 --niter_decay 30 --target_type fa
More examples are in scripts/. For example, trains a model using condition GAN plus the L1 loss. trains a model using perceptual loss. For perceptual loss, the code only supports inputs with 3 channels as it uses pretrained vgg16.
See all options for, e.g. what models it supports:
python -h
Run to test the model. For example (scripts/
CUDA_VISIBLE_DEVICES=4 python --dataroot /data/mri/data/color_fa_sliced --name t1_fa_L1_resnet9_T3_3d --which_model_netG resnet_9blocks_3d --content_only --T 3 --predict_idx_type middle --output_nc 1 --norm batch_3d --conv_type 3d --fineSize 128 --valid_folder val --input_nc 1 --input_channels 0 --display_type single --which_epoch lowest_val --phase test --target_type fa
More examples are in scripts/. For example, scpripts/ tests the model for images with gaussian filter with blur radius equals to 5 pixels. See all options for
python -h
The weights of the trained models are stored under /data/mri/convrnn/checkpoints. "_G" means the generator. "_D" means the discriminator for GAN, not applicable if not using GAN. lowest_val_net_G.pth is the weights of the generator at the lowest validation point. 20_net_G.pth is the weights of the generator at the end of epoch 20.
All the results for test set are stored under /data/mri/convrnn/results