Skip to content

Commit

Permalink
Merge pull request #39 from SFI-Visual-Intelligence/Jan/accuracy
Browse files Browse the repository at this point in the history
Added accuracy and tests for it and Jan model, no clashes merging myself
  • Loading branch information
hzavadil98 authored Feb 5, 2025
2 parents d742fe6 + 46798d2 commit ed0eaf2
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 8 deletions.
19 changes: 16 additions & 3 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@

from utils.metrics import F1Score, Precision, Recall

from utils.metrics import Accuracy, F1Score, Precision, Recall


def test_recall():
Expand Down Expand Up @@ -84,3 +82,18 @@ def test_for_zero_denominator():
assert precision4.allclose(torch.tensor(0.0), atol=1e-5), (
f"Precision Score: {precision4.item()}"
)


def test_accuracy():
import torch

accuracy = Accuracy()

y_true = torch.tensor([0, 3, 2, 3, 4])
y_pred = torch.tensor([0, 1, 2, 3, 4])

accuracy_score = accuracy(y_true, y_pred)

assert (torch.abs(torch.tensor(accuracy_score - 0.8)) < 1e-5), (
f"Accuracy Score: {accuracy_score.item()}"
)
17 changes: 16 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import torch

from utils.models import ChristianModel
from utils.models import ChristianModel, JanModel


@pytest.mark.parametrize(
Expand All @@ -20,3 +20,18 @@ def test_christian_model(image_shape, num_classes):
assert y.sum(dim=1).allclose(torch.ones(n), atol=1e-5), (
f"Softmax output should sum to 1, but got: {y.sum()}"
)


@pytest.mark.parametrize(
"image_shape, num_classes",
[((1, 28, 28), 4), ((3, 16, 16), 10)],
)
def test_jan_model(image_shape, num_classes):
n, c, h, w = 5, *image_shape

model = JanModel(image_shape, num_classes)

x = torch.randn(n, c, h, w)
y = model(x)

assert y.shape == (n, num_classes), f"Shape: {y.shape}"
6 changes: 3 additions & 3 deletions utils/load_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import torch.nn as nn

from .metrics import EntropyPrediction, F1Score, precision
from .metrics import Accuracy, EntropyPrediction, F1Score, Precision


class MetricWrapper(nn.Module):
Expand Down Expand Up @@ -39,9 +39,9 @@ def _get_metric(self, key):
case "recall":
raise NotImplementedError("Recall score not implemented yet")
case "precision":
return precision()
return Precision()
case "accuracy":
raise NotImplementedError("Accuracy score not implemented yet")
return Accuracy()
case _:
raise ValueError(f"Metric {key} not supported")

Expand Down
3 changes: 2 additions & 1 deletion utils/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
__all__ = ["EntropyPrediction", "Recall", "F1Score", "Precision"]
__all__ = ["EntropyPrediction", "Recall", "F1Score", "Precision", "Accuracy"]

from .accuracy import Accuracy
from .EntropyPred import EntropyPrediction
from .F1 import F1Score
from .precision import Precision
Expand Down
33 changes: 33 additions & 0 deletions utils/metrics/accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch
from torch import nn


class Accuracy(nn.Module):
def __init__(self):
super().__init__()

def forward(self, y_true, y_pred):
"""
Compute the accuracy of the model.
Parameters
----------
y_true : torch.Tensor
True labels.
y_pred : torch.Tensor
Predicted labels.
Returns
-------
float
Accuracy score.
"""
return (y_true == y_pred).float().mean().item()


if __name__ == "__main__":
y_true = torch.tensor([0, 3, 2, 3, 4])
y_pred = torch.tensor([0, 1, 2, 3, 4])

accuracy = Accuracy()
print(accuracy(y_true, y_pred))

0 comments on commit ed0eaf2

Please sign in to comment.