diff --git a/src/frequenz/client/dispatch/_client.py b/src/frequenz/client/dispatch/_client.py index 04e59f96..e05f9fc0 100644 --- a/src/frequenz/client/dispatch/_client.py +++ b/src/frequenz/client/dispatch/_client.py @@ -270,12 +270,9 @@ def _get_stream( request = StreamMicrogridDispatchesRequest(microgrid_id=int(microgrid_id)) broadcaster = GrpcStreamBroadcaster( stream_name="StreamMicrogridDispatches", - stream_method=lambda: cast( - AsyncIterator[StreamMicrogridDispatchesResponse], - self.stub.StreamMicrogridDispatches( - request, - timeout=self._stream_timeout_seconds, - ), + stream_method=lambda: self.stub.StreamMicrogridDispatches( + request, + timeout=self._stream_timeout_seconds, ), transform=DispatchEvent.from_protobuf, retry_strategy=LinearBackoff(interval=1, limit=None), diff --git a/src/frequenz/client/dispatch/test/_service.py b/src/frequenz/client/dispatch/test/_service.py index 83a0a519..3a41d8aa 100644 --- a/src/frequenz/client/dispatch/test/_service.py +++ b/src/frequenz/client/dispatch/test/_service.py @@ -9,6 +9,7 @@ from dataclasses import dataclass, replace from datetime import datetime, timezone from typing import AsyncIterator +from unittest.mock import AsyncMock, MagicMock import grpc import grpc.aio @@ -109,7 +110,7 @@ async def ListMicrogridDispatches( ), ) - async def StreamMicrogridDispatches( + def StreamMicrogridDispatches( self, request: StreamMicrogridDispatchesRequest, timeout: int = 5, # pylint: disable=unused-argument @@ -122,20 +123,37 @@ async def StreamMicrogridDispatches( Returns: An async generator for dispatch changes. - - Yields: - An event for each dispatch change. """ - receiver = self._stream_channel.new_receiver() - - async for message in receiver: - _logger.debug("Received message: %s", message) - if message.microgrid_id == MicrogridId(request.microgrid_id): - response = StreamMicrogridDispatchesResponse( - event=message.event.event.value, - dispatch=message.event.dispatch.to_protobuf(), - ) - yield response + + async def stream() -> AsyncIterator[StreamMicrogridDispatchesResponse]: + """Stream microgrid dispatches changes.""" + _logger.debug("Starting stream for microgrid %s", request.microgrid_id) + receiver = self._stream_channel.new_receiver() + + async for message in receiver: + _logger.debug("Received message: %s", message) + if message.microgrid_id == MicrogridId(request.microgrid_id): + response = StreamMicrogridDispatchesResponse( + event=message.event.event.value, + dispatch=message.event.dispatch.to_protobuf(), + ) + yield response + else: + _logger.debug( + "Skipping message for microgrid %s", + message.microgrid_id, + ) + + _logger.debug("Creating mock stream for microgrid %s", request.microgrid_id) + + mock_stream = MagicMock(name="StreamMicrogridDispatches") + mock_stream.__aiter__.side_effect = stream + mock_stream.initial_metadata = AsyncMock( + side_effect=lambda: _logger.debug( + "Initial metadata requested for microgrid %s", request.microgrid_id + ) + ) + return mock_stream # pylint: disable=too-many-branches @staticmethod