Skip to content

Commit

Permalink
Add quantile_tdigest
Browse files Browse the repository at this point in the history
Co-authored-by: Florian Jetter <[email protected]>
  • Loading branch information
dcherian and fjetter committed Nov 7, 2023
1 parent d4ec17c commit 8e485bc
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 1 deletion.
13 changes: 12 additions & 1 deletion flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
from numpy.typing import DTypeLike

from . import aggregate_flox, aggregate_npg, xrutils
from . import aggregate_flox, aggregate_npg, sketches, xrutils
from . import xrdtypes as dtypes

if TYPE_CHECKING:
Expand Down Expand Up @@ -505,6 +505,16 @@ def _pick_second(*x):
mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None)
nanmode = Aggregation(name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None)


quantile_tdigest = Aggregation(
"quantile_tdigest",
numpy=(sketches.tdigest_aggregate,),
chunk=(sketches.tdigest_chunk,),
combine=(sketches.tdigest_combine,),
finalize=sketches.tdigest_aggregate,
)


aggregations = {
"any": any_,
"all": all_,
Expand Down Expand Up @@ -537,6 +547,7 @@ def _pick_second(*x):
"nanquantile": nanquantile,
"mode": mode,
"nanmode": nanmode,
"quantile_tdigest": quantile_tdigest,
}


Expand Down
35 changes: 35 additions & 0 deletions flox/sketches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import numpy as np
import numpy_groupies as npg


def tdigest_chunk(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None, **kwargs):
from crick import TDigest

def _(arr):
digest = TDigest()
# we receive object arrays from numpy_groupies
digest.update(arr.astype(array.dtype, copy=False))
return digest

result = npg.aggregate_numpy.aggregate(group_idx, array, func=_, axis=axis, dtype=object)
return result


def tdigest_combine(digests, axis=-1, keepdims=True):
from crick import TDigest

def _(arr):
t = TDigest()
t.merge(*arr)
return np.array([t], dtype=object)

(axis,) = axis
result = np.apply_along_axis(_, axis, digests)

return result


def tdigest_aggregate(digests, q, axis=-1, keepdims=True):
for idx in np.ndindex(digests.shape):
digests[idx] = digests[idx].quantile(q)
return digests

0 comments on commit 8e485bc

Please sign in to comment.