diff --git a/jupyter_events/logger.py b/jupyter_events/logger.py index 79b4fba..63838e0 100644 --- a/jupyter_events/logger.py +++ b/jupyter_events/logger.py @@ -142,12 +142,16 @@ def register_event_schema(self, schema: SchemaType) -> None: Get this registered schema using the EventLogger.schema.get() method. """ - event_schema = self.schemas.register(schema) # type:ignore[arg-type] key = event_schema.id - self._modifiers[key] = set() - self._modified_listeners[key] = set() - self._unmodified_listeners[key] = set() + # It's possible that listeners and modifiers have been added for this + # schema before the schema is registered. + if key not in self._modifiers: + self._modifiers[key] = set() + if key not in self._modified_listeners: + self._modified_listeners[key] = set() + if key not in self._unmodified_listeners: + self._unmodified_listeners[key] = set() def register_handler(self, handler: logging.Handler) -> None: """Register a new logging handler to the Event Logger. @@ -205,7 +209,11 @@ def add_modifier( # If the schema ID and version is given, only add # this modifier to that schema if schema_id: - self._modifiers[schema_id].add(modifier) + # If the schema hasn't been added yet, + # start a placeholder set. + modifiers = self._modifiers.get(schema_id, set()) + modifiers.add(modifier) + self._modifiers[schema_id] = modifiers return for id_ in self._modifiers: if schema_id is None or id_ == schema_id: @@ -264,9 +272,16 @@ def add_listener( # this modifier to that schema if schema_id: if modified: - self._modified_listeners[schema_id].add(listener) + # If the schema hasn't been added yet, + # start a placeholder set. + listeners = self._modified_listeners.get(schema_id, set()) + listeners.add(listener) + self._modified_listeners[schema_id] = listeners return - self._unmodified_listeners[schema_id].add(listener) + listeners = self._unmodified_listeners.get(schema_id, set()) + listeners.add(listener) + self._unmodified_listeners[schema_id] = listeners + return for id_ in self.schemas.schema_ids: if schema_id is None or id_ == schema_id: if modified: diff --git a/tests/test_listeners.py b/tests/test_listeners.py index 2e167a1..684a51b 100644 --- a/tests/test_listeners.py +++ b/tests/test_listeners.py @@ -5,7 +5,7 @@ import pytest -from jupyter_events.logger import EventLogger +from jupyter_events.logger import EventLogger, SchemaNotRegistered from jupyter_events.schema import EventSchema from .utils import SCHEMA_PATH @@ -138,3 +138,41 @@ async def my_listener(logger: EventLogger, schema_id: str, data: dict) -> None: assert listener_was_called # Check that the active listeners are cleaned up. assert len(event_logger._active_listeners) == 0 + + +@pytest.mark.parametrize( + # Make sure no schemas are added at the start of this test. + "jp_event_schemas", + [ + # Empty events list. + [] + ], +) +async def test_listener_added_before_schemas_passes(jp_event_logger, schema): + # Ensure there are no schemas listed. + assert len(jp_event_logger.schemas.schema_ids) == 0 + + listener_was_called = False + + async def my_listener(logger: EventLogger, schema_id: str, data: dict) -> None: + nonlocal listener_was_called + listener_was_called = True + + # Add the listener without any schemas + jp_event_logger.add_listener(schema_id=schema.id, listener=my_listener) + + # Proof that emitting the event won't success + with pytest.warns(SchemaNotRegistered): + jp_event_logger.emit(schema_id=schema.id, data={"prop": "hello, world"}) + + assert not listener_was_called + + # Now register the event and emit. + jp_event_logger.register_event_schema(schema) + + # Try emitting the event again and ensure the listener saw it. + jp_event_logger.emit(schema_id=schema.id, data={"prop": "hello, world"}) + await jp_event_logger.gather_listeners() + assert listener_was_called + # Check that the active listeners are cleaned up. + assert len(jp_event_logger._active_listeners) == 0