Skip to content

Commit

Permalink
feat: Use multiclass datamodule for segmentation generic example (#73)
Browse files Browse the repository at this point in the history
Update oxford pet segmentation example to multiclass segmentation task (#73)

* feat: update oxford segmentation example

* fix: update parameter name for the model

* feat: update analysis logs

Approved-By: @lorenzomammana
  • Loading branch information
rcmalli authored Oct 8, 2023
1 parent c568794 commit 5677cb3
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 17 deletions.
2 changes: 1 addition & 1 deletion docs/tutorials/examples/segmentation.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ export:
backbone:
model:
classes: 4 # The total number of classes (background + foreground)
num_classes: 4 # The total number of classes (background + foreground)
task:
run_test: true # run test after training is completed
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
_target_: quadra.datamodules.generic.oxford_pet.OxfordPetSegmentationDataModule
idx_to_class:
1: cat_or_dog
data_path: ${oc.env:HOME}/.quadra/datasets/oxford-pet
test_size: 0.2
val_size: 0.2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@
defaults:
- base/segmentation/smp # use smp file as default
- override /datamodule: generic/oxford_pet/segmentation/base # update datamodule
- override /loss: smp_dice_multiclass
- override /model: smp_multiclass
- _self_ # use this file as final config

trainer:
devices: [0]
max_epochs: 10

backbone:
model:
num_classes: 2 # The total number of classes (background + foreground)

task:
report: true
evaluate:
Expand Down
32 changes: 18 additions & 14 deletions quadra/datamodules/generic/oxford_pet.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,32 @@
import os
from typing import Any, Optional, Type
from typing import Any, Dict, Optional, Type

import albumentations
import cv2
import numpy as np
import pandas as pd
from torchvision.datasets.utils import download_and_extract_archive

from quadra.datamodules import SegmentationDataModule
from quadra.datasets.segmentation import SegmentationDataset
from quadra.datamodules import SegmentationMulticlassDataModule
from quadra.datasets.segmentation import SegmentationDatasetMulticlass
from quadra.utils import utils

log = utils.get_logger(__name__)


class OxfordPetSegmentationDataModule(SegmentationDataModule):
class OxfordPetSegmentationDataModule(SegmentationMulticlassDataModule):
"""OxfordPetSegmentationDataModule.
Args:
data_path: path to the oxford pet dataset
test_size: Defaults to 0.3.
val_size: Defaults to 0.3.
seed: Defaults to 42.
idx_to_class: dict with corrispondence btw mask index and classes: {1: class_1, 2: class_2, ..., N: class_N}
except background class which is 0.
name: Defaults to "oxford_pet_segmentation_datamodule".
dataset: Defaults to SegmentationDataset.
batch_size: batch size for training. Defaults to 32.
test_size: Defaults to 0.3.
val_size: Defaults to 0.3.
seed: Defaults to 42.
num_workers: number of workers for data loading. Defaults to 6.
train_transform: Train transform. Defaults to None.
test_transform: Test transform. Defaults to None.
Expand All @@ -34,12 +36,13 @@ class OxfordPetSegmentationDataModule(SegmentationDataModule):
def __init__(
self,
data_path: str,
idx_to_class: Dict,
name: str = "oxford_pet_segmentation_datamodule",
dataset: Type[SegmentationDatasetMulticlass] = SegmentationDatasetMulticlass,
batch_size: int = 32,
test_size: float = 0.3,
val_size: float = 0.3,
seed: int = 42,
name: str = "oxford_pet_segmentation_datamodule",
dataset: Type[SegmentationDataset] = SegmentationDataset,
batch_size: int = 32,
num_workers: int = 6,
train_transform: Optional[albumentations.Compose] = None,
test_transform: Optional[albumentations.Compose] = None,
Expand All @@ -48,16 +51,17 @@ def __init__(
):
super().__init__(
data_path=data_path,
idx_to_class=idx_to_class,
name=name,
dataset=dataset,
batch_size=batch_size,
test_size=test_size,
val_size=val_size,
seed=seed,
name=name,
dataset=dataset,
num_workers=num_workers,
train_transform=train_transform,
test_transform=test_transform,
val_transform=val_transform,
batch_size=batch_size,
num_workers=num_workers,
**kwargs,
)

Expand Down
3 changes: 2 additions & 1 deletion quadra/tasks/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def prepare(self) -> None:
@automatic_datamodule_batch_size(batch_size_attribute_name="batch_size")
def test(self) -> None:
"""Run testing."""
log.info("Starting testing")
log.info("Starting inference for analysis.")

stages: List[str] = []
dataloaders: List[torch.utils.data.DataLoader] = []
Expand All @@ -336,6 +336,7 @@ def test(self) -> None:
stages.append("test")
dataloaders.append(self.datamodule.test_dataloader())
for stage, dataloader in zip(stages, dataloaders):
log.info("Running inference on %s set with batch size: %d", stage, dataloader.batch_size)
image_list, mask_list, mask_pred_list, label_list = [], [], [], []
for batch in dataloader:
images, masks, labels = batch
Expand Down
7 changes: 6 additions & 1 deletion quadra/utils/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,12 @@ def create_mask_report(
if len(area_graph["Defect Area Percentage"]) > 0:
fn_area_path = os.path.join(report_path, f"{stage}_acc_area.png")
fn_area_df = pd.DataFrame(area_graph)
ax = sns.boxplot(x="Defect Area Percentage", y="Accuracy", data=fn_area_df)
ax = sns.boxplot(
x="Defect Area Percentage",
y="Accuracy",
data=fn_area_df,
order=["Very Small <1%", "Small <10%", "Medium <25%", "Large >25%"],
)
ax.set_facecolor("white")
fig = ax.get_figure()
fig.savefig(fn_area_path)
Expand Down

0 comments on commit 5677cb3

Please sign in to comment.