- Anas Himmi ([email protected])
- Zeineb Mellouli ([email protected])
- Antoine Cornaz ([email protected])
Department of Computer Science, EPFL, Switzerland
This project implements deep learning models to detect ant coloration using images from the AntWeb dataset, in collaboration with the DEE - Group Bertelsmeier (Insect Ecology).
The project focuses on developing accurate color detection models for ant specimens, with the ultimate goal of analyzing colors of specific body parts (head, thorax, and abdomen). We implemented and compared several approaches:
- Baseline statistical methods (mean, median, mode)
- Custom CNN architectures (3-channel and 4-channel inputs)
- Pre-trained ResNet18 model
- Color-based segmentation approach
Key achievements:
- Best performance achieved by ResNet18 with validation loss of 4.60
- Successful implementation of custom loss function in CIELAB color space
- Good results on the small dataset of ant body-parts with the CNN model
├── data/ # Dataset directory
├── saves/ # Saved some objects to avoid recomputing them in results.ipynb
├── src/
│ ├── cnn/ # All CNN implementations
│ │ ├── cnn_3channels/ # 3-channel CNN implementation
│ │ ├── cnn_4channels/ # 4-channel CNN implementation
│ │ └── cnn_segmented/ # CNN for segmented body parts
│ ├── resnet/ # ResNet18 implementation
│ ├── Baseline/ # Baseline methods (mean, median, mode)
│ ├── color_segmentation/ # Precomputing color-based segmentation
│ ├── utils.py # Utility functions (custom loss function, etc.)
│ └── final_models/ # Saved model weights (.pth) and predictions except best_model_resnet_final.pth (too large)
├── requirements.txt # Python dependencies
└── results.ipynb # Results visualization and analysis
best_model_resnet_final.pth
is too large to be uploaded to GitHub. It can be found here
- Clone the repository:
git clone [repository-url]
- Install the dependencies:
pip install -r requirements.txt
Each model can be trained using its respective script in the corresponding directory. Example:
python src/cnn/cnn_3channels/train_CNN_hpc_3channel.py # Training (~6 hours on V100 GPU on 28000 images)
--data_path data/colour_digitalization.xlsx \
--img_dir path_to_images/without_background/ \
--save_path outputs/best_model_CNN_3channel_final.pth \
--batch_size 16
python src/cnn/infer_CNN_3channel.py # Inference (<5 minutes on V100 GPU on 5000 images)
--data_path data/colour_digitalization.xlsx \
--img_dir path_to_images/without_background/ \
--model_path outputs/best_model_CNN_3channel_final.pth \
--output_path outputs/predictions_CNN_3channel_final.csv \
--batch_size 16
For practical use, we provide a script to run predictions on new unlabelled images given a pretrained model. The script will output the predicted color values for each body part. The images to process should be placed in a directory without background and the masks for the body parts should be placed in a separate directory. The files should be named in the same way as the training dataset of the segmentation team. Example:
python src/body_part_predict_cnn_4channel.py # 0.16s/image on my cpu (i7-8550U)
--img_dir path_to_images/without_background/ \
--mask_dir path_to_masks/eyes_corrected/ \
--model_path src/final_models/best_model_segmented_final.pth \
--output_path outputs/predictions_CNN_segmented_final.csv \
--batch_size 16
This script will output a CSV file with the predicted color values for each body part. (rgb values in [0, 1] range)
If the model trained is modified, it suffices to change the classes AntColorDataset
and AntColorCNN
in the script to be the same as the new model. For AntColorDataset
, self.colors
(that stores true colors for training and validation) should be removed from __init__
anf from the return of __getitem__
.
Open results.ipynb
in Jupyter Notebook/Lab to view comprehensive results and visualizations.
Model performance on validation set:
- ResNet18: 4.60 (LAB color space loss)
- 3-Channel CNN: 5.15
- 4-Channel CNN: 5.07
- Baseline methods: 15-17
- Body Parts CNN: 8.83
Special thanks to:
- DEE - Group Bertelsmeier (Insect Ecology)
This project uses data from AntWeb under Creative Commons Attribution License. Please ensure proper attribution when using this code or the associated data.