Skip to content

Commit

Permalink
Merge pull request #22028 from rajasekharporeddy:stats-sem
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 645518083
  • Loading branch information
jax authors committed Jun 21, 2024
2 parents 8c7e0d4 + edde7d9 commit 56e8fe6
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
10 changes: 4 additions & 6 deletions jax/_src/scipy/stats/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,13 +237,11 @@ def sem(a: ArrayLike, axis: int | None = 0, ddof: int = 1, nan_policy: str = "pr
array
"""
b, = promote_args_inexact("sem", a)
if axis is None:
b = b.ravel()
axis = 0
if nan_policy == "propagate":
return b.std(axis, ddof=ddof) / jnp.sqrt(b.shape[axis]).astype(b.dtype)
size = b.size if axis is None else b.shape[axis]
return b.std(axis, ddof=ddof, keepdims=keepdims) / jnp.sqrt(size).astype(b.dtype)
elif nan_policy == "omit":
count = (~jnp.isnan(b)).sum(axis)
return jnp.nanstd(b, axis, ddof=ddof) / jnp.sqrt(count).astype(b.dtype)
count = (~jnp.isnan(b)).sum(axis, keepdims=keepdims)
return jnp.nanstd(b, axis, ddof=ddof, keepdims=keepdims) / jnp.sqrt(count).astype(b.dtype)
else:
raise ValueError(f"{nan_policy} is not supported")
12 changes: 8 additions & 4 deletions tests/scipy_stats_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1633,21 +1633,25 @@ def testRankData(self, shape, dtype, axis, method):
self._CompileAndCheck(lax_fun, args_maker, rtol=tol)

@jtu.sample_product(
[dict(shape=shape, axis=axis, ddof=ddof, nan_policy=nan_policy)
[dict(shape=shape, axis=axis, ddof=ddof, nan_policy=nan_policy, keepdims=keepdims)
for shape in [(5,), (5, 6), (5, 6, 7)]
for axis in [None, *range(len(shape))]
for ddof in [0, 1, 2, 3]
for nan_policy in ["propagate", "omit"]
for keepdims in [True, False]
],
dtype=jtu.dtypes.integer + jtu.dtypes.floating,
)
def testSEM(self, shape, dtype, axis, ddof, nan_policy):
def testSEM(self, shape, dtype, axis, ddof, nan_policy, keepdims):

rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]

scipy_fun = partial(osp_stats.sem, axis=axis, ddof=ddof, nan_policy=nan_policy)
lax_fun = partial(lsp_stats.sem, axis=axis, ddof=ddof, nan_policy=nan_policy)
kwds = {} if scipy_version < (1, 11) else {'keepdims': keepdims}
scipy_fun = partial(osp_stats.sem, axis=axis, ddof=ddof, nan_policy=nan_policy,
**kwds)
lax_fun = partial(lsp_stats.sem, axis=axis, ddof=ddof, nan_policy=nan_policy,
**kwds)
tol_spec = {np.float32: 2e-4, np.float64: 5e-6}
tol = jtu.tolerance(dtype, tol_spec)
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
Expand Down

0 comments on commit 56e8fe6

Please sign in to comment.