Skip to content

Commit

Permalink
Issue #74: Add PettingZoo Parallel API support
Browse files Browse the repository at this point in the history
  • Loading branch information
Mark2000 committed Dec 20, 2023
1 parent 7e332f7 commit 9831942
Show file tree
Hide file tree
Showing 4 changed files with 456 additions and 15 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"matplotlib",
"numpy",
"pandas",
"pettingzoo",
"pytest",
"pytest-cov",
"pytest-repeat",
Expand Down
190 changes: 175 additions & 15 deletions src/bsk_rl/envs/general_satellite_tasking/gym_env.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import functools
from copy import deepcopy
from typing import Any, Iterable, Optional, TypeVar, Union
from typing import Any, Generic, Iterable, Optional, TypeVar, Union

import numpy as np
from gymnasium import Env, spaces
from pettingzoo.utils.env import AgentID, ParallelEnv

from bsk_rl.envs.general_satellite_tasking.scenario.communication import NoCommunication
from bsk_rl.envs.general_satellite_tasking.simulation.simulator import Simulator
Expand All @@ -20,7 +22,7 @@
MultiSatAct = Iterable[SatAct]


class GeneralSatelliteTasking(Env):
class GeneralSatelliteTasking(Env, Generic[SatObs, SatAct]):
def __init__(
self,
satellites: Union[Satellite, list[Satellite]],
Expand Down Expand Up @@ -189,24 +191,27 @@ def _get_info(self) -> dict[str, Any]:
info["requires_retasking"] = [
satellite.id
for satellite in self.satellites
if satellite.requires_retasking
if satellite.requires_retasking and satellite.is_alive()
]
return info

def _get_reward(self):
"""Return a scalar reward for the step."""
reward = sum(self.reward_dict.values())
for satellite in self.satellites:
if not satellite.is_alive():
reward += self.failure_penalty
return reward

def _get_terminated(self) -> bool:
"""Return the terminated flag for the step."""
if self.terminate_on_time_limit and self._get_truncated():
return True
else:
return not all(satellite.is_alive() for satellite in self.satellites)

def _get_truncated(self) -> bool:
"""Return the truncated flag for the step."""
return self.simulator.sim_time >= self.time_limit

@property
Expand Down Expand Up @@ -235,17 +240,7 @@ def observation_space(self) -> spaces.Space[MultiSatObs]:
[satellite.observation_space for satellite in self.satellites]
)

def step(
self, actions: MultiSatAct
) -> tuple[MultiSatObs, float, bool, bool, dict[str, Any]]:
"""Propagate the simulation, update information, and get rewards
Args:
Joint action for satellites
Returns:
observation, reward, terminated, truncated, info
"""
def _step(self, actions: MultiSatAct) -> None:
if len(actions) != len(self.satellites):
raise ValueError("There must be the same number of actions and satellites")
for satellite, action in zip(self.satellites, actions):
Expand All @@ -272,6 +267,19 @@ def step(

self.communicator.communicate()

def step(
self, actions: MultiSatAct
) -> tuple[MultiSatObs, float, bool, bool, dict[str, Any]]:
"""Propagate the simulation, update information, and get rewards
Args:
Joint action for satellites
Returns:
observation, reward, terminated, truncated, info
"""
self._step(actions)

observation = self._get_obs()
reward = self._get_reward()
terminated = self._get_terminated()
Expand All @@ -289,7 +297,7 @@ def close(self) -> None:
del self.simulator


class SingleSatelliteTasking(GeneralSatelliteTasking):
class SingleSatelliteTasking(GeneralSatelliteTasking, Generic[SatObs, SatAct]):
"""A special case of the GeneralSatelliteTasking for one satellite. For
compatibility with standard training APIs, actions and observations are directly
exposed for the single satellite and are not wrapped in a tuple.
Expand Down Expand Up @@ -323,3 +331,155 @@ def step(self, action) -> tuple[Any, float, bool, bool, dict[str, Any]]:

def _get_obs(self) -> Any:
return self.satellite.get_obs()


class MultiagentSatelliteTasking(
GeneralSatelliteTasking, ParallelEnv, Generic[SatObs, SatAct, AgentID]
):
"""Implements the environment with the PettingZoo parallel API."""

def reset(
self, seed: int | None = None, options=None
) -> tuple[MultiSatObs, dict[str, Any]]:
self.newly_dead = []
return super().reset(seed, options)

@property
def agents(self) -> list[AgentID]:
"""Agents currently in the environment"""
truncated = super()._get_truncated()
return [
satellite.id
for satellite in self.satellites
if (satellite.is_alive() and not truncated)
]

@property
def num_agents(self) -> int:
"""Number of agents currently in the environment"""
return len(self.agents)

@property
def possible_agents(self) -> list[AgentID]:
"""Return the list of all possible agents."""
return [satellite.id for satellite in self.satellites]

@property
def max_num_agents(self) -> int:
"""Maximum number of agents possible in the environment"""
return len(self.possible_agents)

@property
def previously_dead(self) -> list[AgentID]:
"""Return the list of agents that died at least one step ago."""
return list(set(self.possible_agents) - set(self.agents) - set(self.newly_dead))

@property
def observation_spaces(self) -> dict[AgentID, spaces.Box]:
"""Return the observation space for each agent"""
return {
agent: obs_space
for agent, obs_space in zip(self.possible_agents, super().observation_space)
}

@functools.lru_cache(maxsize=None)
def observation_space(self, agent: AgentID) -> spaces.Space[SatObs]:
"""Return the observation space for a certain agent"""
return self.observation_spaces[agent]

@property
def action_spaces(self) -> dict[AgentID, spaces.Space[SatAct]]:
"""Return the action space for each agent"""
return {
agent: act_space
for agent, act_space in zip(self.possible_agents, super().action_space)
}

@functools.lru_cache(maxsize=None)
def action_space(self, agent: AgentID) -> spaces.Space[SatAct]:
"""Return the action space for a certain agent"""
return self.action_spaces[agent]

def _get_obs(self) -> dict[AgentID, SatObs]:
"""Format the observation per the PettingZoo Parallel API"""
return {
agent: satellite.get_obs()
for agent, satellite in zip(self.possible_agents, self.satellites)
if agent not in self.previously_dead
}

def _get_reward(self) -> dict[AgentID, float]:
"""Format the reward per the PettingZoo Parallel API"""
reward = deepcopy(self.reward_dict)
for agent, satellite in zip(self.possible_agents, self.satellites):
if not satellite.is_alive():
reward[agent] += self.failure_penalty

reward_keys = list(reward.keys())
for agent in reward_keys:
if agent in self.previously_dead:
del reward[agent]

return reward

def _get_terminated(self) -> dict[AgentID, bool]:
"""Format terminations per the PettingZoo Parallel API"""
if self.terminate_on_time_limit and super()._get_truncated():
return {
agent: True
for agent in self.possible_agents
if agent not in self.previously_dead
}
else:
return {
agent: not satellite.is_alive()
for agent, satellite in zip(self.possible_agents, self.satellites)
if agent not in self.previously_dead
}

def _get_truncated(self) -> dict[AgentID, bool]:
"""Format truncations per the PettingZoo Parallel API"""
truncated = super()._get_truncated()
return {
agent: truncated
for agent in self.possible_agents
if agent not in self.previously_dead
}

def _get_info(self) -> dict[AgentID, dict]:
"""Format info per the PettingZoo Parallel API"""
info = super()._get_info()
for agent in self.possible_agents:
if agent in self.previously_dead:
del info[agent]
return info

def step(
self,
actions: dict[AgentID, SatAct],
) -> tuple[
dict[AgentID, SatObs],
dict[AgentID, float],
dict[AgentID, bool],
dict[AgentID, bool],
dict[AgentID, dict],
]:
"""Step the environment and return PettingZoo Parallel API format"""
previous_alive = self.agents

action_vector = []
for agent in self.possible_agents:
if agent in actions.keys():
action_vector.append(actions[agent])
else:
action_vector.append(None)
self._step(action_vector)

self.newly_dead = list(set(previous_alive) - set(self.agents))

observation = self._get_obs()
reward = self._get_reward()
terminated = self._get_terminated()
truncated = self._get_truncated()
info = self._get_info()
return observation, reward, terminated, truncated, info
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import gymnasium as gym
import pytest
from pettingzoo.test.parallel_test import parallel_api_test

from bsk_rl.envs.general_satellite_tasking.gym_env import MultiagentSatelliteTasking
from bsk_rl.envs.general_satellite_tasking.scenario import data
from bsk_rl.envs.general_satellite_tasking.scenario import satellites as sats
from bsk_rl.envs.general_satellite_tasking.scenario.environment_features import (
Expand Down Expand Up @@ -35,6 +37,30 @@
disable_env_checker=True,
)

parallel_env = MultiagentSatelliteTasking(
satellites=[
sats.FullFeaturedSatellite(
"Sentinel-2A",
sat_args=sats.FullFeaturedSatellite.default_sat_args(oe=random_orbit),
imageAttErrorRequirement=0.01,
imageRateErrorRequirement=0.01,
),
sats.FullFeaturedSatellite(
"Sentinel-2B",
sat_args=sats.FullFeaturedSatellite.default_sat_args(oe=random_orbit),
imageAttErrorRequirement=0.01,
imageRateErrorRequirement=0.01,
),
],
env_type=environment.GroundStationEnvModel,
env_args=None,
env_features=StaticTargets(n_targets=1000),
data_manager=data.UniqueImagingManager(),
sim_rate=0.5,
max_step_duration=1e9,
time_limit=5700.0,
)


@pytest.mark.parametrize("env", [multi_env])
def test_reproducibility(env):
Expand All @@ -61,3 +87,10 @@ def test_reproducibility(env):
break

assert reward_sum_2 == reward_sum_1


@pytest.mark.repeat(5)
def test_parallel_api():
with pytest.warns(UserWarning):
# expect an erroneous warning about the info dict due to our additional info
parallel_api_test(parallel_env)
Loading

0 comments on commit 9831942

Please sign in to comment.