Skip to content

Commit

Permalink
Issue #97: Report per-satellite rewards
Browse files Browse the repository at this point in the history
  • Loading branch information
Mark2000 committed Dec 20, 2023
1 parent 8fc01ea commit 7e332f7
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 48 deletions.
43 changes: 25 additions & 18 deletions src/bsk_rl/envs/general_satellite_tasking/gym_env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from copy import deepcopy
from typing import Any, Iterable, Optional, Union
from typing import Any, Iterable, Optional, TypeVar, Union

import numpy as np
from gymnasium import Env, spaces
Expand All @@ -14,8 +14,8 @@
Satellite,
)

SatObs = Any
SatAct = Any
SatObs = TypeVar("SatObs")
SatAct = TypeVar("SatAct")
MultiSatObs = tuple[SatObs, ...]
MultiSatAct = Iterable[SatAct]

Expand Down Expand Up @@ -67,7 +67,7 @@ def __init__(
communicator: Object to manage communication between satellites
sim_rate: Rate for model simulation [s].
max_step_duration: Maximum time to propagate sim at a step [s].
failure_penalty: Reward for satellite failure.
failure_penalty: Reward for satellite failure. Should be nonpositive.
time_limit: Time at which to truncate the simulation [s].
terminate_on_time_limit: Send terminations signal time_limit instead of just
truncation.
Expand Down Expand Up @@ -193,6 +193,22 @@ def _get_info(self) -> dict[str, Any]:
]
return info

def _get_reward(self):
reward = sum(self.reward_dict.values())
for satellite in self.satellites:
if not satellite.is_alive():
reward += self.failure_penalty
return reward

def _get_terminated(self) -> bool:
if self.terminate_on_time_limit and self._get_truncated():
return True
else:
return not all(satellite.is_alive() for satellite in self.satellites)

def _get_truncated(self) -> bool:
return self.simulator.sim_time >= self.time_limit

@property
def action_space(self) -> spaces.Space[MultiSatAct]:
"""Compose satellite action spaces
Expand Down Expand Up @@ -252,23 +268,14 @@ def step(
satellite.id: satellite.data_store.internal_update()
for satellite in self.satellites
}
reward = self.data_manager.reward(new_data)
self.reward_dict = self.data_manager.reward(new_data)

self.communicator.communicate()

terminated = False
for satellite in self.satellites:
if not satellite.is_alive():
terminated = True
reward += self.failure_penalty

truncated = False
if self.simulator.sim_time >= self.time_limit:
truncated = True
if self.terminate_on_time_limit:
terminated = True

observation = self._get_obs()
reward = self._get_reward()
terminated = self._get_terminated()
truncated = self._get_truncated()
info = self._get_info()
return observation, reward, terminated, truncated, info

Expand Down Expand Up @@ -296,7 +303,7 @@ def __init__(self, *args, **kwargs) -> None:
)

@property
def action_space(self) -> spaces.Discrete:
def action_space(self) -> spaces.Space[SatAct]:
"""Return the single satellite action space"""
return self.satellite.action_space

Expand Down
45 changes: 30 additions & 15 deletions src/bsk_rl/envs/general_satellite_tasking/scenario/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,31 +120,34 @@ def __init__(self, env_features: Optional["EnvironmentFeatures"] = None) -> None
data
"""
self.env_features = env_features
self.DataType = self.DataStore.DataType

def reset(self) -> None:
self.data = self.DataStore.DataType()
self.cum_reward = 0.0
self.data = self.DataType()
self.cum_reward = {}

def create_data_store(self, satellite: "Satellite") -> None:
"""Create a data store for a satellite"""
satellite.data_store = self.DataStore(self, satellite)
self.cum_reward[satellite.id] = 0.0

@abstractmethod # pragma: no cover
def _calc_reward(self, new_data_dict: dict[str, DataType]) -> float:
def _calc_reward(self, new_data_dict: dict[str, DataType]) -> dict[str, float]:
"""Calculate step reward based on all satellite data from a step
Args:
new_data_dict: Satellite-DataType pairs of new data from a step
Returns:
Step reward
Step reward for each satellite
"""
pass

def reward(self, new_data_dict: dict[str, DataType]) -> float:
def reward(self, new_data_dict: dict[str, DataType]) -> dict[str, float]:
"""Calls _calc_reward and logs cumulative reward"""
reward = self._calc_reward(new_data_dict)
self.cum_reward += reward
for satellite_id, sat_reward in reward.items():
self.cum_reward[satellite_id] += sat_reward
return reward


Expand All @@ -167,7 +170,7 @@ class NoDataManager(DataManager):
DataStore = NoDataStore

def _calc_reward(self, new_data_dict):
return 0
return {sat: 0.0 for sat in new_data_dict.keys()}


#######################################
Expand Down Expand Up @@ -261,7 +264,9 @@ def __init__(
super().__init__(env_features)
self.reward_fn = reward_fn

def _calc_reward(self, new_data_dict: dict[str, UniqueImageData]) -> float:
def _calc_reward(
self, new_data_dict: dict[str, UniqueImageData]
) -> dict[str, float]:
"""Reward new each unique image once using self.reward_fn()
Args:
Expand All @@ -270,11 +275,19 @@ def _calc_reward(self, new_data_dict: dict[str, UniqueImageData]) -> float:
Returns:
reward: Cumulative reward across satellites for one step
"""
reward = 0.0
for new_data in new_data_dict.values():
reward = {}
imaged_targets = sum(
[new_data.imaged for new_data in new_data_dict.values()], []
)
for sat_id, new_data in new_data_dict.items():
reward[sat_id] = 0.0
for target in new_data.imaged:
if target not in self.data.imaged:
reward += self.reward_fn(target.priority)
reward[sat_id] += self.reward_fn(
target.priority
) / imaged_targets.count(target)

for new_data in new_data_dict.values():
self.data += new_data
return reward

Expand Down Expand Up @@ -454,7 +467,9 @@ def reward_fn(p):

self.reward_fn = reward_fn

def _calc_reward(self, new_data_dict: ["NadirScanningTimeData"]) -> float:
def _calc_reward(
self, new_data_dict: dict[str, "NadirScanningTimeData"]
) -> dict[str, float]:
"""Calculate step reward based on all satellite data from a step
Args:
Expand All @@ -463,8 +478,8 @@ def _calc_reward(self, new_data_dict: ["NadirScanningTimeData"]) -> float:
Returns:
Step reward
"""
reward = 0.0
for scanning_time in new_data_dict.values():
reward += self.reward_fn(scanning_time.scanning_time)
reward = {}
for sat, scanning_time in new_data_dict.items():
reward[sat] = self.reward_fn(scanning_time.scanning_time)

return reward
38 changes: 26 additions & 12 deletions tests/unittest/envs/general_satellite_tasking/scenario/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,23 @@ def test_reset(self):
data.DataManager.DataStore = MagicMock()
dm = data.DataManager(MagicMock())
dm.reset()
assert dm.cum_reward == 0
assert dm.cum_reward == {}

def test_create_data_store(self):
sat = MagicMock()
data.DataManager.DataStore = MagicMock(return_value="ds")
dm = data.DataManager(MagicMock())
dm.reset()
dm.create_data_store(sat)
assert sat.data_store == "ds"
assert sat.id in dm.cum_reward

def test_reward(self):
dm = data.DataManager(MagicMock())
dm._calc_reward = MagicMock(return_value=10.0)
dm.cum_reward = 0
assert 10.0 == dm.reward({"new": "data"})
assert dm.cum_reward == 10.0
dm._calc_reward = MagicMock(return_value={"sat": 10.0})
dm.cum_reward = {"sat": 5.0}
assert {"sat": 10.0} == dm.reward({"sat": "data"})
assert dm.cum_reward == {"sat": 15.0}


class TestNoData:
Expand All @@ -73,7 +75,7 @@ class TestNoDataManager:
def test_calc_reward(self):
dm = data.NoDataManager(MagicMock())
reward = dm._calc_reward({"sat1": 0, "sat2": 1})
assert reward == 0
assert reward == {"sat1": 0.0, "sat2": 0.0}


class TestUniqueImageData:
Expand Down Expand Up @@ -159,7 +161,7 @@ def test_calc_reward(self):
"sat2": data.UniqueImageData([MagicMock(priority=0.2)]),
}
)
assert reward == approx(0.3)
assert reward == {"sat1": approx(0.1), "sat2": approx(0.2)}

def test_calc_reward_existing(self):
tgt = MagicMock(priority=0.2)
Expand All @@ -171,7 +173,19 @@ def test_calc_reward_existing(self):
"sat2": data.UniqueImageData([tgt]),
}
)
assert reward == approx(0.1)
assert reward == {"sat1": approx(0.1), "sat2": 0.0}

def test_calc_reward_repeated(self):
tgt = MagicMock(priority=0.2)
dm = data.UniqueImagingManager(MagicMock())
dm.data = data.UniqueImageData([])
reward = dm._calc_reward(
{
"sat1": data.UniqueImageData([tgt]),
"sat2": data.UniqueImageData([tgt]),
}
)
assert reward == {"sat1": approx(0.1), "sat2": approx(0.1)}

def test_calc_reward_custom_fn(self):
dm = data.UniqueImagingManager(MagicMock(), reward_fn=lambda x: 1 / x)
Expand All @@ -182,7 +196,7 @@ def test_calc_reward_custom_fn(self):
"sat2": data.UniqueImageData([MagicMock(priority=2)]),
}
)
assert reward == approx(1.5)
assert reward == {"sat1": approx(1.0), "sat2": 0.5}


class TestNadirScanningTimeData:
Expand Down Expand Up @@ -240,7 +254,7 @@ def test_calc_reward(self):
"sat2": data.NadirScanningTimeData(2),
}
)
assert reward == approx(3)
assert reward == {"sat1": 1.0, "sat2": 2.0}

def test_calc_reward_existing(self):
dm = data.NadirScanningManager(MagicMock())
Expand All @@ -252,7 +266,7 @@ def test_calc_reward_existing(self):
"sat2": data.NadirScanningTimeData(3),
}
)
assert reward == approx(5)
assert reward == {"sat1": 2.0, "sat2": 3.0}

def test_calc_reward_custom_fn(self):
dm = data.NadirScanningManager(MagicMock(), reward_fn=lambda x: 1 / x)
Expand All @@ -263,4 +277,4 @@ def test_calc_reward_custom_fn(self):
"sat2": data.NadirScanningTimeData(2),
}
)
assert reward == approx(1.0)
assert reward == {"sat1": 0.5, "sat2": 0.5}
10 changes: 7 additions & 3 deletions tests/unittest/envs/general_satellite_tasking/test_gym_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ def test_step(self):
satellites=mock_sats,
env_type=MagicMock(),
env_features=MagicMock(),
data_manager=MagicMock(reward=MagicMock(return_value=25.0)),
data_manager=MagicMock(
reward=MagicMock(return_value={sat.id: 12.5 for sat in mock_sats})
),
)
env.simulator = MagicMock(sim_time=101.0)
_, reward, _, _, info = env.step((0, 10))
Expand Down Expand Up @@ -154,7 +156,9 @@ def test_step_stopped(self, sat_death, timeout, terminate_on_time_limit):
satellites=mock_sats,
env_type=MagicMock(),
env_features=MagicMock(),
data_manager=MagicMock(reward=MagicMock(return_value=25.0)),
data_manager=MagicMock(
reward=MagicMock(return_value={sat.id: 12.5 for sat in mock_sats})
),
terminate_on_time_limit=terminate_on_time_limit,
)
env.simulator = MagicMock(sim_time=101.0)
Expand All @@ -178,7 +182,7 @@ def test_step_retask_needed(self, capfd):
satellites=[mock_sat],
env_type=MagicMock(),
env_features=MagicMock(),
data_manager=MagicMock(reward=MagicMock(return_value=25.0)),
data_manager=MagicMock(reward=MagicMock(return_value={mock_sat.id: 25.0})),
)
env.simulator = MagicMock(sim_time=101.0)
env.step(None)
Expand Down

0 comments on commit 7e332f7

Please sign in to comment.