Skip to content

Commit

Permalink
Moved functions in unit tests and cleaned up packageAlerts
Browse files Browse the repository at this point in the history
  • Loading branch information
bsmartradio committed Feb 22, 2024
1 parent 60552ea commit 68516ad
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 75 deletions.
16 changes: 6 additions & 10 deletions python/lsst/ap/association/diaPipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,16 +541,12 @@ def run(self,
["diaObjectId", "diaForcedSourceId"],
drop=False,
inplace=True)
try:
self.alertPackager.run(associatedDiaSources,
diaCalResult.diaObjectCat,
loaderResult.diaSources,
diaForcedSources,
diffIm,
template)
except ValueError as err:
# Continue processing even if alert sending fails
self.log.error(err)
self.alertPackager.run(associatedDiaSources,
diaCalResult.diaObjectCat,
loaderResult.diaSources,
diaForcedSources,
diffIm,
template)

return pipeBase.Struct(apdbMarker=self.config.apdb.value,
associatedDiaSources=associatedDiaSources,
Expand Down
37 changes: 22 additions & 15 deletions python/lsst/ap/association/packageAlerts.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
import pandas as pd
import struct
import fastavro
# confluent_kafka is not in the standard Rubin environment as it is a third
# party package and is only needed when producing alerts.
try:
import confluent_kafka
from confluent_kafka import KafkaException
Expand Down Expand Up @@ -94,20 +96,17 @@ def __init__(self, **kwargs):
os.makedirs(self.config.alertWriteLocation, exist_ok=True)

if self.config.doProduceAlerts:

if confluent_kafka is not None:

self.password = os.getenv("AP_KAFKA_PRODUCER_PASSWORD")
self.username = os.getenv("AP_KAFKA_PRODUCER_USERNAME")
self.server = os.getenv("AP_KAFKA_SERVER")
self.kafkaTopic = os.getenv("AP_KAFKA_TOPIC")

if not self.password:
raise ValueError("Kafka password environment variable was not set.")
self.username = os.getenv("AP_KAFKA_PRODUCER_USERNAME")
if not self.username:
raise ValueError("Kafka username environment variable was not set.")
self.server = os.getenv("AP_KAFKA_SERVER")
if not self.server:
raise ValueError("Kafka server environment variable was not set.")
self.kafkaTopic = os.getenv("AP_KAFKA_TOPIC")
if not self.kafkaTopic:
raise ValueError("Kafka topic environment variable was not set.")

Expand All @@ -132,12 +131,11 @@ def __init__(self, **kwargs):
"batch.size": 2097152,
"linger.ms": 5,
}

self.producer = confluent_kafka.Producer(**self.kafkaConfig)

else:
self.log.error("Produce alerts is set but confluent_kafka is not present in "
"the environment. Alerts will not be sent to the alert stream.")
raise RuntimeError("Produce alerts is set but confluent_kafka is not present in "
"the environment. Alerts will not be sent to the alert stream.")

@timeMethod
def run(self,
Expand Down Expand Up @@ -224,7 +222,7 @@ def run(self,
diffImCutout,
templateCutout))

if self.config.doProduceAlerts and confluent_kafka is not None:
if self.config.doProduceAlerts:
self.produceAlerts(alerts, ccdVisitId)

if self.config.doPackageAlerts:
Expand All @@ -242,7 +240,7 @@ def _patchDiaSources(self, diaSources):
diaSources["programId"] = 0

def createDiaSourceExtent(self, bboxSize):
"""Create an extent for a box for the cutouts given the size of the
"""Create an extent for a box for the cutouts given the size of the
square BBox that covers the source footprint.
Parameters
Expand All @@ -263,7 +261,17 @@ def createDiaSourceExtent(self, bboxSize):
return extent

def produceAlerts(self, alerts, ccdVisitId):
"""Serialize alerts and send them to the alert stream using
confluent_kafka's producer.
Parameters
----------
alerts : `dict`
Dictionary of alerts to be sent to the alert stream.
ccdVisitId : `int`
ccdVisitId of the alerts sent to the alert stream. Used to write
out alerts which fail to be sent to the alert stream.
"""
for alert in alerts:
alertBytes = self._serializeAlert(alert, schema=self.alertSchema.definition, schema_id=1)
try:
Expand Down Expand Up @@ -472,7 +480,7 @@ def _serializeAlert(self, alert, schema=None, schema_id=0):
schema_id : `int`, optional
The Confluent Schema Registry ID of the schema. By default, 0 (an
invalid ID) is used, indicating that the schema is not registered.
`
Returns
-------
serialized : `bytes`
Expand Down Expand Up @@ -514,7 +522,6 @@ def _serializeConfluentWireHeader(schema_version):

def _delivery_callback(self, err, msg):
if err:
self.log.debug('%% Message failed delivery: %s\n' % err)
self.log.warning('Message failed delivery: %s\n' % err)
else:
self.log.debug('%% Message delivered to %s [%d] @ %d\n'
% (msg.topic(), msg.partition(), msg.offset()))
self.log.debug('Message delivered to %s [%d] @ %d', msg.topic(), msg.partition(), msg.offset())
71 changes: 21 additions & 50 deletions tests/test_packageAlerts.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import io
import os
import struct

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -134,10 +133,27 @@ def mock_alert(alert_id):
}


class TestPackageAlerts(lsst.utils.tests.TestCase):
def _deserialize_alert(alert_bytes):
"""Deserialize an alert message from Kafka.
Parameters
----------
alert_bytes : `bytes`
Binary-encoding serialized Avro alert, including Confluent Wire
Format prefix.
Returns
-------
alert : `dict`
An alert payload.
"""
schema = alertPack.Schema.from_uri(str(alertPack.get_uri_to_latest_schema()))
content_bytes = io.BytesIO(alert_bytes[5:])

def __init__(self, *args, **kwargs):
super(TestPackageAlerts, self).__init__(*args, **kwargs)
return fastavro.schemaless_reader(content_bytes, schema.definition)


class TestPackageAlerts(lsst.utils.tests.TestCase):

def setUp(self):
patcher = patch.dict(os.environ, {"AP_KAFKA_PRODUCER_PASSWORD": "fake_password",
Expand Down Expand Up @@ -204,49 +220,6 @@ def setUp(self):
self.cutoutWcs.wcs.cd = self.exposure.getWcs().getCdMatrix()
self.cutoutWcs.wcs.ctype = ["RA---TAN", "DEC--TAN"]

def _deserialize_alert(self, alert_bytes):
"""Deserialize an alert message from Kafka.
Parameters
----------
alert_bytes : `bytes`
Binary-encoding serialized Avro alert, including Confluent Wire
Format prefix.
Returns
-------
alert : `dict`
An alert payload.
"""
schema = alertPack.Schema.from_uri(
str(alertPack.get_uri_to_latest_schema()))

header_bytes = alert_bytes[:5]
version = self._deserialize_confluent_wire_header(header_bytes)
self.assertEqual(version, 0)
content_bytes = io.BytesIO(alert_bytes[5:])
return fastavro.schemaless_reader(content_bytes, schema.definition)

@staticmethod
def _deserialize_confluent_wire_header(raw):
"""Parses the byte prefix for Confluent Wire Format-style Kafka messages.
Parameters
----------
raw : `bytes`
The 5-byte encoded message prefix.
Returns
-------
schema_version : `int`
A version number which indicates the Confluent Schema Registry ID
number of the Avro schema used to encode the message that follows this
header.
"""
ConfluentWireFormatHeader = struct.Struct(">bi")
_, version = ConfluentWireFormatHeader.unpack(raw)
return version

def testCreateExtent(self):
"""Test the extent creation for the cutout bbox.
"""
Expand Down Expand Up @@ -453,8 +426,6 @@ def test_produceAlerts_one_failure(self, mock_producer):
"""
counter = 0

# confluent_kafka is not visible to mock_producer and needs to be
# re-imported here.
def mock_produce(*args, **kwargs):
nonlocal counter
counter += 1
Expand Down Expand Up @@ -563,7 +534,7 @@ def test_serialize_alert_round_trip(self):

alert = mock_alert(1)
serialized = PackageAlertsTask._serializeAlert(packageAlerts, alert)
deserialized = self._deserialize_alert(serialized)
deserialized = _deserialize_alert(serialized)

for field in alert['diaSource']:
self.assertEqual(alert['diaSource'][field], deserialized['diaSource'][field])
Expand Down

0 comments on commit 68516ad

Please sign in to comment.