Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add mutual exclusion for synchronized stream access in logging handlers and CLPLoglevelTimeout (fixes #55). #59

Merged
merged 12 commits into from
Feb 28, 2025
134 changes: 74 additions & 60 deletions src/clp_logging/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pathlib import Path
from queue import Empty, Queue
from signal import SIGINT, signal, SIGTERM
from threading import Thread, Timer
from threading import RLock, Thread, Timer
from types import FrameType
from typing import Any, Callable, ClassVar, Dict, IO, Optional, Tuple, Union

Expand Down Expand Up @@ -226,29 +226,34 @@ def __init__(
self.ostream: Optional[Union[ZstdCompressionWriter, IO[bytes]]] = None
self.hard_timeout_thread: Optional[Timer] = None
self.soft_timeout_thread: Optional[Timer] = None
self.lock: RLock = RLock()

def set_ostream(self, ostream: Union[ZstdCompressionWriter, IO[bytes]]) -> None:
self.ostream = ostream

def get_lock(self) -> RLock:
return self.lock

def timeout(self) -> None:
"""
Wraps the call to the user supplied `timeout_fn` ensuring that any
existing timeout threads are cancelled, `next_hard_timeout_ts` and
`min_soft_timeout_delta` are reset, and the zstandard frame is flushed.
"""
if self.hard_timeout_thread:
self.hard_timeout_thread.cancel()
if self.soft_timeout_thread:
self.soft_timeout_thread.cancel()
self.next_hard_timeout_ts = ULONG_MAX
self.min_soft_timeout_delta = ULONG_MAX

if self.ostream:
if isinstance(self.ostream, ZstdCompressionWriter):
self.ostream.flush(FLUSH_FRAME)
else:
self.ostream.flush()
self.timeout_fn()
with self.get_lock():
if self.hard_timeout_thread:
self.hard_timeout_thread.cancel()
if self.soft_timeout_thread:
self.soft_timeout_thread.cancel()
self.next_hard_timeout_ts = ULONG_MAX
self.min_soft_timeout_delta = ULONG_MAX

if self.ostream:
if isinstance(self.ostream, ZstdCompressionWriter):
self.ostream.flush(FLUSH_FRAME)
else:
self.ostream.flush()
self.timeout_fn()

def update(self, loglevel: int, log_timestamp_ms: int, log_fn: Callable[[str], None]) -> None:
"""
Expand All @@ -262,44 +267,47 @@ def update(self, loglevel: int, log_timestamp_ms: int, log_fn: Callable[[str], N
allows us to correctly log through the handler itself rather than
just printing to stdout/stderr.
"""
hard_timeout_delta: int
if loglevel not in self.hard_timeout_deltas:
log_fn(
f"{WARN_PREFIX} log level {loglevel} not in self.hard_timeout_deltas; defaulting"
" to _HARD_TIMEOUT_DELTAS[logging.INFO].\n"
)
hard_timeout_delta = CLPLogLevelTimeout._HARD_TIMEOUT_DELTAS[logging.INFO]
else:
hard_timeout_delta = self.hard_timeout_deltas[loglevel]

new_hard_timeout_ts: int = log_timestamp_ms + hard_timeout_delta
if new_hard_timeout_ts < self.next_hard_timeout_ts:
if self.hard_timeout_thread:
self.hard_timeout_thread.cancel()
self.hard_timeout_thread = Timer(new_hard_timeout_ts / 1000 - time.time(), self.timeout)
self.hard_timeout_thread.setDaemon(True)
self.hard_timeout_thread.start()
self.next_hard_timeout_ts = new_hard_timeout_ts

soft_timeout_delta: int
if loglevel not in self.soft_timeout_deltas:
log_fn(
f"{WARN_PREFIX} log level {loglevel} not in self.soft_timeout_deltas; defaulting"
" to _SOFT_TIMEOUT_DELTAS[logging.INFO].\n"
)
soft_timeout_delta = CLPLogLevelTimeout._SOFT_TIMEOUT_DELTAS[logging.INFO]
else:
soft_timeout_delta = self.soft_timeout_deltas[loglevel]
with self.get_lock():
hard_timeout_delta: int
if loglevel not in self.hard_timeout_deltas:
log_fn(
f"{WARN_PREFIX} log level {loglevel} not in self.hard_timeout_deltas; "
"defaulting to _HARD_TIMEOUT_DELTAS[logging.INFO].\n"
)
hard_timeout_delta = CLPLogLevelTimeout._HARD_TIMEOUT_DELTAS[logging.INFO]
else:
hard_timeout_delta = self.hard_timeout_deltas[loglevel]

new_hard_timeout_ts: int = log_timestamp_ms + hard_timeout_delta
if new_hard_timeout_ts < self.next_hard_timeout_ts:
if self.hard_timeout_thread:
self.hard_timeout_thread.cancel()
self.hard_timeout_thread = Timer(
new_hard_timeout_ts / 1000 - time.time(), self.timeout
)
self.hard_timeout_thread.setDaemon(True)
self.hard_timeout_thread.start()
self.next_hard_timeout_ts = new_hard_timeout_ts

soft_timeout_delta: int
if loglevel not in self.soft_timeout_deltas:
log_fn(
f"{WARN_PREFIX} log level {loglevel} not in self.soft_timeout_deltas; "
"defaulting to _SOFT_TIMEOUT_DELTAS[logging.INFO].\n"
)
soft_timeout_delta = CLPLogLevelTimeout._SOFT_TIMEOUT_DELTAS[logging.INFO]
else:
soft_timeout_delta = self.soft_timeout_deltas[loglevel]

if soft_timeout_delta < self.min_soft_timeout_delta:
self.min_soft_timeout_delta = soft_timeout_delta
if soft_timeout_delta < self.min_soft_timeout_delta:
self.min_soft_timeout_delta = soft_timeout_delta

new_soft_timeout_ms: int = log_timestamp_ms + soft_timeout_delta
if self.soft_timeout_thread:
self.soft_timeout_thread.cancel()
self.soft_timeout_thread = Timer(new_soft_timeout_ms / 1000 - time.time(), self.timeout)
self.soft_timeout_thread.setDaemon(True)
self.soft_timeout_thread.start()
new_soft_timeout_ms: int = log_timestamp_ms + soft_timeout_delta
if self.soft_timeout_thread:
self.soft_timeout_thread.cancel()
self.soft_timeout_thread = Timer(new_soft_timeout_ms / 1000 - time.time(), self.timeout)
self.soft_timeout_thread.setDaemon(True)
self.soft_timeout_thread.start()


class CLPSockListener:
Expand Down Expand Up @@ -454,15 +462,18 @@ def log_fn(msg: str) -> None:
if loglevel_timeout:
loglevel_timeout.update(loglevel, last_timestamp_ms, log_fn)
buf += timestamp_buf
ostream.write(buf)
with loglevel_timeout.get_lock() if loglevel_timeout else nullcontext():
ostream.write(buf)
if loglevel_timeout:
loglevel_timeout.timeout()
ostream.write(EOF_CHAR)

if enable_compression:
# Since we are not using context manager, the ostream should be
# explicitly closed.
ostream.close()
with loglevel_timeout.get_lock() if loglevel_timeout else nullcontext():
ostream.write(EOF_CHAR)

if enable_compression:
# Since we are not using context manager, the ostream should be
# explicitly closed.
ostream.close()
# tell _server to exit
CLPSockListener._signaled = True
return 0
Expand Down Expand Up @@ -740,7 +751,8 @@ def _direct_write(self, msg: str) -> None:
raise RuntimeError("Stream already closed")
clp_msg: bytearray
clp_msg, self.last_timestamp_ms = _encode_log_event(msg, self.last_timestamp_ms)
self.ostream.write(clp_msg)
with self.loglevel_timeout.get_lock() if self.loglevel_timeout else nullcontext():
self.ostream.write(clp_msg)

# override
def _write(self, loglevel: int, msg: str) -> None:
Expand All @@ -750,7 +762,8 @@ def _write(self, loglevel: int, msg: str) -> None:
clp_msg, self.last_timestamp_ms = _encode_log_event(msg, self.last_timestamp_ms)
if self.loglevel_timeout:
self.loglevel_timeout.update(loglevel, self.last_timestamp_ms, self._direct_write)
self.ostream.write(clp_msg)
with self.loglevel_timeout.get_lock() if self.loglevel_timeout else nullcontext():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about creating a member function like this:

    def _mutex_context(self) -> AbstractContextManager[Any]:
        return self.loglevel_timeout.get_lock() if self.loglevel_timeout else nullcontext()

Benefits:

  • Improve readability
  • We can properly document how this context works and why we need it

Notice that you might need to check whether we can narrow down the Any I put in the type parameter above. I just put Any to silence mypy for a proof-of-concept.

After doing this, we should also consider to move self.loglevel_timeout_update into the locked context.

self.ostream.write(clp_msg)

# Added to logging.StreamHandler in python 3.7
# override
Expand All @@ -775,8 +788,9 @@ def setStream(self, stream: IO[bytes]) -> Optional[IO[bytes]]:
def close(self) -> None:
if self.loglevel_timeout:
self.loglevel_timeout.timeout()
self.ostream.write(EOF_CHAR)
self.ostream.close()
with self.loglevel_timeout.get_lock() if self.loglevel_timeout else nullcontext():
self.ostream.write(EOF_CHAR)
self.ostream.close()
self.closed = True
super().close()

Expand Down
Loading