Skip to content

Commit

Permalink
Merge pull request #43 from LCOGT/feature/no_reclassify
Browse files Browse the repository at this point in the history
Only reclassify if we haven't seen the target before.
  • Loading branch information
cmccully authored Oct 1, 2020
2 parents e70a1b1 + 107fbb9 commit 5bb1e92
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 8 deletions.
22 changes: 17 additions & 5 deletions banzai_nres/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,22 @@ def find_object_in_catalog(image, db_address):

class StellarClassifier(Stage):
def do_stage(self, image) -> NRESObservationFrame:

closest_previous_classification = dbs.get_closest_existing_classification(self.runtime_context.db_address,
image.ra, image.dec)

if closest_previous_classification is not None:
previous_coordinate = SkyCoord(closest_previous_classification.ra, closest_previous_classification.dec,
unit=(units.deg, units.deg))
this_coordinate = SkyCoord(image.ra, image.dec, unit=(units.deg, units.deg))

# Short circuit if the object is already classified
# We choose 2.6 arcseconds as the don't reclassify cutoff radius as it is the fiber size
if this_coordinate.separation(previous_coordinate) < 2.6 * units.arcsec:
image.classification = closest_previous_classification
image.meta['CLASSIFY'] = 0, 'Was this spectrum classified'
return image

find_object_in_catalog(image, self.runtime_context.db_address)

# TODO: For each param: Fix the other params, get the N closest models and save the results
Expand All @@ -66,9 +82,5 @@ def do_stage(self, image) -> NRESObservationFrame:
image.meta['CLASSIFY'] = 0, 'Was this spectrum classified'
else:
image.meta['CLASSIFY'] = 1, 'Was this spectrum classified'
image.meta['TEFF'] = image.classification.T_effective
image.meta['LOG_G'] = image.classification.log_g
image.meta['FE_H'] = image.classification.metallicity
image.meta['ALPHA'] = image.classification.alpha

dbs.save_classification(self.runtime_context.db_address, image)
return image
74 changes: 72 additions & 2 deletions banzai_nres/dbs.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from banzai.dbs import Base, create_db
from banzai.dbs import Base, create_db, add_or_update_record
from sqlalchemy import Column, String, Integer, Float, Index
import boto3
import banzai.dbs
import os
from glob import glob
import logging
from sqlalchemy import func
from sqlalchemy import func, desc
from sqlalchemy.ext.hybrid import hybrid_method
import numpy as np

from banzai_nres.utils.phoenix_utils import parse_phoenix_header

Expand Down Expand Up @@ -69,6 +70,53 @@ def diff_luminosity(cls, value):
return func.abs(cls.luminosity - value)


# We define the great circle distance here instead of using astropy because we need it to work inside the db.
def cos_great_circle_distance(sin_ra1, cos_ra1, sin_dec1, cos_dec1, sin_ra2, cos_ra2, sin_dec2, cos_dec2):
"""
:param sin_ra1: sin(ra1)
:param cos_ra1: cos(ra1)
:param sin_dec1: sin(dec1)
:param cos_dec1: cos(dec1)
:param sin_ra2: sin(ra2)
:param cos_ra2: cos(ra2)
:param sin_dec2: sin(dec2)
:param cos_dec2: cos(dec2)
:return: cos(D) where D is the great circle distance
This is the standard great circle distance from e.g. https://mathworld.wolfram.com/GreatCircle.html
The only difference is we also use the identity for cos(x1 - x2) (e.g. https://mathworld.wolfram.com/TrigonometricAdditionFormulas.html)
so that we can calculate the sin and cos terms ahead of time.
"""
cos_distance = sin_dec1 * sin_dec2 + cos_dec1 * cos_dec2 * (cos_ra1 * cos_ra2 + sin_ra1 * sin_ra2)
return cos_distance


class Classification(Base):
__tablename__ = 'classifications'
id = Column(Integer, primary_key=True, autoincrement=True)
ra = Column(Float)
dec = Column(Float)
sin_ra = Column(Float)
cos_ra = Column(Float)
sin_dec = Column(Float)
cos_dec = Column(Float)
T_effective = Column(Float)
log_g = Column(Float)
metallicity = Column(Float)
alpha = Column(Float)
Index('idx_radec', "ra", "dec")
Index('idx_radectrig', "sin_ra", "cos_ra", "sin_dec", "cos_dec")

@hybrid_method
def cos_distance(self, sin_ra, cos_ra, sin_dec, cos_dec):
return cos_great_circle_distance(sin_ra, cos_ra, sin_dec, cos_dec, self.sin_ra, self.cos_ra, self.sin_dec, self.cos_dec)

@cos_distance.expression
def cos_distance(cls, sin_ra, cos_ra, sin_dec, cos_dec):
return cos_great_circle_distance(sin_ra, cos_ra, sin_dec, cos_dec, cls.sin_ra, cls.cos_ra, cls.sin_dec, cls.cos_dec)


class ResourceFile(Base):
__tablename__ = 'resourcefiles'
id = Column(Integer, primary_key=True, autoincrement=True)
Expand Down Expand Up @@ -163,3 +211,25 @@ def get_resource_file(db_address, key):
with banzai.dbs.get_session(db_address=db_address) as db_session:
resource_file = db_session.query(ResourceFile).filter(ResourceFile.key == key).first()
return resource_file


def get_closest_existing_classification(db_address, ra, dec):
with banzai.dbs.get_session(db_address=db_address) as db_session:
# Note the desc here. Because sqlite does not have trig functions, we can't take an arc cos. So we need the
# value when the cos is maximum (which is theta = minimum)
order = [desc(Classification.cos_distance(np.sin(np.deg2rad(ra)), np.cos(np.deg2rad(ra)),
np.sin(np.deg2rad(dec)), np.cos(np.deg2rad(dec))))]
model = db_session.query(Classification).order_by(*order).first()
return model


def save_classification(db_address, frame):
with banzai.dbs.get_session(db_address=db_address) as db_session:
equivalence_criteria = {'ra': frame.ra, 'dec': frame.dec}
record_attributes = {'ra': frame.ra, 'dec': frame.dec, 'sin_ra': np.sin(np.deg2rad(frame.ra)),
'cos_ra': np.cos(np.deg2rad(frame.ra)), 'sin_dec': np.sin(np.deg2rad(frame.dec)),
'cos_dec': np.cos(np.deg2rad(frame.dec)),
'T_effective': frame.classification.T_effective,
'log_g': frame.classification.log_g, 'metallicity': frame.classification.metallicity,
'alpha': frame.classification.alpha}
add_or_update_record(db_session, Classification, equivalence_criteria, record_attributes)
18 changes: 18 additions & 0 deletions banzai_nres/frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,24 @@ def pm_dec(self, value):
# Proper motion is stored in arcseconds/year but we always use it in mas/year
self.primary_hdu.meta['PM-DEC'] = value / 1000.0

@property
def classification(self):
return self._classification

@classification.setter
def classification(self, value):
self._classification = value
if value is not None:
self.meta['TEFF'] = value.T_effective
self.meta['LOG_G'] = value.log_g
self.meta['FE_H'] = value.metallicity
self.meta['ALPHA'] = value.alpha
else:
self.meta['TEFF'] = ''
self.meta['LOG_G'] = ''
self.meta['FE_H'] = ''
self.meta['ALPHA'] = ''


class NRESCalibrationFrame(LCOCalibrationFrame, NRESObservationFrame):
def __init__(self, hdu_list: list, file_path: str, frame_id: int = None, grouping_criteria: list = None,
Expand Down
18 changes: 18 additions & 0 deletions banzai_nres/tests/test_gc_distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from banzai_nres.dbs import cos_great_circle_distance
from astropy.coordinates import SkyCoord
from astropy import units
import numpy as np


def test_gc_distance():
ra1, dec1 = 150.0, 25.0
ra2, dec2 = 100.0, 10.0
coord1 = SkyCoord(ra1, dec1, unit=(units.deg, units.deg))
coord2 = SkyCoord(ra2, dec2, unit=(units.deg, units.deg))
expected = np.cos(np.deg2rad(coord1.separation(coord2).deg))

actual = cos_great_circle_distance(np.sin(np.deg2rad(ra1)), np.cos(np.deg2rad(ra1)),
np.sin(np.deg2rad(dec1)), np.cos(np.deg2rad(dec1)),
np.sin(np.deg2rad(ra2)), np.cos(np.deg2rad(ra2)),
np.sin(np.deg2rad(dec2)), np.cos(np.deg2rad(dec2)))
np.testing.assert_allclose(actual, expected)
2 changes: 1 addition & 1 deletion banzai_nres/tests/test_rv.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def load(self, *args):
image.instrument = SimpleNamespace()
image.instrument.site = 'npt'
# Classification just can't be None so that the stage does not abort.
image.classification = 'foo'
image.classification = SimpleNamespace(**{'T_effective': 5000.0, 'log_g': 0.0, 'metallicity': 0.0, 'alpha': 0.0})
mock_db.return_value = SimpleNamespace(**site_info)

# Run the RV code
Expand Down

0 comments on commit 5bb1e92

Please sign in to comment.