Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(StreamEngine): graceful shutdown must wait for all events to be p… #171

Merged
merged 1 commit into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions kstreams/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def __init__(
self.name = name or str(uuid.uuid4())
self.deserializer = deserializer
self.running = False
self.is_processing = asyncio.Lock()
self.initial_offsets = initial_offsets
self.seeked_initial_offsets = False
self.rebalance_listener = rebalance_listener
Expand All @@ -121,15 +122,18 @@ def _create_consumer(self) -> Consumer:
return self.consumer_class(**config)

async def stop(self) -> None:
if not self.running:
return None

if self.consumer is not None:
await self.consumer.stop()
if self.running:
# Don't run anymore to prevent new events comming
self.running = False

if self._consumer_task is not None:
self._consumer_task.cancel()
async with self.is_processing:
# Only enter this block when all the events have been
# proccessed in the middleware chain
if self.consumer is not None:
await self.consumer.stop()

if self._consumer_task is not None:
self._consumer_task.cancel()

async def _subscribe(self) -> None:
# Always create a consumer on stream.start
Expand All @@ -141,7 +145,6 @@ async def _subscribe(self) -> None:
self.consumer.subscribe(
topics=self.topics, listener=self.rebalance_listener
)
self.running = True

async def commit(
self, offsets: typing.Optional[typing.Dict[TopicPartition, int]] = None
Expand Down Expand Up @@ -206,6 +209,7 @@ async def start(self) -> None:
return None

await self._subscribe()
self.running = True

if self.udf_handler.type == UDFType.NO_TYPING:
# normal use case
Expand Down Expand Up @@ -236,9 +240,10 @@ async def func_wrapper(self, func: typing.Awaitable) -> None:
logger.exception(f"CRASHED Stream!!! Task {self._consumer_task} \n\n {e}")

async def func_wrapper_with_typing(self) -> None:
while True:
while self.running:
cr = await self.getone()
await self.func(cr)
async with self.is_processing:
await self.func(cr)

def seek_to_initial_offsets(self) -> None:
if not self.seeked_initial_offsets and self.consumer is not None:
Expand Down
3 changes: 1 addition & 2 deletions kstreams/test_utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ class TestMonitor(PrometheusMonitor):
__test__ = False

def start(self, *args, **kwargs) -> None:
print("herte....")
# ...
...

async def stop(self, *args, **kwargs) -> None:
...
Expand Down
33 changes: 33 additions & 0 deletions tests/test_stream_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,39 @@ async def stream(_):
Consumer.stop.assert_awaited()


@pytest.mark.asyncio
async def test_wait_for_streams_before_stop(
stream_engine: StreamEngine, consumer_record_factory: Callable[..., ConsumerRecord]
):
topic = "local--hello-kpn"
value = b"Hello world"
save_to_db = mock.AsyncMock()

async def getone(_):
return consumer_record_factory(value=value)

@stream_engine.stream(topic)
async def stream(cr: ConsumerRecord):
# Use 5 seconds sleep to simulate a super slow event processing
await asyncio.sleep(5)
await save_to_db(cr.value)

with mock.patch.multiple(
Consumer,
start=mock.DEFAULT,
stop=mock.DEFAULT,
getone=getone,
), mock.patch.multiple(Producer, start=mock.DEFAULT, stop=mock.DEFAULT):
await stream_engine.start()
await asyncio.sleep(0) # Allow stream coroutine to run once

# stop engine immediately, this should not break the streams
# and it should wait until the event is processed.
await stream_engine.stop()
Consumer.stop.assert_awaited()
save_to_db.assert_awaited_once_with(value)


@pytest.mark.asyncio
async def test_recreate_consumer_on_re_start_stream(stream_engine: StreamEngine):
with mock.patch("kstreams.clients.aiokafka.AIOKafkaConsumer.start"):
Expand Down
Loading