From 7b966226f50d227964dfc669bd94bcb3dd501452 Mon Sep 17 00:00:00 2001 From: Santiago Fraire Willemoes Date: Thu, 28 Nov 2024 16:07:12 +0100 Subject: [PATCH] test: add class based test sample --- tests/test_streams.py | 74 ++++++++++++++++++++++++++++++++----------- 1 file changed, 55 insertions(+), 19 deletions(-) diff --git a/tests/test_streams.py b/tests/test_streams.py index 7ab253f..c55cc00 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -3,18 +3,17 @@ from typing import Callable, Set from unittest import mock -import pytest - from kstreams import ConsumerRecord, Send, TopicPartition from kstreams.clients import Consumer, Producer from kstreams.engine import Stream, StreamEngine from kstreams.streams import stream from kstreams.structs import TopicPartitionOffset +from kstreams.test_utils import TestStreamClient from tests import TimeoutErrorException - # NOTE: remove the test when `no typing` support is deprecated -@pytest.mark.asyncio + + async def test_stream_no_typing(stream_engine: StreamEngine, consumer_record_factory): topic_name = "local--kstreams" value = b"test" @@ -50,7 +49,6 @@ async def stream(stream_instance): await stream.stop() -@pytest.mark.asyncio async def test_stream_cr_with_typing( stream_engine: StreamEngine, consumer_record_factory ): @@ -87,7 +85,6 @@ async def stream(cr: ConsumerRecord): await stream.stop() -@pytest.mark.asyncio async def test_stream_generic_cr_with_typing( stream_engine: StreamEngine, consumer_record_factory ): @@ -124,7 +121,43 @@ async def stream(cr: ConsumerRecord[str, bytes]): await stream.stop() -@pytest.mark.asyncio +async def test_stream_class_cr_with_typing( + stream_engine: StreamEngine, consumer_record_factory +): + topic_name = "local--kstreams" + target_topic = "local--kstreams-target" + value = "test" + + client = TestStreamClient(stream_engine, topics=[target_topic]) + + class TestClass: + def __init__(self) -> None: + self.bar = value + + async def streaming_fn(self, cr: ConsumerRecord, send: Send): + """text from func""" + + await send(target_topic, value=self.bar) + + foo = TestClass() + _stream = Stream( + topics=[topic_name], + func=foo.streaming_fn, + ) + stream_engine.add_stream(_stream) + + async with client: + await stream_engine.start() + await client.send(topic_name, value=value) + # import ipdb; ipdb.set_trace() + client.get_topic(topic_name=target_topic) + r = await asyncio.wait_for( + client.get_event(topic_name=target_topic), timeout=0.2 + ) + # r = await client.get_event(topic_name=target_topic) + assert r.value == value + + async def test_stream_cr_and_stream_with_typing( stream_engine: StreamEngine, consumer_record_factory ): @@ -154,7 +187,6 @@ async def stream(cr: ConsumerRecord, stream: Stream): await stream.stop() -@pytest.mark.asyncio async def test_stream_all_typing(stream_engine: StreamEngine, consumer_record_factory): topic_name = "local--kstreams" value = b"test" @@ -191,7 +223,6 @@ async def stream(cr: ConsumerRecord, send: Send, stream: Stream): await stream.stop() -@pytest.mark.asyncio async def test_stream_all_typing_order_in_setup_type( stream_engine: StreamEngine, consumer_record_factory ): @@ -230,7 +261,6 @@ async def stream(stream: Stream, cr: ConsumerRecord, send: Send): await stream.stop() -@pytest.mark.asyncio async def test_stream_multiple_topics(stream_engine: StreamEngine): topics = ["local--hello-kpn", "local--hello-kpn-2"] @@ -251,7 +281,6 @@ async def stream(_): ... ) -@pytest.mark.asyncio async def test_stream_subscribe_topics_pattern(stream_engine: StreamEngine): pattern = "^dev--customer-.*$" @@ -273,7 +302,6 @@ async def stream(_): ... ) -@pytest.mark.asyncio async def test_stream_subscribe_topics_only_one_pattern(stream_engine: StreamEngine): """ We can use only one pattern, so we use the first one @@ -299,7 +327,6 @@ async def stream(_): ... ) -@pytest.mark.asyncio async def test_stream_custom_conf(stream_engine: StreamEngine): @stream_engine.stream( "local--hello-kpn", @@ -323,7 +350,6 @@ async def stream(_): ... assert not stream.consumer._enable_auto_commit -@pytest.mark.asyncio async def test_stream_getmany( stream_engine: StreamEngine, consumer_record_factory: Callable[..., ConsumerRecord] ): @@ -351,7 +377,6 @@ async def getmany(*args, **kwargs): save_to_db.assert_called_once_with(topic_partition_crs) -@pytest.mark.asyncio async def test_stream_decorator(stream_engine: StreamEngine): topic = "local--hello-kpn" @@ -377,7 +402,6 @@ async def streaming_fn(_): Consumer.stop.assert_awaited() -@pytest.mark.asyncio async def test_stream_decorates_properly(stream_engine: StreamEngine): topic = "local--hello-kpn" @@ -389,7 +413,6 @@ async def streaming_fn(_): assert streaming_fn.__doc__ == "text from func" -@pytest.mark.asyncio async def test_recreate_consumer_on_re_start_stream( stream_engine: StreamEngine, consumer_record_factory ): @@ -418,7 +441,6 @@ async def stream(my_stream): assert consumer is not stream.consumer -@pytest.mark.asyncio async def test_seek_to_initial_offsets_normal( stream_engine: StreamEngine, consumer_record_factory ): @@ -465,7 +487,6 @@ async def stream(my_stream): ) -@pytest.mark.asyncio async def test_seek_to_initial_offsets_ignores_wrong_input( stream_engine: StreamEngine, consumer_record_factory ): @@ -511,3 +532,18 @@ async def stream(my_stream): assert stream.rebalance_listener is not None await stream.rebalance_listener.on_partitions_assigned(assigned=assignments) seek_mock.assert_not_called() + + +async def test_stream_simple_di_works( + stream_engine: StreamEngine, consumer_record_factory +): + topic = "local--hello-kpn" + cr: ConsumerRecord = consumer_record_factory(topic=topic, value=b"test") + + @stream_engine.stream(topic) + async def streaming_fn(cr: ConsumerRecord): + """text from func""" + return cr.value + + r = await streaming_fn.func(cr) + assert r == b"test"