From 2b06795f0e51f0dd3e06a9cc47e4734c66ad8199 Mon Sep 17 00:00:00 2001 From: Brianna Smart Date: Fri, 16 Feb 2024 15:46:21 -0800 Subject: [PATCH] Update unit tests and import handling --- python/lsst/ap/association/diaPipe.py | 6 +- python/lsst/ap/association/packageAlerts.py | 183 +++++++------------- tests/test_packageAlerts.py | 113 +++++++----- 3 files changed, 136 insertions(+), 166 deletions(-) diff --git a/python/lsst/ap/association/diaPipe.py b/python/lsst/ap/association/diaPipe.py index 318b4512..02bb630b 100644 --- a/python/lsst/ap/association/diaPipe.py +++ b/python/lsst/ap/association/diaPipe.py @@ -34,7 +34,6 @@ import numpy as np import pandas as pd -import logging from lsst.daf.base import DateTime import lsst.dax.apdb as daxApdb @@ -51,9 +50,6 @@ PackageAlertsTask) from lsst.ap.association.ssoAssociation import SolarSystemAssociationTask -_log = logging.getLogger("lsst." + __name__) -_log.setLevel(logging.DEBUG) - class DiaPipelineConnections( pipeBase.PipelineTaskConnections, @@ -554,7 +550,7 @@ def run(self, template) except ValueError as err: # Continue processing even if alert sending fails - _log.error(err) + self.log.error(err) return pipeBase.Struct(apdbMarker=self.config.apdb.value, associatedDiaSources=associatedDiaSources, diff --git a/python/lsst/ap/association/packageAlerts.py b/python/lsst/ap/association/packageAlerts.py index 08a2d968..bf438481 100644 --- a/python/lsst/ap/association/packageAlerts.py +++ b/python/lsst/ap/association/packageAlerts.py @@ -24,13 +24,18 @@ import io import os import sys -import logging from astropy import wcs import astropy.units as u from astropy.nddata import CCDData, VarianceUncertainty import pandas as pd import struct +import fastavro +try: + import confluent_kafka + from confluent_kafka import KafkaException +except ImportError: + confluent_kafka = None import lsst.alert.packet as alertPack import lsst.afw.geom as afwGeom @@ -39,13 +44,6 @@ from lsst.pex.exceptions import InvalidParameterError import lsst.pipe.base as pipeBase from lsst.utils.timer import timeMethod -import fastavro - -"""Methods for packaging Apdb and Pipelines data into Avro alerts. -""" - -_log = logging.getLogger("lsst." + __name__) -_log.setLevel(logging.DEBUG) class PackageAlertsConfig(pexConfig.Config): @@ -71,13 +69,13 @@ class PackageAlertsConfig(pexConfig.Config): doProduceAlerts = pexConfig.Field( dtype=bool, - doc="Turn on alert production to kafka if true. Set to false by default", + doc="Turn on alert production to kafka if true and if confluent_kafka is in the environment.", default=False, ) doWriteAlerts = pexConfig.Field( dtype=bool, - doc="Write alerts to disk if true. Set to true by default", + doc="Write alerts to disk if true.", default=True, ) @@ -97,50 +95,49 @@ def __init__(self, **kwargs): if self.config.doProduceAlerts: - self.password = os.getenv("AP_KAFKA_PRODUCER_PASSWORD") - self.username = os.getenv("AP_KAFKA_PRODUCER_USERNAME") - self.server = os.getenv("AP_KAFKA_SERVER") - self.kafka_topic = os.getenv("AP_KAFKA_TOPIC") - # confluent_kafka configures all of its classes with dictionaries. This one - # sets up the bare minimum that is needed. - self.kafka_config = { - # This is the URL to use to connect to the Kafka cluster. - "bootstrap.servers": self.server, - # These next two properties tell the Kafka client about the specific - # authentication and authorization protocols that should be used when - # connecting. - "security.protocol": "SASL_PLAINTEXT", - "sasl.mechanisms": "SCRAM-SHA-512", - # The sasl.username and sasl.password are passed through over - # SCRAM-SHA-512 auth to connect to the cluster. The username is not - # sensitive, but the password is (of course) a secret value which - # should never be committed to source code. - "sasl.username": self.username, - "sasl.password": self.password, - # Batch size limits the largest size of a kafka alert that can be sent. - # We set the batch size to 2 Mb. - "batch.size": 2097152, - "linger.ms": 5, - } + if confluent_kafka is not None: + + self.password = os.getenv("AP_KAFKA_PRODUCER_PASSWORD") + self.username = os.getenv("AP_KAFKA_PRODUCER_USERNAME") + self.server = os.getenv("AP_KAFKA_SERVER") + self.kafkaTopic = os.getenv("AP_KAFKA_TOPIC") + + if not self.password: + raise ValueError("Kafka password environment variable was not set.") + if not self.username: + raise ValueError("Kafka username environment variable was not set.") + if not self.server: + raise ValueError("Kafka server environment variable was not set.") + if not self.kafkaTopic: + raise ValueError("Kafka topic environment variable was not set.") + + # confluent_kafka configures all of its classes with dictionaries. This one + # sets up the bare minimum that is needed. + self.kafkaConfig = { + # This is the URL to use to connect to the Kafka cluster. + "bootstrap.servers": self.server, + # These next two properties tell the Kafka client about the specific + # authentication and authorization protocols that should be used when + # connecting. + "security.protocol": "SASL_PLAINTEXT", + "sasl.mechanisms": "SCRAM-SHA-512", + # The sasl.username and sasl.password are passed through over + # SCRAM-SHA-512 auth to connect to the cluster. The username is not + # sensitive, but the password is (of course) a secret value which + # should never be committed to source code. + "sasl.username": self.username, + "sasl.password": self.password, + # Batch size limits the largest size of a kafka alert that can be sent. + # We set the batch size to 2 Mb. + "batch.size": 2097152, + "linger.ms": 5, + } + + self.producer = confluent_kafka.Producer(**self.kafkaConfig) - try: - from confluent_kafka import KafkaException - self.kafka_exception = KafkaException - import confluent_kafka - except ImportError as error: - error.add_note("Could not import confluent_kafka. Alerts will not be sent " - "to the alert stream") - _log.error(error) - - if not self.password: - raise ValueError("Kafka password environment variable was not set.") - if not self.username: - raise ValueError("Kafka username environment variable was not set.") - if not self.server: - raise ValueError("Kafka server environment variable was not set.") - if not self.kafka_topic: - raise ValueError("Kafka topic environment variable was not set.") - self.producer = confluent_kafka.Producer(**self.kafka_config) + else: + self.log.error("Produce alerts is set but confluent_kafka is not present in " + "the environment. Alerts will not be sent to the alert stream.") @timeMethod def run(self, @@ -153,6 +150,10 @@ def run(self, ): """Package DiaSources/Object and exposure data into Avro alerts. + Alerts can be sent to the alert stream if ``doProduceAlerts`` is set + and written to disk if ``doWriteAlerts`` is set. Both can be set at the + same time, and are independent of one another. + Writes Avro alerts to a location determined by the ``alertWriteLocation`` configurable. @@ -226,18 +227,12 @@ def run(self, if self.config.doProduceAlerts and "confluent_kafka" in sys.modules: self.produceAlerts(alerts, ccdVisitId) - elif self.config.doProduceAlerts and "confluent_kafka" not in sys.modules: - raise Exception("Produce alerts is set but confluent_kafka is not in the environment.") - if self.config.doWriteAlerts: with open(os.path.join(self.config.alertWriteLocation, f"{ccdVisitId}.avro"), "wb") as f: self.alertSchema.store_alerts(f, alerts) - if not self.config.doProduceAlerts and not self.config.doWriteAlerts: - raise Exception("Neither produce alerts nor write alerts is set.") - def _patchDiaSources(self, diaSources): """Add the ``programId`` column to the data. @@ -249,7 +244,7 @@ def _patchDiaSources(self, diaSources): diaSources["programId"] = 0 def createDiaSourceExtent(self, bboxSize): - """Create a extent for a box for the cutouts given the size of the + """Create an extent for a box for the cutouts given the size of the square BBox that covers the source footprint. Parameters @@ -272,17 +267,17 @@ def createDiaSourceExtent(self, bboxSize): def produceAlerts(self, alerts, ccdVisitId): for alert in alerts: - alert_bytes = self._serialize_alert(alert, schema=self.alertSchema.definition, schema_id=1) + alertBytes = self._serializeAlert(alert, schema=self.alertSchema.definition, schema_id=1) try: - self.producer.produce(self.kafka_topic, alert_bytes, callback=self._delivery_callback) + self.producer.produce(self.kafkaTopic, alertBytes, callback=self._delivery_callback) self.producer.flush() - except self.kafka_exception as e: - _log.error('Kafka error: {}, message was {} bytes'.format(e, sys.getsizeof(alert_bytes))) + except KafkaException as e: + self.log.warning('Kafka error: {}, message was {} bytes'.format(e, sys.getsizeof(alertBytes))) with open(os.path.join(self.config.alertWriteLocation, f"{ccdVisitId}_{alert['alertId']}.avro"), "wb") as f: - f.write(alert_bytes) + f.write(alertBytes) self.producer.flush() @@ -466,7 +461,7 @@ def streamCcdDataToBytes(self, cutout): cutoutBytes = streamer.getvalue() return cutoutBytes - def _serialize_alert(self, alert, schema=None, schema_id=0): + def _serializeAlert(self, alert, schema=None, schema_id=0): """Serialize an alert to a byte sequence for sending to Kafka. Parameters @@ -491,38 +486,12 @@ def _serialize_alert(self, alert, schema=None, schema_id=0): buf = io.BytesIO() # TODO: Use a proper schema versioning system (DM-42606) - buf.write(self._serialize_confluent_wire_header(schema_id)) + buf.write(self._serializeConfluentWireHeader(schema_id)) fastavro.schemaless_writer(buf, schema, alert) return buf.getvalue() - def _deserialize_alert(self, alert_bytes, schema=None): - """Deserialize an alert message from Kafka. - - Paramaters - ---------- - alert_bytes : `bytes` - Binary-encoding serialized Avro alert, including Confluent Wire - Format prefix. - schema : `dict`, optional - An Avro schema definition describing how to encode `alert`. By default, - the latest schema is used. - - Returns - ------- - alert : `dict` - An alert payload. - """ - if schema is None: - schema = self.alertSchema.definition - - header_bytes = alert_bytes[:5] - version = self._deserialize_confluent_wire_header(header_bytes) - assert version == 0 - content_bytes = io.BytesIO(alert_bytes[5:]) - return fastavro.schemaless_reader(content_bytes, schema) - @staticmethod - def _serialize_confluent_wire_header(schema_version): + def _serializeConfluentWireHeader(schema_version): """Returns the byte prefix for Confluent Wire Format-style Kafka messages. Parameters @@ -545,29 +514,9 @@ def _serialize_confluent_wire_header(schema_version): ConfluentWireFormatHeader = struct.Struct(">bi") return ConfluentWireFormatHeader.pack(0, schema_version) - @staticmethod - def _deserialize_confluent_wire_header(raw): - """Parses the byte prefix for Confluent Wire Format-style Kafka messages. - - Parameters - ---------- - raw : `bytes` - The 5-byte encoded message prefix. - - Returns - ------- - schema_version : `int` - A version number which indicates the Confluent Schema Registry ID - number of the Avro schema used to encode the message that follows this - header. - """ - ConfluentWireFormatHeader = struct.Struct(">bi") - _, version = ConfluentWireFormatHeader.unpack(raw) - return version - def _delivery_callback(self, err, msg): if err: - _log.debug('%% Message failed delivery: %s\n' % err) + self.log.debug('%% Message failed delivery: %s\n' % err) else: - _log.debug('%% Message delivered to %s [%d] @ %d\n' % - (msg.topic(), msg.partition(), msg.offset())) + self.log.debug('%% Message delivered to %s [%d] @ %d\n' + % (msg.topic(), msg.partition(), msg.offset())) diff --git a/tests/test_packageAlerts.py b/tests/test_packageAlerts.py index 5c86952f..63a94d09 100644 --- a/tests/test_packageAlerts.py +++ b/tests/test_packageAlerts.py @@ -21,17 +21,24 @@ import io import os +import struct + import numpy as np import pandas as pd import shutil import tempfile import unittest from unittest.mock import patch, Mock -import sys from astropy import wcs from astropy.nddata import CCDData -import logging +import fastavro +try: + import confluent_kafka + from confluent_kafka import KafkaException +except ImportError: + confluent_kafka = None +import lsst.alert.packet as alertPack from lsst.ap.association import PackageAlertsConfig, PackageAlertsTask from lsst.afw.cameraGeom.testUtils import DetectorWrapper import lsst.afw.image as afwImage @@ -43,15 +50,6 @@ import lsst.utils.tests import utils_tests -_log = logging.getLogger("lsst." + __name__) -_log.setLevel(logging.DEBUG) - -try: - import confluent_kafka # noqa: F401 - from confluent_kafka import KafkaException -except ModuleNotFoundError as e: - _log.error('Kafka module not found: {}'.format(e)) - def _roundTripThroughApdb(objects, sources, forcedSources, dateTime): """Run object and source catalogs through the Apdb to get the correct @@ -116,7 +114,6 @@ def mock_alert(alert_id): return { "alertId": alert_id, "diaSource": { - # Below are all the required fields containing random values. "midpointMjdTai": 5, "diaSourceId": 4, "ccdVisitId": 2, @@ -138,11 +135,8 @@ def mock_alert(alert_id): class TestPackageAlerts(lsst.utils.tests.TestCase): - kafka_enabled = "confluent_kafka" in sys.modules def __init__(self, *args, **kwargs): - TestPackageAlerts.kafka_enabled = "confluent_kafka" in sys.modules - _log.debug('TestPackageAlerts: kafka_enabled={}'.format(self.kafka_enabled)) super(TestPackageAlerts, self).__init__(*args, **kwargs) def setUp(self): @@ -210,6 +204,49 @@ def setUp(self): self.cutoutWcs.wcs.cd = self.exposure.getWcs().getCdMatrix() self.cutoutWcs.wcs.ctype = ["RA---TAN", "DEC--TAN"] + def _deserialize_alert(self, alert_bytes): + """Deserialize an alert message from Kafka. + + Parameters + ---------- + alert_bytes : `bytes` + Binary-encoding serialized Avro alert, including Confluent Wire + Format prefix. + + Returns + ------- + alert : `dict` + An alert payload. + """ + schema = alertPack.Schema.from_uri( + str(alertPack.get_uri_to_latest_schema())) + + header_bytes = alert_bytes[:5] + version = self._deserialize_confluent_wire_header(header_bytes) + assert version == 0 + content_bytes = io.BytesIO(alert_bytes[5:]) + return fastavro.schemaless_reader(content_bytes, schema.definition) + + @staticmethod + def _deserialize_confluent_wire_header(raw): + """Parses the byte prefix for Confluent Wire Format-style Kafka messages. + + Parameters + ---------- + raw : `bytes` + The 5-byte encoded message prefix. + + Returns + ------- + schema_version : `int` + A version number which indicates the Confluent Schema Registry ID + number of the Avro schema used to encode the message that follows this + header. + """ + ConfluentWireFormatHeader = struct.Struct(">bi") + _, version = ConfluentWireFormatHeader.unpack(raw) + return version + def testCreateExtent(self): """Test the extent creation for the cutout bbox. """ @@ -332,60 +369,64 @@ def testMakeAlertDict(self): self.assertEqual(alert["cutoutTemplate"], cutoutBytes) + @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') def test_produceAlerts_empty_password(self): """ Test that produceAlerts raises if the password is empty or missing. """ self.environ['AP_KAFKA_PRODUCER_PASSWORD'] = "" with self.assertRaisesRegex(ValueError, "Kafka password"): packConfig = PackageAlertsConfig(doProduceAlerts=True) - packageAlerts = PackageAlertsTask(config=packConfig) # noqa: F841 + PackageAlertsTask(config=packConfig) del self.environ['AP_KAFKA_PRODUCER_PASSWORD'] with self.assertRaisesRegex(ValueError, "Kafka password"): packConfig = PackageAlertsConfig(doProduceAlerts=True) - packageAlerts = PackageAlertsTask(config=packConfig) # noqa: F841 + PackageAlertsTask(config=packConfig) + @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') def test_produceAlerts_empty_username(self): """ Test that produceAlerts raises if the username is empty or missing. """ self.environ['AP_KAFKA_PRODUCER_USERNAME'] = "" with self.assertRaisesRegex(ValueError, "Kafka username"): packConfig = PackageAlertsConfig(doProduceAlerts=True) - packageAlerts = PackageAlertsTask(config=packConfig) # noqa: F841 + PackageAlertsTask(config=packConfig) del self.environ['AP_KAFKA_PRODUCER_USERNAME'] with self.assertRaisesRegex(ValueError, "Kafka username"): packConfig = PackageAlertsConfig(doProduceAlerts=True) - packageAlerts = PackageAlertsTask(config=packConfig) # noqa: F841 + PackageAlertsTask(config=packConfig) + @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') def test_produceAlerts_empty_server(self): """ Test that produceAlerts raises if the server is empty or missing. """ self.environ['AP_KAFKA_SERVER'] = "" with self.assertRaisesRegex(ValueError, "Kafka server"): packConfig = PackageAlertsConfig(doProduceAlerts=True) - packageAlerts = PackageAlertsTask(config=packConfig) # noqa: F841 + PackageAlertsTask(config=packConfig) del self.environ['AP_KAFKA_SERVER'] with self.assertRaisesRegex(ValueError, "Kafka server"): packConfig = PackageAlertsConfig(doProduceAlerts=True) - packageAlerts = PackageAlertsTask(config=packConfig) # noqa: F841 + PackageAlertsTask(config=packConfig) + @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') def test_produceAlerts_empty_topic(self): """ Test that produceAlerts raises if the topic is empty or missing. """ self.environ['AP_KAFKA_TOPIC'] = "" with self.assertRaisesRegex(ValueError, "Kafka topic"): packConfig = PackageAlertsConfig(doProduceAlerts=True) - packageAlerts = PackageAlertsTask(config=packConfig) # noqa: F841 + PackageAlertsTask(config=packConfig) del self.environ['AP_KAFKA_TOPIC'] with self.assertRaisesRegex(ValueError, "Kafka topic"): packConfig = PackageAlertsConfig(doProduceAlerts=True) - packageAlerts = PackageAlertsTask(config=packConfig) # noqa: F841 + PackageAlertsTask(config=packConfig) @patch('confluent_kafka.Producer') - @unittest.skipIf("confluent_kafka" not in sys.modules, 'Kafka is not enabled') + @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') def test_produceAlerts_success(self, mock_producer): """ Test that produceAlerts calls the producer on all provided alerts when the alerts are all under the batch size limit. @@ -405,7 +446,7 @@ def test_produceAlerts_success(self, mock_producer): self.assertEqual(producer_instance.flush.call_count, len(alerts)+1) @patch('confluent_kafka.Producer') - @unittest.skipIf("confluent_kafka" not in sys.modules, 'Kafka is not enabled') + @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') def test_produceAlerts_one_failure(self, mock_producer): """ Test that produceAlerts correctly fails on one alert and is writing the failure to disk. @@ -496,7 +537,7 @@ def testRun_without_produce(self): @patch.object(PackageAlertsTask, 'produceAlerts') @patch('confluent_kafka.Producer') - @unittest.skipIf("confluent_kafka" not in sys.modules, 'Kafka is not enabled') + @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') def testRun_with_produce(self, mock_produceAlerts, mock_producer): """Test that packageAlerts calls produceAlerts when doProduceAlerts is set to True. @@ -513,22 +554,6 @@ def testRun_with_produce(self, mock_produceAlerts, mock_producer): self.assertEqual(mock_produceAlerts.call_count, 1) - def testRun_without_produce_or_write(self): - """Test that packageAlerts calls produceAlerts when doProduceAlerts - is set to True. - """ - packConfig = PackageAlertsConfig(doProduceAlerts=False, - doWriteAlerts=False) - packageAlerts = PackageAlertsTask(config=packConfig) - - with self.assertRaisesRegex(Exception, "Neither produce alerts"): - packageAlerts.run(self.diaSources, - self.diaObjects, - self.diaSourceHistory, - self.diaForcedSources, - self.exposure, - self.exposure) - def test_serialize_alert_round_trip(self): """Test that values in the alert packet exactly round trip. """ @@ -536,8 +561,8 @@ def test_serialize_alert_round_trip(self): packageAlerts = PackageAlertsTask(config=ConfigClass) alert = mock_alert(1) - serialized = PackageAlertsTask._serialize_alert(packageAlerts, alert) - deserialized = PackageAlertsTask._deserialize_alert(packageAlerts, serialized) + serialized = PackageAlertsTask._serializeAlert(packageAlerts, alert) + deserialized = self._deserialize_alert(serialized) for field in alert['diaSource']: self.assertEqual(alert['diaSource'][field], deserialized['diaSource'][field])