diff --git a/jax/_src/scipy/stats/_core.py b/jax/_src/scipy/stats/_core.py index 6f4f1f40a8fc..3df940e8364d 100644 --- a/jax/_src/scipy/stats/_core.py +++ b/jax/_src/scipy/stats/_core.py @@ -237,13 +237,11 @@ def sem(a: ArrayLike, axis: int | None = 0, ddof: int = 1, nan_policy: str = "pr array """ b, = promote_args_inexact("sem", a) - if axis is None: - b = b.ravel() - axis = 0 if nan_policy == "propagate": - return b.std(axis, ddof=ddof) / jnp.sqrt(b.shape[axis]).astype(b.dtype) + size = b.size if axis is None else b.shape[axis] + return b.std(axis, ddof=ddof, keepdims=keepdims) / jnp.sqrt(size).astype(b.dtype) elif nan_policy == "omit": - count = (~jnp.isnan(b)).sum(axis) - return jnp.nanstd(b, axis, ddof=ddof) / jnp.sqrt(count).astype(b.dtype) + count = (~jnp.isnan(b)).sum(axis, keepdims=keepdims) + return jnp.nanstd(b, axis, ddof=ddof, keepdims=keepdims) / jnp.sqrt(count).astype(b.dtype) else: raise ValueError(f"{nan_policy} is not supported") diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index a0fe16d0849c..501f4cbe5e5f 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -1633,21 +1633,25 @@ def testRankData(self, shape, dtype, axis, method): self._CompileAndCheck(lax_fun, args_maker, rtol=tol) @jtu.sample_product( - [dict(shape=shape, axis=axis, ddof=ddof, nan_policy=nan_policy) + [dict(shape=shape, axis=axis, ddof=ddof, nan_policy=nan_policy, keepdims=keepdims) for shape in [(5,), (5, 6), (5, 6, 7)] for axis in [None, *range(len(shape))] for ddof in [0, 1, 2, 3] for nan_policy in ["propagate", "omit"] + for keepdims in [True, False] ], dtype=jtu.dtypes.integer + jtu.dtypes.floating, ) - def testSEM(self, shape, dtype, axis, ddof, nan_policy): + def testSEM(self, shape, dtype, axis, ddof, nan_policy, keepdims): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] - scipy_fun = partial(osp_stats.sem, axis=axis, ddof=ddof, nan_policy=nan_policy) - lax_fun = partial(lsp_stats.sem, axis=axis, ddof=ddof, nan_policy=nan_policy) + kwds = {} if scipy_version < (1, 11) else {'keepdims': keepdims} + scipy_fun = partial(osp_stats.sem, axis=axis, ddof=ddof, nan_policy=nan_policy, + **kwds) + lax_fun = partial(lsp_stats.sem, axis=axis, ddof=ddof, nan_policy=nan_policy, + **kwds) tol_spec = {np.float32: 2e-4, np.float64: 5e-6} tol = jtu.tolerance(dtype, tol_spec) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,