diff --git a/python/lsst/ap/association/packageAlerts.py b/python/lsst/ap/association/packageAlerts.py index 30f0efff..08a2d968 100644 --- a/python/lsst/ap/association/packageAlerts.py +++ b/python/lsst/ap/association/packageAlerts.py @@ -40,17 +40,9 @@ import lsst.pipe.base as pipeBase from lsst.utils.timer import timeMethod import fastavro -from lsst.alert.packet import Schema -try: - from confluent_kafka import KafkaException - import confluent_kafka -except ImportError as error: - error.add_note("Could not import confluent_kafka. Alerts will not be sent to the alert stream") """Methods for packaging Apdb and Pipelines data into Avro alerts. """ -_ConfluentWireFormatHeader = struct.Struct(">bi") -latest_schema = Schema.from_file().definition _log = logging.getLogger("lsst." + __name__) _log.setLevel(logging.DEBUG) @@ -103,31 +95,52 @@ def __init__(self, **kwargs): self.alertSchema = alertPack.Schema.from_uri(self.config.schemaFile) os.makedirs(self.config.alertWriteLocation, exist_ok=True) - self.password = os.getenv("AP_KAFKA_PRODUCER_PASSWORD") - self.username = os.getenv("AP_KAFKA_PRODUCER_USERNAME") - self.server = os.getenv("AP_KAFKA_SERVER") - # confluent_kafka configures all of its classes with dictionaries. This one - # sets up the bare minimum that is needed. - self.kafka_config = { - # This is the URL to use to connect to the Kafka cluster. - "bootstrap.servers": self.server, - # These next two properties tell the Kafka client about the specific - # authentication and authorization protocols that should be used when - # connecting. - "security.protocol": "SASL_PLAINTEXT", - "sasl.mechanisms": "SCRAM-SHA-512", - # The sasl.username and sasl.password are passed through over - # SCRAM-SHA-512 auth to connect to the cluster. The username is not - # sensitive, but the password is (of course) a secret value which - # should never be committed to source code. - "sasl.username": self.username, - "sasl.password": self.password, - # Batch size limits the largest size of a kafka alert that can be sent. - # We set the batch size to 2 Mb. - "batch.size": 2097152, - "linger.ms": 5, - } - self.kafka_topic = os.getenv("AP_KAFKA_TOPIC") + if self.config.doProduceAlerts: + + self.password = os.getenv("AP_KAFKA_PRODUCER_PASSWORD") + self.username = os.getenv("AP_KAFKA_PRODUCER_USERNAME") + self.server = os.getenv("AP_KAFKA_SERVER") + self.kafka_topic = os.getenv("AP_KAFKA_TOPIC") + # confluent_kafka configures all of its classes with dictionaries. This one + # sets up the bare minimum that is needed. + self.kafka_config = { + # This is the URL to use to connect to the Kafka cluster. + "bootstrap.servers": self.server, + # These next two properties tell the Kafka client about the specific + # authentication and authorization protocols that should be used when + # connecting. + "security.protocol": "SASL_PLAINTEXT", + "sasl.mechanisms": "SCRAM-SHA-512", + # The sasl.username and sasl.password are passed through over + # SCRAM-SHA-512 auth to connect to the cluster. The username is not + # sensitive, but the password is (of course) a secret value which + # should never be committed to source code. + "sasl.username": self.username, + "sasl.password": self.password, + # Batch size limits the largest size of a kafka alert that can be sent. + # We set the batch size to 2 Mb. + "batch.size": 2097152, + "linger.ms": 5, + } + + try: + from confluent_kafka import KafkaException + self.kafka_exception = KafkaException + import confluent_kafka + except ImportError as error: + error.add_note("Could not import confluent_kafka. Alerts will not be sent " + "to the alert stream") + _log.error(error) + + if not self.password: + raise ValueError("Kafka password environment variable was not set.") + if not self.username: + raise ValueError("Kafka username environment variable was not set.") + if not self.server: + raise ValueError("Kafka server environment variable was not set.") + if not self.kafka_topic: + raise ValueError("Kafka topic environment variable was not set.") + self.producer = confluent_kafka.Producer(**self.kafka_config) @timeMethod def run(self, @@ -258,31 +271,20 @@ def createDiaSourceExtent(self, bboxSize): def produceAlerts(self, alerts, ccdVisitId): - if not self.password: - raise ValueError("Kafka password environment variable was not set.") - if not self.username: - raise ValueError("Kafka username environment variable was not set.") - if not self.server: - raise ValueError("Kafka server environment variable was not set.") - if not self.kafka_topic: - raise ValueError("Kafka topic environment variable was not set.") - p = confluent_kafka.Producer(**self.kafka_config) - topic = self.kafka_topic - for alert in alerts: alert_bytes = self._serialize_alert(alert, schema=self.alertSchema.definition, schema_id=1) try: - p.produce(topic, alert_bytes, callback=self._delivery_callback) - p.flush() + self.producer.produce(self.kafka_topic, alert_bytes, callback=self._delivery_callback) + self.producer.flush() - except KafkaException as e: + except self.kafka_exception as e: _log.error('Kafka error: {}, message was {} bytes'.format(e, sys.getsizeof(alert_bytes))) with open(os.path.join(self.config.alertWriteLocation, f"{ccdVisitId}_{alert['alertId']}.avro"), "wb") as f: f.write(alert_bytes) - p.flush() + self.producer.flush() def createCcdDataCutout(self, image, skyCenter, extent, photoCalib, srcId): """Grab an image as a cutout and return a calibrated CCDData image. @@ -540,7 +542,8 @@ def _serialize_confluent_wire_header(schema_version): The Confluent Wire Format is described more fully here: https://docs.confluent.io/current/schema-registry/serdes-develop/index.html#wire-format """ - return _ConfluentWireFormatHeader.pack(0, schema_version) + ConfluentWireFormatHeader = struct.Struct(">bi") + return ConfluentWireFormatHeader.pack(0, schema_version) @staticmethod def _deserialize_confluent_wire_header(raw): @@ -558,7 +561,8 @@ def _deserialize_confluent_wire_header(raw): number of the Avro schema used to encode the message that follows this header. """ - _, version = _ConfluentWireFormatHeader.unpack(raw) + ConfluentWireFormatHeader = struct.Struct(">bi") + _, version = ConfluentWireFormatHeader.unpack(raw) return version def _delivery_callback(self, err, msg): diff --git a/tests/test_packageAlerts.py b/tests/test_packageAlerts.py index 64d4822c..9e307635 100644 --- a/tests/test_packageAlerts.py +++ b/tests/test_packageAlerts.py @@ -30,10 +30,7 @@ import sys from astropy import wcs from astropy.nddata import CCDData -try: - import confluent_kafka -except ImportError as error: - error.msg += ("Could not import confluent_kafka. Alerts will not be sent to the alert stream") +import logging from lsst.ap.association import PackageAlertsConfig, PackageAlertsTask from lsst.afw.cameraGeom.testUtils import DetectorWrapper @@ -46,6 +43,15 @@ import lsst.utils.tests import utils_tests +_log = logging.getLogger("lsst." + __name__) +_log.setLevel(logging.DEBUG) + +try: + import confluent_kafka # noqa: F401 + from confluent_kafka import KafkaException +except ModuleNotFoundError as e: + _log.error('Kafka module not found: {}'.format(e)) + def _roundTripThroughApdb(objects, sources, forcedSources, dateTime): """Run object and source catalogs through the Apdb to get the correct @@ -105,32 +111,37 @@ def _roundTripThroughApdb(objects, sources, forcedSources, dateTime): def mock_alert(alert_id): - """Generate a minimal mock alert. """ + """Generate a minimal mock alert. + """ return { "alertId": alert_id, "diaSource": { - # Below are all the required fields. Set them to zero. - "midpointMjdTai": 0, - "diaSourceId": 0, - "ccdVisitId": 0, - "filterName": "", + # Below are all the required fields containing random values. + "midpointMjdTai": 5, + "diaSourceId": 4, + "ccdVisitId": 2, "band": 'g', - "programId": 0, - "ra": 0, - "dec": 0, - "x": 0, - "y": 0, - "apFlux": 0, - "apFluxErr": 0, - "snr": 0, - "psfFlux": 0, - "psfFluxErr": 0, + "ra": 12.5, + "dec": -16.9, + "x": 15.7, + "y": 89.8, + "apFlux": 54.85, + "apFluxErr": 70.0, + "snr": 6.7, + "psfFlux": 700.0, + "psfFluxErr": 90.0, "flags": 0, } } class TestPackageAlerts(lsst.utils.tests.TestCase): + kafka_enabled = "confluent_kafka" in sys.modules + + def __init__(self, *args, **kwargs): + TestPackageAlerts.kafka_enabled = "confluent_kafka" in sys.modules + _log.debug('TestPackageAlerts: kafka_enabled={}'.format(self.kafka_enabled)) + super(TestPackageAlerts, self).__init__(*args, **kwargs) def setUp(self): patcher = patch.dict(os.environ, {"AP_KAFKA_PRODUCER_PASSWORD": "fake_password", @@ -323,62 +334,62 @@ def test_produceAlerts_empty_password(self): """ Test that produceAlerts raises if the password is empty or missing. """ self.environ['AP_KAFKA_PRODUCER_PASSWORD'] = "" - task = PackageAlertsTask() with self.assertRaisesRegex(ValueError, "Kafka password"): - task.produceAlerts(None, None) + packConfig = PackageAlertsConfig(doProduceAlerts=True) + packageAlerts = PackageAlertsTask(config=packConfig) # noqa: F841 del self.environ['AP_KAFKA_PRODUCER_PASSWORD'] - task = PackageAlertsTask() with self.assertRaisesRegex(ValueError, "Kafka password"): - task.produceAlerts(None, None) + packConfig = PackageAlertsConfig(doProduceAlerts=True) + packageAlerts = PackageAlertsTask(config=packConfig) # noqa: F841 def test_produceAlerts_empty_username(self): """ Test that produceAlerts raises if the username is empty or missing. """ self.environ['AP_KAFKA_PRODUCER_USERNAME'] = "" - task = PackageAlertsTask() with self.assertRaisesRegex(ValueError, "Kafka username"): - task.produceAlerts(None, None) + packConfig = PackageAlertsConfig(doProduceAlerts=True) + packageAlerts = PackageAlertsTask(config=packConfig) # noqa: F841 del self.environ['AP_KAFKA_PRODUCER_USERNAME'] - task = PackageAlertsTask() with self.assertRaisesRegex(ValueError, "Kafka username"): - task.produceAlerts(None, None) + packConfig = PackageAlertsConfig(doProduceAlerts=True) + packageAlerts = PackageAlertsTask(config=packConfig) # noqa: F841 def test_produceAlerts_empty_server(self): """ Test that produceAlerts raises if the server is empty or missing. """ self.environ['AP_KAFKA_SERVER'] = "" - task = PackageAlertsTask() with self.assertRaisesRegex(ValueError, "Kafka server"): - task.produceAlerts(None, None) + packConfig = PackageAlertsConfig(doProduceAlerts=True) + packageAlerts = PackageAlertsTask(config=packConfig) # noqa: F841 del self.environ['AP_KAFKA_SERVER'] - task = PackageAlertsTask() with self.assertRaisesRegex(ValueError, "Kafka server"): - task.produceAlerts(None, None) + packConfig = PackageAlertsConfig(doProduceAlerts=True) + packageAlerts = PackageAlertsTask(config=packConfig) # noqa: F841 def test_produceAlerts_empty_topic(self): """ Test that produceAlerts raises if the topic is empty or missing. """ self.environ['AP_KAFKA_TOPIC'] = "" - task = PackageAlertsTask() with self.assertRaisesRegex(ValueError, "Kafka topic"): - task.produceAlerts(None, None) + packConfig = PackageAlertsConfig(doProduceAlerts=True) + packageAlerts = PackageAlertsTask(config=packConfig) # noqa: F841 del self.environ['AP_KAFKA_TOPIC'] - task = PackageAlertsTask() with self.assertRaisesRegex(ValueError, "Kafka topic"): - task.produceAlerts(None, None) + packConfig = PackageAlertsConfig(doProduceAlerts=True) + packageAlerts = PackageAlertsTask(config=packConfig) # noqa: F841 @patch('confluent_kafka.Producer') - @unittest.skipUnless("confluent_kafka" in sys.modules, "Test requires confluent_kafka.") + @unittest.skipIf("confluent_kafka" not in sys.modules, 'Kafka is not enabled') def test_produceAlerts_success(self, mock_producer): """ Test that produceAlerts calls the producer on all provided alerts when the alerts are all under the batch size limit. """ - - task = PackageAlertsTask() + packConfig = PackageAlertsConfig(doProduceAlerts=True) + packageAlerts = PackageAlertsTask(config=packConfig) alerts = [mock_alert(1), mock_alert(2)] ccdVisitId = 123 @@ -386,28 +397,31 @@ def test_produceAlerts_success(self, mock_producer): producer_instance = mock_producer.return_value producer_instance.produce = Mock() producer_instance.flush = Mock() - task.produceAlerts(alerts, ccdVisitId) + packageAlerts.produceAlerts(alerts, ccdVisitId) self.assertEqual(producer_instance.produce.call_count, len(alerts)) self.assertEqual(producer_instance.flush.call_count, len(alerts)+1) @patch('confluent_kafka.Producer') - @unittest.skipUnless("confluent_kafka" in sys.modules, "Test requires confluent_kafka.") + @unittest.skipIf("confluent_kafka" not in sys.modules, 'Kafka is not enabled') def test_produceAlerts_one_failure(self, mock_producer): """ Test that produceAlerts correctly fails on one alert and is writing the failure to disk. """ counter = 0 + # confluent_kafka is not visible to mock_producer and needs to be + # re-imported here. def mock_produce(*args, **kwargs): nonlocal counter counter += 1 if counter == 2: - raise confluent_kafka.KafkaException + raise KafkaException else: return - task = PackageAlertsTask() + packConfig = PackageAlertsConfig(doProduceAlerts=True) + packageAlerts = PackageAlertsTask(config=packConfig) patcher = patch("builtins.open") patch_open = patcher.start() @@ -418,7 +432,7 @@ def mock_produce(*args, **kwargs): producer_instance.produce = Mock(side_effect=mock_produce) producer_instance.flush = Mock() - task.produceAlerts(alerts, ccdVisitId) + packageAlerts.produceAlerts(alerts, ccdVisitId) self.assertEqual(producer_instance.produce.call_count, len(alerts)) self.assertEqual(patch_open.call_count, 1) @@ -479,12 +493,12 @@ def testRun_without_produce(self): shutil.rmtree(tempdir) @patch.object(PackageAlertsTask, 'produceAlerts') - @unittest.skipUnless("confluent_kafka" in sys.modules, "Test requires confluent_kafka.") - def testRun_with_produce(self, mock_produceAlerts): + @patch('confluent_kafka.Producer') + @unittest.skipIf("confluent_kafka" not in sys.modules, 'Kafka is not enabled') + def testRun_with_produce(self, mock_produceAlerts, mock_producer): """Test that packageAlerts calls produceAlerts when doProduceAlerts is set to True. """ - packConfig = PackageAlertsConfig(doProduceAlerts=True) packageAlerts = PackageAlertsTask(config=packConfig) @@ -521,6 +535,9 @@ def test_serialize_alert_round_trip(self, **kwargs): alert = mock_alert(1) serialized = PackageAlertsTask._serialize_alert(packageAlerts, alert) deserialized = PackageAlertsTask._deserialize_alert(packageAlerts, serialized) + + for field in alert['diaSource']: + self.assertAlmostEqual(alert['diaSource'][field], deserialized['diaSource'][field], places=5) self.assertEqual(1, deserialized["alertId"])