-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Add Weighted Quantile and Percentile Support to jax.numpy #32737
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
6a60ed4
Add weighted quantile and percentile support with tests
Aniketsy f5d2177
Add weighted quantile and percentile support with tests
Aniketsy f7ab683
Add weighted quantile and percentile support with tests
Aniketsy 5f881bf
Add weighted quantile and percentile support with tests
Aniketsy 7b967cb
Add weighted quantile and percentile support with tests
Aniketsy 4f522e6
Add weighted quantile and percentile support with tests
Aniketsy cef1731
Add weighted quantile and percentile support with tests
Aniketsy a230e01
Add weighted quantile and percentile support with tests
Aniketsy 7d50a32
Add weighted quantile and percentile support with tests
Aniketsy ca1d95b
Add weighted quantile and percentile support with tests
Aniketsy 4ebbf21
Add weighted quantile and percentile support with tests
Aniketsy 43cdb5c
Add weighted quantile and percentile support with tests
Aniketsy d9c4614
Merge branch 'main' into support-32647
Aniketsy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,7 +27,7 @@ | |
| from jax._src import core | ||
| from jax._src import dtypes | ||
| from jax._src.numpy.util import ( | ||
| _broadcast_to, ensure_arraylike, | ||
| _broadcast_to, check_arraylike, ensure_arraylike, _complex_elem_type, | ||
| promote_dtypes_inexact, promote_dtypes_numeric, _where) | ||
| from jax._src.lax import control_flow | ||
| from jax._src.lax import lax as lax | ||
|
|
@@ -2376,7 +2376,8 @@ def cumulative_prod( | |
| @api.jit(static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) | ||
| def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, | ||
| out: None = None, overwrite_input: bool = False, method: str = "linear", | ||
| keepdims: bool = False, *, interpolation: DeprecatedArg = DeprecatedArg()) -> Array: | ||
| keepdims: bool = False, *, weights: ArrayLike | None = None, | ||
| interpolation: DeprecatedArg | str = DeprecatedArg()) -> Array: | ||
| """Compute the quantile of the data along the specified axis. | ||
|
|
||
| JAX implementation of :func:`numpy.quantile`. | ||
|
|
@@ -2414,22 +2415,26 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No | |
| >>> jnp.quantile(x, q, method='nearest') | ||
| Array([2., 4., 7.], dtype=float32) | ||
| """ | ||
| a, q = ensure_arraylike("quantile", a, q) | ||
| if weights is None: | ||
| a, q = ensure_arraylike("quantile", a, q) | ||
| else: | ||
| a, q, weights = ensure_arraylike("quantile", a, q, weights) | ||
| if overwrite_input or out is not None: | ||
| raise ValueError("jax.numpy.quantile does not support overwrite_input=True " | ||
| "or out != None") | ||
| # TODO(jakevdp): remove the interpolation argument in JAX v0.9.0 | ||
| if not isinstance(interpolation, DeprecatedArg): | ||
| raise TypeError("quantile() argument interpolation was removed in JAX" | ||
| " v0.8.0. Use method instead.") | ||
| return _quantile(lax.asarray(a), lax.asarray(q), axis, method, keepdims, False) | ||
| return _quantile(a, q, axis, method, keepdims, False, weights) | ||
|
|
||
| # TODO(jakevdp): interpolation argument deprecated 2024-05-16 | ||
| @export | ||
| @api.jit(static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) | ||
| def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, | ||
| out: None = None, overwrite_input: bool = False, method: str = "linear", | ||
| keepdims: bool = False, *, interpolation: DeprecatedArg = DeprecatedArg()) -> Array: | ||
| keepdims: bool = False, *, weights: ArrayLike | None = None, | ||
| interpolation: DeprecatedArg | str = DeprecatedArg()) -> Array: | ||
| """Compute the quantile of the data along the specified axis, ignoring NaNs. | ||
|
|
||
| JAX implementation of :func:`numpy.nanquantile`. | ||
|
|
@@ -2468,7 +2473,10 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = | |
| >>> jnp.nanquantile(x, q) | ||
| Array([1.5, 3. , 4.5], dtype=float32) | ||
| """ | ||
| a, q = ensure_arraylike("nanquantile", a, q) | ||
| if weights is None: | ||
| a, q = ensure_arraylike("nanquantile", a, q) | ||
| else: | ||
| a, q, weights = ensure_arraylike("nanquantile", a, q, weights) | ||
| if overwrite_input or out is not None: | ||
| msg = ("jax.numpy.nanquantile does not support overwrite_input=True or " | ||
| "out != None") | ||
|
|
@@ -2477,13 +2485,12 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = | |
| if not isinstance(interpolation, DeprecatedArg): | ||
| raise TypeError("nanquantile() argument interpolation was removed in JAX" | ||
| " v0.8.0. Use method instead.") | ||
| return _quantile(lax.asarray(a), lax.asarray(q), axis, method, keepdims, True) | ||
| return _quantile(a, q, axis, method, keepdims, True, weights) | ||
|
|
||
| def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, | ||
| method: str, keepdims: bool, squash_nans: bool) -> Array: | ||
| if method not in ["linear", "lower", "higher", "midpoint", "nearest"]: | ||
| raise ValueError("method can only be 'linear', 'lower', 'higher', 'midpoint', or 'nearest'") | ||
| a, = promote_dtypes_inexact(a) | ||
| method: str, keepdims: bool, squash_nans: bool, weights: Array | None = None) -> Array: | ||
| if method not in ["linear", "lower", "higher", "midpoint", "nearest", "inverted_cdf"]: | ||
| raise ValueError("method can only be 'linear', 'lower', 'higher', 'midpoint', 'nearest', or 'inverted_cdf'") | ||
| keepdim = [] | ||
| if dtypes.issubdtype(a.dtype, np.complexfloating): | ||
| raise ValueError("quantile does not support complex input, as the operation is poorly defined.") | ||
|
|
@@ -2513,12 +2520,77 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, | |
| else: | ||
| axis = canonicalize_axis(axis, a.ndim) | ||
|
|
||
| q, = promote_dtypes_inexact(q) | ||
| q_was_scalar = q.ndim == 0 | ||
| if q_was_scalar: | ||
| q = lax.expand_dims(q, (0,)) | ||
| q_shape = q.shape | ||
| q_ndim = q.ndim | ||
| if q_ndim > 1: | ||
| raise ValueError(f"q must be have rank <= 1, got shape {q.shape}") | ||
|
|
||
| a_shape = a.shape | ||
| # Handle weights | ||
| if weights is None: | ||
| a, = promote_dtypes_inexact(a) | ||
| else: | ||
| if method != "inverted_cdf": | ||
| raise ValueError("Weighted quantiles are only supported for method='inverted_cdf'") | ||
| if axis is None: | ||
| raise TypeError("Axis must be specified when shapes of a and weights differ.") | ||
| axis_tuple = canonicalize_axis_tuple(axis, a.ndim) | ||
|
|
||
| a, q, weights = promote_dtypes_inexact(a, q, weights) | ||
| a_shape = a.shape | ||
| w_shape = np.shape(weights) | ||
| if np.ndim(weights) == 0: | ||
| weights = lax.broadcast_in_dim(weights, a_shape, ()) | ||
| w_shape = a_shape | ||
| if w_shape != a_shape: | ||
| expected_shape = tuple(a_shape[i] for i in axis_tuple) | ||
| if w_shape != expected_shape: | ||
| raise ValueError(f"Shape of weights must match the shape of the axes being reduced. " | ||
| f"Expected {expected_shape}, got {w_shape}") | ||
| weights = lax.broadcast_in_dim( | ||
| weights, | ||
| shape=a_shape, | ||
| broadcast_dimensions=axis_tuple | ||
| ) | ||
|
|
||
| if squash_nans: | ||
| nan_mask = ~lax_internal._isnan(a) | ||
| weights = _where(nan_mask, weights, 0) | ||
| else: | ||
| with config.debug_nans(False): | ||
| has_nan_data = any(lax_internal._isnan(a), axis=axis, keepdims=True) | ||
| has_nan_weights = any(lax_internal._isnan(weights), axis=axis, keepdims=True) | ||
| a = _where(has_nan_data | has_nan_weights, np.nan, a) | ||
|
|
||
| total_weight = sum(weights, axis=axis, keepdims=True) | ||
| a_sorted, weights_sorted = lax_internal.sort_key_val(a, weights, dimension=axis) | ||
| cum_weights = cumsum(weights_sorted, axis=axis) | ||
| cum_weights_norm = lax.div(cum_weights, total_weight) | ||
|
|
||
| def _weighted_quantile(qi): | ||
| qi = lax.convert_element_type(qi, cum_weights_norm.dtype) | ||
| index_dtype = dtypes.default_int_dtype() | ||
| idx = _reduce_sum(lax.lt(cum_weights_norm, qi), axis=axis, dtype=index_dtype, keepdims=keepdims) | ||
| idx = lax.clamp(_lax_const(idx, 0), idx, _lax_const(idx, a_sorted.shape[axis] - 1)) | ||
|
|
||
| idx_expanded = lax.expand_dims(idx, (axis,)) if not keepdims else idx | ||
| return jnp.take_along_axis(a_sorted, idx_expanded, axis=axis).squeeze(axis=axis) | ||
| result = api.vmap(_weighted_quantile)(q) | ||
| shape_after = list(a_shape) | ||
| if keepdims: | ||
| shape_after[axis] = 1 | ||
| else: | ||
| del shape_after[axis] | ||
| if not q_was_scalar: | ||
| result = result.reshape((q_shape[0], *shape_after)) | ||
| else: | ||
| if result.ndim > 0 and result.shape[0] == 1: | ||
| result = result.reshape(tuple(shape_after)) | ||
| return result | ||
|
|
||
| if squash_nans: | ||
| a = _where(lax._isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end. | ||
|
|
@@ -2593,14 +2665,18 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, | |
| pred = lax.le(high_weight, lax._const(high_weight, 0.5)) | ||
| result = lax.select(pred, low_value, high_value) | ||
| elif method == "midpoint": | ||
| result = lax.mul(lax.add(low_value, high_value), lax._const(low_value, 0.5)) | ||
| result = lax.mul(lax.add(low_value, high_value), _lax_const(low_value, 0.5)) | ||
| elif method == "inverted_cdf": | ||
| result = high_value | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here: this is not consistent with numpy. |
||
| else: | ||
| raise ValueError(f"{method=!r} not recognized") | ||
| if keepdims and keepdim: | ||
| if q_ndim > 0: | ||
| keepdim = [np.shape(q)[0], *keepdim] | ||
| result = result.reshape(keepdim) | ||
| return lax.convert_element_type(result, a.dtype) | ||
| keepdim_out = list(keepdim) | ||
| if not q_was_scalar: | ||
| keepdim_out = [q_shape[0], *keepdim_out] | ||
| result = result.reshape(tuple(keepdim_out)) | ||
| elif q_was_scalar and result.ndim > 0 and result.shape[0] == 1: | ||
| result = result.squeeze(axis=0) | ||
| return result | ||
|
|
||
|
|
||
| # TODO(jakevdp): interpolation argument deprecated 2024-05-16 | ||
|
|
@@ -2609,7 +2685,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, | |
| def percentile(a: ArrayLike, q: ArrayLike, | ||
| axis: int | tuple[int, ...] | None = None, | ||
| out: None = None, overwrite_input: bool = False, method: str = "linear", | ||
| keepdims: bool = False, *, interpolation: DeprecatedArg = DeprecatedArg()) -> Array: | ||
| keepdims: bool = False, *, weights: ArrayLike | None = None, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array: | ||
| """Compute the percentile of the data along the specified axis. | ||
|
|
||
| JAX implementation of :func:`numpy.percentile`. | ||
|
|
@@ -2647,14 +2723,17 @@ def percentile(a: ArrayLike, q: ArrayLike, | |
| >>> jnp.percentile(x, q, method='nearest') | ||
| Array([1., 3., 4.], dtype=float32) | ||
| """ | ||
| a, q = ensure_arraylike("percentile", a, q) | ||
| if weights is None: | ||
| a, q = ensure_arraylike("percentile", a, q) | ||
| else: | ||
| a, q, weights = ensure_arraylike("percentile", a, q, weights) | ||
| q, = promote_dtypes_inexact(q) | ||
| # TODO(jakevdp): remove the interpolation argument in JAX v0.9.0 | ||
| if not isinstance(interpolation, DeprecatedArg): | ||
| raise TypeError("percentile() argument interpolation was removed in JAX" | ||
| " v0.8.0. Use method instead.") | ||
| return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input, | ||
| method=method, keepdims=keepdims) | ||
| method=method, keepdims=keepdims, weights=weights) | ||
|
|
||
|
|
||
| # TODO(jakevdp): interpolation argument deprecated 2024-05-16 | ||
|
|
@@ -2663,7 +2742,7 @@ def percentile(a: ArrayLike, q: ArrayLike, | |
| def nanpercentile(a: ArrayLike, q: ArrayLike, | ||
| axis: int | tuple[int, ...] | None = None, | ||
| out: None = None, overwrite_input: bool = False, method: str = "linear", | ||
| keepdims: bool = False, *, interpolation: DeprecatedArg = DeprecatedArg()) -> Array: | ||
| keepdims: bool = False, *, weights: ArrayLike | None = None, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array: | ||
| """Compute the percentile of the data along the specified axis, ignoring NaN values. | ||
|
|
||
| JAX implementation of :func:`numpy.nanpercentile`. | ||
|
|
@@ -2703,15 +2782,18 @@ def nanpercentile(a: ArrayLike, q: ArrayLike, | |
| >>> jnp.nanpercentile(x, q) | ||
| Array([1.5, 3. , 4.5], dtype=float32) | ||
| """ | ||
| a, q = ensure_arraylike("nanpercentile", a, q) | ||
| if weights is None: | ||
| a, q = ensure_arraylike("nanpercentile", a, q) | ||
| else: | ||
| a, q, weights = ensure_arraylike("nanpercentile", a, q, weights) | ||
| q, = promote_dtypes_inexact(q) | ||
| q = q / 100 | ||
| # TODO(jakevdp): remove the interpolation argument in JAX v0.9.0 | ||
| if not isinstance(interpolation, DeprecatedArg): | ||
| raise TypeError("nanpercentile() argument interpolation was removed in JAX" | ||
| " v0.8.0. Use method instead.") | ||
| return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input, | ||
| method=method, keepdims=keepdims) | ||
| method=method, keepdims=keepdims, weights=weights) | ||
|
|
||
|
|
||
| @export | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.