Skip to content

Commit

Permalink
Faster Implementations (#2)
Browse files Browse the repository at this point in the history
* dev and patterns for supporting lots of numpy types

* updated binary logic and runs

* rust formatter

* bumping version
  • Loading branch information
zachcoleman authored May 22, 2022
1 parent ee3f688 commit d50aed9
Show file tree
Hide file tree
Showing 7 changed files with 598 additions and 45 deletions.
5 changes: 4 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,7 @@ crate-type = ["cdylib"]
[dependencies]
pyo3 = { version = "0.16.3", features = ["extension-module"] }
numpy = "0.16.2"
ndarray = "0.15.4"
ndarray = {version = "0.15.4" }
# ndarray = {version = "0.15.4", features = ["rayon"]}
num-traits = "0.2.15"
# rayon = "1.5.3"
85 changes: 72 additions & 13 deletions benchmarks/timeit.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@
"import numpy as np"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Compared to scikit-learn"
]
},
{
"cell_type": "code",
"execution_count": 2,
Expand All @@ -30,7 +37,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"600 ms ± 3.48 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
"606 ms ± 11.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
Expand All @@ -48,7 +55,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"596 ms ± 1.44 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
"608 ms ± 12.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
Expand All @@ -66,7 +73,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"603 ms ± 3.91 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
"601 ms ± 1.67 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
Expand Down Expand Up @@ -95,13 +102,13 @@
"name": "stdout",
"output_type": "stream",
"text": [
"24.1 ms ± 204 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
"7.79 ms ± 49.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"_ = fast_stats.precision(actual, pred)"
"_ = fast_stats.binary_precision(actual, pred)"
]
},
{
Expand All @@ -113,13 +120,13 @@
"name": "stdout",
"output_type": "stream",
"text": [
"24.3 ms ± 254 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
"7.27 ms ± 85 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"_ = fast_stats.recall(actual, pred)"
"_ = fast_stats.binary_recall(actual, pred)"
]
},
{
Expand All @@ -131,34 +138,86 @@
"name": "stdout",
"output_type": "stream",
"text": [
"24.2 ms ± 388 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
"9.89 ms ± 239 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"_ = fast_stats.f1_score(actual, pred)"
"_ = fast_stats.binary_f1_score(actual, pred)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"assert np.allclose(\n",
" fast_stats.precision(actual.flatten(), pred.flatten()),\n",
" fast_stats.binary_precision(actual.flatten(), pred.flatten()),\n",
" precision_score(actual.flatten(), pred.flatten())\n",
")\n",
"assert np.allclose(\n",
" fast_stats.recall(actual.flatten(), pred.flatten()),\n",
" fast_stats.binary_recall(actual.flatten(), pred.flatten()),\n",
" recall_score(actual.flatten(), pred.flatten())\n",
")\n",
"assert np.allclose(\n",
" fast_stats.f1_score(actual.flatten(), pred.flatten()),\n",
" fast_stats.binary_f1_score(actual.flatten(), pred.flatten()),\n",
" f1_score(actual.flatten(), pred.flatten())\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Compared to numpy"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"pred, actual = pred.astype(bool), actual.astype(bool)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3.14 ms ± 136 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"_ = fast_stats.binary_precision(pred, actual)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"5.53 ms ± 14.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"_ =np.logical_and(actual, pred).sum() / pred.sum()"
]
}
],
"metadata": {
Expand Down
8 changes: 7 additions & 1 deletion fast_stats/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
from .stats import f1_score, precision, recall
from .fast_stats import (
_binary_f1_score_reqs,
_binary_precision_reqs,
_binary_recall_reqs,
_tp_fp_fn_tn,
)
from .stats import binary_f1_score, binary_precision, binary_recall
41 changes: 25 additions & 16 deletions fast_stats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,61 +2,70 @@

import numpy as np

from .fast_stats import _tp_fp_fn_tn
from .fast_stats import (
_binary_f1_score_reqs,
_binary_precision_reqs,
_binary_recall_reqs,
)

# from math import isnan # for Rust returning float nan


Result = Union[None, float]


def _precision(tp: int, fp: int, zero_division: str = "none") -> Result:
if tp + fp == 0:
def _precision(tp: int, tp_fp: int, zero_division: str = "none") -> Result:
if tp_fp == 0:
if zero_division == "none":
return None
elif zero_division == "zero":
return 0.0
return tp / (tp + fp)
return tp / tp_fp


def _recall(tp: int, fn: int, zero_division: str = "none") -> Result:
if tp + fn == 0:
def _recall(tp: int, tp_fn: int, zero_division: str = "none") -> Result:
if tp_fn == 0:
if zero_division == "none":
return None
elif zero_division == "zero":
return 0.0
return tp / (tp + fn)
return tp / tp_fn


def precision(
def binary_precision(
y_true: np.ndarray, y_pred: np.ndarray, zero_division: str = "none"
) -> Result:
assert y_true.shape == y_pred.shape, "y_true and y_pred must be same shape"
assert isinstance(y_pred, np.ndarray) and isinstance(
y_true, np.ndarray
), "y_true and y_pred must be numpy arrays"

tp, fp, _, _ = _tp_fp_fn_tn(y_true, y_pred)
return _precision(tp, fp, zero_division)
tp, tp_fp = _binary_precision_reqs(y_true, y_pred)
return _precision(tp, tp_fp, zero_division)


def recall(
def binary_recall(
y_true: np.ndarray, y_pred: np.ndarray, zero_division: str = "none"
) -> Result:
assert y_true.shape == y_pred.shape, "y_true and y_pred must be same shape"
assert isinstance(y_pred, np.ndarray) and isinstance(
y_true, np.ndarray
), "y_true and y_pred must be numpy arrays"

tp, _, fn, _ = _tp_fp_fn_tn(y_true, y_pred)
return _recall(tp, fn, zero_division)
tp, tp_fn = _binary_recall_reqs(y_true, y_pred)
return _recall(tp, tp_fn, zero_division)


def f1_score(y_true: np.ndarray, y_pred: np.ndarray, zero_division: str = "none"):
def binary_f1_score(
y_true: np.ndarray, y_pred: np.ndarray, zero_division: str = "none"
):
assert y_true.shape == y_pred.shape, "y_true and y_pred must be same shape"
assert isinstance(y_pred, np.ndarray) and isinstance(
y_true, np.ndarray
), "y_true and y_pred must be numpy arrays"

tp, fp, fn, _ = _tp_fp_fn_tn(y_true, y_pred)
p, r = _precision(tp, fp, "0"), _recall(tp, fn, "0")
tp, tp_fp, tp_fn = _binary_f1_score_reqs(y_true, y_pred)
p, r = _precision(tp, tp_fp, "zero"), _recall(tp, tp_fn, "zero")

if p + r == 0:
if zero_division == "none":
Expand Down
11 changes: 8 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "fast-stats"
version = "0.0.1"
version = "0.0.2"
description = "A fast simple library for calculating basic statistics"
readme = "README.md"
license = {file = "LICENSE"}
Expand All @@ -18,8 +18,8 @@ dependencies = [

[project.optional-dependencies]
test = [
# "pytest",
# "pytest-cov[all]"
"pytest",
"pytest-cov[all]"
]

[project.urls]
Expand All @@ -29,3 +29,8 @@ repository = "https://github.com/zachcoleman/fast-stats"
requires = ["maturin>=0.12,<0.13"]
build-backend = "maturin"

[tool.maturin]
strip = true

[tool.isort]
profile="black"
Loading

0 comments on commit d50aed9

Please sign in to comment.