diff --git a/docs/api/environments/sliding_tile_puzzle.md b/docs/api/environments/sliding_tile_puzzle.md new file mode 100644 index 000000000..bb7855cd0 --- /dev/null +++ b/docs/api/environments/sliding_tile_puzzle.md @@ -0,0 +1,8 @@ +::: jumanji.environments.logic.sliding_tile_puzzle.env.SlidingTilePuzzle + selection: + members: + - __init__ + - reset + - step + - observation_spec + - action_spec diff --git a/docs/env_anim/sliding_tile_puzzle.gif b/docs/env_anim/sliding_tile_puzzle.gif new file mode 100644 index 000000000..960ed5748 Binary files /dev/null and b/docs/env_anim/sliding_tile_puzzle.gif differ diff --git a/docs/env_img/sliding_tile_puzzle.png b/docs/env_img/sliding_tile_puzzle.png new file mode 100644 index 000000000..5f35b2b7f Binary files /dev/null and b/docs/env_img/sliding_tile_puzzle.png differ diff --git a/docs/environments/sliding_tile_puzzle.md b/docs/environments/sliding_tile_puzzle.md new file mode 100644 index 000000000..b62e9012d --- /dev/null +++ b/docs/environments/sliding_tile_puzzle.md @@ -0,0 +1,52 @@ +# Sliding Tile Puzzle Environment + +

+ +

+ +This is a Jax JIT-able implementation of the classic [Sliding Tile Puzzle game](https://en.wikipedia.org/wiki/Sliding_puzzle). + +The Sliding Tile Puzzle game is a classic puzzle that challenges a player to slide (typically flat) pieces along certain routes (usually on a board) to establish a certain end-configuration. The pieces to be moved may consist of simple shapes, or they may be imprinted with colors, patterns, sections of a larger picture (like a jigsaw puzzle), numbers, or letters. + +The puzzle is often 3×3, 4×4 or 5×5 in size and made up of square tiles that are slid into a square base, larger than the tiles by one tile space, in a specific large configuration. Tiles are moved/arranged by sliding an adjacent tile into a position occupied by the missing tile, which creates a new space. The sliding puzzle is mechanical and requires the use of no other equipment or tools. + +## Observation + +The observation in the Sliding Tile Puzzle game includes information about the puzzle, the position of the empty tile, and the action mask. + +- `puzzle`: jax array (int32) of shape `(grid_size, grid_size)`, representing the current game state. Each element in the array corresponds to a puzzle tile. The tile represented by 0 is the empty tile. + + - Here is an example of a random observation of the game board: + + ``` + [[ 1 2 3 4] + [ 5 6 7 8] + [ 9 10 0 12] + [ 13 14 15 11]] + ``` + - In this array, the tile represented by 0 is the empty tile that can be moved. + +- `empty_tile_position`: a tuple (int32) of shape `(2,)` representing the position of the empty tile in the grid. For example, (2, 2) would represent the third row and the third column in a zero-indexed grid. + +- `action_mask`: jax array (bool) of shape `(4,)`, indicating which actions are valid in the current state of the environment. The actions include moving the empty tile up, right, down, or left. For example, an action mask `[True, False, True, False]` means that the valid actions are to move the empty tile upward or downward. + +- `step_count`: jax array (int32) of shape `()`, current number of steps in the episode. + +## Action + +The action space is a `DiscreteArray` of integer values in `[0, 1, 2, 3]`. Specifically, these four actions correspond to moving the empty tile: up (0), right (1), down (2), or left (3). + +## Reward + +The reward could be either: + +- **DenseRewardFn**: This reward function provides a dense reward based on the difference of correctly placed tiles between the current state and the next state. The reward is positive for each newly correctly placed tile and negative for each newly incorrectly placed tile. + +- **SparseRewardFn**: This reward function provides a sparse reward, only rewarding when the puzzle is solved. +The reward is 1 if the puzzle is solved, and 0 otherwise. + +The goal in all cases is to solve the puzzle in a way that maximizes the reward. + +## Registered Versions 📖 + +- `SlidingTilePuzzle-v0`, the Sliding Tile Puzzle with a grid size of 5x5. diff --git a/jumanji/__init__.py b/jumanji/__init__.py index cfa526965..5e04ad474 100644 --- a/jumanji/__init__.py +++ b/jumanji/__init__.py @@ -134,3 +134,7 @@ register(id="Sokoban-v0", entry_point="jumanji.environments:Sokoban") # Pacman - minimal version of Atarti Pacman game register(id="PacMan-v0", entry_point="jumanji.environments:PacMan") +# SlidingTilePuzzle - A sliding tile puzzle environment with the default grid size of 5x5. +register( + id="SlidingTilePuzzle-v0", entry_point="jumanji.environments:SlidingTilePuzzle" +) diff --git a/jumanji/environments/__init__.py b/jumanji/environments/__init__.py index 4e69e2c2a..d69fbbf8e 100644 --- a/jumanji/environments/__init__.py +++ b/jumanji/environments/__init__.py @@ -14,12 +14,20 @@ import sys -from jumanji.environments.logic import game_2048, minesweeper, rubiks_cube +from jumanji.environments.logic import ( + game_2048, + graph_coloring, + minesweeper, + rubiks_cube, + sliding_tile_puzzle, + sudoku, +) from jumanji.environments.logic.game_2048.env import Game2048 from jumanji.environments.logic.graph_coloring.env import GraphColoring -from jumanji.environments.logic.minesweeper import Minesweeper -from jumanji.environments.logic.rubiks_cube import RubiksCube -from jumanji.environments.logic.sudoku import Sudoku +from jumanji.environments.logic.minesweeper.env import Minesweeper +from jumanji.environments.logic.rubiks_cube.env import RubiksCube +from jumanji.environments.logic.sliding_tile_puzzle.env import SlidingTilePuzzle +from jumanji.environments.logic.sudoku.env import Sudoku from jumanji.environments.packing import bin_pack, flat_pack, job_shop, knapsack, tetris from jumanji.environments.packing.bin_pack.env import BinPack from jumanji.environments.packing.flat_pack.env import FlatPack @@ -44,7 +52,7 @@ from jumanji.environments.routing.cvrp.env import CVRP from jumanji.environments.routing.maze.env import Maze from jumanji.environments.routing.mmst.env import MMST -from jumanji.environments.routing.multi_cvrp import MultiCVRP +from jumanji.environments.routing.multi_cvrp.env import MultiCVRP from jumanji.environments.routing.pac_man.env import PacMan from jumanji.environments.routing.robot_warehouse.env import RobotWarehouse from jumanji.environments.routing.snake.env import Snake diff --git a/jumanji/environments/logic/sliding_tile_puzzle/__init__.py b/jumanji/environments/logic/sliding_tile_puzzle/__init__.py new file mode 100644 index 000000000..ac684d878 --- /dev/null +++ b/jumanji/environments/logic/sliding_tile_puzzle/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from jumanji.environments.logic.sliding_tile_puzzle.env import SlidingTilePuzzle +from jumanji.environments.logic.sliding_tile_puzzle.types import Observation, State diff --git a/jumanji/environments/logic/sliding_tile_puzzle/conftest.py b/jumanji/environments/logic/sliding_tile_puzzle/conftest.py new file mode 100644 index 000000000..4138b9357 --- /dev/null +++ b/jumanji/environments/logic/sliding_tile_puzzle/conftest.py @@ -0,0 +1,42 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +import jax.numpy as jnp +import pytest + +from jumanji.environments.logic.sliding_tile_puzzle import SlidingTilePuzzle +from jumanji.environments.logic.sliding_tile_puzzle.generator import RandomWalkGenerator +from jumanji.environments.logic.sliding_tile_puzzle.types import State + + +@pytest.fixture +def sliding_tile_puzzle() -> SlidingTilePuzzle: + """Instantiates a default SlidingTilePuzzle environment.""" + generator = RandomWalkGenerator(grid_size=3) + return SlidingTilePuzzle(generator=generator) + + +@pytest.fixture +def state() -> State: + key = jax.random.PRNGKey(0) + empty_pos = jnp.array([0, 0]) + puzzle = jnp.array( + [ + [0, 1, 3], + [4, 2, 5], + [7, 8, 6], + ] + ) + return State(puzzle=puzzle, empty_tile_position=empty_pos, key=key, step_count=0) diff --git a/jumanji/environments/logic/sliding_tile_puzzle/constants.py b/jumanji/environments/logic/sliding_tile_puzzle/constants.py new file mode 100644 index 000000000..401ae9803 --- /dev/null +++ b/jumanji/environments/logic/sliding_tile_puzzle/constants.py @@ -0,0 +1,24 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import jax.numpy as jnp + +EMPTY_TILE = 0 +INITIAL_STEP_COUNT = 0 + +UP = [-1, 0] +RIGHT = [0, 1] +DOWN = [1, 0] +LEFT = [0, -1] + +MOVES = jnp.array([UP, RIGHT, DOWN, LEFT]) diff --git a/jumanji/environments/logic/sliding_tile_puzzle/env.py b/jumanji/environments/logic/sliding_tile_puzzle/env.py new file mode 100644 index 000000000..fe6a29f3d --- /dev/null +++ b/jumanji/environments/logic/sliding_tile_puzzle/env.py @@ -0,0 +1,284 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional, Sequence, Tuple + +import chex +import jax +import jax.numpy as jnp +import matplotlib.animation as animation +from jax import lax +from numpy.typing import NDArray + +from jumanji import specs +from jumanji.env import Environment +from jumanji.environments.logic.sliding_tile_puzzle.constants import MOVES +from jumanji.environments.logic.sliding_tile_puzzle.generator import ( + Generator, + RandomWalkGenerator, +) +from jumanji.environments.logic.sliding_tile_puzzle.reward import ( + DenseRewardFn, + RewardFn, +) +from jumanji.environments.logic.sliding_tile_puzzle.types import Observation, State +from jumanji.environments.logic.sliding_tile_puzzle.viewer import ( + SlidingTilePuzzleViewer, +) +from jumanji.types import TimeStep, restart, termination, transition +from jumanji.viewer import Viewer + + +class SlidingTilePuzzle(Environment[State]): + """Environment for the Sliding Tile Puzzle problem. + + The problem is a combinatorial optimization task where the goal is + to move the empty tile around in order to arrange all the tiles in order. + See more info: https://en.wikipedia.org/wiki/Sliding_puzzle. + + - observation: `Observation` + - puzzle: jax array (int32) of shape (N, N), representing the current state of the puzzle. + - empty_tile_position: Tuple of int32, representing the position of the empty tile. + - action_mask: jax array (bool) of shape (4,), indicating which actions are valid + in the current state of the environment. + + - action: int32, representing the direction to move the empty tile + (up, down, left, right) + + - reward: float, a dense reward is provided based on the arrangement of the tiles. + It equals the negative sum of the boolean difference between + the current state of the puzzle and the goal state (correctly arranged tiles). + Each incorrectly placed tile contributes -1 to the reward. + + - episode termination: if the puzzle is solved. + + - state: `State` + - puzzle: jax array (int32) of shape (N, N), representing the current state of the puzzle. + - empty_tile_position: Tuple of int32, representing the position of the empty tile. + - key: jax array (uint32) of shape (2,), random key used to generate random numbers + at each step and for auto-reset. + """ + + def __init__( + self, + generator: Optional[Generator] = None, + reward_fn: Optional[RewardFn] = None, + time_limit: int = 500, + viewer: Optional[Viewer[State]] = None, + ) -> None: + """Instantiate a `SlidingTilePuzzle` environment. + + Args: + generator: callable to instantiate environment instances. + Defaults to `RandomWalkGenerator` which generates shuffled puzzles with + a size of 5x5. + reward_fn: RewardFn whose `__call__` method computes the reward of an environment + transition. The function must compute the reward based on the current state, + the chosen action and the next state. + Implemented options are [`DenseRewardFn`, `SparseRewardFn`]. + Defaults to `DenseRewardFn`. + time_limit: maximum number of steps before the episode is terminated, default to 500. + viewer: environment viewer for rendering. + """ + self.generator = generator or RandomWalkGenerator( + grid_size=5, num_random_moves=200 + ) + self.reward_fn = reward_fn or DenseRewardFn() + + self.time_limit = time_limit + + # Create viewer used for rendering + self._env_viewer = viewer or SlidingTilePuzzleViewer(name="SlidingTilePuzzle") + self.solved_puzzle = self.generator.make_solved_puzzle() + + def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: + """Resets the environment to an initial state.""" + key, subkey = jax.random.split(key) + state = self.generator(subkey) + action_mask = self._get_valid_actions(state.empty_tile_position) + obs = Observation( + puzzle=state.puzzle, + empty_tile_position=state.empty_tile_position, + action_mask=action_mask, + step_count=state.step_count, + ) + timestep = restart(observation=obs, extras=self._get_extras(state)) + return state, timestep + + def step( + self, state: State, action: chex.Array + ) -> Tuple[State, TimeStep[Observation]]: + """Updates the environment state after the agent takes an action.""" + (updated_puzzle, updated_empty_tile_position) = self._move_empty_tile( + state.puzzle, state.empty_tile_position, action + ) + # Check if the puzzle is solved + done = jnp.array_equal(updated_puzzle, self.solved_puzzle) + + # Update the action mask + action_mask = self._get_valid_actions(updated_empty_tile_position) + + next_state = State( + puzzle=updated_puzzle, + empty_tile_position=updated_empty_tile_position, + key=state.key, + step_count=state.step_count + 1, + ) + obs = Observation( + puzzle=updated_puzzle, + empty_tile_position=updated_empty_tile_position, + action_mask=action_mask, + step_count=next_state.step_count, + ) + + reward = self.reward_fn(state, action, next_state, self.solved_puzzle) + extras = self._get_extras(next_state) + + timestep = jax.lax.cond( + done | (next_state.step_count >= self.time_limit), + lambda: termination( + reward=reward, + observation=obs, + extras=extras, + ), + lambda: transition( + reward=reward, + observation=obs, + extras=extras, + ), + ) + + return next_state, timestep + + def _move_empty_tile( + self, + puzzle: chex.Array, + empty_tile_position: chex.Array, + action: chex.Array, + ) -> chex.Array: + """Moves the empty tile in the given direction and returns the updated puzzle and reward.""" + + # Compute the new position + new_empty_tile_position = empty_tile_position + MOVES[action] + + # Predicate for the conditional + is_valid_move = jnp.all( + (new_empty_tile_position >= 0) + & (new_empty_tile_position < self.generator.grid_size) + ) + + # Swap the empty tile and the tile at the new position + updated_puzzle = puzzle.at[tuple(empty_tile_position)].set( + puzzle[tuple(new_empty_tile_position)] + ) + updated_puzzle = updated_puzzle.at[tuple(new_empty_tile_position)].set(0) + + return lax.cond( + is_valid_move, + lambda: (updated_puzzle, new_empty_tile_position), + lambda: (puzzle, empty_tile_position), + ) + + def _get_valid_actions(self, empty_tile_position: chex.Array) -> chex.Array: + # Compute the new positions if these movements are applied + new_positions = empty_tile_position + MOVES + + # Check if the new positions are within the grid boundaries + valid_moves_mask = jnp.all( + (new_positions >= 0) & (new_positions < self.generator.grid_size), axis=-1 + ) + + return valid_moves_mask + + def _get_extras(self, state: State) -> Dict[str, chex.Array]: + num_correct_tiles = jnp.sum(self.solved_puzzle == state.puzzle) + return {"prop_correctly_placed": num_correct_tiles / state.puzzle.size} + + def observation_spec(self) -> specs.Spec[Observation]: + """Returns the observation spec.""" + grid_size = self.generator.grid_size + return specs.Spec( + Observation, + "ObservationSpec", + puzzle=specs.BoundedArray( + shape=(grid_size, grid_size), + dtype=jnp.int32, + minimum=0, + maximum=grid_size * grid_size - 1, + name="puzzle", + ), + empty_tile_position=specs.BoundedArray( + shape=(2,), + dtype=jnp.int32, + minimum=0, + maximum=grid_size - 1, + name="empty_tile_position", + ), + action_mask=specs.BoundedArray( + shape=(4,), + dtype=bool, + minimum=False, + maximum=True, + name="action_mask", + ), + step_count=specs.BoundedArray( + shape=(), + dtype=jnp.int32, + minimum=0, + maximum=self.time_limit, + name="step_count", + ), + ) + + def action_spec(self) -> specs.DiscreteArray: + """Returns the action spec.""" + # Up, Right, Down, Left + return specs.DiscreteArray(num_values=4, name="action", dtype=jnp.int32) + + def render(self, state: State) -> Optional[NDArray]: + """Renders the current state of the puzzle board. + + Args: + state: is the current game state to be rendered. + """ + return self._env_viewer.render(state=state) + + def animate( + self, + states: Sequence[State], + interval: int = 200, + save_path: Optional[str] = None, + ) -> animation.FuncAnimation: + """Creates an animated gif of the puzzle board based on the sequence of game states. + + Args: + states: is a list of `State` objects representing the sequence of game states. + interval: the delay between frames in milliseconds, default to 200. + save_path: the path where the animation file should be saved. If it is None, the plot + will not be stored. + + Returns: + animation.FuncAnimation: the animation object that was created. + """ + return self._env_viewer.animate( + states=states, interval=interval, save_path=save_path + ) + + def close(self) -> None: + """Perform any necessary cleanup. + + Environments will automatically :meth:`close()` themselves when + garbage collected or when the program exits. + """ + self._env_viewer.close() diff --git a/jumanji/environments/logic/sliding_tile_puzzle/env_test.py b/jumanji/environments/logic/sliding_tile_puzzle/env_test.py new file mode 100644 index 000000000..31bab5f0e --- /dev/null +++ b/jumanji/environments/logic/sliding_tile_puzzle/env_test.py @@ -0,0 +1,172 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import chex +import jax +import jax.numpy as jnp + +from jumanji.environments.logic.sliding_tile_puzzle import SlidingTilePuzzle +from jumanji.environments.logic.sliding_tile_puzzle.types import State +from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.pytrees import assert_is_jax_array_tree +from jumanji.types import TimeStep + + +def test_sliding_tile_puzzle_reset_jit(sliding_tile_puzzle: SlidingTilePuzzle) -> None: + """Confirm that the reset method is only compiled once when jitted.""" + chex.clear_trace_counter() + reset_fn = jax.jit(chex.assert_max_traces(sliding_tile_puzzle.reset, n=1)) + key = jax.random.PRNGKey(0) + state, timestep = reset_fn(key) + + # Verify the data type of the output. + assert isinstance(timestep, TimeStep) + assert isinstance(state, State) + + # Check that the state is made of DeviceArrays, this is false for the non-jitted. + assert_is_jax_array_tree(state.puzzle) + assert_is_jax_array_tree(state.empty_tile_position) + + # Call again to check it does not compile twice. + state, timestep = reset_fn(key) + assert isinstance(timestep, TimeStep) + assert isinstance(state, State) + + +def test_sliding_tile_puzzle_step_jit( + sliding_tile_puzzle: SlidingTilePuzzle, state: State +) -> None: + """Confirm that the step is only compiled once when jitted.""" + up_action = jnp.array(0) + down_action = jnp.array(2) + + chex.clear_trace_counter() + step_fn = jax.jit(chex.assert_max_traces(sliding_tile_puzzle.step, n=1)) + + new_state, _ = step_fn(state, down_action) + + # Check that the state has changed. + assert not jnp.array_equal(new_state.puzzle, state.puzzle) + + # Check that the state is made of DeviceArrays, this is false for the non-jitted. + assert_is_jax_array_tree(new_state) + + new_state, _ = step_fn(new_state, up_action) + + # We went down and then back up so should be the same puzzle + assert jnp.array_equal(new_state.puzzle, state.puzzle) + + +def test_sliding_tile_puzzle_get_action_mask( + sliding_tile_puzzle: SlidingTilePuzzle, state: State +) -> None: + """Verify that the action mask generated by `_get_valid_actions` is correct.""" + get_valid_actions_fn = jax.jit(sliding_tile_puzzle._get_valid_actions) + action_mask = get_valid_actions_fn(state.empty_tile_position) + + # Check that the action mask is a boolean array with the correct shape. + assert action_mask.dtype == jnp.bool_ + assert action_mask.shape == (4,) + assert jnp.array_equal(action_mask, jnp.array([False, True, True, False])) + + +def test_sliding_tile_puzzle_does_not_smoke( + sliding_tile_puzzle: SlidingTilePuzzle, +) -> None: + """Test that we can run an episode without any errors.""" + check_env_does_not_smoke(sliding_tile_puzzle) + + +def test_env_one_move_to_solve(sliding_tile_puzzle: SlidingTilePuzzle) -> None: + """Test that the environment correctly handles a situation + where the puzzle is one move away from being solved. + """ + # Set up a state that is one move away from being solved. + one_move_away = jnp.array( + [ + [1, 2, 3], + [4, 5, 0], + [7, 8, 6], + ] + ) + empty_tile_position = jnp.array([1, 2]) + state = State( + puzzle=one_move_away, + empty_tile_position=empty_tile_position, + key=jax.random.PRNGKey(0), + step_count=0, + ) + + # The correct action to solve the puzzle is to move the empty tile down (action=2). + down_action = jnp.array(2) + next_state, timestep = sliding_tile_puzzle.step(state, down_action) + + assert jnp.array_equal(next_state.puzzle, sliding_tile_puzzle.solved_puzzle) + assert timestep.last() + assert timestep.discount == 0.0 + + +def test_env_illegal_move_does_not_change_board( + sliding_tile_puzzle: SlidingTilePuzzle, state: State +) -> None: + """Test that an illegal move does not change the board.""" + # An illegal move is to move the empty tile up (action=0) from its current position. + action = jnp.array(0) + next_state, _ = sliding_tile_puzzle.step(state, action) + + assert jnp.array_equal(next_state.puzzle, state.puzzle) + + +def test_env_legal_move_changes_board_as_expected( + sliding_tile_puzzle: SlidingTilePuzzle, state: State +) -> None: + """Test that a legal move changes the board as expected.""" + # [ + # [0, 1, 3], + # [4, 2, 5], + # [7, 8, 6], + # ] + # A legal move is to move the empty tile down (action=2). + action = jnp.array(2) + next_state, _ = sliding_tile_puzzle.step(state, action) + expected_puzzle = jnp.array( + [ + [4, 1, 3], + [0, 2, 5], + [7, 8, 6], + ] + ) + assert jnp.array_equal(next_state.puzzle, expected_puzzle) + + action = jnp.array(1) + next_state, _ = sliding_tile_puzzle.step(next_state, action) + expected_puzzle = jnp.array( + [ + [4, 1, 3], + [2, 0, 5], + [7, 8, 6], + ] + ) + assert jnp.array_equal(next_state.puzzle, expected_puzzle) + + action = jnp.array(0) + next_state, _ = sliding_tile_puzzle.step(next_state, action) + expected_puzzle = jnp.array( + [ + [4, 0, 3], + [2, 1, 5], + [7, 8, 6], + ] + ) + assert jnp.array_equal(next_state.puzzle, expected_puzzle) diff --git a/jumanji/environments/logic/sliding_tile_puzzle/generator.py b/jumanji/environments/logic/sliding_tile_puzzle/generator.py new file mode 100644 index 000000000..f2848e5d2 --- /dev/null +++ b/jumanji/environments/logic/sliding_tile_puzzle/generator.py @@ -0,0 +1,133 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from typing import Tuple + +import chex +import jax +from jax import numpy as jnp + +from jumanji.environments.logic.sliding_tile_puzzle.constants import EMPTY_TILE, MOVES +from jumanji.environments.logic.sliding_tile_puzzle.types import State + + +class Generator(abc.ABC): + def __init__(self, grid_size: int): + self._grid_size = grid_size + + @property + def grid_size(self) -> int: + """Size of the puzzle (n x n grid).""" + return self._grid_size + + def make_solved_puzzle(self) -> chex.Array: + """Creates a solved Sliding Tile Puzzle. + + Returns: + A solved puzzle. + """ + return ( + jnp.arange(1, self.grid_size**2 + 1) + .at[-1] + .set(EMPTY_TILE) + .reshape((self.grid_size, self.grid_size)) + ) + + @abc.abstractmethod + def __call__(self, key: chex.PRNGKey) -> State: + """Generate a problem instance. + + Args: + key: jax random key for any stochasticity used in the instance generation process. + + Returns: + State of the problem instance. + """ + + +class RandomWalkGenerator(Generator): + """A Sliding Tile Puzzle generator that samples solvable puzzles using a random walk + starting from the solved board. + + This generator creates puzzle configurations that are guaranteed to be solvable. + It starts with a solved configuration and makes a series of valid random moves to shuffle + the tiles. + + Args: + grid_size: The size of the puzzle (n x n grid). + num_random_moves: The number of moves to perform from the solved state. + """ + + def __init__(self, grid_size: int, num_random_moves: int = 100): + super().__init__(grid_size) + self.num_random_moves = num_random_moves + self._solved_puzzle = self.make_solved_puzzle() + + def __call__(self, key: chex.PRNGKey) -> State: + """Generate a random Sliding Tile Puzzle configuration. + + Args: + key: PRNGKey used for sampling random actions in the generation process. + + Returns: + State of the problem instance. + """ + # Start with the solved puzzle + puzzle = self._solved_puzzle + empty_tile_position = jnp.array([self._grid_size - 1, self._grid_size - 1]) + + # Perform a number of shuffle moves + key, moves_key = jax.random.split(key) + keys = jax.random.split(moves_key, self.num_random_moves) + (puzzle, empty_tile_position), _ = jax.lax.scan( + lambda carry, key: (self._make_random_move(key, *carry), None), + (puzzle, empty_tile_position), + keys, + ) + state = State( + puzzle=puzzle, + empty_tile_position=empty_tile_position, + key=key, + step_count=jnp.zeros((), jnp.int32), + ) + return state + + def _make_random_move( + self, key: chex.PRNGKey, puzzle: chex.Array, empty_tile_position: chex.Array + ) -> Tuple[chex.Array, chex.Array]: + """Makes a random valid move.""" + new_positions = empty_tile_position + MOVES + + # Determine valid moves (known-size boolean array) + valid_moves_mask = jnp.all( + (new_positions >= 0) & (new_positions < self._grid_size), axis=-1 + ) + move = jax.random.choice(key, MOVES, shape=(), p=valid_moves_mask) + new_empty_tile_position = empty_tile_position + move + # Swap the empty tile with the tile at the new position using _swap_tiles + updated_puzzle = self._swap_tiles( + puzzle, empty_tile_position, new_empty_tile_position + ) + + return updated_puzzle, new_empty_tile_position + + def _swap_tiles( + self, puzzle: chex.Array, pos1: chex.Array, pos2: chex.Array + ) -> chex.Array: + """Swaps the tiles at the given positions.""" + temp = puzzle[tuple(pos1)] + puzzle = puzzle.at[tuple(pos1)].set(puzzle[tuple(pos2)]) + puzzle = puzzle.at[tuple(pos2)].set(temp) + return puzzle diff --git a/jumanji/environments/logic/sliding_tile_puzzle/reward.py b/jumanji/environments/logic/sliding_tile_puzzle/reward.py new file mode 100644 index 000000000..585fad2c1 --- /dev/null +++ b/jumanji/environments/logic/sliding_tile_puzzle/reward.py @@ -0,0 +1,93 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc + +import chex +import jax.numpy as jnp + +from jumanji.environments.logic.sliding_tile_puzzle.types import State + + +class RewardFn(abc.ABC): + @abc.abstractmethod + def __call__( + self, + state: State, + action: chex.Numeric, + next_state: State, + solved_puzzle: chex.Array, + ) -> chex.Numeric: + """Compute the reward based on the current state, the chosen action, the next state, + and the solved puzzle state.""" + + +class SparseRewardFn(RewardFn): + """Reward function that provides a sparse reward, only rewarding when the puzzle is solved.""" + + def __call__( + self, + state: State, + action: chex.Numeric, + next_state: State, + solved_puzzle: chex.Array, + ) -> chex.Numeric: + """ + Calculates the reward for the given state and action. + + Args: + state: The current state. + action: The chosen action. + next_state: The state resulting from the chosen action. + solved_puzzle: What the solved puzzle looks like. + + Returns: + The calculated reward. + """ + # The sparse reward is 1 if the puzzle is solved, and 0 otherwise. + return jnp.array_equal(next_state.puzzle, solved_puzzle).astype(float) + + +class DenseRewardFn(RewardFn): + """Reward function that provides a dense reward based on + the difference of correctly placed tiles between states.""" + + def __call__( + self, + state: State, + action: chex.Numeric, + next_state: State, + solved_puzzle: chex.Array, + ) -> chex.Numeric: + """ + Calculates a dense reward for the given state and action. This dense + reward is positive for each newly correctly placed tile and negative + for each newly incorrectly placed tile. + + Args: + state: The current state. + action: The chosen action. + next_state: The state resulting from the chosen action. + solved_puzzle: What the solved puzzle looks like. + + Returns: + The calculated dense reward. + """ + new_correct_tiles = jnp.sum( + (next_state.puzzle == solved_puzzle) & (state.puzzle != solved_puzzle) + ) + new_incorrect_tiles = jnp.sum( + (next_state.puzzle != solved_puzzle) & (state.puzzle == solved_puzzle) + ) + return (new_correct_tiles - new_incorrect_tiles).astype(float) diff --git a/jumanji/environments/logic/sliding_tile_puzzle/types.py b/jumanji/environments/logic/sliding_tile_puzzle/types.py new file mode 100644 index 000000000..df7c34ae7 --- /dev/null +++ b/jumanji/environments/logic/sliding_tile_puzzle/types.py @@ -0,0 +1,50 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +import chex + +if TYPE_CHECKING: + from dataclasses import dataclass +else: + from chex import dataclass +from typing import NamedTuple + + +@dataclass +class State: + """ + puzzle: 2D array representing the current state of the puzzle. + empty_tile_position: the position of the empty tile in the puzzle. + key: random key used for generating random numbers at each step. + """ + + puzzle: chex.Array # (N, N) + empty_tile_position: chex.Array # (2,) + key: chex.PRNGKey # (2,) + step_count: chex.Array # (1,) + + +class Observation(NamedTuple): + """ + puzzle: 2D array representing the current state of the puzzle. + empty_tile_position: the position of the empty tile in the puzzle. + action_mask: 1D array indicating the validity of each action. + """ + + puzzle: chex.Array # (N, N) + empty_tile_position: chex.Array # (2,) + action_mask: chex.Array # (4,) # assuming 4 possible actions: up, down, left, right + step_count: int # Current timestep diff --git a/jumanji/environments/logic/sliding_tile_puzzle/viewer.py b/jumanji/environments/logic/sliding_tile_puzzle/viewer.py new file mode 100644 index 000000000..6596a323d --- /dev/null +++ b/jumanji/environments/logic/sliding_tile_puzzle/viewer.py @@ -0,0 +1,169 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2023 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Sequence, Tuple + +import matplotlib.animation +import matplotlib.colors as mcolors +import matplotlib.pyplot as plt +import numpy as np + +import jumanji.environments +from jumanji.environments.logic.sliding_tile_puzzle.types import State +from jumanji.viewer import Viewer + + +class SlidingTilePuzzleViewer(Viewer): + EMPTY_TILE_COLOR = "#ccc0b3" + + def __init__(self, name: str = "SlidingTilePuzzle") -> None: + """Viewer for the Sliding Tile Puzzle environment. + + Args: + name: the window name to be used when initialising the window. + grid_size: size of the puzzle. + """ + self._name = name + self._animation: Optional[matplotlib.animation.Animation] = None + self._color_map = mcolors.LinearSegmentedColormap.from_list( + "", ["white", "blue"] + ) + + def render(self, state: State) -> None: + """Renders the current state of the game puzzle. + + Args: + state: is the current game state to be rendered. + """ + self._clear_display() + fig, ax = self.get_fig_ax() + self.draw_puzzle(ax, state) + self._display_human(fig) + + def animate( + self, + states: Sequence[State], + interval: int = 200, + save_path: Optional[str] = None, + ) -> matplotlib.animation.FuncAnimation: + """Creates an animated gif of the sliding tiles puzzle game based on the sequence of game states. + + Args: + states: is a list of `State` objects representing the sequence of game states. + interval: the delay between frames in milliseconds, default to 200. + save_path: the path where the animation file should be saved. If it is None, the plot + will not be saved. + + Returns: + Animation object that can be saved as a GIF, MP4, or rendered with HTML. + """ + fig, ax = self.get_fig_ax() + + def make_frame(state_index: int) -> None: + state = states[state_index] + self.draw_puzzle(ax, state) + + self._animation = matplotlib.animation.FuncAnimation( + fig, + make_frame, + frames=len(states), + interval=interval, + ) + + if save_path: + self._animation.save(save_path) + + return self._animation + + def get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]: + """This function returns a `Matplotlib` figure and axes object for displaying the game puzzle. + + Returns: + A tuple containing the figure and axes objects. + """ + exists = plt.fignum_exists(self._name) + if exists: + fig = plt.figure(self._name) + ax = fig.get_axes()[0] + else: + fig = plt.figure(self._name, figsize=(6.0, 6.0)) + if not plt.isinteractive(): + fig.show() + ax = fig.add_subplot() + + return fig, ax + + def draw_puzzle(self, ax: plt.Axes, state: State) -> None: + """Draw the game puzzle with the current state. + + Args: + ax: the axis to draw the puzzle on. + state: the current state of the game. + """ + ax.clear() + grid_size = state.puzzle.shape[0] + ax.set_xticks(np.arange(-0.5, grid_size - 1, 1), minor=True) + ax.set_yticks(np.arange(-0.5, grid_size - 1, 1), minor=True) + ax.grid(which="minor", color="black", linestyle="-", linewidth=2) + + # Render the puzzle + for row in range(grid_size): + for col in range(grid_size): + tile_value = state.puzzle[row, col] + if tile_value == 0: + # Render the empty tile + rect = plt.Rectangle( + [col - 0.5, row - 0.5], 1, 1, color=self.EMPTY_TILE_COLOR + ) + ax.add_patch(rect) + else: + # Render the numbered tile + ax.text(col, row, str(tile_value), ha="center", va="center") + + # Show the image of the puzzle. + ax.imshow(state.puzzle, cmap=self._color_map) + + def close(self) -> None: + plt.close(self._name) + + def _display_human(self, fig: plt.Figure) -> None: + if plt.isinteractive(): + # Required to update render when using Jupyter Notebook. + fig.canvas.draw() + if jumanji.environments.is_colab(): + plt.show(self._name) + else: + # Required to update render when not using Jupyter Notebook. + fig.canvas.draw_idle() + fig.canvas.flush_events() + + def _clear_display(self) -> None: + if jumanji.environments.is_colab(): + import IPython.display + + IPython.display.clear_output(True) diff --git a/jumanji/environments/routing/sokoban/env_test.py b/jumanji/environments/routing/sokoban/env_test.py index 8c3d8da93..c6e935e4f 100644 --- a/jumanji/environments/routing/sokoban/env_test.py +++ b/jumanji/environments/routing/sokoban/env_test.py @@ -22,7 +22,7 @@ from jumanji.environments.routing.sokoban.constants import AGENT, BOX, TARGET, WALL from jumanji.environments.routing.sokoban.env import Sokoban from jumanji.environments.routing.sokoban.generator import ( - DeepMindGenerator, + HuggingFaceDeepMindGenerator, SimpleSolveGenerator, ) from jumanji.environments.routing.sokoban.types import State @@ -33,10 +33,9 @@ @pytest.fixture(scope="session") def sokoban() -> Sokoban: env = Sokoban( - generator=DeepMindGenerator( - difficulty="unfiltered", - split="train", - proportion_of_files=0.005, + generator=HuggingFaceDeepMindGenerator( + "unfiltered-train", + proportion_of_files=0.01, ) ) return env diff --git a/jumanji/training/configs/config.yaml b/jumanji/training/configs/config.yaml index 6ad8f62fc..a8333eecc 100644 --- a/jumanji/training/configs/config.yaml +++ b/jumanji/training/configs/config.yaml @@ -1,6 +1,6 @@ defaults: - _self_ - - env: snake # [bin_pack, cleaner, connector, cvrp, flat_pack, game_2048, graph_coloring, job_shop, knapsack, maze, minesweeper, mmst, multi_cvrp, pac_man, robot_warehouse, rubiks_cube, snake, sokoban, sudoku, tetris, tsp] + - env: snake # [bin_pack, cleaner, connector, cvrp, flat_pack, game_2048, graph_coloring, job_shop, knapsack, maze, minesweeper, mmst, multi_cvrp, pac_man, robot_warehouse, rubiks_cube, sliding_tile_puzzle, snake, sokoban, sudoku, tetris, tsp] agent: random # [random, a2c] diff --git a/jumanji/training/configs/env/sliding_tile_puzzle.yaml b/jumanji/training/configs/env/sliding_tile_puzzle.yaml new file mode 100644 index 000000000..6571c73fe --- /dev/null +++ b/jumanji/training/configs/env/sliding_tile_puzzle.yaml @@ -0,0 +1,26 @@ +name: sliding_tile_puzzle +registered_version: SlidingTilePuzzle-v0 + +network: + num_channels: 32 + policy_layers: [128, 128] + value_layers: [256, 256] + +training: + num_epochs: 400 + num_learner_steps_per_epoch: 100 + n_steps: 30 + total_batch_size: 256 + +evaluation: + eval_total_batch_size: 512 + greedy_eval_total_batch_size: 512 + +a2c: + normalize_advantage: False + discount_factor: 0.99 + bootstrapping_factor: 0.95 + l_pg: 1.0 + l_td: 1.0 + l_en: 0.01 + learning_rate: 2e-4 diff --git a/jumanji/training/networks/__init__.py b/jumanji/training/networks/__init__.py index 956d8dbb3..910aa044b 100644 --- a/jumanji/training/networks/__init__.py +++ b/jumanji/training/networks/__init__.py @@ -78,6 +78,12 @@ make_actor_critic_networks_rubiks_cube, ) from jumanji.training.networks.rubiks_cube.random import make_random_policy_rubiks_cube +from jumanji.training.networks.sliding_tile_puzzle.actor_critic import ( + make_actor_critic_networks_sliding_tile_puzzle, +) +from jumanji.training.networks.sliding_tile_puzzle.random import ( + make_random_policy_sliding_tile_puzzle, +) from jumanji.training.networks.snake.actor_critic import ( make_actor_critic_networks_snake, ) diff --git a/jumanji/training/networks/sliding_tile_puzzle/__init__.py b/jumanji/training/networks/sliding_tile_puzzle/__init__.py new file mode 100644 index 000000000..21db9ec1c --- /dev/null +++ b/jumanji/training/networks/sliding_tile_puzzle/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jumanji/training/networks/sliding_tile_puzzle/actor_critic.py b/jumanji/training/networks/sliding_tile_puzzle/actor_critic.py new file mode 100644 index 000000000..71625f7cc --- /dev/null +++ b/jumanji/training/networks/sliding_tile_puzzle/actor_critic.py @@ -0,0 +1,95 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Sequence + +import chex +import haiku as hk +import jax +import jax.numpy as jnp + +from jumanji.environments.logic.sliding_tile_puzzle import ( + Observation, + SlidingTilePuzzle, +) +from jumanji.training.networks.actor_critic import ( + ActorCriticNetworks, + FeedForwardNetwork, +) +from jumanji.training.networks.parametric_distribution import ( + CategoricalParametricDistribution, +) + + +def make_actor_critic_networks_sliding_tile_puzzle( + sliding_tile_puzzle: SlidingTilePuzzle, + num_channels: int, + policy_layers: Sequence[int], + value_layers: Sequence[int], +) -> ActorCriticNetworks: + """Make actor-critic networks for the `SlidingTilePuzzle` environment.""" + num_actions = sliding_tile_puzzle.action_spec().num_values + parametric_action_distribution = CategoricalParametricDistribution( + num_actions=num_actions + ) + policy_network = make_mlp_network( + num_outputs=num_actions, + mlp_units=policy_layers, + conv_n_channels=num_channels, + ) + value_network = make_mlp_network( + num_outputs=1, + mlp_units=value_layers, + conv_n_channels=num_channels, + ) + return ActorCriticNetworks( + policy_network=policy_network, + value_network=value_network, + parametric_action_distribution=parametric_action_distribution, + ) + + +def make_mlp_network( + num_outputs: int, + mlp_units: Sequence[int], + conv_n_channels: int, +) -> FeedForwardNetwork: + def network_fn(observation: Observation) -> chex.Array: + puzzle = observation.puzzle.astype(float)[..., None] + torso = hk.Sequential( + [ + hk.Conv2D(conv_n_channels, (3, 3), padding="SAME"), + jax.nn.relu, + hk.Conv2D(conv_n_channels, (3, 3), padding="SAME"), + jax.nn.relu, + hk.Conv2D(conv_n_channels, (3, 3), padding="SAME"), + jax.nn.relu, + hk.Flatten(), + ] + ) + embedding = torso(puzzle) + head = hk.nets.MLP((*mlp_units, num_outputs), activate_final=False) + + if num_outputs == 1: + return jnp.squeeze(head(embedding), axis=-1) + else: + logits = head(embedding) + masked_logits = jnp.where( + observation.action_mask, logits, jnp.finfo(jnp.float32).min + ) + return masked_logits + + init, apply = hk.without_apply_rng(hk.transform(network_fn)) + return FeedForwardNetwork(init=init, apply=apply) diff --git a/jumanji/training/networks/sliding_tile_puzzle/random.py b/jumanji/training/networks/sliding_tile_puzzle/random.py new file mode 100644 index 000000000..bea36a682 --- /dev/null +++ b/jumanji/training/networks/sliding_tile_puzzle/random.py @@ -0,0 +1,23 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from jumanji.training.networks.masked_categorical_random import ( + masked_categorical_random, +) +from jumanji.training.networks.protocols import RandomPolicy + + +def make_random_policy_sliding_tile_puzzle() -> RandomPolicy: + """Make random policy for SlidingTilePuzzle.""" + return masked_categorical_random diff --git a/jumanji/training/setup_train.py b/jumanji/training/setup_train.py index a8b3ed6f1..ef30367ac 100644 --- a/jumanji/training/setup_train.py +++ b/jumanji/training/setup_train.py @@ -40,6 +40,7 @@ PacMan, RobotWarehouse, RubiksCube, + SlidingTilePuzzle, Snake, Sokoban, Sudoku, @@ -143,6 +144,9 @@ def _setup_random_policy( # noqa: CCR001 elif cfg.env.name == "snake": assert isinstance(env.unwrapped, Snake) random_policy = networks.make_random_policy_snake() + elif cfg.env.name == "sliding_tile_puzzle": + assert isinstance(env.unwrapped, SlidingTilePuzzle) + random_policy = networks.make_random_policy_sliding_tile_puzzle() elif cfg.env.name == "tsp": assert isinstance(env.unwrapped, TSP) random_policy = networks.make_random_policy_tsp() @@ -302,6 +306,14 @@ def _setup_actor_critic_neworks( # noqa: CCR001 policy_layers=cfg.env.network.policy_layers, value_layers=cfg.env.network.value_layers, ) + elif cfg.env.name == "sliding_tile_puzzle": + assert isinstance(env.unwrapped, SlidingTilePuzzle) + actor_critic_networks = networks.make_actor_critic_networks_sliding_tile_puzzle( + sliding_tile_puzzle=env.unwrapped, + num_channels=cfg.env.network.num_channels, + policy_layers=cfg.env.network.policy_layers, + value_layers=cfg.env.network.value_layers, + ) elif cfg.env.name == "rubiks_cube": assert isinstance(env.unwrapped, RubiksCube) actor_critic_networks = networks.make_actor_critic_networks_rubiks_cube( diff --git a/mkdocs.yml b/mkdocs.yml index fe048d0c4..5adc18f9e 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -22,6 +22,7 @@ nav: - GraphColoring: environments/graph_coloring.md - Minesweeper: environments/minesweeper.md - RubiksCube: environments/rubiks_cube.md + - SlidingTilePuzzle: environments/sliding_tile_puzzle.md - Sudoku: environments/sudoku.md - Packing: - BinPack: environments/bin_pack.md @@ -54,6 +55,7 @@ nav: - GraphColoring: api/environments/graph_coloring.md - Minesweeper: api/environments/minesweeper.md - RubiksCube: api/environments/rubiks_cube.md + - SlidingTilePuzzle: api/environments/sliding_tile_puzzle.md - Sudoku: api/environments/sudoku.md - Packing: - BinPack: api/environments/bin_pack.md