Skip to content

Commit

Permalink
FSW and dynamics w.i.p.
Browse files Browse the repository at this point in the history
  • Loading branch information
Mark2000 committed Dec 9, 2024
1 parent 52766ae commit 4ef3ca4
Show file tree
Hide file tree
Showing 8 changed files with 748 additions and 30 deletions.
335 changes: 310 additions & 25 deletions examples/continuous_orbit_manuevers.ipynb

Large diffs are not rendered by default.

24 changes: 19 additions & 5 deletions src/bsk_rl/act/continuous_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
from abc import abstractmethod
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

import numpy as np
from gymnasium import spaces
Expand Down Expand Up @@ -76,7 +76,13 @@ def set_action(self, action: np.ndarray) -> None:


class MagicThrust(ContinuousAction):
def __init__(self, name: str = "thrust_act", max_dv: float = float("inf")) -> None:
# TODO set the fsw mode to carry out after action
def __init__(
self,
name: str = "thrust_act",
max_dv: float = float("inf"),
fsw_action: Optional[str] = None,
) -> None:
"""Instantaneously change the satellite's velocity, and drift for some duration.
TODO: Support specifying frame of thrust.
Expand All @@ -87,6 +93,7 @@ def __init__(self, name: str = "thrust_act", max_dv: float = float("inf")) -> No
"""
super().__init__(name)
self.max_dv = max_dv
self.fsw_action = fsw_action

@property
def space(self) -> spaces.Box:
Expand All @@ -106,10 +113,17 @@ def action_description(self) -> list[str]:
def set_action(self, action: np.ndarray) -> None:
"""Thrust the satellite with a given inertial delta-V and drift for some duration."""
assert len(action) == 4, "Action must have 4 elements."
dv_N = action[0:3]
dt = action[3]

self.satellite.log_info(
f"Thrusting with inertial dV {action[0:3]} with {action[3]} second drift."
f"Thrusting with inertial dV {dv_N} with {dt} second drift."
)
self.satellite.fsw.action_magic_thrust(action[0:3])
self.satellite.fsw.action_magic_thrust(dv_N)
self.satellite.update_timed_terminal_event(
self.satellite.simulator.sim_time + action[3]
self.satellite.simulator.sim_time + dt
)

# Activate the FSW action for the drift period
getattr(self.satellite.fsw, self.fsw_action)()
self.satellite.log_info(f"FSW action {self.fsw_action} activated.")
2 changes: 2 additions & 0 deletions src/bsk_rl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
from bsk_rl.data.base import GlobalReward
from bsk_rl.data.nadir_data import ScanningTimeReward
from bsk_rl.data.no_data import NoReward
from bsk_rl.data.rso_data import RSOInspectionReward
from bsk_rl.data.unique_image_data import UniqueImageReward

__doc_title__ = "Data & Reward"
Expand All @@ -87,4 +88,5 @@
"NoReward",
"UniqueImageReward",
"ScanningTimeReward",
"RSOInspectionReward",
]
146 changes: 146 additions & 0 deletions src/bsk_rl/data/rso_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""Data system for recording RSO surface."""

import logging
from typing import TYPE_CHECKING, Optional

import numpy as np

from bsk_rl.data.base import Data, DataStore, GlobalReward
from bsk_rl.sats import Satellite
from bsk_rl.scene.rso_points import RSOPoint
from bsk_rl.sim.dyn import RSODynModel, RSOImagingDynModel

if TYPE_CHECKING:
from bsk_rl.sats import Satellite

logger = logging.getLogger(__name__)

RSO = "rso"
OBSERVER = "observer"


class RSOInspectionData(Data):
def __init__(self, point_inspect_status: Optional[dict[RSOPoint, bool]] = None):
if point_inspect_status is None:
point_inspect_status = {}
self.point_inspect_status = point_inspect_status

def __add__(self, other: "RSOInspectionData"):
point_inspect_status = {}
point_inspect_status.update(self.point_inspect_status)
for point, access in other.point_inspect_status.items():
if point not in point_inspect_status:
point_inspect_status[point] = access
else:
point_inspect_status[point] = point_inspect_status[point] or access

return RSOInspectionData(point_inspect_status)


class RSOInspectionDataStore(DataStore):
data_type = RSOInspectionData

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.point_access_recorders = []
self.storage_recorder = None

if issubclass(self.satellite.dyn_type, RSOImagingDynModel):
self.role = OBSERVER
else:
self.role = RSO

def set_storage_recorder(self, recorder):
self.storage_recorder = recorder
self.satellite.simulator.AddModelToTask(
self.satellite.dynamics.task_name, recorder, ModelPriority=1000
)

def add_point_access_recorder(self, recorder):
self.point_access_recorders.append(recorder)
self.satellite.simulator.AddModelToTask(
self.satellite.dynamics.task_name, recorder, ModelPriority=1000
)

def clear_recorders(self):
if self.storage_recorder:
self.storage_recorder.clear()
for recorder in self.point_access_recorders:
recorder.clear()

def get_log_state(self) -> list[list[bool]]:
"""Log the storage unit state and point access state for all times in the step.
Returns:
todo
"""
if self.role == RSO:
return None

log_len = len(self.storage_recorder.storageLevel)
if log_len <= 1:
imaging_req = np.zeros(log_len)
else:
imaging_req = np.diff(self.storage_recorder.storageLevel)
imaging_req = np.concatenate((imaging_req, [imaging_req[-1]]))

inspected_logs = []
for recorder in self.point_access_recorders:
# inspected = np.logical_and(imaging_req, recorder.hasAccess)
inspected = recorder.hasAccess
inspected_logs.append(list(inspected))

self.clear_recorders()

return inspected_logs

def compare_log_states(self, _, inspected_logs) -> Data:
if self.role == RSO:
return RSOInspectionData()

point_inspect_status = {}
for rso_point, log in zip(
self.data.point_inspect_status.keys(), inspected_logs
):
if any(log):
print(log)
point_inspect_status[rso_point] = True

if len(point_inspect_status) > 0:
self.satellite.logger.info(
f"Inspected {len(point_inspect_status)} points this step"
)

return RSOInspectionData(point_inspect_status)


class RSOInspectionReward(GlobalReward):
datastore_type = RSOInspectionDataStore

def reset_post_sim_init(self) -> None:
super().reset_post_sim_init()

for i, observer in enumerate(self.scenario.observers):
observer.data_store.set_storage_recorder(
observer.dynamics.storageUnit.storageUnitDataOutMsg.recorder()
)
logger.debug(
f"Logging {len(self.scenario.rso.dynamics.rso_points)} access points"
)
for rso_point_model in self.scenario.rso.dynamics.rso_points:
observer.data_store.add_point_access_recorder(
rso_point_model.accessOutMsgs[i].recorder()
)

def initial_data(self, satellite: Satellite) -> Data:
if not issubclass(satellite.dyn_type, RSOImagingDynModel):
return RSOInspectionData()

return RSOInspectionData({point: False for point in self.scenario.rso_points})

def calculate_reward(self, new_data_dict: dict[str, Data]) -> dict[str, float]:
return {} # TODO


__doc_title__ = "RSO Inspection"
__all__ = ["RSOInspectionReward", "RSOInspectionDataStore", "RSOInspectionData"]
5 changes: 5 additions & 0 deletions src/bsk_rl/scene/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
"""

from bsk_rl.scene.scenario import Scenario, UniformNadirScanning

pass # Other imports must come after Scenario
from bsk_rl.scene.rso_points import FibonacciSphereRSOPoints, RSOPoints
from bsk_rl.scene.targets import CityTargets, UniformTargets

__doc_title__ = "Scenario"
Expand All @@ -17,4 +20,6 @@
"UniformTargets",
"CityTargets",
"UniformNadirScanning",
"RSOPoints",
"FibonacciSphereRSOPoints",
]
127 changes: 127 additions & 0 deletions src/bsk_rl/scene/rso_points.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""TODO: Add docstring."""

import logging
from abc import abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING

import numpy as np
import pandas as pd
from Basilisk.utilities import orbitalMotion

from bsk_rl.scene import Scenario
from bsk_rl.sim.dyn import RSODynModel, RSOImagingDynModel
from bsk_rl.sim.fsw import RSOImagingFSWModel
from bsk_rl.utils.orbital import lla2ecef

if TYPE_CHECKING: # pragma: no cover
from bsk_rl.data.base import Data
from bsk_rl.sats import Satellite

logger = logging.getLogger(__name__)


@dataclass
class RSOPoint:
r_PB_B: np.ndarray
n_B: np.ndarray
theta_min: float
range: float

def __hash__(self) -> int:
"""Hash target by unique id."""
return hash(id(self)) # THIS IS ALMOST CERTAINLY A BAD IDEA


class RSOPoints(Scenario):
def reset_overwrite_previous(self) -> None:
"""Overwrite target list from previous episode."""
self.rso_points = []

def reset_pre_sim_init(self) -> None:
self.rso_points = self.generate_points()
return super().reset_pre_sim_init()

def reset_post_sim_init(self) -> None:
# Check for RSOs and observers
rsos = [sat for sat in self.satellites if isinstance(sat.dynamics, RSODynModel)]
if len(rsos) == 0:
logger.warning("No RSODynModel satellites found in scenario.")
return
assert len(rsos) == 1, "Only one RSODynModel satellite is supported."
self.rso = rsos[0]

self.observers = [
sat
for sat in self.satellites
if isinstance(sat.dynamics, RSOImagingDynModel)
]
if len(self.observers) == 0:
logger.warning("No RSOImagingDynModel satellites found in scenario.")
return

# Add points to dynamics and fsw of RSO
assert isinstance(self.rso.dynamics, RSODynModel)
logger.debug("Adding inspection points to RSO and observers")
for point in self.rso_points:
rso_point_model = self.rso.dynamics.add_rso_point(
point.r_PB_B, point.n_B, point.theta_min, point.range
)
# Add point to each observer
for observer in self.observers:
assert isinstance(observer.dynamics, RSOImagingDynModel)
assert isinstance(observer.fsw, RSOImagingFSWModel)
observer.dynamics.add_rso_point(rso_point_model)

logger.debug("Targeting RSO with observers")
for observer in self.observers:
observer.fsw.set_target_rso(self.rso)

@abstractmethod
def generate_points(self) -> list[RSOPoint]:
pass


class FibonacciSphereRSOPoints(RSOPoints):
def __init__(
self,
n_points: int = 100,
radius: float = 1.0,
theta_min: float = np.radians(45),
range: float = -1,
# incidence_min: float = np.radians(60), # TODO handle
):
self.n_points = n_points
self.radius = radius
self.theta_min = theta_min
self.range = range
# self.incidence_min = incidence_min

def generate_points(self) -> list[RSOPoint]:
points = []

# https://gist.github.com/Seanmatthews/a51ac697db1a4f58a6bca7996d75f68c
ga = (3 - np.sqrt(5)) * np.pi # golden angle
theta = ga * np.arange(self.n_points)
z = np.linspace(1 / self.n_points - 1, 1 - 1 / self.n_points, self.n_points)
radius = np.sqrt(1 - z * z)
y = radius * np.sin(theta)
x = radius * np.cos(theta)

for i in range(self.n_points):
r_PB_B = np.array([x[i], y[i], z[i]]) * self.radius
n_B = np.array([x[i], y[i], z[i]])
points.append(
RSOPoint(
r_PB_B,
n_B,
self.theta_min,
self.range,
)
)

return points


__doc_title__ = "RSO Scenarios"
__all__ = ["RSOPoints"]
Loading

0 comments on commit 4ef3ca4

Please sign in to comment.