-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
5 changed files
with
104 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
[package] | ||
name = "fast-stats" | ||
version = "1.1.0" | ||
version = "1.2.0" | ||
edition = "2021" | ||
|
||
[lib] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from enum import Enum | ||
from typing import Union | ||
|
||
import numpy as np | ||
|
||
from ._fast_stats_ext import _binary_f1_score_reqs | ||
|
||
Result = Union[None, float] | ||
|
||
|
||
class ZeroDivision(Enum): | ||
ZERO = "zero" | ||
NONE = "none" | ||
|
||
|
||
def _iou( | ||
tp: int, fp: int, fn: int, zero_division: ZeroDivision = ZeroDivision.NONE | ||
) -> Result: | ||
if tp + fp + fn == 0: | ||
if zero_division == ZeroDivision.NONE: | ||
return None | ||
elif zero_division == ZeroDivision.ZERO: | ||
return 0.0 | ||
return tp / (tp + fp + fn) | ||
|
||
|
||
def iou( | ||
array1: np.ndarray, | ||
array2: np.ndarray, | ||
zero_division: ZeroDivision = ZeroDivision.NONE, | ||
) -> Result: | ||
"""Calculation for IoU | ||
Args: | ||
array1 (np.ndarray): array of 0/1 values (must be bool or int types) | ||
array2 (np.ndarray): array of 0/1 values (must be bool or int types) | ||
zero_division (str): determines how to handle division by zero | ||
Returns: | ||
Result: None or float depending on values and zero division | ||
""" | ||
assert array1.shape == array2.shape, "y_true and y_pred must be same shape" | ||
assert all( | ||
[ | ||
isinstance(array1, np.ndarray), | ||
isinstance(array2, np.ndarray), | ||
] | ||
), "y_true and y_pred must be numpy arrays" | ||
zero_division = ZeroDivision(zero_division) | ||
|
||
tp, tp_fp, tp_fn = _binary_f1_score_reqs(array1, array2) | ||
fp, fn = tp_fp - tp, tp_fn - tp | ||
return _iou(tp, fp, fn, zero_division) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
import fast_stats | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"arr1,arr2,zero_division,expected", | ||
[ | ||
( | ||
np.zeros(4, dtype=np.uint64), | ||
np.ones(4, dtype=np.uint64), | ||
"zero", | ||
0.0, | ||
), | ||
( | ||
np.ones(4, dtype=np.uint64), | ||
np.ones(4, dtype=np.uint64), | ||
"zero", | ||
1.0, | ||
), | ||
( | ||
np.zeros(4, dtype=np.uint64), | ||
np.zeros(4, dtype=np.uint64), | ||
"zero", | ||
0.0, | ||
), | ||
( | ||
np.zeros(4, dtype=np.uint64), | ||
np.zeros(4, dtype=np.uint64), | ||
"none", | ||
None, | ||
), | ||
( | ||
np.ones(4, dtype=np.uint64), | ||
np.array([1, 0, 0, 0], dtype=np.uint64), | ||
"none", | ||
0.25, | ||
), | ||
( | ||
np.array([1, 1, 0, 0], dtype=np.uint64), | ||
np.array([1, 0, 1, 0], dtype=np.uint64), | ||
"none", | ||
1.0 / 3.0, | ||
), | ||
], | ||
) | ||
def test_iou(arr1, arr2, zero_division, expected): | ||
assert fast_stats.iou(arr1, arr2, zero_division) == expected |