Skip to content

Commit

Permalink
Update to remove runtime warning, minor restructure to be consistent …
Browse files Browse the repository at this point in the history
…with PyTorch export updates
  • Loading branch information
pavithraes committed Nov 4, 2024
1 parent b2016a3 commit 46adff2
Showing 1 changed file with 8 additions and 37 deletions.
45 changes: 8 additions & 37 deletions docs/tutorials/jax-export.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
"\n",
"### Install required dependencies\n",
"\n",
"We'll be using `jax` and `jaxlib` (JAX's support library with compiled binaries), along with `flax` and `transformers` for some models to export.\n",
"We use `jax` and `jaxlib` (JAX's support library with compiled binaries), along with `flax` and `transformers` for some models to export.\n",
"We also need to install `tensorflow` to work with SavedModel, and recommend using `tensorflow-cpu` or `tf-nightly` for this tutorial.\n",
"\n",
"[jax-tutorial-colab]: https://colab.research.google.com/github/openxla/stablehlo/blob/main/docs/tutorials/jax-export.ipynb\n",
"[jax-tutorial-kaggle]: https://kaggle.com/kernels/welcome?src=https://github.com/openxla/stablehlo/blob/main/docs/tutorials/jax-export.ipynb"
Expand All @@ -35,7 +36,7 @@
},
"outputs": [],
"source": [
"!pip install -U jax jaxlib flax transformers"
"!pip install -U jax jaxlib flax transformers tensorflow-cpu"
]
},
{
Expand Down Expand Up @@ -318,32 +319,6 @@
"JAX makes it easy to pack StableHLO into a SavedModel, and load that SavedModel in the future. For this section, we'll be using our dynamic model from the previous section."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Gt_bIJaSpsYf"
},
"source": [
"### Install latest TensorFlow\n",
"\n",
"SavedModel definition lives in TensorFlow, so we need to install the dependency. We recommend using `tensorflow-cpu` or `tf-nightly`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"id": "KZEd7NavBoem",
"jupyter": {
"outputs_hidden": true
}
},
"outputs": [],
"source": [
"!pip install tensorflow-cpu"
]
},
{
"cell_type": "markdown",
"metadata": {
Expand Down Expand Up @@ -374,14 +349,6 @@
"text": [
"\u001b[34massets\u001b[m\u001b[m fingerprint.pb saved_model.pb \u001b[34mvariables\u001b[m\u001b[m\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/pavithraes/mambaforge/envs/shlo-docs/lib/python3.11/pty.py:89: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
" pid, fd = os.forkpty()\n"
]
}
],
"source": [
Expand Down Expand Up @@ -442,12 +409,16 @@
"id": "a2Dsm2oF5jn4"
},
"source": [
"## Common Troubleshooting\n",
"## Troubleshooting\n",
"\n",
"### `jax.jit` issues\n",
"\n",
"If the function can be JIT'ed, then it can be exported. Ensure `jax.jit` works first, or look in desired project for uses of JIT already (for example, [AlphaFold's `apply`](https://github.com/google-deepmind/alphafold/blob/dbe2a438ebfc6289f960292f15dbf421a05e563d/alphafold/model/model.py#L89) can be exported easily). \n",
"\n",
"See [JAX's JIT compilation documentation](https://jax.readthedocs.io/en/latest/jit-compilation.html) and [`jax.jit` API reference and examples](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) for troubleshooting JIT transformations. The most common issue is control flow, which can often be resolved with `static_argnums` / `static_argnames` as in the linked example.\n",
"\n",
"### Support tickets\n",
"\n",
"You can open an issue on GitHub for further help. Include a reproducible example using one of the above APIs in your issue report, this will help get the issue resolved much quicker!"
]
}
Expand Down

0 comments on commit 46adff2

Please sign in to comment.