From 18ef25d4493d02e4fada76c293bde1b21fd177ba Mon Sep 17 00:00:00 2001 From: epignatelli Date: Sat, 22 Jun 2024 10:11:27 +0100 Subject: [PATCH 1/4] polish benchamrks --- docs/benchmarks/envs.ipynb | 24 +++++++++++++++---- docs/benchmarks/timesteps.ipynb | 41 +++++++++++++++++++-------------- docs/index.md | 2 +- 3 files changed, 45 insertions(+), 22 deletions(-) diff --git a/docs/benchmarks/envs.ipynb b/docs/benchmarks/envs.ipynb index b583fee..6c138be 100644 --- a/docs/benchmarks/envs.ipynb +++ b/docs/benchmarks/envs.ipynb @@ -11,13 +11,27 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "In this notebook, we will look at the difference in **throughput** between MiniGrid and NAVIX environments." + "In this notebook, we will look at the difference in **THROUGHPUT** between MiniGrid and NAVIX environments: how the performance scales with the number of environments.\n", + "We will still use random actions, but the environments now run in batch mode." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ + "#### TL;DR;\n", + "\n", + "NAVIX can scale up to over **$2M$ environments** in parallel in **less than $10$s**, less than the time required by MiniGrid to run a single environment.\n", + "\n", + "---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Benchmarking MiniGrid\n", + "\n", "Let's start with MiniGrid." ] }, @@ -110,6 +124,9 @@ "source": [ "It scales quite linearly with the number of steps. That's a reasonable time for a single environment that runs on the CPU.\n", "\n", + "\n", + "#### Benchmarking NAVIX\n", + "\n", "Let's see how NAVIX compares to MiniGrid with a single environment." ] }, @@ -139,14 +156,13 @@ "navix_times = {}\n", "i = 2\n", "while True:\n", - " if i >= 524288: # \n", - " break\n", " try:\n", " seeds = jax.random.split(jax.random.PRNGKey(0), i)\n", " benchmark_navix_jit = benchmark_navix.lower(seeds).compile()\n", " navix_times[i] = timeit.timeit(lambda: benchmark_navix_jit(seeds), number=1)\n", " i *= 2\n", " except:\n", + " print(\"Max num_envs reached\", i)\n", " break" ] }, @@ -191,7 +207,7 @@ "source": [ "## Conclusions\n", "\n", - "NAVIX can scale up to $2^21 = 2097152$ environments (over $2M$ environments in parallel!) and still do that in less than the time required by MiniGrid to run a single environment." + "NAVIX can scale up to $2^{21} = 2097152$ environments (over $2M$ environments in parallel!) and still do that in less than the time required by MiniGrid to run a single environment." ] } ], diff --git a/docs/benchmarks/timesteps.ipynb b/docs/benchmarks/timesteps.ipynb index 1055331..19ea445 100644 --- a/docs/benchmarks/timesteps.ipynb +++ b/docs/benchmarks/timesteps.ipynb @@ -24,6 +24,19 @@ "cell_type": "markdown", "metadata": {}, "source": [ + "#### TL;DR;\n", + "\n", + "NAVIX is **at least** $10\\times$ faster than MiniGrid. The performance of NAVIX scales constantly with the number of parallel environments up to $10^4$, after which it starts to scale at least linearly, probably due to memory saturation.\n", + "\n", + "---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Benchmarking Minigrid\n", + "\n", "Let's start with MiniGrid." ] }, @@ -111,6 +124,8 @@ "source": [ "It scales quite linearly with the number of steps. That's a reasonable time for a single environment that runs on the CPU.\n", "\n", + "#### Benchmarking NAVIX\n", + "\n", "Let's see how NAVIX compares to MiniGrid with a single environment." ] }, @@ -127,25 +142,22 @@ "\n", "@functools.partial(jax.jit, static_argnums=1)\n", "def benchmark_navix(seed, num_steps):\n", - " def run(seed):\n", - " env = nx.make('Navix-Empty-8x8-v0') # Create the environment\n", - " key = jax.random.PRNGKey(seed)\n", - " timestep = env.reset(key)\n", - " actions = jax.random.randint(key, (num_steps,), 0, env.action_space.n)\n", + " env = nx.make('Navix-Empty-8x8-v0') # Create the environment\n", + " key = jax.random.PRNGKey(seed)\n", + " timestep = env.reset(key)\n", + " actions = jax.random.randint(key, (num_steps,), 0, env.action_space.n)\n", "\n", - " def body_fun(timestep, action):\n", - " timestep = env.step(timestep, action) # Update the environment state\n", - " return timestep, ()\n", + " def body_fun(timestep, action):\n", + " timestep = env.step(timestep, action) # Update the environment state\n", + " return timestep, ()\n", "\n", - " return jax.lax.scan(body_fun, timestep, actions, unroll=10)[0]\n", + " return jax.lax.scan(body_fun, timestep, actions, unroll=10)[0]\n", "\n", - " final_timestep = jax.jit(jax.vmap(run))(seed)\n", - " return final_timestep\n", "\n", "seed = jax.random.PRNGKey(0)\n", "navix_times = {}\n", "for i in num_steps_set:\n", - " benchmark_navix_jit = benchmark_navix.lower(seed, i).compile()\n", + " benchmark_navix_jit = benchmark_navix.lower(seed, i).compile() # AOT compilation\n", " navix_times[i] = timeit.timeit(lambda: benchmark_navix_jit(seed), number=1)\n" ] }, @@ -194,11 +206,6 @@ "\n", "This is not all. NAVIX is designed to scale with the number of parallel environments. Let's see how it performs in the batched case with the next benchmark." ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [] } ], "metadata": { diff --git a/docs/index.md b/docs/index.md index afeb140..e295d33 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,4 +1,4 @@ -

A fast, fully jittable MiniGrid reimplemented in JAX

+

A fast, fully jittable MiniGrid reimplemented in JAX for HIGH THROUGHPUT

Welcome to NAVIX!

From 251d74b1a137e332afa5b86ed21a7b4a680e6659 Mon Sep 17 00:00:00 2001 From: epignatelli Date: Wed, 26 Jun 2024 19:59:57 +0100 Subject: [PATCH 2/4] fix logo position --- README.md | 2 +- docs/assets/stylesheets/extra.css | 7 ++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 771af4d..6c4a3bc 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ ![PyPI version](https://img.shields.io/pypi/v/navix?label=PyPI&color=%230099ab) -**[Quickstart](#what-is-navix)** | **[Install](#installation)** | **[Examples](#examples)** | **[Docs](https://navix.readthedocs.io)** | **[The JAX ecosystem](#jax-ecosystem-for-rl)** | **[Contribute](#join-us)** | **[Cite](#cite)** +**[Quickstart](#what-is-navix)** | **[Install](#installation)** | **[Examples](#examples)** | **[Docs](https://epignatelli.com/navix)** | **[The JAX ecosystem](#jax-ecosystem-for-rl)** | **[Contribute](#join-us)** | **[Cite](#cite)** diff --git a/docs/assets/stylesheets/extra.css b/docs/assets/stylesheets/extra.css index 7a19883..916abe0 100644 --- a/docs/assets/stylesheets/extra.css +++ b/docs/assets/stylesheets/extra.css @@ -18,14 +18,11 @@ --md-text-font: "Sherpa"; } */ -.md-header__button.md-logo { - margin-bottom: 0; - padding-bottom: 0; -} - .md-header__button.md-logo img, .md-header__button.md-logo svg { height: 3rem !important; + margin-bottom: 0em; + padding-bottom: 0; } .no-bottom-margin { From 6786ba0f5e73fac0cdcfe40f5b260697ede1bc46 Mon Sep 17 00:00:00 2001 From: epignatelli Date: Wed, 26 Jun 2024 20:00:19 +0100 Subject: [PATCH 3/4] fix typo --- docs/examples/getting_started.ipynb | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/examples/getting_started.ipynb b/docs/examples/getting_started.ipynb index e4298a1..e7ebd4c 100644 --- a/docs/examples/getting_started.ipynb +++ b/docs/examples/getting_started.ipynb @@ -7,10 +7,10 @@ "# NAVIX 101\n", "\n", "This tutorial will guide you through the basics of using NAVIX. You will learn:\n", - "- How to create a `navix.Environment`,\n", - "- A vanilla, suboptimal interaction with it\n", - "- How to `jax.jit` compile the environment for faster execution\n", - "- How to run batched simulations" + "- [How to create a `navix.Environment`](#creating-an-environment),\n", + "- [A vanilla, suboptimal interaction with it](#the-environment-interface),\n", + "- [How to `jax.jit` compile the environment for faster execution](#optimizing-with-jax)\n", + "- [How to run batched simulations](#batched-environments)" ] }, { From 49f57985fd82777faeff3a1e2baeb0b5659e0f6b Mon Sep 17 00:00:00 2001 From: epignatelli Date: Wed, 26 Jun 2024 20:00:51 +0100 Subject: [PATCH 4/4] more typos --- docs/home/environments.md | 56 +++++++++++++++++++++++++++++++++++++++ docs/index.md | 4 +-- mkdocs.yml | 8 ++---- navix/_version.py | 2 +- 4 files changed, 61 insertions(+), 9 deletions(-) diff --git a/docs/home/environments.md b/docs/home/environments.md index e69de29..abd4a08 100644 --- a/docs/home/environments.md +++ b/docs/home/environments.md @@ -0,0 +1,56 @@ +# Supported environments + +NAVIX is designed to be a drop-in replacement for the official MiniGrid environments. +You can reuse your existing code and scripts with NAVIX with little to no modification. + +You can find the original MiniGrid environments in the [MiniGrid documentation](https://minigrid.huggingface.co/docs/). +For more details on MiniGrid, have a look also at the [original publication](https://arxiv.org/pdf/2306.13831). + +The following table lists the supported MiniGrid environments and their corresponding NAVIX environments. +If you cannot find the environment you are looking for, please consider [opening a feature request](https://github.com/epignatelli/navix/issues/new?assignees=&labels=enhancement&template=feature_request.md&title=) on GitHub. + +| MiniGrid ID | NAVIX ID | Description | +| -------------------------------------------- | ------------------------------------------------------------------------------------- | ------------------------------------------------------------------------ | +| `MiniGrid-Empty-5x5-v0` | [`Navix-Empty-5x5-v0`](../api/environments/empty.md) | Empty 5x5 grid | +| `MiniGrid-Empty-6x6-v0` | [`Navix-Empty-6x6-v0`](../api/environments/empty.md) | Empty 6x6 grid | +| `MiniGrid-Empty-8x8-v0` | [`Navix-Empty-8x8-v0`](../api/environments/empty.md) | Empty 8x8 grid | +| `MiniGrid-Empty-16x16-v0` | [`Navix-Empty-16x16-v0`](../api/environments/empty.md) | Empty 16x16 grid | +| `MiniGrid-Empty-Random-5x5-v0` | [`Navix-Empty-Random-5x5-v0`](../api/environments/empty.md) | Empty 5x5 grid with random starts | +| `MiniGrid-Empty-Random-6x6-v0` | [`Navix-Empty-Random-6x6-v0`](../api/environments/empty.md) | Empty 6x6 grid with random starts | +| `MiniGrid-Empty-Random-8x8-v0` | [`Navix-Empty-Random-8x8-v0`](../api/environments/empty.md) | Empty 8x8 grid with random starts | +| `MiniGrid-Empty-Random-16x16-v0` | [`Navix-Empty-Random-16x16-v0`](../api/environments/empty.md) | Empty 16x16 grid with random starts | +| `MiniGrid-FourRooms-v0` | [`Navix-FourRooms-v0`](../api/environments/four_rooms.md) | Four rooms | +| `MiniGrid-DoorKey-5x5-v0` | [`Navix-DoorKey-5x5-v0`](../api/environments/door_key.md) | 5x5 grid with a key and a door | +| `MiniGrid-DoorKey-6x6-v0` | [`Navix-DoorKey-6x6-v0`](../api/environments/door_key.md) | 6x6 grid with a key and a door | +| `MiniGrid-DoorKey-8x8-v0` | [`Navix-DoorKey-8x8-v0`](../api/environments/door_key.md) | 8x8 grid with a key and a door | +| `MiniGrid-DoorKey-16x16-v0` | [`Navix-DoorKey-16x16-v0`](../api/environments/door_key.md) | 16x16 grid with a key and a door | +| `MiniGrid-DoorKey-5x5-Random-v0` | [`Navix-DoorKey-5x5-Random-v0`](../api/environments/door_key.md) | 5x5 grid with a key and a door | +| `MiniGrid-DoorKey-6x6-Random-v0` | [`Navix-DoorKey-6x6-Random-v0`](../api/environments/door_key.md) | 6x6 grid with a key and a door | +| `MiniGrid-DoorKey-8x8-Random-v0` | [`Navix-DoorKey-8x8-Random-v0`](../api/environments/door_key.md) | 8x8 grid with a key and a door | +| `MiniGrid-DoorKey-16x16-Random-v0` | [`Navix-DoorKey-16x16-Random-v0`](../api/environments/door_key.md) | 16x16 grid with a key and a door | +| `MiniGrid-KeyCorridorS3R1-v0` | [`Navix-KeyCorridorS3R1-v0`](../api/environments/key_corridor.md) | Corridor with a key 3 cells away | +| `MiniGrid-KeyCorridorS3R2-v0` | [`Navix-KeyCorridorS3R2-v0`](../api/environments/key_corridor.md) | Corridor with a key 3 cells away | +| `MiniGrid-KeyCorridorS3R3-v0` | [`Navix-KeyCorridorS3R3-v0`](../api/environments/key_corridor.md) | Corridor with a key 3 cells away | +| `MiniGrid-KeyCorridorS4R3-v0` | [`Navix-KeyCorridorS4R3-v0`](../api/environments/key_corridor.md) | Corridor with a key 4 cells away | +| `MiniGrid-KeyCorridorS5R3-v0` | [`Navix-KeyCorridorS5R3-v0`](../api/environments/key_corridor.md) | Corridor with a key 5 cells away | +| `MiniGrid-KeyCorridorS6R3-v0` | [`Navix-KeyCorridorS6R3-v0`](../api/environments/key_corridor.md) | Corridor with a key 6 cells away | +| `MiniGrid-Crossings-S9N1-v0` | [`Navix-Crossings-S9N1-v0`](../api/environments/crossings.md) | A 9x9 room with 1 wall crossing it | +| `MiniGrid-Crossings-S9N2-v0` | [`Navix-Crossings-S9N2-v0`](../api/environments/crossings.md) | A 9x9 room with 2 walls crossing it | +| `MiniGrid-Crossings-S9N3-v0` | [`Navix-Crossings-S9N3-v0`](../api/environments/crossings.md) | A 9x9 room with 3 walls crossing it | +| `MiniGrid-Crossings-S11N5-v0` | [`Navix-Crossings-S11N5-v0`](../api/environments/crossings.md) | A 11x11 room with 5 walls crossing it | +| `MiniGrid-DistShift1-v0` | [`Navix-DistShift1-v0`](../api/environments/dist_shift.md) | DistShift with 1 goal | +| `MiniGrid-DistShift2-v0` | [`Navix-DistShift2-v0`](../api/environments/dist_shift.md) | DistShift with 2 goals | +| `MiniGrid-LavaGap-S5-v0` | [`Navix-LavaGap-S5-v0`](../api/environments/lava_gap.md) | LavaGap with in a 5x5 room | +| `MiniGrid-LavaGap-S6-v0` | [`Navix-LavaGap-S6-v0`](../api/environments/lava_gap.md) | LavaGap with in a 6x6 room | +| `MiniGrid-LavaGap-S7-v0` | [`Navix-LavaGap-S7-v0`](../api/environments/lava_gap.md) | LavaGap with 7x7 room | +| `MiniGrid-GoToDoor-5x5-v0` | [`Navix-GoToDoor-5x5-v0`](../api/environments/go_to_door.md) | 5x5 grid that terminates with a `done` action next to a certain door | +| `MiniGrid-GoToDoor-6x6-v0` | [`Navix-GoToDoor-6x6-v0`](../api/environments/go_to_door.md) | 6x6 grid that terminates with a `done` action next to a certain doo | +| `MiniGrid-GoToDoor-8x8-v0` | [`Navix-GoToDoor-8x8-v0`](../api/environments/go_to_door.md) | 8x8 grid grid that terminates with a `done` action next to a certain doo | +| `MiniGrid-Dynamic-Obstacles-5x5-v0` | [`Navix-Dynamic-Obstacles-5x5-v0`](../api/environments/dynamic_obstacles.md) | 5x5 grid with dynamic obstacles | +| `MiniGrid-Dynamic-Obstacles-6x6-v0` | [`Navix-Dynamic-Obstacles-6x6-v0`](../api/environments/dynamic_obstacles.md) | 6x6 grid with dynamic obstacles | +| `MiniGrid-Dynamic-Obstacles-8x8-v0` | [`Navix-Dynamic-Obstacles-8x8-v0`](../api/environments/dynamic_obstacles.md) | 8x8 grid with dynamic obstacles | +| `MiniGrid-Dynamic-Obstacles-16x16-v0` | [`Navix-Dynamic-Obstacles-16x16-v0`](../api/environments/dynamic_obstacles.md) | 16x16 grid with dynamic obstacles | +| `MiniGrid-Dynamic-Obstacles-Random-5x5-v0` | [`Navix-Dynamic-Obstacles-Random-5x5-v0`](../api/environments/dynamic_obstacles.md) | 5x5 grid with dynamic obstacles and random starts | +| `MiniGrid-Dynamic-Obstacles-Random-6x6-v0` | [`Navix-Dynamic-Obstacles-Random-6x6-v0`](../api/environments/dynamic_obstacles.md) | 6x6 grid with dynamic obstacles and random starts | +| `MiniGrid-Dynamic-Obstacles-Random-8x8-v0` | [`Navix-Dynamic-Obstacles-Random-8x8-v0`](../api/environments/dynamic_obstacles.md) | 8x8 grid with dynamic obstacles and random starts | +| `MiniGrid-Dynamic-Obstacles-Random-16x16-v0` | [`Navix-Dynamic-Obstacles-Random-16x16-v0`](../api/environments/dynamic_obstacles.md) | 16x16 grid with dynamic obstacles and random starts | diff --git a/docs/index.md b/docs/index.md index e295d33..869417e 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,4 +1,4 @@ -

A fast, fully jittable MiniGrid reimplemented in JAX for HIGH THROUGHPUT

+

A fast, fully jittable, batched MiniGrid reimplemented in JAX for HIGH THROUGHPUT

Welcome to NAVIX!

@@ -7,7 +7,7 @@ NAVIX is designed to be a drop-in replacement for the original MiniGrid environment, with the added benefit of being significantly faster. Experiments that took **1 week**, now take **15 minutes**. -A `navix.Environment` is a `flax.struct.PyTreeNode` and supports `jax.vmap`, `jax.jit`, `jax.grad`, and all the other JAX's transformations. +A [`navix.Environment`](api/environments/environment.md) is a [`flax.struct.PyTreeNode`](https://flax.readthedocs.io/en/latest/api_reference/flax.struct.html#flax.struct.PyTreeNode) and supports [`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html), [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html), [`jax.grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html), and all the other JAX's transformations. See some examples [here](examples/getting_started.ipynb).
diff --git a/mkdocs.yml b/mkdocs.yml index aa93e18..bf53b35 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -15,12 +15,12 @@ nav: - Home: - Welcome: index.md - Environments: home/environments.md + - Install: install/index.md - Quickstart: - "Getting started": examples/getting_started.ipynb - "PPO": examples/ppo.ipynb # - "Customizing envs": examples/customisation.ipynb - - Install: install/index.md - - Becnhmarks: + - Benchmarks: - "Timesteps": benchmarks/timesteps.ipynb - "Environments": benchmarks/envs.ipynb - API: api/ @@ -63,12 +63,9 @@ theme: - content.tooltips - navigation.instant - navigation.footer - - navigation.sections - navigation.tabs - - navigation.tabs.sticky - navigation.top - navigation.path - - navigation.tracking - search.highlight - search.share - search.suggest @@ -116,7 +113,6 @@ plugins: - docs/scripts/gen_doc_stubs.py # or any other name or path - literate-nav: nav_file: SUMMARY.md - - section-index markdown_extensions: - toc: diff --git a/navix/_version.py b/navix/_version.py index 55d7f03..cd6f1bb 100644 --- a/navix/_version.py +++ b/navix/_version.py @@ -18,5 +18,5 @@ # under the License. -__version__ = "0.6.11" +__version__ = "0.6.12" __version_info__ = tuple(int(i) for i in __version__.split(".") if i.isdigit())