diff --git a/loguru/_error_interceptor.py b/loguru/_error_interceptor.py index 9f63d3db..3d2d6e5e 100644 --- a/loguru/_error_interceptor.py +++ b/loguru/_error_interceptor.py @@ -1,11 +1,14 @@ import sys import traceback +from ._locks_machinery import create_error_lock + class ErrorInterceptor: def __init__(self, should_catch, handler_id): self._should_catch = should_catch self._handler_id = handler_id + self._lock = create_error_lock() def should_catch(self): return self._should_catch @@ -14,21 +17,33 @@ def print(self, record=None, *, exception=None): if not sys.stderr: return - if exception is None: - type_, value, traceback_ = sys.exc_info() - else: - type_, value, traceback_ = (type(exception), exception, exception.__traceback__) + # The Lock prevents concurrent writes to standard error. Also, it's registered into the + # machinery to make sure no fork occurs while internal Lock of "sys.stderr" is acquired. + with self._lock: + if exception is None: + type_, value, traceback_ = sys.exc_info() + else: + type_, value, traceback_ = (type(exception), exception, exception.__traceback__) - try: - sys.stderr.write("--- Logging error in Loguru Handler #%d ---\n" % self._handler_id) try: - record_repr = str(record) - except Exception: - record_repr = "/!\\ Unprintable record /!\\" - sys.stderr.write("Record was: %s\n" % record_repr) - traceback.print_exception(type_, value, traceback_, None, sys.stderr) - sys.stderr.write("--- End of logging error ---\n") - except OSError: - pass - finally: - del type_, value, traceback_ + sys.stderr.write("--- Logging error in Loguru Handler #%d ---\n" % self._handler_id) + try: + record_repr = str(record) + except Exception: + record_repr = "/!\\ Unprintable record /!\\" + sys.stderr.write("Record was: %s\n" % record_repr) + traceback.print_exception(type_, value, traceback_, None, sys.stderr) + sys.stderr.write("--- End of logging error ---\n") + except OSError: + pass + finally: + del type_, value, traceback_ + + def __getstate__(self): + state = self.__dict__.copy() + state["_lock"] = None + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self._lock = create_error_lock() diff --git a/loguru/_handler.py b/loguru/_handler.py index 81a3dca0..7a2724e0 100644 --- a/loguru/_handler.py +++ b/loguru/_handler.py @@ -8,6 +8,7 @@ from ._colorizer import Colorizer from ._locks_machinery import create_handler_lock +from ._record_queue import RecordQueue def prepare_colored_format(format_, ansi_level): @@ -44,8 +45,7 @@ def __init__( multiprocessing_context, error_interceptor, exception_formatter, - id_, - levels_ansi_codes + id_ ): self._name = name self._sink = sink @@ -60,7 +60,7 @@ def __init__( self._error_interceptor = error_interceptor self._exception_formatter = exception_formatter self._id = id_ - self._levels_ansi_codes = levels_ansi_codes # Warning, reference shared among handlers + self._levels_ansi_codes = {} self._decolorized_format = None self._precolorized_formats = {} @@ -68,41 +68,45 @@ def __init__( self._stopped = False self._lock = create_handler_lock() - self._lock_acquired = threading.local() + self._thread_locals = threading.local() self._queue = None self._queue_lock = None self._confirmation_event = None self._confirmation_lock = None self._owner_process_pid = None - self._thread = None + self._writer_thread = None + + # We can't use "object()" because their identity doesn't survive pickling. + self._confirmation_sentinel = True + self._stop_sentinel = None if self._is_formatter_dynamic: if self._colorize: self._memoize_dynamic_format = memoize(prepare_colored_format) else: self._memoize_dynamic_format = memoize(prepare_stripped_format) - else: - if self._colorize: - for level_name in self._levels_ansi_codes: - self.update_format(level_name) - else: - self._decolorized_format = self._formatter.strip() + elif not self._colorize: + self._decolorized_format = self._formatter.strip() if self._enqueue: if self._multiprocessing_context is None: - self._queue = multiprocessing.SimpleQueue() + self._queue = RecordQueue( + self._multiprocessing_context, self._error_interceptor, self._id + ) self._confirmation_event = multiprocessing.Event() self._confirmation_lock = multiprocessing.Lock() else: - self._queue = self._multiprocessing_context.SimpleQueue() + self._queue = RecordQueue( + self._multiprocessing_context, self._error_interceptor, self._id + ) self._confirmation_event = self._multiprocessing_context.Event() self._confirmation_lock = self._multiprocessing_context.Lock() self._queue_lock = create_handler_lock() self._owner_process_pid = os.getpid() - self._thread = Thread( - target=self._queued_writer, daemon=True, name="loguru-writer-%d" % self._id + self._writer_thread = Thread( + target=self._threaded_writer, daemon=True, name="loguru-writer-%d" % self._id ) - self._thread.start() + self._writer_thread.start() def __repr__(self): return "(id=%d, level=%d, sink=%s)" % (self._id, self._levelno, self._name) @@ -110,19 +114,19 @@ def __repr__(self): @contextmanager def _protected_lock(self): """Acquire the lock, but fail fast if its already acquired by the current thread.""" - if getattr(self._lock_acquired, "acquired", False): + if getattr(self._thread_locals, "lock_acquired", False): raise RuntimeError( "Could not acquire internal lock because it was already in use (deadlock avoided). " "This likely happened because the logger was re-used inside a sink, a signal " "handler or a '__del__' method. This is not permitted because the logger and its " "handlers are not re-entrant." ) - self._lock_acquired.acquired = True + self._thread_locals.lock_acquired = True try: with self._lock: yield finally: - self._lock_acquired.acquired = False + self._thread_locals.lock_acquired = False def emit(self, record, level_id, from_decorator, is_raw, colored_message): try: @@ -214,10 +218,15 @@ def stop(self): self._stopped = True if self._enqueue: if self._owner_process_pid != os.getpid(): + self._queue.stop() return - self._queue.put(None) - self._thread.join() - if hasattr(self._queue, "close"): + # Although we're not waiting for any confirmation, we still need to acquire + # the underlying Lock to ensure that not two processes try to stop and complete + # the queue at the same time (would possibly cause deadlock). + with self._confirmation_lock: + self._queue.put_final(self._stop_sentinel) + self._writer_thread.join() + self._queue.stop() self._queue.close() self._sink.stop() @@ -227,7 +236,10 @@ def complete_queue(self): return with self._confirmation_lock: - self._queue.put(True) + if self._queue.is_closed(): + return + with self._protected_lock(): + self._queue.put(self._confirmation_sentinel) self._confirmation_event.wait() self._confirmation_event.clear() @@ -238,11 +250,11 @@ def tasks_to_complete(self): with lock: return self._sink.tasks_to_complete() - def update_format(self, level_id): - if not self._colorize or self._is_formatter_dynamic: - return - ansi_code = self._levels_ansi_codes[level_id] - self._precolorized_formats[level_id] = self._formatter.colorize(ansi_code) + def update_format(self, level_id, ansi_code): + with self._protected_lock(): + self._levels_ansi_codes[level_id] = ansi_code + if self._colorize and not self._is_formatter_dynamic: + self._precolorized_formats[level_id] = self._formatter.colorize(ansi_code) @property def levelno(self): @@ -287,43 +299,40 @@ def _serialize_record(text, record): return json.dumps(serializable, default=str, ensure_ascii=False) + "\n" - def _queued_writer(self): - message = None - queue = self._queue - - # We need to use a lock to protect sink during fork. - # Particularly, writing to stderr may lead to deadlock in child process. - lock = self._queue_lock - + def _threaded_writer(self): while True: try: - message = queue.get() + message = self._queue.get() except Exception: - with lock: + with self._queue_lock: self._error_interceptor.print(None) continue - if message is None: + if message is self._stop_sentinel: break - if message is True: + if message is self._confirmation_sentinel: self._confirmation_event.set() continue - with lock: - try: + try: + # We need to use a registered Lock to protect sink during fork. In particular, if + # this thread is writing to stderr while the main thread is forked, the lock + # internally used by stderr might be copied while being in locked state. That would + # cause a deadlock in the child process. + with self._queue_lock: self._sink.write(message) - except Exception: - self._error_interceptor.print(message.record) + except Exception: + self._error_interceptor.print(message.record) def __getstate__(self): state = self.__dict__.copy() state["_lock"] = None - state["_lock_acquired"] = None + state["_thread_locals"] = None state["_memoize_dynamic_format"] = None if self._enqueue: state["_sink"] = None - state["_thread"] = None + state["_writer_thread"] = None state["_owner_process"] = None state["_queue_lock"] = None return state @@ -331,9 +340,7 @@ def __getstate__(self): def __setstate__(self, state): self.__dict__.update(state) self._lock = create_handler_lock() - self._lock_acquired = threading.local() - if self._enqueue: - self._queue_lock = create_handler_lock() + self._thread_locals = threading.local() if self._is_formatter_dynamic: if self._colorize: self._memoize_dynamic_format = memoize(prepare_colored_format) diff --git a/loguru/_locks_machinery.py b/loguru/_locks_machinery.py index 6f021109..e6815f37 100644 --- a/loguru/_locks_machinery.py +++ b/loguru/_locks_machinery.py @@ -10,29 +10,95 @@ def create_logger_lock(): def create_handler_lock(): return threading.Lock() + def create_queue_lock(): + return threading.Lock() + + def create_error_lock(): + return threading.Lock() + else: - # While forking, we need to sanitize all locks to make sure the child process doesn't run into - # a deadlock (if a lock already acquired is inherited) and to protect sink from corrupted state. - # It's very important to acquire logger locks before handlers one to prevent possible deadlock - # while 'remove()' is called for example. + # Using "fork()" in a multi-threaded Python app is kind of deprecated. Although it can work most + # of the time, it is not guaranteed to be safe because it doesn't respect the POSIX standard. + # Still, we need to support it as it can be used by some users, therefore we need to take some + # precautions. Additionally, Loguru itself makes use of threads when "enqueue=True"; it remains + # to be decided whether we should drop compatibility with "fork()" ore re-implement the + # feature using "multiprocessing" instead. + # + # Apart from the non compliance to standards, mixing threads and multiprocessing "fork" will + # create problems if some important principles are not respected. The entire memory is copied to + # the child process, so if a new process is created while a lock is in the "acquired" state, the + # copied lock will also be in the "acquired" state in the newly started process, causing a + # potential deadlock. This can occur if the process is forked by the main thread while there is + # another thread using locks running in the background. + # + # A possible workaround to this problem consists of acquiring all locks before forking, and then + # releasing them in the parent and child processes. This is what is done below using the + # "os.register_at_fork()" function. This also ensures that the sinks are not interrupted during + # execution and that the possible internal resources they use are not copied in an invalid + # state. + # + # However, this technique requires attention to the order in which the locks are acquired. If a + # function uses nested locks, it is crucial to acquire the "outer" lock before the "inner" lock. + # For example, "Logger.remove()" acquires a Lock and then calls "Handler.stop()" which itself + # acquires a second Lock. If a fork occurs between these two steps in a different thread, the + # "acquire_locks()" function must not acquire the second Lock first, as this could lead to a + # deadlock. For this reason, locks are identified by four different types representing their + # intended use according to the current implementation of Loguru. This makes it possible to + # guarantee their correct order of acquisition. + # + # Additionally, it is important to ensure that no new locks are created while forking is + # occurring in a different thread. This can result in errors, such as changes in the set size + # during iteration or attempts to release a lock that was not previously acquired. To address + # this, a global "machinery_lock" is used. + # + # Special consideration must be paid to the registration of new locks, though. Creating a new + # lock requires first acquiring the global lock. However, we stated above that the order of + # acquisition of locks during a fork must be the same as the order of (nested) acquisition in + # the code. This constraint implies that no lock should be created when another lock is already + # in use. Consequently, the Logger must take care that all internal locks are created in + # advance, outside the scope of any other lock, to ensure that the above measures are effective. + # + # Finally, usage of threading Condition can also cause problems. During a fork, the number of + # current waiters is also copied to the child process. To prevent deadlocks, each Condition + # instance must be re-created in the child process, and the inherited one must not be re-used. + + machinery_lock = threading.Lock() logger_locks = weakref.WeakSet() handler_locks = weakref.WeakSet() + queue_locks = weakref.WeakSet() + error_locks = weakref.WeakSet() def acquire_locks(): + machinery_lock.acquire() + for lock in logger_locks: lock.acquire() for lock in handler_locks: lock.acquire() + for lock in queue_locks: + lock.acquire() + + for lock in error_locks: + lock.acquire() + def release_locks(): - for lock in logger_locks: + for lock in error_locks: + lock.release() + + for lock in queue_locks: lock.release() for lock in handler_locks: lock.release() + for lock in logger_locks: + lock.release() + + machinery_lock.release() + os.register_at_fork( before=acquire_locks, after_in_parent=release_locks, @@ -40,11 +106,25 @@ def release_locks(): ) def create_logger_lock(): - lock = threading.Lock() - logger_locks.add(lock) + with machinery_lock: + lock = threading.Lock() + logger_locks.add(lock) return lock def create_handler_lock(): - lock = threading.Lock() - handler_locks.add(lock) + with machinery_lock: + lock = threading.Lock() + handler_locks.add(lock) + return lock + + def create_queue_lock(): + with machinery_lock: + lock = threading.Lock() + queue_locks.add(lock) + return lock + + def create_error_lock(): + with machinery_lock: + lock = threading.Lock() + error_locks.add(lock) return lock diff --git a/loguru/_logger.py b/loguru/_logger.py index f750967a..0b66ba38 100644 --- a/loguru/_logger.py +++ b/loguru/_logger.py @@ -975,32 +975,38 @@ def add( "not: '%s'" % type(context).__name__ ) - with self._core.lock: - exception_formatter = ExceptionFormatter( - colorize=colorize, - encoding=encoding, - diagnose=diagnose, - backtrace=backtrace, - hidden_frames_filename=self.catch.__code__.co_filename, - prefix=exception_prefix, - ) + exception_formatter = ExceptionFormatter( + colorize=colorize, + encoding=encoding, + diagnose=diagnose, + backtrace=backtrace, + hidden_frames_filename=self.catch.__code__.co_filename, + prefix=exception_prefix, + ) - handler = Handler( - name=name, - sink=wrapped_sink, - levelno=levelno, - formatter=formatter, - is_formatter_dynamic=is_formatter_dynamic, - filter_=filter_func, - colorize=colorize, - serialize=serialize, - enqueue=enqueue, - multiprocessing_context=context, - id_=handler_id, - error_interceptor=error_interceptor, - exception_formatter=exception_formatter, - levels_ansi_codes=self._core.levels_ansi_codes, - ) + handler = Handler( + name=name, + sink=wrapped_sink, + levelno=levelno, + formatter=formatter, + is_formatter_dynamic=is_formatter_dynamic, + filter_=filter_func, + colorize=colorize, + serialize=serialize, + enqueue=enqueue, + multiprocessing_context=context, + id_=handler_id, + error_interceptor=error_interceptor, + exception_formatter=exception_formatter, + ) + + with self._core.lock: + # For thread-safety reasons, the handler is updated under the lock as new levels could + # be added in parallel. Because the Handler uses the global "machinery_lock" during its + # initialization, it must not be created under the lock. That explains why the handler + # is updated this way instead of being created directly with the right levels. + for level_id, ansi_code in self._core.levels_ansi_codes.items(): + handler.update_format(level_id, ansi_code) handlers = self._core.handlers.copy() handlers[handler_id] = handler @@ -1611,7 +1617,7 @@ def level(self, name, no=None, color=None, icon=None): self._core.levels_ansi_codes[name] = ansi self._core.levels_lookup[name] = (name, name, no, icon) for handler in self._core.handlers.values(): - handler.update_format(name) + handler.update_format(name, ansi) return level diff --git a/loguru/_record_queue.py b/loguru/_record_queue.py new file mode 100644 index 00000000..80b45582 --- /dev/null +++ b/loguru/_record_queue.py @@ -0,0 +1,170 @@ +import collections +import multiprocessing +from multiprocessing.util import Finalize +from threading import Condition, Thread + +from ._locks_machinery import create_queue_lock + + +class RecordQueue: + """A multiprocessing-safe queue in charge of transferring records between processes. + + The design is very closely coupled to the intended usage by the Handler class. + + This class is not fully thread-safe. Concurrent calls likely need to be protected by a Lock. + """ + + def __init__(self, multiprocessing_context, error_interceptor, handler_id): + self._error_interceptor = error_interceptor + self._handler_id = handler_id + self._buffer = None + self._multiprocessing_context = multiprocessing_context + + if self._multiprocessing_context is None: + self._lock = multiprocessing.Lock() + self._receiver, self._sender = multiprocessing.Pipe(duplex=False) + self._is_closed = multiprocessing.Event() + else: + self._lock = self._multiprocessing_context.Lock() + self._receiver, self._sender = self._multiprocessing_context.Pipe(duplex=False) + self._is_closed = self._multiprocessing_context.Event() + + self._broker_thread = None + self._condition_lock = create_queue_lock() + self._condition = None + + self._finalize = None + + self._sentinel_stop = object() + self._sentinel_close = object() + + def put(self, item): + """Put a logging record in the queue and return immediately.""" + # Each process needs its own thread. However, when a child process is started, the inherited + # thread will appear nullified ("spawn" method) or stopped ("fork" method). In such case, + # that means that we are in a new process, and that we must therefore start a new thread. + # To reduce repetition, this initialization strategy also applies to the owner process (the + # thread, initially "None", is only created at the time of the first "put()" call). + # Note that we don't need to acquire a Lock here as concurrent calls are already protected + # by a Lock in the Handler. The Condition only serves to wake up the broker thread. + if not self._broker_thread or not self._broker_thread.is_alive(): + # Items copied during fork must be discarded. + self._buffer = collections.deque() + + # Must be re-created in each forked process because the number of waiters is copied. + self._condition = Condition(self._condition_lock) + + # We must ensure the process is not abruptly terminated while the broker thread has + # acquired the Lock. Otherwise, others processes might be blocked forever. We can't + # expect all users to call "logger.complete()" before terminating the process, so we + # use the undocumented "Finalize" class which allows a function to be called when the + # process is about to terminate. This is required because "atexit.register()" is not + # working in child processes started by "fork" method. + self._finalize = Finalize(self, self.stop, exitpriority=0) + + self._broker_thread = Thread( + target=self._threaded_broker, + daemon=True, + name="loguru-broker-%d" % self._handler_id, + ) + self._broker_thread.start() + + with self._condition: + self._buffer.append(item) + self._condition.notify() + + def get(self): + """Get the next pending item from the queue (block until one is available if necessary).""" + # The Handler calls this method from a single thread and from a single process, therefore + # no Lock is necessary (contrary to the "put()" method which is called from multiple + # processes). + return self._receiver.recv() + + def put_final(self, item): + """Put one last item in the queue and disable it for further use. + + Once the item has been processed, subsequent elements possibly added by other processes + will be ignored. This means once the item is read from the queue, it is guaranteed that + no other item will ever be read from it. + + This method is intended to be called exactly once to prepare termination of reader thread. + """ + # The Handler lock protects this method call, therefore the two elements will be added + # atomically (they are guaranteed to appear consecutively in the queue). + self.put(self._sentinel_close) + self.put(item) + + def stop(self): + """Stop processing queued items and wait for the internal thread to finish. + + This method is expected to be called before closing the queue, to ensure all items have + been processed. + """ + if not self._broker_thread or not self._broker_thread.is_alive(): + return + self.put(self._sentinel_stop) + self._broker_thread.join() + + def close(self): + """Close the queue definitively and release its internal resources. + + This method must not be called if the queue is still in use by any process. In particular, + the queue must first be stopped and neither "put()" nor "get()" must be called afterwards or + concurrently. + + This method should be called exactly once, generally after a call to "put_final()". + """ + self._buffer.clear() + self._receiver.close() + self._sender.close() + + def is_closed(self): + """Check whether the queue has been closed (possibly by another process). + + This avoids queuing items that will never be processed. + """ + return self._is_closed.is_set() + + def _threaded_broker(self): + is_final = False + + while True: + with self._condition: + if not self._buffer: + self._condition.wait() + record = self._buffer.popleft() + + if record is self._sentinel_close: + is_final = True + continue + + if record is self._sentinel_stop: + break + + try: + with self._lock: + if self._is_closed.is_set(): + continue + # It's crucial to toggle the "is_closed" flag and send the final record + # atomically (under the same lock acquisition). Other processes must not be able + # to send records after the one that is expected to be last. + self._sender.send(record) + if is_final: + self._is_closed.set() + except Exception: + record = record.record if hasattr(record, "record") else record + self._error_interceptor.print(record) + + def __getstate__(self): + state = self.__dict__.copy() + state["_buffer"] = None + state["_broker_thread"] = None + state["_condition_lock"] = None + state["_condition"] = None + state["_finalize"] = None + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self._condition_lock = create_queue_lock() + self._waiter_lock = create_queue_lock() diff --git a/tests/test_add_option_catch.py b/tests/test_add_option_catch.py index 486f920c..7c66a0dd 100644 --- a/tests/test_add_option_catch.py +++ b/tests/test_add_option_catch.py @@ -1,6 +1,5 @@ import re import sys -import time import pytest @@ -93,7 +92,7 @@ def __repr__(self): def test_broken_sink_message(capsys, enqueue): logger.add(broken_sink, catch=True, enqueue=enqueue) logger.debug("Oops") - time.sleep(0.1) + logger.complete() out, err = capsys.readouterr() lines = err.strip().splitlines() @@ -121,7 +120,7 @@ def half_broken_sink(m): logger.info("NOK") logger.info("B") - time.sleep(0.1) + logger.complete() assert output == "A\nB\n" @@ -138,6 +137,6 @@ def broken_sink(m): with default_threading_excepthook(): logger.info("A") logger.info("B") - time.sleep(0.1) + logger.complete() assert called == 2 diff --git a/tests/test_add_option_context.py b/tests/test_add_option_context.py index c9d3c27e..2cbb5529 100644 --- a/tests/test_add_option_context.py +++ b/tests/test_add_option_context.py @@ -7,13 +7,14 @@ from loguru import logger -@pytest.fixture +@pytest.fixture(autouse=True) def reset_start_method(): + """Ensure tests (from this module or not) aren't affecting each others.""" + multiprocessing.set_start_method(None, force=True) yield multiprocessing.set_start_method(None, force=True) -@pytest.mark.usefixtures("reset_start_method") def test_using_multiprocessing_directly_if_context_is_none(): logger.add(lambda _: None, enqueue=True, context=None) assert multiprocessing.get_start_method(allow_none=True) is not None diff --git a/tests/test_add_option_enqueue.py b/tests/test_add_option_enqueue.py index c367e1d7..792086db 100644 --- a/tests/test_add_option_enqueue.py +++ b/tests/test_add_option_enqueue.py @@ -1,6 +1,7 @@ import pickle import re import sys +import threading import time import pytest @@ -144,16 +145,19 @@ def test_not_caught_exception_queue_put(writer, capsys): logger.add(writer, enqueue=True, catch=False, format="{message}") logger.info("It's fine") - - with pytest.raises(pickle.PicklingError, match=r"You shall not serialize me!"): - logger.bind(broken=NotPicklable()).info("Bye bye...") + logger.bind(broken=NotPicklable()).info("Bye bye...") + logger.info("It's fine again") logger.remove() out, err = capsys.readouterr() - assert writer.read() == "It's fine\n" + lines = err.strip().splitlines() + assert writer.read() == "It's fine\nIt's fine again\n" assert out == "" - assert err == "" + assert lines[0] == "--- Logging error in Loguru Handler #0 ---" + assert re.match(r"Record was: \{.*Bye bye.*\}", lines[1]) + assert lines[-2].endswith("PicklingError: You shall not serialize me!") + assert lines[-1] == "--- End of logging error ---" def test_not_caught_exception_queue_get(writer, capsys): @@ -248,6 +252,39 @@ def slow_sink(message): assert err == "".join("%d\n" % i for i in range(10)) +def test_complete_without_logging_any_message(writer): + logger.add(writer, enqueue=True, catch=False, format="{message}") + logger.complete() + assert writer.read() == "" + + +def test_remove_without_logging_any_message(writer): + logger.add(writer, enqueue=True, catch=False, format="{message}") + logger.remove() + assert writer.read() == "" + + +def test_main_thread_not_blocked(writer): + event = threading.Event() + + def sink(message): + event.wait() + writer(message) + + logger.add(sink, enqueue=True, catch=False, format=lambda r: "{message}") + + # Pipes have default capacity of 65,536 bytes. + # If it's full, the logger must not block. + for _ in range(1000): + logger.info("." * 10000) + + event.set() + + logger.complete() + + assert writer.read() == "." * 10000 * 1000 + + @pytest.mark.parametrize("exception_value", [NotPicklable(), NotPicklableTypeError()]) def test_logging_not_picklable_exception(exception_value): exception = None diff --git a/tests/test_multiprocessing.py b/tests/test_multiprocessing.py index cd25cd64..42b9617b 100644 --- a/tests/test_multiprocessing.py +++ b/tests/test_multiprocessing.py @@ -43,12 +43,34 @@ def subworker_remove(logger_): logger_.info("Child") logger_.remove() logger_.info("Nope") + logger_.complete() def subworker_remove_inheritance(): logger.info("Child") logger.remove() logger.info("Nope") + logger.complete() + + +def subworker_remove_no_logging(logger_): + logger_.remove() + logger_.info("Nope") + logger_.complete() + + +def subworker_remove_no_logging_inheritance(): + logger.remove() + logger.info("Nope") + logger.complete() + + +def subworker_complete_no_logging(logger_): + logger_.complete() + + +def subworker_complete_no_logging_inheritance(): + logger.complete() def subworker_complete(logger_): @@ -69,18 +91,22 @@ async def work(): loop.run_until_complete(work()) -def subworker_barrier(logger_, barrier): +def subworker_barrier(logger_, barrier_1, barrier_2): logger_.info("Child") - barrier.wait() - time.sleep(0.5) + logger_.complete() + barrier_1.wait() + barrier_2.wait() logger_.info("Nope") + logger_.complete() -def subworker_barrier_inheritance(barrier): +def subworker_barrier_inheritance(barrier_1, barrier_2): logger.info("Child") - barrier.wait() - time.sleep(0.5) + logger.complete() + barrier_1.wait() + barrier_2.wait() logger.info("Nope") + logger.complete() class Writer: @@ -314,15 +340,20 @@ def test_remove_in_main_process_spawn(spawn_context): # In such situation, it seems the child process has not enough time to initialize itself # It may fail with an "EOFError" during unpickling of the (garbage collected / closed) Queue writer = Writer() - barrier = spawn_context.Barrier(2) + init_barrier = spawn_context.Barrier(2) + remove_barrier = spawn_context.Barrier(2) logger.add(writer, context=spawn_context, format="{message}", enqueue=True, catch=False) - process = spawn_context.Process(target=subworker_barrier, args=(logger, barrier)) + process = spawn_context.Process( + target=subworker_barrier, args=(logger, init_barrier, remove_barrier) + ) process.start() - barrier.wait() + init_barrier.wait() + logger.info("Main") logger.remove() + remove_barrier.wait() process.join() assert process.exitcode == 0 @@ -333,15 +364,19 @@ def test_remove_in_main_process_spawn(spawn_context): @pytest.mark.skipif(os.name == "nt", reason="Windows does not support forking") def test_remove_in_main_process_fork(fork_context): writer = Writer() - barrier = fork_context.Barrier(2) + init_barrier = fork_context.Barrier(2) + remove_barrier = fork_context.Barrier(2) logger.add(writer, context=fork_context, format="{message}", enqueue=True, catch=False) - process = fork_context.Process(target=subworker_barrier, args=(logger, barrier)) + process = fork_context.Process( + target=subworker_barrier, args=(logger, init_barrier, remove_barrier) + ) process.start() - barrier.wait() + init_barrier.wait() logger.info("Main") logger.remove() + remove_barrier.wait() process.join() assert process.exitcode == 0 @@ -352,15 +387,20 @@ def test_remove_in_main_process_fork(fork_context): @pytest.mark.skipif(os.name == "nt", reason="Windows does not support forking") def test_remove_in_main_process_inheritance(fork_context): writer = Writer() - barrier = fork_context.Barrier(2) + init_barrier = fork_context.Barrier(2) + remove_barrier = fork_context.Barrier(2) logger.add(writer, context=fork_context, format="{message}", enqueue=True, catch=False) - process = fork_context.Process(target=subworker_barrier_inheritance, args=(barrier,)) + process = fork_context.Process( + target=subworker_barrier_inheritance, + args=(init_barrier, remove_barrier), + ) process.start() - barrier.wait() + init_barrier.wait() logger.info("Main") logger.remove() + remove_barrier.wait() process.join() assert process.exitcode == 0 @@ -368,6 +408,112 @@ def test_remove_in_main_process_inheritance(fork_context): assert writer.read() == "Child\nMain\n" +def test_remove_in_child_without_logging_spawn(spawn_context): + writer = Writer() + + logger.add(writer, context=spawn_context, format="{message}", enqueue=True, catch=False) + + process = spawn_context.Process(target=subworker_remove_no_logging, args=(logger,)) + process.start() + process.join() + + assert process.exitcode == 0 + + logger.info("Main") + logger.remove() + + assert writer.read() == "Main\n" + + +@pytest.mark.skipif(os.name == "nt", reason="Windows does not support forking") +def test_remove_in_child_without_logging_fork(fork_context): + writer = Writer() + + logger.add(writer, context=fork_context, format="{message}", enqueue=True, catch=False) + + process = fork_context.Process(target=subworker_remove_no_logging, args=(logger,)) + process.start() + process.join() + + assert process.exitcode == 0 + + logger.info("Main") + logger.remove() + + assert writer.read() == "Main\n" + + +@pytest.mark.skipif(os.name == "nt", reason="Windows does not support forking") +def test_remove_in_child_without_logging_inheritance(fork_context): + writer = Writer() + + logger.add(writer, context=fork_context, format="{message}", enqueue=True, catch=False) + + process = fork_context.Process(target=subworker_remove_no_logging_inheritance) + process.start() + process.join() + + assert process.exitcode == 0 + + logger.info("Main") + logger.remove() + + assert writer.read() == "Main\n" + + +def test_complete_in_child_without_logging_spawn(spawn_context): + writer = Writer() + + logger.add(writer, context=spawn_context, format="{message}", enqueue=True, catch=False) + + process = spawn_context.Process(target=subworker_complete_no_logging, args=(logger,)) + process.start() + process.join() + + assert process.exitcode == 0 + + logger.info("Main") + logger.complete() + + assert writer.read() == "Main\n" + + +@pytest.mark.skipif(os.name == "nt", reason="Windows does not support forking") +def test_complete_in_child_without_logging_fork(fork_context): + writer = Writer() + + logger.add(writer, context=fork_context, format="{message}", enqueue=True, catch=False) + + process = fork_context.Process(target=subworker_complete_no_logging, args=(logger,)) + process.start() + process.join() + + assert process.exitcode == 0 + + logger.info("Main") + logger.complete() + + assert writer.read() == "Main\n" + + +@pytest.mark.skipif(os.name == "nt", reason="Windows does not support forking") +def test_complete_in_child_without_logging_inheritance(fork_context): + writer = Writer() + + logger.add(writer, context=fork_context, format="{message}", enqueue=True, catch=False) + + process = fork_context.Process(target=subworker_complete_no_logging_inheritance) + process.start() + process.join() + + assert process.exitcode == 0 + + logger.info("Main") + logger.complete() + + assert writer.read() == "Main\n" + + def test_await_complete_spawn(capsys, spawn_context): async def writer(msg): print(msg, end="") @@ -604,6 +750,241 @@ def test_no_deadlock_if_external_lock_in_use(enqueue, capsys, fork_context): assert err == "".join("This is a message: %d\n" % i for i in range(num)) +@pytest.mark.skipif(os.name == "nt", reason="Windows does not support forking") +def test_concurrent_logging_from_multiple_children(capsys, fork_context): + writer = Writer() + num = 10 + log_count = 100 + sentence = "This is some message from a child process." + + barrier = fork_context.Barrier(num) + + def sink(message): + for character in message: + writer.write(character) + + def worker(): + barrier.wait() + for _ in range(log_count): + logger.info(sentence) + + logger.add(sink, context=fork_context, format="{message}", enqueue=True, catch=False) + + processes = [] + + for _ in range(num): + process = fork_context.Process(target=worker) + process.start() + processes.append(process) + + for process in processes: + process.join(5) + assert process.exitcode == 0 + + logger.complete() + + assert writer.read() == (sentence + "\n") * num * log_count + + out, err = capsys.readouterr() + assert out == err == "" + + +@pytest.mark.skipif(os.name == "nt", reason="Windows does not support forking") +def test_remove_from_main_while_child_is_processing(fork_context): + barrier = fork_context.Barrier(2) + + count = 10000 + + def worker(): + for _ in range(count): + logger.info(".") + barrier.wait() + + logger.add(lambda m: None, enqueue=True, context=fork_context, catch=False, format="{message}") + + process = fork_context.Process(target=worker) + process.start() + + barrier.wait() + logger.remove() + + process.join(1) + assert process.exitcode == 0 + + +@pytest.mark.skipif(os.name == "nt", reason="Windows does not support forking") +def test_remove_from_main_while_pipe_is_full(fork_context): + barrier = fork_context.Barrier(2) + + def worker(): + for _ in range(10): + logger.info("." * 100000) + barrier.wait() + + logger.add(lambda m: None, enqueue=True, context=fork_context, catch=False, format="{message}") + + process = fork_context.Process(target=worker) + process.start() + + barrier.wait() + logger.remove() + + process.join(1) + assert process.exitcode == 0 + + +@pytest.mark.parametrize("init", [True, False]) +@pytest.mark.parametrize("complete", [True, False]) +@pytest.mark.skipif(os.name == "nt", reason="Windows does not support forking") +def test_remove_from_main_while_children_processing_big_messages(fork_context, init, complete): + barrier1 = fork_context.Barrier(31) + barrier2 = fork_context.Barrier(31) + + def worker(): + barrier1.wait() + for i in range(50): + logger.bind(data="." * 100000).info(i) + barrier2.wait() + for i in range(50): + logger.bind(data="." * 100000).info(i) + if complete: + logger.complete() + + logger.add( + lambda _: None, + enqueue=True, + format="{message}", + context=fork_context, + catch=False, + ) + + if init: + logger.info("Init") + + processes = [] + + for _ in range(30): + process = fork_context.Process(target=worker) + process.start() + processes.append(process) + + barrier1.wait() + barrier2.wait() + logger.remove() + + for process in processes: + process.join(5) + assert process.exitcode == 0 + + +@pytest.mark.skipif(os.name == "nt", reason="Windows does not support forking") +def test_log_and_complete_concurrently_initialize_queue_thread(fork_context, writer): + def worker(i): + barrier = threading.Barrier(2) + + def log(): + barrier.wait() + logger.info(i) + + def complete(): + barrier.wait() + logger.complete() + + thread_1 = threading.Thread(target=log) + thread_2 = threading.Thread(target=complete) + + thread_1.start() + thread_2.start() + + thread_1.join() + thread_2.join() + + logger.add( + writer, + enqueue=True, + format="{message}", + context=fork_context, + catch=False, + ) + + logger.info("Init") + logger.complete() + + for i in range(100): + process = fork_context.Process(target=worker, args=(i,)) + process.start() + process.join(1) + assert process.exitcode == 0 + + logger.complete() + + assert writer.read() == "Init\n" + "".join("%d\n" % i for i in range(100)) + + +@pytest.mark.skipif(os.name == "nt", reason="Windows does not support forking") +def test_concurrent_remove_and_complete(capsys, fork_context): + barrier = fork_context.Barrier(2) + event = fork_context.Event() + num = 100 + + handler_ids = [] + + for _ in range(num): + handler_id = logger.add(lambda m: None, enqueue=True, context=fork_context, catch=False) + handler_ids.append(handler_id) + logger.info("Message") + + def worker(): + barrier.wait() + while not event.is_set(): + logger.complete() + + process = fork_context.Process(target=worker) + process.start() + + barrier.wait() + + for handler_id in handler_ids: + logger.remove(handler_id) + + event.set() + + process.join(1) + assert process.exitcode == 0 + + out, err = capsys.readouterr() + assert out == err == "" + + +@pytest.mark.skipif(os.name == "nt", reason="Windows does not support forking") +def test_creating_machinery_locks_and_concurrent_forking(capsys, fork_context): + running = True + + def worker_thread(): + while running: + i = logger.add(lambda _: None, enqueue=True, context=fork_context, catch=False) + logger.remove(i) + + def worker_process(): + logger.info("Message") + logger.complete() + + thread = threading.Thread(target=worker_thread) + thread.start() + + for _ in range(100): + process = fork_context.Process(target=worker_process) + process.start() + process.join(1) + assert process.exitcode == 0 + + running = False + thread.join() + + out, err = capsys.readouterr() + assert out == err == "" + + @pytest.mark.skipif(os.name == "nt", reason="Windows does not support forking") @pytest.mark.skipif(platform.python_implementation() == "PyPy", reason="PyPy is too slow") def test_complete_from_multiple_child_processes(capsys, fork_context): @@ -629,3 +1010,4 @@ def worker(barrier): out, err = capsys.readouterr() assert out == err == "" + assert out == err == ""