From 390e90361ae87e23e175a62a3bd526ba0114b164 Mon Sep 17 00:00:00 2001 From: ilay menahem <110230889+IlayMenahem@users.noreply.github.com> Date: Sat, 13 Jan 2024 13:01:16 +0000 Subject: [PATCH] Add .hypothesis/ directory to .gitignore and ppf and cdf to scipy.stats.uniform --- .gitignore | 1 + docs/jax.scipy.rst | 2 ++ jax/_src/scipy/stats/uniform.py | 19 +++++++++++++++++++ jax/scipy/stats/uniform.py | 2 ++ tests/scipy_stats_test.py | 31 +++++++++++++++++++++++++++++++ 5 files changed, 55 insertions(+) diff --git a/.gitignore b/.gitignore index 113e34fac244..83f1780df946 100644 --- a/.gitignore +++ b/.gitignore @@ -21,6 +21,7 @@ .envrc jax.iml .bazelrc.user +.hypothesis/ # virtualenv/venv directories /venv/ diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index 6d10f5071532..358254a99e08 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -425,6 +425,8 @@ jax.scipy.stats.uniform logpdf pdf + cdf + ppf jax.scipy.stats.gaussian_kde ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/jax/_src/scipy/stats/uniform.py b/jax/_src/scipy/stats/uniform.py index 2d0c75778081..7fae1408c309 100644 --- a/jax/_src/scipy/stats/uniform.py +++ b/jax/_src/scipy/stats/uniform.py @@ -16,6 +16,7 @@ import scipy.stats as osp_stats from jax import lax +from jax import numpy as jnp from jax.numpy import where, inf, logical_or from jax._src.typing import Array, ArrayLike from jax._src.numpy.util import _wraps, promote_args_inexact @@ -32,3 +33,21 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: @_wraps(osp_stats.uniform.pdf, update_doc=False) def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.exp(logpdf(x, loc, scale)) + +@_wraps(osp_stats.uniform.cdf, update_doc=False) +def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + x, loc, scale = promote_args_inexact("uniform.cdf", x, loc, scale) + zero, one = jnp.array(0, x.dtype), jnp.array(1, x.dtype) + conds = [lax.lt(x, loc), lax.gt(x, lax.add(loc, scale)), lax.ge(x, loc) & lax.le(x, lax.add(loc, scale))] + vals = [zero, one, lax.div(lax.sub(x, loc), scale)] + + return jnp.select(conds, vals) + +@_wraps(osp_stats.uniform.ppf, update_doc=False) +def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + q, loc, scale = promote_args_inexact("uniform.ppf", q, loc, scale) + return where( + jnp.isnan(q) | (q < 0) | (q > 1), + jnp.nan, + lax.add(loc, lax.mul(scale, q)) + ) diff --git a/jax/scipy/stats/uniform.py b/jax/scipy/stats/uniform.py index c485034d8938..d0a06c673b3c 100644 --- a/jax/scipy/stats/uniform.py +++ b/jax/scipy/stats/uniform.py @@ -18,4 +18,6 @@ from jax._src.scipy.stats.uniform import ( logpdf as logpdf, pdf as pdf, + cdf as cdf, + ppf as ppf, ) diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index a8e51627dccb..1551dd54660f 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -1043,6 +1043,36 @@ def args_maker(): tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) + @genNamedParametersNArgs(3) + def testUniformCdf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.uniform.cdf + lax_fun = lsp_stats.uniform.cdf + + def args_maker(): + x, loc, scale = map(rng, shapes, dtypes) + return [x, loc, np.abs(scale)] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-5) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(3) + def testUniformPpf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.uniform.ppf + lax_fun = lsp_stats.uniform.ppf + + def args_maker(): + q, loc, scale = map(rng, shapes, dtypes) + return [q, loc, np.abs(scale)] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-5) + self._CompileAndCheck(lax_fun, args_maker) + @genNamedParametersNArgs(4) def testChi2LogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) @@ -1058,6 +1088,7 @@ def args_maker(): tol=5e-4) self._CompileAndCheck(lax_fun, args_maker) + @genNamedParametersNArgs(4) def testChi2LogCdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng())