From 54767c7a163bf297dd3e8e9734d9a18b40906aef Mon Sep 17 00:00:00 2001 From: Mark Stephenson Date: Tue, 19 Dec 2023 17:25:11 -0700 Subject: [PATCH] Issue #74: Add PettingZoo Parallel API support --- pyproject.toml | 1 + .../envs/general_satellite_tasking/gym_env.py | 190 ++++++++++++-- .../test_int_full_environments.py | 33 +++ .../general_satellite_tasking/test_gym_env.py | 247 ++++++++++++++++++ 4 files changed, 456 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 02ea968e..24b0d74b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ "matplotlib", "numpy", "pandas", + "pettingzoo", "pytest", "pytest-cov", "pytest-repeat", diff --git a/src/bsk_rl/envs/general_satellite_tasking/gym_env.py b/src/bsk_rl/envs/general_satellite_tasking/gym_env.py index 28998f76..257baf09 100644 --- a/src/bsk_rl/envs/general_satellite_tasking/gym_env.py +++ b/src/bsk_rl/envs/general_satellite_tasking/gym_env.py @@ -1,8 +1,10 @@ +import functools from copy import deepcopy -from typing import Any, Iterable, Optional, TypeVar, Union +from typing import Any, Generic, Iterable, Optional, TypeVar, Union import numpy as np from gymnasium import Env, spaces +from pettingzoo.utils.env import AgentID, ParallelEnv from bsk_rl.envs.general_satellite_tasking.scenario.communication import NoCommunication from bsk_rl.envs.general_satellite_tasking.simulation.simulator import Simulator @@ -20,7 +22,7 @@ MultiSatAct = Iterable[SatAct] -class GeneralSatelliteTasking(Env): +class GeneralSatelliteTasking(Env, Generic[SatObs, SatAct]): def __init__( self, satellites: Union[Satellite, list[Satellite]], @@ -189,11 +191,12 @@ def _get_info(self) -> dict[str, Any]: info["requires_retasking"] = [ satellite.id for satellite in self.satellites - if satellite.requires_retasking + if satellite.requires_retasking and satellite.is_alive() ] return info def _get_reward(self): + """Return a scalar reward for the step.""" reward = sum(self.reward_dict.values()) for satellite in self.satellites: if not satellite.is_alive(): @@ -201,12 +204,14 @@ def _get_reward(self): return reward def _get_terminated(self) -> bool: + """Return the terminated flag for the step.""" if self.terminate_on_time_limit and self._get_truncated(): return True else: return not all(satellite.is_alive() for satellite in self.satellites) def _get_truncated(self) -> bool: + """Return the truncated flag for the step.""" return self.simulator.sim_time >= self.time_limit @property @@ -235,17 +240,7 @@ def observation_space(self) -> spaces.Space[MultiSatObs]: [satellite.observation_space for satellite in self.satellites] ) - def step( - self, actions: MultiSatAct - ) -> tuple[MultiSatObs, float, bool, bool, dict[str, Any]]: - """Propagate the simulation, update information, and get rewards - - Args: - Joint action for satellites - - Returns: - observation, reward, terminated, truncated, info - """ + def _step(self, actions: MultiSatAct) -> None: if len(actions) != len(self.satellites): raise ValueError("There must be the same number of actions and satellites") for satellite, action in zip(self.satellites, actions): @@ -272,6 +267,19 @@ def step( self.communicator.communicate() + def step( + self, actions: MultiSatAct + ) -> tuple[MultiSatObs, float, bool, bool, dict[str, Any]]: + """Propagate the simulation, update information, and get rewards + + Args: + Joint action for satellites + + Returns: + observation, reward, terminated, truncated, info + """ + self._step(actions) + observation = self._get_obs() reward = self._get_reward() terminated = self._get_terminated() @@ -289,7 +297,7 @@ def close(self) -> None: del self.simulator -class SingleSatelliteTasking(GeneralSatelliteTasking): +class SingleSatelliteTasking(GeneralSatelliteTasking, Generic[SatObs, SatAct]): """A special case of the GeneralSatelliteTasking for one satellite. For compatibility with standard training APIs, actions and observations are directly exposed for the single satellite and are not wrapped in a tuple. @@ -323,3 +331,155 @@ def step(self, action) -> tuple[Any, float, bool, bool, dict[str, Any]]: def _get_obs(self) -> Any: return self.satellite.get_obs() + + +class MultiagentSatelliteTasking( + GeneralSatelliteTasking, ParallelEnv, Generic[SatObs, SatAct, AgentID] +): + """Implements the environment with the PettingZoo parallel API.""" + + def reset( + self, seed: int | None = None, options=None + ) -> tuple[MultiSatObs, dict[str, Any]]: + self.newly_dead = [] + return super().reset(seed, options) + + @property + def agents(self) -> list[AgentID]: + """Agents currently in the environment""" + truncated = super()._get_truncated() + return [ + satellite.id + for satellite in self.satellites + if (satellite.is_alive() and not truncated) + ] + + @property + def num_agents(self) -> int: + """Number of agents currently in the environment""" + return len(self.agents) + + @property + def possible_agents(self) -> list[AgentID]: + """Return the list of all possible agents.""" + return [satellite.id for satellite in self.satellites] + + @property + def max_num_agents(self) -> int: + """Maximum number of agents possible in the environment""" + return len(self.possible_agents) + + @property + def previously_dead(self) -> list[AgentID]: + """Return the list of agents that died at least one step ago.""" + return list(set(self.possible_agents) - set(self.agents) - set(self.newly_dead)) + + @property + def observation_spaces(self) -> dict[AgentID, spaces.Box]: + """Return the observation space for each agent""" + return { + agent: obs_space + for agent, obs_space in zip(self.possible_agents, super().observation_space) + } + + @functools.lru_cache(maxsize=None) + def observation_space(self, agent: AgentID) -> spaces.Space[SatObs]: + """Return the observation space for a certain agent""" + return self.observation_spaces[agent] + + @property + def action_spaces(self) -> dict[AgentID, spaces.Space[SatAct]]: + """Return the action space for each agent""" + return { + agent: act_space + for agent, act_space in zip(self.possible_agents, super().action_space) + } + + @functools.lru_cache(maxsize=None) + def action_space(self, agent: AgentID) -> spaces.Space[SatAct]: + """Return the action space for a certain agent""" + return self.action_spaces[agent] + + def _get_obs(self) -> dict[AgentID, SatObs]: + """Format the observation per the PettingZoo Parallel API""" + return { + agent: satellite.get_obs() + for agent, satellite in zip(self.possible_agents, self.satellites) + if agent not in self.previously_dead + } + + def _get_reward(self) -> dict[AgentID, float]: + """Format the reward per the PettingZoo Parallel API""" + reward = deepcopy(self.reward_dict) + for agent, satellite in zip(self.possible_agents, self.satellites): + if not satellite.is_alive(): + reward[agent] += self.failure_penalty + + reward_keys = list(reward.keys()) + for agent in reward_keys: + if agent in self.previously_dead: + del reward[agent] + + return reward + + def _get_terminated(self) -> dict[AgentID, bool]: + """Format terminations per the PettingZoo Parallel API""" + if self.terminate_on_time_limit and super()._get_truncated(): + return { + agent: True + for agent in self.possible_agents + if agent not in self.previously_dead + } + else: + return { + agent: not satellite.is_alive() + for agent, satellite in zip(self.possible_agents, self.satellites) + if agent not in self.previously_dead + } + + def _get_truncated(self) -> dict[AgentID, bool]: + """Format truncations per the PettingZoo Parallel API""" + truncated = super()._get_truncated() + return { + agent: truncated + for agent in self.possible_agents + if agent not in self.previously_dead + } + + def _get_info(self) -> dict[AgentID, dict]: + """Format info per the PettingZoo Parallel API""" + info = super()._get_info() + for agent in self.possible_agents: + if agent in self.previously_dead: + del info[agent] + return info + + def step( + self, + actions: dict[AgentID, SatAct], + ) -> tuple[ + dict[AgentID, SatObs], + dict[AgentID, float], + dict[AgentID, bool], + dict[AgentID, bool], + dict[AgentID, dict], + ]: + """Step the environment and return PettingZoo Parallel API format""" + previous_alive = self.agents + + action_vector = [] + for agent in self.possible_agents: + if agent in actions.keys(): + action_vector.append(actions[agent]) + else: + action_vector.append(None) + self._step(action_vector) + + self.newly_dead = list(set(previous_alive) - set(self.agents)) + + observation = self._get_obs() + reward = self._get_reward() + terminated = self._get_terminated() + truncated = self._get_truncated() + info = self._get_info() + return observation, reward, terminated, truncated, info diff --git a/tests/integration/envs/general_satellite_tasking/test_int_full_environments.py b/tests/integration/envs/general_satellite_tasking/test_int_full_environments.py index ab344ed1..c068c9b2 100644 --- a/tests/integration/envs/general_satellite_tasking/test_int_full_environments.py +++ b/tests/integration/envs/general_satellite_tasking/test_int_full_environments.py @@ -1,6 +1,8 @@ import gymnasium as gym import pytest +from pettingzoo.test.parallel_test import parallel_api_test +from bsk_rl.envs.general_satellite_tasking.gym_env import MultiagentSatelliteTasking from bsk_rl.envs.general_satellite_tasking.scenario import data from bsk_rl.envs.general_satellite_tasking.scenario import satellites as sats from bsk_rl.envs.general_satellite_tasking.scenario.environment_features import ( @@ -35,6 +37,30 @@ disable_env_checker=True, ) +parallel_env = MultiagentSatelliteTasking( + satellites=[ + sats.FullFeaturedSatellite( + "Sentinel-2A", + sat_args=sats.FullFeaturedSatellite.default_sat_args(oe=random_orbit), + imageAttErrorRequirement=0.01, + imageRateErrorRequirement=0.01, + ), + sats.FullFeaturedSatellite( + "Sentinel-2B", + sat_args=sats.FullFeaturedSatellite.default_sat_args(oe=random_orbit), + imageAttErrorRequirement=0.01, + imageRateErrorRequirement=0.01, + ), + ], + env_type=environment.GroundStationEnvModel, + env_args=None, + env_features=StaticTargets(n_targets=1000), + data_manager=data.UniqueImagingManager(), + sim_rate=0.5, + max_step_duration=1e9, + time_limit=5700.0, +) + @pytest.mark.parametrize("env", [multi_env]) def test_reproducibility(env): @@ -61,3 +87,10 @@ def test_reproducibility(env): break assert reward_sum_2 == reward_sum_1 + + +@pytest.mark.repeat(5) +def test_parallel_api(): + with pytest.warns(UserWarning): + # expect an erroneous warning about the info dict due to our additional info + parallel_api_test(parallel_env) diff --git a/tests/unittest/envs/general_satellite_tasking/test_gym_env.py b/tests/unittest/envs/general_satellite_tasking/test_gym_env.py index 67a517da..69492961 100644 --- a/tests/unittest/envs/general_satellite_tasking/test_gym_env.py +++ b/tests/unittest/envs/general_satellite_tasking/test_gym_env.py @@ -5,6 +5,7 @@ from bsk_rl.envs.general_satellite_tasking.gym_env import ( GeneralSatelliteTasking, + MultiagentSatelliteTasking, SingleSatelliteTasking, ) from bsk_rl.envs.general_satellite_tasking.scenario.satellites import Satellite @@ -260,3 +261,249 @@ def test_step(self, step_patch): def test_get_obs(self): env, mock_sat = self.make_env() assert env._get_obs() == mock_sat.get_obs() + + +class TestMultiagentSatelliteTasking: + @patch( + "bsk_rl.envs.general_satellite_tasking.gym_env.Simulator", + ) + @patch( + "bsk_rl.envs.general_satellite_tasking.gym_env.MultiagentSatelliteTasking._get_obs", + ) + @patch( + "bsk_rl.envs.general_satellite_tasking.gym_env.MultiagentSatelliteTasking._get_info", + ) + def test_reset(self, mock_sim, obs_fn, info_fn): + mock_sat_1 = MagicMock() + mock_sat_2 = MagicMock() + mock_sat_1.sat_args_generator = {} + mock_sat_2.sat_args_generator = {} + mock_data = MagicMock(env_features=None) + env = MultiagentSatelliteTasking( + satellites=[mock_sat_1, mock_sat_2], + env_type=MagicMock(), + env_features=MagicMock(), + data_manager=mock_data, + ) + env.env_args_generator = {"utc_init": "a long time ago"} + env.communicator = MagicMock() + obs, info = env.reset() + obs_fn.assert_called_once() + info_fn.assert_called_once() + + @patch( + "bsk_rl.envs.general_satellite_tasking.gym_env.GeneralSatelliteTasking._get_truncated", + MagicMock(return_value=False), + ) + def test_agents(self): + env = MultiagentSatelliteTasking( + satellites=[MagicMock() for i in range(3)], + env_type=MagicMock(), + env_features=MagicMock(), + data_manager=MagicMock(), + ) + assert env.agents == [sat.id for sat in env.satellites] + assert env.num_agents == 3 + assert env.possible_agents == [sat.id for sat in env.satellites] + assert env.max_num_agents == 3 + + @patch( + "bsk_rl.envs.general_satellite_tasking.gym_env.GeneralSatelliteTasking._get_truncated", + MagicMock(return_value=False), + ) + def test_get_obs(self): + env = MultiagentSatelliteTasking( + satellites=[MagicMock(get_obs=MagicMock(return_value=i)) for i in range(3)], + env_type=MagicMock(), + env_features=MagicMock(), + data_manager=MagicMock(), + ) + env.newly_dead = [] + assert env._get_obs() == {sat.id: i for i, sat in enumerate(env.satellites)} + + @patch( + "bsk_rl.envs.general_satellite_tasking.gym_env.GeneralSatelliteTasking._get_truncated", + MagicMock(return_value=False), + ) + def test_get_info(self): + mock_sats = [MagicMock(info={"sat_index": i}) for i in range(3)] + env = MultiagentSatelliteTasking( + satellites=mock_sats, + env_type=MagicMock(), + env_features=MagicMock(), + data_manager=MagicMock(), + ) + env.newly_dead = [] + env.latest_step_duration = 10.0 + expected = {sat.id: {"sat_index": i} for i, sat in enumerate(mock_sats)} + expected["d_ts"] = 10.0 + expected["requires_retasking"] = [sat.id for sat in mock_sats] + assert env._get_info() == expected + + def test_action_spaces(self): + env = MultiagentSatelliteTasking( + satellites=[ + MagicMock(action_space=spaces.Discrete(i + 1)) for i in range(3) + ], + env_type=MagicMock(), + env_features=MagicMock(), + data_manager=MagicMock(), + ) + assert env.action_spaces == { + env.satellites[0].id: spaces.Discrete(1), + env.satellites[1].id: spaces.Discrete(2), + env.satellites[2].id: spaces.Discrete(3), + } + + def test_obs_spaces(self): + env = MultiagentSatelliteTasking( + satellites=[ + MagicMock(observation_space=spaces.Discrete(i + 1)) for i in range(3) + ], + env_type=MagicMock(), + env_features=MagicMock(), + data_manager=MagicMock(), + ) + env.simulator = MagicMock() + env.reset = MagicMock() + assert env.observation_spaces == { + env.satellites[0].id: spaces.Discrete(1), + env.satellites[1].id: spaces.Discrete(2), + env.satellites[2].id: spaces.Discrete(3), + } + + @patch( + "bsk_rl.envs.general_satellite_tasking.gym_env.GeneralSatelliteTasking._get_truncated", + MagicMock(return_value=False), + ) + def test_get_reward(self): + env = MultiagentSatelliteTasking( + satellites=[ + MagicMock(is_alive=MagicMock(return_value=False)) for i in range(3) + ], + env_type=MagicMock(), + env_features=MagicMock(), + data_manager=MagicMock(), + failure_penalty=-20.0, + ) + env.newly_dead = [sat.id for sat in env.satellites] + env.reward_dict = {sat.id: 10.0 for i, sat in enumerate(env.satellites)} + assert env._get_reward() == { + sat.id: -10.0 for i, sat in enumerate(env.satellites) + } + + @pytest.mark.parametrize("timeout", [False, True]) + @pytest.mark.parametrize("terminate_on_time_limit", [False, True]) + def test_get_terminated(self, timeout, terminate_on_time_limit): + env = MultiagentSatelliteTasking( + satellites=[ + MagicMock(is_alive=MagicMock(return_value=True if i != 0 else False)) + for i in range(3) + ], + env_type=MagicMock(), + env_features=MagicMock(), + data_manager=MagicMock(), + terminate_on_time_limit=terminate_on_time_limit, + time_limit=100, + ) + env.simulator = MagicMock(sim_time=101 if timeout else 99) + + if not timeout or not terminate_on_time_limit: + env.newly_dead = [sat.id for sat in env.satellites] + assert env._get_terminated() == { + env.satellites[0].id: True, + env.satellites[1].id: False, + env.satellites[2].id: False, + } + else: + env.newly_dead = [sat.id for sat in env.satellites] + assert env._get_terminated() == { + env.satellites[0].id: True, + env.satellites[1].id: True, + env.satellites[2].id: True, + } + + @pytest.mark.parametrize("time", [99, 101]) + def test_get_truncated(self, time): + env = MultiagentSatelliteTasking( + satellites=[MagicMock() for i in range(3)], + env_type=MagicMock(), + env_features=MagicMock(), + data_manager=MagicMock(), + time_limit=100, + ) + env.simulator = MagicMock(sim_time=time) + env.newly_dead = [sat.id for sat in env.satellites] if time >= 100 else [] + assert env._get_truncated() == { + env.satellites[0].id: time >= 100, + env.satellites[1].id: time >= 100, + env.satellites[2].id: time >= 100, + } + + def test_close(self): + env = MultiagentSatelliteTasking( + satellites=[MagicMock()], + env_type=MagicMock(), + env_features=MagicMock(), + data_manager=MagicMock(), + ) + env.simulator = MagicMock() + env.close() + assert not hasattr(env, "simulator") + + @patch( + "bsk_rl.envs.general_satellite_tasking.gym_env.GeneralSatelliteTasking._get_truncated", + MagicMock(return_value=False), + ) + def test_dead(self): + env = MultiagentSatelliteTasking( + satellites=[MagicMock() for _ in range(3)], + env_type=MagicMock(), + env_features=MagicMock(), + data_manager=MagicMock(), + ) + env.satellites[1].is_alive = MagicMock(return_value=False) + env.satellites[2].is_alive = MagicMock(return_value=False) + env.newly_dead = [env.satellites[2].id] + assert env.previously_dead == [env.satellites[1].id] + assert env.agents == [env.satellites[0].id] + assert env.possible_agents == [sat.id for sat in env.satellites] + + mst = "bsk_rl.envs.general_satellite_tasking.gym_env.MultiagentSatelliteTasking." + + @patch( + "bsk_rl.envs.general_satellite_tasking.gym_env.GeneralSatelliteTasking._get_truncated", + MagicMock(return_value=False), + ) + @patch(mst + "_get_obs", MagicMock()) + @patch(mst + "_get_reward", MagicMock()) + @patch(mst + "_get_terminated", MagicMock()) + @patch(mst + "_get_truncated", MagicMock()) + @patch(mst + "_get_info", MagicMock()) + @patch( + "bsk_rl.envs.general_satellite_tasking.gym_env.GeneralSatelliteTasking._step", + MagicMock(), + ) + def test_step(self): + env = MultiagentSatelliteTasking( + satellites=[ + MagicMock(is_alive=MagicMock(return_value=True)) for _ in range(3) + ], + env_type=MagicMock(), + env_features=MagicMock(), + data_manager=MagicMock(), + ) + + def kill_sat_2(): + env.satellites[2].is_alive.return_value = False + + env._step.side_effect = lambda _: kill_sat_2() + env.satellites[1].is_alive.return_value = False + env.step( + { + env.satellites[0].id: 0, + env.satellites[2].id: 2, + } + ) + env._step.assert_called_with([0, None, 2]) + assert env.newly_dead == [env.satellites[2].id]