diff --git a/Cargo.toml b/Cargo.toml index 52cc153..596ec6e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,4 +11,7 @@ crate-type = ["cdylib"] [dependencies] pyo3 = { version = "0.16.3", features = ["extension-module"] } numpy = "0.16.2" -ndarray = "0.15.4" \ No newline at end of file +ndarray = {version = "0.15.4" } +# ndarray = {version = "0.15.4", features = ["rayon"]} +num-traits = "0.2.15" +# rayon = "1.5.3" \ No newline at end of file diff --git a/benchmarks/timeit.ipynb b/benchmarks/timeit.ipynb index 7c89f3b..38ddd69 100644 --- a/benchmarks/timeit.ipynb +++ b/benchmarks/timeit.ipynb @@ -11,6 +11,13 @@ "import numpy as np" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Compared to scikit-learn" + ] + }, { "cell_type": "code", "execution_count": 2, @@ -30,7 +37,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "600 ms ± 3.48 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + "606 ms ± 11.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], @@ -48,7 +55,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "596 ms ± 1.44 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + "608 ms ± 12.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], @@ -66,7 +73,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "603 ms ± 3.91 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + "601 ms ± 1.67 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], @@ -95,13 +102,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "24.1 ms ± 204 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + "7.79 ms ± 49.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "source": [ "%%timeit\n", - "_ = fast_stats.precision(actual, pred)" + "_ = fast_stats.binary_precision(actual, pred)" ] }, { @@ -113,13 +120,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "24.3 ms ± 254 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + "7.27 ms ± 85 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "source": [ "%%timeit\n", - "_ = fast_stats.recall(actual, pred)" + "_ = fast_stats.binary_recall(actual, pred)" ] }, { @@ -131,34 +138,86 @@ "name": "stdout", "output_type": "stream", "text": [ - "24.2 ms ± 388 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + "9.89 ms ± 239 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "source": [ "%%timeit\n", - "_ = fast_stats.f1_score(actual, pred)" + "_ = fast_stats.binary_f1_score(actual, pred)" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "assert np.allclose(\n", - " fast_stats.precision(actual.flatten(), pred.flatten()),\n", + " fast_stats.binary_precision(actual.flatten(), pred.flatten()),\n", " precision_score(actual.flatten(), pred.flatten())\n", ")\n", "assert np.allclose(\n", - " fast_stats.recall(actual.flatten(), pred.flatten()),\n", + " fast_stats.binary_recall(actual.flatten(), pred.flatten()),\n", " recall_score(actual.flatten(), pred.flatten())\n", ")\n", "assert np.allclose(\n", - " fast_stats.f1_score(actual.flatten(), pred.flatten()),\n", + " fast_stats.binary_f1_score(actual.flatten(), pred.flatten()),\n", " f1_score(actual.flatten(), pred.flatten())\n", ")" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Compared to numpy" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "pred, actual = pred.astype(bool), actual.astype(bool)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3.14 ms ± 136 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "_ = fast_stats.binary_precision(pred, actual)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "5.53 ms ± 14.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "_ =np.logical_and(actual, pred).sum() / pred.sum()" + ] } ], "metadata": { diff --git a/fast_stats/__init__.py b/fast_stats/__init__.py index bf46469..683633a 100644 --- a/fast_stats/__init__.py +++ b/fast_stats/__init__.py @@ -1 +1,7 @@ -from .stats import f1_score, precision, recall +from .fast_stats import ( + _binary_f1_score_reqs, + _binary_precision_reqs, + _binary_recall_reqs, + _tp_fp_fn_tn, +) +from .stats import binary_f1_score, binary_precision, binary_recall diff --git a/fast_stats/stats.py b/fast_stats/stats.py index 11e5950..7f7985b 100644 --- a/fast_stats/stats.py +++ b/fast_stats/stats.py @@ -2,30 +2,37 @@ import numpy as np -from .fast_stats import _tp_fp_fn_tn +from .fast_stats import ( + _binary_f1_score_reqs, + _binary_precision_reqs, + _binary_recall_reqs, +) + +# from math import isnan # for Rust returning float nan + Result = Union[None, float] -def _precision(tp: int, fp: int, zero_division: str = "none") -> Result: - if tp + fp == 0: +def _precision(tp: int, tp_fp: int, zero_division: str = "none") -> Result: + if tp_fp == 0: if zero_division == "none": return None elif zero_division == "zero": return 0.0 - return tp / (tp + fp) + return tp / tp_fp -def _recall(tp: int, fn: int, zero_division: str = "none") -> Result: - if tp + fn == 0: +def _recall(tp: int, tp_fn: int, zero_division: str = "none") -> Result: + if tp_fn == 0: if zero_division == "none": return None elif zero_division == "zero": return 0.0 - return tp / (tp + fn) + return tp / tp_fn -def precision( +def binary_precision( y_true: np.ndarray, y_pred: np.ndarray, zero_division: str = "none" ) -> Result: assert y_true.shape == y_pred.shape, "y_true and y_pred must be same shape" @@ -33,11 +40,11 @@ def precision( y_true, np.ndarray ), "y_true and y_pred must be numpy arrays" - tp, fp, _, _ = _tp_fp_fn_tn(y_true, y_pred) - return _precision(tp, fp, zero_division) + tp, tp_fp = _binary_precision_reqs(y_true, y_pred) + return _precision(tp, tp_fp, zero_division) -def recall( +def binary_recall( y_true: np.ndarray, y_pred: np.ndarray, zero_division: str = "none" ) -> Result: assert y_true.shape == y_pred.shape, "y_true and y_pred must be same shape" @@ -45,18 +52,20 @@ def recall( y_true, np.ndarray ), "y_true and y_pred must be numpy arrays" - tp, _, fn, _ = _tp_fp_fn_tn(y_true, y_pred) - return _recall(tp, fn, zero_division) + tp, tp_fn = _binary_recall_reqs(y_true, y_pred) + return _recall(tp, tp_fn, zero_division) -def f1_score(y_true: np.ndarray, y_pred: np.ndarray, zero_division: str = "none"): +def binary_f1_score( + y_true: np.ndarray, y_pred: np.ndarray, zero_division: str = "none" +): assert y_true.shape == y_pred.shape, "y_true and y_pred must be same shape" assert isinstance(y_pred, np.ndarray) and isinstance( y_true, np.ndarray ), "y_true and y_pred must be numpy arrays" - tp, fp, fn, _ = _tp_fp_fn_tn(y_true, y_pred) - p, r = _precision(tp, fp, "0"), _recall(tp, fn, "0") + tp, tp_fp, tp_fn = _binary_f1_score_reqs(y_true, y_pred) + p, r = _precision(tp, tp_fp, "zero"), _recall(tp, tp_fn, "zero") if p + r == 0: if zero_division == "none": diff --git a/pyproject.toml b/pyproject.toml index 859a77d..ff05830 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "fast-stats" -version = "0.0.1" +version = "0.0.2" description = "A fast simple library for calculating basic statistics" readme = "README.md" license = {file = "LICENSE"} @@ -18,8 +18,8 @@ dependencies = [ [project.optional-dependencies] test = [ -# "pytest", -# "pytest-cov[all]" + "pytest", + "pytest-cov[all]" ] [project.urls] @@ -29,3 +29,8 @@ repository = "https://github.com/zachcoleman/fast-stats" requires = ["maturin>=0.12,<0.13"] build-backend = "maturin" +[tool.maturin] +strip = true + +[tool.isort] +profile="black" diff --git a/src/lib.rs b/src/lib.rs index 023509b..b2fd95b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,33 +1,398 @@ -use ndarray::*; use numpy::*; -use pyo3::prelude::*; +use pyo3::{exceptions, prelude::*}; use std::iter::zip; -/// TP and FP +fn sum(arr: ndarray::ArrayD) -> i128 +where + T: Clone + std::ops::Add + num_traits::Num + Into, +{ + let mut sum = 0; + for row in arr.rows() { + sum = sum + row.iter().fold(0, |acc, elt| acc + elt.clone().into()); + } + sum +} + +/// Get tp, fp, fn, tn counts by looping #[pyfunction] -fn _tp_fp_fn_tn(_py: Python<'_>, actual: &PyArrayDyn, pred: &PyArrayDyn) -> (usize, usize, usize, usize) { +#[pyo3(name = "_tp_fp_fn_tn")] +#[pyo3(text_signature = "(actual: np.ndarray, pred: np.ndarray, /)")] +fn tp_fp_fn_tn( + _py: Python<'_>, + actual: &PyArrayDyn, + pred: &PyArrayDyn, +) -> (usize, usize, usize, usize) { let mut _tp = 0; let mut _fp: usize = 0; let mut _fn: usize = 0; let mut _tn: usize = 0; - for (y_pred, y_actual) in zip(pred.readonly().as_array().iter(), actual.readonly().as_array().iter()){ - if *y_pred == 1 && *y_actual == 1{ + for (y_pred, y_actual) in zip( + pred.readonly().as_array().iter(), + actual.readonly().as_array().iter(), + ) { + if *y_pred == 1 && *y_actual == 1 { _tp = _tp + 1; - } else if *y_pred == 1 && *y_actual == 0{ + } else if *y_pred == 1 && *y_actual == 0 { _fp = _fp + 1; - } else if *y_pred == 0 && *y_actual == 1{ + } else if *y_pred == 0 && *y_actual == 1 { _fn = _fn + 1; - } else{ + } else { _tn = _tn + 1; } } (_tp, _fp, _fn, _tn) } +/// Array-based binary precision req calculating +#[pyfunction] +#[pyo3(name = "_binary_precision_reqs")] +#[pyo3(text_signature = "(actual: np.ndarray, pred: np.ndarray, /)")] +fn py_binary_precision_reqs( + _py: Python<'_>, + actual: &PyAny, + pred: &PyAny, +) -> PyResult<(i128, i128)> { + // TODO macro this out + // bool + if let (Ok(i), Ok(j)) = ( + actual.extract::>(), + pred.extract::>(), + ) { + return Ok(binary_precision_reqs::( + i.to_owned_array().mapv(|e| e as u8), + j.to_owned_array().mapv(|e| e as u8), + )); + } + // i8 + if let (Ok(i), Ok(j)) = ( + actual.extract::>(), + pred.extract::>(), + ) { + return Ok(binary_precision_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + // i16 + if let (Ok(i), Ok(j)) = ( + actual.extract::>(), + pred.extract::>(), + ) { + return Ok(binary_precision_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + // i32 + if let (Ok(i), Ok(j)) = ( + actual.extract::>(), + pred.extract::>(), + ) { + return Ok(binary_precision_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + // i64 + if let (Ok(i), Ok(j)) = ( + actual.extract::>(), + pred.extract::>(), + ) { + return Ok(binary_precision_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + // u8 + if let (Ok(i), Ok(j)) = ( + actual.extract::>(), + pred.extract::>(), + ) { + return Ok(binary_precision_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + // u16 + if let (Ok(i), Ok(j)) = ( + actual.extract::>(), + pred.extract::>(), + ) { + return Ok(binary_precision_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + // u32 + if let (Ok(i), Ok(j)) = ( + actual.extract::>(), + pred.extract::>(), + ) { + return Ok(binary_precision_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + // u64 + if let (Ok(i), Ok(j)) = ( + actual.extract::>(), + pred.extract::>(), + ) { + return Ok(binary_precision_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + + Err(PyErr::new::( + "Unsupport numpy dtype", + )) +} + +fn binary_precision_reqs(actual: ndarray::ArrayD, pred: ndarray::ArrayD) -> (i128, i128) +where + T: Clone + std::ops::Add + num_traits::Num + Into, +{ + // TP, TP + FP + (sum(actual * &pred), sum(pred)) +} + +/// Array-based binary recall req calculating +#[pyfunction] +#[pyo3(name = "_binary_recall_reqs")] +#[pyo3(text_signature = "(actual: np.ndarray, pred: np.ndarray, /)")] +fn py_binary_recall_reqs(_py: Python<'_>, actual: &PyAny, pred: &PyAny) -> PyResult<(i128, i128)> { + // TODO macro this out + // bool + if let (Ok(i), Ok(j)) = ( + actual.extract::>(), + pred.extract::>(), + ) { + return Ok(binary_recall_reqs::( + i.to_owned_array().mapv(|e| e as u8), + j.to_owned_array().mapv(|e| e as u8), + )); + } + // i8 + if let (Ok(i), Ok(j)) = ( + actual.extract::>(), + pred.extract::>(), + ) { + return Ok(binary_recall_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + // i16 + if let (Ok(i), Ok(j)) = ( + actual.extract::>(), + pred.extract::>(), + ) { + return Ok(binary_recall_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + // i32 + if let (Ok(i), Ok(j)) = ( + actual.extract::>(), + pred.extract::>(), + ) { + return Ok(binary_recall_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + // i64 + if let (Ok(i), Ok(j)) = ( + actual.extract::>(), + pred.extract::>(), + ) { + return Ok(binary_recall_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + // u8 + if let (Ok(i), Ok(j)) = ( + actual.extract::>(), + pred.extract::>(), + ) { + return Ok(binary_recall_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + // u16 + if let (Ok(i), Ok(j)) = ( + actual.extract::>(), + pred.extract::>(), + ) { + return Ok(binary_recall_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + // u32 + if let (Ok(i), Ok(j)) = ( + actual.extract::>(), + pred.extract::>(), + ) { + return Ok(binary_recall_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + // u64 + if let (Ok(i), Ok(j)) = ( + actual.extract::>(), + pred.extract::>(), + ) { + return Ok(binary_recall_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + + Err(PyErr::new::( + "Unsupport numpy dtype", + )) +} + +fn binary_recall_reqs(actual: ndarray::ArrayD, pred: ndarray::ArrayD) -> (i128, i128) +where + T: Clone + std::ops::Add + num_traits::Num + Into, +{ + // TP, TP + FN + (sum(&actual * pred), sum(actual)) +} + +/// Array-based binary recall req calculating +#[pyfunction] +#[pyo3(name = "_binary_f1_score_reqs")] +#[pyo3(text_signature = "(actual: np.ndarray, pred: np.ndarray, /)")] +fn py_binary_f1_score_reqs( + _py: Python<'_>, + actual: &PyAny, + pred: &PyAny, +) -> PyResult<(i128, i128, i128)> { + // TODO macro this out + + // bool + if let (Ok(i), Ok(j)) = ( + actual.extract::>(), + pred.extract::>(), + ) { + return Ok(binary_f1_score_reqs::( + i.to_owned_array().mapv(|e| e as u8), + j.to_owned_array().mapv(|e| e as u8), + )); + } + + // i8 + if let (Ok(i), Ok(j)) = ( + actual.extract::>(), + pred.extract::>(), + ) { + return Ok(binary_f1_score_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + // i16 + if let (Ok(i), Ok(j)) = ( + actual.extract::>(), + pred.extract::>(), + ) { + return Ok(binary_f1_score_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + // i32 + if let (Ok(i), Ok(j)) = ( + actual.extract::>(), + pred.extract::>(), + ) { + return Ok(binary_f1_score_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + // i64 + if let (Ok(i), Ok(j)) = ( + actual.extract::>(), + pred.extract::>(), + ) { + return Ok(binary_f1_score_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + // u8 + if let (Ok(i), Ok(j)) = ( + actual.extract::>(), + pred.extract::>(), + ) { + return Ok(binary_f1_score_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + // u16 + if let (Ok(i), Ok(j)) = ( + actual.extract::>(), + pred.extract::>(), + ) { + return Ok(binary_f1_score_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + // u32 + if let (Ok(i), Ok(j)) = ( + actual.extract::>(), + pred.extract::>(), + ) { + return Ok(binary_f1_score_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + // u64 + if let (Ok(i), Ok(j)) = ( + actual.extract::>(), + pred.extract::>(), + ) { + return Ok(binary_f1_score_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + + Err(PyErr::new::( + "Unsupport numpy dtype", + )) +} + +fn binary_f1_score_reqs( + actual: ndarray::ArrayD, + pred: ndarray::ArrayD, +) -> (i128, i128, i128) +where + T: Clone + std::ops::Add + num_traits::Num + Into, +{ + // TP, TP + FP, TP + FN + (sum(&actual * &pred), sum(pred), sum(actual)) +} + /// A Python module implemented in Rust. #[pymodule] fn fast_stats(_py: Python, m: &PyModule) -> PyResult<()> { - m.add_function(wrap_pyfunction!(_tp_fp_fn_tn, m)?)?; + m.add_function(wrap_pyfunction!(tp_fp_fn_tn, m)?)?; + m.add_function(wrap_pyfunction!(py_binary_precision_reqs, m)?)?; + m.add_function(wrap_pyfunction!(py_binary_recall_reqs, m)?)?; + m.add_function(wrap_pyfunction!(py_binary_f1_score_reqs, m)?)?; Ok(()) -} \ No newline at end of file +} diff --git a/tests/test_stats.py b/tests/test_stats.py new file mode 100644 index 0000000..5eac344 --- /dev/null +++ b/tests/test_stats.py @@ -0,0 +1,106 @@ +import numpy as np +import pytest + +import fast_stats + + +@pytest.mark.parametrize( + "y_true,y_pred,zero_division,expected", + [ + ( + np.zeros(4, dtype=np.uint64), + np.ones(4, dtype=np.uint64), + "zero", + 0, + ), # all FP + ( + np.ones(4, dtype=np.uint64), + np.ones(4, dtype=np.uint64), + "zero", + 1.0, + ), # all TP + ( + np.zeros(4, dtype=np.uint64), + np.zeros(4, dtype=np.uint64), + "zero", + 0.0, + ), # No TP & No FP + ( + np.zeros(4, dtype=np.uint64), + np.zeros(4, dtype=np.uint64), + "none", + None, + ), # No TP & No FP + ( + np.ones(4, dtype=np.uint64), + np.array([1, 0, 0, 0], dtype=np.uint64), + "none", + 1.0, + ), # 1 TP & 0 FP + ( + np.array([1, 0, 0, 0], dtype=np.uint64), + np.ones(4, dtype=np.uint64), + "none", + 0.25, + ), # 1 TP & 3 FP + ( + np.zeros(4, dtype=np.uint64), + np.array([1, 0, 0, 0], dtype=np.uint64), + "none", + 0.0, + ), # 0 TP & 1 TP + ], +) +def test_precision(y_true, y_pred, zero_division, expected): + assert fast_stats.binary_precision(y_true, y_pred, zero_division) == expected + + +@pytest.mark.parametrize( + "y_true,y_pred,zero_division,expected", + [ + ( + np.ones(4, dtype=np.uint64), + np.zeros(4, dtype=np.uint64), + "zero", + 0, + ), # all FN + ( + np.ones(4, dtype=np.uint64), + np.ones(4, dtype=np.uint64), + "zero", + 1.0, + ), # all TP + ( + np.zeros(4, dtype=np.uint64), + np.zeros(4, dtype=np.uint64), + "zero", + 0.0, + ), # No TP & No FN + ( + np.zeros(4, dtype=np.uint64), + np.zeros(4, dtype=np.uint64), + "none", + None, + ), # No TP & No FN + ( + np.ones(4, dtype=np.uint64), + np.array([1, 0, 0, 0], dtype=np.uint64), + "none", + 0.25, + ), # 1 TP & 3 FN + ( + np.array([1, 0, 0, 0], dtype=np.uint64), + np.ones(4, dtype=np.uint64), + "none", + 1.0, + ), # 1 TP & 0 FN + ( + np.array([1, 0, 0, 0], dtype=np.uint64), + np.zeros(4, dtype=np.uint64), + "none", + 0.0, + ), # 0 TP & 1 FN + ], +) +def test_recall(y_true, y_pred, zero_division, expected): + assert fast_stats.binary_recall(y_true, y_pred, zero_division) == expected