Skip to content

Commit

Permalink
Kill all workers when main process exits in prefork model
Browse files Browse the repository at this point in the history
  • Loading branch information
matusvalo committed Sep 4, 2021
1 parent 917088f commit ef95ed3
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 9 deletions.
4 changes: 4 additions & 0 deletions celery/concurrency/prefork.py
Expand Up @@ -8,6 +8,7 @@
from billiard.common import REMAP_SIGTERM, TERM_SIGNAME
from billiard.pool import CLOSE, RUN
from billiard.pool import Pool as BlockingPool
from billiard.util import set_pdeathsig

from celery import platforms, signals
from celery._state import _set_task_join_will_block, set_default_app
Expand Down Expand Up @@ -41,6 +42,9 @@ def process_initializer(app, hostname):
Initialize the child pool process to ensure the correct
app instance is used and things like logging works.
"""
# Each running worker gets SIGKILL by OS when main process exits.
if platforms.signals.supported('SIGKILL'):
set_pdeathsig(platforms.signals.signum('SIGKILL'))
_set_task_join_will_block(True)
platforms.signals.reset(*WORKER_SIGRESET)
platforms.signals.ignore(*WORKER_SIGIGNORE)
Expand Down
37 changes: 28 additions & 9 deletions t/unit/concurrency/test_prefork.py
@@ -1,9 +1,11 @@
import errno
import os
import socket
import signal
from itertools import cycle
from unittest.mock import Mock, patch

from billiard.util import set_pdeathsig
import pytest
from case import mock

Expand Down Expand Up @@ -53,11 +55,18 @@ def get(self):
return self.value


@patch('celery.platforms.set_mp_process_title')
class test_process_initializer:

@staticmethod
def Loader(*args, **kwargs):
loader = Mock(*args, **kwargs)
loader.conf = {}
loader.override_backends = {}
return loader

@patch('celery.platforms.signals')
@patch('celery.platforms.set_mp_process_title')
def test_process_initializer(self, set_mp_process_title, _signals):
def test_process_initializer(self, _signals, set_mp_process_title):
with mock.restore_logging():
from celery import signals
from celery._state import _tls
Expand All @@ -67,13 +76,7 @@ def test_process_initializer(self, set_mp_process_title, _signals):
on_worker_process_init = Mock()
signals.worker_process_init.connect(on_worker_process_init)

def Loader(*args, **kwargs):
loader = Mock(*args, **kwargs)
loader.conf = {}
loader.override_backends = {}
return loader

with self.Celery(loader=Loader) as app:
with self.Celery(loader=self.Loader) as app:
app.conf = AttributeDict(DEFAULTS)
process_initializer(app, 'awesome.worker.com')
_signals.ignore.assert_any_call(*WORKER_SIGIGNORE)
Expand All @@ -100,6 +103,22 @@ def Loader(*args, **kwargs):
finally:
os.environ.pop('CELERY_LOG_FILE', None)

@patch('celery.concurrency.prefork.set_pdeathsig', wraps=set_pdeathsig)
def test_pdeath_sig(self, _set_pdeathsig, set_mp_process_title):
with mock.restore_logging():
from celery import signals
on_worker_process_init = Mock()
signals.worker_process_init.connect(on_worker_process_init)
from celery.concurrency.prefork import process_initializer

with self.Celery(loader=self.Loader) as app:
app.conf = AttributeDict(DEFAULTS)
process_initializer(app, 'awesome.worker.com')
if hasattr(signal, 'SIGKILL'):
_set_pdeathsig.assert_called_once_with(signal.SIGKILL)
else:
_set_pdeathsig.assert_not_called()


class test_process_destructor:

Expand Down

0 comments on commit ef95ed3

Please sign in to comment.