Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
sash-a authored Nov 4, 2024
2 parents c955320 + 66dfc93 commit 6b34657
Show file tree
Hide file tree
Showing 23 changed files with 128 additions and 129 deletions.
27 changes: 22 additions & 5 deletions .github/workflows/tests_linters.yml
Original file line number Diff line number Diff line change
@@ -1,34 +1,51 @@
name: Tests and Linters 🧪

on: [ push, pull_request ]
on: [ pull_request ]

jobs:
tests-and-linters:
name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}"
runs-on: "${{ matrix.os }}"
timeout-minutes: 20

strategy:
matrix:
python-version: ["3.8", "3.9"]
python-version: ["3.10", "3.11", "3.12"]
os: [ubuntu-latest]

steps:
- name: Install dependencies for viewer test
run: sudo apt-get update && sudo apt-get install -y xvfb

- name: Checkout jumanji 🐍
uses: actions/checkout@v3
- uses: actions/setup-python@v4
uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v3
with:
version: "0.4.26"
enable-cache: true
cache-dependency-glob: "requirements/requirements**.txt" # invalidate cache when requirements file changes

- uses: actions/setup-python@v5
with:
python-version: "${{ matrix.python-version }}"

- name: Install python dependencies 🔧
run: pip install .[dev,train]
run: uv pip install .[dev,train]
env:
UV_SYSTEM_PYTHON: 1

- name: Run linters 🖌️
run: pre-commit run --all-files --verbose

- name: Run tests 🧪
run: pytest -n 2 --cov=jumanji --cov-report=term-missing --junit-xml=test-results.xml -vv jumanji

- name: Run coverage
run: |
coverage html --directory=coverage_html_report
coverage report --fail-under=0.97
- name: Test build docs 📖
run: mkdocs build --verbose --site-dir docs_public
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ cython_debug/
# MacBook Finder
.DS_Store

3.8/
3.10/
jumanji_env/
**/outputs/
*.xml
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ repos:
name: "Trailing whitespace fixer"

- repo: https://github.com/PyCQA/flake8
rev: 3.9.2
rev: 7.1.1
hooks:
- id: flake8
name: "Linter"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ Alternatively, you can install the latest development version directly from GitH
pip install git+https://github.com/instadeepai/jumanji.git
```

Jumanji has been tested on Python 3.8 and 3.9.
Jumanji has been tested on Python 3.10, 3.11 and 3.12.
Note that because the installation of JAX differs depending on your hardware accelerator,
we advise users to explicitly install the correct JAX version (see the
[official installation guide](https://github.com/google/jax#installation)).
Expand Down
2 changes: 1 addition & 1 deletion jumanji/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def reward_spec(self) -> specs.Array:

@cached_property
def discount_spec(self) -> specs.BoundedArray:
"""Returns the discount spec. By default, this is assumed to be a single float between 0 and 1.
"""Returns the discount spec. By default, this is assumed to be a float between 0 and 1.
Returns:
discount_spec: a `specs.BoundedArray` spec.
Expand Down
2 changes: 1 addition & 1 deletion jumanji/environments/logic/game_2048/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def make_frame(state_index: int) -> None:
return self._animation

def get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]:
"""This function returns a `Matplotlib` figure and axes object for displaying the 2048 game board.
"""This function returns a `Matplotlib` figure and axes for displaying the 2048 game board.
Returns:
A tuple containing the figure and axes objects.
Expand Down
2 changes: 1 addition & 1 deletion jumanji/environments/logic/graph_coloring/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def animate(
interval: int = 200,
save_path: Optional[str] = None,
) -> animation.FuncAnimation:
"""Creates an animated gif of the `GraphColoring` environment based on the sequence of game states.
"""Creates an animated gif of the `GraphColoring` environment based on a sequence of states.
Args:
states: is a list of `State` objects representing the sequence of game states.
Expand Down
4 changes: 2 additions & 2 deletions jumanji/environments/logic/sliding_tile_puzzle/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def animate(
interval: int = 200,
save_path: Optional[str] = None,
) -> matplotlib.animation.FuncAnimation:
"""Creates an animated gif of the sliding tiles puzzle game based on the sequence of game states.
"""Creates an animated gif of the sliding tiles puzzle game based on a sequence of states.
Args:
states: is a list of `State` objects representing the sequence of game states.
Expand Down Expand Up @@ -101,7 +101,7 @@ def make_frame(state_index: int) -> None:
return self._animation

def get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]:
"""This function returns a `Matplotlib` figure and axes object for displaying the game puzzle.
"""This function returns a `Matplotlib` figure and axes for displaying the puzzle.
Returns:
A tuple containing the figure and axes objects.
Expand Down
6 changes: 3 additions & 3 deletions jumanji/environments/packing/bin_pack/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,9 @@ def close(self) -> None:
def _make_observation_and_extras(
self, state: State
) -> Tuple[State, Observation, Dict]:
"""Computes the observation and the environment metrics to include in `timestep.extras`. Also
updates the `action_mask` and `sorted_ems_indexes` in the state. The observation is obtained
by selecting a subset of all EMSs, namely the `obs_num_ems` largest ones.
"""Computes the observation and the environment metrics to include in `timestep.extras`.
Also updates the `action_mask` and `sorted_ems_indexes` in the state. The observation is
obtained by selecting a subset of all EMSs, namely the `obs_num_ems` largest ones.
"""
obs_ems, obs_ems_mask, sorted_ems_indexes = self._get_set_of_largest_ems(
state.ems, state.ems_mask
Expand Down
4 changes: 2 additions & 2 deletions jumanji/environments/packing/flat_pack/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@


class InstanceGenerator(abc.ABC):
"""Base class for generators for the flat_pack environment. An `InstanceGenerator` is responsible
for generating a problem instance when the environment is reset.
"""Base class for generators for the flat_pack environment. An `InstanceGenerator` is
responsible for generating a problem instance when the environment is reset.
"""

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion jumanji/environments/routing/mmst/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __call__(self, key: chex.PRNGKey) -> State:


class SplitRandomGenerator(Generator):
"""Generates a random environments that is solvable by spliting the graph into multiple sub graphs.
"""Generates a random environments that is solvable by spliting the graph into sub graphs.
Returns a graph and with a desired number of edges and nodes to connect per agent.
"""
Expand Down
8 changes: 4 additions & 4 deletions jumanji/environments/routing/multi_cvrp/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,10 @@ def _draw_route(self, ax: plt.Axes, coords: chex.Array, col_id: int) -> None:
ax.scatter(x, y, s=self.NODE_SIZE, color=self._cmap(col_id))

def _add_tour(self, ax: plt.Axes, state: State) -> None:
"""Add the customers and the depot to the plot, and draw each route in the tour in a different
colour. The tour is the entire trajectory between the visited customers and a route is a
trajectory either starting and ending at the depot or starting at the depot and ending at
the current city."""
"""Add the customers and the depot to the plot, and draw each route in the tour in a
different colour. The tour is the entire trajectory between the visited customers and a
route is a trajectory either starting and ending at the depot or starting at the depot
and ending at the current city."""
x_coords, y_coords = (
state.nodes.coordinates[:, 0] / self._map_max,
state.nodes.coordinates[:, 1] / self._map_max,
Expand Down
4 changes: 2 additions & 2 deletions jumanji/environments/routing/robot_warehouse/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@

@pytest.fixture(scope="module")
def robot_warehouse_env() -> RobotWarehouse:
"""Instantiates a default RobotWarehouse environment with 2 agents, 1 shelf row, 3 shelf columns,
a column height of 2, sensor range of 1 and a request queue size of 4."""
"""Instantiates a default RobotWarehouse environment with 2 agents, 1 shelf row, 3 shelf
columns, a column height of 2, sensor range of 1 and a request queue size of 4."""
generator = RandomGenerator(
shelf_rows=1,
shelf_columns=3,
Expand Down
4 changes: 3 additions & 1 deletion jumanji/environments/routing/robot_warehouse/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,9 @@ def observation_spec(self) -> specs.Spec[Observation]:

@cached_property
def action_spec(self) -> specs.MultiDiscreteArray:
"""Returns the action spec. 5 actions: [0,1,2,3,4] -> [No Op, Forward, Left, Right, Toggle_load].
"""Returns the action spec. 5 actions:
[0,1,2,3,4] -> [No Op, Forward, Left, Right, Toggle_load].
Since this is a multi-agent environment, the environment expects an array of actions.
This array is of shape (num_agents,).
"""
Expand Down
8 changes: 4 additions & 4 deletions jumanji/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@


class Spec(abc.ABC, Generic[T]):
"""Adapted from `dm_env.spec.Array`. This is an augmentation of the `Array` spec to allow for nested
specs. `self.name`, `self.generate_value` and `self.validate` methods are adapted from the
`dm_env` object."""
"""Adapted from `dm_env.spec.Array`. This is an augmentation of the `Array` spec to allow for
nested specs. `self.name`, `self.generate_value` and `self.validate` methods are adapted from
the `dm_env` object."""

def __init__(
self,
Expand Down Expand Up @@ -139,7 +139,7 @@ def __getitem__(self, item: str) -> "Spec":


class Array(Spec[chex.Array]):
"""Describes a jax array spec. This is adapted from `dm_env.specs.Array` to suit Jax environments.
"""Describes a jax array spec. This is adapted from `dm_env.specs.Array` for Jax environments.
An `Array` spec allows an API to describe the arrays that it accepts or returns, before that
array exists.
Expand Down
14 changes: 7 additions & 7 deletions jumanji/specs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ def test_array(self) -> None:
converted_spec: dm_env.specs.Array = specs.jumanji_specs_to_dm_env_specs(
jumanji_spec
)
assert type(converted_spec) == type(dm_env_spec)
assert type(converted_spec) is type(dm_env_spec)
assert converted_spec.shape == dm_env_spec.shape
assert converted_spec.dtype == dm_env_spec.dtype
assert converted_spec.name == dm_env_spec.name
Expand All @@ -602,7 +602,7 @@ def test_bounded_array(self) -> None:
converted_spec: dm_env.specs.BoundedArray = specs.jumanji_specs_to_dm_env_specs(
jumanji_spec
)
assert type(converted_spec) == type(dm_env_spec)
assert type(converted_spec) is type(dm_env_spec)
assert converted_spec.shape == dm_env_spec.shape
assert converted_spec.dtype == dm_env_spec.dtype
assert converted_spec.name == dm_env_spec.name
Expand All @@ -615,7 +615,7 @@ def test_discrete_array(self) -> None:
converted_spec: dm_env.specs.DiscreteArray = (
specs.jumanji_specs_to_dm_env_specs(jumanji_spec)
)
assert type(converted_spec) == type(dm_env_spec)
assert type(converted_spec) is type(dm_env_spec)
assert converted_spec.shape == dm_env_spec.shape
assert converted_spec.dtype == dm_env_spec.dtype
assert converted_spec.name == dm_env_spec.name
Expand Down Expand Up @@ -675,7 +675,7 @@ def test_array(self) -> None:
jumanji_spec = specs.Array((1, 2), jnp.int32)
gym_space = gym.spaces.Box(-np.inf, np.inf, (1, 2), jnp.int32)
converted_spec = specs.jumanji_specs_to_gym_spaces(jumanji_spec)
assert type(converted_spec) == type(gym_space)
assert type(converted_spec) is type(gym_space)
assert_trees_all_equal(converted_spec.low, gym_space.low)
assert_trees_all_equal(converted_spec.high, gym_space.high)
assert converted_spec.shape == gym_space.shape
Expand All @@ -687,7 +687,7 @@ def test_bounded_array(self) -> None:
)
gym_space = gym.spaces.Box(low=0.0, high=1.0, shape=(1, 2), dtype=jnp.float32)
converted_spec = specs.jumanji_specs_to_gym_spaces(jumanji_spec)
assert type(converted_spec) == type(gym_space)
assert type(converted_spec) is type(gym_space)
assert converted_spec.shape == gym_space.shape
assert converted_spec.dtype == gym_space.dtype
assert_trees_all_equal(converted_spec.low, gym_space.low)
Expand All @@ -697,7 +697,7 @@ def test_discrete_array(self) -> None:
jumanji_spec = specs.DiscreteArray(num_values=5, dtype=jnp.int32)
gym_space = gym.spaces.Discrete(n=5)
converted_spec = specs.jumanji_specs_to_gym_spaces(jumanji_spec)
assert type(converted_spec) == type(gym_space)
assert type(converted_spec) is type(gym_space)
assert converted_spec.shape == gym_space.shape
assert converted_spec.dtype == gym_space.dtype
assert converted_spec.n == gym_space.n
Expand All @@ -708,7 +708,7 @@ def test_multi_discrete_array(self) -> None:
)
gym_space = gym.spaces.MultiDiscrete(nvec=[5, 6])
converted_spec = specs.jumanji_specs_to_gym_spaces(jumanji_spec)
assert type(converted_spec) == type(gym_space)
assert type(converted_spec) is type(gym_space)
assert converted_spec.shape == gym_space.shape
assert converted_spec.dtype == gym_space.dtype
assert jnp.array_equal(converted_spec.nvec, gym_space.nvec)
Expand Down
2 changes: 1 addition & 1 deletion jumanji/training/networks/graph_coloring/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def __call__(self, observation: Observation) -> chex.Array:
mlp_units=self.transformer_mlp_units,
w_init_scale=2 / self.num_transformer_layers,
model_size=self.model_size,
name=f"cross_attention_color_node_block_{block_id+1}",
name=f"cross_attention_color_node_block_{block_id + 1}",
)(color_embeddings, current_node_embeddings, current_node_embeddings)

return new_embedding
Expand Down
2 changes: 1 addition & 1 deletion jumanji/training/networks/mmst/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def __call__(self, observation: Observation) -> chex.Array:
mlp_units=self.transformer_mlp_units,
w_init_scale=2 / self.num_transformer_layers,
model_size=self.model_size,
name=f"cross_attention_agent_node_block_{block_id+1}",
name=f"cross_attention_agent_node_block_{block_id + 1}",
)(agents_embeddings, current_node_embeddings, current_node_embeddings)

return new_embedding
Expand Down
4 changes: 2 additions & 2 deletions jumanji/training/networks/tsp/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def __init__(
transformer_mlp_units: Sequence[int],
name: Optional[str] = None,
):
"""Linear embedding of all cities' coordinates followed by `transformer_num_blocks` blocks of self
attention.
"""Linear embedding of all cities' coordinates followed by `transformer_num_blocks` blocks
of self attention.
"""
super().__init__(name=name)
self.transformer_num_blocks = transformer_num_blocks
Expand Down
52 changes: 49 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,51 @@
[tool.isort]
profile = "black"
[build-system]
requires=["setuptools>=62.6"]
build-backend="setuptools.build_meta"

[project]
name="jumanji"
authors=[{name="InstaDeep Ltd", email="[email protected]"}]
dynamic=["version", "dependencies", "optional-dependencies"]
license={file="LICENSE"}
description="A diverse suite of scalable reinforcement learning environments in JAX"
readme ="README.md"
requires-python=">=3.10"
keywords=["reinforcement-learning", "python", "jax"]
classifiers=[
"Development Status :: 5 - Production/Stable",
"Environment :: Console",
"Intended Audience :: Science/Research",
"Intended Audience :: Developers",
"Operating System :: OS Independent",
"Programming Language :: Python",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries :: Python Modules",
"License :: OSI Approved :: Apache Software License",
]

[tool.setuptools.packages.find]
include=["jumanji*"]

[tool.setuptools.package-data]
"jumanji" = ["py.typed"]

[tool.setuptools.dynamic]
version={attr="jumanji.version.__version__"}
dependencies={file="requirements/requirements.txt"}
optional-dependencies.dev={file=["requirements/requirements-dev.txt"]}
optional-dependencies.train={file=["requirements/requirements-train.txt"]}


[project.urls]
"Homepage"="https://github.com/instadeep/jumanji"
"Bug Tracker"="https://github.com/instadeep/jumanji/issues"
"Documentation"="https://instadeepai.github.io/jumanji"

[tool.mypy]
python_version = 3.8
python_version = "3.10"
namespace_packages = true
incremental = false
cache_dir = ""
Expand Down Expand Up @@ -47,3 +90,6 @@ module = [
"PIL.*",
]
ignore_missing_imports = true

[tool.isort]
profile = "black"
7 changes: 1 addition & 6 deletions requirements/requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
black==22.3.0
coverage
flake8==3.9.2
importlib-metadata<5.0
flake8
isort==5.11.5
livereload
mkdocs==1.2.3
Expand All @@ -22,9 +21,5 @@ pytest-cov
pytest-mock
pytest-parallel
pytest-xdist
pytype
scipy>=1.7.3
testfixtures
types-Pillow
types-requests<1.27
types-setuptools
Loading

0 comments on commit 6b34657

Please sign in to comment.