From e726259e86d555c7055fb48bd5842cf37af78bfd Mon Sep 17 00:00:00 2001 From: James Doran Date: Wed, 11 Oct 2023 12:07:41 +0100 Subject: [PATCH] Wave Function Collapse Environments (#371) Co-authored-by: Isaac Karth Co-authored-by: Valentin Valls Co-authored-by: Isaac Karth Co-authored-by: Kyle Benesch <4b796c65+github@gmail.com> Co-authored-by: Samuel Garcin Co-authored-by: Mark Towers --- .pre-commit-config.yaml | 2 +- minigrid/__init__.py | 10 + minigrid/core/world_object.py | 2 +- minigrid/envs/wfc/__init__.py | 24 + minigrid/envs/wfc/config.py | 220 +++++++++ minigrid/envs/wfc/graphtransforms.py | 396 +++++++++++++++ minigrid/envs/wfc/patterns/Angular.png | Bin 0 -> 101 bytes minigrid/envs/wfc/patterns/Blackdots.png | Bin 0 -> 2804 bytes minigrid/envs/wfc/patterns/Cave.png | Bin 0 -> 168 bytes minigrid/envs/wfc/patterns/City.png | Bin 0 -> 87 bytes minigrid/envs/wfc/patterns/DungeonExtr.png | Bin 0 -> 3147 bytes minigrid/envs/wfc/patterns/Fabric.png | Bin 0 -> 120 bytes minigrid/envs/wfc/patterns/Hogs.png | Bin 0 -> 186 bytes minigrid/envs/wfc/patterns/Knot.png | Bin 0 -> 144 bytes minigrid/envs/wfc/patterns/Lake.png | Bin 0 -> 190 bytes minigrid/envs/wfc/patterns/LessRooms.png | Bin 0 -> 172 bytes minigrid/envs/wfc/patterns/MagicOffice.png | Bin 0 -> 151 bytes minigrid/envs/wfc/patterns/Maze.png | Bin 0 -> 178 bytes minigrid/envs/wfc/patterns/Mazelike.png | Bin 0 -> 161 bytes minigrid/envs/wfc/patterns/Office.png | Bin 0 -> 171 bytes minigrid/envs/wfc/patterns/Paths.png | Bin 0 -> 147 bytes minigrid/envs/wfc/patterns/RedMaze.png | Bin 0 -> 105 bytes minigrid/envs/wfc/patterns/Rooms.png | Bin 0 -> 181 bytes minigrid/envs/wfc/patterns/ScaledMaze.png | Bin 0 -> 135 bytes minigrid/envs/wfc/patterns/SimpleKnot.png | Bin 0 -> 148 bytes minigrid/envs/wfc/patterns/SimpleMaze.png | Bin 0 -> 99 bytes minigrid/envs/wfc/patterns/SimpleWall.png | Bin 0 -> 138 bytes minigrid/envs/wfc/patterns/Skew1.png | Bin 0 -> 253 bytes minigrid/envs/wfc/patterns/Skew2.png | Bin 0 -> 259 bytes minigrid/envs/wfc/patterns/Spirals.png | Bin 0 -> 137 bytes minigrid/envs/wfc/patterns/SpiralsNeg.png | Bin 0 -> 144 bytes minigrid/envs/wfc/wfcenv.py | 226 +++++++++ minigrid/envs/wfc/wfclogic/__init__.py | 0 minigrid/envs/wfc/wfclogic/adjacency.py | 56 +++ minigrid/envs/wfc/wfclogic/control.py | 295 ++++++++++++ minigrid/envs/wfc/wfclogic/patterns.py | 199 ++++++++ minigrid/envs/wfc/wfclogic/solver.py | 530 +++++++++++++++++++++ minigrid/envs/wfc/wfclogic/tiles.py | 64 +++ minigrid/envs/wfc/wfclogic/utilities.py | 77 +++ minigrid/wrappers.py | 1 + py.Dockerfile | 2 +- pyproject.toml | 4 + tests/test_wfc/__init__.py | 0 tests/test_wfc/conftest.py | 40 ++ tests/test_wfc/test_wfc_adjacency.py | 41 ++ tests/test_wfc/test_wfc_patterns.py | 60 +++ tests/test_wfc/test_wfc_solver.py | 148 ++++++ tests/test_wfc/test_wfc_tiles.py | 17 + tests/utils.py | 10 + 49 files changed, 2421 insertions(+), 3 deletions(-) create mode 100644 minigrid/envs/wfc/__init__.py create mode 100644 minigrid/envs/wfc/config.py create mode 100644 minigrid/envs/wfc/graphtransforms.py create mode 100644 minigrid/envs/wfc/patterns/Angular.png create mode 100644 minigrid/envs/wfc/patterns/Blackdots.png create mode 100644 minigrid/envs/wfc/patterns/Cave.png create mode 100644 minigrid/envs/wfc/patterns/City.png create mode 100644 minigrid/envs/wfc/patterns/DungeonExtr.png create mode 100644 minigrid/envs/wfc/patterns/Fabric.png create mode 100644 minigrid/envs/wfc/patterns/Hogs.png create mode 100644 minigrid/envs/wfc/patterns/Knot.png create mode 100644 minigrid/envs/wfc/patterns/Lake.png create mode 100644 minigrid/envs/wfc/patterns/LessRooms.png create mode 100644 minigrid/envs/wfc/patterns/MagicOffice.png create mode 100644 minigrid/envs/wfc/patterns/Maze.png create mode 100644 minigrid/envs/wfc/patterns/Mazelike.png create mode 100644 minigrid/envs/wfc/patterns/Office.png create mode 100644 minigrid/envs/wfc/patterns/Paths.png create mode 100644 minigrid/envs/wfc/patterns/RedMaze.png create mode 100644 minigrid/envs/wfc/patterns/Rooms.png create mode 100644 minigrid/envs/wfc/patterns/ScaledMaze.png create mode 100644 minigrid/envs/wfc/patterns/SimpleKnot.png create mode 100644 minigrid/envs/wfc/patterns/SimpleMaze.png create mode 100644 minigrid/envs/wfc/patterns/SimpleWall.png create mode 100644 minigrid/envs/wfc/patterns/Skew1.png create mode 100644 minigrid/envs/wfc/patterns/Skew2.png create mode 100644 minigrid/envs/wfc/patterns/Spirals.png create mode 100644 minigrid/envs/wfc/patterns/SpiralsNeg.png create mode 100644 minigrid/envs/wfc/wfcenv.py create mode 100644 minigrid/envs/wfc/wfclogic/__init__.py create mode 100644 minigrid/envs/wfc/wfclogic/adjacency.py create mode 100644 minigrid/envs/wfc/wfclogic/control.py create mode 100644 minigrid/envs/wfc/wfclogic/patterns.py create mode 100644 minigrid/envs/wfc/wfclogic/solver.py create mode 100644 minigrid/envs/wfc/wfclogic/tiles.py create mode 100644 minigrid/envs/wfc/wfclogic/utilities.py create mode 100644 tests/test_wfc/__init__.py create mode 100644 tests/test_wfc/conftest.py create mode 100644 tests/test_wfc/test_wfc_adjacency.py create mode 100644 tests/test_wfc/test_wfc_patterns.py create mode 100644 tests/test_wfc/test_wfc_solver.py create mode 100644 tests/test_wfc/test_wfc_tiles.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 93bb2434b..50df300f0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,7 +16,7 @@ repos: - id: flake8 args: - '--per-file-ignores=*/__init__.py:F401' -# - --ignore= + - --ignore=E203, W503 - --max-complexity=30 - --max-line-length=456 - --show-source diff --git a/minigrid/__init__.py b/minigrid/__init__.py index 6371da484..be2f75264 100644 --- a/minigrid/__init__.py +++ b/minigrid/__init__.py @@ -5,6 +5,7 @@ from minigrid import minigrid_env, wrappers from minigrid.core import roomgrid from minigrid.core.world_object import Wall +from minigrid.envs.wfc.config import WFC_PRESETS __version__ = "2.3.1" @@ -565,6 +566,15 @@ def register_minigrid_envs(): entry_point="minigrid.envs:UnlockPickupEnv", ) + # WaveFunctionCollapse + # ---------------------------------------- + for name in WFC_PRESETS.keys(): + register( + id=f"MiniGrid-WFC-{name}-v0", + entry_point="minigrid.envs.wfc:WFCEnv", + kwargs={"wfc_config": name}, + ) + # BabyAI - Language based levels - GoTo # ---------------------------------------- diff --git a/minigrid/core/world_object.py b/minigrid/core/world_object.py index 592be953a..de4e550b3 100644 --- a/minigrid/core/world_object.py +++ b/minigrid/core/world_object.py @@ -74,7 +74,7 @@ def decode(type_idx: int, color_idx: int, state: int) -> WorldObj | None: obj_type = IDX_TO_OBJECT[type_idx] color = IDX_TO_COLOR[color_idx] - if obj_type == "empty" or obj_type == "unseen": + if obj_type == "empty" or obj_type == "unseen" or obj_type == "agent": return None # State, 0: open, 1: closed, 2: locked diff --git a/minigrid/envs/wfc/__init__.py b/minigrid/envs/wfc/__init__.py new file mode 100644 index 000000000..6135608ce --- /dev/null +++ b/minigrid/envs/wfc/__init__.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from minigrid.envs.wfc.config import ( + WFC_PRESETS, + WFC_PRESETS_INCONSISTENT, + WFC_PRESETS_SLOW, + WFCConfig, +) + +# This is wrapped in a try-except block so the presets can be accessed for registration +# Otherwise, importing here will fail when networkx is not installed +try: + from minigrid.envs.wfc.wfcenv import WFCEnv +except ImportError: + + class WFCEnv: + """Dummy class to give a helpful error message when dependencies are missing""" + + def __init__(self, *args, **kwargs): + from gymnasium.error import DependencyNotInstalled + + raise DependencyNotInstalled( + 'WFC dependencies are missing, please run `pip install "minigrid[wfc]"`' + ) diff --git a/minigrid/envs/wfc/config.py b/minigrid/envs/wfc/config.py new file mode 100644 index 000000000..f2b615623 --- /dev/null +++ b/minigrid/envs/wfc/config.py @@ -0,0 +1,220 @@ +from __future__ import annotations + +from dataclasses import asdict, dataclass +from pathlib import Path + +from typing_extensions import Literal + +PATTERN_PATH = Path(__file__).parent / "patterns" + + +@dataclass +class WFCConfig: + """Dataclass for holding WFC configuration parameters. + + This controls the behavior of the WFC algorithm. The parameters are passed directly to the WFC solver. + + Attributes: + pattern_path: Path to the pattern image that will be automatically loaded. + tile_size: Size of the tiles in pixels to create from the pattern image. + pattern_width: Size of the patterns in tiles to take from the pattern image. (greater than 3 is quite slow) + rotations: Number of rotations for each tile. + output_periodic: Whether the output should be periodic (wraps over edges). + input_periodic: Whether the input should be periodic (wraps over edges). + loc_heuristic: Heuristic for choosing the next tile location to collapse. + choice_heuristic: Heuristic for choosing the next tile to use between possible tiles. + backtracking: Whether to backtrack when contradictions are discovered. + """ + + pattern_path: Path + tile_size: int = 1 + pattern_width: int = 2 + rotations: int = 8 + output_periodic: bool = False + input_periodic: bool = False + loc_heuristic: Literal[ + "lexical", "spiral", "entropy", "anti-entropy", "simple", "random" + ] = "entropy" + choice_heuristic: Literal["lexical", "rarest", "weighted", "random"] = "weighted" + backtracking: bool = False + + @property + def wfc_kwargs(self): + try: + from imageio.v2 import imread + except ImportError as e: + from gymnasium.error import DependencyNotInstalled + + raise DependencyNotInstalled( + 'imageio is missing, please run `pip install "minigrid[wfc]"`' + ) from e + kwargs = asdict(self) + kwargs["image"] = imread(kwargs.pop("pattern_path"))[:, :, :3] + return kwargs + + +# Basic presets for WFC configurations (that should generate in <1 min) +WFC_PRESETS = { + "MazeSimple": WFCConfig( + pattern_path=PATTERN_PATH / "SimpleMaze.png", + tile_size=1, + pattern_width=2, + output_periodic=False, + input_periodic=False, + ), + "DungeonMazeScaled": WFCConfig( + pattern_path=PATTERN_PATH / "ScaledMaze.png", + tile_size=1, + pattern_width=2, + output_periodic=True, + input_periodic=True, + ), + "RoomsFabric": WFCConfig( + pattern_path=PATTERN_PATH / "Fabric.png", + tile_size=1, + pattern_width=3, + output_periodic=False, + input_periodic=False, + ), + "ObstaclesBlackdots": WFCConfig( + pattern_path=PATTERN_PATH / "Blackdots.png", + tile_size=1, + pattern_width=2, + output_periodic=False, + input_periodic=False, + ), + "ObstaclesAngular": WFCConfig( + pattern_path=PATTERN_PATH / "Angular.png", + tile_size=1, + pattern_width=3, + output_periodic=True, + input_periodic=True, + ), + "ObstaclesHogs3": WFCConfig( + pattern_path=PATTERN_PATH / "Hogs.png", + tile_size=1, + pattern_width=3, + output_periodic=True, + input_periodic=True, + ), +} + +# Presets that take a large number of attempts to generate a consistent environment +WFC_PRESETS_INCONSISTENT = { + "MazeKnot": WFCConfig( + pattern_path=PATTERN_PATH / "Knot.png", + tile_size=1, + pattern_width=3, + output_periodic=True, + input_periodic=True, + ), # This is not too inconsistent (often 10 attempts is enough) + "MazeWall": WFCConfig( + pattern_path=PATTERN_PATH / "SimpleWall.png", + tile_size=1, + pattern_width=2, + output_periodic=True, + input_periodic=True, + ), + "RoomsOffice": WFCConfig( + pattern_path=PATTERN_PATH / "Office.png", + tile_size=1, + pattern_width=3, + output_periodic=True, + input_periodic=True, + ), + "ObstaclesHogs2": WFCConfig( + pattern_path=PATTERN_PATH / "Hogs.png", + tile_size=1, + pattern_width=2, + output_periodic=True, + input_periodic=True, + ), + "Skew2": WFCConfig( + pattern_path=PATTERN_PATH / "Skew2.png", + tile_size=1, + pattern_width=3, + output_periodic=True, + input_periodic=True, + ), +} + +# Slow presets for WFC configurations (Most take about 2-4 min but some take 10+ min) +WFC_PRESETS_SLOW = { + "Maze": WFCConfig( + pattern_path=PATTERN_PATH / "Maze.png", + tile_size=1, + pattern_width=3, + output_periodic=True, + input_periodic=True, + ), # This is unusually slow: ~20min per 25x25 room + "MazeSpirals": WFCConfig( + pattern_path=PATTERN_PATH / "Spirals.png", + tile_size=1, + pattern_width=3, + output_periodic=True, + input_periodic=True, + ), + "MazePaths": WFCConfig( + pattern_path=PATTERN_PATH / "Paths.png", + tile_size=1, + pattern_width=3, + output_periodic=True, + input_periodic=True, + ), + "Mazelike": WFCConfig( + pattern_path=PATTERN_PATH / "Mazelike.png", + tile_size=1, + pattern_width=3, + output_periodic=True, + input_periodic=True, + ), + "Dungeon": WFCConfig( + pattern_path=PATTERN_PATH / "DungeonExtr.png", + tile_size=1, + pattern_width=3, + output_periodic=True, + input_periodic=True, + ), # ~10 mins + "DungeonRooms": WFCConfig( + pattern_path=PATTERN_PATH / "Rooms.png", + tile_size=1, + pattern_width=3, + output_periodic=True, + input_periodic=True, + ), + "DungeonLessRooms": WFCConfig( + pattern_path=PATTERN_PATH / "LessRooms.png", + tile_size=1, + pattern_width=3, + output_periodic=True, + input_periodic=True, + ), + "DungeonSpirals": WFCConfig( + pattern_path=PATTERN_PATH / "SpiralsNeg.png", + tile_size=1, + pattern_width=3, + output_periodic=True, + input_periodic=True, + ), + "RoomsMagicOffice": WFCConfig( + pattern_path=PATTERN_PATH / "MagicOffice.png", + tile_size=1, + pattern_width=3, + output_periodic=True, + input_periodic=True, + ), + "SkewCave": WFCConfig( + pattern_path=PATTERN_PATH / "Cave.png", + tile_size=1, + pattern_width=3, + output_periodic=False, + input_periodic=False, + ), + "SkewLake": WFCConfig( + pattern_path=PATTERN_PATH / "Lake.png", + tile_size=1, + pattern_width=3, + output_periodic=True, + input_periodic=True, + ), # ~10 mins +} diff --git a/minigrid/envs/wfc/graphtransforms.py b/minigrid/envs/wfc/graphtransforms.py new file mode 100644 index 000000000..c2ea41297 --- /dev/null +++ b/minigrid/envs/wfc/graphtransforms.py @@ -0,0 +1,396 @@ +from __future__ import annotations + +from collections import OrderedDict, defaultdict +from dataclasses import dataclass +from itertools import product + +import networkx as nx +import numpy as np + +from minigrid.core.constants import COLOR_TO_IDX, IDX_TO_OBJECT, OBJECT_TO_IDX +from minigrid.minigrid_env import MiniGridEnv + + +@dataclass +class EdgeDescriptor: + between: tuple[str, str] | tuple[str] + structure: str | None = None + + +# This is maybe general enough to be in utils +class GraphTransforms: + OBJECT_TO_DENSE_GRAPH_ATTRIBUTE = { + "empty": ("navigable", "empty"), + "start": ("navigable", "start"), + "agent": ("navigable", "start"), + "goal": ("navigable", "goal"), + "moss": ("navigable", "moss"), + "wall": ("non_navigable", "wall"), + "lava": ("non_navigable", "lava"), + } + + DENSE_GRAPH_ATTRIBUTE_TO_OBJECT = { + "empty": "empty", + "start": "start", + "goal": "goal", + "moss": "moss", + "wall": "wall", + "lava": "lava", + "navigable": None, + "non_navigable": None, + } + + MINIGRID_COLOR_CONFIG = { + "empty": None, + "wall": "grey", + "agent": "blue", + "goal": "green", + "lava": "red", + "moss": "purple", + } + + @staticmethod + def minigrid_to_bitmap(grids): + + layout = grids[..., 0] + bitmap = np.zeros_like(layout) + bitmap[layout == 2] = 1 + bitmap = list(bitmap) + + start_pos_id = np.where(layout == 10) + goal_pos_id = np.where(layout == 8) + + start_pos = [] + goal_pos = [] + for i in range(len(bitmap)): + bitmap[i] = bitmap[i][1:-1, 1:-1] + start_pos.append(np.array([start_pos_id[2][i], start_pos_id[1][i]])) + goal_pos.append(np.array([goal_pos_id[2][i], goal_pos_id[1][i]])) + + return bitmap, start_pos, goal_pos + + @staticmethod + def minigrid_to_dense_graph( + minigrids: np.ndarray | list[MiniGridEnv], + node_attr=None, + edge_config=None, + ) -> list[nx.Graph]: + if isinstance(minigrids[0], np.ndarray): + minigrids = np.array(minigrids) + layouts = minigrids[..., 0] + elif isinstance(minigrids[0], MiniGridEnv): + layouts = [minigrid.grid.encode()[..., 0] for minigrid in minigrids] + for i in range(len(minigrids)): + layouts[i][tuple(minigrids[i].agent_pos)] = OBJECT_TO_IDX["agent"] + layouts = np.array(layouts) + else: + raise TypeError( + f"minigrids must be of type List[np.ndarray], List[MiniGridEnv], " + f"List[MultiGridEnv], not {type(minigrids[0])}" + ) + graphs, _ = GraphTransforms.minigrid_layout_to_dense_graph( + layouts, remove_border=True, node_attr=node_attr, edge_config=edge_config + ) + return graphs + + @staticmethod + def minigrid_layout_to_dense_graph( + layouts: np.ndarray, remove_border=True, node_attr=None, edge_config=None + ) -> tuple[list[nx.Graph], dict[str, list[nx.Graph]]]: + + assert ( + layouts.ndim == 3 + ), f"Wrong dimensions for minigrid layout, expected 3 dimensions, got {layouts.ndim}." + + node_attr = [] if node_attr is None else node_attr + + # Remove borders + if remove_border: + layouts = layouts[:, 1:-1, 1:-1] # remove edges + dim_grid = layouts.shape[1:] + + # Get the objects present in the layout + objects_idx = np.unique(layouts) + object_instances = [IDX_TO_OBJECT[obj] for obj in objects_idx] + assert set(object_instances).issubset( + {"empty", "wall", "start", "goal", "agent", "lava", "moss"} + ), ( + f"Unsupported object(s) in minigrid layout. Supported objects are: " + f"empty, wall, start, goal, agent, lava, moss. Got {object_instances}." + ) + + # Get location of each object in the layout + object_locations = {} + for obj in object_instances: + object_locations[obj] = defaultdict(list) + ids = list(zip(*np.where(layouts == OBJECT_TO_IDX[obj]))) + for tup in ids: + object_locations[obj][tup[0]].append(tup[1:]) + for m in range(layouts.shape[0]): + if m not in object_locations[obj]: + object_locations[obj][m] = [] + object_locations[obj] = OrderedDict(sorted(object_locations[obj].items())) + if "start" not in object_instances and "agent" in object_instances: + object_locations["start"] = object_locations["agent"] + if "agent" not in object_instances and "start" in object_instances: + object_locations["agent"] = object_locations["start"] + + # Create one-hot graph feature tensor + graph_feats = {} + object_to_attr = GraphTransforms.OBJECT_TO_DENSE_GRAPH_ATTRIBUTE + for obj in object_instances: + for attr in object_to_attr[obj]: + if attr not in graph_feats and attr in node_attr: + graph_feats[attr] = np.zeros(layouts.shape) + loc = list(object_locations[obj].values()) + assert len(loc) == layouts.shape[0] + for m in range(layouts.shape[0]): + if loc[m]: + loc_m = np.array(loc[m]) + graph_feats[attr][m][loc_m[:, 0], loc_m[:, 1]] = 1 + for attr in node_attr: + if attr not in graph_feats: + graph_feats[attr] = np.zeros(layouts.shape) + graph_feats[attr] = graph_feats[attr].reshape(layouts.shape[0], -1) + + graphs, edge_graphs = GraphTransforms.features_to_dense_graph( + graph_feats, dim_grid, edge_config + ) + + return graphs, edge_graphs + + @staticmethod + def features_to_dense_graph( + features: dict[str, np.ndarray], + dim_grid: tuple, + edge_config: dict[str, EdgeDescriptor] = None, + ) -> tuple[list[nx.Graph], dict[str, list[nx.Graph]]]: + + graphs = [] + edge_graphs = defaultdict(list) + for m in range(features[list(features.keys())[0]].shape[0]): + g_temp = nx.grid_2d_graph(*dim_grid) + g = nx.Graph() + g.add_nodes_from(sorted(g_temp.nodes(data=True))) + for attr in features: + nx.set_node_attributes( + g, {k: v for k, v in zip(g.nodes, features[attr][m].tolist())}, attr + ) + if edge_config is not None: + edge_layers = GraphTransforms.get_edge_layers( + g, edge_config, list(features.keys()), dim_grid + ) + for edge_n, edge_g in edge_layers.items(): + g.add_edges_from(edge_g.edges(data=True), label=edge_n) + edge_graphs[edge_n].append(edge_g) + graphs.append(g) + + return graphs, edge_graphs + + @staticmethod + def graph_features_to_minigrid( + graph_features: dict[str, np.ndarray], shape: tuple[int, int], padding=1 + ) -> np.ndarray: + + features = graph_features.copy() + node_attributes = list(features.keys()) + + color_config = GraphTransforms.MINIGRID_COLOR_CONFIG + + # shape_no_padding = (features[node_attributes[0]].shape[-2], shape[0] - 2, shape[1] - 2, 3) + shape_no_padding = (shape[0] - 2 * padding, shape[1] - 2 * padding, 3) + for attr in node_attributes: + features[attr] = features[attr].reshape(*shape_no_padding[:-1]) + grids = np.ones(shape_no_padding, dtype=np.uint8) * OBJECT_TO_IDX["empty"] + + minigrid_object_to_encoding_map = {} # [object_id, color, state] + for feature in node_attributes: + obj_type = GraphTransforms.DENSE_GRAPH_ATTRIBUTE_TO_OBJECT[feature] + if ( + obj_type is not None + and obj_type not in minigrid_object_to_encoding_map.keys() + ): + if obj_type == "empty": + minigrid_object_to_encoding_map[obj_type] = [ + OBJECT_TO_IDX["empty"], + 0, + 0, + ] + elif obj_type == "agent": + minigrid_object_to_encoding_map[obj_type] = [ + OBJECT_TO_IDX["agent"], + 0, + 0, + ] + elif obj_type == "start": + color_str = color_config["agent"] + minigrid_object_to_encoding_map[obj_type] = [ + OBJECT_TO_IDX["agent"], + COLOR_TO_IDX[color_str], + 0, + ] + else: + color_str = color_config[obj_type] + minigrid_object_to_encoding_map[obj_type] = [ + OBJECT_TO_IDX[obj_type], + COLOR_TO_IDX[color_str], + 0, + ] + + if ( + "start" not in minigrid_object_to_encoding_map.keys() + and "agent" in minigrid_object_to_encoding_map.keys() + ): + minigrid_object_to_encoding_map["start"] = minigrid_object_to_encoding_map[ + "agent" + ] + if ( + "agent" not in minigrid_object_to_encoding_map.keys() + and "start" in minigrid_object_to_encoding_map.keys() + ): + minigrid_object_to_encoding_map["agent"] = minigrid_object_to_encoding_map[ + "start" + ] + + for i, attr in enumerate(node_attributes): + if "wall" not in node_attributes: + if attr == "navigable" and "wall" not in node_attributes: + mapping = minigrid_object_to_encoding_map["wall"] + grids[features[attr] == 0] = np.array(mapping, dtype=np.uint8) + else: + mapping = minigrid_object_to_encoding_map[attr] + grids[features[attr] == 1] = np.array(mapping, dtype=np.uint8) + else: + try: + mapping = minigrid_object_to_encoding_map[attr] + grids[features[attr] == 1] = np.array(mapping, dtype=np.uint8) + except KeyError: + pass + + wall_encoding = np.array( + minigrid_object_to_encoding_map["wall"], dtype=np.uint8 + ) + padded_grid = np.pad( + grids, + ((padding, padding), (padding, padding), (0, 0)), + "constant", + constant_values=-1, + ) + padded_grid = np.where( + padded_grid == -np.ones(3, dtype=np.uint8), wall_encoding, padded_grid + ) + return padded_grid + + @staticmethod + def get_node_features( + graph: nx.Graph, pattern_shape, node_attributes: list[str] = None, reshape=True + ) -> tuple[np.ndarray, list[str]]: + + if node_attributes is None: + # Get node attributes from some node + node_attributes = list(next(iter(graph.nodes.data()))[1].keys()) + + # Get node features + Fx = [] + for attr in node_attributes: + if attr == "non_navigable" or attr == "wall": + # The graph we are getting is only the navigable nodes so those that + # are not present should be assumed to be walls and non-navigable + f = np.ones(pattern_shape) + else: + f = np.zeros(pattern_shape) + for node, data in graph.nodes.data(attr): + f[node] = data + if reshape: + f = f.ravel() + Fx.append(f) + # Fx = torch.stack(Fx, dim=-1).to(device) + Fx = np.stack(Fx, axis=-1) + + return Fx, node_attributes + + @staticmethod + def dense_graph_to_minigrid( + graph: nx.Graph, shape: tuple[int, int], padding=1 + ) -> np.ndarray: + + pattern_shape = (shape[0] - 2 * padding, shape[1] - 2 * padding) + features, node_attributes = GraphTransforms.get_node_features( + graph, pattern_shape, node_attributes=None + ) + # num_zeros = features[features == 0.0].numel() + # num_ones = features[features == 1.0].numel() + num_zeros = (features == 0.0).sum() + num_ones = (features == 1.0).sum() + + assert num_zeros + num_ones == features.size, "Graph features should be binary" + features_dict = {} + for i, key in enumerate(node_attributes): + features_dict[key] = features[..., i] + grids = GraphTransforms.graph_features_to_minigrid( + features_dict, shape=shape, padding=padding + ) + + return grids + + @staticmethod + def get_edge_layers( + graph: nx.Graph, + edge_config: dict[str, EdgeDescriptor], + node_attr: list[str], + dim_grid: tuple[int, int], + ) -> dict[str, nx.Graph]: + + navigable_nodes = ["empty", "start", "goal", "moss"] + non_navigable_nodes = ["wall", "lava"] + assert all([isinstance(n, tuple) for n in graph.nodes]) + assert all([len(n) == 2 for n in graph.nodes]) + + def partial_grid(graph, nodes, dim_grid): + non_grid_nodes = [n for n in graph.nodes if n not in nodes] + g_temp = nx.grid_2d_graph(*dim_grid) + g_temp.remove_nodes_from(non_grid_nodes) + g_temp.add_nodes_from(non_grid_nodes) + g = nx.Graph() + g.add_nodes_from(graph.nodes(data=True)) + g.add_edges_from(g_temp.edges) + return g + + def pair_edges(graph, node_types): + all_nodes = [] + for n_type in node_types: + all_nodes.append( + [n for n, a in graph.nodes.items() if a[n_type] >= 1.0] + ) + edges = list(product(*all_nodes)) + edged_graph = nx.create_empty_copy(graph, with_data=True) + edged_graph.add_edges_from(edges) + return edged_graph + + edge_graphs = {} + for edge_ in edge_config.keys(): + if edge_ == "navigable" and "navigable" not in node_attr: + edge_config[edge_].between = navigable_nodes + elif edge_ == "non_navigable" and "non_navigable" not in node_attr: + edge_config[edge_].between = non_navigable_nodes + elif not set(edge_config[edge_].between).issubset(set(node_attr)): + # TODO: remove + # logger.warning(f"Edge {edge_} not compatible with node attributes {node_attr}. Skipping.") + continue + if edge_config[edge_].structure is None: + edge_graphs[edge_] = pair_edges(graph, edge_config[edge_].between) + elif edge_config[edge_].structure == "grid": + nodes = [] + for n_type in edge_config[edge_].between: + nodes += [ + n + for n, a in graph.nodes.items() + if a[n_type] >= 1.0 and n not in nodes + ] + edge_graphs[edge_] = partial_grid(graph, nodes, dim_grid) + else: + raise NotImplementedError( + f"Edge structure {edge_config[edge_].structure} not supported." + ) + + return edge_graphs diff --git a/minigrid/envs/wfc/patterns/Angular.png b/minigrid/envs/wfc/patterns/Angular.png new file mode 100644 index 0000000000000000000000000000000000000000..b188366d97ed03c3b913d91e027b54f94c4cf827 GIT binary patch literal 101 zcmeAS@N?(olHy`uVBq!ia0vp^+#t-s1SHkYJtzcHI-V|$Ar-fh|NQ@N&m4E4=~#M7 zip#Tq|NqbLY!q0~#wF&nJ2QdXG{iuNCzv7o38xs(hv}<<`WQT2{an^LB{Ts5krW@_ literal 0 HcmV?d00001 diff --git a/minigrid/envs/wfc/patterns/Blackdots.png b/minigrid/envs/wfc/patterns/Blackdots.png new file mode 100644 index 0000000000000000000000000000000000000000..6feedb68803fa14116b958ec4281401f5974fc51 GIT binary patch literal 2804 zcmVKLZ*U+IBfRsybQWXdwQbLP>6pAqfylh#{fb6;Z(vMMVS~$e@S=j*ftg6;Uhf59&ghTmgWD0l;*T zI709Y^p6lP1rIRMx#05C~cW=H_Aw*bJ-5DT&Z2n+x)QHX^p z00esgV8|mQcmRZ%02D^@S3L16t`O%c004NIvOKvYIYoh62rY33S640`D9%Y2D-rV&neh&#Q1i z007~1e$oCcFS8neI|hJl{-P!B1ZZ9hpmq0)X0i`JwE&>$+E?>%_LC6RbVIkUx0b+_+BaR3cnT7Zv!AJxW zizFb)h!jyGOOZ85F;a?DAXP{m@;!0_IfqH8(HlgRxt7s3}k3K`kFu>>-2Q$QMFfPW!La{h336o>X zu_CMttHv6zR;&ZNiS=X8v3CR#fknUxHUxJ0uoBa_M6WNWeqIg~6QE69c9o#eyhGvpiOA@W-aonk<7r1(?fC{oI5N*U!4 zfg=2N-7=cNnjjOr{yriy6mMFgG#l znCF=fnQv8CDz++o6_Lscl}eQ+l^ZHARH>?_s@|##Rr6KLRFA1%Q+=*RRWnoLsR`7U zt5vFIcfW3@?wFpwUVxrVZ>QdQz32KIeJ}k~{cZZE^+ya? z2D1z#2HOnI7(B%_ac?{wFUQ;QQA1tBKtrWrm0_3Rgps+?Jfqb{jYbcQX~taRB;#$y zZN{S}1|}gUOHJxc?wV3fxuz+mJ4`!F$IZ;mqRrNsHJd##*D~ju=bP7?-?v~|cv>vB zsJ6IeNwVZxrdjT`yl#bBIa#GxRa#xMMy;K#CDyyGyQdMSxlWT#tDe?p!?5wT$+oGt z8L;Kp2HUQ-ZMJ=3XJQv;x5ci*?vuTfeY$;({XGW_huIFR9a(?@3)XSs8O^N5RyOM=TTmp(3=8^+zpz2r)C z^>JO{deZfso3oq3?Wo(Y?l$ge?uXo;%ru`Vo>?<<(8I_>;8Eq#KMS9gFl*neeosSB zfoHYnBQIkwkyowPu(zdms`p{<7e4kra-ZWq<2*OsGTvEV%s0Td$hXT+!*8Bnh2KMe zBmZRodjHV?r+_5^X9J0WL4jKW`}lf%A-|44I@@LTvf1rHjG(ze6+w@Jt%Bvjts!X0 z?2xS?_ve_-kiKB_KiJlZ$9G`c^=E@oNG)mWWaNo-3TIW8)$Hg0Ub-~8?KhvJ>$ z3*&nim@mj(aCxE5!t{lw7O5^0EIO7zOo&c6l<+|iDySBWCGrz@C5{St!X3hAA}`T4 z(TLbXTq+(;@<=L8dXnssyft|w#WSTW<++3>sgS%(4NTpeI-VAqb|7ssJvzNHgOZVu zaYCvgO_R1~>SyL=cFU|~g|hy|Zi}}s9+d~lYqOB71z9Z$wnC=pR9Yz4DhIM>Wmjgu z&56o6maCpC&F##y%G;1PobR9i?GnNg;gYtchD%p19a!eQtZF&3JaKv33gZ<8D~47E ztUS1iwkmDaPpj=$m#%)jCVEY4fnLGNg2A-`YwHVD3gv};>)hAvT~AmqS>Lr``i7kw zJ{5_It`yrBmlc25DBO7E8;5VoznR>Ww5hAaxn$2~(q`%A-YuS64wkBy=9dm`4cXeX z4c}I@?e+FW+b@^RDBHV(wnMq2zdX3SWv9u`%{xC-q*U}&`cyXV(%rRT*Z6MH?i+i& z_B8C(+grT%{XWUQ+f@NoP1R=AW&26{v-dx)iK^-Nmiuj8txj!m?Z*Ss1N{dh4z}01 z)YTo*JycSU)+_5r4#yw9{+;i4Ee$peRgIj+;v;ZGdF1K$3E%e~4LaI(jC-u%2h$&R z9cLXcYC@Xwnns&bn)_Q~Te?roKGD|d-g^8;+aC{{G(1^(O7m37Y1-+6)01cN&y1aw zoqc{T`P^XJqPBbIW6s}d4{z_f5Om?vMgNQEJG?v2T=KYd^0M3I6IZxbny)%vZR&LD zJpPl@Psh8QyPB@KTx+@RdcC!KX7}kEo;S|j^u2lU7XQ}Oo;f|;z4Ll+_r>@1-xl3| zawq-H%e&ckC+@AhPrP6BKT#_XdT7&;F71j}Joy zkC~6lh7E@6o;W@^IpRNZ{ptLtL(gQ-CY~4mqW;US7Zxvm_|@yz&e53Bp_lTPlfP|z zrTyx_>lv@x#=^!PzR7qqF<$gm`|ZJZ+;<)Cqu&ot2z=0000WV@Og>004R=004l4008;_004mL004C`008P>0026e000+nl3&F} z0000VNklHA#rhYWpPWovV37p#cLM-$ItE7?L{2OK00002BR0pkS1z zi(`n!#HAAt3Na|KxJ3W|f8J__l#}6G-H&0nCG3-)oepHF=Y8;ET9+h8BHNy{6}xpS zFYjK(H#_6NP46xBCl9K8J0fov^5)&zB?~jxFdSO*cuL7?k=OEH*K*!WR~M|%|Jehy Og~8L+&t;ucLK6T|nLNY* literal 0 HcmV?d00001 diff --git a/minigrid/envs/wfc/patterns/City.png b/minigrid/envs/wfc/patterns/City.png new file mode 100644 index 0000000000000000000000000000000000000000..5865e2fc9bf2b199e6386c5e8db31fddc855bfac GIT binary patch literal 87 zcmeAS@N?(olHy`uVBq!ia0vp^oFL4>1SIo6Pjm-Ta-J@ZAr-fh|NQ@N&#c$r{qFz& k|2{lB4kk=WKJwU}LBofKIm>X>I-p7hPgg&ebxsLQ0NA%0y#N3J literal 0 HcmV?d00001 diff --git a/minigrid/envs/wfc/patterns/DungeonExtr.png b/minigrid/envs/wfc/patterns/DungeonExtr.png new file mode 100644 index 0000000000000000000000000000000000000000..bfef84e2852f558e8fd0a227ff8f0fa5c9df2aaf GIT binary patch literal 3147 zcmV-R47Br!P)KLZ*U+IBfRsybQWXdwQbLP>6pAqfylh#{fb6;Z(vMMVS~$e@S=j*ftg6;Uhf59&ghTmgWD0l;*T zI709Y^p6lP1rIRMx#05C~cW=H_Aw*bJ-5DT&Z2n+x)QHX^p z00esgV8|mQcmRZ%02D^@S3L16t`O%c004NIvOKvYIYoh62rY33S640`D9%Y2D-rV&neh&#Q1i z007~1e$oCcFS8neI|hJl{-P!B1ZZ9hpmq0)X0i`JwE&>$+E?>%_LC6RbVIkUx0b+_+BaR3cnT7Zv!AJxW zizFb)h!jyGOOZ85F;a?DAXP{m@;!0_IfqH8(HlgRxt7s3}k3K`kFu>>-2Q$QMFfPW!La{h336o>X zu_CMttHv6zR;&ZNiS=X8v3CR#fknUxHUxJ0uoBa_M6WNWeqIg~6QE69c9o#eyhGvpiOA@W-aonk<7r1(?fC{oI5N*U!4 zfg=2N-7=cNnjjOr{yriy6mMFgG#l znCF=fnQv8CDz++o6_Lscl}eQ+l^ZHARH>?_s@|##Rr6KLRFA1%Q+=*RRWnoLsR`7U zt5vFIcfW3@?wFpwUVxrVZ>QdQz32KIeJ}k~{cZZE^+ya? z2D1z#2HOnI7(B%_ac?{wFUQ;QQA1tBKtrWrm0_3Rgps+?Jfqb{jYbcQX~taRB;#$y zZN{S}1|}gUOHJxc?wV3fxuz+mJ4`!F$IZ;mqRrNsHJd##*D~ju=bP7?-?v~|cv>vB zsJ6IeNwVZxrdjT`yl#bBIa#GxRa#xMMy;K#CDyyGyQdMSxlWT#tDe?p!?5wT$+oGt z8L;Kp2HUQ-ZMJ=3XJQv;x5ci*?vuTfeY$;({XGW_huIFR9a(?@3)XSs8O^N5RyOM=TTmp(3=8^+zpz2r)C z^>JO{deZfso3oq3?Wo(Y?l$ge?uXo;%ru`Vo>?<<(8I_>;8Eq#KMS9gFl*neeosSB zfoHYnBQIkwkyowPu(zdms`p{<7e4kra-ZWq<2*OsGTvEV%s0Td$hXT+!*8Bnh2KMe zBmZRodjHV?r+_5^X9J0WL4jKW`}lf%A-|44I@@LTvf1rHjG(ze6+w@Jt%Bvjts!X0 z?2xS?_ve_-kiKB_KiJlZ$9G`c^=E@oNG)mWWaNo-3TIW8)$Hg0Ub-~8?KhvJ>$ z3*&nim@mj(aCxE5!t{lw7O5^0EIO7zOo&c6l<+|iDySBWCGrz@C5{St!X3hAA}`T4 z(TLbXTq+(;@<=L8dXnssyft|w#WSTW<++3>sgS%(4NTpeI-VAqb|7ssJvzNHgOZVu zaYCvgO_R1~>SyL=cFU|~g|hy|Zi}}s9+d~lYqOB71z9Z$wnC=pR9Yz4DhIM>Wmjgu z&56o6maCpC&F##y%G;1PobR9i?GnNg;gYtchD%p19a!eQtZF&3JaKv33gZ<8D~47E ztUS1iwkmDaPpj=$m#%)jCVEY4fnLGNg2A-`YwHVD3gv};>)hAvT~AmqS>Lr``i7kw zJ{5_It`yrBmlc25DBO7E8;5VoznR>Ww5hAaxn$2~(q`%A-YuS64wkBy=9dm`4cXeX z4c}I@?e+FW+b@^RDBHV(wnMq2zdX3SWv9u`%{xC-q*U}&`cyXV(%rRT*Z6MH?i+i& z_B8C(+grT%{XWUQ+f@NoP1R=AW&26{v-dx)iK^-Nmiuj8txj!m?Z*Ss1N{dh4z}01 z)YTo*JycSU)+_5r4#yw9{+;i4Ee$peRgIj+;v;ZGdF1K$3E%e~4LaI(jC-u%2h$&R z9cLXcYC@Xwnns&bn)_Q~Te?roKGD|d-g^8;+aC{{G(1^(O7m37Y1-+6)01cN&y1aw zoqc{T`P^XJqPBbIW6s}d4{z_f5Om?vMgNQEJG?v2T=KYd^0M3I6IZxbny)%vZR&LD zJpPl@Psh8QyPB@KTx+@RdcC!KX7}kEo;S|j^u2lU7XQ}Oo;f|;z4Ll+_r>@1-xl3| zawq-H%e&ckC+@AhPrP6BKT#_XdT7&;F71j}Joy zkC~6lh7E@6o;W@^IpRNZ{ptLtL(gQ-CY~4mqW;US7Zxvm_|@yz&e53Bp_lTPlfP|z zrTyx_>lv@x#=^!PzR7qqF<$gm`|ZJZ+;<)Cqu&ot2z=0000WV@Og>004R=004l4008;_004mL004C`008P>0026e000+nl3&F} z0004YNkl9%-V$Wq;9&Y3#D*?JwS{VtKZ)I)wvai=u&4MaBt-G16BeTYN8`MiUnW)`}Y z`mEc5NP}c+;NF{MvToU|>2Z1C5P}oY9wkgeC=A^>>+kDm9~XRC8fC)^Is~qY>$vo$ zNc)BzJayn3m~vqPr>^#Otpv58QPdNu0~dlKv;6(U#~jtggk?NZJ5o2qQwwSi{w~H9 zrbq{=6;SKu<7Anoq-YTJ&`!oaR|_C+!xgAI8f|g9pG?xp630NPhv(P34*A^NedOL+ z26>l%`YN lb(}v|+Z4#J-v~Ii&j34LTr?%of~Nog002ovPDHLkV1g%n2BR0prDqg zi(?4K%-mCod<+IWhaGnRuP-**;H7MAm2jKs)lmiqp?zF+5 s!j&5f4Zm4#f8*8B-*33{0@qc>d9QSNqIWO-1hj&|)78&qol`;+0Crz9ga7~l literal 0 HcmV?d00001 diff --git a/minigrid/envs/wfc/patterns/Lake.png b/minigrid/envs/wfc/patterns/Lake.png new file mode 100644 index 0000000000000000000000000000000000000000..0466a0b103f937110ac1959d2a8da8f87f37d1cc GIT binary patch literal 190 zcmeAS@N?(olHy`uVBq!ia0vp^B0wz6!2~4b-|cz?q*^^)978H@Ee-MHJD|Xk!twvV zk?j_#W{ZSVfoY5`ic0VFlo?J~s(or*`G;f8wj-7Y>2BR0pkSP* zi(`n!#I=(G`5Fv3oGVYitAF|{QAOo&Wba>P?O08hmu^5JbJ7~yD-z3{SN)L6zHz;f!LLSF{D|%T Q0H93_p00i_>zopr0QgimUjP6A literal 0 HcmV?d00001 diff --git a/minigrid/envs/wfc/patterns/MagicOffice.png b/minigrid/envs/wfc/patterns/MagicOffice.png new file mode 100644 index 0000000000000000000000000000000000000000..7164edbc2261c39150b7228690aeaca6d9370f7e GIT binary patch literal 151 zcmeAS@N?(olHy`uVBq!ia0vp^{2ah@)YAr-f_2A$?RpujWf*#G~j zWd_RYT%Ys3_T=eLU_Du8bJ33X8N4EY8$#Xk`r9rnkG`?AlZRX65af zEbQrcf716G&p+$t?o2&>b=Da@^Cri`{rg`ri5F&E>p!a(544HF)78&qol`;+0M;Bj A(*OVf literal 0 HcmV?d00001 diff --git a/minigrid/envs/wfc/patterns/Maze.png b/minigrid/envs/wfc/patterns/Maze.png new file mode 100644 index 0000000000000000000000000000000000000000..da2bf4df688e3b7cc69811e504c963faf81552f9 GIT binary patch literal 178 zcmeAS@N?(olHy`uVBq!ia0vp^0wB!61SBU+%rFB|m7Xq+Ar-fJgAR%vFko42`~Uyx zn+wxsZa8}|SW9W`+U1NVZk#dNnapr*dYM^v%&7z164STq7st%hU>QbN{etL>|PJczUXKKVt{ zGc}gE{&R0!+0x_A_uKYF>yA556P_NI-R^sORqo}(?Y})Uk{R>l>{Yn8?7I%Mlfl!~ K&t;ucLK6Vt)I3%I literal 0 HcmV?d00001 diff --git a/minigrid/envs/wfc/patterns/Office.png b/minigrid/envs/wfc/patterns/Office.png new file mode 100644 index 0000000000000000000000000000000000000000..a4912f28fd57271aaa5fdfdbd647204a992fb6d8 GIT binary patch literal 171 zcmeAS@N?(olHy`uVBq!ia0vp^{2oCO|{#S9GGLLkg|>2BR0pkS=0 zi(`n!#HkYl`3@-XIGx_~JAP*E3=1_`w|`5P*ll^`vO;T;CgYg{rF(9>3N^@83+=r= z*LHi>(=di|sYzw8Gm5q|tIPfKnq>WEr+x5Iu8{m+{8#SV{oPaUk(pl25IMnyt$X?d Q6`(~7p00i_>zopr0E_rK!T|^4RYi=V8GFI^nd+K z*=c5{yN_ w#W93qX7Zo^|LvJ&6pCjoa5^F5#s&fm;+xqS+~k)o1S(_jboFyt=akR{0RFcZC;$Ke literal 0 HcmV?d00001 diff --git a/minigrid/envs/wfc/patterns/Rooms.png b/minigrid/envs/wfc/patterns/Rooms.png new file mode 100644 index 0000000000000000000000000000000000000000..4210deb1f6e15dc462fdd7e7c874737e38552040 GIT binary patch literal 181 zcmeAS@N?(olHy`uVBq!ia0vp^0wB!61SBU+%rFB|oCO|{#S9GGLLkg|>2BR0pkTVE zi(`n!#JPcvdzDHJ@IZsD(?&( z#Y_X{2J1PKREp1beUBAu-(8dSe`Jepr6HBW!pWVI9R aC*vHxTZ;KfD=!6F$>8bg=d#Wzp$P!w;yrNy literal 0 HcmV?d00001 diff --git a/minigrid/envs/wfc/patterns/ScaledMaze.png b/minigrid/envs/wfc/patterns/ScaledMaze.png new file mode 100644 index 0000000000000000000000000000000000000000..537c506f15f0f679c76533697d70acbb652f34cb GIT binary patch literal 135 zcmeAS@N?(olHy`uVBq!ia0vp^93afW1SGw4HSYi^#^NA%Cx&(BWL^R}Ea{HEjtmSN z`?>!lvI6;>1s;*b3=DjSL74G){)!Z!pp2)BV~9j}@{fLIfk_DmSX?|6Hg`^TV7#F! aA;ECDGk#&xjXzG1_ch5=(qpvue}v;>0G&e>QPCSlY!@WBoD}X9y@mLUhYak&%88Li&pg>pB}Vg!QYomyhlZ479A*i1T>Ms)78&qol`;+0NQdhE&u=k literal 0 HcmV?d00001 diff --git a/minigrid/envs/wfc/patterns/SimpleMaze.png b/minigrid/envs/wfc/patterns/SimpleMaze.png new file mode 100644 index 0000000000000000000000000000000000000000..c6c9bd4ebc5b311b88ef7ac561b68ffbf04d5481 GIT binary patch literal 99 zcmeAS@N?(olHy`uVBq!ia0vp^EFjFm1SHiab7}%9&H|6fVg?3oArNM~bhqvgP*Bv< s#W93qX7Zo^|LvJ&6pCjAv`k24i1z3H(Ow@g3#fp>)78&qol`;+0G@Ui;s5{u literal 0 HcmV?d00001 diff --git a/minigrid/envs/wfc/patterns/SimpleWall.png b/minigrid/envs/wfc/patterns/SimpleWall.png new file mode 100644 index 0000000000000000000000000000000000000000..9369ff6ca81e0eaa30c1a0b21a53d309613d3a19 GIT binary patch literal 138 zcmeAS@N?(olHy`uVBq!ia0vp^f*{Pn1SGfcUswyII14-?iy0WWg+Q3`(%rg0KtUT% z7sn8diOGNd|F>t>N=QjLBH|*JerU>KP6rF->5~nXo+!MXaLt)7AtfO}!Dz>W1cP?Q hzn7(*I>ag&8Il(&^4b-eWCG1$@O1TaS?83{1ORBZC@ug1 literal 0 HcmV?d00001 diff --git a/minigrid/envs/wfc/patterns/Skew1.png b/minigrid/envs/wfc/patterns/Skew1.png new file mode 100644 index 0000000000000000000000000000000000000000..65f27eafd92eb8c6f9d5ab616d9b3825663227e0 GIT binary patch literal 253 zcmeAS@N?(olHy`uVBq!ia0vp^!XV7S1|*9D%+3HQ#^NA%Cx&(BWL^R}Ea{HEjtmSN z`?>!lvI6;>1s;*b3=DinK$vl=HlH+5aFVBsV~BOM-;eG&+QnrQm{KOp z=u@UFRhT)ei$f-L6^Bly)`m$wkxPV(1s7z_(`$b7Cc=1@n-Wu|d2)d&(@LH@9&Af} u80YKWXL+r+HdFo<^J|B5DfY`UzcF1*w-u9k#S3&Q1B0ilpUXO@geCxS2342< literal 0 HcmV?d00001 diff --git a/minigrid/envs/wfc/patterns/Skew2.png b/minigrid/envs/wfc/patterns/Skew2.png new file mode 100644 index 0000000000000000000000000000000000000000..8c60b39956770543efe5c9e542f5bb6c98dbe784 GIT binary patch literal 259 zcmeAS@N?(olHy`uVBq!ia0vp^!ayv_!3HFka9+I#q!^2X+?^QKos)S9a~60+7BevL9RXp+soH$fK*1TFE{-7<{%a>j^ED{&uyB9hZ@fvX?fAEa z7MvfO-`rU0>EdOp=%qPHNNG+@eZ8X2;Sf2+Nc~Mq*7dZSF4Ww-ETR0BnXVvn^s>pW zE*#r0N2CM_CQo?V;jpUd!{dxa%T5%(Uw`1^l^%}0LWgI+<(i)Rj$`5ybNz2cO^P9s zA6kmLm!|cH&b&L*=z!BL{}{;^hp%l>l-!^quEai%)lRqf#=HqYH!^s-`njxgN@xNA DPm^7i literal 0 HcmV?d00001 diff --git a/minigrid/envs/wfc/patterns/Spirals.png b/minigrid/envs/wfc/patterns/Spirals.png new file mode 100644 index 0000000000000000000000000000000000000000..b68a9615c767227fb8241323dd194ffe0e7e5de0 GIT binary patch literal 137 zcmeAS@N?(olHy`uVBq!ia0vp@Ak4u8B>#36YXhkOPZ!6Kid(G{9EBJZIG7@r|F6Hv zXX=}}(mar5iqp<&mT05uH=ayo=gzU{1hRnoM+ j+sB)D?G98weUWhiqcO|L?J7HgW;1xY`njxgN@xNACL1f| literal 0 HcmV?d00001 diff --git a/minigrid/envs/wfc/patterns/SpiralsNeg.png b/minigrid/envs/wfc/patterns/SpiralsNeg.png new file mode 100644 index 0000000000000000000000000000000000000000..5592020be2458a9025045fd488d44e1b9fa3daa0 GIT binary patch literal 144 zcmeAS@N?(olHy`uVBq!ia0vp@Ak4u8B>#36YXhlpPZ!6Kid$1BIPx(laxk}k_^*Fn z&MfoBBDLF-!`QhP*Q{4cnzeL$Y3R!(Qu>!Y71k#2XKMQV<>Z`A%Z!iNvnF)!c nx.Graph: + wall_graph_attr = GraphTransforms.OBJECT_TO_DENSE_GRAPH_ATTRIBUTE["wall"] + # Prepare graph + inactive_nodes = [x for x, y in graph.nodes(data=True) if y["navigable"] < 0.5] + graph.remove_nodes_from(inactive_nodes) + + components = [ + graph.subgraph(c).copy() + for c in sorted(nx.connected_components(graph), key=len, reverse=True) + if len(c) > 1 + ] + component = components[0] + graph = graph.subgraph(component) + + for node in graph.nodes(): + if node not in component.nodes(): + for feat in graph.nodes[node]: + if feat in wall_graph_attr: + graph.nodes[node][feat] = 1.0 + else: + graph.nodes[node][feat] = 0.0 + # TODO: Check if this is necessary + g = nx.Graph() + g.add_nodes_from(graph.nodes(data=True)) + g.add_edges_from(component.edges(data=True)) + + g_out = copy.deepcopy(g) + + return g_out + + def _place_start_and_goal_random(self, graph: nx.Graph) -> nx.Graph: + node_set = "navigable" + + # Get two random navigable nodes + possible_nodes = [n for n, d in graph.nodes(data=True) if d[node_set]] + inds = self.np_random.permutation(len(possible_nodes))[:2] + start_node, goal_node = possible_nodes[inds[0]], possible_nodes[inds[1]] + + graph.nodes[start_node]["start"] = 1 + graph.nodes[goal_node]["goal"] = 1 + + return graph diff --git a/minigrid/envs/wfc/wfclogic/__init__.py b/minigrid/envs/wfc/wfclogic/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/minigrid/envs/wfc/wfclogic/adjacency.py b/minigrid/envs/wfc/wfclogic/adjacency.py new file mode 100644 index 000000000..0923a79bd --- /dev/null +++ b/minigrid/envs/wfc/wfclogic/adjacency.py @@ -0,0 +1,56 @@ +"""Convert input data to adjacency information. Implementation based on https://github.com/ikarth/wfc_2019f""" +from __future__ import annotations + +import numpy as np +from numpy.typing import NDArray + + +def adjacency_extraction( + pattern_grid: NDArray[np.int64], + pattern_catalog: dict[int, NDArray[np.int64]], + direction_offsets: list[tuple[int, tuple[int, int]]], + pattern_size: tuple[int, int] = (2, 2), +) -> list[tuple[tuple[int, int], int, int]]: + """Takes a pattern grid and returns a list of all of the legal adjacencies found in it.""" + + def is_valid_overlap_xy( + adjacency_direction: tuple[int, int], pattern_1: int, pattern_2: int + ) -> bool: + """Given a direction and two patterns, find the overlap of the two patterns + and return True if the intersection matches.""" + dimensions = (1, 0) + not_a_number = -1 + + # TODO: can probably speed this up by using the right slices, rather than rolling the whole pattern... + shifted = np.roll( + np.pad( + pattern_catalog[pattern_2], + max(pattern_size), + mode="constant", + constant_values=not_a_number, + ), + adjacency_direction, + dimensions, + ) + compare = shifted[ + pattern_size[0] : pattern_size[0] + pattern_size[0], + pattern_size[1] : pattern_size[1] + pattern_size[1], + ] + + left = max(0, 0, +adjacency_direction[0]) + right = min(pattern_size[0], pattern_size[0] + adjacency_direction[0]) + top = max(0, 0 + adjacency_direction[1]) + bottom = min(pattern_size[1], pattern_size[1] + adjacency_direction[1]) + a = pattern_catalog[pattern_1][top:bottom, left:right] + b = compare[top:bottom, left:right] + res = np.array_equal(a, b) + return res + + pattern_list = list(pattern_catalog.keys()) + legal = [] + for pattern_1 in pattern_list: + for pattern_2 in pattern_list: + for _direction_index, direction in direction_offsets: + if is_valid_overlap_xy(direction, pattern_1, pattern_2): + legal.append((direction, pattern_1, pattern_2)) + return legal diff --git a/minigrid/envs/wfc/wfclogic/control.py b/minigrid/envs/wfc/wfclogic/control.py new file mode 100644 index 000000000..9d8e89870 --- /dev/null +++ b/minigrid/envs/wfc/wfclogic/control.py @@ -0,0 +1,295 @@ +"""Main WFC execution function. Implementation based on https://github.com/ikarth/wfc_2019f""" +from __future__ import annotations + +import logging +import time +from typing import Any, Callable + +import numpy as np +from numpy.typing import NDArray +from typing_extensions import Literal + +from minigrid.envs.wfc.wfclogic.adjacency import adjacency_extraction +from minigrid.envs.wfc.wfclogic.patterns import ( + make_pattern_catalog_with_rotations, + pattern_grid_to_tiles, +) +from minigrid.envs.wfc.wfclogic.solver import ( + Contradiction, + StopEarly, + TimedOut, + lexicalLocationHeuristic, + lexicalPatternHeuristic, + make_global_use_all_patterns, + makeAdj, + makeAntiEntropyLocationHeuristic, + makeEntropyLocationHeuristic, + makeHilbertLocationHeuristic, + makeRandomLocationHeuristic, + makeRandomPatternHeuristic, + makeRarestPatternHeuristic, + makeSpiralLocationHeuristic, + makeWave, + makeWeightedPatternHeuristic, + run, + simpleLocationHeuristic, +) + +from .tiles import make_tile_catalog +from .utilities import tile_grid_to_image + +logger = logging.getLogger(__name__) + + +def make_log_stats() -> Callable[[dict[str, Any], str], None]: + log_line = 0 + + def log_stats(stats: dict[str, Any], filename: str) -> None: + nonlocal log_line + if stats: + log_line += 1 + with open(filename, "a", encoding="utf_8") as logf: + if log_line < 2: + for s in stats.keys(): + print(str(s), end="\t", file=logf) + print("", file=logf) + for s in stats.keys(): + print(str(stats[s]), end="\t", file=logf) + print("", file=logf) + + return log_stats + + +def execute_wfc( + image: NDArray[np.integer], + tile_size: int = 1, + pattern_width: int = 2, + rotations: int = 8, + output_size: tuple[int, int] = (48, 48), + ground: int | None = None, + attempt_limit: int = 10, + output_periodic: bool = True, + input_periodic: bool = True, + loc_heuristic: Literal[ + "lexical", "hilbert", "spiral", "entropy", "anti-entropy", "simple", "random" + ] = "entropy", + choice_heuristic: Literal["lexical", "rarest", "weighted", "random"] = "weighted", + global_constraint: Literal[False, "allpatterns"] = False, + backtracking: bool = False, + log_filename: str = "log", + logging: bool = False, + global_constraints: None = None, + log_stats_to_output: Callable[[dict[str, Any], str], None] | None = None, + np_random: np.random.Generator | None = None, +) -> NDArray[np.integer]: + time_begin = time.perf_counter() + output_destination = r"./output/" + np_random: np.random.Generator = ( + np.random.default_rng() if np_random is None else np_random + ) + + rotations -= 1 # change to zero-based + + input_stats = { + "tile_size": tile_size, + "pattern_width": pattern_width, + "rotations": rotations, + "output_size": output_size, + "ground": ground, + "attempt_limit": attempt_limit, + "output_periodic": output_periodic, + "input_periodic": input_periodic, + "location heuristic": loc_heuristic, + "choice heuristic": choice_heuristic, + "global constraint": global_constraint, + "backtracking": backtracking, + } + # TODO: generalize this to more than the four cardinal directions + direction_offsets = list(enumerate([(0, -1), (1, 0), (0, 1), (-1, 0)])) + + tile_catalog, tile_grid, _code_list, _unique_tiles = make_tile_catalog( + image, tile_size + ) + ( + pattern_catalog, + pattern_weights, + pattern_list, + pattern_grid, + ) = make_pattern_catalog_with_rotations( + tile_grid, pattern_width, input_is_periodic=input_periodic, rotations=rotations + ) + + logger.debug("profiling adjacency relations") + + adjacency_relations = adjacency_extraction( + pattern_grid, + pattern_catalog, + direction_offsets, + (pattern_width, pattern_width), + ) + + logger.debug("adjacency_relations") + + logger.debug(f"output size: {output_size}\noutput periodic: {output_periodic}") + number_of_patterns = len(pattern_weights) + logger.debug(f"# patterns: {number_of_patterns}") + decode_patterns = dict(enumerate(pattern_list)) + encode_patterns = {x: i for i, x in enumerate(pattern_list)} + + adjacency_list: dict[tuple[int, int], list[set[int]]] = {} + for _, adjacency in direction_offsets: + adjacency_list[adjacency] = [set() for _ in pattern_weights] + # logger.debug(adjacency_list) + for adjacency, pattern1, pattern2 in adjacency_relations: + # logger.debug(adjacency) + # logger.debug(decode_patterns[pattern1]) + adjacency_list[adjacency][encode_patterns[pattern1]].add( + encode_patterns[pattern2] + ) + + logger.debug(f"adjacency: {len(adjacency_list)}") + + time_adjacency = time.perf_counter() + + # Ground # + + ground_list: NDArray[np.int64] | None = None + if ground: + ground_list = np.vectorize(lambda x: encode_patterns[x])( + pattern_grid.flat[(ground - 1) :] + ) + if ground_list is None or ground_list.size == 0: + ground_list = None + + wave = makeWave( + number_of_patterns, output_size[0], output_size[1], ground=ground_list + ) + adjacency_matrix = makeAdj(adjacency_list) + + # Heuristics # + + encoded_weights: NDArray[np.float64] = np.zeros( + (number_of_patterns), dtype=np.float64 + ) + for w_id, w_val in pattern_weights.items(): + encoded_weights[encode_patterns[w_id]] = w_val + choice_random_weighting: NDArray[np.float64] = ( + np_random.random(wave.shape[1:]) * 0.1 + ) + + pattern_heuristic: Callable[ + [NDArray[np.bool_], NDArray[np.bool_]], int + ] = lexicalPatternHeuristic + if choice_heuristic == "rarest": + pattern_heuristic = makeRarestPatternHeuristic(encoded_weights, np_random) + if choice_heuristic == "weighted": + pattern_heuristic = makeWeightedPatternHeuristic(encoded_weights, np_random) + if choice_heuristic == "random": + pattern_heuristic = makeRandomPatternHeuristic(encoded_weights, np_random) + + logger.debug(loc_heuristic) + location_heuristic: Callable[ + [NDArray[np.bool_]], tuple[int, int] + ] = lexicalLocationHeuristic + if loc_heuristic == "anti-entropy": + location_heuristic = makeAntiEntropyLocationHeuristic(choice_random_weighting) + if loc_heuristic == "entropy": + location_heuristic = makeEntropyLocationHeuristic(choice_random_weighting) + if loc_heuristic == "random": + location_heuristic = makeRandomLocationHeuristic(choice_random_weighting) + if loc_heuristic == "simple": + location_heuristic = simpleLocationHeuristic + if loc_heuristic == "spiral": + location_heuristic = makeSpiralLocationHeuristic(choice_random_weighting) + if loc_heuristic == "hilbert": + # This requires hilbert_curve to be installed + location_heuristic = makeHilbertLocationHeuristic(choice_random_weighting) + + # Global Constraints # + + if global_constraint == "allpatterns": + active_global_constraint = make_global_use_all_patterns() + else: + + def active_global_constraint(wave) -> bool: + return True + + logger.debug(active_global_constraint) + combined_constraints = [active_global_constraint] + + def combinedConstraints(wave: NDArray[np.bool_]) -> bool: + return all(fn(wave) for fn in combined_constraints) + + # Solving # + + time_solve_start = None + time_solve_end = None + + solution_tile_grid = None + logger.debug("solving...") + attempts = 0 + while attempts < attempt_limit: + attempts += 1 + time_solve_start = time.perf_counter() + stats = {} + try: + solution = run( + wave.copy(), + adjacency_matrix, + locationHeuristic=location_heuristic, + patternHeuristic=pattern_heuristic, + periodic=output_periodic, + backtracking=backtracking, + checkFeasible=combinedConstraints, + ) + solution_as_ids = np.vectorize(lambda x: decode_patterns[x])(solution) + solution_tile_grid = pattern_grid_to_tiles(solution_as_ids, pattern_catalog) + + time_solve_end = time.perf_counter() + stats.update({"outcome": "success"}) + except StopEarly: + logger.debug("Skipping...") + stats.update({"outcome": "skipped"}) + raise + except TimedOut: + logger.debug("Timed Out") + stats.update({"outcome": "timed_out"}) + except Contradiction: + # logger.warning(f"Contradiction: {exc}") + stats.update({"outcome": "contradiction"}) + finally: + # profiler.dump_stats(f"logs/profile_{filename}_{timecode}.txt") + outstats = {} + outstats.update(input_stats) + solve_duration = time.perf_counter() - time_solve_start + if time_solve_end is not None: + solve_duration = time_solve_end - time_solve_start + adjacency_duration = time_solve_start - time_adjacency + outstats.update( + { + "attempts": attempts, + "time_start": time_begin, + "time_adjacency": time_adjacency, + "adjacency_duration": adjacency_duration, + "time solve start": time_solve_start, + "time solve end": time_solve_end, + "solve duration": solve_duration, + "pattern count": number_of_patterns, + } + ) + outstats.update(stats) + if log_stats_to_output is not None: + log_stats_to_output( + outstats, output_destination + log_filename + ".tsv" + ) + if solution_tile_grid is not None: + return ( + tile_grid_to_image( + solution_tile_grid, tile_catalog, (tile_size, tile_size) + ), + outstats, + ) + else: + return None, outstats + + raise TimedOut("Attempt limit exceeded.") diff --git a/minigrid/envs/wfc/wfclogic/patterns.py b/minigrid/envs/wfc/wfclogic/patterns.py new file mode 100644 index 000000000..d975d1463 --- /dev/null +++ b/minigrid/envs/wfc/wfclogic/patterns.py @@ -0,0 +1,199 @@ +"Extract patterns from grids of tiles. Implementation based on https://github.com/ikarth/wfc_2019f" +from __future__ import annotations + +import logging +from collections import Counter +from typing import Any, Mapping + +import numpy as np +from numpy.typing import NDArray + +from minigrid.envs.wfc.wfclogic.utilities import hash_downto + +logger = logging.getLogger(__name__) + + +def unique_patterns_2d( + agrid: NDArray[np.int64], ksize: int, periodic_input: bool +) -> tuple[NDArray[np.int64], NDArray[np.int64], NDArray[np.int64]]: + assert ksize >= 1 + if periodic_input: + agrid = np.pad( + agrid, + ((0, ksize - 1), (0, ksize - 1), *(((0, 0),) * (len(agrid.shape) - 2))), + mode="wrap", + ) + else: + # TODO: implement non-wrapped image handling + # a = np.pad(a, ((0,k-1),(0,k-1),*(((0,0),)*(len(a.shape)-2))), mode='constant', constant_values=None) + agrid = np.pad( + agrid, + ((0, ksize - 1), (0, ksize - 1), *(((0, 0),) * (len(agrid.shape) - 2))), + mode="wrap", + ) + + patches: NDArray[np.int64] = np.lib.stride_tricks.as_strided( + agrid, + ( + agrid.shape[0] - ksize + 1, + agrid.shape[1] - ksize + 1, + ksize, + ksize, + *agrid.shape[2:], + ), + agrid.strides[:2] + agrid.strides[:2] + agrid.strides[2:], + writeable=False, + ) + patch_codes = hash_downto(patches, 2) + uc, ui = np.unique(patch_codes, return_index=True) + locs = np.unravel_index(ui, patch_codes.shape) + up: NDArray[np.int64] = patches[locs[0], locs[1]] + ids: NDArray[np.int64] = np.vectorize( + {code: ind for ind, code in enumerate(uc)}.get + )(patch_codes) + return ids, up, patch_codes + + +def unique_patterns_brute_force(grid, size, periodic_input): + padded_grid = np.pad( + grid, + ((0, size - 1), (0, size - 1), *(((0, 0),) * (len(grid.shape) - 2))), + mode="wrap", + ) + patches = [] + for x in range(grid.shape[0]): + row_patches = [] + for y in range(grid.shape[1]): + row_patches.append( + np.ndarray.tolist(padded_grid[x : x + size, y : y + size]) + ) + patches.append(row_patches) + patches = np.array(patches) + patch_codes = hash_downto(patches, 2) + uc, ui = np.unique(patch_codes, return_index=True) + locs = np.unravel_index(ui, patch_codes.shape) + up = patches[locs[0], locs[1]] + ids = np.vectorize({c: i for i, c in enumerate(uc)}.get)(patch_codes) + return ids, up + + +def make_pattern_catalog( + tile_grid: NDArray[np.int64], pattern_width: int, input_is_periodic: bool = True +) -> tuple[dict[int, NDArray[np.int64]], Counter, NDArray[np.int64], NDArray[np.int64]]: + """Returns a pattern catalog (dictionary of pattern hashes to constituent tiles), + an ordered list of pattern weights, and an ordered list of pattern contents.""" + _patterns_in_grid, pattern_contents_list, patch_codes = unique_patterns_2d( + tile_grid, pattern_width, input_is_periodic + ) + dict_of_pattern_contents: dict[int, NDArray[np.int64]] = {} + for pat_idx in range(pattern_contents_list.shape[0]): + p_hash = hash_downto(pattern_contents_list[pat_idx], 0) + dict_of_pattern_contents.update({p_hash.item(): pattern_contents_list[pat_idx]}) + pattern_frequency = Counter(hash_downto(pattern_contents_list, 1)) + return ( + dict_of_pattern_contents, + pattern_frequency, + hash_downto(pattern_contents_list, 1), + patch_codes, + ) + + +def identity_grid(grid): + """Do nothing to the grid""" + # return np.array([[7,5,5,5],[5,0,0,0],[5,0,1,0],[5,0,0,0]]) + return grid + + +def reflect_grid(grid): + """Reflect the grid left/right""" + return np.fliplr(grid) + + +def rotate_grid(grid): + """Rotate the grid""" + return np.rot90(grid, axes=(1, 0)) + + +def make_pattern_catalog_with_rotations( + tile_grid: NDArray[np.int64], + pattern_width: int, + rotations: int = 7, + input_is_periodic: bool = True, +) -> tuple[dict[int, NDArray[np.int64]], Counter, NDArray[np.int64], NDArray[np.int64]]: + rotated_tile_grid = tile_grid.copy() + merged_dict_of_pattern_contents: dict[int, NDArray[np.int64]] = {} + merged_pattern_frequency: Counter = Counter() + merged_pattern_contents_list: NDArray[np.int64] | None = None + merged_patch_codes: NDArray[np.int64] | None = None + + def _make_catalog() -> None: + nonlocal rotated_tile_grid, merged_dict_of_pattern_contents, merged_pattern_contents_list, merged_pattern_frequency, merged_patch_codes + ( + dict_of_pattern_contents, + pattern_frequency, + pattern_contents_list, + patch_codes, + ) = make_pattern_catalog(rotated_tile_grid, pattern_width, input_is_periodic) + merged_dict_of_pattern_contents.update(dict_of_pattern_contents) + merged_pattern_frequency.update(pattern_frequency) + if merged_pattern_contents_list is None: + merged_pattern_contents_list = pattern_contents_list.copy() + else: + merged_pattern_contents_list = np.unique( + np.concatenate((merged_pattern_contents_list, pattern_contents_list)) + ) + if merged_patch_codes is None: + merged_patch_codes = patch_codes.copy() + + counter = 0 + grid_ops = [ + identity_grid, + reflect_grid, + rotate_grid, + reflect_grid, + rotate_grid, + reflect_grid, + rotate_grid, + reflect_grid, + ] + while counter <= (rotations): + # logger.debug(rotated_tile_grid.shape) + # logger.debug(np.array_equiv(reflect_grid(rotated_tile_grid.copy()), rotate_grid(rotated_tile_grid.copy()))) + + # logger.debug(counter) + # logger.debug(grid_ops[counter].__name__) + rotated_tile_grid = grid_ops[counter](rotated_tile_grid.copy()) + # logger.debug(rotated_tile_grid) + # logger.debug("---") + _make_catalog() + counter += 1 + + # assert False + assert merged_pattern_contents_list is not None + assert merged_patch_codes is not None + return ( + merged_dict_of_pattern_contents, + merged_pattern_frequency, + merged_pattern_contents_list, + merged_patch_codes, + ) + + +def pattern_grid_to_tiles( + pattern_grid: NDArray[np.int64], pattern_catalog: Mapping[int, NDArray[np.int64]] +) -> NDArray[np.int64]: + anchor_x = 0 + anchor_y = 0 + + def pattern_to_tile(pattern: int) -> Any: + # if isinstance(pattern, list): + # ptrns = [] + # for p in pattern: + # logger.debug(p) + # ptrns.push(pattern_to_tile(p)) + # logger.debug(ptrns) + # assert False + # return ptrns + return pattern_catalog[pattern][anchor_x][anchor_y] + + return np.vectorize(pattern_to_tile)(pattern_grid) diff --git a/minigrid/envs/wfc/wfclogic/solver.py b/minigrid/envs/wfc/wfclogic/solver.py new file mode 100644 index 000000000..a4cadad9f --- /dev/null +++ b/minigrid/envs/wfc/wfclogic/solver.py @@ -0,0 +1,530 @@ +"""Wave Function Collapse solver. Implementation based on https://github.com/ikarth/wfc_2019f""" +from __future__ import annotations + +import itertools +import logging +import math +from typing import Any, Callable, Collection, Iterable, Iterator, Mapping, TypeVar + +# from scipy import sparse # type: ignore +import numpy +import numpy as np +from numpy.typing import NBitBase, NDArray + +logger = logging.getLogger(__name__) + +T = TypeVar("T", bound=NBitBase) + + +class Contradiction(Exception): + """Solving could not proceed without backtracking/restarting.""" + + pass + + +class TimedOut(Exception): + """Solve timed out.""" + + pass + + +class StopEarly(Exception): + """Aborting solve early.""" + + pass + + +class Solver: + """WFC Solver which can hold wave and backtracking state.""" + + def __init__( + self, + *, + wave: NDArray[np.bool_], + adj: Mapping[tuple[int, int], NDArray[numpy.bool_]], + periodic: bool = False, + backtracking: bool = False, + on_backtrack: Callable[[], None] | None = None, + on_choice: Callable[[int, int, int], None] | None = None, + on_observe: Callable[[NDArray[numpy.bool_]], None] | None = None, + on_propagate: Callable[[NDArray[numpy.bool_]], None] | None = None, + check_feasible: Callable[[NDArray[numpy.bool_]], bool] | None = None, + ) -> None: + self.wave = wave + self.adj = adj + self.periodic = periodic + self.backtracking = backtracking + self.history: list[NDArray[np.bool_]] = [] # An undo history for backtracking. + self.on_backtrack = on_backtrack + self.on_choice = on_choice + self.on_observe = on_observe + self.on_propagate = on_propagate + self.check_feasible = check_feasible + + @property + def is_solved(self) -> bool: + """Is True if the wave has been fully resolved.""" + return ( + self.wave.sum() == self.wave.shape[1] * self.wave.shape[2] + and (self.wave.sum(axis=0) == 1).all() + ) + + def solve_next( + self, + location_heuristic: Callable[[NDArray[numpy.bool_]], tuple[int, int]], + pattern_heuristic: Callable[[NDArray[np.bool_], NDArray[np.bool_]], int], + ) -> bool: + """Attempt to collapse one wave. Returns True if no more steps remain.""" + if self.is_solved: + return True + if self.check_feasible and not self.check_feasible(self.wave): + raise Contradiction("Not feasible.") + if self.backtracking: + self.history.append(self.wave.copy()) + propagate( + self.wave, self.adj, periodic=self.periodic, onPropagate=self.on_propagate + ) + pattern, i, j = None, None, None + try: + pattern, i, j = observe(self.wave, location_heuristic, pattern_heuristic) + if self.on_choice: + self.on_choice(pattern, i, j) + self.wave[:, i, j] = False + self.wave[pattern, i, j] = True + if self.on_observe: + self.on_observe(self.wave) + propagate( + self.wave, + self.adj, + periodic=self.periodic, + onPropagate=self.on_propagate, + ) + return False # Assume there is remaining steps, if not then the next call will return True. + except Contradiction: + if not self.backtracking: + raise + if not self.history: + raise Contradiction("Every permutation has been attempted.") + if self.on_backtrack: + self.on_backtrack() + self.wave = self.history.pop() + self.wave[pattern, i, j] = False + return False + + def solve( + self, + location_heuristic: Callable[[NDArray[numpy.bool_]], tuple[int, int]], + pattern_heuristic: Callable[[NDArray[np.bool_], NDArray[np.bool_]], int], + ) -> NDArray[np.int64]: + """Attempts to solve all waves and returns the solution.""" + while not self.solve_next( + location_heuristic=location_heuristic, pattern_heuristic=pattern_heuristic + ): + pass + return numpy.argmax(self.wave, axis=0) + + +def makeWave( + n: int, w: int, h: int, ground: Iterable[int] | None = None +) -> NDArray[numpy.bool_]: + wave: NDArray[numpy.bool_] = numpy.ones((n, w, h), dtype=numpy.bool_) + if ground is not None: + wave[:, :, h - 1] = False + for g in ground: + wave[ + g, + :, + ] = False + wave[g, :, h - 1] = True + # logger.debug(wave) + # for i in range(wave.shape[0]): + # logger.debug(wave[i]) + return wave + + +def makeAdj( + adjLists: Mapping[tuple[int, int], Collection[Iterable[int]]] +) -> dict[tuple[int, int], NDArray[numpy.bool_]]: + adjMatrices = {} + # logger.debug(adjLists) + num_patterns = len(list(adjLists.values())[0]) + for d in adjLists: + m = numpy.zeros((num_patterns, num_patterns), dtype=bool) + for i, js in enumerate(adjLists[d]): + # logger.debug(js) + for j in js: + m[i, j] = 1 + # If scipy is available, use sparse matrices. + # adjMatrices[d] = sparse.csr_matrix(m) + adjMatrices[d] = m + return adjMatrices + + +###################################### +# Location Heuristics + + +def makeRandomLocationHeuristic( + preferences: NDArray[np.floating[Any]], +) -> Callable[[NDArray[np.bool_]], tuple[int, int]]: + def randomLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]: + unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1 + cell_weights = numpy.where(unresolved_cell_mask, preferences, numpy.inf) + row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape) + return row.item(), col.item() + + return randomLocationHeuristic + + +def makeEntropyLocationHeuristic( + preferences: NDArray[np.floating[Any]], +) -> Callable[[NDArray[np.bool_]], tuple[int, int]]: + def entropyLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]: + unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1 + cell_weights = numpy.where( + unresolved_cell_mask, + preferences + numpy.count_nonzero(wave, axis=0), + numpy.inf, + ) + row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape) + return row.item(), col.item() + + return entropyLocationHeuristic + + +def makeAntiEntropyLocationHeuristic( + preferences: NDArray[np.floating[Any]], +) -> Callable[[NDArray[np.bool_]], tuple[int, int]]: + def antiEntropyLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]: + unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1 + cell_weights = numpy.where( + unresolved_cell_mask, + preferences + numpy.count_nonzero(wave, axis=0), + -numpy.inf, + ) + row, col = numpy.unravel_index(numpy.argmax(cell_weights), cell_weights.shape) + return row.item(), col.item() + + return antiEntropyLocationHeuristic + + +def spiral_transforms() -> Iterator[tuple[int, int]]: + for N in itertools.count(start=1): + if N % 2 == 0: + yield (0, 1) # right + for _ in range(N): + yield (1, 0) # down + for _ in range(N): + yield (0, -1) # left + else: + yield (0, -1) # left + for _ in range(N): + yield (-1, 0) # up + for _ in range(N): + yield (0, 1) # right + + +def spiral_coords(x: int, y: int) -> Iterator[tuple[int, int]]: + yield x, y + for transform in spiral_transforms(): + x += transform[0] + y += transform[1] + yield x, y + + +def fill_with_curve( + arr: NDArray[np.floating[T]], curve_gen: Iterable[Iterable[int]] +) -> NDArray[np.floating[T]]: + arr_len = numpy.prod(arr.shape) + fill = 0 + for coord in curve_gen: + # logger.debug(fill, idx, coord) + if fill < arr_len: + try: + arr[tuple(coord)] = fill / arr_len + fill += 1 + except IndexError: + pass + else: + break + # logger.debug(arr) + return arr + + +def makeSpiralLocationHeuristic( + preferences: NDArray[np.floating[Any]], +) -> Callable[[NDArray[np.bool_]], tuple[int, int]]: + # https://stackoverflow.com/a/23707273/5562922 + + spiral_gen = ( + sc for sc in spiral_coords(preferences.shape[0] // 2, preferences.shape[1] // 2) + ) + + cell_order = fill_with_curve(preferences, spiral_gen) + + def spiralLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]: + unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1 + cell_weights = numpy.where(unresolved_cell_mask, cell_order, numpy.inf) + row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape) + return row.item(), col.item() + + return spiralLocationHeuristic + + +def makeHilbertLocationHeuristic( + preferences: NDArray[np.floating[Any]], +) -> Callable[[NDArray[np.bool_]], tuple[int, int]]: + from hilbertcurve.hilbertcurve import HilbertCurve # type: ignore + + curve_size = math.ceil(math.sqrt(max(preferences.shape[0], preferences.shape[1]))) + logger.debug(curve_size) + curve_size = 4 + h_curve = HilbertCurve(curve_size, 2) + h_coords = (h_curve.point_from_distance(i) for i in itertools.count()) + cell_order = fill_with_curve(preferences, h_coords) + # logger.debug(cell_order) + + def hilbertLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]: + unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1 + cell_weights = numpy.where(unresolved_cell_mask, cell_order, numpy.inf) + row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape) + return row.item(), col.item() + + return hilbertLocationHeuristic + + +def simpleLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]: + unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1 + cell_weights = numpy.where( + unresolved_cell_mask, numpy.count_nonzero(wave, axis=0), numpy.inf + ) + row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape) + return row.item(), col.item() + + +def lexicalLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]: + unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1 + cell_weights = numpy.where(unresolved_cell_mask, 1.0, numpy.inf) + row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape) + return row.item(), col.item() + + +##################################### +# Pattern Heuristics + + +def lexicalPatternHeuristic(weights: NDArray[np.bool_], wave: NDArray[np.bool_]) -> int: + return numpy.nonzero(weights)[0][0].item() + + +def makeWeightedPatternHeuristic( + weights: NDArray[np.floating[Any]], + np_random: numpy.random.Generator | None = None, +): + num_of_patterns = len(weights) + np_random: numpy.random.Generator = ( + numpy.random.default_rng() if np_random is None else np_random + ) + + def weightedPatternHeuristic(wave: NDArray[np.bool_], _: NDArray[np.bool_]) -> int: + # TODO: there's maybe a faster, more controlled way to do this sampling... + weighted_wave: NDArray[np.floating[Any]] = weights * wave + weighted_wave /= weighted_wave.sum() + result = np_random.choice(num_of_patterns, p=weighted_wave) + return result + + return weightedPatternHeuristic + + +def makeRarestPatternHeuristic( + weights: NDArray[np.floating[Any]], + np_random: numpy.random.Generator | None = None, +) -> Callable[[NDArray[np.bool_], NDArray[np.bool_]], int]: + """Return a function that chooses the rarest (currently least-used) pattern.""" + np_random: numpy.random.Generator = ( + numpy.random.default_rng() if np_random is None else np_random + ) + + def weightedPatternHeuristic( + wave: NDArray[np.bool_], total_wave: NDArray[np.bool_] + ) -> int: + logger.debug(total_wave.shape) + # [logger.debug(e) for e in wave] + wave_sums = numpy.sum(total_wave, (1, 2)) + # logger.debug(wave_sums) + selected_pattern = np_random.choice( + numpy.where(wave_sums == wave_sums.max())[0] + ) + return selected_pattern + + return weightedPatternHeuristic + + +def makeMostCommonPatternHeuristic( + weights: NDArray[np.floating[Any]], + np_random: numpy.random.Generator | None = None, +) -> Callable[[NDArray[np.bool_], NDArray[np.bool_]], int]: + """Return a function that chooses the most common (currently most-used) pattern.""" + np_random: numpy.random.Generator = ( + numpy.random.default_rng() if np_random is None else np_random + ) + + def weightedPatternHeuristic( + wave: NDArray[np.bool_], total_wave: NDArray[np.bool_] + ) -> int: + logger.debug(total_wave.shape) + # [logger.debug(e) for e in wave] + wave_sums = numpy.sum(total_wave, (1, 2)) + selected_pattern = np_random.choice( + numpy.where(wave_sums == wave_sums.min())[0] + ) + return selected_pattern + + return weightedPatternHeuristic + + +def makeRandomPatternHeuristic( + weights: NDArray[np.floating[Any]], + np_random: numpy.random.Generator | None = None, +) -> Callable[[NDArray[np.bool_], NDArray[np.bool_]], int]: + num_of_patterns = len(weights) + np_random: numpy.random.Generator = ( + numpy.random.default_rng() if np_random is None else np_random + ) + + def randomPatternHeuristic(wave: NDArray[np.bool_], _: NDArray[np.bool_]) -> int: + # TODO: there's maybe a faster, more controlled way to do this sampling... + weighted_wave = 1.0 * wave + weighted_wave /= weighted_wave.sum() + result = np_random.choice(num_of_patterns, p=weighted_wave) + return result + + return randomPatternHeuristic + + +###################################### +# Global Constraints + + +def make_global_use_all_patterns() -> Callable[[NDArray[np.bool_]], bool]: + def global_use_all_patterns(wave: NDArray[np.bool_]) -> bool: + """Returns true if at least one instance of each pattern is still possible.""" + return numpy.all(numpy.any(wave, axis=(1, 2))).item() + + return global_use_all_patterns + + +##################################### +# Solver + + +def propagate( + wave: NDArray[np.bool_], + adj: Mapping[tuple[int, int], NDArray[numpy.bool_]], + periodic: bool = False, + onPropagate: Callable[[NDArray[numpy.bool_]], None] | None = None, +) -> None: + """Completely probagate any newly collapsed waves to all areas.""" + last_count = wave.sum() + + while True: + supports = {} + if periodic: + padded = numpy.pad(wave, ((0, 0), (1, 1), (1, 1)), mode="wrap") + else: + padded = numpy.pad( + wave, ((0, 0), (1, 1), (1, 1)), mode="constant", constant_values=True + ) + + # adj is the list of adjacencies. For each direction d in adjacency, + # check which patterns are still valid... + for d in adj: + dx, dy = d + # padded[] is a version of the adjacency matrix with the values wrapped around + # shifted[] is the padded version with the values shifted over in one direction + # because my code stores the directions as relative (x,y) coordinates, we can find + # the adjacent cell for each direction by simply shifting the matrix in that direction, + # which allows for arbitrary adjacency directions. This is somewhat excessive, but elegant. + + shifted = padded[ + :, 1 + dx : 1 + wave.shape[1] + dx, 1 + dy : 1 + wave.shape[2] + dy + ] + # logger.debug(f"shifted: {shifted.shape} | adj[d]: {adj[d].shape} | d: {d}") + # raise StopEarly + # supports[d] = numpy.einsum('pwh,pq->qwh', shifted, adj[d]) > 0 + + # The adjacency matrix is a boolean matrix, indexed by the direction and the two patterns. + # If the value for (direction, pattern1, pattern2) is True, then this is a valid adjacency. + # This gives us a rapid way to compare: True is 1, False is 0, so multiplying the matrices + # gives us the adjacency compatibility. + supports[d] = (adj[d] @ shifted.reshape(shifted.shape[0], -1)).reshape( + shifted.shape + ) > 0 + # supports[d] = ( <- for each cell in the matrix + # adj[d] <- the adjacency matrix [sliced by the direction d] + # @ <- Matrix multiplication + # shifted.reshape(shifted.shape[0], -1)) <- change the shape of the shifted matrix to 2-dimensions, to make the matrix multiplication easier + # .reshape( <- reshape our matrix-multiplied result... + # shifted.shape) <- ...to match the original shape of the shifted matrix + # > 0 <- is not false + + # multiply the wave matrix by the support matrix to find which patterns are still in the domain + for d in adj: + wave *= supports[d] + + if wave.sum() == last_count: + break # No changes since the last loop, changed waves have been fully propagated. + last_count = wave.sum() + + if onPropagate: + onPropagate(wave) + + if (wave.sum(axis=0) == 0).any(): + raise Contradiction("Wave is in a contradictory state and can not be solved.") + + +def observe( + wave: NDArray[np.bool_], + locationHeuristic: Callable[[NDArray[np.bool_]], tuple[int, int]], + patternHeuristic: Callable[[NDArray[np.bool_], NDArray[np.bool_]], int], +) -> tuple[int, int, int]: + """Return the next best wave to collapse based on the provided heuristics.""" + i, j = locationHeuristic(wave) + pattern = patternHeuristic(wave[:, i, j], wave) + return pattern, i, j + + +def run( + wave: NDArray[np.bool_], + adj: Mapping[tuple[int, int], NDArray[numpy.bool_]], + locationHeuristic: Callable[[NDArray[numpy.bool_]], tuple[int, int]], + patternHeuristic: Callable[[NDArray[np.bool_], NDArray[np.bool_]], int], + periodic: bool = False, + backtracking: bool = False, + onBacktrack: Callable[[], None] | None = None, + onChoice: Callable[[int, int, int], None] | None = None, + onObserve: Callable[[NDArray[numpy.bool_]], None] | None = None, + onPropagate: Callable[[NDArray[numpy.bool_]], None] | None = None, + checkFeasible: Callable[[NDArray[numpy.bool_]], bool] | None = None, + onFinal: Callable[[NDArray[numpy.bool_]], None] | None = None, + depth: int = 0, + depth_limit: int | None = None, +) -> NDArray[numpy.int64]: + solver = Solver( + wave=wave, + adj=adj, + periodic=periodic, + backtracking=backtracking, + on_backtrack=onBacktrack, + on_choice=onChoice, + on_observe=onObserve, + on_propagate=onPropagate, + check_feasible=checkFeasible, + ) + while not solver.solve_next( + location_heuristic=locationHeuristic, pattern_heuristic=patternHeuristic + ): + pass + if onFinal: + onFinal(solver.wave) + return numpy.argmax(solver.wave, axis=0) diff --git a/minigrid/envs/wfc/wfclogic/tiles.py b/minigrid/envs/wfc/wfclogic/tiles.py new file mode 100644 index 000000000..a4ffcdf2e --- /dev/null +++ b/minigrid/envs/wfc/wfclogic/tiles.py @@ -0,0 +1,64 @@ +"""Breaks an image into consituant tiles. Implementation based on https://github.com/ikarth/wfc_2019f""" +from __future__ import annotations + +import numpy as np +from numpy.typing import NDArray + +from minigrid.envs.wfc.wfclogic.utilities import hash_downto + + +def image_to_tiles(img: NDArray[np.integer], tile_size: int) -> NDArray[np.integer]: + """ + Takes an images, divides it into tiles, return an array of tiles. + """ + padding_argument = [(0, 0), (0, 0), (0, 0)] + for input_dim in [0, 1]: + padding_argument[input_dim] = ( + 0, + (tile_size - img.shape[input_dim]) % tile_size, + ) + img = np.pad(img, padding_argument, mode="constant") + tiles = img.reshape( + ( + img.shape[0] // tile_size, + tile_size, + img.shape[1] // tile_size, + tile_size, + img.shape[2], + ) + ).swapaxes(1, 2) + return tiles + + +def make_tile_catalog( + image_data: NDArray[np.integer], tile_size: int +) -> tuple[ + dict[int, NDArray[np.integer]], + NDArray[np.int64], + NDArray[np.int64], + tuple[NDArray[np.int64], NDArray[np.int64]], +]: + """ + Takes an image and tile size and returns the following: + tile_catalog is a dictionary tiles, with the hashed ID as the key + tile_grid is the original image, expressed in terms of hashed tile IDs + code_list is the original image, expressed in terms of hashed tile IDs and reduced to one dimension + unique_tiles is the set of tiles, plus the frequency of their occurrence + """ + channels = image_data.shape[2] # Number of color channels in the image + tiles = image_to_tiles(image_data, tile_size) + tile_list: NDArray[np.integer] = tiles.reshape( + (tiles.shape[0] * tiles.shape[1], tile_size, tile_size, channels) + ) + code_list: NDArray[np.int64] = hash_downto(tiles, 2).reshape( + tiles.shape[0] * tiles.shape[1] + ) + tile_grid: NDArray[np.int64] = hash_downto(tiles, 2) + unique_tiles: tuple[NDArray[np.int64], NDArray[np.int64]] = np.unique( + tile_grid, return_counts=True + ) + + tile_catalog: dict[int, NDArray[np.integer]] = {} + for i, j in enumerate(code_list): + tile_catalog[j] = tile_list[i] + return tile_catalog, tile_grid, code_list, unique_tiles diff --git a/minigrid/envs/wfc/wfclogic/utilities.py b/minigrid/envs/wfc/wfclogic/utilities.py new file mode 100644 index 000000000..8c9e6e997 --- /dev/null +++ b/minigrid/envs/wfc/wfclogic/utilities.py @@ -0,0 +1,77 @@ +"""Utility data and functions for WFC. Implementation based on https://github.com/ikarth/wfc_2019f""" +from __future__ import annotations + +import collections +import logging + +import numpy as np +from numpy.typing import NDArray + +logger = logging.getLogger(__name__) + +CoordXY = collections.namedtuple("CoordXY", ["x", "y"]) +CoordRC = collections.namedtuple("CoordRC", ["row", "column"]) + + +def hash_downto(a: NDArray[np.integer], rank: int, seed=0) -> NDArray[np.int64]: + state = np.random.RandomState(seed) + # np_random = np.random.default_rng(seed) + assert rank < len(a.shape) + + u: NDArray[np.integer] = a.reshape((np.prod(a.shape[:rank], dtype=np.int64), -1)) + v = state.randint(1 - (1 << 63), 1 << 63, np.prod(a.shape[rank:]), dtype=np.int64) + # v = np_random.integers(1 - (1 << 63), 1 << 63, np.prod(a.shape[rank:]), dtype=np.int64) + return np.asarray(np.inner(u, v).reshape(a.shape[:rank]), dtype=np.int64) + + +def find_pattern_center(wfc_ns): + # wfc_ns.pattern_center = (math.floor((wfc_ns.pattern_width - 1) / 2), math.floor((wfc_ns.pattern_width - 1) / 2)) + wfc_ns.pattern_center = (0, 0) + return wfc_ns + + +def tile_grid_to_image( + tile_grid: NDArray[np.int64], + tile_catalog: dict[int, NDArray[np.integer]], + tile_size: tuple[int, int], + partial: bool = False, + color_channels: int = 3, +) -> NDArray[np.integer]: + """ + Takes a tile_grid and transforms it into an image, using the information + in tile_catalog. We use tile_size to figure out the size the new image + should be. + """ + tile_dtype = next(iter(tile_catalog.values())).dtype + new_img = np.zeros( + ( + tile_grid.shape[0] * tile_size[0], + tile_grid.shape[1] * tile_size[1], + color_channels, + ), + dtype=tile_dtype, + ) + if partial and (len(tile_grid.shape)) > 2: + # TODO: implement rendering partially completed solution + # Call tile_grid_to_average() instead. + assert False + else: + for i in range(tile_grid.shape[0]): + for j in range(tile_grid.shape[1]): + tile = tile_grid[i, j] + for u in range(tile_size[0]): + for v in range(tile_size[1]): + pixel = [200, 0, 200] + # If we want to display a partial pattern, it is helpful to + # be able to show empty cells. + pixel = tile_catalog[tile][u, v] + # TODO: will need to change if using an image with more than 3 channels + new_img[ + (i * tile_size[0]) + u, (j * tile_size[1]) + v + ] = np.resize( + pixel, + new_img[ + (i * tile_size[0]) + u, (j * tile_size[1]) + v + ].shape, + ) + return new_img diff --git a/minigrid/wrappers.py b/minigrid/wrappers.py index 86a714e53..569fa11d0 100644 --- a/minigrid/wrappers.py +++ b/minigrid/wrappers.py @@ -522,6 +522,7 @@ def get_minigrid_words(): "object", "from", "room", + "maze", ] all_words = colors + objects + verbs + extra_words diff --git a/py.Dockerfile b/py.Dockerfile index 04d073434..f51bff278 100644 --- a/py.Dockerfile +++ b/py.Dockerfile @@ -11,7 +11,7 @@ RUN apt-get -y update \ COPY . /usr/local/minigrid/ WORKDIR /usr/local/minigrid/ -RUN pip install .[testing] --no-cache-dir +RUN pip install .[wfc,testing] --no-cache-dir RUN ["chmod", "+x", "/usr/local/minigrid/docker_entrypoint"] diff --git a/pyproject.toml b/pyproject.toml index cfcab75c1..5d71ab691 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,10 @@ testing = [ "pytest-mock>=3.10.0", "matplotlib>=3.0" ] +wfc = [ + "networkx", + "imageio>=2.31.1", +] [project.urls] Homepage = "https://farama.org" diff --git a/tests/test_wfc/__init__.py b/tests/test_wfc/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_wfc/conftest.py b/tests/test_wfc/conftest.py new file mode 100644 index 000000000..1f88603da --- /dev/null +++ b/tests/test_wfc/conftest.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +import pytest +from numpy import array, uint8 + +from minigrid.envs.wfc.config import PATTERN_PATH + + +class Resources: + def get_pattern(self, image: str) -> str: + return PATTERN_PATH / image + + +@pytest.fixture(scope="session") +def resources() -> Resources: + return Resources() + + +@pytest.fixture(scope="session") +def img_redmaze(resources: Resources) -> array: + try: + import imageio # type: ignore + + pattern = resources.get_pattern("RedMaze.png") + img = imageio.v2.imread(pattern) + except ImportError: + b = [0, 0, 0] + w = [255, 255, 255] + r = [255, 0, 0] + img = array( + [ + [w, w, w, w], + [w, b, b, b], + [w, b, r, b], + [w, b, b, b], + ], + dtype=uint8, + ) + + return img diff --git a/tests/test_wfc/test_wfc_adjacency.py b/tests/test_wfc/test_wfc_adjacency.py new file mode 100644 index 000000000..be9cf22e2 --- /dev/null +++ b/tests/test_wfc/test_wfc_adjacency.py @@ -0,0 +1,41 @@ +"""Convert input data to adjacency information""" +from __future__ import annotations + +import numpy as np + +from minigrid.envs.wfc.wfclogic import adjacency as wfc_adjacency +from minigrid.envs.wfc.wfclogic import patterns as wfc_patterns +from minigrid.envs.wfc.wfclogic import tiles as wfc_tiles + + +def test_adjacency_extraction(img_redmaze: np.ndarray) -> None: + # TODO: generalize this to more than the four cardinal directions + direction_offsets = list(enumerate([(0, -1), (1, 0), (0, 1), (-1, 0)])) + + img = img_redmaze + tile_size = 1 + pattern_width = 2 + periodic = False + _tile_catalog, tile_grid, _code_list, _unique_tiles = wfc_tiles.make_tile_catalog( + img, tile_size + ) + ( + pattern_catalog, + _pattern_weights, + _pattern_list, + pattern_grid, + ) = wfc_patterns.make_pattern_catalog(tile_grid, pattern_width, periodic) + adjacency_relations = wfc_adjacency.adjacency_extraction( + pattern_grid, pattern_catalog, direction_offsets + ) + assert ((0, -1), -6150964001204120324, -4042134092912931260) in adjacency_relations + assert ((-1, 0), -4042134092912931260, 3069048847358774683) in adjacency_relations + assert ((1, 0), -3950451988873469076, -3950451988873469076) in adjacency_relations + assert ((-1, 0), -3950451988873469076, -3950451988873469076) in adjacency_relations + assert ((0, 1), -3950451988873469076, 3336256675067683735) in adjacency_relations + assert ( + not ((0, -1), -3950451988873469076, -3950451988873469076) in adjacency_relations + ) + assert ( + not ((0, 1), -3950451988873469076, -3950451988873469076) in adjacency_relations + ) diff --git a/tests/test_wfc/test_wfc_patterns.py b/tests/test_wfc/test_wfc_patterns.py new file mode 100644 index 000000000..3e4b93a0f --- /dev/null +++ b/tests/test_wfc/test_wfc_patterns.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import numpy as np + +from minigrid.envs.wfc.wfclogic import patterns as wfc_patterns +from minigrid.envs.wfc.wfclogic import tiles as wfc_tiles + + +def test_unique_patterns_2d(img_redmaze) -> None: + img = img_redmaze + tile_size = 1 + pattern_width = 2 + _tile_catalog, tile_grid, _code_list, _unique_tiles = wfc_tiles.make_tile_catalog( + img, tile_size + ) + + ( + _patterns_in_grid, + pattern_contents_list, + patch_codes, + ) = wfc_patterns.unique_patterns_2d(tile_grid, pattern_width, True) + assert patch_codes[1][2] == 4867810695119132864 + assert pattern_contents_list[7][1][1] == 8253868773529191888 + + +def test_make_pattern_catalog(img_redmaze) -> None: + img = img_redmaze + tile_size = 1 + pattern_width = 2 + _tile_catalog, tile_grid, _code_list, _unique_tiles = wfc_tiles.make_tile_catalog( + img, tile_size + ) + + ( + pattern_catalog, + pattern_weights, + pattern_list, + _pattern_grid, + ) = wfc_patterns.make_pattern_catalog(tile_grid, pattern_width) + assert pattern_weights[-6150964001204120324] == 1 + assert pattern_list[3] == 2800765426490226432 + assert pattern_catalog[5177878755649963747][0][1] == -8754995591521426669 + + +def test_pattern_to_tile(img_redmaze) -> None: + img = img_redmaze + tile_size = 1 + pattern_width = 2 + _tile_catalog, tile_grid, _code_list, _unique_tiles = wfc_tiles.make_tile_catalog( + img, tile_size + ) + + ( + pattern_catalog, + _pattern_weights, + _pattern_list, + pattern_grid, + ) = wfc_patterns.make_pattern_catalog(tile_grid, pattern_width) + new_tile_grid = wfc_patterns.pattern_grid_to_tiles(pattern_grid, pattern_catalog) + assert np.array_equal(tile_grid, new_tile_grid) diff --git a/tests/test_wfc/test_wfc_solver.py b/tests/test_wfc/test_wfc_solver.py new file mode 100644 index 000000000..831238b82 --- /dev/null +++ b/tests/test_wfc/test_wfc_solver.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +import numpy as np +import pytest +from numpy.typing import NDArray + +from minigrid.envs.wfc.wfclogic import solver as wfc_solver + + +def test_makeWave() -> None: + wave = wfc_solver.makeWave(3, 10, 20, ground=[-1]) + assert wave.sum() == (2 * 10 * 19) + (1 * 10 * 1) + assert wave[2, 5, 19] + assert not wave[1, 5, 19] + + +def test_entropyLocationHeuristic() -> None: + wave = np.ones((5, 3, 4), dtype=bool) # everything is possible + wave[1:, 0, 0] = False # first cell is fully observed + wave[4, :, 2] = False + preferences: NDArray[np.float_] = np.ones((3, 4), dtype=np.float_) * 0.5 + preferences[1, 2] = 0.3 + preferences[1, 1] = 0.1 + heu = wfc_solver.makeEntropyLocationHeuristic(preferences) + result = heu(wave) + assert (1, 2) == result + + +def test_observe() -> None: + my_wave = np.ones((5, 3, 4), dtype=np.bool_) + my_wave[0, 1, 2] = False + + def locHeu(wave: NDArray[np.bool_]) -> tuple[int, int]: + assert np.array_equal(wave, my_wave) + return 1, 2 + + def patHeu(weights: NDArray[np.bool_], wave: NDArray[np.bool_]) -> int: + assert np.array_equal(weights, my_wave[:, 1, 2]) + return 3 + + assert wfc_solver.observe( + my_wave, locationHeuristic=locHeu, patternHeuristic=patHeu + ) == ( + 3, + 1, + 2, + ) + + +def test_propagate() -> None: + wave = np.ones((3, 3, 4), dtype=bool) + adjLists = {} + # checkerboard #0/#1 or solid fill #2 + adjLists[(+1, 0)] = adjLists[(-1, 0)] = adjLists[(0, +1)] = adjLists[(0, -1)] = [ + [1], + [0], + [2], + ] + wave[:, 0, 0] = False + wave[0, 0, 0] = True + adj = wfc_solver.makeAdj(adjLists) + wfc_solver.propagate(wave, adj, periodic=False) + expected_result = np.array( + [ + [ + [True, False, True, False], + [False, True, False, True], + [True, False, True, False], + ], + [ + [False, True, False, True], + [True, False, True, False], + [False, True, False, True], + ], + [ + [False, False, False, False], + [False, False, False, False], + [False, False, False, False], + ], + ] + ) + assert np.array_equal(wave, expected_result) + + +def test_run() -> None: + wave = wfc_solver.makeWave(3, 3, 4) + adjLists = {} + adjLists[(+1, 0)] = adjLists[(-1, 0)] = adjLists[(0, +1)] = adjLists[(0, -1)] = [ + [1], + [0], + [2], + ] + adj = wfc_solver.makeAdj(adjLists) + + first_result = wfc_solver.run( + wave.copy(), + adj, + locationHeuristic=wfc_solver.lexicalLocationHeuristic, + patternHeuristic=wfc_solver.lexicalPatternHeuristic, + periodic=False, + ) + + expected_first_result = np.array([[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]]) + + assert np.array_equal(first_result, expected_first_result) + + event_log: list = [] + + def onChoice(pattern: int, i: int, j: int) -> None: + event_log.append((pattern, i, j)) + + def onBacktrack() -> None: + event_log.append("backtrack") + + second_result = wfc_solver.run( + wave.copy(), + adj, + locationHeuristic=wfc_solver.lexicalLocationHeuristic, + patternHeuristic=wfc_solver.lexicalPatternHeuristic, + periodic=True, + backtracking=True, + onChoice=onChoice, + onBacktrack=onBacktrack, + ) + + expected_second_result = np.array([[2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2]]) + + assert np.array_equal(second_result, expected_second_result) + assert event_log == [(0, 0, 0), "backtrack", (2, 0, 0)] + + class Infeasible(Exception): + pass + + def explode(wave: NDArray[np.bool_]) -> bool: + if wave.sum() < 20: + raise Infeasible + return False + + with pytest.raises(wfc_solver.Contradiction): + wfc_solver.run( + wave.copy(), + adj, + locationHeuristic=wfc_solver.lexicalLocationHeuristic, + patternHeuristic=wfc_solver.lexicalPatternHeuristic, + periodic=True, + backtracking=True, + checkFeasible=explode, + ) diff --git a/tests/test_wfc/test_wfc_tiles.py b/tests/test_wfc/test_wfc_tiles.py new file mode 100644 index 000000000..4a114b4f7 --- /dev/null +++ b/tests/test_wfc/test_wfc_tiles.py @@ -0,0 +1,17 @@ +"""Breaks an image into consituant tiles.""" +from __future__ import annotations + +from minigrid.envs.wfc.wfclogic import tiles as wfc_tiles + + +def test_image_to_tile(img_redmaze) -> None: + img = img_redmaze + tiles = wfc_tiles.image_to_tiles(img, 1) + assert tiles[2][2][0][0][0] == 255 + assert tiles[2][2][0][0][1] == 0 + + +def test_make_tile_catalog(img_redmaze) -> None: + img = img_redmaze + tc, tg, cl, ut = wfc_tiles.make_tile_catalog(img, 1) + assert ut[1][0] == 7 diff --git a/tests/utils.py b/tests/utils.py index 12ddf6da0..1a8f27cb3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,6 +1,8 @@ """Finds all the specs that we can test with""" from __future__ import annotations +from importlib.util import find_spec + import gymnasium as gym import numpy as np @@ -13,6 +15,14 @@ ) ] +if find_spec("imageio") is None or find_spec("networkx") is None: + # Do not test WFC environments if dependencies are not installed + all_testing_env_specs = [ + env_spec + for env_spec in all_testing_env_specs + if not env_spec.entry_point.startswith("minigrid.envs.wfc") + ] + minigrid_testing_env_specs = [ env_spec for env_spec in all_testing_env_specs