diff --git a/python/lsst/ap/association/packageAlerts.py b/python/lsst/ap/association/packageAlerts.py index 5336802e..832acf04 100644 --- a/python/lsst/ap/association/packageAlerts.py +++ b/python/lsst/ap/association/packageAlerts.py @@ -36,6 +36,7 @@ try: import confluent_kafka from confluent_kafka import KafkaException + from confluent_kafka.admin import AdminClient except ImportError: confluent_kafka = None @@ -131,6 +132,25 @@ def __init__(self, **kwargs): "batch.size": 2097152, "linger.ms": 5, } + self.kafkaAdminConfig = { + # This is the URL to use to connect to the Kafka cluster. + "bootstrap.servers": self.server, + # These next two properties tell the Kafka client about the specific + # authentication and authorization protocols that should be used when + # connecting. + "security.protocol": "SASL_PLAINTEXT", + "sasl.mechanisms": "SCRAM-SHA-512", + # The sasl.username and sasl.password are passed through over + # SCRAM-SHA-512 auth to connect to the cluster. The username is not + # sensitive, but the password is (of course) a secret value which + # should never be committed to source code. + "sasl.username": self.username, + "sasl.password": self.password, + # Batch size limits the largest size of a kafka alert that can be sent. + # We set the batch size to 2 Mb. + } + + self._server_check() self.producer = confluent_kafka.Producer(**self.kafkaConfig) else: @@ -293,6 +313,7 @@ def produceAlerts(self, alerts, ccdVisitId): ccdVisitId of the alerts sent to the alert stream. Used to write out alerts which fail to be sent to the alert stream. """ + self._server_check() for alert in alerts: alertBytes = self._serializeAlert(alert, schema=self.alertSchema.definition, schema_id=1) try: @@ -301,7 +322,6 @@ def produceAlerts(self, alerts, ccdVisitId): except KafkaException as e: self.log.warning('Kafka error: {}, message was {} bytes'.format(e, sys.getsizeof(alertBytes))) - with open(os.path.join(self.config.alertWriteLocation, f"{ccdVisitId}_{alert['alertId']}.avro"), "wb") as f: f.write(alertBytes) @@ -555,3 +575,15 @@ def _delivery_callback(self, err, msg): self.log.warning('Message failed delivery: %s\n' % err) else: self.log.debug('Message delivered to %s [%d] @ %d', msg.topic(), msg.partition(), msg.offset()) + + def _server_check(self): + try: + admin_client = AdminClient(self.kafkaAdminConfig) + topics = admin_client.list_topics(timeout=0.5).topics + + if not topics: + raise RuntimeError() + + except KafkaException as e: + self.log.error(e) + raise diff --git a/tests/test_packageAlerts.py b/tests/test_packageAlerts.py index eefc0237..c8d7c1bc 100644 --- a/tests/test_packageAlerts.py +++ b/tests/test_packageAlerts.py @@ -402,8 +402,9 @@ def test_produceAlerts_empty_topic(self): PackageAlertsTask(config=packConfig) @patch('confluent_kafka.Producer') + @patch.object(PackageAlertsTask, '_server_check') @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') - def test_produceAlerts_success(self, mock_producer): + def test_produceAlerts_success(self, mock_server_check, mock_producer): """ Test that produceAlerts calls the producer on all provided alerts when the alerts are all under the batch size limit. """ @@ -418,12 +419,14 @@ def test_produceAlerts_success(self, mock_producer): producer_instance.flush = Mock() packageAlerts.produceAlerts(alerts, ccdVisitId) + self.assertEqual(mock_server_check.call_count, 2) self.assertEqual(producer_instance.produce.call_count, len(alerts)) self.assertEqual(producer_instance.flush.call_count, len(alerts)+1) @patch('confluent_kafka.Producer') + @patch.object(PackageAlertsTask, '_server_check') @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') - def test_produceAlerts_one_failure(self, mock_producer): + def test_produceAlerts_one_failure(self, mock_server_check, mock_producer): """ Test that produceAlerts correctly fails on one alert and is writing the failure to disk. """ @@ -451,6 +454,7 @@ def mock_produce(*args, **kwargs): packageAlerts.produceAlerts(alerts, ccdVisitId) + self.assertEqual(mock_server_check.call_count, 2) self.assertEqual(producer_instance.produce.call_count, len(alerts)) self.assertEqual(patch_open.call_count, 1) self.assertIn("123_2.avro", patch_open.call_args.args[0]) @@ -459,11 +463,11 @@ def mock_produce(*args, **kwargs): self.assertEqual(producer_instance.flush.call_count, len(alerts)) patcher.stop() - def testRun_without_produce(self): + @patch.object(PackageAlertsTask, '_server_check') + def testRun_without_produce(self, mock_server_check): """Test the run method of package alerts with produce set to False and doWriteAlerts set to true. """ - packConfig = PackageAlertsConfig(doWriteAlerts=True) tempdir = tempfile.mkdtemp(prefix='alerts') packConfig.alertWriteLocation = tempdir @@ -477,6 +481,8 @@ def testRun_without_produce(self): self.exposure, self.exposure) + self.assertEqual(mock_server_check.call_count, 0) + ccdVisitId = self.exposure.info.id with open(os.path.join(tempdir, f"{ccdVisitId}.avro"), 'rb') as f: writer_schema, data_stream = \ @@ -513,8 +519,9 @@ def testRun_without_produce(self): @patch.object(PackageAlertsTask, 'produceAlerts') @patch('confluent_kafka.Producer') + @patch.object(PackageAlertsTask, '_server_check') @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') - def testRun_with_produce(self, mock_produceAlerts, mock_producer): + def testRun_with_produce(self, mock_produceAlerts, mock_server_check, mock_producer): """Test that packageAlerts calls produceAlerts when doProduceAlerts is set to True. """ @@ -526,15 +533,16 @@ def testRun_with_produce(self, mock_produceAlerts, mock_producer): self.diaSourceHistory, self.diaForcedSources, self.exposure, + self.exposure, self.exposure) - + self.assertEqual(mock_server_check.call_count, 1) self.assertEqual(mock_produceAlerts.call_count, 1) def test_serialize_alert_round_trip(self): """Test that values in the alert packet exactly round trip. """ - ConfigClass = PackageAlertsConfig() - packageAlerts = PackageAlertsTask(config=ConfigClass) + packClass = PackageAlertsConfig() + packageAlerts = PackageAlertsTask(config=packClass) alert = mock_alert(1) serialized = PackageAlertsTask._serializeAlert(packageAlerts, alert) @@ -544,6 +552,13 @@ def test_serialize_alert_round_trip(self): self.assertEqual(alert['diaSource'][field], deserialized['diaSource'][field]) self.assertEqual(1, deserialized["alertId"]) + @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') + def test_server_check(self): + + with self.assertRaisesRegex(KafkaException, "_TRANSPORT"): + packConfig = PackageAlertsConfig(doProduceAlerts=True) + PackageAlertsTask(config=packConfig) + class MemoryTester(lsst.utils.tests.MemoryTestCase): pass