Skip to content

Commit

Permalink
Merge pull request #16 from florencejt/refactor/metricchoices
Browse files Browse the repository at this point in the history
Refactor/metricchoices: User-defined validation metrics
  • Loading branch information
florencejt authored Jan 11, 2024
2 parents 1411337 + 0efbe2e commit 7f099db
Show file tree
Hide file tree
Showing 25 changed files with 1,004 additions and 191 deletions.
Empty file added docs/__init__.py
Empty file.
Binary file removed docs/computergif.gif
Binary file not shown.
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@
[
"examples/customising_behaviour",
"examples/training_and_testing",
"examples/model_comparison",
],
),
"within_subsection_order": FileNameSortKey,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""
How to customise the training in Fusilli
#########################################
Customising Training
=========================================

This tutorial will show you how to customise the training of your fusion model.
This page will show you how to customise the training and evaluation of your fusion models.

We will cover the following topics:

* Early stopping
* Valildation metrics
* Batch size
* Number of epochs
* Checkpoint suffix modification
Expand Down Expand Up @@ -53,6 +53,70 @@

-----

Choosing metrics
-----------------

By default, Fusilli uses the following metrics for each prediction task:

* Binary classification: `Area under the ROC curve <https://lightning.ai/docs/torchmetrics/stable/classification/auroc.html>`_ and `accuracy <https://lightning.ai/docs/torchmetrics/stable/classification/accuracy.html>`_
* Multiclass classification: `Area under the ROC curve <https://lightning.ai/docs/torchmetrics/stable/classification/auroc.html>`_ and `accuracy <https://lightning.ai/docs/torchmetrics/stable/classification/accuracy.html>`_
* Regression: `R2 score <https://lightning.ai/docs/torchmetrics/stable/regression/r2_score.html>`_ and `mean absolute error <https://lightning.ai/docs/torchmetrics/stable/regression/mean_absolute_error.html>`_

You can change the metrics used by passing a list of metrics to the ``metrics_list`` argument in the :func:`~.fusilli.train.train_and_save_models` function.
For example, if you wanted to change the metrics used for a binary classification task to precision, recall, and area under the precision-recall curve, you could do the following:

.. code-block:: python
new_metrics_list = ["precision", "recall", "auprc"]
trained_model = train_and_save_models(
data_module=datamodule,
fusion_model=example_model,
metrics_list=new_metrics_list,
)
Here are the supported metrics as of Fusilli v1.2.0:

**Regression**:

* `R2 score <https://lightning.ai/docs/torchmetrics/stable/regression/r2_score.html>`_: ``r2``
* `Mean absolute error <https://lightning.ai/docs/torchmetrics/stable/regression/mean_absolute_error.html>`_: ``mae``
* `Mean squared error <https://lightning.ai/docs/torchmetrics/stable/regression/mean_squared_error.html>`_: ``mse``

**Binary or multiclass classification**:

* `Area under the ROC curve <https://lightning.ai/docs/torchmetrics/stable/classification/auroc.html>`_: ``auroc``
* `Accuracy <https://lightning.ai/docs/torchmetrics/stable/classification/accuracy.html>`_: ``accuracy``
* `Recall <https://lightning.ai/docs/torchmetrics/stable/classification/recall.html>`_: ``recall``
* `Specificity <https://lightning.ai/docs/torchmetrics/stable/classification/specificity.html>`_: ``specificity``
* `Precision <https://lightning.ai/docs/torchmetrics/stable/classification/precision.html>`_: ``precision``
* `F1 score <https://lightning.ai/docs/torchmetrics/stable/classification/f1_score.html>`_: ``f1``
* `Area under the precision-recall curve <https://lightning.ai/docs/torchmetrics/stable/classification/average_precision.html>`_: ``auprc``
* `Balanced accuracy <https://lightning.ai/docs/torchmetrics/stable/classification/accuracy.html>`_: ``balanced_accuracy``

If you'd like to add more metrics to fusilli, then please open an issue on the `Fusilli GitHub repository issues page <https://github.com/florencejt/fusilli/issues>`_ or submit a pull request.
The metrics are calculated in :class:`~.fusilli.utils.metrics_utils.MetricsCalculator`, with a separate method for each metric.

**Using your own custom metric:**

If you'd like to use your own custom metric without adding it to fusilli, then you can calculate it using the validation labels and predictions/probabilities.
You can access the validation labels and validation predictions/probabilities from the trained model that is returned by the :func:`~.fusilli.train.train_and_save_models` function.
Look at :class:`~.fusilli.fusionmodels.base_model.BaseModel` for a list of attributes that are available to you to access.


.. note::

The first metric in the metrics list is used to rank the models in the model comparison evaluation figures.
Only the first two metrics will be shown in the model comparison figures.
The rest of the metrics will be shown in the model evaluation dataframe and printed out to the console during training.

.. warning::

There must be at least two metrics in the metrics list.

-----


Batch size
----------

Expand Down Expand Up @@ -112,7 +176,7 @@

-----

Checkpoint suffix modification
Checkpoint file names
------------------------------

By default, Fusilli saves the model checkpoints in the following format:
Expand Down Expand Up @@ -156,5 +220,4 @@
.. note::

The ``extra_log_string_dict`` argument is also used to modify the logging behaviour of the model. For more information, see :ref:`wandb`.
"""
# sphinx_gallery_thumbnail_path = '_static/pink_pasta_logo.png'

6 changes: 4 additions & 2 deletions docs/examples/customising_behaviour/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,7 @@ Customising Fusilli
These are examples showing how to get more in depth with Fusilli and customise its behaviour.

* Modify the fusion model structures
* Change the training behaviour: early stopping, batch size, test set size, etc.
* Use Fusilli for hyperparameter tuning by adding suffixes to training outputs

.. note::

More examples to come throughout 2024.
4 changes: 4 additions & 0 deletions docs/examples/model_comparison/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
.. _advanced-examples:

Comparing Models
=======================================================
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""
Training multiple models in a loop: k-fold regression
Comparing All Fusion Models
====================================================================
Welcome to the "Comparing Multiple K-Fold Trained Fusion Models" tutorial! In this tutorial, we'll explore how to train and compare multiple fusion models for a regression task using k-fold cross-validation with multimodal tabular data. This tutorial is designed to help you understand and implement key features, including:
Welcome to the "Comparing All Fusion Models" tutorial! In this tutorial, we'll explore how to train and compare multiple fusion models for a regression task using k-fold cross-validation with multimodal tabular data. This tutorial is designed to help you understand and implement key features, including:
- 📥 Importing fusion models based on modality types.
- 🚲 Setting training parameters for your models
Expand Down
6 changes: 5 additions & 1 deletion docs/examples/training_and_testing/README.rst
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
.. _train_test_examples:

Running Fusilli on your own data
Training and Testing
==========================================

These are examples of how to train and validate fusion models with Fusilli.

.. contents:: **Contents**
:local:
:depth: 1


Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Binary Classification: Training a K-Fold Model
K-Fold Cross-Validation: Binary Classification
======================================================
🚀 In this tutorial, we'll explore binary classification using K-fold cross validation.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
"""
Train/Test split: Regression
======================================================
🚀 In this tutorial, we'll explore regression using a train/test split.
Specifically, we're using the :class:`~.TabularCrossmodalMultiheadAttention` model.
Key Features:
- 📥 Importing a model based on its path.
- 🧪 Training and testing a model with train/test split.
- 📈 Plotting the loss curves of each fold.
- 📊 Visualising the results of a single train/test model using the :class:`~.RealsVsPreds` class.
"""

import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import os

from docs.examples import generate_sklearn_simulated_data
from fusilli.data import prepare_fusion_data
from fusilli.eval import RealsVsPreds
from fusilli.train import train_and_save_models

# sphinx_gallery_thumbnail_number = -1

# %%
# 1. Import the fusion model 🔍
# --------------------------------
# We're importing only one model for this example, the :class:`~.TabularCrossmodalMultiheadAttention` model.
# Instead of using the :func:`~fusilli.utils.model_chooser.import_chosen_fusion_models` function, we're importing the model directly like with any other library method.


from fusilli.fusionmodels.tabularfusion.crossmodal_att import (
TabularCrossmodalMultiheadAttention,
)

# %%
# 2. Set the training parameters 🎯
# -----------------------------------
# Now we're configuring our training parameters.
#
# For training and testing, the necessary parameters are:
# - Paths to the input data files.
# - Paths to the output directories.
# - ``prediction_task``: the type of prediction to be performed. This is either ``regression``, ``binary``, or ``classification``.
#
# Some optional parameters are:
#
# - ``kfold``: a boolean of whether to use k-fold cross-validation (True) or not (False). By default, this is set to False.
# - ``num_folds``: the number of folds to use. It can't be ``k=1``.
# - ``wandb_logging``: a boolean of whether to log the results using Weights and Biases (True) or not (False). Default is False.
# - ``test_size``: the proportion of the dataset to include in the test split. Default is 0.2.
# - ``batch_size``: the batch size to use for training. Default is 8.
# - ``multiclass_dimensions``: the number of classes to use for multiclass classification. Default is None unless ``prediction_task`` is ``multiclass``.
# - ``max_epochs``: the maximum number of epochs to train for. Default is 1000.

# Regression task
prediction_task = "regression"

# Set the batch size
batch_size = 32

# Setting output directories
output_paths = {
"losses": "loss_logs/one_model_regression_traintest",
"checkpoints": "checkpoints/one_model_regression_traintest",
"figures": "figures/one_model_regression_traintest",
}

# Create the output directories if they don't exist
for path in output_paths.values():
os.makedirs(path, exist_ok=True)

# Clearing the loss logs directory (only for the example notebooks)
for dir in os.listdir(output_paths["losses"]):
# remove files
for file in os.listdir(os.path.join(output_paths["losses"], dir)):
os.remove(os.path.join(output_paths["losses"], dir, file))
# remove dir
os.rmdir(os.path.join(output_paths["losses"], dir))

# %%
# 3. Generating simulated data 🔮
# --------------------------------
# Time to create some simulated data for our models to work their wonders on.
# This function also simulated image data which we aren't using here.

tabular1_path, tabular2_path = generate_sklearn_simulated_data(prediction_task,
num_samples=500,
num_tab1_features=10,
num_tab2_features=20)

data_paths = {
"tabular1": tabular1_path,
"tabular2": tabular2_path,
"image": "",
}

# %%
# 4. Training the fusion model 🏁
# --------------------------------------
# Now we're ready to train our model. We're using the :func:`~fusilli.train.train_and_save_models` function to train our model.
#
# First we need to create a data module using the :func:`~fusilli.data.prepare_fusion_data` function.
# This function takes the following parameters:
#
# - ``prediction_task``: the type of prediction to be performed.
# - ``fusion_model``: the fusion model to be trained.
# - ``data_paths``: the paths to the input data files.
# - ``output_paths``: the paths to the output directories.
#
# Then we pass the data module and the fusion model to the :func:`~fusilli.train.train_and_save_models` function.
# We're not using checkpointing for this example, so we set ``enable_checkpointing=False``. We're also setting ``show_loss_plot=True`` to plot the loss curve.


fusion_model = TabularCrossmodalMultiheadAttention

print("method_name:", fusion_model.method_name)
print("modality_type:", fusion_model.modality_type)
print("fusion_type:", fusion_model.fusion_type)

dm = prepare_fusion_data(prediction_task=prediction_task,
fusion_model=fusion_model,
data_paths=data_paths,
output_paths=output_paths,
batch_size=batch_size)

# train and test
single_model_list = train_and_save_models(
data_module=dm,
fusion_model=fusion_model,
enable_checkpointing=False, # False for the example notebooks
show_loss_plot=True,
metrics_list=["r2", "mae", "mse"]
)

# %%
# 6. Plotting the results 📊
# ----------------------------
# Now we're ready to plot the results of our model.
# We're using the :class:`~.RealsVsPreds` class to plot the confusion matrix.

reals_preds_fig = RealsVsPreds.from_final_val_data(
single_model_list
)
plt.show()
2 changes: 2 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Why would you want to use fusilli?

choosing_model
modifying_models
customising_training
logging_with_wandb
glossary

Expand All @@ -63,6 +64,7 @@ Why would you want to use fusilli?
:caption: 🌸 Tutorials 🌸

auto_examples/training_and_testing/index
auto_examples/model_comparison/index
auto_examples/customising_behaviour/index

-----
Expand Down
3 changes: 3 additions & 0 deletions docs/installation.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
.. _install_instructions:


How to Install
==============

Expand Down
12 changes: 9 additions & 3 deletions docs/quick_start.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ This script provides a simple setup to train a model using ``fusilli`` on a sing

This code showcases the necessary steps to execute Fusilli on a single dataset.

**Before you run this, you need to:**

1. Install ``fusilli`` (see :ref:`install_instructions`).
2. Prepare your data and specify the paths to your data (see :ref:`data-loading`).
3. Specify output file paths (see :ref:`experiment-set-up`).


Usage Example
-------------
Expand All @@ -22,7 +28,7 @@ Usage Example
import matplotlib.pyplot as plt
# Import the example fusion model
from fusilli.fusionmodels.tabularfusion.example_model import ExampleModel
from fusilli.fusionmodels.tabularfusion.concat_data import ConcatTabularData
data_paths = {
"tabular1": "path/to/tabular_1.csv", # Path to tabular dataset 1
Expand All @@ -38,13 +44,13 @@ Usage Example
# Get the data module (PyTorch Lightning-compatible data structure)
data_module = prepare_fusion_data(prediction_task="regression",
fusion_model=ExampleModel,
fusion_model=ConcatTabularData,
data_paths=data_paths,
output_paths=output_paths)
# Train the model and receive a list with the trained model
trained_model = train_and_save_models(data_module=data_module,
fusion_model=ExampleModel)
fusion_model=ConcatTabularData)
# Evaluate the model by plotting the real values vs. predicted values
RealsVsPreds_figure = RealsVsPreds.from_final_val_data(trained_model)
Expand Down
Loading

0 comments on commit 7f099db

Please sign in to comment.