We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
What considerations prevent using float32? For now we have only this statement:
jaxfit/src/jaxfit/__init__.py
Lines 9 to 11 in e6b28d4
The text was updated successfully, but these errors were encountered:
Trying to reproduce the claim in this note, I see what is probably a negligible numerical error:
import jax import jax.numpy as jnp from jax.scipy import special def poisson_logpdf(x, mu): return special.xlogy(x, mu) - mu - special.gammaln(x+1) @jax.custom_jvp def poisson_logpdf_fancy(x, mu): """Standard log(P) for Poisson Implemented with custom jvp to improve accuracy when x and mu are both very large? """ return special.xlogy(x, mu) - mu - special.gammaln(x+1) @poisson_logpdf_fancy.defjvp def _poisson_logpdf_jvp(primals, tangents): x, mu = primals dx, dmu = tangents df_dx = jnp.log(mu) - jax.grad(special.gammaln)(x+1) df_dmu = (x - mu)/mu return poisson_logpdf(x, mu), df_dx * dx + df_dmu * dmu @jax.vmap def compare(x, mu): return jax.grad(poisson_logpdf_fancy, 1)(x, mu) - jax.grad(poisson_logpdf, 1)(x, mu) x = jnp.geomspace(1, 1e8, 20, dtype=jnp.float32) mu = x + 1 fig, ax = plt.subplots() ax.plot(compare(x, mu), label="Gradient diff") ax.plot(x/mu - 1, label=r"$x/\mu-1$") ax.set_ylim(-1e-7, 1e-7) ax.legend()
Sorry, something went wrong.
No branches or pull requests
What considerations prevent using float32? For now we have only this statement:
jaxfit/src/jaxfit/__init__.py
Lines 9 to 11 in e6b28d4
The text was updated successfully, but these errors were encountered: