|
23 | 23 |
|
24 | 24 | import numpy as np |
25 | 25 |
|
26 | | -import jax |
27 | | -from jax import lax |
28 | | -import jax._src.numpy as jnp |
| 26 | +from jax._src.lax import lax |
29 | 27 | from jax._src import api |
30 | 28 | from jax._src import core |
31 | 29 | from jax._src import deprecations |
@@ -2486,7 +2484,8 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, |
2486 | 2484 |
|
2487 | 2485 | q, = promote_dtypes_inexact(q) |
2488 | 2486 | q = lax_internal.asarray(q) |
2489 | | - if getattr(q, "ndim", 0) == 0: |
| 2487 | + q_was_scalar = getattr(q, "ndim", 0) == 0 |
| 2488 | + if q_was_scalar: |
2490 | 2489 | q = lax.expand_dims(q, (0,)) |
2491 | 2490 | q_shape = q.shape |
2492 | 2491 | q_ndim = q.ndim |
@@ -2534,7 +2533,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, |
2534 | 2533 | a = _where(any(lax_internal._isnan(a), axis=axis, keepdims=True), np.nan, a) |
2535 | 2534 |
|
2536 | 2535 | total_weight = sum(weights, axis=axis, keepdims=True) |
2537 | | - a_sorted, weights_sorted = lax.sort_key_val(a, weights, dimension=axis) |
| 2536 | + a_sorted, weights_sorted = lax_internal.sort_key_val(a, weights, dimension=axis) |
2538 | 2537 | cum_weights = cumsum(weights_sorted, axis=axis) |
2539 | 2538 | cum_weights_norm = lax.div(cum_weights, total_weight) |
2540 | 2539 |
|
@@ -2576,17 +2575,16 @@ def _weighted_quantile(qi): |
2576 | 2575 | else: |
2577 | 2576 | raise ValueError(f"{method=!r} not recognized") |
2578 | 2577 | return out |
| 2578 | + result = api.vmap(_weighted_quantile)(q) |
| 2579 | + keepdim_out = list(keepdim) |
| 2580 | + if not q_was_scalar: |
| 2581 | + keepdim_out = [q_shape[0], *keepdim_out] |
| 2582 | + result = result.reshape(tuple(keepdim_out)) |
| 2583 | + elif q_was_scalar and result.ndim > 0 and result.shape[0] == 1: |
| 2584 | + result = result.squeeze(axis=0) |
| 2585 | + return result |
| 2586 | + return result |
2579 | 2587 |
|
2580 | | - result = jax.vmap(_weighted_quantile)(q) |
2581 | | - if keepdims and keepdim: |
2582 | | - if q_ndim > 0: |
2583 | | - keepdim = [q_shape[0], *keepdim] |
2584 | | - result = result.reshape(tuple(keepdim)) |
2585 | | - else: |
2586 | | - if q_ndim == 0 or (q_ndim == 1 and q_shape[0] == 1): |
2587 | | - if result.ndim > 0 and result.shape[0] == 1: |
2588 | | - result = lax.squeeze(result, (0,)) |
2589 | | - return lax.convert_element_type(result, a.dtype) |
2590 | 2588 |
|
2591 | 2589 | if squash_nans: |
2592 | 2590 | a = _where(lax_internal._isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end. |
@@ -2666,15 +2664,13 @@ def _weighted_quantile(qi): |
2666 | 2664 | result = high_value |
2667 | 2665 | else: |
2668 | 2666 | raise ValueError(f"{method=!r} not recognized") |
2669 | | - if keepdims and keepdim: |
2670 | | - if q_ndim > 0: |
2671 | | - keepdim = [np.shape(q)[0], *keepdim] |
2672 | | - result = result.reshape(keepdim) |
2673 | | - else: |
2674 | | - if q_ndim == 0 or (q_ndim == 1 and q_shape[0] == 1): |
2675 | | - if result.ndim > 0 and result.shape[0] == 1: |
2676 | | - result = lax.squeeze(result, (0,)) |
2677 | | - return lax.convert_element_type(result, a.dtype) |
| 2667 | + keepdim_out = list(keepdim) |
| 2668 | + if not q_was_scalar: |
| 2669 | + keepdim_out = [q_shape[0], *keepdim_out] |
| 2670 | + result = result.reshape(tuple(keepdim_out)) |
| 2671 | + elif q_was_scalar and result.ndim > 0 and result.shape[0] == 1: |
| 2672 | + result = result.squeeze(axis=0) |
| 2673 | + return result |
2678 | 2674 |
|
2679 | 2675 |
|
2680 | 2676 | # TODO(jakevdp): interpolation argument deprecated 2024-05-16 |
|
0 commit comments