Skip to content

Commit

Permalink
sse support
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Dec 16, 2023
1 parent 00eae13 commit 8441c5f
Show file tree
Hide file tree
Showing 9 changed files with 188 additions and 25 deletions.
16 changes: 16 additions & 0 deletions examples/sse/counter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import asyncio
from microdot import Microdot
from microdot.sse import with_sse

app = Microdot()


@app.route('/events')
@with_sse
async def events(request, sse):
for i in range(10):
await asyncio.sleep(1)
await sse.send({'counter': i})


app.run(debug=True)
23 changes: 15 additions & 8 deletions examples/streaming/video_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,20 @@ async def index(request):

@app.route('/video_feed')
async def video_feed(request):
print('Starting video stream.')

if sys.implementation.name != 'micropython':
# CPython supports yielding async generators
# CPython supports async generator function
async def stream():
yield b'--frame\r\n'
while True:
for frame in frames:
yield b'Content-Type: image/jpeg\r\n\r\n' + frame + \
b'\r\n--frame\r\n'
await asyncio.sleep(1)

try:
yield b'--frame\r\n'
while True:
for frame in frames:
yield b'Content-Type: image/jpeg\r\n\r\n' + frame + \
b'\r\n--frame\r\n'
await asyncio.sleep(1)
except GeneratorExit:
print('Stopping video stream.')
else:
# MicroPython can only use class-based async generators
class stream():
Expand All @@ -52,6 +56,9 @@ async def __anext__(self):
return b'Content-Type: image/jpeg\r\n\r\n' + \
frames[self.i] + b'\r\n--frame\r\n'

async def aclose(self):
print('Stopping video stream.')

return stream(), 200, {'Content-Type':
'multipart/x-mixed-replace; boundary=frame'}

Expand Down
2 changes: 2 additions & 0 deletions src/microdot/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ async def cancel_monitor():
await send({'type': 'http.response.body',
'body': res_body,
'more_body': False})
if hasattr(body_iter, 'aclose'): # pragma: no branch
await body_iter.aclose()
cancelled = True
await monitor_task

Expand Down
50 changes: 35 additions & 15 deletions src/microdot/microdot.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,10 +646,21 @@ async def write(self, stream):

# body
if not self.is_head:
async for body in self.body_iter():
iter = self.body_iter()
async for body in iter:
if isinstance(body, str): # pragma: no cover
body = body.encode()
await stream.awrite(body)
try:
await stream.awrite(body)
except OSError as exc: # pragma: no cover
if exc.errno in MUTED_SOCKET_ERRORS or \
exc.args[0] == 'Connection lost':
if hasattr(iter, 'aclose'):
await iter.aclose()
raise
if hasattr(iter, 'aclose'):
await iter.aclose()

except OSError as exc: # pragma: no cover
if exc.errno in MUTED_SOCKET_ERRORS or \
exc.args[0] == 'Connection lost':
Expand All @@ -665,41 +676,50 @@ def body_iter(self):
response = self

class iter:
ITER_UNKNOWN = 0
ITER_SYNC_GEN = 1
ITER_FILE_OBJ = 2
ITER_NO_BODY = -1

def __aiter__(self):
if response.body:
self.i = 0 # need to determine type of response.body
self.i = self.ITER_UNKNOWN # need to determine type
else:
self.i = -1 # no response body
self.i = self.ITER_NO_BODY
return self

async def __anext__(self):
if self.i == -1:
if self.i == self.ITER_NO_BODY:
await self.aclose()
raise StopAsyncIteration
if self.i == 0:
if self.i == self.ITER_UNKNOWN:
if hasattr(response.body, 'read'):
self.i = 2 # response body is a file-like object
self.i = self.ITER_FILE_OBJ
elif hasattr(response.body, '__next__'):
self.i = 1 # response body is a sync generator
self.i = self.ITER_SYNC_GEN
return next(response.body)
else:
self.i = -1 # response body is a plain string
self.i = self.ITER_NO_BODY
return response.body
elif self.i == 1:
elif self.i == self.ITER_SYNC_GEN:
try:
return next(response.body)
except StopIteration:
await self.aclose()
raise StopAsyncIteration
buf = response.body.read(response.send_file_buffer_size)
if iscoroutine(buf): # pragma: no cover
buf = await buf
if len(buf) < response.send_file_buffer_size:
self.i = -1
if hasattr(response.body, 'close'): # pragma: no cover
result = response.body.close()
if iscoroutine(result):
await result
self.i = self.ITER_NO_BODY
return buf

async def aclose(self):
if hasattr(response.body, 'close'):
result = response.body.close()
if iscoroutine(result): # pragma: no cover
await result

return iter()

@classmethod
Expand Down
104 changes: 104 additions & 0 deletions src/microdot/sse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import asyncio
import json


class SSE:
def __init__(self):
self.event = asyncio.Event()
self.queue = []

async def send(self, data, event=None):
if isinstance(data, (dict, list)):
data = json.dumps(data)
elif not isinstance(data, str):
data = str(data)
data = f'data: {data}\n\n'
if event:
data = f'event: {event}\n{data}'
self.queue.append(data)
self.event.set()

async def events(self):
while True:
await self.event.wait()
self.event.clear()
queue = self.queue
self.queue = []
for event in queue:
yield event


def sse_response(request, event_function, *args, **kwargs):
"""Return a response object that initiates an event stream.
:param request: the request object.
:param event_function: an asynchronous function that will send events to
the client. The function is invoked with ``request``
and an ``sse`` object. The function should use
``sse.send()`` to send events to the client.
:param args: additional positional arguments to be passed to the response.
:param kwargs: additional keyword arguments to be passed to the response.
Example::
@app.route('/events')
async def events_route(request):
async def events(request, sse):
# send an unnamed event with string data
await sse.send('hello')
# send an unnamed event with JSON data
await sse.send({'foo': 'bar'})
# send a named event
await sse.send('hello', event='greeting')
return sse_response(request, events)
"""
sse = SSE()

async def sse_task_wrapper():
await event_function(request, sse, *args, **kwargs)
sse.event.set()

task = asyncio.ensure_future(sse_task_wrapper())

class sse_loop:
def __aiter__(self):
return self

async def __anext__(self):
event = None
while sse.queue or not task.done():
try:
event = sse.queue.pop()
break
except IndexError:
await sse.event.wait()
sse.event.clear()
if event is None:
raise StopAsyncIteration
return event

async def aclose(self):
task.cancel()

return sse_loop(), 200, {'Content-Type': 'text/event-stream'}


def with_sse(f):
"""Decorator to make a route a Server-Sent Events endpoint.
This decorator is used to define a route that accepts SSE connections. The
route then receives a sse object as a second argument that it can use to
send events to the client::
@app.route('/events')
@with_sse
async def events(request, sse):
for i in range(10):
await asyncio.sleep(1)
await sse.send(f'{i}')
"""
async def sse_handler(request, *args, **kwargs):
return sse_response(request, f, *args, **kwargs)

return sse_handler
5 changes: 4 additions & 1 deletion src/microdot/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,13 @@ def _initialize_response(self, res):

async def _initialize_body(self, res):
self.body = b''
async for body in res.body_iter(): # pragma: no branch
iter = res.body_iter()
async for body in iter: # pragma: no branch
if isinstance(body, str):
body = body.encode()
self.body += body
if hasattr(iter, 'aclose'): # pragma: no branch
await iter.aclose()

def _process_text_body(self):
try:
Expand Down
3 changes: 2 additions & 1 deletion src/microdot/websocket.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import binascii
import hashlib
from microdot import Response
from microdot.microdot import MUTED_SOCKET_ERRORS


class WebSocket:
Expand Down Expand Up @@ -162,7 +163,7 @@ async def wrapper(request, *args, **kwargs):
await f(request, ws, *args, **kwargs)
await ws.close() # pragma: no cover
except OSError as exc:
if exc.errno not in [32, 54, 104]: # pragma: no cover
if exc.errno not in MUTED_SOCKET_ERRORS: # pragma: no cover
raise
return ''
return wrapper
Expand Down
4 changes: 4 additions & 0 deletions src/microdot/wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ def __next__(self):
except StopAsyncIteration:
raise StopIteration

def close(self): # pragma: no cover
if hasattr(self.iter, 'aclose'):
self.loop.run_until_complete(self.iter.aclose())

return async_to_sync_iter(res.body_iter(), self.loop)

def __call__(self, environ, start_response):
Expand Down
6 changes: 6 additions & 0 deletions tests/test_microdot.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,7 @@ def index(req):

def test_streaming(self):
app = Microdot()
done = False

@app.route('/')
def index(req):
Expand All @@ -700,6 +701,10 @@ async def __anext__(self):
self.i += 1
return data

async def aclose(self):
nonlocal done
done = True

return stream()

client = TestClient(app)
Expand All @@ -708,6 +713,7 @@ async def __anext__(self):
self.assertEqual(res.headers['Content-Type'],
'text/plain; charset=UTF-8')
self.assertEqual(res.text, 'foobar')
self.assertEqual(done, True)

def test_already_handled_response(self):
app = Microdot()
Expand Down

0 comments on commit 8441c5f

Please sign in to comment.