Skip to content

Commit

Permalink
Merge pull request jax-ml#20923 from pearu:pearu/asinh-2
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 629130313
  • Loading branch information
jax authors committed Apr 29, 2024
2 parents 1b5c49e + ee5c134 commit b44e9bf
Showing 1 changed file with 27 additions and 10 deletions.
37 changes: 27 additions & 10 deletions jax/_src/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1796,12 +1796,12 @@ def exp2(self, x):
def arcsin(self, x):
ctx = x.context
if isinstance(x, ctx.mpc):
# Workaround mpmath 1.3 bug in asin(+-inf+-infj) evaluation (see mpmath/mpmath#793).
# TODO(pearu): remove this function when mpmath 1.4 or newer
# will be the required test dependency.
# Workaround mpmath 1.3 bug in asin(+-inf+-infj) evaluation (see
# mpmath/mpmath#793).
# TODO(pearu): remove the if-block below when mpmath 1.4 or
# newer will be the required test dependency.
pi = ctx.pi
inf = ctx.inf
nan = ctx.nan
zero = ctx.zero
if ctx.isinf(x.real):
sign_real = -1 if x.real < 0 else 1
Expand All @@ -1811,9 +1811,10 @@ def arcsin(self, x):
elif ctx.isinf(x.imag):
return ctx.make_mpc((zero._mpf_, x.imag._mpf_))

# TODO(pearu): adjust this code according to mpmath/mpmath#786
# resolution when mpmath 1.4 or newer will be the required test
# dependency.
# On branch cut, mpmath.mp.asin returns different value compared
# to mpmath.fp.asin and numpy.arcsin (see
# mpmath/mpmath#786). The following if-block ensures
# compatibiliy with numpy.arcsin.
if x.real > 1 and x.imag == 0:
return ctx.asin(x).conjugate()

Expand All @@ -1822,10 +1823,26 @@ def arcsin(self, x):
def arcsinh(self, x):
ctx = x.context

# TODO(pearu): adjust this code according to mpmath/mpmath#786
# resolution when mpmath 1.4 or newer will be the required test
# dependency.
if isinstance(x, ctx.mpc):
# Workaround mpmath 1.3 bug in asinh(+-inf+-infj) evaluation
# (see mpmath/mpmath#749).
# TODO(pearu): remove the if-block below when mpmath 1.4 or
# newer will be the required test dependency.
pi = ctx.pi
inf = ctx.inf
zero = ctx.zero
if ctx.isinf(x.imag):
sign_imag = -1 if x.imag < 0 else 1
real = -inf if x.real < 0 else inf
imag = sign_imag * pi / (4 if ctx.isinf(x.real) else 2)
return ctx.make_mpc((real._mpf_, imag._mpf_))
elif ctx.isinf(x.real):
return ctx.make_mpc((x.real._mpf_, zero._mpf_))

# On branch cut, mpmath.mp.asinh returns different value
# compared to mpmath.fp.asinh and numpy.arcsinh (see
# mpmath/mpmath#786). The following if-block ensures
# compatibiliy with numpy.arcsinh.
if x.real == 0 and x.imag < -1:
return (-ctx.asinh(x)).conjugate()
return ctx.asinh(x)
Expand Down

0 comments on commit b44e9bf

Please sign in to comment.