Skip to content

Commit

Permalink
Add bad credential error handling and update unit tests
Browse files Browse the repository at this point in the history
Add check_server function to check whether or not the server is contactable. Update unit tests
  • Loading branch information
bsmartradio committed Mar 21, 2024
1 parent 6179bb6 commit f459dd6
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 9 deletions.
34 changes: 33 additions & 1 deletion python/lsst/ap/association/packageAlerts.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
try:
import confluent_kafka
from confluent_kafka import KafkaException
from confluent_kafka.admin import AdminClient
except ImportError:
confluent_kafka = None

Expand Down Expand Up @@ -131,6 +132,25 @@ def __init__(self, **kwargs):
"batch.size": 2097152,
"linger.ms": 5,
}
self.kafkaAdminConfig = {
# This is the URL to use to connect to the Kafka cluster.
"bootstrap.servers": self.server,
# These next two properties tell the Kafka client about the specific
# authentication and authorization protocols that should be used when
# connecting.
"security.protocol": "SASL_PLAINTEXT",
"sasl.mechanisms": "SCRAM-SHA-512",
# The sasl.username and sasl.password are passed through over
# SCRAM-SHA-512 auth to connect to the cluster. The username is not
# sensitive, but the password is (of course) a secret value which
# should never be committed to source code.
"sasl.username": self.username,
"sasl.password": self.password,
# Batch size limits the largest size of a kafka alert that can be sent.
# We set the batch size to 2 Mb.
}

self._server_check()
self.producer = confluent_kafka.Producer(**self.kafkaConfig)

else:
Expand Down Expand Up @@ -293,6 +313,7 @@ def produceAlerts(self, alerts, ccdVisitId):
ccdVisitId of the alerts sent to the alert stream. Used to write
out alerts which fail to be sent to the alert stream.
"""
self._server_check()
for alert in alerts:
alertBytes = self._serializeAlert(alert, schema=self.alertSchema.definition, schema_id=1)
try:
Expand All @@ -301,7 +322,6 @@ def produceAlerts(self, alerts, ccdVisitId):

except KafkaException as e:
self.log.warning('Kafka error: {}, message was {} bytes'.format(e, sys.getsizeof(alertBytes)))

with open(os.path.join(self.config.alertWriteLocation,
f"{ccdVisitId}_{alert['alertId']}.avro"), "wb") as f:
f.write(alertBytes)
Expand Down Expand Up @@ -555,3 +575,15 @@ def _delivery_callback(self, err, msg):
self.log.warning('Message failed delivery: %s\n' % err)
else:
self.log.debug('Message delivered to %s [%d] @ %d', msg.topic(), msg.partition(), msg.offset())

def _server_check(self):
try:
admin_client = AdminClient(self.kafkaAdminConfig)
topics = admin_client.list_topics(timeout=0.5).topics

if not topics:
raise RuntimeError()

except KafkaException as e:
self.log.error(e)
raise
31 changes: 23 additions & 8 deletions tests/test_packageAlerts.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,8 +402,9 @@ def test_produceAlerts_empty_topic(self):
PackageAlertsTask(config=packConfig)

@patch('confluent_kafka.Producer')
@patch.object(PackageAlertsTask, '_server_check')
@unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
def test_produceAlerts_success(self, mock_producer):
def test_produceAlerts_success(self, mock_server_check, mock_producer):
""" Test that produceAlerts calls the producer on all provided alerts
when the alerts are all under the batch size limit.
"""
Expand All @@ -418,12 +419,14 @@ def test_produceAlerts_success(self, mock_producer):
producer_instance.flush = Mock()
packageAlerts.produceAlerts(alerts, ccdVisitId)

self.assertEqual(mock_server_check.call_count, 2)
self.assertEqual(producer_instance.produce.call_count, len(alerts))
self.assertEqual(producer_instance.flush.call_count, len(alerts)+1)

@patch('confluent_kafka.Producer')
@patch.object(PackageAlertsTask, '_server_check')
@unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
def test_produceAlerts_one_failure(self, mock_producer):
def test_produceAlerts_one_failure(self, mock_server_check, mock_producer):
""" Test that produceAlerts correctly fails on one alert
and is writing the failure to disk.
"""
Expand Down Expand Up @@ -451,6 +454,7 @@ def mock_produce(*args, **kwargs):

packageAlerts.produceAlerts(alerts, ccdVisitId)

self.assertEqual(mock_server_check.call_count, 2)
self.assertEqual(producer_instance.produce.call_count, len(alerts))
self.assertEqual(patch_open.call_count, 1)
self.assertIn("123_2.avro", patch_open.call_args.args[0])
Expand All @@ -459,11 +463,11 @@ def mock_produce(*args, **kwargs):
self.assertEqual(producer_instance.flush.call_count, len(alerts))
patcher.stop()

def testRun_without_produce(self):
@patch.object(PackageAlertsTask, '_server_check')
def testRun_without_produce(self, mock_server_check):
"""Test the run method of package alerts with produce set to False and
doWriteAlerts set to true.
"""

packConfig = PackageAlertsConfig(doWriteAlerts=True)
tempdir = tempfile.mkdtemp(prefix='alerts')
packConfig.alertWriteLocation = tempdir
Expand All @@ -477,6 +481,8 @@ def testRun_without_produce(self):
self.exposure,
self.exposure)

self.assertEqual(mock_server_check.call_count, 0)

ccdVisitId = self.exposure.info.id
with open(os.path.join(tempdir, f"{ccdVisitId}.avro"), 'rb') as f:
writer_schema, data_stream = \
Expand Down Expand Up @@ -513,8 +519,9 @@ def testRun_without_produce(self):

@patch.object(PackageAlertsTask, 'produceAlerts')
@patch('confluent_kafka.Producer')
@patch.object(PackageAlertsTask, '_server_check')
@unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
def testRun_with_produce(self, mock_produceAlerts, mock_producer):
def testRun_with_produce(self, mock_produceAlerts, mock_server_check, mock_producer):
"""Test that packageAlerts calls produceAlerts when doProduceAlerts
is set to True.
"""
Expand All @@ -526,15 +533,16 @@ def testRun_with_produce(self, mock_produceAlerts, mock_producer):
self.diaSourceHistory,
self.diaForcedSources,
self.exposure,
self.exposure,
self.exposure)

self.assertEqual(mock_server_check.call_count, 1)
self.assertEqual(mock_produceAlerts.call_count, 1)

def test_serialize_alert_round_trip(self):
"""Test that values in the alert packet exactly round trip.
"""
ConfigClass = PackageAlertsConfig()
packageAlerts = PackageAlertsTask(config=ConfigClass)
packClass = PackageAlertsConfig()
packageAlerts = PackageAlertsTask(config=packClass)

alert = mock_alert(1)
serialized = PackageAlertsTask._serializeAlert(packageAlerts, alert)
Expand All @@ -544,6 +552,13 @@ def test_serialize_alert_round_trip(self):
self.assertEqual(alert['diaSource'][field], deserialized['diaSource'][field])
self.assertEqual(1, deserialized["alertId"])

@unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
def test_server_check(self):

with self.assertRaisesRegex(KafkaException, "_TRANSPORT"):
packConfig = PackageAlertsConfig(doProduceAlerts=True)
PackageAlertsTask(config=packConfig)


class MemoryTester(lsst.utils.tests.MemoryTestCase):
pass
Expand Down

0 comments on commit f459dd6

Please sign in to comment.