Skip to content

Commit

Permalink
Release 1.2.0 (#22)
Browse files Browse the repository at this point in the history
* Adding IoU

* bumping version
  • Loading branch information
zachcoleman authored Nov 9, 2022
1 parent 410b5e1 commit 42bc367
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "fast-stats"
version = "1.1.0"
version = "1.2.0"
edition = "2021"

[lib]
Expand Down
1 change: 1 addition & 0 deletions fast_stats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@
binary_tp_fp_fn,
)
from .confusion_matrix import confusion_matrix
from .iou import iou
from .multiclass import f1_score, precision, recall, stats
52 changes: 52 additions & 0 deletions fast_stats/iou.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from enum import Enum
from typing import Union

import numpy as np

from ._fast_stats_ext import _binary_f1_score_reqs

Result = Union[None, float]


class ZeroDivision(Enum):
ZERO = "zero"
NONE = "none"


def _iou(
tp: int, fp: int, fn: int, zero_division: ZeroDivision = ZeroDivision.NONE
) -> Result:
if tp + fp + fn == 0:
if zero_division == ZeroDivision.NONE:
return None
elif zero_division == ZeroDivision.ZERO:
return 0.0
return tp / (tp + fp + fn)


def iou(
array1: np.ndarray,
array2: np.ndarray,
zero_division: ZeroDivision = ZeroDivision.NONE,
) -> Result:
"""Calculation for IoU
Args:
array1 (np.ndarray): array of 0/1 values (must be bool or int types)
array2 (np.ndarray): array of 0/1 values (must be bool or int types)
zero_division (str): determines how to handle division by zero
Returns:
Result: None or float depending on values and zero division
"""
assert array1.shape == array2.shape, "y_true and y_pred must be same shape"
assert all(
[
isinstance(array1, np.ndarray),
isinstance(array2, np.ndarray),
]
), "y_true and y_pred must be numpy arrays"
zero_division = ZeroDivision(zero_division)

tp, tp_fp, tp_fn = _binary_f1_score_reqs(array1, array2)
fp, fn = tp_fp - tp, tp_fn - tp
return _iou(tp, fp, fn, zero_division)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "fast-stats"
version = "1.1.1"
version = "1.2.0"
description = "A fast and simple library for calculating basic statistics"
readme = "README.md"
license = {text="Apache 2.0"}
Expand Down
49 changes: 49 additions & 0 deletions tests/test_iou.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import numpy as np
import pytest

import fast_stats


@pytest.mark.parametrize(
"arr1,arr2,zero_division,expected",
[
(
np.zeros(4, dtype=np.uint64),
np.ones(4, dtype=np.uint64),
"zero",
0.0,
),
(
np.ones(4, dtype=np.uint64),
np.ones(4, dtype=np.uint64),
"zero",
1.0,
),
(
np.zeros(4, dtype=np.uint64),
np.zeros(4, dtype=np.uint64),
"zero",
0.0,
),
(
np.zeros(4, dtype=np.uint64),
np.zeros(4, dtype=np.uint64),
"none",
None,
),
(
np.ones(4, dtype=np.uint64),
np.array([1, 0, 0, 0], dtype=np.uint64),
"none",
0.25,
),
(
np.array([1, 1, 0, 0], dtype=np.uint64),
np.array([1, 0, 1, 0], dtype=np.uint64),
"none",
1.0 / 3.0,
),
],
)
def test_iou(arr1, arr2, zero_division, expected):
assert fast_stats.iou(arr1, arr2, zero_division) == expected

0 comments on commit 42bc367

Please sign in to comment.