Skip to content

Commit

Permalink
✅ Improve classification cov. & fix logits
Browse files Browse the repository at this point in the history
  • Loading branch information
o-laurent committed Aug 16, 2023
1 parent 3063f14 commit 2e76869
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 26 deletions.
20 changes: 12 additions & 8 deletions tests/_dummies/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,22 @@ def __new__(
**kwargs,
)
elif baseline_type == "ensemble":
kwargs["num_estimators"] = 2
return ClassificationEnsemble(
num_classes=num_classes,
num_estimators=2,
model=model,
loss=loss,
optimization_procedure=optimization_procedure,
**kwargs,
)

@staticmethod
@classmethod
def add_model_specific_args(
parent_parser: ArgumentParser,
cls,
parser: ArgumentParser,
) -> ArgumentParser:
return parent_parser
parser = ClassificationEnsemble.add_model_specific_args(parser)
return parser


class DummyRegressionBaseline:
Expand Down Expand Up @@ -86,8 +88,8 @@ def __new__(
**kwargs,
)
elif baseline_type == "ensemble":
kwargs["num_estimators"] = 2
return RegressionEnsemble(
num_estimators=2,
model=model,
loss=loss,
optimization_procedure=optimization_procedure,
Expand All @@ -97,8 +99,10 @@ def __new__(
**kwargs,
)

@staticmethod
@classmethod
def add_model_specific_args(
parent_parser: ArgumentParser,
cls,
parser: ArgumentParser,
) -> ArgumentParser:
return parent_parser
parser = ClassificationEnsemble.add_model_specific_args(parser)
return parser
27 changes: 22 additions & 5 deletions tests/_dummies/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class DummyClassificationDataModule(LightningDataModule):
def __init__(
self,
root: Union[str, Path],
ood_detection: bool,
batch_size: int,
num_classes: int = 10,
num_workers: int = 1,
Expand All @@ -31,6 +32,7 @@ def __init__(
root = Path(root)

self.root: Path = root
self.ood_detection = ood_detection
self.batch_size = batch_size
self.num_classes = num_classes
self.num_workers = num_workers
Expand Down Expand Up @@ -84,8 +86,11 @@ def train_dataloader(self) -> DataLoader:
def val_dataloader(self) -> DataLoader:
return self._data_loader(self.val)

def test_dataloader(self) -> List[DataLoader]:
return [self._data_loader(self.test), self._data_loader(self.ood)]
def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
dataloader = [self._data_loader(self.test)]
if self.ood_detection:
dataloader.append(self._data_loader(self.ood))
return dataloader

def _data_loader(
self, dataset: Dataset, shuffle: bool = False
Expand All @@ -109,6 +114,9 @@ def add_argparse_args(
p.add_argument("--root", type=str, default="./data/")
p.add_argument("--batch_size", type=int, default=2)
p.add_argument("--num_workers", type=int, default=1)
p.add_argument(
"--evaluate_ood", dest="ood_detection", action="store_true"
)
return parent_parser


Expand All @@ -119,6 +127,7 @@ class DummyRegressionDataModule(LightningDataModule):
def __init__(
self,
root: Union[str, Path],
ood_detection: bool,
batch_size: int,
out_features: int = 2,
num_workers: int = 1,
Expand All @@ -128,9 +137,10 @@ def __init__(
) -> None:
super().__init__()

root = Path(root)

if isinstance(root, str):
root = Path(root)
self.root: Path = root
self.ood_detection = ood_detection
self.batch_size = batch_size
self.out_features = out_features
self.num_workers = num_workers
Expand Down Expand Up @@ -164,6 +174,7 @@ def setup(self, stage: Optional[str] = None) -> None:
out_features=self.out_features,
transform=self.transform_test,
)
if self.ood_detection:
self.ood = self.ood_dataset(
self.root,
out_features=self.out_features,
Expand All @@ -177,7 +188,10 @@ def val_dataloader(self) -> DataLoader:
return self._data_loader(self.val)

def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
return self._data_loader(self.test)
dataloader = [self._data_loader(self.test)]
if self.ood_detection:
dataloader.append(self._data_loader(self.ood))
return dataloader

def _data_loader(
self, dataset: Dataset, shuffle: bool = False
Expand All @@ -201,4 +215,7 @@ def add_argparse_args(
p.add_argument("--root", type=str, default="./data/")
p.add_argument("--batch_size", type=int, default=2)
p.add_argument("--num_workers", type=int, default=1)
p.add_argument(
"--evaluate_ood", dest="ood_detection", action="store_true"
)
return parent_parser
3 changes: 2 additions & 1 deletion tests/_dummies/dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# fmt:off
from pathlib import Path
from typing import Any, Callable, Tuple

import torch
Expand All @@ -12,7 +13,7 @@
class DummyClassificationDataset(data.Dataset):
def __init__(
self,
root: str,
root: Path,
train: bool = True,
transform: Callable[..., Any] | None = None,
target_transform: Callable[..., Any] | None = None,
Expand Down
35 changes: 29 additions & 6 deletions tests/routines/test_classification.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
# fmt:off
from pathlib import Path

import pytest
import torch.nn as nn
from cli_test_helpers import ArgvContext

from torch_uncertainty import cli_main, init_args
from torch_uncertainty.optimization_procedures import optim_cifar10_resnet18
from torch_uncertainty.routines.classification import (
ClassificationEnsemble,
ClassificationSingle,
)

from .._dummies import (
DummyClassificationBaseline,
Expand All @@ -19,7 +24,7 @@ class TestClassificationSingle:

def test_cli_main_dummy_binary(self):
root = Path(__file__).parent.absolute().parents[0]
with ArgvContext(""):
with ArgvContext("file.py", "--logits"):
args = init_args(
DummyClassificationBaseline, DummyClassificationDataModule
)
Expand All @@ -31,7 +36,7 @@ def test_cli_main_dummy_binary(self):
model = DummyClassificationBaseline(
num_classes=dm.num_classes,
in_channels=dm.num_channels,
loss=nn.CrossEntropyLoss,
loss=nn.BCEWithLogitsLoss,
optimization_procedure=optim_cifar10_resnet18,
baseline_type="single",
**vars(args),
Expand All @@ -41,7 +46,7 @@ def test_cli_main_dummy_binary(self):

def test_cli_main_dummy_ood(self):
root = Path(__file__).parent.absolute().parents[0]
with ArgvContext("--evaluate_ood"):
with ArgvContext("file.py", "--evaluate_ood", "--entropy"):
args = init_args(
DummyClassificationBaseline, DummyClassificationDataModule
)
Expand All @@ -61,13 +66,19 @@ def test_cli_main_dummy_ood(self):

cli_main(model, dm, root, "dummy", args)

def test_classification_failures(self):
with pytest.raises(ValueError):
ClassificationSingle(
10, nn.Module(), None, None, use_entropy=True, use_logits=True
)


class TestClassificationEnsemble:
"""Testing the classification routine with an ensemble model."""

def test_cli_main_dummy_binary(self):
root = Path(__file__).parent.absolute().parents[0]
with ArgvContext(""):
with ArgvContext("file.py", "--mutual_information"):
args = init_args(
DummyClassificationBaseline, DummyClassificationDataModule
)
Expand All @@ -79,7 +90,7 @@ def test_cli_main_dummy_binary(self):
model = DummyClassificationBaseline(
num_classes=dm.num_classes,
in_channels=dm.num_channels,
loss=nn.CrossEntropyLoss,
loss=nn.BCEWithLogitsLoss,
optimization_procedure=optim_cifar10_resnet18,
baseline_type="ensemble",
**vars(args),
Expand All @@ -89,7 +100,7 @@ def test_cli_main_dummy_binary(self):

def test_cli_main_dummy_ood(self):
root = Path(__file__).parent.absolute().parents[0]
with ArgvContext("--evaluate_ood"):
with ArgvContext("file.py", "--evaluate_ood", "--variation_ratio"):
args = init_args(
DummyClassificationBaseline, DummyClassificationDataModule
)
Expand All @@ -108,3 +119,15 @@ def test_cli_main_dummy_ood(self):
)

cli_main(model, dm, root, "dummy", args)

def test_classification_failures(self):
with pytest.raises(ValueError):
ClassificationEnsemble(
10,
nn.Module(),
None,
None,
2,
use_entropy=True,
use_logits=True,
)
12 changes: 6 additions & 6 deletions torch_uncertainty/routines/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ def training_step(
logits = self.forward(inputs)

# BCEWithLogitsLoss expects float targets
if self.loss == nn.BCEWithLogitsLoss:
if self.binary_cls and self.loss == nn.BCEWithLogitsLoss:
logits = logits.squeeze(-1)
targets = targets.float()
loss = self.criterion(logits, targets)
self.log("train_loss", loss)
Expand Down Expand Up @@ -205,10 +206,10 @@ def test_step(
probs = torch.sigmoid(logits).squeeze(-1)
else:
probs = F.softmax(logits, dim=-1)
confs, _ = probs.max(dim=-1)
confs = probs.max(dim=-1)[0]

if self.use_logits:
ood_values, _ = -logits.max(dim=-1)
ood_values = -logits.max(dim=-1)[0]
elif self.use_entropy:
ood_values = torch.special.entr(probs).sum(dim=-1)
else:
Expand Down Expand Up @@ -257,7 +258,6 @@ def add_model_specific_args(
) -> ArgumentParser:
"""Defines the routine's attributes via command-line options:
- ``--evaluate_ood``: sets :attr:`ood_detection` to ``True``.
- ``--entropy``: sets :attr:`use_entropy` to ``True``.
- ``--logits``: sets :attr:`use_logits` to ``True``.
"""
Expand Down Expand Up @@ -413,10 +413,10 @@ def test_step(
probs_per_est = F.softmax(logits, dim=-1)

probs = probs_per_est.mean(dim=1)
confs, _ = probs.max(-1)
confs = probs.max(-1)[0]

if self.use_logits:
ood_values, _ = -logits.mean(dim=1).max(dim=-1)
ood_values = -logits.mean(dim=1).max(dim=-1)[0]
elif self.use_entropy:
ood_values = torch.special.entr(probs).sum(dim=-1).mean(dim=1)
elif self.use_mi:
Expand Down

0 comments on commit 2e76869

Please sign in to comment.