Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
btschwertfeger committed Nov 28, 2024
1 parent 179740c commit f4d0b20
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 14 deletions.
4 changes: 1 addition & 3 deletions src/websockets/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,9 +913,7 @@ async def send_context(
if wait_for_close:
try:
async with asyncio_timeout_at(self.close_deadline):
self.recv_messages.cancelling = True
if self.recv_messages.paused:
self.recv_messages.resume()
self.recv_messages.prepare_close()
await asyncio.shield(self.connection_lost_waiter)
except TimeoutError:
# There's no risk to overwrite another error because
Expand Down
25 changes: 15 additions & 10 deletions src/websockets/asyncio/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,12 @@ def __init__( # pragma: no cover
# This flag prevents concurrent calls to get() by user code.
self.get_in_progress = False

# This flag marks a soon cancellation
self.cancelling = False
# This flag marks a soon end of the connection.
self.closing = False

# This flag marks the end of the connection.
self.closed = False


async def get(self, decode: bool | None = None) -> Data:
"""
Read the next message.
Expand All @@ -142,8 +141,6 @@ async def get(self, decode: bool | None = None) -> Data:
:meth:`get_iter` concurrently.
"""
if self.cancelling:
return
if self.get_in_progress:
raise ConcurrencyError("get() or get_iter() is already running")
self.get_in_progress = True
Expand Down Expand Up @@ -207,8 +204,6 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]:
:meth:`get_iter` concurrently.
"""
if self.cancelling:
return
if self.get_in_progress:
raise ConcurrencyError("get() or get_iter() is already running")
self.get_in_progress = True
Expand Down Expand Up @@ -259,13 +254,13 @@ def put(self, frame: Frame) -> None:
EOFError: If the stream of frames has ended.
"""
if self.cancelling:
return
if self.closed:
raise EOFError("stream of frames ended")

self.frames.put(frame)
self.maybe_pause()

if not self.closing:
self.maybe_pause()

def maybe_pause(self) -> None:
"""Pause the writer if queue is above the high water mark."""
Expand All @@ -289,6 +284,16 @@ def maybe_resume(self) -> None:
self.paused = False
self.resume()

def prepare_close(self) -> None:
"""
Prepare to close by ensuring that no more messages will be processed.
"""
self.closing = True

# Resuming the writer to avoid deadlocks
if self.paused:
self.resume()

def close(self) -> None:
"""
End the stream of frames.
Expand Down
2 changes: 1 addition & 1 deletion src/websockets/sync/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def close(self) -> None:
"""
End the stream of frames.
Callling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`,
Calling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`,
or :meth:`put` is safe. They will raise :exc:`EOFError`.
"""
Expand Down

0 comments on commit f4d0b20

Please sign in to comment.