From 1f73f71ea313a485127013cbc2630a5466224f0b Mon Sep 17 00:00:00 2001 From: Sagi Polaczek <56922146+SagiPolaczek@users.noreply.github.com> Date: Wed, 8 Mar 2023 19:03:09 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=A7=B9=20Integrate=20`nb-clean`=20(#287)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * cleaned notebooks * added nb-clean to GitHub Actions * added nb-clean as a pre-commit hook * fixed typo --------- Co-authored-by: Sagi Polaczek --- .github/workflows/lint.yaml | 10 + .pre-commit-config.yaml | 4 + fuse/requirements_dev.txt | 1 + .../imaging/hello_world/hello_world.ipynb | 1222 ++++++++--------- .../multimodality_image_clinical.ipynb | 7 +- fuseimg/datasets/kits21_example.ipynb | 7 +- 6 files changed, 592 insertions(+), 659 deletions(-) diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index c2f3cde09..f7dd7e3bc 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -32,3 +32,13 @@ jobs: - run: pip install --upgrade pip - run: pip install mypy==0.950 - run: python -m mypy . --show-error-codes + nb-clean: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: "3.7" + - run: pip install --upgrade pip + - run: pip install nb-clean + - run: nb-clean check . diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 78961ff90..cf0db8a71 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,3 +18,7 @@ repos: rev: v0.950 hooks: - id: mypy + - repo: https://github.com/srstevenson/nb-clean + rev: "2.4.0" + hooks: + - id: nb-clean diff --git a/fuse/requirements_dev.txt b/fuse/requirements_dev.txt index 6646e3f8e..28e4dce89 100644 --- a/fuse/requirements_dev.txt +++ b/fuse/requirements_dev.txt @@ -5,3 +5,4 @@ mypy==0.950 flake8 black==22.3.0 pre-commit +nb-clean \ No newline at end of file diff --git a/fuse_examples/imaging/hello_world/hello_world.ipynb b/fuse_examples/imaging/hello_world/hello_world.ipynb index 2c6cf8633..90f0fb13f 100644 --- a/fuse_examples/imaging/hello_world/hello_world.ipynb +++ b/fuse_examples/imaging/hello_world/hello_world.ipynb @@ -1,651 +1,575 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "DuNXhtjF_J6b" - }, - "source": [ - "# FuseMedML - Hello World\n", - "[![Github repo](https://img.shields.io/static/v1?label=GitHub&message=FuseMedML&color=brightgreen)](https://github.com/BiomedSciAI/fuse-med-ml)\n", - "\n", - "[![PyPI version](https://badge.fury.io/py/fuse-med-ml.svg)](https://badge.fury.io/py/fuse-med-ml)\n", - "\n", - "[![Slack channel](https://img.shields.io/badge/support-slack-slack.svg?logo=slack)](https://join.slack.com/t/fusemedml/shared_invite/zt-xr1jaj29-h7IMsSc0Lq4qpVNxW97Phw)\n", - "\n", - "[![Open Source](https://badges.frapsoft.com/os/v1/open-source.svg)](https://github.com/BiomedSciAI/fuse-med-ml)\n", - "\n", - "\n", - "**Welcome to FuseMedML's 'hello world' hands-on notebook!**\n", - "\n", - "In this notebook we'll examine a FuseMedML's basic use case: MNIST multiclass classification - incluing training, inference and evaluation.\n", - "\n", - "By the end of the session we hope you'll be familiar with basic Fuse's workflow and acknowledge it's potential.\n", - "\n", - "Open and run this notebook in [Google Colab](https://colab.research.google.com/github/BiomedSciAI/fuse-med-ml/blob/master/fuse_examples/imaging/hello_world/hello_world.ipynb)\n", - "\n", - "ENJOY" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": { - "id": "3aSfvNXy_J6d" - }, - "source": [ - "------------\n", - "## **Installation Details - Google Colab**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "gVAmstRJ_J6d" - }, - "outputs": [], - "source": [ - "# @title 1. Install FuseMedML\n", - "\n", - "# @markdown Please choose whether or not to install FuseMedML and execute this cell by pressing the *Play* button on the left.\n", - "\n", - "\n", - "install_fuse = False # @param {type:\"boolean\"}\n", - "use_gpu = True # @param {type:\"boolean\"}\n", - "\n", - "# @markdown ### **Warning!**\n", - "# @markdown If you wish to install FuseMedML -- as a workaround for\n", - "# @markdown [this](https://stackoverflow.com/questions/57831187/need-to-restart-runtime-before-import-an-installed-package-in-colab)\n", - "# @markdown issue please follow those steps:
\n", - "# @markdown 1. Execute this cell by pressing the ▶️ button on the left.\n", - "# @markdown 2. Restart runtime\n", - "# @markdown 3. Execute it once again\n", - "# @markdown 4. Enjoy\n", - "if install_fuse:\n", - " !git clone https://github.com/BiomedSciAI/fuse-med-ml.git\n", - " %cd fuse-med-ml\n", - " %pip install -e .[all,examples]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Tt3DFDcZ_J6e" - }, - "source": [ - "\n", - "## **Setup environment**" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "k7-RZ9s9_J6f" - }, - "source": [ - "##### **Imports**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "fhdyx0TH_J6f" - }, - "outputs": [], - "source": [ - "# @title 1. Imports\n", - "\n", - "# @markdown Please execute this cell by pressing the *Play* button on the left.\n", - "\n", - "import os\n", - "import copy\n", - "from typing import OrderedDict\n", - "\n", - "import torch.nn.functional as F\n", - "import torch.optim as optim\n", - "import pytorch_lightning as pl\n", - "from torch.utils.data.dataloader import DataLoader\n", - "\n", - "from fuse.eval.evaluator import EvaluatorDefault\n", - "from fuse.dl.losses.loss_default import LossDefault\n", - "from fuse.eval.metrics.classification.metrics_classification_common import MetricAccuracy, MetricAUCROC, MetricROCCurve\n", - "from fuse.eval.metrics.classification.metrics_thresholding_common import MetricApplyThresholds\n", - "from fuse.dl.models.model_wrapper import ModelWrapSeqToDict\n", - "from fuse.data.utils.samplers import BatchSamplerDefault\n", - "from fuse.data.utils.collates import CollateDefault\n", - "from fuse.dl.lightning.pl_module import LightningModuleDefault\n", - "from fuse.dl.lightning.pl_funcs import convert_predictions_to_dataframe\n", - "from fuse.utils.file_io.file_io import create_dir, save_dataframe\n", - "from fuseimg.datasets.mnist import MNIST\n", - "\n", - "from fuse_examples.imaging.hello_world.hello_world_utils import LeNet, perform_softmax" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XIqail6k_J6f" - }, - "source": [ - "##### **Output paths**\n", - "The user is able to easily customize the directories paths." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "pBmmAveF_J6g" - }, - "outputs": [], - "source": [ - "ROOT = \"_examples/mnist\"\n", - "model_dir = os.path.join(ROOT, \"model_dir\")\n", - "PATHS = {\n", - " \"model_dir\": model_dir,\n", - " \"cache_dir\": os.path.join(ROOT, \"cache_dir\"),\n", - " \"inference_dir\": os.path.join(model_dir, \"infer_dir\"),\n", - " \"eval_dir\": os.path.join(model_dir, \"eval_dir\"),\n", - "}\n", - "\n", - "paths = PATHS" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "VmknNXJI_J6g" - }, - "source": [ - "##### **Training Parameters**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "u0MDbU_2_J6g" - }, - "outputs": [], - "source": [ - "TRAIN_COMMON_PARAMS = {}\n", - "\n", - "### Data ###\n", - "TRAIN_COMMON_PARAMS[\"data.batch_size\"] = 100\n", - "TRAIN_COMMON_PARAMS[\"data.train_num_workers\"] = 8\n", - "TRAIN_COMMON_PARAMS[\"data.validation_num_workers\"] = 8\n", - "\n", - "### PL Trainer ###\n", - "TRAIN_COMMON_PARAMS[\"trainer.num_epochs\"] = 2\n", - "TRAIN_COMMON_PARAMS[\"trainer.num_devices\"] = 1\n", - "TRAIN_COMMON_PARAMS[\"trainer.accelerator\"] = \"gpu\" if use_gpu else \"cpu\"\n", - "TRAIN_COMMON_PARAMS[\"trainer.ckpt_path\"] = None # path to the checkpoint you wish continue the training from\n", - "\n", - "### Optimizer ###\n", - "TRAIN_COMMON_PARAMS[\"opt.lr\"] = 1e-4\n", - "TRAIN_COMMON_PARAMS[\"opt.weight_decay\"] = 0.001\n", - "\n", - "train_params = TRAIN_COMMON_PARAMS" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "j-xEdrD3_J6h" - }, - "source": [ - "## **Training the model**" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Zh4jZrk7_J6h" - }, - "source": [ - "##### **Data**\n", - "Downloading the MNIST dataset and building dataloaders (torch.utils.data.DataLoader) for both train and validation.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "xiq8kQH0_J6h" - }, - "outputs": [], - "source": [ - "## Training Data\n", - "# Create dataset\n", - "train_dataset = MNIST.dataset(paths[\"cache_dir\"], train=True)\n", - "\n", - "# Create Fuse's custom sampler\n", - "sampler = BatchSamplerDefault(\n", - " dataset=train_dataset,\n", - " balanced_class_name=\"data.label\",\n", - " num_balanced_classes=10,\n", - " batch_size=train_params[\"data.batch_size\"],\n", - " balanced_class_weights=None,\n", - ")\n", - "\n", - "# Create dataloader\n", - "train_dataloader = DataLoader(\n", - " dataset=train_dataset,\n", - " batch_sampler=sampler,\n", - " collate_fn=CollateDefault(),\n", - " num_workers=train_params[\"data.train_num_workers\"],\n", - ")\n", - "\n", - "## Validation data\n", - "# Create dataset\n", - "validation_dataset = MNIST.dataset(paths[\"cache_dir\"], train=False)\n", - "\n", - "# dataloader\n", - "validation_dataloader = DataLoader(\n", - " dataset=validation_dataset,\n", - " batch_size=train_params[\"data.batch_size\"],\n", - " collate_fn=CollateDefault(),\n", - " num_workers=train_params[\"data.validation_num_workers\"],\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DkZEpVvO_J6h" - }, - "source": [ - "##### **Model**\n", - "Building a LeNet model using \"pure\" PyTorch and wrapping it with Fuse's component. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "wzPFRv7x_J6i" - }, - "outputs": [], - "source": [ - "def create_model():\n", - " torch_model = LeNet()\n", - " # wrap basic torch model to automatically read inputs from batch_dict and save its outputs to batch_dict\n", - " model = ModelWrapSeqToDict(\n", - " model=torch_model,\n", - " model_inputs=[\"data.image\"],\n", - " post_forward_processing_function=perform_softmax,\n", - " model_outputs=[\"model.logits.classification\", \"model.output.classification\"],\n", - " )\n", - " return model\n", - "\n", - "\n", - "model = create_model()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ImJ4E9hE_J6i" - }, - "source": [ - "##### **Loss function**\n", - "Dictionary of loss elements such that each element is a sub-class of LossBase. The total loss will be the weighted sum of all the elements." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "9JeRh9Bl_J6i" - }, - "outputs": [], - "source": [ - "losses = {\n", - " \"cls_loss\": LossDefault(\n", - " pred=\"model.logits.classification\", target=\"data.label\", callable=F.cross_entropy, weight=1.0\n", - " ),\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "p1upw3_d_J6i" - }, - "source": [ - "##### **Metrics**\n", - "Dictionary of metric elements such that each element is a sub-class of MetricBase.\n", - "\n", - "The metrics will be calculated per epoch for both the validation and train." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "d9Svdy2o_J6i" - }, - "outputs": [], - "source": [ - "train_metrics = OrderedDict(\n", - " [\n", - " (\"operation_point\", MetricApplyThresholds(pred=\"model.output.classification\")), # will apply argmax\n", - " (\"accuracy\", MetricAccuracy(pred=\"results:metrics.operation_point.cls_pred\", target=\"data.label\")),\n", - " ]\n", - ")\n", - "validation_metrics = copy.deepcopy(train_metrics) # use the same metrics in validation as well" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "cggT9L_i_J6j" - }, - "source": [ - "##### **Best Epoch Source**\n", - "Defining what will be considered as 'the best epoch' so the model will be saved according to it." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "HXqoizBR_J6j" - }, - "outputs": [], - "source": [ - "best_epoch_source = dict(monitor=\"validation.metrics.accuracy\", mode=\"max\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "o-0FtWtH_J6j" - }, - "source": [ - "##### **Train**\n", - "Training session using PyTorch Lightning's trainer." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "bScADcdo_J6j" - }, - "outputs": [], - "source": [ - "# create optimizer\n", - "optimizer = optim.Adam(model.parameters(), lr=train_params[\"opt.lr\"], weight_decay=train_params[\"opt.weight_decay\"])\n", - "\n", - "# create scheduler\n", - "lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)\n", - "lr_sch_config = dict(scheduler=lr_scheduler, monitor=\"validation.losses.total_loss\")\n", - "\n", - "# optimizer and lr sch - see pl.LightningModule.configure_optimizers return value for all options\n", - "optimizers_and_lr_schs = dict(optimizer=optimizer, lr_scheduler=lr_sch_config)\n", - "\n", - "# create instance of PL module - FuseMedML generic version\n", - "pl_module = LightningModuleDefault(\n", - " model_dir=paths[\"model_dir\"],\n", - " model=model,\n", - " losses=losses,\n", - " train_metrics=train_metrics,\n", - " validation_metrics=validation_metrics,\n", - " best_epoch_source=best_epoch_source,\n", - " optimizers_and_lr_schs=optimizers_and_lr_schs,\n", - ")\n", - "\n", - "# create lightning trainer\n", - "pl_trainer = pl.Trainer(\n", - " default_root_dir=paths[\"model_dir\"],\n", - " max_epochs=train_params[\"trainer.num_epochs\"],\n", - " accelerator=train_params[\"trainer.accelerator\"],\n", - " devices=train_params[\"trainer.num_devices\"],\n", - " auto_select_gpus=True,\n", - ")\n", - "\n", - "# train\n", - "pl_trainer.fit(pl_module, train_dataloader, validation_dataloader, ckpt_path=train_params[\"trainer.ckpt_path\"])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vNi1gK86_J6j" - }, - "source": [ - "## **Infer**" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "FYOQKWaf_J6k" - }, - "source": [ - "##### **Define Infer Common Params**\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "O50FcBaV_J6k" - }, - "outputs": [], - "source": [ - "INFER_COMMON_PARAMS = {}\n", - "INFER_COMMON_PARAMS[\"infer_filename\"] = \"infer_file.gz\"\n", - "INFER_COMMON_PARAMS[\"checkpoint\"] = \"best_epoch.ckpt\"\n", - "INFER_COMMON_PARAMS[\"trainer.num_devices\"] = TRAIN_COMMON_PARAMS[\"trainer.num_devices\"]\n", - "INFER_COMMON_PARAMS[\"trainer.accelerator\"] = TRAIN_COMMON_PARAMS[\"trainer.accelerator\"]\n", - "\n", - "infer_common_params = INFER_COMMON_PARAMS" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "mAyCwNqx_J6k" - }, - "source": [ - "##### **Infer**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "2RuJ8EkG_J6k" - }, - "outputs": [], - "source": [ - "# setting dir and paths\n", - "create_dir(paths[\"inference_dir\"])\n", - "infer_file = os.path.join(paths[\"inference_dir\"], infer_common_params[\"infer_filename\"])\n", - "checkpoint_file = os.path.join(paths[\"model_dir\"], infer_common_params[\"checkpoint\"])\n", - "\n", - "# creating a dataloader\n", - "validation_dataloader = DataLoader(dataset=validation_dataset, collate_fn=CollateDefault(), batch_size=2, num_workers=2)\n", - "\n", - "# load pytorch lightning module\n", - "model = create_model()\n", - "pl_module = LightningModuleDefault.load_from_checkpoint(\n", - " checkpoint_file, model_dir=paths[\"model_dir\"], model=model, map_location=\"cpu\", strict=True\n", - ")\n", - "\n", - "# set the prediction keys to extract (the ones used be the evaluation function).\n", - "pl_module.set_predictions_keys(\n", - " [\"model.output.classification\", \"data.label\"]\n", - ") # which keys to extract and dump into file\n", - "\n", - "# create a trainer instance\n", - "pl_trainer = pl.Trainer(\n", - " default_root_dir=paths[\"model_dir\"],\n", - " accelerator=infer_common_params[\"trainer.accelerator\"],\n", - " devices=infer_common_params[\"trainer.num_devices\"],\n", - " auto_select_gpus=True,\n", - ")\n", - "\n", - "# predict\n", - "predictions = pl_trainer.predict(pl_module, validation_dataloader, return_predictions=True)\n", - "\n", - "# convert list of batch outputs into a dataframe\n", - "infer_df = convert_predictions_to_dataframe(predictions)\n", - "save_dataframe(infer_df, infer_file)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "yw8Yeb59_J6k" - }, - "source": [ - "## **Evaluation**\n", - "Using the `EvaluatorDefault` from the evaluation package of FuseMedML (fuse.eval) which is a standalone library for evaluating ML models that not necessarily trained with FuseMedML.\n", - "\n", - "More details and examples for the evaluation package can be found [here](https://github.com/BiomedSciAI/fuse-med-ml/blob/master/fuse/eval/README.md).\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "kjJs_As0_J6k" - }, - "source": [ - "##### **Define EVAL Common Params**\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "RiREWr9t_J6l" - }, - "outputs": [], - "source": [ - "EVAL_COMMON_PARAMS = {}\n", - "EVAL_COMMON_PARAMS[\"infer_filename\"] = INFER_COMMON_PARAMS[\"infer_filename\"]\n", - "\n", - "eval_common_params = EVAL_COMMON_PARAMS" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AsvtwRPs_J6l" - }, - "source": [ - "##### **Define metrics**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "frozeFWa_J6l" - }, - "outputs": [], - "source": [ - "class_names = [str(i) for i in range(10)]\n", - "\n", - "# metrics\n", - "metrics = OrderedDict(\n", - " [\n", - " (\"operation_point\", MetricApplyThresholds(pred=\"model.output.classification\")), # will apply argmax\n", - " (\"accuracy\", MetricAccuracy(pred=\"results:metrics.operation_point.cls_pred\", target=\"data.label\")),\n", - " (\n", - " \"roc\",\n", - " MetricROCCurve(\n", - " pred=\"model.output.classification\",\n", - " target=\"data.label\",\n", - " class_names=class_names,\n", - " output_filename=os.path.join(paths[\"inference_dir\"], \"roc_curve.png\"),\n", - " ),\n", - " ),\n", - " (\"auc\", MetricAUCROC(pred=\"model.output.classification\", target=\"data.label\", class_names=class_names)),\n", - " ]\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8nbda_2R_J6l" - }, - "source": [ - "##### **Evaluate**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "EgP63Gpc_J6l" - }, - "outputs": [], - "source": [ - "# create evaluator\n", - "evaluator = EvaluatorDefault()\n", - "\n", - "# run eval\n", - "results = evaluator.eval(\n", - " ids=None,\n", - " data=os.path.join(paths[\"inference_dir\"], eval_common_params[\"infer_filename\"]),\n", - " metrics=metrics,\n", - " output_dir=paths[\"eval_dir\"],\n", - " silent=False,\n", - ")\n", - "\n", - "print(\"Done!\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ytSA3_uX_J6l" - }, - "outputs": [], - "source": [ - "# For testing purposes\n", - "test_result_acc = results[\"metrics.accuracy\"]" - ] - } - ], - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "display_name": "fuse_repo_3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.13" - }, - "orig_nbformat": 4, - "vscode": { - "interpreter": { - "hash": "603ef3845b02b84c8b743302232442c478ebfea21d9503b404b4c0a993eb87a9" - } - } - }, - "nbformat": 4, - "nbformat_minor": 0 + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# FuseMedML - Hello World\n", + "[![Github repo](https://img.shields.io/static/v1?label=GitHub&message=FuseMedML&color=brightgreen)](https://github.com/BiomedSciAI/fuse-med-ml)\n", + "\n", + "[![PyPI version](https://badge.fury.io/py/fuse-med-ml.svg)](https://badge.fury.io/py/fuse-med-ml)\n", + "\n", + "[![Slack channel](https://img.shields.io/badge/support-slack-slack.svg?logo=slack)](https://join.slack.com/t/fusemedml/shared_invite/zt-xr1jaj29-h7IMsSc0Lq4qpVNxW97Phw)\n", + "\n", + "[![Open Source](https://badges.frapsoft.com/os/v1/open-source.svg)](https://github.com/BiomedSciAI/fuse-med-ml)\n", + "\n", + "\n", + "**Welcome to FuseMedML's 'hello world' hands-on notebook!**\n", + "\n", + "In this notebook we'll examine a FuseMedML's basic use case: MNIST multiclass classification - incluing training, inference and evaluation.\n", + "\n", + "By the end of the session we hope you'll be familiar with basic Fuse's workflow and acknowledge it's potential.\n", + "\n", + "Open and run this notebook in [Google Colab](https://colab.research.google.com/github/BiomedSciAI/fuse-med-ml/blob/master/fuse_examples/imaging/hello_world/hello_world.ipynb)\n", + "\n", + "ENJOY" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "------------\n", + "## **Installation Details - Google Colab**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @title 1. Install FuseMedML\n", + "\n", + "# @markdown Please choose whether or not to install FuseMedML and execute this cell by pressing the *Play* button on the left.\n", + "\n", + "\n", + "install_fuse = False # @param {type:\"boolean\"}\n", + "use_gpu = True # @param {type:\"boolean\"}\n", + "\n", + "# @markdown ### **Warning!**\n", + "# @markdown If you wish to install FuseMedML -- as a workaround for\n", + "# @markdown [this](https://stackoverflow.com/questions/57831187/need-to-restart-runtime-before-import-an-installed-package-in-colab)\n", + "# @markdown issue please follow those steps:
\n", + "# @markdown 1. Execute this cell by pressing the ▶️ button on the left.\n", + "# @markdown 2. Restart runtime\n", + "# @markdown 3. Execute it once again\n", + "# @markdown 4. Enjoy\n", + "if install_fuse:\n", + " !git clone https://github.com/BiomedSciAI/fuse-med-ml.git\n", + " %cd fuse-med-ml\n", + " %pip install -e .[all,examples]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "## **Setup environment**" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### **Imports**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @title 1. Imports\n", + "\n", + "# @markdown Please execute this cell by pressing the *Play* button on the left.\n", + "\n", + "import os\n", + "import copy\n", + "from typing import OrderedDict\n", + "\n", + "import torch.nn.functional as F\n", + "import torch.optim as optim\n", + "import pytorch_lightning as pl\n", + "from torch.utils.data.dataloader import DataLoader\n", + "\n", + "from fuse.eval.evaluator import EvaluatorDefault\n", + "from fuse.dl.losses.loss_default import LossDefault\n", + "from fuse.eval.metrics.classification.metrics_classification_common import MetricAccuracy, MetricAUCROC, MetricROCCurve\n", + "from fuse.eval.metrics.classification.metrics_thresholding_common import MetricApplyThresholds\n", + "from fuse.dl.models.model_wrapper import ModelWrapSeqToDict\n", + "from fuse.data.utils.samplers import BatchSamplerDefault\n", + "from fuse.data.utils.collates import CollateDefault\n", + "from fuse.dl.lightning.pl_module import LightningModuleDefault\n", + "from fuse.dl.lightning.pl_funcs import convert_predictions_to_dataframe\n", + "from fuse.utils.file_io.file_io import create_dir, save_dataframe\n", + "from fuseimg.datasets.mnist import MNIST\n", + "\n", + "from fuse_examples.imaging.hello_world.hello_world_utils import LeNet, perform_softmax" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### **Output paths**\n", + "The user is able to easily customize the directories paths." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ROOT = \"_examples/mnist\"\n", + "model_dir = os.path.join(ROOT, \"model_dir\")\n", + "PATHS = {\n", + " \"model_dir\": model_dir,\n", + " \"cache_dir\": os.path.join(ROOT, \"cache_dir\"),\n", + " \"inference_dir\": os.path.join(model_dir, \"infer_dir\"),\n", + " \"eval_dir\": os.path.join(model_dir, \"eval_dir\"),\n", + "}\n", + "\n", + "paths = PATHS" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### **Training Parameters**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "TRAIN_COMMON_PARAMS = {}\n", + "\n", + "### Data ###\n", + "TRAIN_COMMON_PARAMS[\"data.batch_size\"] = 100\n", + "TRAIN_COMMON_PARAMS[\"data.train_num_workers\"] = 8\n", + "TRAIN_COMMON_PARAMS[\"data.validation_num_workers\"] = 8\n", + "\n", + "### PL Trainer ###\n", + "TRAIN_COMMON_PARAMS[\"trainer.num_epochs\"] = 2\n", + "TRAIN_COMMON_PARAMS[\"trainer.num_devices\"] = 1\n", + "TRAIN_COMMON_PARAMS[\"trainer.accelerator\"] = \"gpu\" if use_gpu else \"cpu\"\n", + "TRAIN_COMMON_PARAMS[\"trainer.ckpt_path\"] = None # path to the checkpoint you wish continue the training from\n", + "\n", + "### Optimizer ###\n", + "TRAIN_COMMON_PARAMS[\"opt.lr\"] = 1e-4\n", + "TRAIN_COMMON_PARAMS[\"opt.weight_decay\"] = 0.001\n", + "\n", + "train_params = TRAIN_COMMON_PARAMS" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## **Training the model**" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### **Data**\n", + "Downloading the MNIST dataset and building dataloaders (torch.utils.data.DataLoader) for both train and validation.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "## Training Data\n", + "# Create dataset\n", + "train_dataset = MNIST.dataset(paths[\"cache_dir\"], train=True)\n", + "\n", + "# Create Fuse's custom sampler\n", + "sampler = BatchSamplerDefault(\n", + " dataset=train_dataset,\n", + " balanced_class_name=\"data.label\",\n", + " num_balanced_classes=10,\n", + " batch_size=train_params[\"data.batch_size\"],\n", + " balanced_class_weights=None,\n", + ")\n", + "\n", + "# Create dataloader\n", + "train_dataloader = DataLoader(\n", + " dataset=train_dataset,\n", + " batch_sampler=sampler,\n", + " collate_fn=CollateDefault(),\n", + " num_workers=train_params[\"data.train_num_workers\"],\n", + ")\n", + "\n", + "## Validation data\n", + "# Create dataset\n", + "validation_dataset = MNIST.dataset(paths[\"cache_dir\"], train=False)\n", + "\n", + "# dataloader\n", + "validation_dataloader = DataLoader(\n", + " dataset=validation_dataset,\n", + " batch_size=train_params[\"data.batch_size\"],\n", + " collate_fn=CollateDefault(),\n", + " num_workers=train_params[\"data.validation_num_workers\"],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### **Model**\n", + "Building a LeNet model using \"pure\" PyTorch and wrapping it with Fuse's component. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def create_model():\n", + " torch_model = LeNet()\n", + " # wrap basic torch model to automatically read inputs from batch_dict and save its outputs to batch_dict\n", + " model = ModelWrapSeqToDict(\n", + " model=torch_model,\n", + " model_inputs=[\"data.image\"],\n", + " post_forward_processing_function=perform_softmax,\n", + " model_outputs=[\"model.logits.classification\", \"model.output.classification\"],\n", + " )\n", + " return model\n", + "\n", + "\n", + "model = create_model()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### **Loss function**\n", + "Dictionary of loss elements such that each element is a sub-class of LossBase. The total loss will be the weighted sum of all the elements." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "losses = {\n", + " \"cls_loss\": LossDefault(\n", + " pred=\"model.logits.classification\", target=\"data.label\", callable=F.cross_entropy, weight=1.0\n", + " ),\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### **Metrics**\n", + "Dictionary of metric elements such that each element is a sub-class of MetricBase.\n", + "\n", + "The metrics will be calculated per epoch for both the validation and train." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_metrics = OrderedDict(\n", + " [\n", + " (\"operation_point\", MetricApplyThresholds(pred=\"model.output.classification\")), # will apply argmax\n", + " (\"accuracy\", MetricAccuracy(pred=\"results:metrics.operation_point.cls_pred\", target=\"data.label\")),\n", + " ]\n", + ")\n", + "validation_metrics = copy.deepcopy(train_metrics) # use the same metrics in validation as well" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### **Best Epoch Source**\n", + "Defining what will be considered as 'the best epoch' so the model will be saved according to it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "best_epoch_source = dict(monitor=\"validation.metrics.accuracy\", mode=\"max\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### **Train**\n", + "Training session using PyTorch Lightning's trainer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# create optimizer\n", + "optimizer = optim.Adam(model.parameters(), lr=train_params[\"opt.lr\"], weight_decay=train_params[\"opt.weight_decay\"])\n", + "\n", + "# create scheduler\n", + "lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)\n", + "lr_sch_config = dict(scheduler=lr_scheduler, monitor=\"validation.losses.total_loss\")\n", + "\n", + "# optimizer and lr sch - see pl.LightningModule.configure_optimizers return value for all options\n", + "optimizers_and_lr_schs = dict(optimizer=optimizer, lr_scheduler=lr_sch_config)\n", + "\n", + "# create instance of PL module - FuseMedML generic version\n", + "pl_module = LightningModuleDefault(\n", + " model_dir=paths[\"model_dir\"],\n", + " model=model,\n", + " losses=losses,\n", + " train_metrics=train_metrics,\n", + " validation_metrics=validation_metrics,\n", + " best_epoch_source=best_epoch_source,\n", + " optimizers_and_lr_schs=optimizers_and_lr_schs,\n", + ")\n", + "\n", + "# create lightning trainer\n", + "pl_trainer = pl.Trainer(\n", + " default_root_dir=paths[\"model_dir\"],\n", + " max_epochs=train_params[\"trainer.num_epochs\"],\n", + " accelerator=train_params[\"trainer.accelerator\"],\n", + " devices=train_params[\"trainer.num_devices\"],\n", + " auto_select_gpus=True,\n", + ")\n", + "\n", + "# train\n", + "pl_trainer.fit(pl_module, train_dataloader, validation_dataloader, ckpt_path=train_params[\"trainer.ckpt_path\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## **Infer**" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### **Define Infer Common Params**\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "INFER_COMMON_PARAMS = {}\n", + "INFER_COMMON_PARAMS[\"infer_filename\"] = \"infer_file.gz\"\n", + "INFER_COMMON_PARAMS[\"checkpoint\"] = \"best_epoch.ckpt\"\n", + "INFER_COMMON_PARAMS[\"trainer.num_devices\"] = TRAIN_COMMON_PARAMS[\"trainer.num_devices\"]\n", + "INFER_COMMON_PARAMS[\"trainer.accelerator\"] = TRAIN_COMMON_PARAMS[\"trainer.accelerator\"]\n", + "\n", + "infer_common_params = INFER_COMMON_PARAMS" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### **Infer**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# setting dir and paths\n", + "create_dir(paths[\"inference_dir\"])\n", + "infer_file = os.path.join(paths[\"inference_dir\"], infer_common_params[\"infer_filename\"])\n", + "checkpoint_file = os.path.join(paths[\"model_dir\"], infer_common_params[\"checkpoint\"])\n", + "\n", + "# creating a dataloader\n", + "validation_dataloader = DataLoader(dataset=validation_dataset, collate_fn=CollateDefault(), batch_size=2, num_workers=2)\n", + "\n", + "# load pytorch lightning module\n", + "model = create_model()\n", + "pl_module = LightningModuleDefault.load_from_checkpoint(\n", + " checkpoint_file, model_dir=paths[\"model_dir\"], model=model, map_location=\"cpu\", strict=True\n", + ")\n", + "\n", + "# set the prediction keys to extract (the ones used be the evaluation function).\n", + "pl_module.set_predictions_keys(\n", + " [\"model.output.classification\", \"data.label\"]\n", + ") # which keys to extract and dump into file\n", + "\n", + "# create a trainer instance\n", + "pl_trainer = pl.Trainer(\n", + " default_root_dir=paths[\"model_dir\"],\n", + " accelerator=infer_common_params[\"trainer.accelerator\"],\n", + " devices=infer_common_params[\"trainer.num_devices\"],\n", + " auto_select_gpus=True,\n", + ")\n", + "\n", + "# predict\n", + "predictions = pl_trainer.predict(pl_module, validation_dataloader, return_predictions=True)\n", + "\n", + "# convert list of batch outputs into a dataframe\n", + "infer_df = convert_predictions_to_dataframe(predictions)\n", + "save_dataframe(infer_df, infer_file)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## **Evaluation**\n", + "Using the `EvaluatorDefault` from the evaluation package of FuseMedML (fuse.eval) which is a standalone library for evaluating ML models that not necessarily trained with FuseMedML.\n", + "\n", + "More details and examples for the evaluation package can be found [here](https://github.com/BiomedSciAI/fuse-med-ml/blob/master/fuse/eval/README.md).\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### **Define EVAL Common Params**\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "EVAL_COMMON_PARAMS = {}\n", + "EVAL_COMMON_PARAMS[\"infer_filename\"] = INFER_COMMON_PARAMS[\"infer_filename\"]\n", + "\n", + "eval_common_params = EVAL_COMMON_PARAMS" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### **Define metrics**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class_names = [str(i) for i in range(10)]\n", + "\n", + "# metrics\n", + "metrics = OrderedDict(\n", + " [\n", + " (\"operation_point\", MetricApplyThresholds(pred=\"model.output.classification\")), # will apply argmax\n", + " (\"accuracy\", MetricAccuracy(pred=\"results:metrics.operation_point.cls_pred\", target=\"data.label\")),\n", + " (\n", + " \"roc\",\n", + " MetricROCCurve(\n", + " pred=\"model.output.classification\",\n", + " target=\"data.label\",\n", + " class_names=class_names,\n", + " output_filename=os.path.join(paths[\"inference_dir\"], \"roc_curve.png\"),\n", + " ),\n", + " ),\n", + " (\"auc\", MetricAUCROC(pred=\"model.output.classification\", target=\"data.label\", class_names=class_names)),\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### **Evaluate**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# create evaluator\n", + "evaluator = EvaluatorDefault()\n", + "\n", + "# run eval\n", + "results = evaluator.eval(\n", + " ids=None,\n", + " data=os.path.join(paths[\"inference_dir\"], eval_common_params[\"infer_filename\"]),\n", + " metrics=metrics,\n", + " output_dir=paths[\"eval_dir\"],\n", + " silent=False,\n", + ")\n", + "\n", + "print(\"Done!\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# For testing purposes\n", + "test_result_acc = results[\"metrics.accuracy\"]" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "fuse_repo_3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + }, + "vscode": { + "interpreter": { + "hash": "603ef3845b02b84c8b743302232442c478ebfea21d9503b404b4c0a993eb87a9" + } + } + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/fuse_examples/multimodality/image_clinical/multimodality_image_clinical.ipynb b/fuse_examples/multimodality/image_clinical/multimodality_image_clinical.ipynb index d9bea2a96..f80b366db 100644 --- a/fuse_examples/multimodality/image_clinical/multimodality_image_clinical.ipynb +++ b/fuse_examples/multimodality/image_clinical/multimodality_image_clinical.ipynb @@ -177,9 +177,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": true - }, + "metadata": {}, "outputs": [], "source": [ "print(\"Download data: it might take few miuntes\")\n", @@ -559,8 +557,7 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.15" + "pygments_lexer": "ipython3" }, "vscode": { "interpreter": { diff --git a/fuseimg/datasets/kits21_example.ipynb b/fuseimg/datasets/kits21_example.ipynb index 91a658738..66eabd71d 100644 --- a/fuseimg/datasets/kits21_example.ipynb +++ b/fuseimg/datasets/kits21_example.ipynb @@ -81,9 +81,7 @@ "cell_type": "code", "execution_count": null, "id": "e9d12c6d", - "metadata": { - "scrolled": true - }, + "metadata": {}, "outputs": [], "source": [ "num_samples = 2\n", @@ -497,8 +495,7 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.13" + "pygments_lexer": "ipython3" } }, "nbformat": 4,