-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #16 from florencejt/refactor/metricchoices
Refactor/metricchoices: User-defined validation metrics
- Loading branch information
Showing
25 changed files
with
1,004 additions
and
191 deletions.
There are no files selected for viewing
Empty file.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
.. _advanced-examples: | ||
|
||
Comparing Models | ||
======================================================= |
4 changes: 2 additions & 2 deletions
4
...sting/plot_model_comparison_loop_kfold.py → ...rison/plot_model_comparison_loop_kfold.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,12 @@ | ||
.. _train_test_examples: | ||
|
||
Running Fusilli on your own data | ||
Training and Testing | ||
========================================== | ||
|
||
These are examples of how to train and validate fusion models with Fusilli. | ||
|
||
.. contents:: **Contents** | ||
:local: | ||
:depth: 1 | ||
|
||
|
2 changes: 1 addition & 1 deletion
2
docs/examples/training_and_testing/plot_one_model_binary_kfold.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
148 changes: 148 additions & 0 deletions
148
docs/examples/training_and_testing/plot_one_model_regression_traintest.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
""" | ||
Train/Test split: Regression | ||
====================================================== | ||
🚀 In this tutorial, we'll explore regression using a train/test split. | ||
Specifically, we're using the :class:`~.TabularCrossmodalMultiheadAttention` model. | ||
Key Features: | ||
- 📥 Importing a model based on its path. | ||
- 🧪 Training and testing a model with train/test split. | ||
- 📈 Plotting the loss curves of each fold. | ||
- 📊 Visualising the results of a single train/test model using the :class:`~.RealsVsPreds` class. | ||
""" | ||
|
||
import matplotlib.pyplot as plt | ||
from tqdm.auto import tqdm | ||
import os | ||
|
||
from docs.examples import generate_sklearn_simulated_data | ||
from fusilli.data import prepare_fusion_data | ||
from fusilli.eval import RealsVsPreds | ||
from fusilli.train import train_and_save_models | ||
|
||
# sphinx_gallery_thumbnail_number = -1 | ||
|
||
# %% | ||
# 1. Import the fusion model 🔍 | ||
# -------------------------------- | ||
# We're importing only one model for this example, the :class:`~.TabularCrossmodalMultiheadAttention` model. | ||
# Instead of using the :func:`~fusilli.utils.model_chooser.import_chosen_fusion_models` function, we're importing the model directly like with any other library method. | ||
|
||
|
||
from fusilli.fusionmodels.tabularfusion.crossmodal_att import ( | ||
TabularCrossmodalMultiheadAttention, | ||
) | ||
|
||
# %% | ||
# 2. Set the training parameters 🎯 | ||
# ----------------------------------- | ||
# Now we're configuring our training parameters. | ||
# | ||
# For training and testing, the necessary parameters are: | ||
# - Paths to the input data files. | ||
# - Paths to the output directories. | ||
# - ``prediction_task``: the type of prediction to be performed. This is either ``regression``, ``binary``, or ``classification``. | ||
# | ||
# Some optional parameters are: | ||
# | ||
# - ``kfold``: a boolean of whether to use k-fold cross-validation (True) or not (False). By default, this is set to False. | ||
# - ``num_folds``: the number of folds to use. It can't be ``k=1``. | ||
# - ``wandb_logging``: a boolean of whether to log the results using Weights and Biases (True) or not (False). Default is False. | ||
# - ``test_size``: the proportion of the dataset to include in the test split. Default is 0.2. | ||
# - ``batch_size``: the batch size to use for training. Default is 8. | ||
# - ``multiclass_dimensions``: the number of classes to use for multiclass classification. Default is None unless ``prediction_task`` is ``multiclass``. | ||
# - ``max_epochs``: the maximum number of epochs to train for. Default is 1000. | ||
|
||
# Regression task | ||
prediction_task = "regression" | ||
|
||
# Set the batch size | ||
batch_size = 32 | ||
|
||
# Setting output directories | ||
output_paths = { | ||
"losses": "loss_logs/one_model_regression_traintest", | ||
"checkpoints": "checkpoints/one_model_regression_traintest", | ||
"figures": "figures/one_model_regression_traintest", | ||
} | ||
|
||
# Create the output directories if they don't exist | ||
for path in output_paths.values(): | ||
os.makedirs(path, exist_ok=True) | ||
|
||
# Clearing the loss logs directory (only for the example notebooks) | ||
for dir in os.listdir(output_paths["losses"]): | ||
# remove files | ||
for file in os.listdir(os.path.join(output_paths["losses"], dir)): | ||
os.remove(os.path.join(output_paths["losses"], dir, file)) | ||
# remove dir | ||
os.rmdir(os.path.join(output_paths["losses"], dir)) | ||
|
||
# %% | ||
# 3. Generating simulated data 🔮 | ||
# -------------------------------- | ||
# Time to create some simulated data for our models to work their wonders on. | ||
# This function also simulated image data which we aren't using here. | ||
|
||
tabular1_path, tabular2_path = generate_sklearn_simulated_data(prediction_task, | ||
num_samples=500, | ||
num_tab1_features=10, | ||
num_tab2_features=20) | ||
|
||
data_paths = { | ||
"tabular1": tabular1_path, | ||
"tabular2": tabular2_path, | ||
"image": "", | ||
} | ||
|
||
# %% | ||
# 4. Training the fusion model 🏁 | ||
# -------------------------------------- | ||
# Now we're ready to train our model. We're using the :func:`~fusilli.train.train_and_save_models` function to train our model. | ||
# | ||
# First we need to create a data module using the :func:`~fusilli.data.prepare_fusion_data` function. | ||
# This function takes the following parameters: | ||
# | ||
# - ``prediction_task``: the type of prediction to be performed. | ||
# - ``fusion_model``: the fusion model to be trained. | ||
# - ``data_paths``: the paths to the input data files. | ||
# - ``output_paths``: the paths to the output directories. | ||
# | ||
# Then we pass the data module and the fusion model to the :func:`~fusilli.train.train_and_save_models` function. | ||
# We're not using checkpointing for this example, so we set ``enable_checkpointing=False``. We're also setting ``show_loss_plot=True`` to plot the loss curve. | ||
|
||
|
||
fusion_model = TabularCrossmodalMultiheadAttention | ||
|
||
print("method_name:", fusion_model.method_name) | ||
print("modality_type:", fusion_model.modality_type) | ||
print("fusion_type:", fusion_model.fusion_type) | ||
|
||
dm = prepare_fusion_data(prediction_task=prediction_task, | ||
fusion_model=fusion_model, | ||
data_paths=data_paths, | ||
output_paths=output_paths, | ||
batch_size=batch_size) | ||
|
||
# train and test | ||
single_model_list = train_and_save_models( | ||
data_module=dm, | ||
fusion_model=fusion_model, | ||
enable_checkpointing=False, # False for the example notebooks | ||
show_loss_plot=True, | ||
metrics_list=["r2", "mae", "mse"] | ||
) | ||
|
||
# %% | ||
# 6. Plotting the results 📊 | ||
# ---------------------------- | ||
# Now we're ready to plot the results of our model. | ||
# We're using the :class:`~.RealsVsPreds` class to plot the confusion matrix. | ||
|
||
reals_preds_fig = RealsVsPreds.from_final_val_data( | ||
single_model_list | ||
) | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
.. _install_instructions: | ||
|
||
|
||
How to Install | ||
============== | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.