Skip to content

Commit

Permalink
Merge pull request #161 from sot/allow-grid-model-extrapolation
Browse files Browse the repository at this point in the history
Add tolerance args to clip_and_warn
  • Loading branch information
taldcroft authored Nov 22, 2023
2 parents e25f0fc + 7d6134e commit 6c6c8fa
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 5 deletions.
20 changes: 15 additions & 5 deletions chandra_aca/star_probs.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,11 +389,11 @@ def get_default_acq_prob_model_info(verbose=True):
return info


def clip_and_warn(name, val, val_lo, val_hi, model):
def clip_and_warn(name, val, val_lo, val_hi, model, tol_lo=0.0, tol_hi=0.0):
"""
Clip ``val`` to be in the range ``val_lo`` to ``val_hi`` and issue a
warning if clipping occurs. The ``name`` and ``model`` are just used in
the warning.
warning if clipping occurs, subject to ``tol_lo`` and ``tol_hi`` expansions.
The ``name`` and ``model`` are just used in the warning.
Parameters
----------
Expand All @@ -407,17 +407,26 @@ def clip_and_warn(name, val, val_lo, val_hi, model):
Maximum
model
Model name
tol_lo
Tolerance below ``val_lo`` for issuing a warning (default=0.0)
tol_hi
Tolerance above ``val_hi`` for issuing a warning (default=0.0)
Returns
-------
Clipped value
"""
val = np.asarray(val)
if np.any((val > val_hi) | (val < val_lo)):

# Provide a tolerance for emitting a warning clipping
if np.any((val > val_hi + tol_hi) | (val < val_lo - tol_lo)):
warnings.warn(
f"\nModel {model} computed between {val_lo} <= {name} <= {val_hi}, "
f"clipping input {name}(s) outside that range."
)

# Now clip to the actual limits
if np.any((val > val_hi) | (val < val_lo)):
val = np.clip(val, val_lo, val_hi)

return val
Expand Down Expand Up @@ -621,7 +630,8 @@ def grid_model_acq_prob(
model_filename = Path(gfm["info"]["data_file_path"]).name

# Make sure inputs are within range of gridded model
mag = clip_and_warn("mag", mag, mag_lo, mag_hi, model_filename)
# TODO: run additional test cases on ASVT, make a new model, remove tol_hi for mag.
mag = clip_and_warn("mag", mag, mag_lo, mag_hi, model_filename, tol_hi=0.25)
t_ccd = clip_and_warn("t_ccd", t_ccd, t_ccd_lo, t_ccd_hi, model_filename)
halfwidth = clip_and_warn("halfw", halfwidth, halfw_lo, halfw_hi, model_filename)

Expand Down
27 changes: 27 additions & 0 deletions chandra_aca/tests/test_star_probs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from chandra_aca.star_probs import (
acq_success_prob,
binom_ppf,
clip_and_warn,
conf,
get_default_acq_prob_model_info,
get_grid_func_model,
Expand Down Expand Up @@ -676,3 +677,29 @@ def test_grid_model_3_48(monkeypatch):
assert info["version"] == "3.48"
assert info["data_file_path"].endswith("grid-floor-2020-02.fits.gz")
assert info["commit"] == "68a58099a9b51bef52ef14fbd0f1971f950e6ba3"


def test_clip_and_warn():
"""Test that we get a warning when clipping occurs"""
name = "mag"
model = "grid-floor"
val_lo = 5.0
val_hi = 11.75

# No warnings
clip_and_warn(name, 11.75, val_lo, val_hi, model)
clip_and_warn(name, 5.0, val_lo, val_hi, model)
clip_and_warn(name, 12.0, val_lo, val_hi, model, tol_hi=0.25)
clip_and_warn(name, 4.75, val_lo, val_hi, model, tol_lo=0.25)

# Expected warnings
match_re = rf"{model} computed between 5.0 <= {name} <= 11.75"
for val, tol_hi, tol_lo in [
(12.01, 0.25, 0.0),
(4.74, 0.0, 0.25),
([4.74, 7.0, 12.01], 0.25, 0.25),
]:
with pytest.warns(UserWarning, match=match_re):
clip_and_warn(
name, val, val_lo, val_hi, model, tol_hi=tol_hi, tol_lo=tol_lo
)

0 comments on commit 6c6c8fa

Please sign in to comment.