GobletNet: Wavelet-Based High Frequency Fusion Network for Semantic Segmentation of Electron Microscopy Images
This is the official code of GobletNet: Wavelet-Based High Frequency Fusion Network for Semantic Segmentation of Electron Microscopy Images (TMI 2024.10).
⭐⭐⭐ Using the characteristics of segmented images to drive the architecture design is the simplest but most effective!
We quantitatively analyze and summarize two characteristics of electron microscope (EM) images:
- Characteristic 1 Compared with other images, the HF components of EM images based on the wavelet transform have richer texture details and clearer object contours but also have more noise.
- Characteristic 2 For EM images, appropriately adding LF components to HF images can alleviate noise interference while maintaining sufficient HF details.
Qualitative comparison of HF characteristics among natural, medical, microscopic and EM images. (a) Raw images. (b) Wavelet transform results. (c) HF images. (d) Information richness heatmaps. (e) Noise intensity heatmaps. (f) Detailed distribution heatmaps. (g) Detailed distribution heatmaps (overlaid on raw images). (h) Ground truth.
- For Characteristic 1, we use the HF image as an extra input and use an extra encoder to extract the rich HF information in HF image.
- For Characteristic 2, we add LF components to HF image at a certain ratio to reduce the negative impact of excessive noise on model training.
(a) Raw images. (b) Ground truth. (c) SAM. (d) Deeplab V3+. (e) UNet 3+. (f) FusionNet. (g) WaveSNet. (h) UNet. (i) nnUNet. (j) GobletNet.
We have reimplemented some semantic segmentation models with different application scenarios, including natural, medical, wavelet and EM models.
Scenario | Model | Code |
---|---|---|
Natural | Deeplab V3+ | models/deeplabv3.py |
Res-UNet | models/resunet.py | |
U2-Net | models/u2net.py | |
Medical | UNet | models/unet.py |
UNet++ | models/unet_plusplus.py | |
Att-UNet | models/unet.py | |
UNet 3+ | models/unet_3plus.py | |
SwinUNet | models/swinunet.py | |
XNet | models/xnet.py | |
Wavelet | ALNet | models/aerial_lanenet.py.py |
MWCNN | models/mwcnn.py | |
WaveSNet | models/wavesnet.py | |
WDS | models/wds.py | |
EM | DCR | models/dcr.py |
FusionNet | models/fusionnet.py | |
GobletNet (Ours) | models/GobletNet.py |
albumentations==1.2.1
einops==0.4.1
matplotlib==3.1.0
MedPy==0.4.0
numpy==1.21.6
opencv_python_headless==4.5.4.60
Pillow==10.4.0
PyWavelets==1.3.0
scikit_image==0.19.3
scikit_learn==1.5.1
scipy==1.7.3
SimpleITK==2.4.0
skimage==0.0
thop==0.1.1.post2209072238
timm==0.6.7
torch==1.8.0+cu111
torchio==0.18.84
torchvision==0.9.0+cu111
tqdm==4.64.0
tqdm_pathos==0.4
visdom==0.1.8.9
- Dataset preparation
Use /tools/wavelet.py to generate wavelet transform results. Build your own dataset and its directory tree should be look like this:
dataset
├── train
├── image
├── 1.tif
├── 2.tif
└── ...
├── H_0.1_db2
├── 1.tif
├── 2.tif
└── ...
└── mask
├── 1.tif
├── 2.tif
└── ...
└── val
├── image
└── mask
- Configure dataset parameters
Add configuration in /config/dataset_config/dataset_config.py The configuration should be as follows:
'CREMI':
{
'IN_CHANNELS': 1,
'NUM_CLASSES': 2,
'SIZE': (128, 128),
'MEAN': [0.503902],
'STD': [0.110739],
'MEAN_H_0.1_db2': [0.515329],
'STD_H_0.1_db2': [0.118728],
'PALETTE': list(np.array([
[255, 255, 255],
[0, 0, 0],
]).flatten())
},
- Training
python -m torch.distributed.launch --nproc_per_node=4 train_GobletNet.py
- Testing
python -m torch.distributed.launch --nproc_per_node=4 test_GobletNet.py
If our work is useful for your research, please cite our paper: