diff --git a/google/cloud/pubsub_v1/subscriber/scheduler.py b/google/cloud/pubsub_v1/subscriber/scheduler.py index a11ca490b..b8f2b592c 100644 --- a/google/cloud/pubsub_v1/subscriber/scheduler.py +++ b/google/cloud/pubsub_v1/subscriber/scheduler.py @@ -21,6 +21,7 @@ import abc import concurrent.futures import queue +import warnings class Scheduler(metaclass=abc.ABCMeta): @@ -114,7 +115,14 @@ def schedule(self, callback, *args, **kwargs): Returns: None """ - self._executor.submit(callback, *args, **kwargs) + try: + self._executor.submit(callback, *args, **kwargs) + except RuntimeError: + warnings.warn( + "Scheduling a callback after executor shutdown.", + category=RuntimeWarning, + stacklevel=2, + ) def shutdown(self, await_msg_callbacks=False): """Shut down the scheduler and immediately end all pending callbacks. @@ -142,6 +150,8 @@ def shutdown(self, await_msg_callbacks=False): try: while True: work_item = self._executor._work_queue.get(block=False) + if work_item is None: # Exceutor in shutdown mode. + continue dropped_messages.append(work_item.args[0]) except queue.Empty: pass diff --git a/tests/unit/pubsub_v1/subscriber/test_scheduler.py b/tests/unit/pubsub_v1/subscriber/test_scheduler.py index 82a6719d7..0545c967c 100644 --- a/tests/unit/pubsub_v1/subscriber/test_scheduler.py +++ b/tests/unit/pubsub_v1/subscriber/test_scheduler.py @@ -16,6 +16,7 @@ import queue import threading import time +import warnings import mock @@ -61,6 +62,24 @@ def callback(*args, **kwargs): assert sorted(called_with) == expected_calls +def test_schedule_after_executor_shutdown_warning(): + def callback(*args, **kwargs): + pass + + scheduler_ = scheduler.ThreadScheduler() + + scheduler_.schedule(callback, "arg1", kwarg1="meep") + scheduler_._executor.shutdown() + + with warnings.catch_warnings(record=True) as warned: + scheduler_.schedule(callback, "arg2", kwarg2="boop") + + assert len(warned) == 1 + assert issubclass(warned[0].category, RuntimeWarning) + warning_msg = str(warned[0].message) + assert "after executor shutdown" in warning_msg + + def test_shutdown_nonblocking_by_default(): called_with = [] at_least_one_called = threading.Event() @@ -125,3 +144,30 @@ def callback(message): err_msg = "Shutdown did not wait for the already running callbacks to complete." assert at_least_one_completed.is_set(), err_msg + + +def test_shutdown_handles_executor_queue_sentinels(): + at_least_one_called = threading.Event() + + def callback(_): + at_least_one_called.set() + time.sleep(1.0) + + executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + scheduler_ = scheduler.ThreadScheduler(executor=executor) + + scheduler_.schedule(callback, "message_1") + scheduler_.schedule(callback, "message_2") + scheduler_.schedule(callback, "message_3") + + # Simulate executor shutdown from another thread. + executor._work_queue.put(None) + executor._work_queue.put(None) + + at_least_one_called.wait() + dropped = scheduler_.shutdown(await_msg_callbacks=True) + + assert len(set(dropped)) == 2 # Also test for item uniqueness. + for msg in dropped: + assert msg is not None + assert msg.startswith("message_")