Skip to content

Commit

Permalink
Kill all workers when main process exits in prefork model (#6942)
Browse files Browse the repository at this point in the history
* Kill all workers when main process exits in prefork model

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Make flake8 happy

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
matusvalo and pre-commit-ci[bot] committed Sep 5, 2021
1 parent 917088f commit 8ae1215
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 10 deletions.
2 changes: 2 additions & 0 deletions celery/concurrency/prefork.py
Expand Up @@ -41,6 +41,8 @@ def process_initializer(app, hostname):
Initialize the child pool process to ensure the correct
app instance is used and things like logging works.
"""
# Each running worker gets SIGKILL by OS when main process exits.
platforms.set_pdeathsig('SIGKILL')
_set_task_join_will_block(True)
platforms.signals.reset(*WORKER_SIGRESET)
platforms.signals.ignore(*WORKER_SIGIGNORE)
Expand Down
11 changes: 11 additions & 0 deletions celery/platforms.py
Expand Up @@ -17,6 +17,7 @@
from contextlib import contextmanager

from billiard.compat import close_open_fds, get_fdmax
from billiard.util import set_pdeathsig as _set_pdeathsig
# fileno used to be in this module
from kombu.utils.compat import maybe_fileno
from kombu.utils.encoding import safe_str
Expand Down Expand Up @@ -708,6 +709,16 @@ def strargv(argv):
return ''


def set_pdeathsig(name):
"""Sends signal ``name`` to process when parent process terminates."""
if signals.supported('SIGKILL'):
try:
_set_pdeathsig(signals.signum('SIGKILL'))
except OSError:
# We ignore when OS does not support set_pdeathsig
pass


def set_process_title(progname, info=None):
"""Set the :command:`ps` name for the currently running process.
Expand Down
32 changes: 23 additions & 9 deletions t/unit/concurrency/test_prefork.py
Expand Up @@ -53,11 +53,18 @@ def get(self):
return self.value


@patch('celery.platforms.set_mp_process_title')
class test_process_initializer:

@staticmethod
def Loader(*args, **kwargs):
loader = Mock(*args, **kwargs)
loader.conf = {}
loader.override_backends = {}
return loader

@patch('celery.platforms.signals')
@patch('celery.platforms.set_mp_process_title')
def test_process_initializer(self, set_mp_process_title, _signals):
def test_process_initializer(self, _signals, set_mp_process_title):
with mock.restore_logging():
from celery import signals
from celery._state import _tls
Expand All @@ -67,13 +74,7 @@ def test_process_initializer(self, set_mp_process_title, _signals):
on_worker_process_init = Mock()
signals.worker_process_init.connect(on_worker_process_init)

def Loader(*args, **kwargs):
loader = Mock(*args, **kwargs)
loader.conf = {}
loader.override_backends = {}
return loader

with self.Celery(loader=Loader) as app:
with self.Celery(loader=self.Loader) as app:
app.conf = AttributeDict(DEFAULTS)
process_initializer(app, 'awesome.worker.com')
_signals.ignore.assert_any_call(*WORKER_SIGIGNORE)
Expand All @@ -100,6 +101,19 @@ def Loader(*args, **kwargs):
finally:
os.environ.pop('CELERY_LOG_FILE', None)

@patch('celery.platforms.set_pdeathsig')
def test_pdeath_sig(self, _set_pdeathsig, set_mp_process_title):
with mock.restore_logging():
from celery import signals
on_worker_process_init = Mock()
signals.worker_process_init.connect(on_worker_process_init)
from celery.concurrency.prefork import process_initializer

with self.Celery(loader=self.Loader) as app:
app.conf = AttributeDict(DEFAULTS)
process_initializer(app, 'awesome.worker.com')
_set_pdeathsig.assert_called_once_with('SIGKILL')


class test_process_destructor:

Expand Down
14 changes: 13 additions & 1 deletion t/unit/utils/test_platforms.py
Expand Up @@ -18,7 +18,7 @@
close_open_fds, create_pidlock, detached,
fd_by_path, get_fdmax, ignore_errno, initgroups,
isatty, maybe_drop_privileges, parse_gid,
parse_uid, set_mp_process_title,
parse_uid, set_mp_process_title, set_pdeathsig,
set_process_title, setgid, setgroups, setuid,
signals)
from celery.utils.text import WhateverIO
Expand Down Expand Up @@ -170,6 +170,18 @@ def test_setitem_raises(self, set):
signals['INT'] = lambda *a: a


class test_set_pdeathsig:

def test_call(self):
set_pdeathsig('SIGKILL')

@t.skip.if_win32
def test_call_with_correct_parameter(self):
with patch('celery.platforms._set_pdeathsig') as _set_pdeathsig:
set_pdeathsig('SIGKILL')
_set_pdeathsig.assert_called_once_with(signal.SIGKILL)


@t.skip.if_win32
class test_get_fdmax:

Expand Down

0 comments on commit 8ae1215

Please sign in to comment.