Skip to content

Commit

Permalink
Update references to JAX's GitHub repo
Browse files Browse the repository at this point in the history
JAX has moved from https://github.com/google/jax to https://github.com/jax-ml/jax

PiperOrigin-RevId: 703058838
  • Loading branch information
jakeharmon8 authored and copybara-github committed Dec 5, 2024
1 parent 72c6db4 commit 19a2ed6
Show file tree
Hide file tree
Showing 7 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ Haiku is written in pure Python, but depends on C++ code via JAX.
Because JAX installation is different depending on your CUDA version, Haiku does
not list JAX as a dependency in `requirements.txt`.

First, follow [these instructions](https://github.com/google/jax#installation)
First, follow [these instructions](https://github.com/jax-ml/jax#installation)
to install JAX with the relevant accelerator support.

Then, install Haiku using pip:
Expand Down Expand Up @@ -462,7 +462,7 @@ In this bibtex entry, the version number is intended to be from
[`haiku/__init__.py`](https://github.com/deepmind/dm-haiku/blob/main/haiku/__init__.py),
and the year corresponds to the project's open-source release.

[JAX]: https://github.com/google/jax
[JAX]: https://github.com/jax-ml/jax
[Sonnet]: https://github.com/deepmind/sonnet
[Tensorflow]: https://github.com/tensorflow/tensorflow
[Flax]: https://github.com/google/flax
Expand Down
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ abstractions for machine learning research.
Installation
------------

See https://github.com/google/jax#pip-installation for instructions on
See https://github.com/jax-ml/jax#pip-installation for instructions on
installing JAX.

We suggest installing the latest version of Haiku by running::
Expand Down
2 changes: 1 addition & 1 deletion docs/notebooks/non_trainable.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
"output_type": "stream",
"text": [
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n",
"/tmp/haiku-docs-env/lib/python3.8/site-packages/jax/_src/lax/lax.py:6271: UserWarning: Explicitly requested dtype float64 requested in zeros is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.\n",
"/tmp/haiku-docs-env/lib/python3.8/site-packages/jax/_src/lax/lax.py:6271: UserWarning: Explicitly requested dtype float64 requested in zeros is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.\n",
" warnings.warn(msg.format(dtype, fun_name , truncated_dtype))\n"
]
},
Expand Down
2 changes: 1 addition & 1 deletion examples/haiku_lstms.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
"source": [
"# LSTMs in Haiku\n",
"\n",
"**[Haiku](https://github.com/deepmind/dm-haiku) is a simple neural network library for [JAX](https://github.com/google/jax).**\n",
"**[Haiku](https://github.com/deepmind/dm-haiku) is a simple neural network library for [JAX](https://github.com/jax-ml/jax).**\n",
"\n",
"This notebook walks through a simple LSTM in JAX with Haiku.\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/impala_lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def step(
def loss(self, params: hk.Params, trajs: Transition) -> jax.Array:
"""Computes a loss of trajs wrt params."""
# Re-run the agent over the trajectories.
# Due to https://github.com/google/jax/issues/1459, we use hk.BatchApply
# Due to https://github.com/jax-ml/jax/issues/1459, we use hk.BatchApply
# instead of vmap.
# BatchApply turns the input tensors from [T, B, ...] into [T*B, ...].
# We `functools.partial` params in so it does not get transformed.
Expand Down
2 changes: 1 addition & 1 deletion haiku/_src/batch_norm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def test_no_offset_beta_init_provided(self):
offset_init=jnp.zeros)

def test_eps_cast_to_var_dtype(self):
# See https://github.com/google/jax/issues/4718 for more info. In the
# See https://github.com/jax-ml/jax/issues/4718 for more info. In the
# context of this test we need to assert NumPy bf16 params/state and a
# Python float for eps preserve bf16 output.

Expand Down
2 changes: 1 addition & 1 deletion haiku/_src/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def __call__(
# it along the row dimension and treat each row as a separate index into
# one of the dimensions of the array. The error only surfaces when
# indexing with DeviceArray, while indexing with numpy.ndarray works fine.
# See https://github.com/google/jax/issues/620 for more details.
# See https://github.com/jax-ml/jax/issues/620 for more details.
# Cast to a jnp array in case `ids` is a tracer (eg un a dynamic_unroll).
return jnp.asarray(self.embeddings)[(ids,)]

Expand Down

0 comments on commit 19a2ed6

Please sign in to comment.