Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Features] PettingZoo possibility to choose reset strategy #2048

Merged
merged 16 commits into from
Apr 8, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ dependencies:
- expecttest
- pyyaml
- autorom[accept-rom-license]
- pettingzoo[all]==1.24.1
- pettingzoo[all]==1.24.3
vmoens marked this conversation as resolved.
Show resolved Hide resolved
58 changes: 54 additions & 4 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,6 @@ def _make_spec( # noqa: F811

@pytest.mark.parametrize("categorical", [True, False])
def test_gym_spec_cast(self, categorical):

batch_size = [3, 4]
cat = DiscreteTensorSpec if categorical else OneHotDiscreteTensorSpec
cat_shape = batch_size if categorical else (*batch_size, 5)
Expand Down Expand Up @@ -543,7 +542,6 @@ def test_torchrl_to_gym(self, backend, numpy):
],
)
def test_gym(self, env_name, frame_skip, from_pixels, pixels_only):

if env_name == PONG_VERSIONED() and not from_pixels:
# raise pytest.skip("already pixel")
# we don't skip because that would raise an exception
Expand Down Expand Up @@ -3126,7 +3124,6 @@ class TestPettingZoo:
def test_pistonball(
self, parallel, continuous_actions, use_mask, return_state, group_map
):

kwargs = {"n_pistons": 21, "continuous": continuous_actions}

env = PettingZooEnv(
Expand All @@ -3141,6 +3138,60 @@ def test_pistonball(

check_env_specs(env)

def test_dead_agents_done(self, seed=0):
scenario_args = {"n_walkers": 3, "terminate_on_fall": False}

env = PettingZooEnv(
task="multiwalker_v9",
parallel=True,
seed=seed,
use_mask=False,
done_on_any=False,
**scenario_args,
)
td_reset = env.reset(seed=seed)
with pytest.raises(
ValueError,
match="Dead agents found in the environment, "
"you need to set use_mask=True to allow this.",
):
env.rollout(
max_steps=500,
break_when_any_done=True, # This looks at root done set with done_on_any
auto_reset=False,
tensordict=td_reset,
)

for done_on_any in [True, False]:
env = PettingZooEnv(
task="multiwalker_v9",
parallel=True,
seed=seed,
use_mask=True,
done_on_any=done_on_any,
**scenario_args,
)
td_reset = env.reset(seed=seed)
td = env.rollout(
max_steps=500,
break_when_any_done=True, # This looks at root done set with done_on_any
auto_reset=False,
tensordict=td_reset,
)
done = td.get(("next", "walker", "done"))
mask = td.get(("next", "walker", "mask"))

if done_on_any:
assert not done[-1].all() # Done triggered on any
else:
assert done[-1].all() # Done triggered on all
assert not done[
mask
].any() # When mask is true (alive agent), all agents are not done
assert done[
~mask
].all() # When mask is false (dead agent), all agents are done

@pytest.mark.parametrize(
"wins_player_0",
[True, False],
Expand All @@ -3156,7 +3207,6 @@ def test_tic_tac_toe(self, wins_player_0):
)

class Policy:

action = 0
t = 0

Expand Down
49 changes: 38 additions & 11 deletions torchrl/envs/libs/pettingzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,12 @@ class PettingZooWrapper(_EnvWrapper):
If the number of agents during the task varies, please set ``use_mask=True``.
``"mask"`` will be provided
as an output in each group and should be used to mask out dead agents.
The environment will be reset as soon as one agent is done.
The environment will be reset as soon as one agent is done (unless ``done_on_any`` is ``False``).

In wrapped ``pettingzoo.AECEnv``, at each step only one agent will act.
For this reason, it is compulsory to set ``use_mask=True`` for this type of environment.
``"mask"`` will be provided as an output for each group and can be used to mask out non-acting agents.
The environment will be reset only when all agents are done.
The environment will be reset only when all agents are done (unless ``done_on_any`` is ``True``).

If there are any unavailable actions for an agent,
the environment will also automatically update the mask of its ``action_spec`` and output an ``"action_mask"``
Expand Down Expand Up @@ -156,6 +156,9 @@ class PettingZooWrapper(_EnvWrapper):
categorical_actions (bool, optional): if the enviornments actions are discrete, whether to transform
them to categorical or one-hot.
seed (int, optional): the seed. Defaults to ``None``.
done_on_any (bool, optional): whether the environment's done keys are set by aggregating the agent keys
using ``any()`` (when ``True``) or ``all()`` (when ``False``). Default (``None``) is to use ``any()`` for
parallel environments and ``all()`` for AEC ones.

Examples:
>>> # Parallel env
Expand Down Expand Up @@ -204,6 +207,7 @@ def __init__(
use_mask: bool = False,
categorical_actions: bool = True,
seed: int | None = None,
done_on_any: bool | None = None,
**kwargs,
):
if env is not None:
Expand All @@ -214,6 +218,7 @@ def __init__(
self.seed = seed
self.use_mask = use_mask
self.categorical_actions = categorical_actions
self.done_on_any = done_on_any

super().__init__(**kwargs, allow_done_after_reset=True)

Expand Down Expand Up @@ -283,6 +288,9 @@ def _make_specs(
"pettingzoo.utils.env.AECEnv", # noqa: F821
],
) -> None:
# Set default for done on any or all
if self.done_on_any is None:
self.done_on_any = self.parallel

# Create and check group map
if self.group_map is None:
Expand Down Expand Up @@ -582,7 +590,6 @@ def _step(
self,
tensordict: TensorDictBase,
) -> TensorDictBase:

if self.parallel:
(
observation_dict,
Expand Down Expand Up @@ -651,16 +658,33 @@ def _step(
value, device=self.device
)

elif not self.use_action_mask:
elif self.use_mask:
if agent in self.agents:
raise ValueError(
f"Dead agent {agent} not found in step observation but still available in {self.agents}"
)
# Dead agent
terminated = (
terminations_dict[agent] if agent in terminations_dict else True
)
truncated = (
truncations_dict[agent] if agent in truncations_dict else True
)
done = terminated or truncated
group_done[index] = done
group_terminated[index] = terminated
group_truncated[index] = truncated

else:
# Dead agent, if we are not masking it out, this is not allowed
raise ValueError(
"Dead agents found in the environment,"
" you need to set use_action_mask=True to allow this."
" you need to set use_mask=True to allow this."
)

# set done values
done, terminated, truncated = self._aggregate_done(
tensordict_out, use_any=self.parallel
tensordict_out, use_any=self.done_on_any
)

tensordict_out.set("done", done)
Expand All @@ -673,7 +697,7 @@ def _aggregate_done(self, tensordict_out, use_any):
truncated = False if use_any else True
terminated = False if use_any else True
for key in self.done_keys:
if isinstance(key, tuple):
if isinstance(key, tuple): # Only look at group keys
if use_any:
if key[-1] == "done":
done = done | tensordict_out.get(key).any()
Expand Down Expand Up @@ -719,7 +743,6 @@ def _step_aec(
self,
tensordict: TensorDictBase,
) -> Tuple[Dict, Dict, Dict, Dict, Dict]:

for group, agents in self.group_map.items():
if self.agent_selection in agents:
agent_index = agents.index(self._env.agent_selection)
Expand Down Expand Up @@ -747,7 +770,6 @@ def _step_aec(
)

def _update_action_mask(self, td, observation_dict, info_dict):

# Since we remove the action_mask keys we need to copy the data
observation_dict = copy.deepcopy(observation_dict)
info_dict = copy.deepcopy(info_dict)
Expand Down Expand Up @@ -821,15 +843,15 @@ class PettingZooEnv(PettingZooWrapper):
If the number of agents during the task varies, please set ``use_mask=True``.
``"mask"`` will be provided
as an output in each group and should be used to mask out dead agents.
The environment will be reset as soon as one agent is done.
The environment will be reset as soon as one agent is done (unless ``done_on_any`` is ``False``).

For wrapping ``pettingzoo.AECEnv`` provide the name of your petting zoo task (in the ``task`` argument)
and specify ``parallel=False``. This will construct the ``pettingzoo.AECEnv`` version of that task
and wrap it for torchrl.
In wrapped ``pettingzoo.AECEnv``, at each step only one agent will act.
For this reason, it is compulsory to set ``use_mask=True`` for this type of environment.
``"mask"`` will be provided as an output for each group and can be used to mask out non-acting agents.
The environment will be reset only when all agents are done.
The environment will be reset only when all agents are done (unless ``done_on_any`` is ``True``).

If there are any unavailable actions for an agent,
the environment will also automatically update the mask of its ``action_spec`` and output an ``"action_mask"``
Expand Down Expand Up @@ -892,6 +914,9 @@ class PettingZooEnv(PettingZooWrapper):
categorical_actions (bool, optional): if the enviornments actions are discrete, whether to transform
them to categorical or one-hot.
seed (int, optional): the seed. Defaults to ``None``.
done_on_any (bool, optional): whether the environment's done keys are set by aggregating the agent keys
using ``any()`` (when ``True``) or ``all()`` (when ``False``). Default (``None``) is to use ``any()`` for
parallel environments and ``all()`` for AEC ones.

Examples:
>>> # Parallel env
Expand Down Expand Up @@ -930,6 +955,7 @@ def __init__(
use_mask: bool = False,
categorical_actions: bool = True,
seed: int | None = None,
done_on_any: bool | None = None,
**kwargs,
):
if not _has_pettingzoo:
Expand All @@ -944,6 +970,7 @@ def __init__(
kwargs["use_mask"] = use_mask
kwargs["categorical_actions"] = categorical_actions
kwargs["seed"] = seed
kwargs["done_on_any"] = done_on_any

super().__init__(**kwargs)

Expand Down
Loading