From c4bacb70c08f07c978c70e85cf74a03731efdbb6 Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Fri, 1 Nov 2024 14:59:21 +0200 Subject: [PATCH] chore: upgrade flake and fix errors --- .pre-commit-config.yaml | 2 +- jumanji/env.py | 2 +- .../environments/logic/game_2048/viewer.py | 2 +- .../environments/logic/graph_coloring/env.py | 2 +- .../logic/sliding_tile_puzzle/viewer.py | 4 +-- jumanji/environments/packing/bin_pack/env.py | 6 ++--- .../packing/flat_pack/generator.py | 4 +-- .../environments/routing/mmst/generator.py | 2 +- .../environments/routing/multi_cvrp/viewer.py | 8 +++--- .../routing/robot_warehouse/conftest.py | 4 +-- .../routing/robot_warehouse/env.py | 4 ++- jumanji/specs.py | 8 +++--- jumanji/specs_test.py | 14 +++++----- .../networks/graph_coloring/actor_critic.py | 2 +- .../training/networks/mmst/actor_critic.py | 2 +- jumanji/training/networks/tsp/actor_critic.py | 4 +-- setup.cfg | 27 ++++++++++++------- 17 files changed, 54 insertions(+), 43 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 928835b3d..f62452ace 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,7 +33,7 @@ repos: name: "Trailing whitespace fixer" - repo: https://github.com/PyCQA/flake8 - rev: 3.9.2 + rev: 7.1.1 hooks: - id: flake8 name: "Linter" diff --git a/jumanji/env.py b/jumanji/env.py index 48035a992..9674960c8 100644 --- a/jumanji/env.py +++ b/jumanji/env.py @@ -110,7 +110,7 @@ def reward_spec(self) -> specs.Array: @cached_property def discount_spec(self) -> specs.BoundedArray: - """Returns the discount spec. By default, this is assumed to be a single float between 0 and 1. + """Returns the discount spec. By default, this is assumed to be a float between 0 and 1. Returns: discount_spec: a `specs.BoundedArray` spec. diff --git a/jumanji/environments/logic/game_2048/viewer.py b/jumanji/environments/logic/game_2048/viewer.py index 819d1c251..b64b48b20 100644 --- a/jumanji/environments/logic/game_2048/viewer.py +++ b/jumanji/environments/logic/game_2048/viewer.py @@ -123,7 +123,7 @@ def make_frame(state_index: int) -> None: return self._animation def get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]: - """This function returns a `Matplotlib` figure and axes object for displaying the 2048 game board. + """This function returns a `Matplotlib` figure and axes for displaying the 2048 game board. Returns: A tuple containing the figure and axes objects. diff --git a/jumanji/environments/logic/graph_coloring/env.py b/jumanji/environments/logic/graph_coloring/env.py index b5e65a3e5..32de81019 100644 --- a/jumanji/environments/logic/graph_coloring/env.py +++ b/jumanji/environments/logic/graph_coloring/env.py @@ -296,7 +296,7 @@ def animate( interval: int = 200, save_path: Optional[str] = None, ) -> animation.FuncAnimation: - """Creates an animated gif of the `GraphColoring` environment based on the sequence of game states. + """Creates an animated gif of the `GraphColoring` environment based on a sequence of states. Args: states: is a list of `State` objects representing the sequence of game states. diff --git a/jumanji/environments/logic/sliding_tile_puzzle/viewer.py b/jumanji/environments/logic/sliding_tile_puzzle/viewer.py index 6596a323d..7fa905fcf 100644 --- a/jumanji/environments/logic/sliding_tile_puzzle/viewer.py +++ b/jumanji/environments/logic/sliding_tile_puzzle/viewer.py @@ -71,7 +71,7 @@ def animate( interval: int = 200, save_path: Optional[str] = None, ) -> matplotlib.animation.FuncAnimation: - """Creates an animated gif of the sliding tiles puzzle game based on the sequence of game states. + """Creates an animated gif of the sliding tiles puzzle game based on a sequence of states. Args: states: is a list of `State` objects representing the sequence of game states. @@ -101,7 +101,7 @@ def make_frame(state_index: int) -> None: return self._animation def get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]: - """This function returns a `Matplotlib` figure and axes object for displaying the game puzzle. + """This function returns a `Matplotlib` figure and axes for displaying the puzzle. Returns: A tuple containing the figure and axes objects. diff --git a/jumanji/environments/packing/bin_pack/env.py b/jumanji/environments/packing/bin_pack/env.py index 5b2b2c7cf..3506fa0b0 100644 --- a/jumanji/environments/packing/bin_pack/env.py +++ b/jumanji/environments/packing/bin_pack/env.py @@ -397,9 +397,9 @@ def close(self) -> None: def _make_observation_and_extras( self, state: State ) -> Tuple[State, Observation, Dict]: - """Computes the observation and the environment metrics to include in `timestep.extras`. Also - updates the `action_mask` and `sorted_ems_indexes` in the state. The observation is obtained - by selecting a subset of all EMSs, namely the `obs_num_ems` largest ones. + """Computes the observation and the environment metrics to include in `timestep.extras`. + Also updates the `action_mask` and `sorted_ems_indexes` in the state. The observation is + obtained by selecting a subset of all EMSs, namely the `obs_num_ems` largest ones. """ obs_ems, obs_ems_mask, sorted_ems_indexes = self._get_set_of_largest_ems( state.ems, state.ems_mask diff --git a/jumanji/environments/packing/flat_pack/generator.py b/jumanji/environments/packing/flat_pack/generator.py index 7ea8495d5..412c4d9e1 100644 --- a/jumanji/environments/packing/flat_pack/generator.py +++ b/jumanji/environments/packing/flat_pack/generator.py @@ -28,8 +28,8 @@ class InstanceGenerator(abc.ABC): - """Base class for generators for the flat_pack environment. An `InstanceGenerator` is responsible - for generating a problem instance when the environment is reset. + """Base class for generators for the flat_pack environment. An `InstanceGenerator` is + responsible for generating a problem instance when the environment is reset. """ def __init__( diff --git a/jumanji/environments/routing/mmst/generator.py b/jumanji/environments/routing/mmst/generator.py index c71d01054..be0e4685a 100644 --- a/jumanji/environments/routing/mmst/generator.py +++ b/jumanji/environments/routing/mmst/generator.py @@ -88,7 +88,7 @@ def __call__(self, key: chex.PRNGKey) -> State: class SplitRandomGenerator(Generator): - """Generates a random environments that is solvable by spliting the graph into multiple sub graphs. + """Generates a random environments that is solvable by spliting the graph into sub graphs. Returns a graph and with a desired number of edges and nodes to connect per agent. """ diff --git a/jumanji/environments/routing/multi_cvrp/viewer.py b/jumanji/environments/routing/multi_cvrp/viewer.py index 3dbda1aec..dd6fc5e8e 100644 --- a/jumanji/environments/routing/multi_cvrp/viewer.py +++ b/jumanji/environments/routing/multi_cvrp/viewer.py @@ -205,10 +205,10 @@ def _draw_route(self, ax: plt.Axes, coords: chex.Array, col_id: int) -> None: ax.scatter(x, y, s=self.NODE_SIZE, color=self._cmap(col_id)) def _add_tour(self, ax: plt.Axes, state: State) -> None: - """Add the customers and the depot to the plot, and draw each route in the tour in a different - colour. The tour is the entire trajectory between the visited customers and a route is a - trajectory either starting and ending at the depot or starting at the depot and ending at - the current city.""" + """Add the customers and the depot to the plot, and draw each route in the tour in a + different colour. The tour is the entire trajectory between the visited customers and a + route is a trajectory either starting and ending at the depot or starting at the depot + and ending at the current city.""" x_coords, y_coords = ( state.nodes.coordinates[:, 0] / self._map_max, state.nodes.coordinates[:, 1] / self._map_max, diff --git a/jumanji/environments/routing/robot_warehouse/conftest.py b/jumanji/environments/routing/robot_warehouse/conftest.py index 95ed58271..68d815705 100644 --- a/jumanji/environments/routing/robot_warehouse/conftest.py +++ b/jumanji/environments/routing/robot_warehouse/conftest.py @@ -31,8 +31,8 @@ @pytest.fixture(scope="module") def robot_warehouse_env() -> RobotWarehouse: - """Instantiates a default RobotWarehouse environment with 2 agents, 1 shelf row, 3 shelf columns, - a column height of 2, sensor range of 1 and a request queue size of 4.""" + """Instantiates a default RobotWarehouse environment with 2 agents, 1 shelf row, 3 shelf + columns, a column height of 2, sensor range of 1 and a request queue size of 4.""" generator = RandomGenerator( shelf_rows=1, shelf_columns=3, diff --git a/jumanji/environments/routing/robot_warehouse/env.py b/jumanji/environments/routing/robot_warehouse/env.py index eb9c2c578..8ab107bc4 100644 --- a/jumanji/environments/routing/robot_warehouse/env.py +++ b/jumanji/environments/routing/robot_warehouse/env.py @@ -362,7 +362,9 @@ def observation_spec(self) -> specs.Spec[Observation]: @cached_property def action_spec(self) -> specs.MultiDiscreteArray: - """Returns the action spec. 5 actions: [0,1,2,3,4] -> [No Op, Forward, Left, Right, Toggle_load]. + """Returns the action spec. 5 actions: + [0,1,2,3,4] -> [No Op, Forward, Left, Right, Toggle_load]. + Since this is a multi-agent environment, the environment expects an array of actions. This array is of shape (num_agents,). """ diff --git a/jumanji/specs.py b/jumanji/specs.py index 6dc40237b..6cfacd546 100644 --- a/jumanji/specs.py +++ b/jumanji/specs.py @@ -44,9 +44,9 @@ class Spec(abc.ABC, Generic[T]): - """Adapted from `dm_env.spec.Array`. This is an augmentation of the `Array` spec to allow for nested - specs. `self.name`, `self.generate_value` and `self.validate` methods are adapted from the - `dm_env` object.""" + """Adapted from `dm_env.spec.Array`. This is an augmentation of the `Array` spec to allow for + nested specs. `self.name`, `self.generate_value` and `self.validate` methods are adapted from + the `dm_env` object.""" def __init__( self, @@ -139,7 +139,7 @@ def __getitem__(self, item: str) -> "Spec": class Array(Spec[chex.Array]): - """Describes a jax array spec. This is adapted from `dm_env.specs.Array` to suit Jax environments. + """Describes a jax array spec. This is adapted from `dm_env.specs.Array` for Jax environments. An `Array` spec allows an API to describe the arrays that it accepts or returns, before that array exists. diff --git a/jumanji/specs_test.py b/jumanji/specs_test.py index 74e95b512..09b9f48b1 100644 --- a/jumanji/specs_test.py +++ b/jumanji/specs_test.py @@ -589,7 +589,7 @@ def test_array(self) -> None: converted_spec: dm_env.specs.Array = specs.jumanji_specs_to_dm_env_specs( jumanji_spec ) - assert type(converted_spec) == type(dm_env_spec) + assert type(converted_spec) is type(dm_env_spec) assert converted_spec.shape == dm_env_spec.shape assert converted_spec.dtype == dm_env_spec.dtype assert converted_spec.name == dm_env_spec.name @@ -602,7 +602,7 @@ def test_bounded_array(self) -> None: converted_spec: dm_env.specs.BoundedArray = specs.jumanji_specs_to_dm_env_specs( jumanji_spec ) - assert type(converted_spec) == type(dm_env_spec) + assert type(converted_spec) is type(dm_env_spec) assert converted_spec.shape == dm_env_spec.shape assert converted_spec.dtype == dm_env_spec.dtype assert converted_spec.name == dm_env_spec.name @@ -615,7 +615,7 @@ def test_discrete_array(self) -> None: converted_spec: dm_env.specs.DiscreteArray = ( specs.jumanji_specs_to_dm_env_specs(jumanji_spec) ) - assert type(converted_spec) == type(dm_env_spec) + assert type(converted_spec) is type(dm_env_spec) assert converted_spec.shape == dm_env_spec.shape assert converted_spec.dtype == dm_env_spec.dtype assert converted_spec.name == dm_env_spec.name @@ -675,7 +675,7 @@ def test_array(self) -> None: jumanji_spec = specs.Array((1, 2), jnp.int32) gym_space = gym.spaces.Box(-np.inf, np.inf, (1, 2), jnp.int32) converted_spec = specs.jumanji_specs_to_gym_spaces(jumanji_spec) - assert type(converted_spec) == type(gym_space) + assert type(converted_spec) is type(gym_space) assert_trees_all_equal(converted_spec.low, gym_space.low) assert_trees_all_equal(converted_spec.high, gym_space.high) assert converted_spec.shape == gym_space.shape @@ -687,7 +687,7 @@ def test_bounded_array(self) -> None: ) gym_space = gym.spaces.Box(low=0.0, high=1.0, shape=(1, 2), dtype=jnp.float32) converted_spec = specs.jumanji_specs_to_gym_spaces(jumanji_spec) - assert type(converted_spec) == type(gym_space) + assert type(converted_spec) is type(gym_space) assert converted_spec.shape == gym_space.shape assert converted_spec.dtype == gym_space.dtype assert_trees_all_equal(converted_spec.low, gym_space.low) @@ -697,7 +697,7 @@ def test_discrete_array(self) -> None: jumanji_spec = specs.DiscreteArray(num_values=5, dtype=jnp.int32) gym_space = gym.spaces.Discrete(n=5) converted_spec = specs.jumanji_specs_to_gym_spaces(jumanji_spec) - assert type(converted_spec) == type(gym_space) + assert type(converted_spec) is type(gym_space) assert converted_spec.shape == gym_space.shape assert converted_spec.dtype == gym_space.dtype assert converted_spec.n == gym_space.n @@ -708,7 +708,7 @@ def test_multi_discrete_array(self) -> None: ) gym_space = gym.spaces.MultiDiscrete(nvec=[5, 6]) converted_spec = specs.jumanji_specs_to_gym_spaces(jumanji_spec) - assert type(converted_spec) == type(gym_space) + assert type(converted_spec) is type(gym_space) assert converted_spec.shape == gym_space.shape assert converted_spec.dtype == gym_space.dtype assert jnp.array_equal(converted_spec.nvec, gym_space.nvec) diff --git a/jumanji/training/networks/graph_coloring/actor_critic.py b/jumanji/training/networks/graph_coloring/actor_critic.py index 6e2e336f6..62187d709 100644 --- a/jumanji/training/networks/graph_coloring/actor_critic.py +++ b/jumanji/training/networks/graph_coloring/actor_critic.py @@ -219,7 +219,7 @@ def __call__(self, observation: Observation) -> chex.Array: mlp_units=self.transformer_mlp_units, w_init_scale=2 / self.num_transformer_layers, model_size=self.model_size, - name=f"cross_attention_color_node_block_{block_id+1}", + name=f"cross_attention_color_node_block_{block_id + 1}", )(color_embeddings, current_node_embeddings, current_node_embeddings) return new_embedding diff --git a/jumanji/training/networks/mmst/actor_critic.py b/jumanji/training/networks/mmst/actor_critic.py index 45e776b4c..18b21dec3 100644 --- a/jumanji/training/networks/mmst/actor_critic.py +++ b/jumanji/training/networks/mmst/actor_critic.py @@ -207,7 +207,7 @@ def __call__(self, observation: Observation) -> chex.Array: mlp_units=self.transformer_mlp_units, w_init_scale=2 / self.num_transformer_layers, model_size=self.model_size, - name=f"cross_attention_agent_node_block_{block_id+1}", + name=f"cross_attention_agent_node_block_{block_id + 1}", )(agents_embeddings, current_node_embeddings, current_node_embeddings) return new_embedding diff --git a/jumanji/training/networks/tsp/actor_critic.py b/jumanji/training/networks/tsp/actor_critic.py index cff891c5e..1d720743b 100644 --- a/jumanji/training/networks/tsp/actor_critic.py +++ b/jumanji/training/networks/tsp/actor_critic.py @@ -72,8 +72,8 @@ def __init__( transformer_mlp_units: Sequence[int], name: Optional[str] = None, ): - """Linear embedding of all cities' coordinates followed by `transformer_num_blocks` blocks of self - attention. + """Linear embedding of all cities' coordinates followed by `transformer_num_blocks` blocks + of self attention. """ super().__init__(name=name) self.transformer_num_blocks = transformer_num_blocks diff --git a/setup.cfg b/setup.cfg index 6d09b4167..47a19d216 100644 --- a/setup.cfg +++ b/setup.cfg @@ -20,12 +20,21 @@ docstring-convention = google per-file-ignores = __init__.py:F401 ignore = - A002 # Argument shadowing a Python builtin. - A003 # Class attribute shadowing a Python builtin. - A005 # Module shadowing a Python builtin. - D107 # Do not require docstrings for __init__. - E266 # Do not require block comments to only have a single leading #. - E731 # Do not assign a lambda expression, use a def. - W503 # Line break before binary operator (not compatible with black). - B017 # assertRaises(Exception): or pytest.raises(Exception) should be considered evil. - E203 # black and flake8 disagree on whitespace before ':'. +# Argument shadowing a Python builtin. + A002 +# Class attribute shadowing a Python builtin. + A003 +# Module shadowing a Python builtin. + A005 +# Do not require docstrings for __init__. + D107 +# Do not require block comments to only have a single leading #. + E266 +# Do not assign a lambda expression, use a def. + E731 +# Line break before binary operator (not compatible with black). + W503 +# assertRaises(Exception): or pytest.raises(Exception) should be considered evil. + B017 +# black and flake8 disagree on whitespace before ':'. + E203