Skip to content

Commit

Permalink
Merge pull request #78 from epignatelli/docs-bench
Browse files Browse the repository at this point in the history
Docs bench
  • Loading branch information
epignatelli authored Jun 26, 2024
2 parents be9fba0 + 49f5798 commit 20e53e9
Show file tree
Hide file tree
Showing 9 changed files with 112 additions and 40 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
![PyPI version](https://img.shields.io/pypi/v/navix?label=PyPI&color=%230099ab)
<!-- [![arXiv](https://img.shields.io/badge/arXiv-1234.56789-b31b1b.svg?style=flat)](https://arxiv.org/abs/1234.56789) -->

**[Quickstart](#what-is-navix)** | **[Install](#installation)** | **[Examples](#examples)** | **[Docs](https://navix.readthedocs.io)** | **[The JAX ecosystem](#jax-ecosystem-for-rl)** | **[Contribute](#join-us)** | **[Cite](#cite)**
**[Quickstart](#what-is-navix)** | **[Install](#installation)** | **[Examples](#examples)** | **[Docs](https://epignatelli.com/navix)** | **[The JAX ecosystem](#jax-ecosystem-for-rl)** | **[Contribute](#join-us)** | **[Cite](#cite)**

</div>

Expand Down
7 changes: 2 additions & 5 deletions docs/assets/stylesheets/extra.css
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,11 @@
--md-text-font: "Sherpa";
} */

.md-header__button.md-logo {
margin-bottom: 0;
padding-bottom: 0;
}

.md-header__button.md-logo img,
.md-header__button.md-logo svg {
height: 3rem !important;
margin-bottom: 0em;
padding-bottom: 0;
}

.no-bottom-margin {
Expand Down
24 changes: 20 additions & 4 deletions docs/benchmarks/envs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,27 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"In this notebook, we will look at the difference in **throughput** between MiniGrid and NAVIX environments."
"In this notebook, we will look at the difference in **THROUGHPUT** between MiniGrid and NAVIX environments: how the performance scales with the number of environments.\n",
"We will still use random actions, but the environments now run in batch mode."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### TL;DR;\n",
"\n",
"NAVIX can scale up to over **$2M$ environments** in parallel in **less than $10$s**, less than the time required by MiniGrid to run a single environment.\n",
"\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Benchmarking MiniGrid\n",
"\n",
"Let's start with MiniGrid."
]
},
Expand Down Expand Up @@ -110,6 +124,9 @@
"source": [
"It scales quite linearly with the number of steps. That's a reasonable time for a single environment that runs on the CPU.\n",
"\n",
"\n",
"#### Benchmarking NAVIX\n",
"\n",
"Let's see how NAVIX compares to MiniGrid with a single environment."
]
},
Expand Down Expand Up @@ -139,14 +156,13 @@
"navix_times = {}\n",
"i = 2\n",
"while True:\n",
" if i >= 524288: # \n",
" break\n",
" try:\n",
" seeds = jax.random.split(jax.random.PRNGKey(0), i)\n",
" benchmark_navix_jit = benchmark_navix.lower(seeds).compile()\n",
" navix_times[i] = timeit.timeit(lambda: benchmark_navix_jit(seeds), number=1)\n",
" i *= 2\n",
" except:\n",
" print(\"Max num_envs reached\", i)\n",
" break"
]
},
Expand Down Expand Up @@ -191,7 +207,7 @@
"source": [
"## Conclusions\n",
"\n",
"NAVIX can scale up to $2^21 = 2097152$ environments (over $2M$ environments in parallel!) and still do that in less than the time required by MiniGrid to run a single environment."
"NAVIX can scale up to $2^{21} = 2097152$ environments (over $2M$ environments in parallel!) and still do that in less than the time required by MiniGrid to run a single environment."
]
}
],
Expand Down
41 changes: 24 additions & 17 deletions docs/benchmarks/timesteps.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,19 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"#### TL;DR;\n",
"\n",
"NAVIX is **at least** $10\\times$ faster than MiniGrid. The performance of NAVIX scales constantly with the number of parallel environments up to $10^4$, after which it starts to scale at least linearly, probably due to memory saturation.\n",
"\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Benchmarking Minigrid\n",
"\n",
"Let's start with MiniGrid."
]
},
Expand Down Expand Up @@ -111,6 +124,8 @@
"source": [
"It scales quite linearly with the number of steps. That's a reasonable time for a single environment that runs on the CPU.\n",
"\n",
"#### Benchmarking NAVIX\n",
"\n",
"Let's see how NAVIX compares to MiniGrid with a single environment."
]
},
Expand All @@ -127,25 +142,22 @@
"\n",
"@functools.partial(jax.jit, static_argnums=1)\n",
"def benchmark_navix(seed, num_steps):\n",
" def run(seed):\n",
" env = nx.make('Navix-Empty-8x8-v0') # Create the environment\n",
" key = jax.random.PRNGKey(seed)\n",
" timestep = env.reset(key)\n",
" actions = jax.random.randint(key, (num_steps,), 0, env.action_space.n)\n",
" env = nx.make('Navix-Empty-8x8-v0') # Create the environment\n",
" key = jax.random.PRNGKey(seed)\n",
" timestep = env.reset(key)\n",
" actions = jax.random.randint(key, (num_steps,), 0, env.action_space.n)\n",
"\n",
" def body_fun(timestep, action):\n",
" timestep = env.step(timestep, action) # Update the environment state\n",
" return timestep, ()\n",
" def body_fun(timestep, action):\n",
" timestep = env.step(timestep, action) # Update the environment state\n",
" return timestep, ()\n",
"\n",
" return jax.lax.scan(body_fun, timestep, actions, unroll=10)[0]\n",
" return jax.lax.scan(body_fun, timestep, actions, unroll=10)[0]\n",
"\n",
" final_timestep = jax.jit(jax.vmap(run))(seed)\n",
" return final_timestep\n",
"\n",
"seed = jax.random.PRNGKey(0)\n",
"navix_times = {}\n",
"for i in num_steps_set:\n",
" benchmark_navix_jit = benchmark_navix.lower(seed, i).compile()\n",
" benchmark_navix_jit = benchmark_navix.lower(seed, i).compile() # AOT compilation\n",
" navix_times[i] = timeit.timeit(lambda: benchmark_navix_jit(seed), number=1)\n"
]
},
Expand Down Expand Up @@ -194,11 +206,6 @@
"\n",
"This is not all. NAVIX is designed to scale with the number of parallel environments. Let's see how it performs in the batched case with the next benchmark."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
}
],
"metadata": {
Expand Down
8 changes: 4 additions & 4 deletions docs/examples/getting_started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
"# NAVIX 101\n",
"\n",
"This tutorial will guide you through the basics of using NAVIX. You will learn:\n",
"- How to create a `navix.Environment`,\n",
"- A vanilla, suboptimal interaction with it\n",
"- How to `jax.jit` compile the environment for faster execution\n",
"- How to run batched simulations"
"- [How to create a `navix.Environment`](#creating-an-environment),\n",
"- [A vanilla, suboptimal interaction with it](#the-environment-interface),\n",
"- [How to `jax.jit` compile the environment for faster execution](#optimizing-with-jax)\n",
"- [How to run batched simulations](#batched-environments)"
]
},
{
Expand Down
56 changes: 56 additions & 0 deletions docs/home/environments.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Supported environments

NAVIX is designed to be a drop-in replacement for the official MiniGrid environments.
You can reuse your existing code and scripts with NAVIX with little to no modification.

You can find the original MiniGrid environments in the [MiniGrid documentation](https://minigrid.huggingface.co/docs/).
For more details on MiniGrid, have a look also at the [original publication](https://arxiv.org/pdf/2306.13831).

The following table lists the supported MiniGrid environments and their corresponding NAVIX environments.
If you cannot find the environment you are looking for, please consider [opening a feature request](https://github.com/epignatelli/navix/issues/new?assignees=&labels=enhancement&template=feature_request.md&title=) on GitHub.

| MiniGrid ID | NAVIX ID | Description |
| -------------------------------------------- | ------------------------------------------------------------------------------------- | ------------------------------------------------------------------------ |
| `MiniGrid-Empty-5x5-v0` | [`Navix-Empty-5x5-v0`](../api/environments/empty.md) | Empty 5x5 grid |
| `MiniGrid-Empty-6x6-v0` | [`Navix-Empty-6x6-v0`](../api/environments/empty.md) | Empty 6x6 grid |
| `MiniGrid-Empty-8x8-v0` | [`Navix-Empty-8x8-v0`](../api/environments/empty.md) | Empty 8x8 grid |
| `MiniGrid-Empty-16x16-v0` | [`Navix-Empty-16x16-v0`](../api/environments/empty.md) | Empty 16x16 grid |
| `MiniGrid-Empty-Random-5x5-v0` | [`Navix-Empty-Random-5x5-v0`](../api/environments/empty.md) | Empty 5x5 grid with random starts |
| `MiniGrid-Empty-Random-6x6-v0` | [`Navix-Empty-Random-6x6-v0`](../api/environments/empty.md) | Empty 6x6 grid with random starts |
| `MiniGrid-Empty-Random-8x8-v0` | [`Navix-Empty-Random-8x8-v0`](../api/environments/empty.md) | Empty 8x8 grid with random starts |
| `MiniGrid-Empty-Random-16x16-v0` | [`Navix-Empty-Random-16x16-v0`](../api/environments/empty.md) | Empty 16x16 grid with random starts |
| `MiniGrid-FourRooms-v0` | [`Navix-FourRooms-v0`](../api/environments/four_rooms.md) | Four rooms |
| `MiniGrid-DoorKey-5x5-v0` | [`Navix-DoorKey-5x5-v0`](../api/environments/door_key.md) | 5x5 grid with a key and a door |
| `MiniGrid-DoorKey-6x6-v0` | [`Navix-DoorKey-6x6-v0`](../api/environments/door_key.md) | 6x6 grid with a key and a door |
| `MiniGrid-DoorKey-8x8-v0` | [`Navix-DoorKey-8x8-v0`](../api/environments/door_key.md) | 8x8 grid with a key and a door |
| `MiniGrid-DoorKey-16x16-v0` | [`Navix-DoorKey-16x16-v0`](../api/environments/door_key.md) | 16x16 grid with a key and a door |
| `MiniGrid-DoorKey-5x5-Random-v0` | [`Navix-DoorKey-5x5-Random-v0`](../api/environments/door_key.md) | 5x5 grid with a key and a door |
| `MiniGrid-DoorKey-6x6-Random-v0` | [`Navix-DoorKey-6x6-Random-v0`](../api/environments/door_key.md) | 6x6 grid with a key and a door |
| `MiniGrid-DoorKey-8x8-Random-v0` | [`Navix-DoorKey-8x8-Random-v0`](../api/environments/door_key.md) | 8x8 grid with a key and a door |
| `MiniGrid-DoorKey-16x16-Random-v0` | [`Navix-DoorKey-16x16-Random-v0`](../api/environments/door_key.md) | 16x16 grid with a key and a door |
| `MiniGrid-KeyCorridorS3R1-v0` | [`Navix-KeyCorridorS3R1-v0`](../api/environments/key_corridor.md) | Corridor with a key 3 cells away |
| `MiniGrid-KeyCorridorS3R2-v0` | [`Navix-KeyCorridorS3R2-v0`](../api/environments/key_corridor.md) | Corridor with a key 3 cells away |
| `MiniGrid-KeyCorridorS3R3-v0` | [`Navix-KeyCorridorS3R3-v0`](../api/environments/key_corridor.md) | Corridor with a key 3 cells away |
| `MiniGrid-KeyCorridorS4R3-v0` | [`Navix-KeyCorridorS4R3-v0`](../api/environments/key_corridor.md) | Corridor with a key 4 cells away |
| `MiniGrid-KeyCorridorS5R3-v0` | [`Navix-KeyCorridorS5R3-v0`](../api/environments/key_corridor.md) | Corridor with a key 5 cells away |
| `MiniGrid-KeyCorridorS6R3-v0` | [`Navix-KeyCorridorS6R3-v0`](../api/environments/key_corridor.md) | Corridor with a key 6 cells away |
| `MiniGrid-Crossings-S9N1-v0` | [`Navix-Crossings-S9N1-v0`](../api/environments/crossings.md) | A 9x9 room with 1 wall crossing it |
| `MiniGrid-Crossings-S9N2-v0` | [`Navix-Crossings-S9N2-v0`](../api/environments/crossings.md) | A 9x9 room with 2 walls crossing it |
| `MiniGrid-Crossings-S9N3-v0` | [`Navix-Crossings-S9N3-v0`](../api/environments/crossings.md) | A 9x9 room with 3 walls crossing it |
| `MiniGrid-Crossings-S11N5-v0` | [`Navix-Crossings-S11N5-v0`](../api/environments/crossings.md) | A 11x11 room with 5 walls crossing it |
| `MiniGrid-DistShift1-v0` | [`Navix-DistShift1-v0`](../api/environments/dist_shift.md) | DistShift with 1 goal |
| `MiniGrid-DistShift2-v0` | [`Navix-DistShift2-v0`](../api/environments/dist_shift.md) | DistShift with 2 goals |
| `MiniGrid-LavaGap-S5-v0` | [`Navix-LavaGap-S5-v0`](../api/environments/lava_gap.md) | LavaGap with in a 5x5 room |
| `MiniGrid-LavaGap-S6-v0` | [`Navix-LavaGap-S6-v0`](../api/environments/lava_gap.md) | LavaGap with in a 6x6 room |
| `MiniGrid-LavaGap-S7-v0` | [`Navix-LavaGap-S7-v0`](../api/environments/lava_gap.md) | LavaGap with 7x7 room |
| `MiniGrid-GoToDoor-5x5-v0` | [`Navix-GoToDoor-5x5-v0`](../api/environments/go_to_door.md) | 5x5 grid that terminates with a `done` action next to a certain door |
| `MiniGrid-GoToDoor-6x6-v0` | [`Navix-GoToDoor-6x6-v0`](../api/environments/go_to_door.md) | 6x6 grid that terminates with a `done` action next to a certain doo |
| `MiniGrid-GoToDoor-8x8-v0` | [`Navix-GoToDoor-8x8-v0`](../api/environments/go_to_door.md) | 8x8 grid grid that terminates with a `done` action next to a certain doo |
| `MiniGrid-Dynamic-Obstacles-5x5-v0` | [`Navix-Dynamic-Obstacles-5x5-v0`](../api/environments/dynamic_obstacles.md) | 5x5 grid with dynamic obstacles |
| `MiniGrid-Dynamic-Obstacles-6x6-v0` | [`Navix-Dynamic-Obstacles-6x6-v0`](../api/environments/dynamic_obstacles.md) | 6x6 grid with dynamic obstacles |
| `MiniGrid-Dynamic-Obstacles-8x8-v0` | [`Navix-Dynamic-Obstacles-8x8-v0`](../api/environments/dynamic_obstacles.md) | 8x8 grid with dynamic obstacles |
| `MiniGrid-Dynamic-Obstacles-16x16-v0` | [`Navix-Dynamic-Obstacles-16x16-v0`](../api/environments/dynamic_obstacles.md) | 16x16 grid with dynamic obstacles |
| `MiniGrid-Dynamic-Obstacles-Random-5x5-v0` | [`Navix-Dynamic-Obstacles-Random-5x5-v0`](../api/environments/dynamic_obstacles.md) | 5x5 grid with dynamic obstacles and random starts |
| `MiniGrid-Dynamic-Obstacles-Random-6x6-v0` | [`Navix-Dynamic-Obstacles-Random-6x6-v0`](../api/environments/dynamic_obstacles.md) | 6x6 grid with dynamic obstacles and random starts |
| `MiniGrid-Dynamic-Obstacles-Random-8x8-v0` | [`Navix-Dynamic-Obstacles-Random-8x8-v0`](../api/environments/dynamic_obstacles.md) | 8x8 grid with dynamic obstacles and random starts |
| `MiniGrid-Dynamic-Obstacles-Random-16x16-v0` | [`Navix-Dynamic-Obstacles-Random-16x16-v0`](../api/environments/dynamic_obstacles.md) | 16x16 grid with dynamic obstacles and random starts |
4 changes: 2 additions & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<p class="maiusc" style="margin-bottom: 0.5em;">A <b>fast</b>, fully <b>jittable</b> MiniGrid reimplemented in JAX</p>
<p class="maiusc" style="margin-bottom: 0.8em;">A <b>fast</b>, fully <b>jittable, batched MiniGrid</b> reimplemented in JAX for HIGH <b>THROUGHPUT</b></p>
<h1>Welcome to <b>NAVIX</b>!</h1>


Expand All @@ -7,7 +7,7 @@
NAVIX is designed to be a drop-in replacement for the original MiniGrid environment, with the added benefit of being significantly faster.
Experiments that took **1 week**, now take **15 minutes**.

A `navix.Environment` is a `flax.struct.PyTreeNode` and supports `jax.vmap`, `jax.jit`, `jax.grad`, and all the other JAX's transformations.
A [`navix.Environment`](api/environments/environment.md) is a [`flax.struct.PyTreeNode`](https://flax.readthedocs.io/en/latest/api_reference/flax.struct.html#flax.struct.PyTreeNode) and supports [`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html), [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html), [`jax.grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html), and all the other JAX's transformations.
See some examples [here](examples/getting_started.ipynb).

<br>
Expand Down
Loading

0 comments on commit 20e53e9

Please sign in to comment.