diff --git a/README.md b/README.md index 9e8cd00..b1d9b05 100644 --- a/README.md +++ b/README.md @@ -8,72 +8,133 @@ **[Quickstart](#what-is-navix)** | **[Installation](#installation)** | **[Examples](#examples)** | **[Cite](#cite)** ## What is NAVIX? -NAVIX is [minigrid](https://github.com/Farama-Foundation/Minigrid) in JAX, **>1000x** faster with Autograd and XLA support. -You can see a superficial performance comparison [here](docs/performance.ipynb). +NAVIX is a JAX-powered reimplementation of [minigrid](https://github.com/Farama-Foundation/Minigrid). Key features: +- Performance Boost: NAVIX offers a **~>1000x** speed increase compared to the original Minigrid, enabling faster experimentation and scaling. You can see a preliminary performance comparison [here](docs/performance.py). +- XLA Compilation: Leverage the power of XLA to optimize NAVIX computations for your hardware (CPU, GPU, TPU). +- Autograd Support: Differentiate through environment transitions, opening up new possibilities such as learned world models. The library is in active development, and we are working on adding more environments and features. If you want join the development and contribute, please [open a discussion](https://github.com/epignatelli/navix/discussions/new?category=general) and let's have a chat! ## Installation -We currently support the OSs supported by JAX. -You can find a description [here](https://github.com/google/jax#installation). +#### Install JAX +Follow the official installation guide for your OS and preferred accelerator: https://github.com/google/jax#installation. -You might want to follow the same guide to install jax for your faviourite accelerator -(e.g. [CPU](https://github.com/google/jax#pip-installation-cpu), -[GPU](https://github.com/google/jax#pip-installation-gpu-cuda-installed-locally-harder), or -[TPU](https://github.com/google/jax#pip-installation-colab-tpu) -). - -- ### Stable -Then, install the stable version of `navix` and its dependencies with: +#### Install NAVIX ```bash pip install navix ``` -- ### Nightly -Or, if you prefer to install the latest version from source: +Or, for the latest version from source: ```bash pip install git+https://github.com/epignatelli/navix ``` ## Examples -### XLA compilation -One straightforward use case is to accelerate the computation of the environment with XLA compilation. -For example, here we vectorise the environment to run multiple environments in parallel, and compile **the full training run**. - -You can find a partial performance comparison with [minigrid](https://github.com/Farama-Foundation/Minigrid) in the [docs](docs/profiling.ipynb). - +### Compiling a collection step ```python import jax import navix as nx +import jax.numpy as jnp -def run(seed) - env = nx.environments.Room(16, 16, 8) +def run(seed): + env = nx.make('MiniGrid-Empty-8x8-v0') # Create the environment key = jax.random.PRNGKey(seed) timestep = env.reset(key) - actions = jax.random.randint(key, (N_TIMESTEPS,), 0, 6) + actions = jax.random.randint(key, (N_TIMESTEPS,), 0, env.action_space.n) def body_fun(timestep, action): - timestep = env.step(timestep, jnp.asarray(action)) + timestep = env.step(action) # Update the environment state return timestep, () - return jax.lax.scan(body_fun, timestep, jnp.asarray(actions, dtype=jnp.int32))[0] + return jax.lax.scan(body_fun, timestep, actions)[0] -final_timestep = jax.jit(jax.vmap(run))(jax.numpy.arange(1000)) +# Compile the entire training run for maximum performance +final_timestep = jax.jit(jax.vmap(run))(jnp.arange(1000)) +``` + +### Compiling a full training run +```python +import jax +import navix as nx +import jax.numpy as jnp +from jax import random + +def run_episode(seed, env, policy): + """Simulates a single episode with a given policy""" + key = random.PRNGKey(seed) + timestep = env.reset(key) + done = False + total_reward = 0 + + while not done: + action = policy(timestep.observation) + timestep, reward, done, _ = env.step(action) + total_reward += reward + + return total_reward + +def train_policy(policy, num_episodes): + """Trains a policy over multiple parallel episodes""" + envs = jax.vmap(nx.make, in_axes=0)(['MiniGrid-MultiRoom-N2-S4-v0'] * num_episodes) + seeds = random.split(random.PRNGKey(0), num_episodes) + + # Compile the entire training loop with XLA + compiled_episode = jax.jit(run_episode) + compiled_train = jax.jit(jax.vmap(compiled_episode, in_axes=(0, 0, None))) + + for _ in range(num_episodes): + rewards = compiled_train(seeds, envs, policy) + # ... Update the policy based on rewards ... + +# Hypothetical policy function +def policy(observation): + # ... your policy logic ... + return action + +# Start the training +train_policy(policy, num_episodes=100) ``` ### Backpropagation through the environment +```python +import jax +import navix as nx +import jax.numpy as jnp +from jax import grad +from flax import struct + + +class Model(struct.PyTreeNode): + @nn.compact + def __call__(self, x): + # ... your NN here -Another use case it to backpropagate through the environment transition function, for example to learn a world model. +model = Model() +env = nx.environments.Room(16, 16, 8) + +def loss(params, timestep): + action = jnp.asarray(0) + pred_obs = model.apply(timestep.observation) + timestep = env.step(timestep, action) + return jnp.square(timestep.observation - pred_obs).mean() + +key = jax.random.PRNGKey(0) +timestep = env.reset(key) +params = model.init(key, timestep.observation) + +gradients = grad(loss)(params, timestep) +``` -TODO(epignatelli): add example. +## Join Us! +NAVIX is actively developed. If you'd like to contribute to this open-source project, we welcome your involvement! Start a discussion or open a pull request. ## Cite -If you use `navix` please consider citing it as: +If you use `navix` please cite it as: ```bibtex @misc{pignatelli2023navix, diff --git a/assets/COPYRIGHT b/assets/COPYRIGHT new file mode 100644 index 0000000..db956a8 --- /dev/null +++ b/assets/COPYRIGHT @@ -0,0 +1,4 @@ +Copyright 2024 https://github.com/Farama-Foundation/Minigrid +The following images are under Apache 2.0 License as per https://github.com/Farama-Foundation/Minigrid/LICENSE. +A copy of the license is provided in the fileassets/LICENSE. + diff --git a/assets/LICENSE b/assets/LICENSE new file mode 100644 index 0000000..8800573 --- /dev/null +++ b/assets/LICENSE @@ -0,0 +1,40 @@ +Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + 1. Definitions. + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + "You" (or "Your") shall mean an individual or Legal Entity + @@ -154,49 +194,23 @@ + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS diff --git a/assets/sprites/ball_blue.png b/assets/sprites/ball_blue.png new file mode 100644 index 0000000..75f25fd Binary files /dev/null and b/assets/sprites/ball_blue.png differ diff --git a/assets/sprites/ball_green.png b/assets/sprites/ball_green.png new file mode 100644 index 0000000..8f30d9e Binary files /dev/null and b/assets/sprites/ball_green.png differ diff --git a/assets/sprites/ball_grey.png b/assets/sprites/ball_grey.png new file mode 100644 index 0000000..ea7294b Binary files /dev/null and b/assets/sprites/ball_grey.png differ diff --git a/assets/sprites/ball_purple.png b/assets/sprites/ball_purple.png new file mode 100644 index 0000000..270fd4a Binary files /dev/null and b/assets/sprites/ball_purple.png differ diff --git a/assets/sprites/ball_red.png b/assets/sprites/ball_red.png new file mode 100644 index 0000000..1694f7f Binary files /dev/null and b/assets/sprites/ball_red.png differ diff --git a/assets/sprites/ball_yellow.png b/assets/sprites/ball_yellow.png new file mode 100644 index 0000000..5a43e15 Binary files /dev/null and b/assets/sprites/ball_yellow.png differ diff --git a/assets/sprites/box_blue.png b/assets/sprites/box_blue.png new file mode 100644 index 0000000..3a55523 Binary files /dev/null and b/assets/sprites/box_blue.png differ diff --git a/assets/sprites/box_green.png b/assets/sprites/box_green.png new file mode 100644 index 0000000..a0b2bb3 Binary files /dev/null and b/assets/sprites/box_green.png differ diff --git a/assets/sprites/box_grey.png b/assets/sprites/box_grey.png new file mode 100644 index 0000000..1a7edeb Binary files /dev/null and b/assets/sprites/box_grey.png differ diff --git a/assets/sprites/box_purple.png b/assets/sprites/box_purple.png new file mode 100644 index 0000000..b6fd511 Binary files /dev/null and b/assets/sprites/box_purple.png differ diff --git a/assets/sprites/box_red.png b/assets/sprites/box_red.png new file mode 100644 index 0000000..ab524e5 Binary files /dev/null and b/assets/sprites/box_red.png differ diff --git a/assets/sprites/box_yellow.png b/assets/sprites/box_yellow.png new file mode 100644 index 0000000..299f10d Binary files /dev/null and b/assets/sprites/box_yellow.png differ diff --git a/assets/sprites/door_closed_blue.png b/assets/sprites/door_closed_blue.png new file mode 100644 index 0000000..ec0aec7 Binary files /dev/null and b/assets/sprites/door_closed_blue.png differ diff --git a/assets/sprites/door_closed_green.png b/assets/sprites/door_closed_green.png new file mode 100644 index 0000000..a937a27 Binary files /dev/null and b/assets/sprites/door_closed_green.png differ diff --git a/assets/sprites/door_closed_grey.png b/assets/sprites/door_closed_grey.png new file mode 100644 index 0000000..4dc90c5 Binary files /dev/null and b/assets/sprites/door_closed_grey.png differ diff --git a/assets/sprites/door_closed_purple.png b/assets/sprites/door_closed_purple.png new file mode 100644 index 0000000..b6ee905 Binary files /dev/null and b/assets/sprites/door_closed_purple.png differ diff --git a/assets/sprites/door_closed_red.png b/assets/sprites/door_closed_red.png new file mode 100644 index 0000000..f617eed Binary files /dev/null and b/assets/sprites/door_closed_red.png differ diff --git a/assets/sprites/door_closed_yellow.png b/assets/sprites/door_closed_yellow.png new file mode 100644 index 0000000..47d381e Binary files /dev/null and b/assets/sprites/door_closed_yellow.png differ diff --git a/assets/sprites/door_locked_blue.png b/assets/sprites/door_locked_blue.png new file mode 100644 index 0000000..ecccff6 Binary files /dev/null and b/assets/sprites/door_locked_blue.png differ diff --git a/assets/sprites/door_locked_green.png b/assets/sprites/door_locked_green.png new file mode 100644 index 0000000..da7f89b Binary files /dev/null and b/assets/sprites/door_locked_green.png differ diff --git a/assets/sprites/door_locked_grey.png b/assets/sprites/door_locked_grey.png new file mode 100644 index 0000000..0b99d79 Binary files /dev/null and b/assets/sprites/door_locked_grey.png differ diff --git a/assets/sprites/door_locked_purple.png b/assets/sprites/door_locked_purple.png new file mode 100644 index 0000000..9800d68 Binary files /dev/null and b/assets/sprites/door_locked_purple.png differ diff --git a/assets/sprites/door_locked_red.png b/assets/sprites/door_locked_red.png new file mode 100644 index 0000000..6478edb Binary files /dev/null and b/assets/sprites/door_locked_red.png differ diff --git a/assets/sprites/door_locked_yellow.png b/assets/sprites/door_locked_yellow.png new file mode 100644 index 0000000..d1eea50 Binary files /dev/null and b/assets/sprites/door_locked_yellow.png differ diff --git a/assets/sprites/door_open_blue.png b/assets/sprites/door_open_blue.png new file mode 100644 index 0000000..05e97eb Binary files /dev/null and b/assets/sprites/door_open_blue.png differ diff --git a/assets/sprites/door_open_green.png b/assets/sprites/door_open_green.png new file mode 100644 index 0000000..8a36b63 Binary files /dev/null and b/assets/sprites/door_open_green.png differ diff --git a/assets/sprites/door_open_grey.png b/assets/sprites/door_open_grey.png new file mode 100644 index 0000000..f5c8078 Binary files /dev/null and b/assets/sprites/door_open_grey.png differ diff --git a/assets/sprites/door_open_purple.png b/assets/sprites/door_open_purple.png new file mode 100644 index 0000000..16ced01 Binary files /dev/null and b/assets/sprites/door_open_purple.png differ diff --git a/assets/sprites/door_open_red.png b/assets/sprites/door_open_red.png new file mode 100644 index 0000000..5ea1e41 Binary files /dev/null and b/assets/sprites/door_open_red.png differ diff --git a/assets/sprites/door_open_yellow.png b/assets/sprites/door_open_yellow.png new file mode 100644 index 0000000..ee75dcd Binary files /dev/null and b/assets/sprites/door_open_yellow.png differ diff --git a/assets/sprites/floor.png b/assets/sprites/floor.png new file mode 100644 index 0000000..b97777f Binary files /dev/null and b/assets/sprites/floor.png differ diff --git a/assets/sprites/goal.png b/assets/sprites/goal.png new file mode 100644 index 0000000..230537a Binary files /dev/null and b/assets/sprites/goal.png differ diff --git a/assets/sprites/key_blue.png b/assets/sprites/key_blue.png new file mode 100644 index 0000000..65ffd12 Binary files /dev/null and b/assets/sprites/key_blue.png differ diff --git a/assets/sprites/key_green.png b/assets/sprites/key_green.png new file mode 100644 index 0000000..5317e99 Binary files /dev/null and b/assets/sprites/key_green.png differ diff --git a/assets/sprites/key_grey.png b/assets/sprites/key_grey.png new file mode 100644 index 0000000..77cc0c0 Binary files /dev/null and b/assets/sprites/key_grey.png differ diff --git a/assets/sprites/key_purple.png b/assets/sprites/key_purple.png new file mode 100644 index 0000000..2830ee3 Binary files /dev/null and b/assets/sprites/key_purple.png differ diff --git a/assets/sprites/key_red.png b/assets/sprites/key_red.png new file mode 100644 index 0000000..8d52bb7 Binary files /dev/null and b/assets/sprites/key_red.png differ diff --git a/assets/sprites/key_yellow.png b/assets/sprites/key_yellow.png new file mode 100644 index 0000000..a9e6c15 Binary files /dev/null and b/assets/sprites/key_yellow.png differ diff --git a/assets/sprites/lava.png b/assets/sprites/lava.png new file mode 100644 index 0000000..f5ff0ea Binary files /dev/null and b/assets/sprites/lava.png differ diff --git a/assets/sprites/player_east.png b/assets/sprites/player_east.png new file mode 100644 index 0000000..65e6af9 Binary files /dev/null and b/assets/sprites/player_east.png differ diff --git a/assets/sprites/player_north.png b/assets/sprites/player_north.png new file mode 100644 index 0000000..3dbf789 Binary files /dev/null and b/assets/sprites/player_north.png differ diff --git a/assets/sprites/player_south.png b/assets/sprites/player_south.png new file mode 100644 index 0000000..370f615 Binary files /dev/null and b/assets/sprites/player_south.png differ diff --git a/assets/sprites/player_west.png b/assets/sprites/player_west.png new file mode 100644 index 0000000..0a9674d Binary files /dev/null and b/assets/sprites/player_west.png differ diff --git a/assets/sprites/wall.png b/assets/sprites/wall.png new file mode 100644 index 0000000..79974e1 Binary files /dev/null and b/assets/sprites/wall.png differ diff --git a/docs/performance.ipynb b/docs/performance.ipynb deleted file mode 100644 index 38cb6bc..0000000 --- a/docs/performance.ipynb +++ /dev/null @@ -1,138 +0,0 @@ -{ - "cells": [ - { - "attachments": {}, - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "view-in-github" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "LWXsFY0rpNs7", - "outputId": "2299207e-7ee7-470a-cfae-db3594560033" - }, - "outputs": [], - "source": [ - "!pip install -q git+https://github.com/epignatelli/navix minigrid" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "E7Nd2DmipPkg" - }, - "outputs": [], - "source": [ - "import jax\n", - "import jax.numpy as jnp\n", - "import navix as nx\n", - "\n", - "import gymnasium as gym\n", - "import minigrid\n", - "import random\n", - "import time\n", - "\n", - "from timeit import timeit\n", - "\n", - "\n", - "N_TIMEIT_LOOPS = 5\n", - "N_TIMESTEPS = 1_000\n", - "N_SEEDS = 10_000\n", - "\n", - "\n", - "def profile_navix(seed):\n", - " env = nx.environments.Room(16, 16, 8)\n", - " key = jax.random.PRNGKey(seed)\n", - " timestep = env.reset(key)\n", - " actions = jax.random.randint(key, (N_TIMESTEPS,), 0, 6)\n", - "\n", - " timestep, _ = jax.lax.while_loop(\n", - " lambda x: x[1] < N_TIMESTEPS,\n", - " lambda x: (env.step(x[0], actions[x[1]]), x[1] + 1),\n", - " (timestep, jnp.asarray(0)),\n", - " )\n", - "\n", - " return timestep\n", - "\n", - "\n", - "def profile_minigrid(seed):\n", - " env = gym.make(\"MiniGrid-Empty-16x16-v0\", render_mode=None)\n", - " observation, info = env.reset(seed=42)\n", - " for _ in range(N_TIMESTEPS):\n", - " action = random.randint(0, 4)\n", - " observation, reward, terminated, truncated, info = env.step(action)\n", - "\n", - " if terminated or truncated:\n", - " observation, info = env.reset()\n", - " env.close()\n", - " return observation\n", - "\n", - "\n", - "if __name__ == \"__main__\":\n", - " # profile navix\n", - " print(\n", - " \"Profiling navix, N_SEEDS = {}, N_TIMESTEPS = {}\".format(N_SEEDS, N_TIMESTEPS)\n", - " )\n", - " seeds = jnp.arange(N_SEEDS)\n", - "\n", - " print(\"\\tCompiling...\")\n", - " start = time.time()\n", - " n_devices = jax.local_device_count()\n", - " seeds = seeds.reshape(n_devices, N_SEEDS // n_devices)\n", - " f = jax.vmap(profile_navix, axis_name=\"batch\")\n", - " f = jax.pmap(f, axis_name=\"device\")\n", - " f = f.lower(seeds).compile()\n", - " print(\"\\tCompiled in {:.2f}s\".format(time.time() - start))\n", - "\n", - " print(\"\\tRunning ...\")\n", - " res_navix = timeit(\n", - " lambda: f(seeds).state.grid.block_until_ready(), number=N_TIMEIT_LOOPS\n", - " )\n", - " print(res_navix)\n", - "\n", - " # profile minigrid\n", - " print(\"Profiling minigrid, N_SEEDS = 1, N_TIMESTEPS = {}\".format(N_TIMESTEPS))\n", - " res_minigrid = timeit(lambda: profile_minigrid(0), number=N_TIMEIT_LOOPS)\n", - " print(res_minigrid)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "5reQuYCeuP_q" - }, - "outputs": [], - "source": [] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "gpuType": "T4", - "include_colab_link": true, - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.9.16" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/docs/performance.py b/docs/performance.py index f5408fa..7179584 100644 --- a/docs/performance.py +++ b/docs/performance.py @@ -4,6 +4,7 @@ import gymnasium as gym import minigrid +from minigrid.wrappers import ImgObsWrapper import random import time @@ -11,7 +12,7 @@ N_TIMEIT_LOOPS = 5 -N_TIMESTEPS = 1_000 +N_TIMESTEPS = 10 N_SEEDS = 10_000 @@ -21,24 +22,27 @@ def profile_navix(seed): timestep = env.reset(key) actions = jax.random.randint(key, (N_TIMESTEPS,), 0, 6) - timestep, _ = jax.lax.while_loop( - lambda x: x[1] < N_TIMESTEPS, - lambda x: (env.step(x[0], actions[x[1]]), x[1] + 1), - (timestep, jnp.asarray(0)), - ) + # for loop + for i in range(N_TIMESTEPS): + timestep = env.step(timestep, actions[i]) return timestep def profile_minigrid(seed): - env = gym.make("MiniGrid-Empty-16x16-v0", render_mode=None) + num_envs = N_SEEDS // 1000 + env = gym.vector.make( + "MiniGrid-Empty-16x16-v0", + wrappers=ImgObsWrapper, + num_envs=num_envs, + render_mode=None, + asynchronous=True, + ) observation, info = env.reset(seed=42) for _ in range(N_TIMESTEPS): action = random.randint(0, 4) - observation, reward, terminated, truncated, info = env.step(action) + timestep = env.step([action] * num_envs) - if terminated or truncated: - observation, info = env.reset() env.close() return observation @@ -66,6 +70,10 @@ def profile_minigrid(seed): print(res_navix) # profile minigrid - print("Profiling minigrid, N_SEEDS = 1, N_TIMESTEPS = {}".format(N_TIMESTEPS)) + print( + "Profiling minigrid, N_SEEDS = {}, N_TIMESTEPS = {}".format( + N_TIMESTEPS, N_SEEDS // 1000 + ) + ) res_minigrid = timeit(lambda: profile_minigrid(0), number=N_TIMEIT_LOOPS) print(res_minigrid) diff --git a/navix/__init__.py b/navix/__init__.py index d809db3..e9b7e23 100644 --- a/navix/__init__.py +++ b/navix/__init__.py @@ -22,7 +22,6 @@ actions, components, entities, - graphics, grid, observations, tasks, @@ -30,4 +29,5 @@ terminations, config, spaces, + rendering ) diff --git a/navix/_version.py b/navix/_version.py index 5d0379a..0b700c7 100644 --- a/navix/_version.py +++ b/navix/_version.py @@ -18,5 +18,5 @@ # under the License. -__version__ = "0.3.14" +__version__ = "0.4.0" __version_info__ = tuple(int(i) for i in __version__.split(".") if i.isdigit()) diff --git a/navix/actions.py b/navix/actions.py index 4c605d5..ccae567 100644 --- a/navix/actions.py +++ b/navix/actions.py @@ -163,14 +163,15 @@ def open(state: State) -> State: door_found = positions_equal(position_in_front, doors.position) # and that, if so, either it does not require a key or the player has the key - requires_key = doors.requires != -1 + locked = doors.requires != -1 key_match = player.pocket == doors.requires - can_open = door_found & (key_match | ~requires_key) + can_open = door_found & (key_match | ~locked) # update doors if closed and can_open do_open = ~doors.open & can_open open = jnp.where(do_open, True, doors.open) - doors = doors.replace(open=open) + requires = jnp.where(do_open, -1, doors.requires) + doors = doors.replace(open=open, requires=requires) # remove key from player's pocket pocket = jnp.asarray(player.pocket * jnp.any(can_open), dtype=jnp.int32) diff --git a/navix/components.py b/navix/components.py index 18dc81b..c18a3af 100644 --- a/navix/components.py +++ b/navix/components.py @@ -54,6 +54,11 @@ class Directional(Component): """The direction the entity: 0 = east, 1 = south, 2 = west, 3 = north""" +class HasColour(Component): + colour: Array = field(shape=()) + """The colour of the object for rendering. """ + + class Stochastic(Component): probability: Array = field(shape=()) """The probability of receiving the reward, if reached.""" @@ -61,9 +66,10 @@ class Stochastic(Component): class Openable(Component): requires: Array = field(shape=()) - """The id of the item required to consume this item. If set, it must be >= 1.""" + """The id of the item required to consume this item. If set, it must be > 0. + If -1, the door is unlocked and does not require any key to open.""" open: Array = field(shape=()) - """Whether the item is open or not.""" + """Open is jnp.asarray(0) if the entity is closed and 1 if open.""" class Pickable(Component): diff --git a/navix/entities.py b/navix/entities.py index 65f818d..edfd6c0 100644 --- a/navix/entities.py +++ b/navix/entities.py @@ -6,11 +6,12 @@ from jax import Array import jax.numpy as jnp from flax import struct -from jax.random import KeyArray + from .components import ( Positionable, Directional, + HasColour, HasTag, Stochastic, Openable, @@ -18,7 +19,8 @@ Holder, HasSprite, ) -from .graphics import RenderingCache, SPRITES_REGISTRY +from .rendering.cache import RenderingCache +from .rendering.registry import PALETTE, SPRITES_REGISTRY from .config import config T = TypeVar("T", bound="Entity") @@ -64,10 +66,18 @@ def check_ndim(self, batched: bool = False) -> None: def __getitem__(self: T, idx) -> T: return jax.tree_util.tree_map(lambda x: x[idx], self) + @property + def name(self) -> str: + return self.__class__.__name__ + @property def shape(self) -> Tuple[int, ...]: """The batch shape of the entity""" - return self.position.shape[: self.position.ndim - 1] + return self.position.shape[:-1] + + @property + def ndim(self) -> int: + return self.position.ndim - 1 @property def walkable(self) -> Array: @@ -175,7 +185,7 @@ def tag(self) -> Array: return jnp.broadcast_to(jnp.asarray(3), self.shape) -class Key(Entity, Pickable): +class Key(Entity, Pickable, HasColour): """Pickable items are world objects that can be picked up by the player. Examples of pickable items are keys, coins, etc.""" @@ -183,9 +193,10 @@ class Key(Entity, Pickable): def create( cls, position: Array, + colour: Array, id: Array, ) -> Key: - return cls(position=position, id=id) + return cls(position=position, id=id, colour=colour) @property def walkable(self) -> Array: @@ -197,7 +208,7 @@ def transparent(self) -> Array: @property def sprite(self) -> Array: - sprite = SPRITES_REGISTRY[Entities.KEY] + sprite = SPRITES_REGISTRY[Entities.KEY][self.colour] if sprite.ndim == 3: # batch it sprite = sprite[None] @@ -211,7 +222,7 @@ def tag(self) -> Array: return jnp.broadcast_to(jnp.asarray(4), self.shape) -class Door(Entity, Directional, Openable): +class Door(Entity, Openable, HasColour): """Consumable items are world objects that can be consumed by the player. Consuming an item requires a tool (e.g. a key to open a door). A tool is an id (int) of another item, specified in the `requires` field (-1 if no tool is required). @@ -224,11 +235,16 @@ class Door(Entity, Directional, Openable): def create( cls, position: Array, - direction: Array, requires: Array, + colour: Array, open: Array, ) -> Door: - return cls(position=position, direction=direction, requires=requires, open=open) + return cls( + position=position, + requires=requires, + open=open, + colour=colour, + ) @property def walkable(self) -> Array: @@ -241,7 +257,7 @@ def transparent(self) -> Array: @property def sprite(self) -> Array: sprite = SPRITES_REGISTRY[Entities.DOOR][ - self.direction, jnp.asarray(self.open, dtype=jnp.int32) + self.colour, jnp.asarray(self.open + 2 * self.locked, dtype=jnp.int32) ] if sprite.ndim == 3: # batch it @@ -255,11 +271,15 @@ def sprite(self) -> Array: def tag(self) -> Array: return jnp.broadcast_to(jnp.asarray(5), self.shape) + @property + def locked(self) -> Array: + return self.requires != jnp.asarray(-1) + class State(struct.PyTreeNode): """The Markovian state of the environment""" - key: KeyArray + key: Array """The random number generator state""" grid: Array """The base map of the environment that remains constant throughout the training""" diff --git a/navix/environments/environment.py b/navix/environments/environment.py index 9c78b91..bc8d157 100644 --- a/navix/environments/environment.py +++ b/navix/environments/environment.py @@ -22,13 +22,12 @@ from typing import Any, Callable, Dict import jax import jax.numpy as jnp -from jax.random import KeyArray from jax import Array from flax import struct from .. import tasks, terminations, observations -from ..graphics import RenderingCache, TILE_SIZE +from ..rendering.cache import RenderingCache, TILE_SIZE from ..entities import State from ..actions import ACTIONS from ..spaces import Space, Discrete, Continuous @@ -111,7 +110,7 @@ def action_space(self) -> Space: return Discrete(len(ACTIONS)) @abc.abstractmethod - def reset(self, key: KeyArray, cache: RenderingCache | None = None) -> Timestep: + def reset(self, key: Array, cache: RenderingCache | None = None) -> Timestep: raise NotImplementedError() def _step(self, timestep: Timestep, action: Array, actions_set=ACTIONS) -> Timestep: diff --git a/navix/environments/keydoor.py b/navix/environments/keydoor.py index cb3e628..402ee29 100644 --- a/navix/environments/keydoor.py +++ b/navix/environments/keydoor.py @@ -1,10 +1,11 @@ import jax import jax.numpy as jnp -from jax.random import KeyArray +from jax import Array from typing import Union from ..components import EMPTY_POCKET_ID -from ..graphics import RenderingCache +from ..rendering.cache import RenderingCache +from ..rendering.registry import PALETTE from ..environments import Environment from ..entities import State, Player, Key, Door, Goal, Wall from ..environments import Timestep @@ -12,7 +13,7 @@ class KeyDoor(Environment): - def reset(self, key: KeyArray, cache: Union[RenderingCache, None] = None) -> Timestep: # type: ignore + def reset(self, key: Array, cache: Union[RenderingCache, None] = None) -> Timestep: # type: ignore # check minimum height and width assert ( self.height > 3 @@ -34,8 +35,8 @@ def reset(self, key: KeyArray, cache: Union[RenderingCache, None] = None) -> Tim doors = Door( position=door_pos, requires=jnp.asarray(3), - direction=jnp.asarray(0), open=jnp.asarray(False), + colour=PALETTE.YELLOW, ) # wall positions @@ -67,7 +68,7 @@ def reset(self, key: KeyArray, cache: Union[RenderingCache, None] = None) -> Tim # spawn key key_pos = random_positions(k2, first_room, exclude=player_pos) - keys = Key(position=key_pos, id=jnp.asarray(3)) + keys = Key(position=key_pos, id=jnp.asarray(3), colour=PALETTE.YELLOW) # mask the second room diff --git a/navix/environments/room.py b/navix/environments/room.py index 9946653..6e3e4e8 100644 --- a/navix/environments/room.py +++ b/navix/environments/room.py @@ -23,18 +23,18 @@ import jax import jax.numpy as jnp -from jax.random import KeyArray +from jax import Array from ..components import EMPTY_POCKET_ID from ..entities import Entities, Goal, Player, State from ..grid import random_positions, random_directions, room -from ..graphics import RenderingCache +from ..rendering.cache import RenderingCache from .environment import Environment, Timestep class Room(Environment): def reset( - self, key: KeyArray, cache: Union[RenderingCache, None] = None + self, key: Array, cache: Union[RenderingCache, None] = None ) -> Timestep: key, k1, k2 = jax.random.split(key, 3) diff --git a/navix/graphics.py b/navix/graphics.py deleted file mode 100644 index 0e31e86..0000000 --- a/navix/graphics.py +++ /dev/null @@ -1,392 +0,0 @@ -from __future__ import annotations - -from typing import Any, Dict, Tuple - -import jax -import jax.numpy as jnp -from flax import struct -from jax import Array - - -TILE_SIZE = 32 -RED = jnp.asarray([255, 0, 0], dtype=jnp.uint8) -GREEN = jnp.asarray([0, 255, 0], dtype=jnp.uint8) -BLUE = jnp.asarray([0, 0, 255], dtype=jnp.uint8) -BLACK = jnp.asarray([0, 0, 0], dtype=jnp.uint8) -WHITE = jnp.asarray([255, 255, 255], dtype=jnp.uint8) -YELLOW = jnp.asarray([255, 255, 0], dtype=jnp.uint8) -PURPLE = jnp.asarray([255, 0, 255], dtype=jnp.uint8) -CYAN = jnp.asarray([0, 255, 255], dtype=jnp.uint8) -ORANGE = jnp.asarray([255, 128, 0], dtype=jnp.uint8) -PINK = jnp.asarray([255, 0, 128], dtype=jnp.uint8) -BROWN = jnp.asarray([128, 64, 0], dtype=jnp.uint8) -GRAY_10 = jnp.asarray([230, 230, 230], dtype=jnp.uint8) -GRAY_20 = jnp.asarray([205, 205, 205], dtype=jnp.uint8) -GRAY_40 = jnp.asarray([153, 153, 153], dtype=jnp.uint8) -GRAY_50 = jnp.asarray([128, 128, 128], dtype=jnp.uint8) -GRAY_70 = jnp.asarray([77, 77, 77], dtype=jnp.uint8) -GRAY_80 = jnp.asarray([51, 51, 51], dtype=jnp.uint8) -GRAY_90 = jnp.asarray([25, 25, 25], dtype=jnp.uint8) -GOLD = jnp.asarray([255, 215, 0], dtype=jnp.uint8) -SILVER = jnp.asarray([192, 192, 192], dtype=jnp.uint8) -BRONZE = jnp.asarray([205, 127, 50], dtype=jnp.uint8) -MAROON = jnp.asarray([128, 0, 0], dtype=jnp.uint8) -NAVY = jnp.asarray([0, 0, 128], dtype=jnp.uint8) -TEAL = jnp.asarray([0, 128, 128], dtype=jnp.uint8) -OLIVE = jnp.asarray([128, 128, 0], dtype=jnp.uint8) -LIME = jnp.asarray([0, 255, 0], dtype=jnp.uint8) -AQUA = jnp.asarray([0, 255, 255], dtype=jnp.uint8) -FUCHSIA = jnp.asarray([255, 0, 255], dtype=jnp.uint8) -SALMON = jnp.asarray([250, 128, 114], dtype=jnp.uint8) -TURQUOISE = jnp.asarray([64, 224, 208], dtype=jnp.uint8) -VIOLET = jnp.asarray([238, 130, 238], dtype=jnp.uint8) -INDIGO = jnp.asarray([75, 0, 130], dtype=jnp.uint8) -BEIGE = jnp.asarray([245, 245, 220], dtype=jnp.uint8) -MINT = jnp.asarray([189, 252, 201], dtype=jnp.uint8) -LAVENDER = jnp.asarray([230, 230, 250], dtype=jnp.uint8) -APRICOT = jnp.asarray([251, 206, 177], dtype=jnp.uint8) -MAUVE = jnp.asarray([224, 176, 255], dtype=jnp.uint8) -LILAC = jnp.asarray([200, 162, 200], dtype=jnp.uint8) -TAN = jnp.asarray([210, 180, 140], dtype=jnp.uint8) - - -def colour_chart(size: int = TILE_SIZE) -> Array: - colours = [ - RED, - GREEN, - BLUE, - BLACK, - WHITE, - YELLOW, - PURPLE, - CYAN, - ORANGE, - PINK, - BROWN, - GRAY_20, - GRAY_50, - GRAY_70, - GRAY_90, - GOLD, - SILVER, - BRONZE, - MAROON, - NAVY, - TEAL, - OLIVE, - LIME, - AQUA, - FUCHSIA, - SALMON, - TURQUOISE, - VIOLET, - INDIGO, - BEIGE, - MINT, - LAVENDER, - APRICOT, - MAUVE, - LILAC, - TAN, - ] - grid = jnp.zeros((size * len(colours), size * len(colours), 3), dtype=jnp.uint8) - for i, colour in enumerate(colours): - for j, colour in enumerate(colours): - grid = grid.at[i * size : (i + 1) * size, j * size : (j + 1) * size].set( - colour - ) - return grid - - -class RenderingCache(struct.PyTreeNode): - patches: Array - """A flat set of patches representing the RGB values of each tile in the base map""" - - @classmethod - def init(cls, grid: Array) -> RenderingCache: - background = render_background(grid) - patches = flatten_patches(background) - - # add discard pile - patches = jnp.concatenate( - [ - patches, - jnp.zeros((1, TILE_SIZE, TILE_SIZE, 3), dtype=jnp.uint8), - ], - axis=0, - ) - return cls(patches=patches) - - -def colorise_tile(tile: Array, colour: Array, background: Array = WHITE) -> Array: - assert tile.shape == ( - TILE_SIZE, - TILE_SIZE, - ), "Tile must be of size TILE_SIZE, TILE_SIZE, 3, got {}".format(tile.shape) - tile = jnp.stack([tile] * colour.shape[0], axis=-1) - tile = jnp.where(tile, colour, background) - return tile - - -def render_rectangle(size: int = TILE_SIZE, colour: Array = BLACK) -> Array: - rectangle = jnp.ones((size - 2, size - 2), dtype=jnp.int32) - rectangle = jnp.pad(rectangle, 1, "constant", constant_values=0) - return colorise_tile(rectangle, colour) - - -def render_triangle_east(size: int = TILE_SIZE, colour: Array = RED) -> Array: - triangle = jnp.ones((size, size), dtype=jnp.int32) - triangle = jnp.tril(triangle, k=0) - triangle = jnp.flip(triangle, axis=0) - triangle = jnp.tril(triangle, k=0) - triangle = jnp.roll(triangle, (0, size // 3)) - return colorise_tile(triangle, colour) - - -def render_triangle_south(size: int = TILE_SIZE, colour: Array = RED) -> Array: - triangle = render_triangle_east(size) - triangle = jnp.rot90(triangle, k=3) - return colorise_tile(triangle, colour) - - -def render_triangle_west(size: int = TILE_SIZE, colour: Array = RED) -> Array: - triangle = render_triangle_east(size) - triangle = jnp.rot90(triangle, k=2) - return colorise_tile(triangle, colour) - - -def render_triangle_north(size: int = TILE_SIZE, colour: Array = RED) -> Array: - triangle = render_triangle_east(size) - triangle = jnp.rot90(triangle, k=1) - return colorise_tile(triangle, colour) - - -def render_diamond(size: int = TILE_SIZE, colour: Array = GOLD) -> Array: - diamond = jnp.ones((size, size), dtype=jnp.int32) - diamond = jnp.tril(diamond, k=size // 2.5) - diamond = jnp.flip(diamond, axis=0) - diamond = jnp.tril(diamond, k=size // 2.5) - diamond = jnp.flip(diamond, axis=1) - diamond = jnp.tril(diamond, k=size // 2.5) - diamond = jnp.flip(diamond, axis=0) - diamond = jnp.tril(diamond, k=size // 2.5) - return colorise_tile(diamond, colour) - - -def render_door_closed(size: int = TILE_SIZE, colour: Array = BROWN) -> Array: - frame_size = size - 6 - door = jnp.zeros((frame_size, frame_size), dtype=jnp.int32) - door = jnp.pad(door, 1, "constant", constant_values=1) - door = jnp.pad(door, 1, "constant", constant_values=0) - door = jnp.pad(door, 1, "constant", constant_values=1) - - x_0 = size - size // 4 - y_centre = size // 2 - y_size = size // 5 - door = door.at[y_centre - y_size // 2 : y_centre + y_size // 2, x_0 : x_0 + 1].set( - 1 - ) - return colorise_tile(door, colour) - - -def render_door_locked(size: int = TILE_SIZE, colour: Array = BROWN) -> Array: - frame_size = size - 4 - door = jnp.zeros((frame_size, frame_size), dtype=jnp.int32) - door = jnp.pad(door, 2, "constant", constant_values=1) - - x_0 = size - size // 4 - y_centre = size // 2 - y_size = size // 5 - door = door.at[y_centre - y_size // 2 : y_centre + y_size // 2, x_0 : x_0 + 1].set( - 1 - ) - return colorise_tile(door, colour, background=colour / 2) - - -def render_door_open(size: int = TILE_SIZE, colour: Array = BROWN) -> Array: - door = jnp.zeros((size, size), dtype=jnp.int32) - door = door.at[0].set(1) - door = door.at[3].set(1) - door = door.at[:3, 0].set(1) - door = door.at[:3, -1].set(1) - return colorise_tile(door, colour) - - -def render_key(size: int = TILE_SIZE, colour: Array = BRONZE) -> Array: - key = jnp.zeros((size, size), dtype=jnp.int32) - - # Handle (Round Part) - handle_radius = size // 4 - handle_center = (size // 2, size // 4) - y, x = jnp.ogrid[:size, :size] - mask = (x - handle_center[0]) ** 2 + ( - y - handle_center[1] - ) ** 2 <= handle_radius**2 - key = jnp.where(mask, 1, key) - - # Shaft (Straight Part) - shaft_width = size // 8 - shaft_height = size // 2 - shaft_start = (size // 2 - shaft_width // 2, size // 2 - shaft_height // 2) - shaft_end = (size // 2 + shaft_width // 2, size // 2 + shaft_height // 2) - shaft_mask = jnp.logical_and( - jnp.logical_and(x >= shaft_start[0], x <= shaft_end[0]), - jnp.logical_and(y >= shaft_start[1], y <= shaft_end[1]), - ) - key = jnp.where(shaft_mask, 1, key) - - # Tooth (Pointy End) - tooth_width = size // 15 - tooth_height = size // 2 - tooth_position = (size // 2 - tooth_width // 2, size - tooth_height) - tooth_mask = jnp.logical_and( - jnp.logical_and(x >= tooth_position[0], x <= tooth_position[0] + tooth_width), - jnp.logical_and(y >= tooth_position[1], y <= tooth_position[1] + tooth_height), - ) - key = jnp.where(tooth_mask, 1, key) - - return colorise_tile(key, colour) - - -def render_floor(size: int = TILE_SIZE, colour: Array = WHITE) -> Array: - floor = jnp.ones((size - 2, size - 2), dtype=jnp.int32) - floor = jnp.pad(floor, 1, "constant", constant_values=0) - return colorise_tile(floor, colour, background=GRAY_10) - - -def render_wall(size: int = TILE_SIZE, colour: Array = GRAY_80) -> Array: - wall = jnp.ones((size, size), dtype=jnp.int32) - return colorise_tile(wall, colour) - - -def tile_grid(grid: Array, tile: Array) -> Array: - tiled = jnp.tile(tile, (*grid.shape, 1)) - return jnp.asarray(tiled, dtype=jnp.uint8) - - -@jax.jit -def build_sprites_registry() -> Dict[str, Any]: - registry = {} - - wall = render_wall() - floor = render_floor() - player = render_triangle_east() - goal = render_diamond() - key = render_key() - door_closed = render_door_closed() - door_open = render_door_open() - - # 0: set wall sprites - registry["wall"] = wall - - # 1: set floor sprites - registry["floor"] = floor - - # 2: set player sprites - registry["player"] = jnp.stack( - [ - player, - jnp.rot90(player, k=3), - jnp.rot90(player, k=2), - jnp.rot90(player, k=1), - ] - ) - - # 3: set goal sprites - registry["goal"] = goal - - # 4: set key sprites - registry["key"] = key - - # 5: set door sprites - door = jnp.zeros((4, 2, TILE_SIZE, TILE_SIZE, 3), dtype=jnp.uint8) - - door_closed_by_direction = jnp.stack( - [ - jnp.rot90(door_closed, k=1), - door_closed, - jnp.rot90(door_closed, k=3), - jnp.rot90(door_closed, k=2), - ] - ) - door = door.at[:, 0].set(door_closed_by_direction) - - door_open_by_direction = jnp.stack( - [ - door_open, - jnp.rot90(door_open, k=1), - jnp.rot90(door_open, k=2), - jnp.rot90(door_open, k=3), - ] - ) - door = door.at[:, 1].set(door_open_by_direction) - - registry["door"] = door - - return registry - - -SPRITES_REGISTRY: Dict[str, Any] = build_sprites_registry() - - -def render_background( - grid: Array, sprites_registry: Dict[str, Any] = SPRITES_REGISTRY -) -> Array: - image_width = grid.shape[0] * TILE_SIZE - image_height = grid.shape[1] * TILE_SIZE - n_channels = 3 - - background = jnp.zeros((image_height, image_width, n_channels), dtype=jnp.uint8) - grid_resized = jax.image.resize( - grid, (grid.shape[0] * TILE_SIZE, grid.shape[1] * TILE_SIZE), method="nearest" - ) - - mask = jnp.asarray(grid_resized, dtype=bool) # 0 = floor, 1 = wall - # index by [entity_type, direction, open/closed, y, x, channel] - wall_tile = tile_grid(grid, sprites_registry["wall"]) - floor_tile = tile_grid(grid, sprites_registry["floor"]) - background = jnp.where(mask[..., None], wall_tile, floor_tile) - return background - - -def flatten_patches( - image: Array, patch_size: Tuple[int, int] = (TILE_SIZE, TILE_SIZE) -) -> Array: - height = image.shape[0] // patch_size[0] - width = image.shape[1] // patch_size[1] - n_channels = image.shape[2] - - grid = image.reshape(height, patch_size[0], width, patch_size[1], n_channels) - - # Swap the first and second axes of the grid to revert the stacking order - grid = jnp.swapaxes(grid, 1, 2) - - # Reshape the grid of tiles into the original list of tiles - patches = grid.reshape(height * width, patch_size[0], patch_size[1], n_channels) - - return patches - - -def unflatten_patches(patches: Array, image_size: Tuple[int, int]) -> Array: - image_height = image_size[0] - image_width = image_size[1] - patch_height = patches.shape[1] - patch_width = patches.shape[2] - n_channels = patches.shape[3] - - # Reshape the list of tiles into a 2D grid - grid = patches.reshape( - image_height // patch_height, - image_width // patch_width, - patch_height, - patch_width, - n_channels, - ) - - # Swap the first and second axes of the grid to change the order of stacking - grid = jnp.swapaxes(grid, 1, 2) - - # Reshape and stack the grid tiles horizontally and vertically to form the final image - image = grid.reshape(image_height, image_width, n_channels) - - return image diff --git a/navix/grid.py b/navix/grid.py index eb8440e..1b05f0b 100644 --- a/navix/grid.py +++ b/navix/grid.py @@ -24,7 +24,6 @@ from typing import Callable, Dict, Tuple import jax import jax.numpy as jnp -from jax.random import KeyArray from jax import Array @@ -32,7 +31,7 @@ def coordinates(grid: Array) -> Coordinates: - return tuple(jnp.mgrid[0 : grid.shape[0], 0 : grid.shape[1]]) + return tuple(jnp.mgrid[0 : grid.shape[0], 0 : grid.shape[1]]) # type: ignore def idx_from_coordinates(grid: Array, coordinates: Array): @@ -110,7 +109,7 @@ def align(patch: Array, current_direction: Array, desired_direction: Array) -> A def random_positions( - key: KeyArray, grid: Array, n: int = 1, exclude: Array = jnp.asarray((-1, -1)) + key: Array, grid: Array, n: int = 1, exclude: Array = jnp.asarray((-1, -1)) ) -> Array: probs = grid.reshape(-1) indices = idx_from_coordinates(grid, exclude) @@ -120,7 +119,7 @@ def random_positions( return position.squeeze() -def random_directions(key: KeyArray, n=1) -> Array: +def random_directions(key: Array, n=1) -> Array: return jax.random.randint(key, (n,), 0, 4).squeeze() @@ -141,7 +140,7 @@ def room(height: int, width: int): return jnp.pad(grid, 1, mode="constant", constant_values=-1) -def two_rooms(height: int, width: int, key: KeyArray) -> Tuple[Array, Array]: +def two_rooms(height: int, width: int, key: Array) -> Tuple[Array, Array]: """Two rooms separated by a vertical wall at `width // 2`""" # create room grid = jnp.zeros((height - 2, width - 2), dtype=jnp.int32) diff --git a/navix/observations.py b/navix/observations.py index 2ff6d0e..fbe3bfb 100644 --- a/navix/observations.py +++ b/navix/observations.py @@ -24,7 +24,7 @@ import jax.numpy as jnp from jax import Array -from . import graphics +from .rendering.cache import TILE_SIZE, unflatten_patches from .components import DISCARD_PILE_IDX from .entities import State from .grid import align, idx_from_coordinates, crop, view_cone @@ -92,10 +92,10 @@ def rgb( patches = patches[:DISCARD_PILE_IDX] # unflatten patches to reconstruct the image image_size = ( - state.grid.shape[0] * graphics.TILE_SIZE, - state.grid.shape[1] * graphics.TILE_SIZE, + state.grid.shape[0] * TILE_SIZE, + state.grid.shape[1] * TILE_SIZE, ) - image = graphics.unflatten_patches(patches, image_size) + image = unflatten_patches(patches, image_size) return image @@ -104,8 +104,8 @@ def rgb_first_person( ) -> Array: # calculate final image size image_size = ( - state.grid.shape[0] * graphics.TILE_SIZE, - state.grid.shape[1] * graphics.TILE_SIZE, + state.grid.shape[0] * TILE_SIZE, + state.grid.shape[1] * TILE_SIZE, ) # get agent's view diff --git a/navix/rendering/__init__.py b/navix/rendering/__init__.py new file mode 100644 index 0000000..6afdaf9 --- /dev/null +++ b/navix/rendering/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2023 The Navix Authors. + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +from . import ( + cache, + registry +) diff --git a/navix/rendering/cache.py b/navix/rendering/cache.py new file mode 100644 index 0000000..53710c3 --- /dev/null +++ b/navix/rendering/cache.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +from typing import Dict, Tuple + +import jax +from jax import Array +import jax.numpy as jnp +from flax import struct + +from .registry import SpritesRegistry, TILE_SIZE, SPRITES_REGISTRY + + +class RenderingCache(struct.PyTreeNode): + patches: Array + """A flat set of patches representing the RGB values of each tile in the base map""" + + @classmethod + def init(cls, grid: Array) -> RenderingCache: + background = render_background(grid) + patches = flatten_patches(background) + + # add discard pile + patches = jnp.concatenate( + [ + patches, + jnp.zeros((1, TILE_SIZE, TILE_SIZE, 3), dtype=jnp.uint8), + ], + axis=0, + ) + return cls(patches=patches) + + +def render_background( + grid: Array, sprites_registry: Dict[str, Array] = SPRITES_REGISTRY +) -> Array: + image_width = grid.shape[0] * TILE_SIZE + image_height = grid.shape[1] * TILE_SIZE + n_channels = 3 + + background = jnp.zeros((image_height, image_width, n_channels), dtype=jnp.uint8) + grid_resized = jax.image.resize( + grid, (grid.shape[0] * TILE_SIZE, grid.shape[1] * TILE_SIZE), method="nearest" + ) + + mask = jnp.asarray(grid_resized, dtype=bool) # 0 = floor, 1 = wall + # index by [entity_type, direction, open/closed, y, x, channel] + wall_tile = tile_grid(grid, sprites_registry["wall"]) + floor_tile = tile_grid(grid, sprites_registry["floor"]) + background = jnp.where(mask[..., None], wall_tile, floor_tile) + return background + + +def tile_grid(grid: Array, tile: Array) -> Array: + """Tiles a grid (H, W) with equal tiles `tiles` (w, h, 3) to get a final array + of shape (H * h, W * w, 3) and dtype `jnp.uint8`""" + tiled = jnp.tile(tile, (*grid.shape, 1)) + return jnp.asarray(tiled, dtype=jnp.uint8) + + +def flatten_patches( + image: Array, patch_size: Tuple[int, int] = (TILE_SIZE, TILE_SIZE) +) -> Array: + height = image.shape[0] // patch_size[0] + width = image.shape[1] // patch_size[1] + n_channels = image.shape[2] + + grid = image.reshape(height, patch_size[0], width, patch_size[1], n_channels) + + # Swap the first and second axes of the grid to revert the stacking order + grid = jnp.swapaxes(grid, 1, 2) + + # Reshape the grid of tiles into the original list of tiles + patches = grid.reshape(height * width, patch_size[0], patch_size[1], n_channels) + + return patches + + +def unflatten_patches(patches: Array, image_size: Tuple[int, int]) -> Array: + image_height = image_size[0] + image_width = image_size[1] + patch_height = patches.shape[1] + patch_width = patches.shape[2] + n_channels = patches.shape[3] + + # Reshape the list of tiles into a 2D grid + grid = patches.reshape( + image_height // patch_height, + image_width // patch_width, + patch_height, + patch_width, + n_channels, + ) + + # Swap the first and second axes of the grid to change the order of stacking + grid = jnp.swapaxes(grid, 1, 2) + + # Reshape and stack the grid tiles horizontally and vertically to form the final image + image = grid.reshape(image_height, image_width, n_channels) + + return image diff --git a/navix/rendering/registry.py b/navix/rendering/registry.py new file mode 100644 index 0000000..f899519 --- /dev/null +++ b/navix/rendering/registry.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +import os +from typing import Dict, Tuple +from PIL import Image + +import jax +from jax import Array +import jax.numpy as jnp + + +SPRITES_DIR = os.path.normpath( + os.path.join(__file__, "..", "..", "..", "assets", "sprites") +) +MIN_TILE__SIZE = 8 +TILE_SIZE = 32 + + +def load_sprite(name: str) -> Array: + """Loads an image from disk in RGB space. + Args: + path(str): the filepath of the image on disk + + Returns: + (Array): a jax.Array of shape (H, W, C)""" + path = os.path.join(SPRITES_DIR, f"{name}.png") + image = Image.open(path) + array = jnp.asarray(image) + resized = jax.image.resize(array, (TILE_SIZE, TILE_SIZE, 3), method="nearest") + return resized + + +class PALETTE: + RED: Array = jnp.asarray(0) + GREEN: Array = jnp.asarray(1) + BLUE: Array = jnp.asarray(2) + PURPLE: Array = jnp.asarray(3) + YELLOW: Array = jnp.asarray(4) + GREY: Array = jnp.asarray(5) + + @classmethod + def as_string(cls): + return ["red", "green", "blue", "purple", "yellow", "grey"] + + @classmethod + def as_array(cls): + return [cls.RED, cls.GREEN, cls.BLUE, cls.PURPLE, cls.YELLOW, cls.GREY] + + +class SpritesRegistry: + def __init__(self): + self.registry = {} + self.build_registry() + + def build_registry(self): + """Populates the sprites registry for all entities.""" + self.set_wall_sprite() + self.set_floor_sprite() + self.set_goal_sprite() + self.set_key_sprite() + self.set_player_sprite() + self.set_door_sprite() + + def set_wall_sprite(self): + self.registry["wall"] = load_sprite("wall") + + def set_floor_sprite(self): + self.registry["floor"] = load_sprite("floor") + + def set_goal_sprite(self): + self.registry["goal"] = load_sprite("goal") + + def set_key_sprite(self): + keys_coloured = [ + load_sprite("key" + f"_{colour}") for colour in PALETTE.as_string() + ] + self.registry["key"] = jnp.stack(keys_coloured, axis=0) + + def set_player_sprite(self): + self.registry["player"] = jnp.stack( + [ + load_sprite("player_east"), + load_sprite("player_south"), + load_sprite("player_west"), + load_sprite("player_north"), + ] + ) + + def set_door_sprite(self): + door = jnp.zeros( + (len(PALETTE.as_string()), 3, TILE_SIZE, TILE_SIZE, 3), dtype=jnp.uint8 + ) + for c_idx, colour in enumerate(PALETTE.as_string()): + for s_idx, state in enumerate(["closed", "open", "locked"]): + sprite = load_sprite("door" + f"_{state}" + f"_{colour}") + door = door.at[c_idx, s_idx].set(sprite) + self.registry["door"] = door + + +# initialise sprites registry +SPRITES_REGISTRY = SpritesRegistry().registry diff --git a/navix/spaces.py b/navix/spaces.py index 748908d..f4d37b8 100644 --- a/navix/spaces.py +++ b/navix/spaces.py @@ -14,14 +14,15 @@ from __future__ import annotations +from typing import Any, Sequence, Union import jax import jax.numpy as jnp -from jax.random import KeyArray -from jax.core import Shape from jax import Array -from jax.core import ShapedArray, Shape +from jax.core import ShapedArray + +Shape = Sequence[Union[int, Any]] MIN_INT = jax.numpy.iinfo(jnp.int16).min @@ -39,7 +40,7 @@ def __repr__(self): super().__repr__()[:-1], self.minimum, self.maximum ) - def sample(self, key: KeyArray) -> Array: + def sample(self, key: Array) -> Array: raise NotImplementedError() @@ -49,10 +50,8 @@ def __init__(self, n_elements: int = MAX_INT, shape: Shape = (), dtype=jnp.int32 self.minimum = jnp.asarray(0) self.maximum = jnp.asarray(n_elements - 1) - def sample(self, key: KeyArray) -> Array: - item = jax.random.randint( - key, self.shape, self.minimum, self.maximum - ) + def sample(self, key: Array) -> Array: + item = jax.random.randint(key, self.shape, self.minimum, self.maximum) # randint cannot draw jnp.uint, so we cast it later return jnp.asarray(item, dtype=self.dtype) @@ -63,7 +62,7 @@ def __init__(self, shape: Shape = (), minimum=MIN_INT_ARR, maximum=MAX_INT_ARR): self.minimum = minimum self.maximum = maximum - def sample(self, key: KeyArray) -> Array: + def sample(self, key: Array) -> Array: assert jnp.issubdtype(self.dtype, jnp.floating) # see: https://github.com/google/jax/issues/14003 lower = jnp.nan_to_num(self.minimum) diff --git a/pyproject.toml b/pyproject.toml index 9a6fc9c..0bc110b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,7 +63,7 @@ dependencies = {file = "./requirements.txt"} [tool.setuptools.packages.find] -include = ["navix*"] +include = ["navix*", "assets*"] exclude = ["tests", "examples", "scripts", "docs"] diff --git a/requirements.txt b/requirements.txt index 4758433..6ec6554 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,7 @@ setuptools_scm pytest absl-py # compute libraries +pillow numpy jax chex diff --git a/tests/test_actions.py b/tests/test_actions.py index a70ea8d..918dbce 100644 --- a/tests/test_actions.py +++ b/tests/test_actions.py @@ -5,6 +5,7 @@ import navix as nx from navix.components import EMPTY_POCKET_ID, DISCARD_PILE_COORDS from navix.entities import Entities, Entity, State +from navix.rendering.registry import PALETTE def test_rotation(): @@ -15,7 +16,7 @@ def test_rotation(): player = nx.entities.Player( position=jnp.asarray((1, 1)), direction=direction, pocket=EMPTY_POCKET_ID )[None] - cache = nx.graphics.RenderingCache.init(grid) + cache = nx.rendering.cache.RenderingCache.init(grid) entities: Dict[str, Entity] = { Entities.PLAYER: player, @@ -62,14 +63,16 @@ def test_move(): position=jnp.asarray((1, 1)), direction=jnp.asarray(0), pocket=EMPTY_POCKET_ID ) goals = nx.entities.Goal(position=jnp.asarray((3, 3)), probability=jnp.asarray(1.0)) - keys = nx.entities.Key(position=jnp.asarray((3, 1)), id=jnp.asarray(-1)) + keys = nx.entities.Key( + position=jnp.asarray((3, 1)), id=jnp.asarray(-1), colour=PALETTE.YELLOW + ) doors = nx.entities.Door( position=jnp.asarray((2, 2)), - direction=jnp.asarray(0), requires=jnp.asarray(-1), open=jnp.asarray(False), + colour=PALETTE.YELLOW, ) - cache = nx.graphics.RenderingCache.init(grid) + cache = nx.rendering.cache.RenderingCache.init(grid) player.check_ndim(batched=False) goals.check_ndim(batched=False) @@ -160,14 +163,16 @@ def test_walkable(): position=jnp.asarray((1, 1)), direction=jnp.asarray(0), pocket=EMPTY_POCKET_ID ) goals = nx.entities.Goal(position=jnp.asarray((3, 3)), probability=jnp.asarray(1.0)) - keys = nx.entities.Key(position=jnp.asarray((3, 1)), id=jnp.asarray(1)) + keys = nx.entities.Key( + position=jnp.asarray((3, 1)), id=jnp.asarray(1), colour=PALETTE.YELLOW + ) doors = nx.entities.Door( position=jnp.asarray((1, 3)), - direction=jnp.asarray(0), requires=jnp.asarray(1), open=jnp.asarray(False), + colour=PALETTE.YELLOW, ) - cache = nx.graphics.RenderingCache.init(grid) + cache = nx.rendering.cache.RenderingCache.init(grid) player.check_ndim(batched=False) goals.check_ndim(batched=False) @@ -245,14 +250,16 @@ def test_pickup(): position=jnp.asarray((1, 1)), direction=jnp.asarray(1), pocket=EMPTY_POCKET_ID ) goals = nx.entities.Goal(position=jnp.asarray((3, 3)), probability=jnp.asarray(1.0)) - keys = nx.entities.Key(position=jnp.asarray((2, 1)), id=jnp.asarray(1)) + keys = nx.entities.Key( + position=jnp.asarray((2, 1)), id=jnp.asarray(1), colour=PALETTE.YELLOW + ) doors = nx.entities.Door( position=jnp.asarray((1, 3)), - direction=jnp.asarray(0), requires=jnp.asarray(1), open=jnp.asarray(False), + colour=PALETTE.YELLOW, ) - cache = nx.graphics.RenderingCache.init(grid) + cache = nx.rendering.cache.RenderingCache.init(grid) # Looks like this """ @@ -311,14 +318,16 @@ def test_open(): position=jnp.asarray((1, 1)), direction=jnp.asarray(0), pocket=EMPTY_POCKET_ID ) goals = nx.entities.Goal(position=jnp.asarray((3, 3)), probability=jnp.asarray(1.0)) - keys = nx.entities.Key(position=jnp.asarray((3, 1)), id=jnp.asarray(1)) + keys = nx.entities.Key( + position=jnp.asarray((3, 1)), id=jnp.asarray(1), colour=PALETTE.YELLOW + ) doors = nx.entities.Door( position=jnp.asarray((1, 3)), - direction=jnp.asarray(0), requires=jnp.asarray(1), open=jnp.asarray(False), + colour=PALETTE.YELLOW ) - cache = nx.graphics.RenderingCache.init(grid) + cache = nx.rendering.cache.RenderingCache.init(grid) player.check_ndim(batched=False) goals.check_ndim(batched=False) diff --git a/tests/test_entities.py b/tests/test_entities.py index 9283455..f7205e8 100644 --- a/tests/test_entities.py +++ b/tests/test_entities.py @@ -3,6 +3,7 @@ import navix as nx from navix.entities import Goal, Player +from navix.rendering.registry import TILE_SIZE def test_indexing(): @@ -19,11 +20,11 @@ def test_indexing(): def test_get_sprites(): # batched entity with batch size 1 entity = Goal(position=jnp.ones((1, 2)), probability=jnp.ones((1,))) - assert entity.sprite.shape == (1, nx.graphics.TILE_SIZE, nx.graphics.TILE_SIZE, 3) + assert entity.sprite.shape == (1, TILE_SIZE, TILE_SIZE, 3) # batched entity with batch size > 1 entity = Goal(position=jnp.ones((5, 2)), probability=jnp.ones((5,))) - assert entity.sprite.shape == (5, nx.graphics.TILE_SIZE, nx.graphics.TILE_SIZE, 3) + assert entity.sprite.shape == (5, TILE_SIZE, TILE_SIZE, 3) if __name__ == "__main__": diff --git a/tests/test_observations.py b/tests/test_observations.py index 22c6b59..5f10197 100644 --- a/tests/test_observations.py +++ b/tests/test_observations.py @@ -4,6 +4,8 @@ import navix as nx from navix.entities import Entities, Player, Goal, Key, Door from navix.components import EMPTY_POCKET_ID +from navix.rendering.cache import RenderingCache, TILE_SIZE +from navix.rendering.registry import SPRITES_REGISTRY, PALETTE def test_rgb(): @@ -16,12 +18,12 @@ def test_rgb(): position=jnp.asarray((1, 1)), direction=jnp.asarray(0), pocket=EMPTY_POCKET_ID ) goals = Goal(position=jnp.asarray((4, 4)), probability=jnp.asarray(1.0)) - keys = Key(position=jnp.asarray((2, 2)), id=jnp.asarray(0)) + keys = Key(position=jnp.asarray((2, 2)), id=jnp.asarray(0), colour=PALETTE.YELLOW) doors = Door( position=jnp.asarray([(1, 5), (1, 6)]), - direction=jnp.asarray((0, 2)), requires=jnp.asarray((0, 0)), open=jnp.asarray((False, True)), + colour=PALETTE.YELLOW[None], ) entities = { @@ -34,10 +36,10 @@ def test_rgb(): state = nx.entities.State( key=jax.random.PRNGKey(0), grid=grid, - cache=nx.graphics.RenderingCache.init(grid), + cache=RenderingCache.init(grid), entities=entities, ) - sprites_registry = nx.graphics.SPRITES_REGISTRY + sprites_registry = SPRITES_REGISTRY doors = state.get_doors() doors = doors.replace(open=jnp.asarray((False, True))) @@ -45,8 +47,8 @@ def test_rgb(): obs = nx.observations.rgb(state) expected_obs_shape = ( - height * nx.graphics.TILE_SIZE, - width * nx.graphics.TILE_SIZE, + height * TILE_SIZE, + width * TILE_SIZE, 3, ) assert ( @@ -54,9 +56,9 @@ def test_rgb(): ), f"Expected observation {expected_obs_shape}, got {obs.shape} instead" def get_tile(position): - x = position[0] * nx.graphics.TILE_SIZE - y = position[1] * nx.graphics.TILE_SIZE - return obs[x : x + nx.graphics.TILE_SIZE, y : y + nx.graphics.TILE_SIZE, :] + x = position[0] * TILE_SIZE + y = position[1] * TILE_SIZE + return obs[x : x + TILE_SIZE, y : y + TILE_SIZE, :] player = state.get_player() player_tile = get_tile(player.position) @@ -70,21 +72,26 @@ def get_tile(position): keys = state.get_keys() key_tile = get_tile(keys.position[0]) - assert jnp.array_equal(key_tile, sprites_registry[Entities.KEY]), key_tile + colour = keys.colour[0] + assert jnp.array_equal( + key_tile, sprites_registry[Entities.KEY][colour] + ), key_tile doors = state.get_doors() - door_tile = get_tile(doors.position[0]) - direction = doors.direction[0] - open = jnp.asarray(doors.open[0], dtype=jnp.int32) + door = doors[0] + door_tile = get_tile(door.position) + colour = door.colour + idx = jnp.asarray(door.open + 2 * door.locked, dtype=jnp.int32) assert jnp.array_equal( - door_tile, sprites_registry[Entities.DOOR][direction, open] + door_tile, sprites_registry[Entities.DOOR][colour, idx] ), door_tile - door_tile = get_tile(doors.position[1]) - direction = doors.direction[1] - open = jnp.asarray(doors.open[1], dtype=jnp.int32) + door = doors[1] + door_tile = get_tile(door.position) + colour = door.colour + idx = jnp.asarray(door.open + 2 * door.locked, dtype=jnp.int32) assert jnp.array_equal( - door_tile, sprites_registry[Entities.DOOR][direction, open] + door_tile, sprites_registry[Entities.DOOR][colour, idx] ), door_tile return @@ -100,12 +107,12 @@ def test_categorical_first_person(): position=jnp.asarray((1, 1)), direction=jnp.asarray(0), pocket=EMPTY_POCKET_ID ) goals = Goal(position=jnp.asarray((4, 4)), probability=jnp.asarray(1.0)) - keys = Key(position=jnp.asarray((2, 2)), id=jnp.asarray(0)) + keys = Key(position=jnp.asarray((2, 2)), id=jnp.asarray(0), colour=PALETTE.YELLOW) doors = Door( position=jnp.asarray([(1, 5), (1, 6)]), - direction=jnp.asarray((0, 2)), requires=jnp.asarray((0, 0)), open=jnp.asarray((False, True)), + colour=PALETTE.YELLOW, ) entities = { Entities.PLAYER: players[None], @@ -117,7 +124,7 @@ def test_categorical_first_person(): state = nx.entities.State( key=jax.random.PRNGKey(0), grid=grid, - cache=nx.graphics.RenderingCache.init(grid), + cache=RenderingCache.init(grid), entities=entities, ) @@ -127,5 +134,5 @@ def test_categorical_first_person(): if __name__ == "__main__": test_rgb() - test_categorical_first_person() + # test_categorical_first_person() # jax.jit(test_categorical_first_person)() diff --git a/tests/test_tasks.py b/tests/test_tasks.py index ac740dc..2e35bfd 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -4,6 +4,7 @@ import navix as nx from navix.entities import Entities, Player, Goal, Key, Door from navix.components import EMPTY_POCKET_ID +from navix.rendering.registry import PALETTE def test_navigation(): @@ -16,13 +17,15 @@ def test_navigation(): players = Player( position=jnp.asarray((1, 1)), direction=jnp.asarray(0), pocket=EMPTY_POCKET_ID ) - goals = Goal(position=jnp.asarray([(1, 1), (1, 1)]), probability=jnp.asarray([0.0, 0.0])) - keys = Key(position=jnp.asarray((2, 2)), id=jnp.asarray(0)) + goals = Goal( + position=jnp.asarray([(1, 1), (1, 1)]), probability=jnp.asarray([0.0, 0.0]) + ) + keys = Key(position=jnp.asarray((2, 2)), id=jnp.asarray(0), colour=PALETTE.YELLOW) doors = Door( position=jnp.asarray([(1, 5), (1, 6)]), - direction=jnp.asarray((0, 2)), requires=jnp.asarray((0, 0)), open=jnp.asarray((False, True)), + colour=PALETTE.YELLOW, ) entities = { @@ -35,7 +38,7 @@ def test_navigation(): state = nx.entities.State( key=jax.random.PRNGKey(0), grid=grid, - cache=nx.graphics.RenderingCache.init(grid), + cache=nx.rendering.cache.RenderingCache.init(grid), entities=entities, ) action = jnp.asarray(0)