From ef5171c14498b884a734867d2981972d05e839fe Mon Sep 17 00:00:00 2001 From: Theophile du Laz Date: Wed, 27 Sep 2023 14:34:10 -0700 Subject: [PATCH 1/5] BTS model v1 + extra ML features (#243) * update the BTS model * extra ML features --- config.defaults.yaml | 16 +++---- kowalski/tests/test_alert_broker_ztf.py | 8 ++-- kowalski/utils.py | 55 +++++++++++++++++++------ 3 files changed, 56 insertions(+), 23 deletions(-) diff --git a/config.defaults.yaml b/config.defaults.yaml index a852f40e..4d827a47 100644 --- a/config.defaults.yaml +++ b/config.defaults.yaml @@ -1118,7 +1118,7 @@ kowalski: ml: ZTF: # instruments need to have a list of allowed features (as tuple), and a list of models - allowed_features: ('drb', 'diffmaglim', 'ra', 'dec', 'magpsf', 'sigmapsf', 'chipsf', 'fwhm', 'sky', 'chinr', 'sharpnr', 'sgscore1', 'distpsnr1', 'sgscore2', 'distpsnr2', 'sgscore3', 'distpsnr3', 'ndethist', 'ncovhist', 'scorr', 'nmtchps', 'clrcoeff', 'clrcounc', 'neargaia', 'neargaiabright', 'classtar', 'peakmag', 'age') + allowed_features: ('drb', 'diffmaglim', 'ra', 'dec', 'magpsf', 'sigmapsf', 'chipsf', 'fwhm', 'sky', 'chinr', 'sharpnr', 'sgscore1', 'distpsnr1', 'sgscore2', 'distpsnr2', 'sgscore3', 'distpsnr3', 'ndethist', 'ncovhist', 'scorr', 'nmtchps', 'clrcoeff', 'clrcounc', 'neargaia', 'neargaiabright', 'classtar', 'peakmag_so_far', 'maxmag_so_far', 'days_since_peak', 'days_to_peak', 'nnondet', 'age') models: # models need: a version (string, e.g. "v1"), a triplet (bool), and feature_names (bool, or list of feature names as tuple) # if feature_names is True, all features from allowed_features are used @@ -1152,13 +1152,13 @@ kowalski: feature_names: ('drb', 'diffmaglim', 'ra', 'dec', 'magpsf', 'sigmapsf', 'chipsf', 'fwhm', 'sky', 'chinr', 'sharpnr', 'sgscore1', 'distpsnr1', 'sgscore2', 'distpsnr2', 'sgscore3', 'distpsnr3', 'ndethist', 'ncovhist', 'scorr', 'nmtchps', 'clrcoeff', 'clrcounc', 'neargaia', 'neargaiabright') version: "d1_dnn_20201130" url: "https://github.com/dmitryduev/acai/raw/master/models/acai_b.d1_dnn_20201130.h5" - # bts: - # triplet: True - # feature_names: ('sgscore1','distpsnr1','sgscore2','distpsnr2','fwhm','magpsf','sigmapsf','ra','dec','diffmaglim','ndethist','nmtchps','age','peakmag') - # version: "v03" - # format: "pb" - # order: ["triplet", "features"] - # url: "https://raw.githubusercontent.com/nabeelre/BNB-models/main/v03.tar.gz" + bts: + triplet: True + feature_names: ('sgscore1', 'distpsnr1', 'sgscore2', 'distpsnr2', 'fwhm', 'magpsf', 'sigmapsf', 'chipsf', 'ra', 'dec', 'diffmaglim', 'ndethist', 'nmtchps', 'age', 'days_since_peak', 'days_to_peak', 'peakmag_so_far', 'drb', 'ncovhist', 'nnondet', 'chinr', 'sharpnr', 'scorr', 'sky', 'maxmag_so_far') + version: "v1" + format: "pb" + order: ["triplet", "features"] + url: "https://raw.githubusercontent.com/nabeelre/BTSbot/main/production_models/v1.tar.gz" skyportal: diff --git a/kowalski/tests/test_alert_broker_ztf.py b/kowalski/tests/test_alert_broker_ztf.py index 61393c91..22788778 100644 --- a/kowalski/tests/test_alert_broker_ztf.py +++ b/kowalski/tests/test_alert_broker_ztf.py @@ -3,6 +3,7 @@ from kowalski.alert_brokers.alert_broker_ztf import ZTFAlertWorker from kowalski.config import load_config from kowalski.log import log +from copy import deepcopy """ load config and secrets """ config = load_config(config_files=["config.yaml"])["kowalski"] @@ -109,10 +110,11 @@ def test_make_thumbnails(self): def test_alert_filter__ml(self): """Test executing ML models on the alert""" - alert, _ = self.worker.alert_mongify(self.alert) - scores = self.worker.alert_filter__ml(alert) + alert, prv_candidates = self.worker.alert_mongify(self.alert) + all_prv_candidates = deepcopy(prv_candidates) + [deepcopy(alert["candidate"])] + scores = self.worker.alert_filter__ml(alert, all_prv_candidates) assert len(scores) > 0 - log(scores) + # print(scores) def test_alert_filter__xmatch(self): """Test cross matching with external catalog""" diff --git a/kowalski/utils.py b/kowalski/utils.py index b97788c4..221dde12 100644 --- a/kowalski/utils.py +++ b/kowalski/utils.py @@ -1176,19 +1176,48 @@ def __init__(self, alert, alert_history, models, label=None, **kwargs): self.alert = deepcopy(alert) - # add a peakmag field to the alert (min of all magpsf) - self.alert["candidate"]["peakmag"] = min( - [30] - + [ - a.get("magpsf", 30) - for a in alert_history - if a.get("magpsf", None) is not None - ] + # ADD EXTRA FEATURES + peakmag_jd = alert["candidate"]["jd"] + peakmag = 30 + maxmag = 0 + # find the mjd of the peak magnitude + for a in alert_history: + if a.get("magpsf", None) is not None: + if a["magpsf"] < peakmag: + peakmag = a["magpsf"] + peakmag_jd = a["jd"] + if a["magpsf"] > maxmag: + maxmag = a["magpsf"] + + first_alert_jd = min( + alert["candidate"].get("jdstarthist", None), + min( + [alert["candidate"]["jd"]] + + [a["jd"] for a in alert_history if a["magpsf"] is not None] + ), ) - # add an age field to the alert (alert["candidate"].jd - alert["candidate"].jdstarthist) - self.alert["candidate"]["age"] = self.alert["candidate"]["jd"] - self.alert[ - "candidate" - ].get("jdstarthist", self.alert["candidate"]["jd"]) + + # add a peakmag_so_far field to the alert (min of all magpsf) + self.alert["candidate"]["peakmag_so_far"] = peakmag + + # add a maxmag_so_far field to the alert (max of all magpsf) + self.alert["candidate"]["maxmag_so_far"] = maxmag + + # add a days_since_peak field to the alert (jd - peakmag_jd) + self.alert["candidate"]["days_since_peak"] = ( + self.alert["candidate"]["jd"] - peakmag_jd + ) + + # add a days_to_peak field to the alert (peakmag_jd - first_alert_jd) + self.alert["candidate"]["days_to_peak"] = peakmag_jd - first_alert_jd + + # add an age field to the alert: (jd - first_alert_jd) + self.alert["candidate"]["age"] = self.alert["candidate"]["jd"] - first_alert_jd + + # number of non-detections: ncovhist - ndethist + self.alert["candidate"]["nnondet"] = alert["candidate"].get( + "ncovhist", 0 + ) - alert["candidate"].get("ndethist", 0) triplet_normalize = kwargs.get("triplet_normalize", True) to_tpu = kwargs.get("to_tpu", False) @@ -1204,6 +1233,8 @@ def __init__(self, alert, alert_history, models, label=None, **kwargs): # "dmdt": dmdt } + del peakmag_jd, peakmag, maxmag, first_alert_jd + def make_triplet(self, normalize: bool = True, to_tpu: bool = False): """ Feed in alert packet From e406925ef1cb26b1d99786ae4ea4f0af9b2ac499 Mon Sep 17 00:00:00 2001 From: Theophile du Laz Date: Wed, 27 Sep 2023 14:34:33 -0700 Subject: [PATCH 2/5] Auto follow-up: cancel if already classified or has spectra (#244) * dont trigger followup requests if there is already classification or spectra for a source * fix duplicate comments, test no followup when classified --- kowalski/alert_brokers/alert_broker.py | 3 + kowalski/tests/test_alert_broker_ztf.py | 79 ++++++++++++++++++++++++- 2 files changed, 81 insertions(+), 1 deletion(-) diff --git a/kowalski/alert_brokers/alert_broker.py b/kowalski/alert_brokers/alert_broker.py index 17a58fb5..6d6cc4be 100644 --- a/kowalski/alert_brokers/alert_broker.py +++ b/kowalski/alert_brokers/alert_broker.py @@ -1336,7 +1336,10 @@ def alert_filter__user_defined( ).strftime("%Y-%m-%dT%H:%M:%S.%f"), # one week validity window }, + # constraints "source_group_ids": [_filter["group_id"]], + "not_if_classified": True, + "not_if_spectra_exist": True, }, } diff --git a/kowalski/tests/test_alert_broker_ztf.py b/kowalski/tests/test_alert_broker_ztf.py index 22788778..92740151 100644 --- a/kowalski/tests/test_alert_broker_ztf.py +++ b/kowalski/tests/test_alert_broker_ztf.py @@ -350,7 +350,6 @@ def test_alert_filter__user_defined_followup_with_broker(self): } passed_filters = self.worker.alert_filter__user_defined([filter], self.alert) - delete_alert(self.worker, self.alert) assert passed_filters is not None assert len(passed_filters) == 1 assert "auto_followup" in passed_filters[0] @@ -385,3 +384,81 @@ def test_alert_filter__user_defined_followup_with_broker(self): if (f["allocation_id"] == allocation_id and f["status"] == "submitted") ] assert len(followup_requests) == 0 + + # rerun the first filter, but with the ignore_if_saved_group_id + # this time we are testing that it does not trigger a follow-up request + # if the source is already classified + + # first post a classification + response = self.worker.api_skyportal( + "POST", + "/api/classification", + { + "obj_id": alert["objectId"], + "classification": "Ia", + "probability": 0.8, + "taxonomy_id": 1, + "group_ids": [saved_group_id], + }, + ) + assert response.status_code == 200 + classification_id = response.json()["data"]["classification_id"] + assert classification_id is not None + + # now rerun the filter + filter["group_id"] = saved_group_id + del filter["autosave"]["ignore_group_ids"] + + passed_filters = self.worker.alert_filter__user_defined([filter], self.alert) + assert passed_filters is not None + assert len(passed_filters) == 1 + assert "auto_followup" in passed_filters[0] + assert ( + passed_filters[0]["auto_followup"]["data"]["payload"]["observation_type"] + == "IFU" + ) + + alert, prv_candidates = self.worker.alert_mongify(self.alert) + self.worker.alert_sentinel_skyportal(alert, prv_candidates, passed_filters) + + # now fetch the source from SP + response = self.worker.api_skyportal( + "GET", f"/api/sources/{alert['objectId']}", None + ) + assert response.status_code == 200 + source = response.json()["data"] + assert source["id"] == "ZTF20aajcbhr" + assert len(source["groups"]) == 1 + # should only be saved to the group of the first filter + assert source["groups"][0]["id"] == saved_group_id + + # verify that there is a follow-up request + response = self.worker.api_skyportal( + "GET", f"/api/followup_request?sourceID={alert['objectId']}", None + ) + assert response.status_code == 200 + followup_requests = response.json()["data"].get("followup_requests", []) + followup_requests = [ + f + for f in followup_requests + if (f["allocation_id"] == allocation_id and f["status"] == "submitted") + ] + assert len(followup_requests) == 0 + + # delete the classification + response = self.worker.api_skyportal( + "DELETE", f"/api/classification/{classification_id}", None + ) + + # unsave the source from the group + response = self.worker.api_skyportal( + "POST", + "/api/source_groups", + { + "objId": alert["objectId"], + "unsaveGroupIds": [saved_group_id], + }, + ) + assert response.status_code == 200 + + delete_alert(self.worker, self.alert) From 9f4c155d2b44567a296f3dfcc6a1d125ad08d86d Mon Sep 17 00:00:00 2001 From: Theophile du Laz Date: Thu, 28 Sep 2023 15:54:47 -0700 Subject: [PATCH 3/5] Temporary: use all candidates (not just prv_candidates) to generate lightcurve (#249) While we are retrieving data in prod for the prv_candidates, make sure to post all the candidates for an object as detections. --- kowalski/alert_brokers/alert_broker.py | 77 +++++++++++++++++++++++++- 1 file changed, 75 insertions(+), 2 deletions(-) diff --git a/kowalski/alert_brokers/alert_broker.py b/kowalski/alert_brokers/alert_broker.py index 6d6cc4be..a9b94cfc 100644 --- a/kowalski/alert_brokers/alert_broker.py +++ b/kowalski/alert_brokers/alert_broker.py @@ -45,6 +45,9 @@ time_stamp, timer, ) +from warnings import simplefilter + +simplefilter(action="ignore", category=pd.errors.PerformanceWarning) # Tensorflow is problematic for Mac's currently, so we can add an option to disable it USE_TENSORFLOW = os.environ.get("USE_TENSORFLOW", True) in [ @@ -1680,6 +1683,37 @@ def alert_sentinel_skyportal(self, alert, prv_candidates, passed_filters): log(e) alert["prv_candidates"] = prv_candidates + # also get all the alerts for this object, to make sure to have all the detections + try: + all_alerts = list( + retry(self.mongo.db[self.collection_alerts].find)( + { + "objectId": alert["objectId"], + "candid": {"$ne": alert["candid"]}, + }, + { + "candidate": 1, + }, + ) + ) + all_alerts = [ + {**a["candidate"]} for a in all_alerts if "candidate" in a + ] + # add to prv_candidates the detections that are not already in there + # use the jd and the fid to match + for a in all_alerts: + if not any( + [ + (a["jd"] == p["jd"]) and (a["fid"] == p["fid"]) + for p in alert["prv_candidates"] + ] + ): + alert["prv_candidates"].append(a) + del all_alerts + except Exception as e: + # this should never happen, but just in case + log(f"Failed to get all alerts for {alert['objectId']}: {e}") + self.alert_put_photometry(alert) # post thumbnails @@ -1768,6 +1802,37 @@ def alert_sentinel_skyportal(self, alert, prv_candidates, passed_filters): # post alert photometry in single call to /api/photometry alert["prv_candidates"] = prv_candidates + # also get all the alerts for this object, to make sure to have all the detections + try: + all_alerts = list( + retry(self.mongo.db[self.collection_alerts].find)( + { + "objectId": alert["objectId"], + "candid": {"$ne": alert["candid"]}, + }, + { + "candidate": 1, + }, + ) + ) + all_alerts = [ + {**a["candidate"]} for a in all_alerts if "candidate" in a + ] + # add to prv_candidates the detections that are not already in there + # use the jd and the fid to match + for a in all_alerts: + if not any( + [ + (a["jd"] == p["jd"]) and (a["fid"] == p["fid"]) + for p in alert["prv_candidates"] + ] + ): + alert["prv_candidates"].append(a) + del all_alerts + except Exception as e: + # this should never happen, but just in case + log(f"Failed to get all alerts for {alert['objectId']}: {e}") + self.alert_put_photometry(alert) if len(autosave_group_ids): @@ -1888,14 +1953,22 @@ def alert_sentinel_skyportal(self, alert, prv_candidates, passed_filters): ) if response.json()["status"] != "success": raise ValueError( - response.json()["message"] + response.json().get( + "message", + "unknow error posting comment", + ) ) except Exception as e: log( f"Failed to post followup comment {comment['text']} for {alert['objectId']} to SkyPortal: {e}" ) else: - raise ValueError(response.json()["message"]) + raise ValueError( + response.json().get( + "message", + "unknow error posting followup request", + ) + ) except Exception as e: log( f"Failed to post followup request for {alert['objectId']} to SkyPortal: {e}" From 63c12586f5882606631d97ee42bd7a78ebdbcdf7 Mon Sep 17 00:00:00 2001 From: Theophile du Laz Date: Fri, 29 Sep 2023 11:02:01 -0700 Subject: [PATCH 4/5] ML feature bugfix, remove recent photometry addition due to latency (#251) minor ML feature bugfix, and remove extra photometry block for latency sake --- kowalski/alert_brokers/alert_broker.py | 31 -------------------------- kowalski/utils.py | 2 +- 2 files changed, 1 insertion(+), 32 deletions(-) diff --git a/kowalski/alert_brokers/alert_broker.py b/kowalski/alert_brokers/alert_broker.py index a9b94cfc..c64b8a8a 100644 --- a/kowalski/alert_brokers/alert_broker.py +++ b/kowalski/alert_brokers/alert_broker.py @@ -1802,37 +1802,6 @@ def alert_sentinel_skyportal(self, alert, prv_candidates, passed_filters): # post alert photometry in single call to /api/photometry alert["prv_candidates"] = prv_candidates - # also get all the alerts for this object, to make sure to have all the detections - try: - all_alerts = list( - retry(self.mongo.db[self.collection_alerts].find)( - { - "objectId": alert["objectId"], - "candid": {"$ne": alert["candid"]}, - }, - { - "candidate": 1, - }, - ) - ) - all_alerts = [ - {**a["candidate"]} for a in all_alerts if "candidate" in a - ] - # add to prv_candidates the detections that are not already in there - # use the jd and the fid to match - for a in all_alerts: - if not any( - [ - (a["jd"] == p["jd"]) and (a["fid"] == p["fid"]) - for p in alert["prv_candidates"] - ] - ): - alert["prv_candidates"].append(a) - del all_alerts - except Exception as e: - # this should never happen, but just in case - log(f"Failed to get all alerts for {alert['objectId']}: {e}") - self.alert_put_photometry(alert) if len(autosave_group_ids): diff --git a/kowalski/utils.py b/kowalski/utils.py index 221dde12..740864d5 100644 --- a/kowalski/utils.py +++ b/kowalski/utils.py @@ -1193,7 +1193,7 @@ def __init__(self, alert, alert_history, models, label=None, **kwargs): alert["candidate"].get("jdstarthist", None), min( [alert["candidate"]["jd"]] - + [a["jd"] for a in alert_history if a["magpsf"] is not None] + + [a["jd"] for a in alert_history if a.get("magpsf", None) is not None] ), ) From bd886c7951ed2694f99b36532bc8158e12c35814 Mon Sep 17 00:00:00 2001 From: Theophile du Laz Date: Wed, 4 Oct 2023 17:04:04 -0700 Subject: [PATCH 5/5] Update auto follow-up request priority (#252) * update follow-up requests priority, prevent duplicates better, get full alert history for ML features --- kowalski/alert_brokers/alert_broker.py | 131 ++++++++++++++++++--- kowalski/alert_brokers/alert_broker_ztf.py | 13 +- kowalski/tests/test_alert_broker_ztf.py | 69 ++++++++++- kowalski/utils.py | 20 ++++ 4 files changed, 211 insertions(+), 22 deletions(-) diff --git a/kowalski/alert_brokers/alert_broker.py b/kowalski/alert_brokers/alert_broker.py index c64b8a8a..3b75138c 100644 --- a/kowalski/alert_brokers/alert_broker.py +++ b/kowalski/alert_brokers/alert_broker.py @@ -44,6 +44,7 @@ retry, time_stamp, timer, + compare_dicts, ) from warnings import simplefilter @@ -1843,7 +1844,14 @@ def alert_sentinel_skyportal(self, alert, prv_candidates, passed_filters): ] if len(passed_filters_followup) > 0: - # first fetch the followup requests on SkyPortal for this alert + # first sort all the filters by priority (highest first) + passed_filters_followup = sorted( + passed_filters_followup, + key=lambda f: f["auto_followup"]["data"]["payload"]["priority"], + reverse=True, + ) + + # then, fetch the existing followup requests on SkyPortal for this alert with timer( f"Getting followup requests for {alert['objectId']} from SkyPortal", self.verbose > 1, @@ -1860,23 +1868,35 @@ def alert_sentinel_skyportal(self, alert, prv_candidates, passed_filters): for r in existing_requests if r["status"] in ["completed", "submitted"] ] + # sort by priority (highest first) + existing_requests = sorted( + existing_requests, + key=lambda r: r["payload"]["priority"], + reverse=True, + ) else: log(f"Failed to get followup requests for {alert['objectId']}") existing_requests = [] + for passed_filter in passed_filters_followup: - # post a followup request with the payload and allocation_id - # if there isn't already a pending request for this alert and this allocation_id - if ( - len( - [ - r - for r in existing_requests - if r["allocation_id"] - == passed_filter["auto_followup"]["allocation_id"] - ] + # look for existing requests with the same allocation, group, and payload + existing_requests_filtered = [ + (i, r) + for (i, r) in enumerate(existing_requests) + if r["allocation_id"] + == passed_filter["auto_followup"]["allocation_id"] + and set([passed_filter["group_id"]]).issubset( + [g["id"] for g in r["target_groups"]] ) - == 0 - ): + and compare_dicts( + passed_filter["auto_followup"]["data"]["payload"], + r["payload"], + ignore_keys=["priority", "start_date", "end_date"], + ) + is True + ] + if len(existing_requests_filtered) == 0: + # if no existing request, post a new one with timer( f"Posting auto followup request for {alert['objectId']} to SkyPortal", self.verbose > 1, @@ -1899,6 +1919,24 @@ def alert_sentinel_skyportal(self, alert, prv_candidates, passed_filters): log( f"Posted followup request for {alert['objectId']} to SkyPortal" ) + # add it to the existing requests + existing_requests.append( + { + "allocation_id": passed_filter["auto_followup"][ + "allocation_id" + ], + "payload": passed_filter["auto_followup"][ + "data" + ]["payload"], + "target_groups": [ + { + "id": passed_filter["group_id"], + } + ], + "status": "submitted", + } + ) + if ( passed_filter["auto_followup"].get("comment", None) is not None @@ -1943,6 +1981,67 @@ def alert_sentinel_skyportal(self, alert, prv_candidates, passed_filters): f"Failed to post followup request for {alert['objectId']} to SkyPortal: {e}" ) else: - log( - f"Pending Followup request for {alert['objectId']} and allocation_id {passed_filter['auto_followup']['allocation_id']} already exists on SkyPortal" - ) + # if there is an existing request, but the priority is lower than the one we want to post, + # update the existing request with the new priority + request_to_update = existing_requests_filtered[0][1] + if ( + passed_filter["auto_followup"]["data"]["payload"]["priority"] + > request_to_update["payload"]["priority"] + ): + with timer( + f"Updating priority of auto followup request for {alert['objectId']} to SkyPortal", + self.verbose > 1, + ): + # to update, the api needs to get the request id, target group id, and payload + # so we'll basically get that from the existing request, and simply update the priority + try: + data = { + "payload": { + **request_to_update["payload"], + "priority": passed_filter["auto_followup"][ + "data" + ]["payload"]["priority"], + }, + "obj_id": alert["objectId"], + "allocation_id": request_to_update["allocation_id"], + } + response = self.api_skyportal( + "PUT", + f"/api/followup_request/{request_to_update['id']}", + data, + ) + if ( + response.json()["status"] == "success" + and response.json() + .get("data", {}) + .get("ignored", False) + is False + ): + log( + f"Updated priority of followup request for {alert['objectId']} to SkyPortal" + ) + # update the existing_requests list + existing_requests[existing_requests_filtered[0][0]][ + "priority" + ] = passed_filter["auto_followup"]["data"][ + "payload" + ][ + "priority" + ] + + # TODO: post a comment to the source to mention the update + else: + raise ValueError( + response.json().get( + "message", + "unknow error updating followup request", + ) + ) + except Exception as e: + log( + f"Failed to update priority of followup request for {alert['objectId']} to SkyPortal: {e}" + ) + else: + log( + f"Pending Followup request for {alert['objectId']} and allocation_id {passed_filter['auto_followup']['allocation_id']} already exists on SkyPortal, no need for update" + ) diff --git a/kowalski/alert_brokers/alert_broker_ztf.py b/kowalski/alert_brokers/alert_broker_ztf.py index a593d3fa..d355d6ad 100644 --- a/kowalski/alert_brokers/alert_broker_ztf.py +++ b/kowalski/alert_brokers/alert_broker_ztf.py @@ -75,7 +75,18 @@ def process_alert(alert: Mapping, topic: str): and len(existing_aux.get("prv_candidates", [])) > 0 ): all_prv_candidates += existing_aux["prv_candidates"] - del existing_aux + + # get all alerts for this objectId: + existing_alerts = list( + alert_worker.mongo.db[alert_worker.collection_alerts].find( + {"objectId": object_id}, {"candidate": 1} + ) + ) + if len(existing_alerts) > 0: + all_prv_candidates += [ + existing_alert["candidate"] for existing_alert in existing_alerts + ] + del existing_aux, existing_alerts # ML models: with timer(f"MLing of {object_id} {candid}", alert_worker.verbose > 1): diff --git a/kowalski/tests/test_alert_broker_ztf.py b/kowalski/tests/test_alert_broker_ztf.py index 92740151..2e2bcd50 100644 --- a/kowalski/tests/test_alert_broker_ztf.py +++ b/kowalski/tests/test_alert_broker_ztf.py @@ -298,24 +298,36 @@ def test_alert_filter__user_defined_followup_with_broker(self): "allocation_id": allocation_id, "payload": { # example payload for SEDM "observation_type": "IFU", - "priority": 3, + "priority": 2, }, } - passed_filters = self.worker.alert_filter__user_defined([filter], self.alert) + # make a copy of that filter, but with priority 3 + filter2 = deepcopy(filter) + filter2["auto_followup"]["payload"]["priority"] = 3 + passed_filters = self.worker.alert_filter__user_defined( + [filter, filter2], self.alert + ) assert passed_filters is not None - assert len(passed_filters) == 1 + assert len(passed_filters) == 2 # both filters should have passed assert "auto_followup" in passed_filters[0] assert ( passed_filters[0]["auto_followup"]["data"]["payload"]["observation_type"] == "IFU" ) - assert passed_filters[0]["auto_followup"]["data"]["payload"]["priority"] == 3 + assert passed_filters[0]["auto_followup"]["data"]["payload"]["priority"] == 2 + assert "auto_followup" in passed_filters[1] + assert ( + passed_filters[1]["auto_followup"]["data"]["payload"]["observation_type"] + == "IFU" + ) + assert passed_filters[1]["auto_followup"]["data"]["payload"]["priority"] == 3 alert, prv_candidates = self.worker.alert_mongify(self.alert) self.worker.alert_sentinel_skyportal(alert, prv_candidates, passed_filters) # now fetch the follow-up request from SP + # it should have deduplicated and used the highest priority response = self.worker.api_skyportal( "GET", f"/api/followup_request?sourceID={alert['objectId']}", None ) @@ -328,7 +340,54 @@ def test_alert_filter__user_defined_followup_with_broker(self): ] assert len(followup_requests) == 1 assert followup_requests[0]["payload"]["observation_type"] == "IFU" - assert followup_requests[0]["payload"]["priority"] == 3 + assert ( + followup_requests[0]["payload"]["priority"] == 3 + ) # it should have deduplicated and used the highest priority + + # now run it once more, but with a higher priority to see if the update works + filter2["auto_followup"]["payload"]["priority"] = 4 + passed_filters = self.worker.alert_filter__user_defined( + [filter, filter2], self.alert + ) + + assert passed_filters is not None + assert len(passed_filters) == 2 + assert "auto_followup" in passed_filters[0] + assert ( + passed_filters[0]["auto_followup"]["data"]["payload"]["observation_type"] + == "IFU" + ) + assert passed_filters[0]["auto_followup"]["data"]["payload"]["priority"] == 2 + assert "auto_followup" in passed_filters[1] + assert ( + passed_filters[1]["auto_followup"]["data"]["payload"]["observation_type"] + == "IFU" + ) + assert passed_filters[1]["auto_followup"]["data"]["payload"]["priority"] == 4 + + alert, prv_candidates = self.worker.alert_mongify(self.alert) + self.worker.alert_sentinel_skyportal(alert, prv_candidates, passed_filters) + + # now fetch the follow-up request from SP + # it should have deduplicated and used the highest priority + response = self.worker.api_skyportal( + "GET", f"/api/followup_request?sourceID={alert['objectId']}", None + ) + assert response.status_code == 200 + followup_requests_updated = response.json()["data"].get("followup_requests", []) + followup_requests_updated = [ + f + for f in followup_requests_updated + if (f["allocation_id"] == allocation_id and f["status"] == "submitted") + ] + assert len(followup_requests_updated) == 1 + assert followup_requests_updated[0]["payload"]["observation_type"] == "IFU" + assert ( + followup_requests_updated[0]["payload"]["priority"] == 4 + ) # it should have deduplicated and used the highest priority + assert ( + followup_requests_updated[0]["id"] == followup_requests[0]["id"] + ) # the id should be the same # delete the follow-up request response = self.worker.api_skyportal( diff --git a/kowalski/utils.py b/kowalski/utils.py index 740864d5..4f018cd9 100644 --- a/kowalski/utils.py +++ b/kowalski/utils.py @@ -1168,6 +1168,26 @@ def str_to_numeric(s): return float(s) +def compare_dicts(a: dict, b: dict, ignore_keys=[], same_keys=False): + """Compare two followup payloads, making sure that a is the same as b or a subset of b, ignoring certain keys""" + if same_keys and len(a) != len(b): + return False + for k, v in a.items(): + if k in ignore_keys: + continue + if k not in b: + return False + if isinstance(v, dict): + if not compare_dicts(v, b[k]): + return False + elif isinstance(v, list): + if not all([i in b[k] for i in v]): + return False + elif b[k] != v: + return False + return True + + class ZTFAlert: def __init__(self, alert, alert_history, models, label=None, **kwargs): self.kwargs = kwargs