Skip to content

Commit

Permalink
chore: upgrade flake and fix errors
Browse files Browse the repository at this point in the history
  • Loading branch information
sash-a committed Nov 1, 2024
1 parent 333bb4a commit c4bacb7
Show file tree
Hide file tree
Showing 17 changed files with 54 additions and 43 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ repos:
name: "Trailing whitespace fixer"

- repo: https://github.com/PyCQA/flake8
rev: 3.9.2
rev: 7.1.1
hooks:
- id: flake8
name: "Linter"
Expand Down
2 changes: 1 addition & 1 deletion jumanji/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def reward_spec(self) -> specs.Array:

@cached_property
def discount_spec(self) -> specs.BoundedArray:
"""Returns the discount spec. By default, this is assumed to be a single float between 0 and 1.
"""Returns the discount spec. By default, this is assumed to be a float between 0 and 1.
Returns:
discount_spec: a `specs.BoundedArray` spec.
Expand Down
2 changes: 1 addition & 1 deletion jumanji/environments/logic/game_2048/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def make_frame(state_index: int) -> None:
return self._animation

def get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]:
"""This function returns a `Matplotlib` figure and axes object for displaying the 2048 game board.
"""This function returns a `Matplotlib` figure and axes for displaying the 2048 game board.
Returns:
A tuple containing the figure and axes objects.
Expand Down
2 changes: 1 addition & 1 deletion jumanji/environments/logic/graph_coloring/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def animate(
interval: int = 200,
save_path: Optional[str] = None,
) -> animation.FuncAnimation:
"""Creates an animated gif of the `GraphColoring` environment based on the sequence of game states.
"""Creates an animated gif of the `GraphColoring` environment based on a sequence of states.
Args:
states: is a list of `State` objects representing the sequence of game states.
Expand Down
4 changes: 2 additions & 2 deletions jumanji/environments/logic/sliding_tile_puzzle/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def animate(
interval: int = 200,
save_path: Optional[str] = None,
) -> matplotlib.animation.FuncAnimation:
"""Creates an animated gif of the sliding tiles puzzle game based on the sequence of game states.
"""Creates an animated gif of the sliding tiles puzzle game based on a sequence of states.
Args:
states: is a list of `State` objects representing the sequence of game states.
Expand Down Expand Up @@ -101,7 +101,7 @@ def make_frame(state_index: int) -> None:
return self._animation

def get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]:
"""This function returns a `Matplotlib` figure and axes object for displaying the game puzzle.
"""This function returns a `Matplotlib` figure and axes for displaying the puzzle.
Returns:
A tuple containing the figure and axes objects.
Expand Down
6 changes: 3 additions & 3 deletions jumanji/environments/packing/bin_pack/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,9 @@ def close(self) -> None:
def _make_observation_and_extras(
self, state: State
) -> Tuple[State, Observation, Dict]:
"""Computes the observation and the environment metrics to include in `timestep.extras`. Also
updates the `action_mask` and `sorted_ems_indexes` in the state. The observation is obtained
by selecting a subset of all EMSs, namely the `obs_num_ems` largest ones.
"""Computes the observation and the environment metrics to include in `timestep.extras`.
Also updates the `action_mask` and `sorted_ems_indexes` in the state. The observation is
obtained by selecting a subset of all EMSs, namely the `obs_num_ems` largest ones.
"""
obs_ems, obs_ems_mask, sorted_ems_indexes = self._get_set_of_largest_ems(
state.ems, state.ems_mask
Expand Down
4 changes: 2 additions & 2 deletions jumanji/environments/packing/flat_pack/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@


class InstanceGenerator(abc.ABC):
"""Base class for generators for the flat_pack environment. An `InstanceGenerator` is responsible
for generating a problem instance when the environment is reset.
"""Base class for generators for the flat_pack environment. An `InstanceGenerator` is
responsible for generating a problem instance when the environment is reset.
"""

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion jumanji/environments/routing/mmst/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __call__(self, key: chex.PRNGKey) -> State:


class SplitRandomGenerator(Generator):
"""Generates a random environments that is solvable by spliting the graph into multiple sub graphs.
"""Generates a random environments that is solvable by spliting the graph into sub graphs.
Returns a graph and with a desired number of edges and nodes to connect per agent.
"""
Expand Down
8 changes: 4 additions & 4 deletions jumanji/environments/routing/multi_cvrp/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,10 @@ def _draw_route(self, ax: plt.Axes, coords: chex.Array, col_id: int) -> None:
ax.scatter(x, y, s=self.NODE_SIZE, color=self._cmap(col_id))

def _add_tour(self, ax: plt.Axes, state: State) -> None:
"""Add the customers and the depot to the plot, and draw each route in the tour in a different
colour. The tour is the entire trajectory between the visited customers and a route is a
trajectory either starting and ending at the depot or starting at the depot and ending at
the current city."""
"""Add the customers and the depot to the plot, and draw each route in the tour in a
different colour. The tour is the entire trajectory between the visited customers and a
route is a trajectory either starting and ending at the depot or starting at the depot
and ending at the current city."""
x_coords, y_coords = (
state.nodes.coordinates[:, 0] / self._map_max,
state.nodes.coordinates[:, 1] / self._map_max,
Expand Down
4 changes: 2 additions & 2 deletions jumanji/environments/routing/robot_warehouse/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@

@pytest.fixture(scope="module")
def robot_warehouse_env() -> RobotWarehouse:
"""Instantiates a default RobotWarehouse environment with 2 agents, 1 shelf row, 3 shelf columns,
a column height of 2, sensor range of 1 and a request queue size of 4."""
"""Instantiates a default RobotWarehouse environment with 2 agents, 1 shelf row, 3 shelf
columns, a column height of 2, sensor range of 1 and a request queue size of 4."""
generator = RandomGenerator(
shelf_rows=1,
shelf_columns=3,
Expand Down
4 changes: 3 additions & 1 deletion jumanji/environments/routing/robot_warehouse/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,9 @@ def observation_spec(self) -> specs.Spec[Observation]:

@cached_property
def action_spec(self) -> specs.MultiDiscreteArray:
"""Returns the action spec. 5 actions: [0,1,2,3,4] -> [No Op, Forward, Left, Right, Toggle_load].
"""Returns the action spec. 5 actions:
[0,1,2,3,4] -> [No Op, Forward, Left, Right, Toggle_load].
Since this is a multi-agent environment, the environment expects an array of actions.
This array is of shape (num_agents,).
"""
Expand Down
8 changes: 4 additions & 4 deletions jumanji/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@


class Spec(abc.ABC, Generic[T]):
"""Adapted from `dm_env.spec.Array`. This is an augmentation of the `Array` spec to allow for nested
specs. `self.name`, `self.generate_value` and `self.validate` methods are adapted from the
`dm_env` object."""
"""Adapted from `dm_env.spec.Array`. This is an augmentation of the `Array` spec to allow for
nested specs. `self.name`, `self.generate_value` and `self.validate` methods are adapted from
the `dm_env` object."""

def __init__(
self,
Expand Down Expand Up @@ -139,7 +139,7 @@ def __getitem__(self, item: str) -> "Spec":


class Array(Spec[chex.Array]):
"""Describes a jax array spec. This is adapted from `dm_env.specs.Array` to suit Jax environments.
"""Describes a jax array spec. This is adapted from `dm_env.specs.Array` for Jax environments.
An `Array` spec allows an API to describe the arrays that it accepts or returns, before that
array exists.
Expand Down
14 changes: 7 additions & 7 deletions jumanji/specs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ def test_array(self) -> None:
converted_spec: dm_env.specs.Array = specs.jumanji_specs_to_dm_env_specs(
jumanji_spec
)
assert type(converted_spec) == type(dm_env_spec)
assert type(converted_spec) is type(dm_env_spec)
assert converted_spec.shape == dm_env_spec.shape
assert converted_spec.dtype == dm_env_spec.dtype
assert converted_spec.name == dm_env_spec.name
Expand All @@ -602,7 +602,7 @@ def test_bounded_array(self) -> None:
converted_spec: dm_env.specs.BoundedArray = specs.jumanji_specs_to_dm_env_specs(
jumanji_spec
)
assert type(converted_spec) == type(dm_env_spec)
assert type(converted_spec) is type(dm_env_spec)
assert converted_spec.shape == dm_env_spec.shape
assert converted_spec.dtype == dm_env_spec.dtype
assert converted_spec.name == dm_env_spec.name
Expand All @@ -615,7 +615,7 @@ def test_discrete_array(self) -> None:
converted_spec: dm_env.specs.DiscreteArray = (
specs.jumanji_specs_to_dm_env_specs(jumanji_spec)
)
assert type(converted_spec) == type(dm_env_spec)
assert type(converted_spec) is type(dm_env_spec)
assert converted_spec.shape == dm_env_spec.shape
assert converted_spec.dtype == dm_env_spec.dtype
assert converted_spec.name == dm_env_spec.name
Expand Down Expand Up @@ -675,7 +675,7 @@ def test_array(self) -> None:
jumanji_spec = specs.Array((1, 2), jnp.int32)
gym_space = gym.spaces.Box(-np.inf, np.inf, (1, 2), jnp.int32)
converted_spec = specs.jumanji_specs_to_gym_spaces(jumanji_spec)
assert type(converted_spec) == type(gym_space)
assert type(converted_spec) is type(gym_space)
assert_trees_all_equal(converted_spec.low, gym_space.low)
assert_trees_all_equal(converted_spec.high, gym_space.high)
assert converted_spec.shape == gym_space.shape
Expand All @@ -687,7 +687,7 @@ def test_bounded_array(self) -> None:
)
gym_space = gym.spaces.Box(low=0.0, high=1.0, shape=(1, 2), dtype=jnp.float32)
converted_spec = specs.jumanji_specs_to_gym_spaces(jumanji_spec)
assert type(converted_spec) == type(gym_space)
assert type(converted_spec) is type(gym_space)
assert converted_spec.shape == gym_space.shape
assert converted_spec.dtype == gym_space.dtype
assert_trees_all_equal(converted_spec.low, gym_space.low)
Expand All @@ -697,7 +697,7 @@ def test_discrete_array(self) -> None:
jumanji_spec = specs.DiscreteArray(num_values=5, dtype=jnp.int32)
gym_space = gym.spaces.Discrete(n=5)
converted_spec = specs.jumanji_specs_to_gym_spaces(jumanji_spec)
assert type(converted_spec) == type(gym_space)
assert type(converted_spec) is type(gym_space)
assert converted_spec.shape == gym_space.shape
assert converted_spec.dtype == gym_space.dtype
assert converted_spec.n == gym_space.n
Expand All @@ -708,7 +708,7 @@ def test_multi_discrete_array(self) -> None:
)
gym_space = gym.spaces.MultiDiscrete(nvec=[5, 6])
converted_spec = specs.jumanji_specs_to_gym_spaces(jumanji_spec)
assert type(converted_spec) == type(gym_space)
assert type(converted_spec) is type(gym_space)
assert converted_spec.shape == gym_space.shape
assert converted_spec.dtype == gym_space.dtype
assert jnp.array_equal(converted_spec.nvec, gym_space.nvec)
Expand Down
2 changes: 1 addition & 1 deletion jumanji/training/networks/graph_coloring/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def __call__(self, observation: Observation) -> chex.Array:
mlp_units=self.transformer_mlp_units,
w_init_scale=2 / self.num_transformer_layers,
model_size=self.model_size,
name=f"cross_attention_color_node_block_{block_id+1}",
name=f"cross_attention_color_node_block_{block_id + 1}",
)(color_embeddings, current_node_embeddings, current_node_embeddings)

return new_embedding
Expand Down
2 changes: 1 addition & 1 deletion jumanji/training/networks/mmst/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def __call__(self, observation: Observation) -> chex.Array:
mlp_units=self.transformer_mlp_units,
w_init_scale=2 / self.num_transformer_layers,
model_size=self.model_size,
name=f"cross_attention_agent_node_block_{block_id+1}",
name=f"cross_attention_agent_node_block_{block_id + 1}",
)(agents_embeddings, current_node_embeddings, current_node_embeddings)

return new_embedding
Expand Down
4 changes: 2 additions & 2 deletions jumanji/training/networks/tsp/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def __init__(
transformer_mlp_units: Sequence[int],
name: Optional[str] = None,
):
"""Linear embedding of all cities' coordinates followed by `transformer_num_blocks` blocks of self
attention.
"""Linear embedding of all cities' coordinates followed by `transformer_num_blocks` blocks
of self attention.
"""
super().__init__(name=name)
self.transformer_num_blocks = transformer_num_blocks
Expand Down
27 changes: 18 additions & 9 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,21 @@ docstring-convention = google
per-file-ignores = __init__.py:F401

ignore =
A002 # Argument shadowing a Python builtin.
A003 # Class attribute shadowing a Python builtin.
A005 # Module shadowing a Python builtin.
D107 # Do not require docstrings for __init__.
E266 # Do not require block comments to only have a single leading #.
E731 # Do not assign a lambda expression, use a def.
W503 # Line break before binary operator (not compatible with black).
B017 # assertRaises(Exception): or pytest.raises(Exception) should be considered evil.
E203 # black and flake8 disagree on whitespace before ':'.
# Argument shadowing a Python builtin.
A002
# Class attribute shadowing a Python builtin.
A003
# Module shadowing a Python builtin.
A005
# Do not require docstrings for __init__.
D107
# Do not require block comments to only have a single leading #.
E266
# Do not assign a lambda expression, use a def.
E731
# Line break before binary operator (not compatible with black).
W503
# assertRaises(Exception): or pytest.raises(Exception) should be considered evil.
B017
# black and flake8 disagree on whitespace before ':'.
E203

0 comments on commit c4bacb7

Please sign in to comment.