Skip to content

Commit

Permalink
Add new "context" option for multiprocessing while adding handler (#852)
Browse files Browse the repository at this point in the history
  • Loading branch information
Delgan committed Apr 22, 2023
1 parent 9fc929a commit 9faba68
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 93 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.rst
@@ -1,3 +1,9 @@
`Unreleased`_
=============

- Add a new ``context`` optional argument to ``logger.add()`` specifying ``multiprocessing`` context (like ``"spawn"`` or ``"fork"``) to be used internally instead of the default one (`#851 <https://github.com/Delgan/loguru/issues/851>`_).


`0.7.0`_ (2023-04-10)
=====================

Expand Down
43 changes: 4 additions & 39 deletions docs/resources/recipes.rst
Expand Up @@ -1027,53 +1027,18 @@ The |multiprocessing| library is also commonly used to start a pool of workers u
Independently of the operating system, note that the process in which a handler is added with ``enqueue=True`` is in charge of the queue internally used. This means that you should avoid to ``.remove()`` such handler from the parent process is any child is likely to continue using it. More importantly, note that a |Thread| is started internally to consume the queue. Therefore, it is recommended to call |complete| before leaving |Process| to make sure the queue is left in a stable state.

Another thing to keep in mind when dealing with multiprocessing is the fact that handlers created with ``enqueue=True`` create a queue internally in the current multiprocessing context. If they are passed through to a subprocesses instantiated within a different context (e.g. with :code:`multiprocessing.get_context("spawn")` on linux, where the default context is :code:`"fork"`) it will most likely result in crashing the subprocess. This is also noted in the `python multiprocessing docs <https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods>`_.
Another thing to keep in mind when dealing with multiprocessing is the fact that handlers created with ``enqueue=True`` create a queue internally in the default multiprocessing context. If they are passed through to a subprocesses instantiated within a different context (e.g. with ``multiprocessing.get_context("spawn")`` on linux, where the default context is ``"fork"``) it will most likely result in crashing the subprocess. This is also noted in the `python multiprocessing docs <https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods>`_. To prevent any problems, you should specify the context to be used by Loguru while adding the handler. This can be done by passing the ``context`` argument to the ``add()`` method::

So, running this on linux where the default context is ``fork`` this will not work since the handler is added in a different context:


.. code::
# main.py
import multiprocessing
from loguru import logger
import workers_a
import workers_b

if __name__ == "__main__":
logger.remove()
logger.add("file.log", enqueue=True)
worker = workers_a.Worker()
with multiprocessing.get_context("spawn").Pool(4, initializer=worker.set_logger, initargs=(logger, )) as pool:
results = pool.map(worker.work, [1, 10, 100])
with multiprocessing.get_context("spawn").Pool(4, initializer=workers_b.set_logger, initargs=(logger, )) as pool:
results = pool.map(workers_b.work, [1, 10, 100])
logger.info("Done")
To fix this you can set the multiprocessing context globally so that the handler is created in the same context as the subprocesses run in:

.. code::
# main.py
import multiprocessing
from loguru import logger
import workers_a
import workers_b
if __name__ == "__main__":
multiprocessing.set_start_method("spawn")
context = multiprocessing.get_context("spawn")

logger.remove()
logger.add("file.log", enqueue=True)
logger.add("file.log", enqueue=True, context=context)

worker = workers_a.Worker()
with multiprocessing.Pool(4, initializer=worker.set_logger, initargs=(logger, )) as pool:
with context.Pool(4, initializer=worker.set_logger, initargs=(logger, )) as pool:
results = pool.map(worker.work, [1, 10, 100])
with multiprocessing.Pool(4, initializer=workers_b.set_logger, initargs=(logger, )) as pool:
results = pool.map(workers_b.work, [1, 10, 100])
logger.info("Done")
1 change: 1 addition & 0 deletions loguru/_defaults.py
Expand Up @@ -42,6 +42,7 @@ def env(key, type_, default=None):
LOGURU_BACKTRACE = env("LOGURU_BACKTRACE", bool, True)
LOGURU_DIAGNOSE = env("LOGURU_DIAGNOSE", bool, True)
LOGURU_ENQUEUE = env("LOGURU_ENQUEUE", bool, False)
LOGURU_CONTEXT = env("LOGURU_CONTEXT", str, None)
LOGURU_CATCH = env("LOGURU_CATCH", bool, True)

LOGURU_TRACE_NO = env("LOGURU_TRACE_NO", int, 5)
Expand Down
9 changes: 5 additions & 4 deletions loguru/_handler.py
@@ -1,6 +1,5 @@
import functools
import json
import multiprocessing
import os
import threading
from contextlib import contextmanager
Expand Down Expand Up @@ -41,6 +40,7 @@ def __init__(
colorize,
serialize,
enqueue,
multiprocessing_context,
error_interceptor,
exception_formatter,
id_,
Expand All @@ -55,6 +55,7 @@ def __init__(
self._colorize = colorize
self._serialize = serialize
self._enqueue = enqueue
self._multiprocessing_context = multiprocessing_context
self._error_interceptor = error_interceptor
self._exception_formatter = exception_formatter
self._id = id_
Expand Down Expand Up @@ -86,9 +87,9 @@ def __init__(
self._decolorized_format = self._formatter.strip()

if self._enqueue:
self._queue = multiprocessing.SimpleQueue()
self._confirmation_event = multiprocessing.Event()
self._confirmation_lock = multiprocessing.Lock()
self._queue = self._multiprocessing_context.SimpleQueue()
self._confirmation_event = self._multiprocessing_context.Event()
self._confirmation_lock = self._multiprocessing_context.Lock()
self._owner_process_pid = os.getpid()
self._thread = Thread(
target=self._queued_writer, daemon=True, name="loguru-writer-%d" % self._id
Expand Down
23 changes: 21 additions & 2 deletions loguru/_logger.py
Expand Up @@ -37,6 +37,7 @@
.. |logging| replace:: :mod:`logging`
.. |signal| replace:: :mod:`signal`
.. |contextvars| replace:: :mod:`contextvars`
.. |multiprocessing| replace:: :mod:`multiprocessing`
.. |Thread.run| replace:: :meth:`Thread.run()<threading.Thread.run()>`
.. |Exception| replace:: :class:`Exception`
.. |AbstractEventLoop| replace:: :class:`AbstractEventLoop<asyncio.AbstractEventLoop>`
Expand All @@ -62,6 +63,9 @@
.. _coroutine function: https://docs.python.org/3/glossary.html#term-coroutine-function
.. |re.Pattern| replace:: ``re.Pattern``
.. _re.Pattern: https://docs.python.org/3/library/re.html#re-objects
.. |multiprocessing.Context| replace:: ``multiprocessing.Context``
.. _multiprocessing.Context:
https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
.. |better_exceptions| replace:: ``better_exceptions``
.. _better_exceptions: https://github.com/Qix-/better-exceptions
Expand All @@ -82,7 +86,8 @@
import warnings
from collections import namedtuple
from inspect import isclass, iscoroutinefunction, isgeneratorfunction
from multiprocessing import current_process
from multiprocessing import current_process, get_context
from multiprocessing.context import BaseContext
from os.path import basename, splitext
from threading import current_thread

Expand Down Expand Up @@ -236,6 +241,7 @@ def add(
backtrace=_defaults.LOGURU_BACKTRACE,
diagnose=_defaults.LOGURU_DIAGNOSE,
enqueue=_defaults.LOGURU_ENQUEUE,
context=_defaults.LOGURU_CONTEXT,
catch=_defaults.LOGURU_CATCH,
**kwargs
):
Expand Down Expand Up @@ -270,6 +276,10 @@ def add(
Whether the messages to be logged should first pass through a multiprocessing-safe queue
before reaching the sink. This is useful while logging to a file through multiple
processes. This also has the advantage of making logging calls non-blocking.
context : |multiprocessing.Context| or |str|, optional
A context object or name that will be used for all tasks involving internally the
|multiprocessing| module, in particular when ``enqueue=True``. If ``None``, the default
context is used.
catch : |bool|, optional
Whether errors occurring while sink handles logs messages should be automatically
caught. If ``True``, an exception message is displayed on |sys.stderr| but the exception
Expand Down Expand Up @@ -340,7 +350,7 @@ def add(
The ``sink`` handles incoming log messages and proceed to their writing somewhere and
somehow. A sink can take many forms:
- A |file-like object|_ like ``sys.stderr`` or ``open("somefile.log", "w")``. Anything with
- A |file-like object|_ like ``sys.stderr`` or ``open("file.log", "w")``. Anything with
a ``.write()`` method is considered as a file-like object. Custom handlers may also
implement ``flush()`` (called after each logged message), ``stop()`` (called at sink
termination) and ``complete()`` (awaited by the eponymous method).
Expand Down Expand Up @@ -948,6 +958,14 @@ def add(
if not isinstance(encoding, str):
encoding = "ascii"

if context is None or isinstance(context, str):
context = get_context(context)
elif not isinstance(context, BaseContext):
raise TypeError(
"Invalid context, it should be a string or a multiprocessing context, "
"not: '%s'" % type(context).__name__
)

with self._core.lock:
exception_formatter = ExceptionFormatter(
colorize=colorize,
Expand All @@ -968,6 +986,7 @@ def add(
colorize=colorize,
serialize=serialize,
enqueue=enqueue,
multiprocessing_context=context,
id_=handler_id,
error_interceptor=error_interceptor,
exception_formatter=exception_formatter,
Expand Down
66 changes: 66 additions & 0 deletions tests/test_add_option_context.py
@@ -0,0 +1,66 @@
import multiprocessing
import os
from unittest.mock import MagicMock

import pytest

from loguru import logger


def get_handler_context():
# No better way to test correct value than to access the private attribute.
handler = next(iter(logger._core.handlers.values()))
return handler._multiprocessing_context


def test_default_context():
logger.add(lambda _: None, context=None)
assert get_handler_context() == multiprocessing.get_context(None)


@pytest.mark.skipif(os.name == "nt", reason="Windows does not support forking")
@pytest.mark.parametrize("context_name", ["fork", "forkserver"])
def test_fork_context_as_string(context_name):
logger.add(lambda _: None, context=context_name)
assert get_handler_context() == multiprocessing.get_context(context_name)


def test_spawn_context_as_string():
logger.add(lambda _: None, context="spawn")
assert get_handler_context() == multiprocessing.get_context("spawn")


@pytest.mark.skipif(os.name == "nt", reason="Windows does not support forking")
@pytest.mark.parametrize("context_name", ["fork", "forkserver"])
def test_fork_context_as_object(context_name):
context = multiprocessing.get_context(context_name)
logger.add(lambda _: None, context=context)
assert get_handler_context() == context


def test_spawn_context_as_object():
context = multiprocessing.get_context("spawn")
logger.add(lambda _: None, context=context)
assert get_handler_context() == context


def test_context_effectively_used():
default_context = multiprocessing.get_context()
mocked_context = MagicMock(spec=default_context, wraps=default_context)
logger.add(lambda _: None, context=mocked_context, enqueue=True)
logger.complete()
assert mocked_context.Lock.called


def test_invalid_context_name():
with pytest.raises(ValueError, match=r"cannot find context for"):
logger.add(lambda _: None, context="foobar")


@pytest.mark.parametrize("context", [42, object()])
def test_invalid_context_object(context):
with pytest.raises(
TypeError,
match=r"Invalid context, it should be a string or a multiprocessing context",
):
logger.add(lambda _: None, context=context)
8 changes: 3 additions & 5 deletions tests/test_coroutine_sink.py
Expand Up @@ -7,7 +7,6 @@

import pytest

import loguru
from loguru import logger


Expand Down Expand Up @@ -590,14 +589,13 @@ async def write(self, message):


def test_complete_with_sub_processes(monkeypatch, capsys):
ctx = multiprocessing.get_context("spawn")
monkeypatch.setattr(loguru._handler, "multiprocessing", ctx)
spawn_context = multiprocessing.get_context("spawn")

loop = asyncio.new_event_loop()
writer = Writer()
logger.add(writer.write, format="{message}", enqueue=True, loop=loop)
logger.add(writer.write, context=spawn_context, format="{message}", enqueue=True, loop=loop)

process = ctx.Process(target=subworker, args=[logger])
process = spawn_context.Process(target=subworker, args=[logger])
process.start()
process.join()

Expand Down

0 comments on commit 9faba68

Please sign in to comment.