-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
748 additions
and
30 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
"""Data system for recording RSO surface.""" | ||
|
||
import logging | ||
from typing import TYPE_CHECKING, Optional | ||
|
||
import numpy as np | ||
|
||
from bsk_rl.data.base import Data, DataStore, GlobalReward | ||
from bsk_rl.sats import Satellite | ||
from bsk_rl.scene.rso_points import RSOPoint | ||
from bsk_rl.sim.dyn import RSODynModel, RSOImagingDynModel | ||
|
||
if TYPE_CHECKING: | ||
from bsk_rl.sats import Satellite | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
RSO = "rso" | ||
OBSERVER = "observer" | ||
|
||
|
||
class RSOInspectionData(Data): | ||
def __init__(self, point_inspect_status: Optional[dict[RSOPoint, bool]] = None): | ||
if point_inspect_status is None: | ||
point_inspect_status = {} | ||
self.point_inspect_status = point_inspect_status | ||
|
||
def __add__(self, other: "RSOInspectionData"): | ||
point_inspect_status = {} | ||
point_inspect_status.update(self.point_inspect_status) | ||
for point, access in other.point_inspect_status.items(): | ||
if point not in point_inspect_status: | ||
point_inspect_status[point] = access | ||
else: | ||
point_inspect_status[point] = point_inspect_status[point] or access | ||
|
||
return RSOInspectionData(point_inspect_status) | ||
|
||
|
||
class RSOInspectionDataStore(DataStore): | ||
data_type = RSOInspectionData | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.point_access_recorders = [] | ||
self.storage_recorder = None | ||
|
||
if issubclass(self.satellite.dyn_type, RSOImagingDynModel): | ||
self.role = OBSERVER | ||
else: | ||
self.role = RSO | ||
|
||
def set_storage_recorder(self, recorder): | ||
self.storage_recorder = recorder | ||
self.satellite.simulator.AddModelToTask( | ||
self.satellite.dynamics.task_name, recorder, ModelPriority=1000 | ||
) | ||
|
||
def add_point_access_recorder(self, recorder): | ||
self.point_access_recorders.append(recorder) | ||
self.satellite.simulator.AddModelToTask( | ||
self.satellite.dynamics.task_name, recorder, ModelPriority=1000 | ||
) | ||
|
||
def clear_recorders(self): | ||
if self.storage_recorder: | ||
self.storage_recorder.clear() | ||
for recorder in self.point_access_recorders: | ||
recorder.clear() | ||
|
||
def get_log_state(self) -> list[list[bool]]: | ||
"""Log the storage unit state and point access state for all times in the step. | ||
Returns: | ||
todo | ||
""" | ||
if self.role == RSO: | ||
return None | ||
|
||
log_len = len(self.storage_recorder.storageLevel) | ||
if log_len <= 1: | ||
imaging_req = np.zeros(log_len) | ||
else: | ||
imaging_req = np.diff(self.storage_recorder.storageLevel) | ||
imaging_req = np.concatenate((imaging_req, [imaging_req[-1]])) | ||
|
||
inspected_logs = [] | ||
for recorder in self.point_access_recorders: | ||
# inspected = np.logical_and(imaging_req, recorder.hasAccess) | ||
inspected = recorder.hasAccess | ||
inspected_logs.append(list(inspected)) | ||
|
||
self.clear_recorders() | ||
|
||
return inspected_logs | ||
|
||
def compare_log_states(self, _, inspected_logs) -> Data: | ||
if self.role == RSO: | ||
return RSOInspectionData() | ||
|
||
point_inspect_status = {} | ||
for rso_point, log in zip( | ||
self.data.point_inspect_status.keys(), inspected_logs | ||
): | ||
if any(log): | ||
print(log) | ||
point_inspect_status[rso_point] = True | ||
|
||
if len(point_inspect_status) > 0: | ||
self.satellite.logger.info( | ||
f"Inspected {len(point_inspect_status)} points this step" | ||
) | ||
|
||
return RSOInspectionData(point_inspect_status) | ||
|
||
|
||
class RSOInspectionReward(GlobalReward): | ||
datastore_type = RSOInspectionDataStore | ||
|
||
def reset_post_sim_init(self) -> None: | ||
super().reset_post_sim_init() | ||
|
||
for i, observer in enumerate(self.scenario.observers): | ||
observer.data_store.set_storage_recorder( | ||
observer.dynamics.storageUnit.storageUnitDataOutMsg.recorder() | ||
) | ||
logger.debug( | ||
f"Logging {len(self.scenario.rso.dynamics.rso_points)} access points" | ||
) | ||
for rso_point_model in self.scenario.rso.dynamics.rso_points: | ||
observer.data_store.add_point_access_recorder( | ||
rso_point_model.accessOutMsgs[i].recorder() | ||
) | ||
|
||
def initial_data(self, satellite: Satellite) -> Data: | ||
if not issubclass(satellite.dyn_type, RSOImagingDynModel): | ||
return RSOInspectionData() | ||
|
||
return RSOInspectionData({point: False for point in self.scenario.rso_points}) | ||
|
||
def calculate_reward(self, new_data_dict: dict[str, Data]) -> dict[str, float]: | ||
return {} # TODO | ||
|
||
|
||
__doc_title__ = "RSO Inspection" | ||
__all__ = ["RSOInspectionReward", "RSOInspectionDataStore", "RSOInspectionData"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
"""TODO: Add docstring.""" | ||
|
||
import logging | ||
from abc import abstractmethod | ||
from dataclasses import dataclass | ||
from typing import TYPE_CHECKING | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from Basilisk.utilities import orbitalMotion | ||
|
||
from bsk_rl.scene import Scenario | ||
from bsk_rl.sim.dyn import RSODynModel, RSOImagingDynModel | ||
from bsk_rl.sim.fsw import RSOImagingFSWModel | ||
from bsk_rl.utils.orbital import lla2ecef | ||
|
||
if TYPE_CHECKING: # pragma: no cover | ||
from bsk_rl.data.base import Data | ||
from bsk_rl.sats import Satellite | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@dataclass | ||
class RSOPoint: | ||
r_PB_B: np.ndarray | ||
n_B: np.ndarray | ||
theta_min: float | ||
range: float | ||
|
||
def __hash__(self) -> int: | ||
"""Hash target by unique id.""" | ||
return hash(id(self)) # THIS IS ALMOST CERTAINLY A BAD IDEA | ||
|
||
|
||
class RSOPoints(Scenario): | ||
def reset_overwrite_previous(self) -> None: | ||
"""Overwrite target list from previous episode.""" | ||
self.rso_points = [] | ||
|
||
def reset_pre_sim_init(self) -> None: | ||
self.rso_points = self.generate_points() | ||
return super().reset_pre_sim_init() | ||
|
||
def reset_post_sim_init(self) -> None: | ||
# Check for RSOs and observers | ||
rsos = [sat for sat in self.satellites if isinstance(sat.dynamics, RSODynModel)] | ||
if len(rsos) == 0: | ||
logger.warning("No RSODynModel satellites found in scenario.") | ||
return | ||
assert len(rsos) == 1, "Only one RSODynModel satellite is supported." | ||
self.rso = rsos[0] | ||
|
||
self.observers = [ | ||
sat | ||
for sat in self.satellites | ||
if isinstance(sat.dynamics, RSOImagingDynModel) | ||
] | ||
if len(self.observers) == 0: | ||
logger.warning("No RSOImagingDynModel satellites found in scenario.") | ||
return | ||
|
||
# Add points to dynamics and fsw of RSO | ||
assert isinstance(self.rso.dynamics, RSODynModel) | ||
logger.debug("Adding inspection points to RSO and observers") | ||
for point in self.rso_points: | ||
rso_point_model = self.rso.dynamics.add_rso_point( | ||
point.r_PB_B, point.n_B, point.theta_min, point.range | ||
) | ||
# Add point to each observer | ||
for observer in self.observers: | ||
assert isinstance(observer.dynamics, RSOImagingDynModel) | ||
assert isinstance(observer.fsw, RSOImagingFSWModel) | ||
observer.dynamics.add_rso_point(rso_point_model) | ||
|
||
logger.debug("Targeting RSO with observers") | ||
for observer in self.observers: | ||
observer.fsw.set_target_rso(self.rso) | ||
|
||
@abstractmethod | ||
def generate_points(self) -> list[RSOPoint]: | ||
pass | ||
|
||
|
||
class FibonacciSphereRSOPoints(RSOPoints): | ||
def __init__( | ||
self, | ||
n_points: int = 100, | ||
radius: float = 1.0, | ||
theta_min: float = np.radians(45), | ||
range: float = -1, | ||
# incidence_min: float = np.radians(60), # TODO handle | ||
): | ||
self.n_points = n_points | ||
self.radius = radius | ||
self.theta_min = theta_min | ||
self.range = range | ||
# self.incidence_min = incidence_min | ||
|
||
def generate_points(self) -> list[RSOPoint]: | ||
points = [] | ||
|
||
# https://gist.github.com/Seanmatthews/a51ac697db1a4f58a6bca7996d75f68c | ||
ga = (3 - np.sqrt(5)) * np.pi # golden angle | ||
theta = ga * np.arange(self.n_points) | ||
z = np.linspace(1 / self.n_points - 1, 1 - 1 / self.n_points, self.n_points) | ||
radius = np.sqrt(1 - z * z) | ||
y = radius * np.sin(theta) | ||
x = radius * np.cos(theta) | ||
|
||
for i in range(self.n_points): | ||
r_PB_B = np.array([x[i], y[i], z[i]]) * self.radius | ||
n_B = np.array([x[i], y[i], z[i]]) | ||
points.append( | ||
RSOPoint( | ||
r_PB_B, | ||
n_B, | ||
self.theta_min, | ||
self.range, | ||
) | ||
) | ||
|
||
return points | ||
|
||
|
||
__doc_title__ = "RSO Scenarios" | ||
__all__ = ["RSOPoints"] |
Oops, something went wrong.