Skip to content

Commit

Permalink
unit tests and dark sky basis function upgrade
Browse files Browse the repository at this point in the history
  • Loading branch information
yoachim committed Jan 12, 2024
1 parent 9354842 commit 2f4db6d
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 36 deletions.
32 changes: 10 additions & 22 deletions rubin_scheduler/scheduler/basis_functions/basis_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,8 @@

from rubin_scheduler.scheduler import features, utils
from rubin_scheduler.scheduler.utils import IntRounded
from rubin_scheduler.site_models import SeeingModel
from rubin_scheduler.skybrightness_pre import dark_sky
from rubin_scheduler.utils import _hpid2_ra_dec, m5_flat_sed
from rubin_scheduler.skybrightness_pre import dark_m5
from rubin_scheduler.utils import _hpid2_ra_dec

Check warning on line 57 in rubin_scheduler/scheduler/basis_functions/basis_functions.py

View check run for this annotation

Codecov / codecov/patch

rubin_scheduler/scheduler/basis_functions/basis_functions.py#L56-L57

Added lines #L56 - L57 were not covered by tests


class BaseBasisFunction:
Expand Down Expand Up @@ -397,11 +396,15 @@ def __init__(
nside=nside,
)
self.result = np.zeros(hp.nside2npix(self.nside))
if self.filtername is not None:
self.dark_map = dark_sky(nside)[filtername]
self.dark_map = None

Check warning on line 399 in rubin_scheduler/scheduler/basis_functions/basis_functions.py

View check run for this annotation

Codecov / codecov/patch

rubin_scheduler/scheduler/basis_functions/basis_functions.py#L399

Added line #L399 was not covered by tests
self.footprint = footprint

def _calc_value(self, conditions, indx=None):
if self.filtername is not None:
if self.dark_map is None:
self.dark_map = dark_m5(

Check warning on line 405 in rubin_scheduler/scheduler/basis_functions/basis_functions.py

View check run for this annotation

Codecov / codecov/patch

rubin_scheduler/scheduler/basis_functions/basis_functions.py#L403-L405

Added lines #L403 - L405 were not covered by tests
conditions.dec, self.filtername, conditions.site.latitude_rad, fiducial_FWHMEff=0.7
)
result = 0
# Need to update the feature to the current season
self.survey_features["N_good_seeing"].season_update(conditions=conditions)
Expand Down Expand Up @@ -952,29 +955,14 @@ class M5DiffBasisFunction(BaseBasisFunction):
def __init__(self, filtername="r", fiducial_FWHMEff=0.7, nside=None):
super().__init__(nside=nside, filtername=filtername)
# The dark sky surface brightness values
self.dark_sky = dark_sky(nside)[filtername]
self.dark_map = None
self.fiducial_FWHMEff = fiducial_FWHMEff
self.filtername = filtername

def _calc_value(self, conditions, indx=None):
if self.dark_map is None:
# compute the maximum altitude each HEALpix reaches,
# this lets us determine the dark sky values with appropriate seeing
# for each declination.
min_z = np.abs(conditions.dec - conditions.site.latitude_rad)
airmass_min = 1 / np.cos(min_z)
airmass_min = np.where(airmass_min < 0, np.nan, airmass_min)
sm = SeeingModel(filter_list=[self.filtername])
fwhm_eff = sm(self.fiducial_FWHMEff, airmass_min)["fwhmEff"][0]
self.dark_map = m5_flat_sed(
self.filtername,
musky=self.dark_sky,
fwhm_eff=fwhm_eff,
exp_time=30.0,
airmass=airmass_min,
nexp=1,
tau_cloud=0,
self.dark_map = dark_m5(

Check warning on line 964 in rubin_scheduler/scheduler/basis_functions/basis_functions.py

View check run for this annotation

Codecov / codecov/patch

rubin_scheduler/scheduler/basis_functions/basis_functions.py#L964

Added line #L964 was not covered by tests
conditions.dec, self.filtername, conditions.site.latitude_rad, self.fiducial_FWHMEff
)

# No way to get the sign on this right the first time.
Expand Down
4 changes: 2 additions & 2 deletions rubin_scheduler/scheduler/surveys/dd_surveys.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def generate_dd_surveys(
frac_total=frac_total,
aggressive_frac=aggressive_frac,
delays=delays,
nside=nside
nside=nside,
)

surveys.append(
Expand Down Expand Up @@ -321,7 +321,7 @@ def generate_dd_surveys(
frac_total=frac_total,
aggressive_frac=aggressive_frac,
delays=delays,
nside=nside
nside=nside,
)
surveys.append(
DeepDrillingSurvey(
Expand Down
9 changes: 7 additions & 2 deletions rubin_scheduler/scheduler/surveys/pointings_survey.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import pandas as pd

Check warning on line 5 in rubin_scheduler/scheduler/surveys/pointings_survey.py

View check run for this annotation

Codecov / codecov/patch

rubin_scheduler/scheduler/surveys/pointings_survey.py#L3-L5

Added lines #L3 - L5 were not covered by tests

from rubin_scheduler.scheduler.detailers import ParallacticRotationDetailer
from rubin_scheduler.scheduler.utils import IntRounded
from rubin_scheduler.skybrightness_pre import dark_m5
from rubin_scheduler.utils import _angular_separation, _approx_ra_dec2_alt_az

Check warning on line 10 in rubin_scheduler/scheduler/surveys/pointings_survey.py

View check run for this annotation

Codecov / codecov/patch

rubin_scheduler/scheduler/surveys/pointings_survey.py#L7-L10

Added lines #L7 - L10 were not covered by tests
from rubin_scheduler.scheduler.utils import IntRounded

from .base_survey import BaseSurvey

Check warning on line 12 in rubin_scheduler/scheduler/surveys/pointings_survey.py

View check run for this annotation

Codecov / codecov/patch

rubin_scheduler/scheduler/surveys/pointings_survey.py#L12

Added line #L12 was not covered by tests

Expand Down Expand Up @@ -200,7 +200,12 @@ def ha_limit(self, conditions):
"""Apply hour angle limits."""
result = self.zeros.copy()

Check warning on line 201 in rubin_scheduler/scheduler/surveys/pointings_survey.py

View check run for this annotation

Codecov / codecov/patch

rubin_scheduler/scheduler/surveys/pointings_survey.py#L201

Added line #L201 was not covered by tests
# apply hour angle limits
result[np.where((IntRounded(self.ha) > IntRounded(self.ha_max)) & (IntRounded(self.ha) < IntRounded(self.ha_min)))] = np.nan
result[

Check warning on line 203 in rubin_scheduler/scheduler/surveys/pointings_survey.py

View check run for this annotation

Codecov / codecov/patch

rubin_scheduler/scheduler/surveys/pointings_survey.py#L203

Added line #L203 was not covered by tests
np.where(
(IntRounded(self.ha) > IntRounded(self.ha_max))
& (IntRounded(self.ha) < IntRounded(self.ha_min))
)
] = np.nan
return result

Check warning on line 209 in rubin_scheduler/scheduler/surveys/pointings_survey.py

View check run for this annotation

Codecov / codecov/patch

rubin_scheduler/scheduler/surveys/pointings_survey.py#L209

Added line #L209 was not covered by tests

def alt_limit(self, conditions):

Check warning on line 211 in rubin_scheduler/scheduler/surveys/pointings_survey.py

View check run for this annotation

Codecov / codecov/patch

rubin_scheduler/scheduler/surveys/pointings_survey.py#L211

Added line #L211 was not covered by tests
Expand Down
2 changes: 1 addition & 1 deletion rubin_scheduler/scheduler/surveys/scripted_surveys.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np

from rubin_scheduler.scheduler.surveys import BaseSurvey
from rubin_scheduler.scheduler.utils import empty_observation, set_default_nside, IntRounded
from rubin_scheduler.scheduler.utils import IntRounded, empty_observation, set_default_nside
from rubin_scheduler.utils import _approx_ra_dec2_alt_az

Check warning on line 10 in rubin_scheduler/scheduler/surveys/scripted_surveys.py

View check run for this annotation

Codecov / codecov/patch

rubin_scheduler/scheduler/surveys/scripted_surveys.py#L9-L10

Added lines #L9 - L10 were not covered by tests

log = logging.getLogger(__name__)
Expand Down
99 changes: 91 additions & 8 deletions tests/scheduler/test_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,52 @@
from rubin_scheduler.scheduler.example import example_scheduler
from rubin_scheduler.scheduler.model_observatory import ModelObservatory
from rubin_scheduler.scheduler.schedulers import CoreScheduler
from rubin_scheduler.scheduler.surveys import BlobSurvey, GreedySurvey, generate_dd_surveys
from rubin_scheduler.scheduler.surveys import (
BlobSurvey,
GreedySurvey,
ScriptedSurvey,
generate_ddf_scheduled_obs,
)
from rubin_scheduler.scheduler.utils import SkyAreaGenerator, calc_norm_factor_array

SAMPLE_BIG_DATA_FILE = os.path.join(get_data_dir(), "scheduler/dust_maps/dust_nside_32.npz")


class ModelObservatoryWindy(ModelObservatory):
"""Have the model observatory always have a strong wind from the north"""

def return_conditions(self):
"""
Returns
-------
rubin_scheduler.scheduler.features.conditions object
"""
_conditions = super().return_conditions()

# Always have a strong wind from the north
wind_speed = 40.0
wind_direction = 0.0
self.conditions.wind_speed = wind_speed
self.conditions.wind_direction = wind_direction

return self.conditions


def ddf_surveys(detailers=None, season_unobs_frac=0.2, euclid_detailers=None, nside=None):
obs_array = generate_ddf_scheduled_obs(season_unobs_frac=season_unobs_frac)

euclid_obs = np.where((obs_array["note"] == "DD:EDFS_b") | (obs_array["note"] == "DD:EDFS_a"))[0]
all_other = np.where((obs_array["note"] != "DD:EDFS_b") & (obs_array["note"] != "DD:EDFS_a"))[0]

survey1 = ScriptedSurvey([bf.AvoidDirectWind(nside=nside)], detailers=detailers)
survey1.set_script(obs_array[all_other])

survey2 = ScriptedSurvey([bf.AvoidDirectWind(nside=nside)], detailers=euclid_detailers)
survey2.set_script(obs_array[euclid_obs])

return [survey1, survey2]


def gen_greedy_surveys(nside):
"""
Make a quick set of greedy surveys
Expand All @@ -40,14 +80,15 @@ def gen_greedy_surveys(nside):
bfs.append(bf.SlewtimeBasisFunction(filtername=filtername, nside=nside))
bfs.append(bf.StrictFilterBasisFunction(filtername=filtername))
# Masks, give these 0 weight
bfs.append(bf.AvoidDirectWind(nside=nside))
bfs.append(bf.ZenithShadowMaskBasisFunction(nside=nside, shadow_minutes=60.0, max_alt=76.0))
bfs.append(bf.MoonAvoidanceBasisFunction(nside=nside, moon_distance=30.0))
bfs.append(bf.CloudedOutBasisFunction())

bfs.append(bf.FilterLoadedBasisFunction(filternames=filtername))
bfs.append(bf.PlanetMaskBasisFunction(nside=nside))

weights = np.array([3.0, 0.3, 3.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0])
weights = np.array([3.0, 0.3, 3.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
surveys.append(
GreedySurvey(
bfs,
Expand Down Expand Up @@ -102,6 +143,7 @@ def gen_blob_surveys(nside):
bfs.append(bf.SlewtimeBasisFunction(filtername=filtername, nside=nside))
bfs.append(bf.StrictFilterBasisFunction(filtername=filtername))
# Masks, give these 0 weight
bfs.append(bf.AvoidDirectWind(nside=nside))
bfs.append(bf.ZenithShadowMaskBasisFunction(nside=nside, shadow_minutes=60.0, max_alt=76.0))
bfs.append(bf.MoonAvoidanceBasisFunction(nside=nside, moon_distance=30.0))
bfs.append(bf.CloudedOutBasisFunction())
Expand All @@ -112,16 +154,18 @@ def gen_blob_surveys(nside):
bfs.append(bf.NotTwilightBasisFunction())
bfs.append(bf.PlanetMaskBasisFunction(nside=nside))

weights = np.array([3.0, 3.0, 0.3, 0.3, 3.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
weights = np.array([3.0, 3.0, 0.3, 0.3, 3.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
if filtername2 is None:
# Need to scale weights up so filter balancing works properly.
weights = np.array([6.0, 0.6, 3.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
weights = np.array([6.0, 0.6, 3.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
if filtername2 is None:
survey_name = "blob, %s" % filtername
else:
survey_name = "blob, %s%s" % (filtername, filtername2)
if filtername2 is not None:
detailer_list.append(detailers.TakeAsPairsDetailer(filtername=filtername2))

detailer_list.append(detailers.FlushByDetailer())
pair_surveys.append(
BlobSurvey(
bfs,
Expand All @@ -143,7 +187,7 @@ def test_example(self):
"""Try out the example scheduler."""
mjd_start = utils.survey_start_mjd()
nside = 32
survey_length = 2.0 # days
survey_length = 4.0 # days
scheduler = example_scheduler(nside=nside, mjd_start=mjd_start)
observatory = ModelObservatory(nside=nside, mjd_start=mjd_start)
observatory, scheduler, observations = sim_runner(
Expand All @@ -157,6 +201,8 @@ def test_example(self):
assert observations.size > 1000
# Make sure nothing tried to look through the earth
assert np.min(observations["alt"]) > 0
# Make sure a DDF executed
assert np.any(["DD" in note for note in observations["note"]])


class TestFeatures(unittest.TestCase):
Expand All @@ -175,7 +221,7 @@ def test_greedy(self):
# surveys.append(Pairs_survey_scripted(None, ignore_obs='DD'))

# Set up the DD
dd_surveys = generate_dd_surveys(nside=nside)
dd_surveys = ddf_surveys(nside=nside)
surveys.extend(dd_surveys)

scheduler = CoreScheduler(surveys, nside=nside)
Expand All @@ -202,7 +248,7 @@ def test_blobs(self):

surveys = []
# Set up the DD
dd_surveys = generate_dd_surveys(nside=nside)
dd_surveys = ddf_surveys(nside=nside)
surveys.append(dd_surveys)

surveys.append(gen_blob_surveys(nside))
Expand All @@ -224,6 +270,43 @@ def test_blobs(self):
# Make sure nothing tried to look through the earth
assert np.min(observations["alt"]) > 0

@unittest.skipUnless(os.path.isfile(SAMPLE_BIG_DATA_FILE), "Test data not available.")
def test_wind(self):
"""
Test that a wind mask prevent things from being executed in the wrong spot
"""
mjd_start = utils.survey_start_mjd()
nside = 32
survey_length = 4.0 # days

surveys = []
# Set up the DD
dd_surveys = ddf_surveys(nside=nside)
surveys.append(dd_surveys)

surveys.append(gen_blob_surveys(nside))
surveys.append(gen_greedy_surveys(nside))

scheduler = CoreScheduler(surveys, nside=nside)
observatory = ModelObservatoryWindy(nside=nside, mjd_start=mjd_start)
observatory, scheduler, observations = sim_runner(
observatory, scheduler, survey_length=survey_length, filename=None
)

# Make sure some blobs executed
assert "blob, gg, b" in observations["note"]
assert "blob, gg, a" in observations["note"]
# Make sure some greedy executed
assert "greedy" in observations["note"]
# Make sure lots of observations executed
assert observations.size > 1000
# Make sure nothing tried to look through the earth
assert np.min(observations["alt"]) > 0

# Make sure nothing executed in the strong wind
assert np.min(np.degrees(observations["az"])) > 30.0
assert np.max(np.degrees(observations["az"])) < (360.0 - 30.0)

@unittest.skipUnless(os.path.isfile(SAMPLE_BIG_DATA_FILE), "Test data not available.")
def test_nside(self):
"""
Expand All @@ -235,7 +318,7 @@ def test_nside(self):

surveys = []
# Set up the DD
dd_surveys = generate_dd_surveys(nside=nside)
dd_surveys = ddf_surveys(nside=nside)
surveys.append(dd_surveys)

surveys.append(gen_blob_surveys(nside))
Expand Down
39 changes: 38 additions & 1 deletion tests/scheduler/test_surveys.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import rubin_scheduler.scheduler.surveys as surveys
from rubin_scheduler.scheduler.basis_functions import SimpleArrayBasisFunction
from rubin_scheduler.scheduler.model_observatory import ModelObservatory
from rubin_scheduler.scheduler.utils import set_default_nside
from rubin_scheduler.scheduler.utils import empty_observation, set_default_nside


class TestSurveys(unittest.TestCase):
Expand All @@ -34,6 +34,43 @@ def test_field_survey(self):
self.assertIsInstance(reward_df, pd.DataFrame)
reward_df = survey.make_reward_df(conditions, accum=False)

def test_pointings_survey(self):
"""Test the pointing survey."""
mo = ModelObservatory()
conditions = mo.return_conditions()

# Make a ring of points near the equator so
# some should always be visible
fields = empty_observation(n=10)
fields["RA"] = np.arange(0, fields.size) / fields.size * 2.0 * np.pi
fields["dec"] = -0.01
fields["note"] = ["test%i" % ind for ind in range(fields.size)]
fields["filter"] = "r"
survey = surveys.PointingsSurvey(fields)

reward = survey.calc_reward_function(conditions)
assert np.isfinite(reward)

obs = survey.generate_observations(conditions)
# Confirm that our desired input values got passed through
assert obs[0]["dec"] < 0
assert obs[0]["note"][0][0:4] == "test"

# Adding observations
assert np.sum(survey.n_obs) == 0
survey.add_observation(obs[0])
assert np.sum(survey.n_obs) == 1
survey.add_observations_array(fields, None)
assert np.sum(survey.n_obs) == 11

# Check we can get display things out
rc = survey.reward_changes(conditions)
assert len(rc) == len(survey.weights)

# Check we get a dataFrame
df = survey.make_reward_df(conditions)
assert len(df) == len(survey.weights)

def test_roi(self):
random_seed = 6563
infeasible_hpix = 123
Expand Down

0 comments on commit 2f4db6d

Please sign in to comment.