A cell classifier using information limited to images of DAPI stained nuclei.
This napari plugin was generated with Cookiecutter using @napari's cookiecutter-napari-plugin template.
This project provides a flexible interface for training and predicting with various model types: neural networks (nn
), logistic regression (logreg
), and random forests (rf
). The interface uses argparse
for handling a wide array of customizable options, allowing tailored configurations for both training and prediction.
This project was built with Python 3.12.2 and may not support other Python versions. To install:
-
Clone the repository and navigate into it:
git clone <repository-url> cd cell-classification
-
Install Pytorch v2.3.1. The installation process depends on your OS as well as GPU availability. Refer to PyTorch installation guide for details. For installation on OSX:
pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1
For installation on Linux and Windows with CUDA or CPU-only, choose the appropriate option below:
-
CUDA 11.8:
pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu118
-
CUDA 12.1:
pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu121
-
CPU only:
pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cpu
-
-
Install
cell-classification
via pip:pip install .
Below is an overview of the project structure, outlining the main directories and their purpose:
cell-classification/
├── Analysis/ # Analysis files: plotting, validation, and test scoring
├── src/
│ └── nucleus_3d_classification/
│ ├── preprocess/ # Scripts for dataset generation and ground-truth creation
│ ├── baseline/ # Script to fit baseline models on 2D as well as 3D extracted features, with setup files
│ ├── utils/ # Dataset preparation, transformations, and loss functions
│ ├── models/ # Model architectures (ResNet support)
│ └── main.py # Main entry point for model training
└── setup_json/ # Example JSON configurations for NN custom datamodule setup as well as model_runner setup
-
Analysis Folder
Contains all files for data analysis, including scripts for plotting and evaluating model scores on validation/test data. -
Source Code
Located insrc/nucleus_3d_classification/
, this folder contains the core codebase:baseline/
: Script to train and evaluate baseline models, independantly of the CLI tool.model_fit.py
: Baseline training and evaluation script.- Configuration (JSON) files used for the fitting and evaluation of the baseline models.
preprocess/
: Scripts to generate and curate the ground-truth dataset. For example:napari_curation_and_labelling.py
: A Napari plug-in tool used to label and curate nuclei segmentation masks.- Additional scripts for feature extraction and crop generation.
utils/
: Contains essential scripts for preparing the dataset for NN training and testing, including:datamodule.py
: Dataset setup.transforms.py
: Custom transformations.- Loss function configurations.
models/
: Defines the available model architectures. Currently, only ResNet models are supported, with all configurations inResNet.py
.
-
Setup JSON
Thesetup_json
directory provides examples of how JSON files should be formatted for the--data
argument during NN training. Note that the batch size provided in the JSON can be overwritten by providing a batch size with the--batch_size
. These examples demonstrate the data structure required by the DataModule to work, no other arguments should be put in other than in the example. Themodel_runner_example.py
script contains a sample configuration file illustrating the setup for themodel_runner --params
argument, as used during model training.
Run the script by specifying the model type (--model_type
) and command (--command
), along with the necessary parameters for each specific model.
The main entry point for training models is main.py
. To view usage options, run:
python main.py --h
To view more options for a specific model type and command, for example a neural network and train, run:
python main.py --model_type nn --command train --h
Training a Neural Network:
python main.py \
--model_type nn \
--command train \
--data /path/to/cd41_setup.json \
--data_module CustomDataModule \
--model_class ResNet50 \
--max_epochs 20
Testing a trained Neural Network:
To run predictions on a trained neural network model using the test dataset specified in the DataModule
, set --stage test
. For predictions on the validation dataset, set --stage validate
. The output by default will be a csv file with the logged metrics.
python main.py \
--model_type nn \
--command predict \
--data /path/to/Sca_setup.json \
--model_class ResNet50 \
--enable_progress_bar \
--data_module CustomDataModule \
--model_file /path/to/Sca1_best-f1_score-epoch=87-val_f1=0.34.ckpt \
--stage test \
--save_dir /cluster/project/schroeder/AG/CD41/results/predictions/sca1/
Training a Logistic Regression Model:
python main.py \
--model_type logreg \
--command train \
--data train_data.csv
Training a Logistic Regression Model:
python main.py \
--model_type rf \
--command predict \
--model_file rf_model.pkl \
--data test_data.csv
Contributions are very welcome. Tests can be run with tox, please ensure the coverage at least stays the same before you submit a pull request.
Distributed under the terms of the BSD-3 license, "cell-classification" is free and open source software
If you encounter any problems, please [file an issue] along with a detailed description.