Skip to content

Commit

Permalink
Merge main to branch (#250)
Browse files Browse the repository at this point in the history
* BTS model v1 + extra ML features (#243)

* update the BTS model
* extra ML features

* Auto follow-up: cancel if already classified or has spectra (#244)

* dont trigger followup requests if there is already classification or spectra for a source

* fix duplicate comments, test no followup when classified

* Temporary: use all candidates (not just prv_candidates) to generate lightcurve  (#249)

While we are retrieving data in prod for the prv_candidates, make sure to post all the candidates for an object as detections.
  • Loading branch information
Theodlz authored Sep 28, 2023
1 parent bcfda33 commit 5bb09e9
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 26 deletions.
16 changes: 8 additions & 8 deletions config.defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1118,7 +1118,7 @@ kowalski:
ml:
ZTF:
# instruments need to have a list of allowed features (as tuple), and a list of models
allowed_features: ('drb', 'diffmaglim', 'ra', 'dec', 'magpsf', 'sigmapsf', 'chipsf', 'fwhm', 'sky', 'chinr', 'sharpnr', 'sgscore1', 'distpsnr1', 'sgscore2', 'distpsnr2', 'sgscore3', 'distpsnr3', 'ndethist', 'ncovhist', 'scorr', 'nmtchps', 'clrcoeff', 'clrcounc', 'neargaia', 'neargaiabright', 'classtar', 'peakmag', 'age')
allowed_features: ('drb', 'diffmaglim', 'ra', 'dec', 'magpsf', 'sigmapsf', 'chipsf', 'fwhm', 'sky', 'chinr', 'sharpnr', 'sgscore1', 'distpsnr1', 'sgscore2', 'distpsnr2', 'sgscore3', 'distpsnr3', 'ndethist', 'ncovhist', 'scorr', 'nmtchps', 'clrcoeff', 'clrcounc', 'neargaia', 'neargaiabright', 'classtar', 'peakmag_so_far', 'maxmag_so_far', 'days_since_peak', 'days_to_peak', 'nnondet', 'age')
models:
# models need: a version (string, e.g. "v1"), a triplet (bool), and feature_names (bool, or list of feature names as tuple)
# if feature_names is True, all features from allowed_features are used
Expand Down Expand Up @@ -1152,13 +1152,13 @@ kowalski:
feature_names: ('drb', 'diffmaglim', 'ra', 'dec', 'magpsf', 'sigmapsf', 'chipsf', 'fwhm', 'sky', 'chinr', 'sharpnr', 'sgscore1', 'distpsnr1', 'sgscore2', 'distpsnr2', 'sgscore3', 'distpsnr3', 'ndethist', 'ncovhist', 'scorr', 'nmtchps', 'clrcoeff', 'clrcounc', 'neargaia', 'neargaiabright')
version: "d1_dnn_20201130"
url: "https://github.com/dmitryduev/acai/raw/master/models/acai_b.d1_dnn_20201130.h5"
# bts:
# triplet: True
# feature_names: ('sgscore1','distpsnr1','sgscore2','distpsnr2','fwhm','magpsf','sigmapsf','ra','dec','diffmaglim','ndethist','nmtchps','age','peakmag')
# version: "v03"
# format: "pb"
# order: ["triplet", "features"]
# url: "https://raw.githubusercontent.com/nabeelre/BNB-models/main/v03.tar.gz"
bts:
triplet: True
feature_names: ('sgscore1', 'distpsnr1', 'sgscore2', 'distpsnr2', 'fwhm', 'magpsf', 'sigmapsf', 'chipsf', 'ra', 'dec', 'diffmaglim', 'ndethist', 'nmtchps', 'age', 'days_since_peak', 'days_to_peak', 'peakmag_so_far', 'drb', 'ncovhist', 'nnondet', 'chinr', 'sharpnr', 'scorr', 'sky', 'maxmag_so_far')
version: "v1"
format: "pb"
order: ["triplet", "features"]
url: "https://raw.githubusercontent.com/nabeelre/BTSbot/main/production_models/v1.tar.gz"


skyportal:
Expand Down
80 changes: 78 additions & 2 deletions kowalski/alert_brokers/alert_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@
time_stamp,
timer,
)
from warnings import simplefilter

simplefilter(action="ignore", category=pd.errors.PerformanceWarning)

# Tensorflow is problematic for Mac's currently, so we can add an option to disable it
USE_TENSORFLOW = os.environ.get("USE_TENSORFLOW", True) in [
Expand Down Expand Up @@ -1336,7 +1339,10 @@ def alert_filter__user_defined(
).strftime("%Y-%m-%dT%H:%M:%S.%f"),
# one week validity window
},
# constraints
"source_group_ids": [_filter["group_id"]],
"not_if_classified": True,
"not_if_spectra_exist": True,
},
}

Expand Down Expand Up @@ -1677,6 +1683,37 @@ def alert_sentinel_skyportal(self, alert, prv_candidates, passed_filters):
log(e)
alert["prv_candidates"] = prv_candidates

# also get all the alerts for this object, to make sure to have all the detections
try:
all_alerts = list(
retry(self.mongo.db[self.collection_alerts].find)(
{
"objectId": alert["objectId"],
"candid": {"$ne": alert["candid"]},
},
{
"candidate": 1,
},
)
)
all_alerts = [
{**a["candidate"]} for a in all_alerts if "candidate" in a
]
# add to prv_candidates the detections that are not already in there
# use the jd and the fid to match
for a in all_alerts:
if not any(
[
(a["jd"] == p["jd"]) and (a["fid"] == p["fid"])
for p in alert["prv_candidates"]
]
):
alert["prv_candidates"].append(a)
del all_alerts
except Exception as e:
# this should never happen, but just in case
log(f"Failed to get all alerts for {alert['objectId']}: {e}")

self.alert_put_photometry(alert)

# post thumbnails
Expand Down Expand Up @@ -1765,6 +1802,37 @@ def alert_sentinel_skyportal(self, alert, prv_candidates, passed_filters):
# post alert photometry in single call to /api/photometry
alert["prv_candidates"] = prv_candidates

# also get all the alerts for this object, to make sure to have all the detections
try:
all_alerts = list(
retry(self.mongo.db[self.collection_alerts].find)(
{
"objectId": alert["objectId"],
"candid": {"$ne": alert["candid"]},
},
{
"candidate": 1,
},
)
)
all_alerts = [
{**a["candidate"]} for a in all_alerts if "candidate" in a
]
# add to prv_candidates the detections that are not already in there
# use the jd and the fid to match
for a in all_alerts:
if not any(
[
(a["jd"] == p["jd"]) and (a["fid"] == p["fid"])
for p in alert["prv_candidates"]
]
):
alert["prv_candidates"].append(a)
del all_alerts
except Exception as e:
# this should never happen, but just in case
log(f"Failed to get all alerts for {alert['objectId']}: {e}")

self.alert_put_photometry(alert)

if len(autosave_group_ids):
Expand Down Expand Up @@ -1885,14 +1953,22 @@ def alert_sentinel_skyportal(self, alert, prv_candidates, passed_filters):
)
if response.json()["status"] != "success":
raise ValueError(
response.json()["message"]
response.json().get(
"message",
"unknow error posting comment",
)
)
except Exception as e:
log(
f"Failed to post followup comment {comment['text']} for {alert['objectId']} to SkyPortal: {e}"
)
else:
raise ValueError(response.json()["message"])
raise ValueError(
response.json().get(
"message",
"unknow error posting followup request",
)
)
except Exception as e:
log(
f"Failed to post followup request for {alert['objectId']} to SkyPortal: {e}"
Expand Down
87 changes: 83 additions & 4 deletions kowalski/tests/test_alert_broker_ztf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from kowalski.alert_brokers.alert_broker_ztf import ZTFAlertWorker
from kowalski.config import load_config
from kowalski.log import log
from copy import deepcopy

""" load config and secrets """
config = load_config(config_files=["config.yaml"])["kowalski"]
Expand Down Expand Up @@ -109,10 +110,11 @@ def test_make_thumbnails(self):

def test_alert_filter__ml(self):
"""Test executing ML models on the alert"""
alert, _ = self.worker.alert_mongify(self.alert)
scores = self.worker.alert_filter__ml(alert)
alert, prv_candidates = self.worker.alert_mongify(self.alert)
all_prv_candidates = deepcopy(prv_candidates) + [deepcopy(alert["candidate"])]
scores = self.worker.alert_filter__ml(alert, all_prv_candidates)
assert len(scores) > 0
log(scores)
# print(scores)

def test_alert_filter__xmatch(self):
"""Test cross matching with external catalog"""
Expand Down Expand Up @@ -348,7 +350,6 @@ def test_alert_filter__user_defined_followup_with_broker(self):
}

passed_filters = self.worker.alert_filter__user_defined([filter], self.alert)
delete_alert(self.worker, self.alert)
assert passed_filters is not None
assert len(passed_filters) == 1
assert "auto_followup" in passed_filters[0]
Expand Down Expand Up @@ -383,3 +384,81 @@ def test_alert_filter__user_defined_followup_with_broker(self):
if (f["allocation_id"] == allocation_id and f["status"] == "submitted")
]
assert len(followup_requests) == 0

# rerun the first filter, but with the ignore_if_saved_group_id
# this time we are testing that it does not trigger a follow-up request
# if the source is already classified

# first post a classification
response = self.worker.api_skyportal(
"POST",
"/api/classification",
{
"obj_id": alert["objectId"],
"classification": "Ia",
"probability": 0.8,
"taxonomy_id": 1,
"group_ids": [saved_group_id],
},
)
assert response.status_code == 200
classification_id = response.json()["data"]["classification_id"]
assert classification_id is not None

# now rerun the filter
filter["group_id"] = saved_group_id
del filter["autosave"]["ignore_group_ids"]

passed_filters = self.worker.alert_filter__user_defined([filter], self.alert)
assert passed_filters is not None
assert len(passed_filters) == 1
assert "auto_followup" in passed_filters[0]
assert (
passed_filters[0]["auto_followup"]["data"]["payload"]["observation_type"]
== "IFU"
)

alert, prv_candidates = self.worker.alert_mongify(self.alert)
self.worker.alert_sentinel_skyportal(alert, prv_candidates, passed_filters)

# now fetch the source from SP
response = self.worker.api_skyportal(
"GET", f"/api/sources/{alert['objectId']}", None
)
assert response.status_code == 200
source = response.json()["data"]
assert source["id"] == "ZTF20aajcbhr"
assert len(source["groups"]) == 1
# should only be saved to the group of the first filter
assert source["groups"][0]["id"] == saved_group_id

# verify that there is a follow-up request
response = self.worker.api_skyportal(
"GET", f"/api/followup_request?sourceID={alert['objectId']}", None
)
assert response.status_code == 200
followup_requests = response.json()["data"].get("followup_requests", [])
followup_requests = [
f
for f in followup_requests
if (f["allocation_id"] == allocation_id and f["status"] == "submitted")
]
assert len(followup_requests) == 0

# delete the classification
response = self.worker.api_skyportal(
"DELETE", f"/api/classification/{classification_id}", None
)

# unsave the source from the group
response = self.worker.api_skyportal(
"POST",
"/api/source_groups",
{
"objId": alert["objectId"],
"unsaveGroupIds": [saved_group_id],
},
)
assert response.status_code == 200

delete_alert(self.worker, self.alert)
55 changes: 43 additions & 12 deletions kowalski/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,19 +1176,48 @@ def __init__(self, alert, alert_history, models, label=None, **kwargs):

self.alert = deepcopy(alert)

# add a peakmag field to the alert (min of all magpsf)
self.alert["candidate"]["peakmag"] = min(
[30]
+ [
a.get("magpsf", 30)
for a in alert_history
if a.get("magpsf", None) is not None
]
# ADD EXTRA FEATURES
peakmag_jd = alert["candidate"]["jd"]
peakmag = 30
maxmag = 0
# find the mjd of the peak magnitude
for a in alert_history:
if a.get("magpsf", None) is not None:
if a["magpsf"] < peakmag:
peakmag = a["magpsf"]
peakmag_jd = a["jd"]
if a["magpsf"] > maxmag:
maxmag = a["magpsf"]

first_alert_jd = min(
alert["candidate"].get("jdstarthist", None),
min(
[alert["candidate"]["jd"]]
+ [a["jd"] for a in alert_history if a["magpsf"] is not None]
),
)
# add an age field to the alert (alert["candidate"].jd - alert["candidate"].jdstarthist)
self.alert["candidate"]["age"] = self.alert["candidate"]["jd"] - self.alert[
"candidate"
].get("jdstarthist", self.alert["candidate"]["jd"])

# add a peakmag_so_far field to the alert (min of all magpsf)
self.alert["candidate"]["peakmag_so_far"] = peakmag

# add a maxmag_so_far field to the alert (max of all magpsf)
self.alert["candidate"]["maxmag_so_far"] = maxmag

# add a days_since_peak field to the alert (jd - peakmag_jd)
self.alert["candidate"]["days_since_peak"] = (
self.alert["candidate"]["jd"] - peakmag_jd
)

# add a days_to_peak field to the alert (peakmag_jd - first_alert_jd)
self.alert["candidate"]["days_to_peak"] = peakmag_jd - first_alert_jd

# add an age field to the alert: (jd - first_alert_jd)
self.alert["candidate"]["age"] = self.alert["candidate"]["jd"] - first_alert_jd

# number of non-detections: ncovhist - ndethist
self.alert["candidate"]["nnondet"] = alert["candidate"].get(
"ncovhist", 0
) - alert["candidate"].get("ndethist", 0)

triplet_normalize = kwargs.get("triplet_normalize", True)
to_tpu = kwargs.get("to_tpu", False)
Expand All @@ -1204,6 +1233,8 @@ def __init__(self, alert, alert_history, models, label=None, **kwargs):
# "dmdt": dmdt
}

del peakmag_jd, peakmag, maxmag, first_alert_jd

def make_triplet(self, normalize: bool = True, to_tpu: bool = False):
"""
Feed in alert packet
Expand Down

0 comments on commit 5bb09e9

Please sign in to comment.