diff --git a/rubin_scheduler/scheduler/surveys/scripted_surveys.py b/rubin_scheduler/scheduler/surveys/scripted_surveys.py index 2dd01f75..353a3391 100644 --- a/rubin_scheduler/scheduler/surveys/scripted_surveys.py +++ b/rubin_scheduler/scheduler/surveys/scripted_surveys.py @@ -6,7 +6,7 @@ import numpy as np from rubin_scheduler.scheduler.surveys import BaseSurvey -from rubin_scheduler.scheduler.utils import ObservationArray, ScheduledObservationArray +from rubin_scheduler.scheduler.utils import obsarray_concat, ScheduledObservationArray from rubin_scheduler.utils import DEFAULT_NSIDE, _angular_separation, _approx_ra_dec2_alt_az log = logging.getLogger(__name__) @@ -318,7 +318,7 @@ def set_script(self, obs_wanted, append=True, add_index=True): self.id_start = self.script_id_array.max() + 1 if append & (self.obs_wanted is not None): - self.obs_wanted = np.concatenate([self.obs_wanted, obs_wanted]) + self.obs_wanted = obsarray_concat([self.obs_wanted, obs_wanted]) self.obs_wanted.sort(order=["mjd", "filter"]) else: self.obs_wanted = obs_wanted @@ -373,5 +373,6 @@ def generate_observations_rough(self, conditions): return self.observations self.last_mjd = conditions.mjd - + # Cache results, convert to ObservationArray + self.observations = observations.to_observation_array() return self.observations diff --git a/rubin_scheduler/scheduler/utils/utils.py b/rubin_scheduler/scheduler/utils/utils.py index 9c53f659..66274b23 100644 --- a/rubin_scheduler/scheduler/utils/utils.py +++ b/rubin_scheduler/scheduler/utils/utils.py @@ -793,6 +793,16 @@ def __new__(cls, n=1): obj = np.zeros(n, dtype=dtypes1 + dtype2).view(cls) return obj + def to_observation_array(self): + """Convert the scheduled observation to a + Regular ObservationArray + """ + result = ObservationArray(n=self.size) + in_common = np.intersect1d(self.dtype.names, result.dtype.names) + for key in in_common: + result[key] = self[key] + return result + def obsarray_concat(in_arrays): """Concatenate ObservationArray objects.