Skip to content

Commit

Permalink
add tests for boss flags
Browse files Browse the repository at this point in the history
  • Loading branch information
andycasey committed Nov 12, 2024
1 parent a6e9eed commit ca24cdf
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 74 deletions.
75 changes: 29 additions & 46 deletions src/astra/models/bossnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,50 +62,33 @@ def flag_bad(self):
| self.flag_runtime_exception
)

@classmethod
def from_spectrum(cls, spectrum, **kwargs):

# TODO: Move this to happen at runtime
def apply_result_flags():

(
BossNet
.update(result_flags=BossNet.flag_runtime_exception.set())
.where(BossNet.teff.is_null())
.execute()
)

(
BossNet
.update(result_flags=BossNet.flag_unreliable_teff.set())
.where(
(BossNet.teff < 1700) | (BossNet.teff > 100_000)
)
.execute()
)
(
BossNet
.update(result_flags=BossNet.flag_unreliable_logg.set())
.where(
(BossNet.logg < -1) | (BossNet.logg > 10)
)
.execute()
)
(
BossNet
.update(result_flags=BossNet.flag_unreliable_fe_h.set())
.where(
(BossNet.teff < 3200)
| (BossNet.logg > 5)
| (BossNet.fe_h < -4)
| (BossNet.fe_h > 2)
)
.execute()
)
(
BossNet
.update(result_flags=BossNet.flag_suspicious_fe_h.set())
.where(
(BossNet.teff < 3900)
& ((6 > BossNet.logg) & (BossNet.logg > 3))
)
.execute()
)
kwds = kwargs.copy()
teff = kwargs.get("teff", None)
if teff is not None:
kwds["flag_unreliable_teff"] = ((teff < 1700) | (teff > 100_000))
else:
kwds["flag_runtime_exception"] = True

logg = kwargs.get("logg", None)
if logg is not None:
kwds["flag_unreliable_logg"] = ((logg < -1) | (logg > 10))

fe_h = kwargs.get("fe_h", None)
if fe_h is not None and logg is not None and teff is not None:
kwds["flag_unreliable_fe_h"] = (
(teff < 3200)
| (logg > 5)
| (fe_h < -4)
| (fe_h > 2)
)

if teff is not None and logg is not None:
kwds["flag_suspicious_fe_h"] = (
(teff < 3900)
& (6 > logg > 3)
)

return super().from_spectrum(spectrum, **kwds)
18 changes: 9 additions & 9 deletions src/astra/pipelines/bossnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,23 +391,23 @@ def bossnet(spectra: Iterable[BossVisitSpectrum], num_uncertainty_draws: Optiona
log_G,log_Teff,FeH,rv,log_G_std,log_Teff_std,Feh_std,rv_std = make_prediction(flux, e_flux, wavelen, num_uncertainty_draws,model,device)
except:
log.exception(f"Exception when running ANet on {spectrum}")
yield BossNet(
spectrum_pk=spectrum.spectrum_pk,
source_pk=spectrum.source_pk,
yield BossNet.from_spectrum(
spectrum,
flag_runtime_exception=True
)
else:
yield BossNet(
spectrum_pk=spectrum.spectrum_pk,
source_pk=spectrum.source_pk,
teff = 10**log_Teff
e_teff = 10**log_Teff * log_Teff_std * np.log(10)
yield BossNet.from_spectrum(
spectrum,
fe_h=FeH,
e_fe_h=Feh_std,
logg=log_G,
e_logg=log_G_std,
teff=10**log_Teff,
e_teff=10**log_Teff * log_Teff_std * np.log(10),
teff=teff,
e_teff=e_teff,
bn_v_r=rv,
e_bn_v_r=rv_std
e_bn_v_r=rv_std,
)


Expand Down
88 changes: 88 additions & 0 deletions tests/test_bossnet_flags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import pytest
import peewee
import os
os.environ["ASTRA_DATABASE_PATH"] = ":memory:"

def test_bossnet_flags():

import datetime
from astra import __version__
from astra.fields import AutoField, ForeignKeyField, DateTimeField
from astra.models.base import database, BaseModel
from astra.models.source import Source
from astra.models.spectrum import Spectrum, SpectrumMixin
from astra.models.bossnet import BossNet
from astra.models.boss import BossVisitSpectrum

class DummyBossVisitSpectrum(BaseModel, SpectrumMixin):

"""A BOSS visit spectrum, where a visit is defined by spectra taken on the same MJD."""

pk = AutoField()

#> Identifiers
spectrum_pk = ForeignKeyField(
Spectrum,
null=True,
index=True,
unique=True,
lazy_load=False,
column_name="spectrum_pk"
)
source = ForeignKeyField(
Source,
null=True,
index=True,
column_name="source_pk",
backref="boss_visit_spectra"
)

created = DateTimeField(default=datetime.datetime.now)
modified = DateTimeField(default=datetime.datetime.now)

models = (Source, Spectrum, BossNet, DummyBossVisitSpectrum)
database.create_tables(models)

Source.create()
Spectrum.create()
s = DummyBossVisitSpectrum.create(
source_pk=1,
spectrum_pk=1,
release="sdss5",
filetype="specFull",
run2d="run2d",
mjd=1,
fieldid=1,
catalogid=1,
healpix=1
)

scenarios = [
({}, (lambda r: r.flag_runtime_exception, lambda r: r.result_flags > 0)),
(dict(teff=5000), (lambda r: not r.flag_unreliable_teff, lambda r: not r.flag_runtime_exception)),
(dict(logg=3), (lambda r: not r.flag_unreliable_logg, )),
(dict(fe_h=-1), (lambda r: not r.flag_unreliable_fe_h, )),
(dict(teff=5000, logg=3, fe_h=-1), (lambda r: not r.flag_unreliable_teff, lambda r: not r.flag_unreliable_logg, lambda r: not r.flag_unreliable_fe_h, lambda r: not r.flag_runtime_exception)),
(dict(teff=1699), (lambda r: r.flag_unreliable_teff, lambda r: r.result_flags > 0)),
(dict(teff=100001), (lambda r: r.flag_unreliable_teff, lambda r: r.result_flags > 0)),
(dict(teff=5000, fe_h=0, logg=-1.1), (lambda r: r.flag_unreliable_logg, lambda r: r.result_flags > 0)),
(dict(teff=5000, fe_h=0, logg=10.1), (lambda r: r.flag_unreliable_logg, lambda r: r.result_flags > 0)),
(dict(teff=5000, logg=3, fe_h=-4.1), (lambda r: r.flag_unreliable_fe_h, lambda r: r.result_flags > 0)),
(dict(teff=5000, logg=3, fe_h=2.1), (lambda r: r.flag_unreliable_fe_h, lambda r: r.result_flags > 0)),
(dict(teff=3100, logg=3, fe_h=-1), (lambda r: r.flag_unreliable_fe_h, lambda r: r.result_flags > 0)),
(dict(teff=5000, logg=6, fe_h=-1), (lambda r: r.flag_unreliable_fe_h, lambda r: r.result_flags > 0)),
(dict(teff=3100, logg=5, fe_h=-1), (lambda r: r.flag_suspicious_fe_h, )),
]

for kwds, expectations in scenarios:
r = BossNet.from_spectrum(s, **kwds)
for n, fun in enumerate(expectations):
assert fun(r), f"Failed on scenario {n} with {kwds}"
r.save()
r = BossNet.get(task_pk=r.task_pk)
for n, fun in enumerate(expectations):
assert fun(r), f"Failed on scenario {n} with {kwds}"
r.delete_instance()



15 changes: 9 additions & 6 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,16 @@ def dummy_task(spectra) -> Iterable[ThisDummy]:

for model in (Source, Spectrum, ThisDummy, ApogeeVisitSpectrum):
model.create_table()

#for model in (ApogeeVisitSpectrum, Spectrum, Source):
# model.delete().execute()

source_pk = Source.create().pk
spectrum_pk1 = Spectrum.create().pk
spectrum_pk2 = Spectrum.create().pk

Source.create()
Spectrum.create()
Spectrum.create()

s = ApogeeVisitSpectrum.create(spectrum_pk=1, source_pk=1, release="test", apred="apred", plate="plate", telescope="telescope", fiber=0, mjd=0, field="field", prefix="ap")
s2 = ApogeeVisitSpectrum.create(spectrum_pk=2, source_pk=1, release="test", apred="apred", plate="plate", telescope="telescope", fiber=1, mjd=0, field="field", prefix="ap")
s = ApogeeVisitSpectrum.create(spectrum_pk=spectrum_pk1, source_pk=source_pk, release="test", apred="apred", plate="plate", telescope="telescope", fiber=0, mjd=0, field="field", prefix="ap")
s2 = ApogeeVisitSpectrum.create(spectrum_pk=spectrum_pk2, source_pk=source_pk, release="test", apred="apred", plate="plate", telescope="telescope", fiber=1, mjd=0, field="field", prefix="ap")

r1 = list(dummy_task([s]))[0].__data__
sleep(1)
Expand Down
24 changes: 11 additions & 13 deletions tests/test_task_query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

def test_task_query_builder_with_group_by():
from astra import task, generate_queries_for_task
from astra.models.base import database
from astra.models.source import Source
from astra.models.spectrum import Spectrum
from astra.models.apogee import ApogeeVisitSpectrum
Expand All @@ -13,18 +14,14 @@ def test_task_query_builder_with_group_by():
from datetime import datetime
from typing import Iterable



class ApogeeVisitState(PipelineOutputModel):
used = BooleanField()

for model in (ApogeeVisitState, Source, Spectrum, ApogeeVisitSpectrum):
model.create_table()
models = (Source, Spectrum, ApogeeVisitSpectrum, ApogeeVisitState)
database.create_tables(models)

for n in range(3):
Source.create()
for n in range(4):
Spectrum.create()
source_pks = [Source.create().pk for n in range(3)]
spectrum_pks = [Spectrum.create().pk for n in range(4)]


@task(group_by=("source_pk", "telescope"))
Expand All @@ -42,10 +39,11 @@ def make_stack(spectra: Iterable[ApogeeVisitSpectrum]) -> Iterable[ApogeeVisitSt
yield ApogeeVisitState.from_spectrum(s, used=True)



s1 = ApogeeVisitSpectrum.create(spectrum_pk=1, source_pk=1, release="test", apred="apred", plate="plate", telescope="apo", fiber=0, mjd=0, field="field", prefix="ap")
s2 = ApogeeVisitSpectrum.create(spectrum_pk=2, source_pk=1, release="test", apred="apred", plate="plate", telescope="apo", fiber=1, mjd=0, field="field", prefix="ap")
s3 = ApogeeVisitSpectrum.create(spectrum_pk=3, source_pk=2, release="test", apred="apred", plate="plate", telescope="lco", fiber=1, mjd=0, field="field", prefix="ap")
ApogeeVisitSpectrum.delete().execute()

s1 = ApogeeVisitSpectrum.create(spectrum_pk=spectrum_pks[1-1], source_pk=source_pks[1-1], release="test", apred="apred", plate="plate", telescope="apo", fiber=0, mjd=0, field="field", prefix="ap")
s2 = ApogeeVisitSpectrum.create(spectrum_pk=spectrum_pks[2-1], source_pk=source_pks[1-1], release="test", apred="apred", plate="plate", telescope="apo", fiber=1, mjd=0, field="field", prefix="ap")
s3 = ApogeeVisitSpectrum.create(spectrum_pk=spectrum_pks[3-1], source_pk=source_pks[2-1], release="test", apred="apred", plate="plate", telescope="lco", fiber=1, mjd=0, field="field", prefix="ap")


_, q = next(generate_queries_for_task(make_stack))
Expand All @@ -69,7 +67,7 @@ def make_stack(spectra: Iterable[ApogeeVisitSpectrum]) -> Iterable[ApogeeVisitSt
s1.modified = datetime.now()
s1.save()

s4 = ApogeeVisitSpectrum.create(spectrum_pk=4, source_pk=3, release="test", apred="apred", plate="plate", telescope="apo", fiber=12, mjd=1, field="field", prefix="ap")
s4 = ApogeeVisitSpectrum.create(spectrum_pk=spectrum_pks[4-1], source_pk=source_pks[3-1], release="test", apred="apred", plate="plate", telescope="apo", fiber=12, mjd=1, field="field", prefix="ap")
_, q = next(generate_queries_for_task(make_stack))
assert q.count() == 2
q = list(q)
Expand Down

0 comments on commit ca24cdf

Please sign in to comment.