diff --git a/src/bsk_rl/envs/general_satellite_tasking/gym_env.py b/src/bsk_rl/envs/general_satellite_tasking/gym_env.py index bc944981..28998f76 100644 --- a/src/bsk_rl/envs/general_satellite_tasking/gym_env.py +++ b/src/bsk_rl/envs/general_satellite_tasking/gym_env.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import Any, Iterable, Optional, Union +from typing import Any, Iterable, Optional, TypeVar, Union import numpy as np from gymnasium import Env, spaces @@ -14,8 +14,8 @@ Satellite, ) -SatObs = Any -SatAct = Any +SatObs = TypeVar("SatObs") +SatAct = TypeVar("SatAct") MultiSatObs = tuple[SatObs, ...] MultiSatAct = Iterable[SatAct] @@ -67,7 +67,7 @@ def __init__( communicator: Object to manage communication between satellites sim_rate: Rate for model simulation [s]. max_step_duration: Maximum time to propagate sim at a step [s]. - failure_penalty: Reward for satellite failure. + failure_penalty: Reward for satellite failure. Should be nonpositive. time_limit: Time at which to truncate the simulation [s]. terminate_on_time_limit: Send terminations signal time_limit instead of just truncation. @@ -193,6 +193,22 @@ def _get_info(self) -> dict[str, Any]: ] return info + def _get_reward(self): + reward = sum(self.reward_dict.values()) + for satellite in self.satellites: + if not satellite.is_alive(): + reward += self.failure_penalty + return reward + + def _get_terminated(self) -> bool: + if self.terminate_on_time_limit and self._get_truncated(): + return True + else: + return not all(satellite.is_alive() for satellite in self.satellites) + + def _get_truncated(self) -> bool: + return self.simulator.sim_time >= self.time_limit + @property def action_space(self) -> spaces.Space[MultiSatAct]: """Compose satellite action spaces @@ -252,23 +268,14 @@ def step( satellite.id: satellite.data_store.internal_update() for satellite in self.satellites } - reward = self.data_manager.reward(new_data) + self.reward_dict = self.data_manager.reward(new_data) self.communicator.communicate() - terminated = False - for satellite in self.satellites: - if not satellite.is_alive(): - terminated = True - reward += self.failure_penalty - - truncated = False - if self.simulator.sim_time >= self.time_limit: - truncated = True - if self.terminate_on_time_limit: - terminated = True - observation = self._get_obs() + reward = self._get_reward() + terminated = self._get_terminated() + truncated = self._get_truncated() info = self._get_info() return observation, reward, terminated, truncated, info @@ -296,7 +303,7 @@ def __init__(self, *args, **kwargs) -> None: ) @property - def action_space(self) -> spaces.Discrete: + def action_space(self) -> spaces.Space[SatAct]: """Return the single satellite action space""" return self.satellite.action_space diff --git a/src/bsk_rl/envs/general_satellite_tasking/scenario/data.py b/src/bsk_rl/envs/general_satellite_tasking/scenario/data.py index daa06579..6b15e585 100644 --- a/src/bsk_rl/envs/general_satellite_tasking/scenario/data.py +++ b/src/bsk_rl/envs/general_satellite_tasking/scenario/data.py @@ -120,31 +120,34 @@ def __init__(self, env_features: Optional["EnvironmentFeatures"] = None) -> None data """ self.env_features = env_features + self.DataType = self.DataStore.DataType def reset(self) -> None: - self.data = self.DataStore.DataType() - self.cum_reward = 0.0 + self.data = self.DataType() + self.cum_reward = {} def create_data_store(self, satellite: "Satellite") -> None: """Create a data store for a satellite""" satellite.data_store = self.DataStore(self, satellite) + self.cum_reward[satellite.id] = 0.0 @abstractmethod # pragma: no cover - def _calc_reward(self, new_data_dict: dict[str, DataType]) -> float: + def _calc_reward(self, new_data_dict: dict[str, DataType]) -> dict[str, float]: """Calculate step reward based on all satellite data from a step Args: new_data_dict: Satellite-DataType pairs of new data from a step Returns: - Step reward + Step reward for each satellite """ pass - def reward(self, new_data_dict: dict[str, DataType]) -> float: + def reward(self, new_data_dict: dict[str, DataType]) -> dict[str, float]: """Calls _calc_reward and logs cumulative reward""" reward = self._calc_reward(new_data_dict) - self.cum_reward += reward + for satellite_id, sat_reward in reward.items(): + self.cum_reward[satellite_id] += sat_reward return reward @@ -167,7 +170,7 @@ class NoDataManager(DataManager): DataStore = NoDataStore def _calc_reward(self, new_data_dict): - return 0 + return {sat: 0.0 for sat in new_data_dict.keys()} ####################################### @@ -261,7 +264,9 @@ def __init__( super().__init__(env_features) self.reward_fn = reward_fn - def _calc_reward(self, new_data_dict: dict[str, UniqueImageData]) -> float: + def _calc_reward( + self, new_data_dict: dict[str, UniqueImageData] + ) -> dict[str, float]: """Reward new each unique image once using self.reward_fn() Args: @@ -270,11 +275,19 @@ def _calc_reward(self, new_data_dict: dict[str, UniqueImageData]) -> float: Returns: reward: Cumulative reward across satellites for one step """ - reward = 0.0 - for new_data in new_data_dict.values(): + reward = {} + imaged_targets = sum( + [new_data.imaged for new_data in new_data_dict.values()], [] + ) + for sat_id, new_data in new_data_dict.items(): + reward[sat_id] = 0.0 for target in new_data.imaged: if target not in self.data.imaged: - reward += self.reward_fn(target.priority) + reward[sat_id] += self.reward_fn( + target.priority + ) / imaged_targets.count(target) + + for new_data in new_data_dict.values(): self.data += new_data return reward @@ -454,7 +467,9 @@ def reward_fn(p): self.reward_fn = reward_fn - def _calc_reward(self, new_data_dict: ["NadirScanningTimeData"]) -> float: + def _calc_reward( + self, new_data_dict: dict[str, "NadirScanningTimeData"] + ) -> dict[str, float]: """Calculate step reward based on all satellite data from a step Args: @@ -463,8 +478,8 @@ def _calc_reward(self, new_data_dict: ["NadirScanningTimeData"]) -> float: Returns: Step reward """ - reward = 0.0 - for scanning_time in new_data_dict.values(): - reward += self.reward_fn(scanning_time.scanning_time) + reward = {} + for sat, scanning_time in new_data_dict.items(): + reward[sat] = self.reward_fn(scanning_time.scanning_time) return reward diff --git a/tests/unittest/envs/general_satellite_tasking/scenario/test_data.py b/tests/unittest/envs/general_satellite_tasking/scenario/test_data.py index 93f7ad14..19f6442b 100644 --- a/tests/unittest/envs/general_satellite_tasking/scenario/test_data.py +++ b/tests/unittest/envs/general_satellite_tasking/scenario/test_data.py @@ -38,21 +38,23 @@ def test_reset(self): data.DataManager.DataStore = MagicMock() dm = data.DataManager(MagicMock()) dm.reset() - assert dm.cum_reward == 0 + assert dm.cum_reward == {} def test_create_data_store(self): sat = MagicMock() data.DataManager.DataStore = MagicMock(return_value="ds") dm = data.DataManager(MagicMock()) + dm.reset() dm.create_data_store(sat) assert sat.data_store == "ds" + assert sat.id in dm.cum_reward def test_reward(self): dm = data.DataManager(MagicMock()) - dm._calc_reward = MagicMock(return_value=10.0) - dm.cum_reward = 0 - assert 10.0 == dm.reward({"new": "data"}) - assert dm.cum_reward == 10.0 + dm._calc_reward = MagicMock(return_value={"sat": 10.0}) + dm.cum_reward = {"sat": 5.0} + assert {"sat": 10.0} == dm.reward({"sat": "data"}) + assert dm.cum_reward == {"sat": 15.0} class TestNoData: @@ -73,7 +75,7 @@ class TestNoDataManager: def test_calc_reward(self): dm = data.NoDataManager(MagicMock()) reward = dm._calc_reward({"sat1": 0, "sat2": 1}) - assert reward == 0 + assert reward == {"sat1": 0.0, "sat2": 0.0} class TestUniqueImageData: @@ -159,7 +161,7 @@ def test_calc_reward(self): "sat2": data.UniqueImageData([MagicMock(priority=0.2)]), } ) - assert reward == approx(0.3) + assert reward == {"sat1": approx(0.1), "sat2": approx(0.2)} def test_calc_reward_existing(self): tgt = MagicMock(priority=0.2) @@ -171,7 +173,19 @@ def test_calc_reward_existing(self): "sat2": data.UniqueImageData([tgt]), } ) - assert reward == approx(0.1) + assert reward == {"sat1": approx(0.1), "sat2": 0.0} + + def test_calc_reward_repeated(self): + tgt = MagicMock(priority=0.2) + dm = data.UniqueImagingManager(MagicMock()) + dm.data = data.UniqueImageData([]) + reward = dm._calc_reward( + { + "sat1": data.UniqueImageData([tgt]), + "sat2": data.UniqueImageData([tgt]), + } + ) + assert reward == {"sat1": approx(0.1), "sat2": approx(0.1)} def test_calc_reward_custom_fn(self): dm = data.UniqueImagingManager(MagicMock(), reward_fn=lambda x: 1 / x) @@ -182,7 +196,7 @@ def test_calc_reward_custom_fn(self): "sat2": data.UniqueImageData([MagicMock(priority=2)]), } ) - assert reward == approx(1.5) + assert reward == {"sat1": approx(1.0), "sat2": 0.5} class TestNadirScanningTimeData: @@ -240,7 +254,7 @@ def test_calc_reward(self): "sat2": data.NadirScanningTimeData(2), } ) - assert reward == approx(3) + assert reward == {"sat1": 1.0, "sat2": 2.0} def test_calc_reward_existing(self): dm = data.NadirScanningManager(MagicMock()) @@ -252,7 +266,7 @@ def test_calc_reward_existing(self): "sat2": data.NadirScanningTimeData(3), } ) - assert reward == approx(5) + assert reward == {"sat1": 2.0, "sat2": 3.0} def test_calc_reward_custom_fn(self): dm = data.NadirScanningManager(MagicMock(), reward_fn=lambda x: 1 / x) @@ -263,4 +277,4 @@ def test_calc_reward_custom_fn(self): "sat2": data.NadirScanningTimeData(2), } ) - assert reward == approx(1.0) + assert reward == {"sat1": 0.5, "sat2": 0.5} diff --git a/tests/unittest/envs/general_satellite_tasking/test_gym_env.py b/tests/unittest/envs/general_satellite_tasking/test_gym_env.py index 6fd9f015..67a517da 100644 --- a/tests/unittest/envs/general_satellite_tasking/test_gym_env.py +++ b/tests/unittest/envs/general_satellite_tasking/test_gym_env.py @@ -119,7 +119,9 @@ def test_step(self): satellites=mock_sats, env_type=MagicMock(), env_features=MagicMock(), - data_manager=MagicMock(reward=MagicMock(return_value=25.0)), + data_manager=MagicMock( + reward=MagicMock(return_value={sat.id: 12.5 for sat in mock_sats}) + ), ) env.simulator = MagicMock(sim_time=101.0) _, reward, _, _, info = env.step((0, 10)) @@ -154,7 +156,9 @@ def test_step_stopped(self, sat_death, timeout, terminate_on_time_limit): satellites=mock_sats, env_type=MagicMock(), env_features=MagicMock(), - data_manager=MagicMock(reward=MagicMock(return_value=25.0)), + data_manager=MagicMock( + reward=MagicMock(return_value={sat.id: 12.5 for sat in mock_sats}) + ), terminate_on_time_limit=terminate_on_time_limit, ) env.simulator = MagicMock(sim_time=101.0) @@ -178,7 +182,7 @@ def test_step_retask_needed(self, capfd): satellites=[mock_sat], env_type=MagicMock(), env_features=MagicMock(), - data_manager=MagicMock(reward=MagicMock(return_value=25.0)), + data_manager=MagicMock(reward=MagicMock(return_value={mock_sat.id: 25.0})), ) env.simulator = MagicMock(sim_time=101.0) env.step(None)