Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/74 petting zoo #104

Merged
merged 2 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"matplotlib",
"numpy",
"pandas",
"pettingzoo",
"pytest",
"pytest-cov",
"pytest-repeat",
Expand Down
227 changes: 197 additions & 30 deletions src/bsk_rl/envs/general_satellite_tasking/gym_env.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import functools
from copy import deepcopy
from typing import Any, Iterable, Optional, Union
from typing import Any, Generic, Iterable, Optional, TypeVar, Union

import numpy as np
from gymnasium import Env, spaces
from pettingzoo.utils.env import AgentID, ParallelEnv

from bsk_rl.envs.general_satellite_tasking.scenario.communication import NoCommunication
from bsk_rl.envs.general_satellite_tasking.simulation.simulator import Simulator
Expand All @@ -14,13 +16,13 @@
Satellite,
)

SatObs = Any
SatAct = Any
SatObs = TypeVar("SatObs")
SatAct = TypeVar("SatAct")
MultiSatObs = tuple[SatObs, ...]
MultiSatAct = Iterable[SatAct]


class GeneralSatelliteTasking(Env):
class GeneralSatelliteTasking(Env, Generic[SatObs, SatAct]):
def __init__(
self,
satellites: Union[Satellite, list[Satellite]],
Expand Down Expand Up @@ -67,7 +69,7 @@ def __init__(
communicator: Object to manage communication between satellites
sim_rate: Rate for model simulation [s].
max_step_duration: Maximum time to propagate sim at a step [s].
failure_penalty: Reward for satellite failure.
failure_penalty: Reward for satellite failure. Should be nonpositive.
time_limit: Time at which to truncate the simulation [s].
terminate_on_time_limit: Send terminations signal time_limit instead of just
truncation.
Expand Down Expand Up @@ -189,10 +191,29 @@ def _get_info(self) -> dict[str, Any]:
info["requires_retasking"] = [
satellite.id
for satellite in self.satellites
if satellite.requires_retasking
if satellite.requires_retasking and satellite.is_alive()
]
return info

def _get_reward(self):
"""Return a scalar reward for the step."""
reward = sum(self.reward_dict.values())
for satellite in self.satellites:
if not satellite.is_alive():
reward += self.failure_penalty
return reward

def _get_terminated(self) -> bool:
"""Return the terminated flag for the step."""
if self.terminate_on_time_limit and self._get_truncated():
return True
else:
return not all(satellite.is_alive() for satellite in self.satellites)

def _get_truncated(self) -> bool:
"""Return the truncated flag for the step."""
return self.simulator.sim_time >= self.time_limit

@property
def action_space(self) -> spaces.Space[MultiSatAct]:
"""Compose satellite action spaces
Expand All @@ -219,17 +240,7 @@ def observation_space(self) -> spaces.Space[MultiSatObs]:
[satellite.observation_space for satellite in self.satellites]
)

def step(
self, actions: MultiSatAct
) -> tuple[MultiSatObs, float, bool, bool, dict[str, Any]]:
"""Propagate the simulation, update information, and get rewards

Args:
Joint action for satellites

Returns:
observation, reward, terminated, truncated, info
"""
def _step(self, actions: MultiSatAct) -> None:
if len(actions) != len(self.satellites):
raise ValueError("There must be the same number of actions and satellites")
for satellite, action in zip(self.satellites, actions):
Expand All @@ -252,23 +263,27 @@ def step(
satellite.id: satellite.data_store.internal_update()
for satellite in self.satellites
}
reward = self.data_manager.reward(new_data)
self.reward_dict = self.data_manager.reward(new_data)

self.communicator.communicate()

terminated = False
for satellite in self.satellites:
if not satellite.is_alive():
terminated = True
reward += self.failure_penalty
def step(
self, actions: MultiSatAct
) -> tuple[MultiSatObs, float, bool, bool, dict[str, Any]]:
"""Propagate the simulation, update information, and get rewards

Args:
Joint action for satellites

truncated = False
if self.simulator.sim_time >= self.time_limit:
truncated = True
if self.terminate_on_time_limit:
terminated = True
Returns:
observation, reward, terminated, truncated, info
"""
self._step(actions)

observation = self._get_obs()
reward = self._get_reward()
terminated = self._get_terminated()
truncated = self._get_truncated()
info = self._get_info()
return observation, reward, terminated, truncated, info

Expand All @@ -282,7 +297,7 @@ def close(self) -> None:
del self.simulator


class SingleSatelliteTasking(GeneralSatelliteTasking):
class SingleSatelliteTasking(GeneralSatelliteTasking, Generic[SatObs, SatAct]):
"""A special case of the GeneralSatelliteTasking for one satellite. For
compatibility with standard training APIs, actions and observations are directly
exposed for the single satellite and are not wrapped in a tuple.
Expand All @@ -296,7 +311,7 @@ def __init__(self, *args, **kwargs) -> None:
)

@property
def action_space(self) -> spaces.Discrete:
def action_space(self) -> spaces.Space[SatAct]:
"""Return the single satellite action space"""
return self.satellite.action_space

Expand All @@ -316,3 +331,155 @@ def step(self, action) -> tuple[Any, float, bool, bool, dict[str, Any]]:

def _get_obs(self) -> Any:
return self.satellite.get_obs()


class MultiagentSatelliteTasking(
GeneralSatelliteTasking, ParallelEnv, Generic[SatObs, SatAct, AgentID]
):
"""Implements the environment with the PettingZoo parallel API."""

def reset(
self, seed: int | None = None, options=None
) -> tuple[MultiSatObs, dict[str, Any]]:
self.newly_dead = []
return super().reset(seed, options)

@property
def agents(self) -> list[AgentID]:
"""Agents currently in the environment"""
truncated = super()._get_truncated()
return [
satellite.id
for satellite in self.satellites
if (satellite.is_alive() and not truncated)
]

@property
def num_agents(self) -> int:
"""Number of agents currently in the environment"""
return len(self.agents)

@property
def possible_agents(self) -> list[AgentID]:
"""Return the list of all possible agents."""
return [satellite.id for satellite in self.satellites]

@property
def max_num_agents(self) -> int:
"""Maximum number of agents possible in the environment"""
return len(self.possible_agents)

@property
def previously_dead(self) -> list[AgentID]:
"""Return the list of agents that died at least one step ago."""
return list(set(self.possible_agents) - set(self.agents) - set(self.newly_dead))

@property
def observation_spaces(self) -> dict[AgentID, spaces.Box]:
"""Return the observation space for each agent"""
return {
agent: obs_space
for agent, obs_space in zip(self.possible_agents, super().observation_space)
}

@functools.lru_cache(maxsize=None)
def observation_space(self, agent: AgentID) -> spaces.Space[SatObs]:
"""Return the observation space for a certain agent"""
return self.observation_spaces[agent]

@property
def action_spaces(self) -> dict[AgentID, spaces.Space[SatAct]]:
"""Return the action space for each agent"""
return {
agent: act_space
for agent, act_space in zip(self.possible_agents, super().action_space)
}

@functools.lru_cache(maxsize=None)
def action_space(self, agent: AgentID) -> spaces.Space[SatAct]:
"""Return the action space for a certain agent"""
return self.action_spaces[agent]

def _get_obs(self) -> dict[AgentID, SatObs]:
"""Format the observation per the PettingZoo Parallel API"""
return {
agent: satellite.get_obs()
for agent, satellite in zip(self.possible_agents, self.satellites)
if agent not in self.previously_dead
}

def _get_reward(self) -> dict[AgentID, float]:
"""Format the reward per the PettingZoo Parallel API"""
reward = deepcopy(self.reward_dict)
for agent, satellite in zip(self.possible_agents, self.satellites):
if not satellite.is_alive():
reward[agent] += self.failure_penalty

reward_keys = list(reward.keys())
for agent in reward_keys:
if agent in self.previously_dead:
del reward[agent]

return reward

def _get_terminated(self) -> dict[AgentID, bool]:
"""Format terminations per the PettingZoo Parallel API"""
if self.terminate_on_time_limit and super()._get_truncated():
return {
agent: True
for agent in self.possible_agents
if agent not in self.previously_dead
}
else:
return {
agent: not satellite.is_alive()
for agent, satellite in zip(self.possible_agents, self.satellites)
if agent not in self.previously_dead
}

def _get_truncated(self) -> dict[AgentID, bool]:
"""Format truncations per the PettingZoo Parallel API"""
truncated = super()._get_truncated()
return {
agent: truncated
for agent in self.possible_agents
if agent not in self.previously_dead
}

def _get_info(self) -> dict[AgentID, dict]:
"""Format info per the PettingZoo Parallel API"""
info = super()._get_info()
for agent in self.possible_agents:
if agent in self.previously_dead:
del info[agent]
return info

def step(
self,
actions: dict[AgentID, SatAct],
) -> tuple[
dict[AgentID, SatObs],
dict[AgentID, float],
dict[AgentID, bool],
dict[AgentID, bool],
dict[AgentID, dict],
]:
"""Step the environment and return PettingZoo Parallel API format"""
previous_alive = self.agents

action_vector = []
for agent in self.possible_agents:
if agent in actions.keys():
action_vector.append(actions[agent])
else:
action_vector.append(None)
self._step(action_vector)

self.newly_dead = list(set(previous_alive) - set(self.agents))

observation = self._get_obs()
reward = self._get_reward()
terminated = self._get_terminated()
truncated = self._get_truncated()
info = self._get_info()
return observation, reward, terminated, truncated, info
Loading