Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SP-1841: Backwards compatible updates for Band/Filter #149

Merged
merged 1 commit into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 87 additions & 134 deletions rubin_scheduler/scheduler/basis_functions/basis_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
"M5DiffBasisFunction",
"M5DiffAtHpixBasisFunction",
"StrictBandBasisFunction",
"StrictFilterBasisFunction",
"BandChangeBasisFunction",
"FilterChangeBasisFunction",
"SlewtimeBasisFunction",
"CadenceEnhanceBasisFunction",
"CadenceEnhanceTrapezoidBasisFunction",
Expand All @@ -21,15 +23,14 @@
"NObsPerYearBasisFunction",
"CadenceInSeasonBasisFunction",
"NearSunHighAirmassBasisFunction",
"NObsHighAmBasisFunction",
"GoodSeeingBasisFunction",
"EclipticBasisFunction",
"VisitGap",
"NGoodSeeingBasisFunction",
"AvoidDirectWind",
"BalanceVisits",
"RewardNObsSequence",
"BandDistBasisFunction",
"FilterDistBasisFunction",
"RewardRisingBasisFunction",
"send_unused_deprecation_warning",
)
Expand Down Expand Up @@ -61,7 +62,7 @@
"""Class that takes features and computes a reward function when
called."""

def __init__(self, nside=DEFAULT_NSIDE, bandname=None, **kwargs):
def __init__(self, nside=DEFAULT_NSIDE, bandname=None, filtername=None, **kwargs):
# Set if basis function needs to be recalculated if there is a new
# observation
self.update_on_newobs = True
Expand All @@ -86,6 +87,13 @@
else:
self.nside = nside

if filtername is not None:
warnings.warn(

Check warning on line 91 in rubin_scheduler/scheduler/basis_functions/basis_functions.py

View check run for this annotation

Codecov / codecov/patch

rubin_scheduler/scheduler/basis_functions/basis_functions.py#L91

Added line #L91 was not covered by tests
"Use of `filtername` will be deprecated in favor of `bandname` at v4", FutureWarning
)
bandname = filtername

Check warning on line 94 in rubin_scheduler/scheduler/basis_functions/basis_functions.py

View check run for this annotation

Codecov / codecov/patch

rubin_scheduler/scheduler/basis_functions/basis_functions.py#L94

Added line #L94 was not covered by tests
# Save filtername as a backup in case someone tries to access it
self.filtername = filtername

Check warning on line 96 in rubin_scheduler/scheduler/basis_functions/basis_functions.py

View check run for this annotation

Codecov / codecov/patch

rubin_scheduler/scheduler/basis_functions/basis_functions.py#L96

Added line #L96 was not covered by tests
self.bandname = bandname

def add_observations_array(self, observations_array, observations_hpid):
Expand Down Expand Up @@ -291,6 +299,14 @@
return result


class FilterDistBasisFunction(BandDistBasisFunction):
"""Deprecated version of BandDistBasisFunction"""

def __init__(self, filtername="r"):
warnings.warn("FilterDistBasisFunction deprecated for BandDistBasisFunction", FutureWarning)
super().__init__(bandname=filtername)


class NObsPerYearBasisFunction(BaseBasisFunction):
"""Reward areas that have not been observed N-times in the last year

Expand Down Expand Up @@ -325,7 +341,11 @@
season_start_hour=-4.0,
season_end_hour=2.0,
night_max=365,
filtername=None,
):
if filtername is not None:
warnings.warn("filtername deprecated in favor of bandname", FutureWarning)
bandname = filtername

Check warning on line 348 in rubin_scheduler/scheduler/basis_functions/basis_functions.py

View check run for this annotation

Codecov / codecov/patch

rubin_scheduler/scheduler/basis_functions/basis_functions.py#L347-L348

Added lines #L347 - L348 were not covered by tests
super(NObsPerYearBasisFunction, self).__init__(nside=nside, bandname=bandname)
self.footprint = footprint
self.n_obs = n_obs
Expand Down Expand Up @@ -406,7 +426,11 @@
n_obs_desired=3,
mjd_start=None,
footprint=None,
filtername=None,
):
if filtername is not None:
warnings.warn("filtername deprecated in favor of bandname", FutureWarning)
bandname = filtername

Check warning on line 433 in rubin_scheduler/scheduler/basis_functions/basis_functions.py

View check run for this annotation

Codecov / codecov/patch

rubin_scheduler/scheduler/basis_functions/basis_functions.py#L432-L433

Added lines #L432 - L433 were not covered by tests
super().__init__(nside=nside, bandname=bandname)
self.seeing_fwhm_max = seeing_fwhm_max
self.m5_penalty_max = m5_penalty_max
Expand Down Expand Up @@ -464,76 +488,6 @@
return az_rel_moon


class NObsHighAmBasisFunction(BaseBasisFunction):
"""Reward only reward/count observations at high airmass."""

def __init__(
self,
nside=DEFAULT_NSIDE,
bandname="r",
footprint=None,
n_obs=3,
season=300.0,
am_limits=[1.5, 2.2],
out_of_bounds_val=np.nan,
):
send_unused_deprecation_warning("NObsHighAmBasisFunction")
return
super(NObsHighAmBasisFunction, self).__init__(nside=nside, bandname=bandname)
if footprint is None:
footprints, labels = get_current_footprint(self.nside)
footprint = footprints[self.bandname]
self.footprint = footprint
self.out_footprint = np.where((footprint == 0) | np.isnan(footprint))
self.am_limits = am_limits
self.season = season
self.survey_features["last_n_mjds"] = features.Last_n_obs_times(
nside=nside, bandname=bandname, n_obs=n_obs
)

self.result = np.zeros(hp.nside2npix(self.nside), dtype=float) + out_of_bounds_val
self.out_of_bounds_val = out_of_bounds_val

def add_observation(self, observation, indx=None):
# Only count the observations if they are at the airmass limits
if (observation["airmass"] > np.min(self.am_limits)) & (
observation["airmass"] < np.max(self.am_limits)
):
for feature in self.survey_features:
self.survey_features[feature].add_observation(observation, indx=indx)
if self.update_on_newobs:
self.recalc = True

def check_feasibility(self, conditions):
result = True
reward = self._calc_value(conditions)
# If there are no non-NaN values, we're not feasible now
if True not in np.isfinite(reward):
result = False

return result

def _calc_value(self, conditions, indx=None):
result = self.result.copy()
behind_pix = np.where(
(
IntRounded(conditions.mjd - self.survey_features["last_n_mjds"].feature[0])
> IntRounded(self.season)
)
& (IntRounded(conditions.airmass) > IntRounded(np.min(self.am_limits)))
& (IntRounded(conditions.airmass) < IntRounded(np.max(self.am_limits)))
)
result[behind_pix] = 1
result[self.out_footprint] = self.out_of_bounds_val

# Update the last time we had an mjd
self.mjd_last = conditions.mjd + 0
self.recalc = False
self.value = result

return result


class EclipticBasisFunction(BaseBasisFunction):
"""Mark the area around the ecliptic"""

Expand Down Expand Up @@ -566,7 +520,12 @@
How long to wait before activating the basis function (days).
"""

def __init__(self, drive_map, bandname="griz", season_span=2.5, cadence=2.5, nside=DEFAULT_NSIDE):
def __init__(
self, drive_map, bandname="griz", season_span=2.5, cadence=2.5, nside=DEFAULT_NSIDE, filtername=None
):
if filtername is not None:
warnings.warn("filtername deprecated in favor of bandname", FutureWarning)
bandname = filtername

Check warning on line 528 in rubin_scheduler/scheduler/basis_functions/basis_functions.py

View check run for this annotation

Codecov / codecov/patch

rubin_scheduler/scheduler/basis_functions/basis_functions.py#L527-L528

Added lines #L527 - L528 were not covered by tests
super(CadenceInSeasonBasisFunction, self).__init__(nside=nside, bandname=bandname)
self.drive_map = drive_map
self.season_span = season_span / 12.0 * np.pi # To radians
Expand Down Expand Up @@ -630,7 +589,11 @@
n_per_season=3,
mjd_start=None,
season_frac_start=0.5,
filtername=None,
):
if filtername is not None:
warnings.warn("filtername deprecated in favor of bandname", FutureWarning)
bandname = filtername

Check warning on line 596 in rubin_scheduler/scheduler/basis_functions/basis_functions.py

View check run for this annotation

Codecov / codecov/patch

rubin_scheduler/scheduler/basis_functions/basis_functions.py#L595-L596

Added lines #L595 - L596 were not covered by tests
send_unused_deprecation_warning("SeasonCoverageBasisFunction")
super().__init__(nside=nside, bandname=bandname)

Expand Down Expand Up @@ -693,7 +656,10 @@
Will be masked if set to np.nan (default).
"""

def __init__(self, bandname="r", nside=DEFAULT_NSIDE, gap_min=25.0, penalty_val=np.nan):
def __init__(self, bandname="r", nside=DEFAULT_NSIDE, gap_min=25.0, penalty_val=np.nan, filtername=None):
if filtername is not None:
warnings.warn("filtername deprecated in favor of bandname", FutureWarning)
bandname = filtername

Check warning on line 662 in rubin_scheduler/scheduler/basis_functions/basis_functions.py

View check run for this annotation

Codecov / codecov/patch

rubin_scheduler/scheduler/basis_functions/basis_functions.py#L661-L662

Added lines #L661 - L662 were not covered by tests
super().__init__(nside=nside, bandname=bandname)

self.bandname = bandname
Expand Down Expand Up @@ -767,7 +733,12 @@
The number of pairs of observations to attempt to gather
"""

def __init__(self, gap_min=25.0, gap_max=45.0, bandname="r", nside=DEFAULT_NSIDE, npairs=1):
def __init__(
self, gap_min=25.0, gap_max=45.0, bandname="r", nside=DEFAULT_NSIDE, npairs=1, filtername=None
):
if filtername is not None:
warnings.warn("filtername deprecated in favor of bandname", FutureWarning)
bandname = filtername

Check warning on line 741 in rubin_scheduler/scheduler/basis_functions/basis_functions.py

View check run for this annotation

Codecov / codecov/patch

rubin_scheduler/scheduler/basis_functions/basis_functions.py#L740-L741

Added lines #L740 - L741 were not covered by tests
super(VisitRepeatBasisFunction, self).__init__(nside=nside, bandname=bandname)

self.gap_min = IntRounded(gap_min / 60.0 / 24.0)
Expand Down Expand Up @@ -824,7 +795,10 @@
Default None uses `set_default_nside()`.
"""

def __init__(self, bandname="r", fiducial_FWHMEff=0.7, nside=DEFAULT_NSIDE):
def __init__(self, bandname="r", fiducial_FWHMEff=0.7, nside=DEFAULT_NSIDE, filtername=None):
if filtername is not None:
warnings.warn("filtername deprecated in favor of bandname", FutureWarning)
bandname = filtername

Check warning on line 801 in rubin_scheduler/scheduler/basis_functions/basis_functions.py

View check run for this annotation

Codecov / codecov/patch

rubin_scheduler/scheduler/basis_functions/basis_functions.py#L800-L801

Added lines #L800 - L801 were not covered by tests
super().__init__(nside=nside, bandname=bandname)
# The dark sky surface brightness values
self.dark_map = None
Expand Down Expand Up @@ -908,6 +882,16 @@
return result


class StrictFilterBasisFunction(StrictBandBasisFunction):
"""Deprecated in favor of StrictBandBasisFunction"""

def __init__(self, time_lag=10.0, filtername="r", twi_change=-18.0, note_free="DD"):
warnings.warn(
"StrictFilterBasisFunction deprecated in favor of StrictBandBasisFunction", FutureWarning
)
super().__init__(time_lag=time_lag, bandname=filtername, twi_change=twi_change, note_free=note_free)


class BandChangeBasisFunction(BaseBasisFunction):
"""Reward staying in the current band."""

Expand All @@ -922,6 +906,16 @@
return result


class FilterChangeBasisFunction(BandChangeBasisFunction):
"""Deprecated in favor of BandChangeBasisFunction"""

def __init__(self, filtername="r"):
warnings.warn(
"FilterChangeBasisFunction deprecated in favor of BandChangeBasisFunction", FutureWarning
)
super().__init__(bandname=filtername)


class SlewtimeBasisFunction(BaseBasisFunction):
"""Reward slews that take little time

Expand All @@ -941,7 +935,10 @@
Default None will use `set_default_nside()`.
"""

def __init__(self, max_time=135.0, bandname="r", nside=DEFAULT_NSIDE):
def __init__(self, max_time=135.0, bandname="r", nside=DEFAULT_NSIDE, filtername=None):
if filtername is not None:
warnings.warn("filtername deprecated in favor of bandname", FutureWarning)
bandname = filtername

Check warning on line 941 in rubin_scheduler/scheduler/basis_functions/basis_functions.py

View check run for this annotation

Codecov / codecov/patch

rubin_scheduler/scheduler/basis_functions/basis_functions.py#L940-L941

Added lines #L940 - L941 were not covered by tests
super(SlewtimeBasisFunction, self).__init__(nside=nside, bandname=bandname)

self.maxtime = max_time
Expand Down Expand Up @@ -998,7 +995,11 @@
enhance_window=[2.1, 3.2],
enhance_val=1.0,
apply_area=None,
filtername=None,
):
if filtername is not None:
warnings.warn("filtername deprecated in favor of bandname", FutureWarning)
bandname = filtername

Check warning on line 1002 in rubin_scheduler/scheduler/basis_functions/basis_functions.py

View check run for this annotation

Codecov / codecov/patch

rubin_scheduler/scheduler/basis_functions/basis_functions.py#L1001-L1002

Added lines #L1001 - L1002 were not covered by tests
super(CadenceEnhanceBasisFunction, self).__init__(nside=nside, bandname=bandname)

self.supress_window = np.sort(supress_window)
Expand Down Expand Up @@ -1091,7 +1092,11 @@
enhance_amp=1.0,
apply_area=None,
season_limit=None,
filtername=None,
):
if filtername is not None:
warnings.warn("filtername deprecated in favor of bandname", FutureWarning)
bandname = filtername

Check warning on line 1099 in rubin_scheduler/scheduler/basis_functions/basis_functions.py

View check run for this annotation

Codecov / codecov/patch

rubin_scheduler/scheduler/basis_functions/basis_functions.py#L1098-L1099

Added lines #L1098 - L1099 were not covered by tests
super(CadenceEnhanceTrapezoidBasisFunction, self).__init__(nside=nside, bandname=bandname)

self.delay_width = delay_width
Expand Down Expand Up @@ -1280,61 +1285,6 @@
return result


class GoodSeeingBasisFunction(BaseBasisFunction):
"""Drive observations in good seeing conditions"""

def __init__(
self,
nside=DEFAULT_NSIDE,
bandname="r",
footprint=None,
fwhm_eff_limit=0.8,
mag_diff=0.75,
):
send_unused_deprecation_warning("GoodSeeingBasisFunction")
return
super(GoodSeeingBasisFunction, self).__init__(nside=nside)

self.bandname = bandname
self.fwhm_eff_limit = IntRounded(fwhm_eff_limit)
if footprint is None:
footprints, labels = get_current_footprint(nside=self.nside)
fp = footprints[self.bandname]
else:
fp = footprint
self.out_of_bounds = np.where(fp == 0)[0]
self.result = fp * 0

self.mag_diff = IntRounded(mag_diff)
self.survey_features = {}
self.survey_features["coadd_depth_all"] = features.CoaddedDepth(
bandname=self.bandname, nside=self.nside
)
self.survey_features["coadd_depth_good"] = features.CoaddedDepth(
bandname=self.bandname, nside=self.nside, fwhm_eff_limit=fwhm_eff_limit
)

def _calc_value(self, conditions, **kwargs):
# Seeing is "bad"
if IntRounded(conditions.FWHMeff[self.bandname].min()) > self.fwhm_eff_limit:
return 0
result = self.result.copy()

diff = (
self.survey_features["coadd_depth_all"].feature - self.survey_features["coadd_depth_good"].feature
)
# Where are there things we want to observe?
good_pix = np.where(
(IntRounded(diff) > self.mag_diff)
& (IntRounded(conditions.FWHMeff[self.bandname]) <= self.fwhm_eff_limit)
)
# Hm, should this scale by the mag differences? Probably.
result[good_pix] = diff[good_pix]
result[self.out_of_bounds] = 0

return result


class VisitGap(BaseBasisFunction):
"""Basis function to create a visit gap based on the survey note field.

Expand All @@ -1358,7 +1308,10 @@
the last observation was at least gap in the past.
"""

def __init__(self, note, band_names=None, gap_min=25.0, penalty_val=np.nan):
def __init__(self, note, band_names=None, gap_min=25.0, penalty_val=np.nan, filter_names=None):
if filter_names is not None:
warnings.warn("filter_names deprecated in favor of band_names", FutureWarning)
band_names = filter_names

Check warning on line 1314 in rubin_scheduler/scheduler/basis_functions/basis_functions.py

View check run for this annotation

Codecov / codecov/patch

rubin_scheduler/scheduler/basis_functions/basis_functions.py#L1313-L1314

Added lines #L1313 - L1314 were not covered by tests
super().__init__()
self.penalty_val = penalty_val

Expand Down
Loading
Loading