From d40d02713c8c189937ae5c21d099b88a3131a59f Mon Sep 17 00:00:00 2001 From: Peter Lamut Date: Wed, 24 Mar 2021 09:36:58 +0100 Subject: [PATCH] fix: move await_msg_callbacks flag to subscribe() method (#320) * Revert "revert: add graceful streaming pull shutdown (#315)" This reverts commit 16bf58823020c6b20e03a21b8b1de46eda3a2340. * Move await_msg_callbacks to subscribe() method This is to keep the StreamingPullFuture's surface intact for compatibility with PubSub Lite client. * Make streaming pull close() method non-blocking * Add a blocking streaming pull shutdown sample * Refine docs on awaiting callbacks on shutdown --- .../subscriber/_protocol/dispatcher.py | 3 - .../subscriber/_protocol/heartbeater.py | 9 +- .../pubsub_v1/subscriber/_protocol/leaser.py | 2 +- .../_protocol/streaming_pull_manager.py | 108 +++++++++++++----- google/cloud/pubsub_v1/subscriber/client.py | 13 +++ .../cloud/pubsub_v1/subscriber/scheduler.py | 49 ++++++-- samples/snippets/subscriber.py | 57 +++++++++ samples/snippets/subscriber_test.py | 47 ++++++++ tests/system.py | 86 +++++++++++++- .../pubsub_v1/subscriber/test_dispatcher.py | 31 +++-- .../pubsub_v1/subscriber/test_heartbeater.py | 41 +++++-- .../unit/pubsub_v1/subscriber/test_leaser.py | 24 ++-- .../pubsub_v1/subscriber/test_scheduler.py | 85 ++++++++++++-- .../subscriber/test_streaming_pull_manager.py | 82 +++++++++---- .../subscriber/test_subscriber_client.py | 2 + 15 files changed, 533 insertions(+), 106 deletions(-) diff --git a/google/cloud/pubsub_v1/subscriber/_protocol/dispatcher.py b/google/cloud/pubsub_v1/subscriber/_protocol/dispatcher.py index 7a8950844..382c5c38a 100644 --- a/google/cloud/pubsub_v1/subscriber/_protocol/dispatcher.py +++ b/google/cloud/pubsub_v1/subscriber/_protocol/dispatcher.py @@ -99,9 +99,6 @@ def dispatch_callback(self, items): ValueError: If ``action`` isn't one of the expected actions "ack", "drop", "lease", "modify_ack_deadline" or "nack". """ - if not self._manager.is_active: - return - batched_commands = collections.defaultdict(list) for item in items: diff --git a/google/cloud/pubsub_v1/subscriber/_protocol/heartbeater.py b/google/cloud/pubsub_v1/subscriber/_protocol/heartbeater.py index 9cd84a1e2..fef158965 100644 --- a/google/cloud/pubsub_v1/subscriber/_protocol/heartbeater.py +++ b/google/cloud/pubsub_v1/subscriber/_protocol/heartbeater.py @@ -35,10 +35,11 @@ def __init__(self, manager, period=_DEFAULT_PERIOD): self._period = period def heartbeat(self): - """Periodically send heartbeats.""" - while self._manager.is_active and not self._stop_event.is_set(): - self._manager.heartbeat() - _LOGGER.debug("Sent heartbeat.") + """Periodically send streaming pull heartbeats. + """ + while not self._stop_event.is_set(): + if self._manager.heartbeat(): + _LOGGER.debug("Sent heartbeat.") self._stop_event.wait(timeout=self._period) _LOGGER.info("%s exiting.", _HEARTBEAT_WORKER_NAME) diff --git a/google/cloud/pubsub_v1/subscriber/_protocol/leaser.py b/google/cloud/pubsub_v1/subscriber/_protocol/leaser.py index adb1650d2..c1f8b46d2 100644 --- a/google/cloud/pubsub_v1/subscriber/_protocol/leaser.py +++ b/google/cloud/pubsub_v1/subscriber/_protocol/leaser.py @@ -124,7 +124,7 @@ def maintain_leases(self): ack IDs, then waits for most of that time (but with jitter), and repeats. """ - while self._manager.is_active and not self._stop_event.is_set(): + while not self._stop_event.is_set(): # Determine the appropriate duration for the lease. This is # based off of how long previous messages have taken to ack, with # a sensible default and within the ranges allowed by Pub/Sub. diff --git a/google/cloud/pubsub_v1/subscriber/_protocol/streaming_pull_manager.py b/google/cloud/pubsub_v1/subscriber/_protocol/streaming_pull_manager.py index de333c539..ac940de26 100644 --- a/google/cloud/pubsub_v1/subscriber/_protocol/streaming_pull_manager.py +++ b/google/cloud/pubsub_v1/subscriber/_protocol/streaming_pull_manager.py @@ -16,6 +16,7 @@ import collections import functools +import itertools import logging import threading import uuid @@ -36,6 +37,7 @@ from google.pubsub_v1 import types as gapic_types _LOGGER = logging.getLogger(__name__) +_REGULAR_SHUTDOWN_THREAD_NAME = "Thread-RegularStreamShutdown" _RPC_ERROR_THREAD_NAME = "Thread-OnRpcTerminated" _RETRYABLE_STREAM_ERRORS = ( exceptions.DeadlineExceeded, @@ -110,12 +112,20 @@ class StreamingPullManager(object): scheduler (~google.cloud.pubsub_v1.scheduler.Scheduler): The scheduler to use to process messages. If not provided, a thread pool-based scheduler will be used. + await_callbacks_on_shutdown (bool): + If ``True``, the shutdown thread will wait until all scheduler threads + terminate and only then proceed with shutting down the remaining running + helper threads. + + If ``False`` (default), the shutdown thread will shut the scheduler down, + but it will not wait for the currently executing scheduler threads to + terminate. + + This setting affects when the on close callbacks get invoked, and + consequently, when the StreamingPullFuture associated with the stream gets + resolved. """ - _UNARY_REQUESTS = True - """If set to True, this class will make requests over a separate unary - RPC instead of over the streaming RPC.""" - def __init__( self, client, @@ -123,11 +133,13 @@ def __init__( flow_control=types.FlowControl(), scheduler=None, use_legacy_flow_control=False, + await_callbacks_on_shutdown=False, ): self._client = client self._subscription = subscription self._flow_control = flow_control self._use_legacy_flow_control = use_legacy_flow_control + self._await_callbacks_on_shutdown = await_callbacks_on_shutdown self._ack_histogram = histogram.Histogram() self._last_histogram_size = 0 self._ack_deadline = 10 @@ -291,6 +303,9 @@ def activate_ordering_keys(self, ordering_keys): activate. May be empty. """ with self._pause_resume_lock: + if self._scheduler is None: + return # We are shutting down, don't try to dispatch any more messages. + self._messages_on_hold.activate_ordering_keys( ordering_keys, self._schedule_message_on_hold ) @@ -420,37 +435,36 @@ def send(self, request): If a RetryError occurs, the manager shutdown is triggered, and the error is re-raised. """ - if self._UNARY_REQUESTS: - try: - self._send_unary_request(request) - except exceptions.GoogleAPICallError: - _LOGGER.debug( - "Exception while sending unary RPC. This is typically " - "non-fatal as stream requests are best-effort.", - exc_info=True, - ) - except exceptions.RetryError as exc: - _LOGGER.debug( - "RetryError while sending unary RPC. Waiting on a transient " - "error resolution for too long, will now trigger shutdown.", - exc_info=False, - ) - # The underlying channel has been suffering from a retryable error - # for too long, time to give up and shut the streaming pull down. - self._on_rpc_done(exc) - raise - - else: - self._rpc.send(request) + try: + self._send_unary_request(request) + except exceptions.GoogleAPICallError: + _LOGGER.debug( + "Exception while sending unary RPC. This is typically " + "non-fatal as stream requests are best-effort.", + exc_info=True, + ) + except exceptions.RetryError as exc: + _LOGGER.debug( + "RetryError while sending unary RPC. Waiting on a transient " + "error resolution for too long, will now trigger shutdown.", + exc_info=False, + ) + # The underlying channel has been suffering from a retryable error + # for too long, time to give up and shut the streaming pull down. + self._on_rpc_done(exc) + raise def heartbeat(self): """Sends an empty request over the streaming pull RPC. - This always sends over the stream, regardless of if - ``self._UNARY_REQUESTS`` is set or not. + Returns: + bool: If a heartbeat request has actually been sent. """ if self._rpc is not None and self._rpc.is_active: self._rpc.send(gapic_types.StreamingPullRequest()) + return True + + return False def open(self, callback, on_callback_error): """Begin consuming messages. @@ -517,11 +531,29 @@ def close(self, reason=None): This method is idempotent. Additional calls will have no effect. + The method does not block, it delegates the shutdown operations to a background + thread. + Args: - reason (Any): The reason to close this. If None, this is considered + reason (Any): The reason to close this. If ``None``, this is considered an "intentional" shutdown. This is passed to the callbacks specified via :meth:`add_close_callback`. """ + thread = threading.Thread( + name=_REGULAR_SHUTDOWN_THREAD_NAME, + daemon=True, + target=self._shutdown, + kwargs={"reason": reason}, + ) + thread.start() + + def _shutdown(self, reason=None): + """Run the actual shutdown sequence (stop the stream and all helper threads). + + Args: + reason (Any): The reason to close the stream. If ``None``, this is + considered an "intentional" shutdown. + """ with self._closing: if self._closed: return @@ -534,7 +566,9 @@ def close(self, reason=None): # Shutdown all helper threads _LOGGER.debug("Stopping scheduler.") - self._scheduler.shutdown() + dropped_messages = self._scheduler.shutdown( + await_msg_callbacks=self._await_callbacks_on_shutdown + ) self._scheduler = None # Leaser and dispatcher reference each other through the shared @@ -548,11 +582,23 @@ def close(self, reason=None): # because the consumer gets shut down first. _LOGGER.debug("Stopping leaser.") self._leaser.stop() + + total = len(dropped_messages) + len( + self._messages_on_hold._messages_on_hold + ) + _LOGGER.debug(f"NACK-ing all not-yet-dispatched messages (total: {total}).") + messages_to_nack = itertools.chain( + dropped_messages, self._messages_on_hold._messages_on_hold + ) + for msg in messages_to_nack: + msg.nack() + _LOGGER.debug("Stopping dispatcher.") self._dispatcher.stop() self._dispatcher = None # dispatcher terminated, OK to dispose the leaser reference now self._leaser = None + _LOGGER.debug("Stopping heartbeater.") self._heartbeater.stop() self._heartbeater = None @@ -722,7 +768,7 @@ def _on_rpc_done(self, future): _LOGGER.info("RPC termination has signaled streaming pull manager shutdown.") error = _wrap_as_exception(future) thread = threading.Thread( - name=_RPC_ERROR_THREAD_NAME, target=self.close, kwargs={"reason": error} + name=_RPC_ERROR_THREAD_NAME, target=self._shutdown, kwargs={"reason": error} ) thread.daemon = True thread.start() diff --git a/google/cloud/pubsub_v1/subscriber/client.py b/google/cloud/pubsub_v1/subscriber/client.py index f306d2d99..51bdc106c 100644 --- a/google/cloud/pubsub_v1/subscriber/client.py +++ b/google/cloud/pubsub_v1/subscriber/client.py @@ -122,6 +122,7 @@ def subscribe( flow_control=(), scheduler=None, use_legacy_flow_control=False, + await_callbacks_on_shutdown=False, ): """Asynchronously start receiving messages on a given subscription. @@ -199,6 +200,17 @@ def callback(message): *scheduler* to use when executing the callback. This controls how callbacks are executed concurrently. This object must not be shared across multiple SubscriberClients. + await_callbacks_on_shutdown (bool): + If ``True``, after canceling the returned future, the latter's + ``result()`` method will block until the background stream and its + helper threads have been terminated, and all currently executing message + callbacks are done processing. + + If ``False`` (default), the returned future's ``result()`` method will + not block after canceling the future. The method will instead return + immediately after the background stream and its helper threads have been + terminated, but some of the message callback threads might still be + running at that point. Returns: A :class:`~google.cloud.pubsub_v1.subscriber.futures.StreamingPullFuture` @@ -212,6 +224,7 @@ def callback(message): flow_control=flow_control, scheduler=scheduler, use_legacy_flow_control=use_legacy_flow_control, + await_callbacks_on_shutdown=await_callbacks_on_shutdown, ) future = futures.StreamingPullFuture(manager) diff --git a/google/cloud/pubsub_v1/subscriber/scheduler.py b/google/cloud/pubsub_v1/subscriber/scheduler.py index 84f494eb9..dd623517c 100644 --- a/google/cloud/pubsub_v1/subscriber/scheduler.py +++ b/google/cloud/pubsub_v1/subscriber/scheduler.py @@ -54,8 +54,21 @@ def schedule(self, callback, *args, **kwargs): raise NotImplementedError @abc.abstractmethod - def shutdown(self): + def shutdown(self, await_msg_callbacks=False): """Shuts down the scheduler and immediately end all pending callbacks. + + Args: + await_msg_callbacks (bool): + If ``True``, the method will block until all currently executing + callbacks are done processing. If ``False`` (default), the + method will not wait for the currently running callbacks to complete. + + Returns: + List[pubsub_v1.subscriber.message.Message]: + The messages submitted to the scheduler that were not yet dispatched + to their callbacks. + It is assumed that each message was submitted to the scheduler as the + first positional argument to the provided callback. """ raise NotImplementedError @@ -103,15 +116,35 @@ def schedule(self, callback, *args, **kwargs): """ self._executor.submit(callback, *args, **kwargs) - def shutdown(self): - """Shuts down the scheduler and immediately end all pending callbacks. + def shutdown(self, await_msg_callbacks=False): + """Shut down the scheduler and immediately end all pending callbacks. + + Args: + await_msg_callbacks (bool): + If ``True``, the method will block until all currently executing + executor threads are done processing. If ``False`` (default), the + method will not wait for the currently running threads to complete. + + Returns: + List[pubsub_v1.subscriber.message.Message]: + The messages submitted to the scheduler that were not yet dispatched + to their callbacks. + It is assumed that each message was submitted to the scheduler as the + first positional argument to the provided callback. """ - # Drop all pending item from the executor. Without this, the executor - # will block until all pending items are complete, which is - # undesirable. + dropped_messages = [] + + # Drop all pending item from the executor. Without this, the executor will also + # try to process any pending work items before termination, which is undesirable. + # + # TODO: Replace the logic below by passing `cancel_futures=True` to shutdown() + # once we only need to support Python 3.9+. try: while True: - self._executor._work_queue.get(block=False) + work_item = self._executor._work_queue.get(block=False) + dropped_messages.append(work_item.args[0]) except queue.Empty: pass - self._executor.shutdown() + + self._executor.shutdown(wait=await_msg_callbacks) + return dropped_messages diff --git a/samples/snippets/subscriber.py b/samples/snippets/subscriber.py index aa5771b86..112c5a96a 100644 --- a/samples/snippets/subscriber.py +++ b/samples/snippets/subscriber.py @@ -478,6 +478,50 @@ def callback(message): # [END pubsub_subscriber_flow_settings] +def receive_messages_with_blocking_shutdown(project_id, subscription_id, timeout=5.0): + """Shuts down a pull subscription by awaiting message callbacks to complete.""" + # [START pubsub_subscriber_blocking_shutdown] + import time + from concurrent.futures import TimeoutError + from google.cloud import pubsub_v1 + + # TODO(developer) + # project_id = "your-project-id" + # subscription_id = "your-subscription-id" + # Number of seconds the subscriber should listen for messages + # timeout = 5.0 + + subscriber = pubsub_v1.SubscriberClient() + subscription_path = subscriber.subscription_path(project_id, subscription_id) + + def callback(message): + print(f"Received {message.data}.") + time.sleep(timeout + 5.0) # Pocess longer than streaming pull future timeout. + message.ack() + print(f"Done processing the message {message.data}.") + + streaming_pull_future = subscriber.subscribe( + subscription_path, callback=callback, await_callbacks_on_shutdown=True, + ) + print(f"Listening for messages on {subscription_path}..\n") + + # Wrap subscriber in a 'with' block to automatically call close() when done. + with subscriber: + try: + # When `timeout` is not set, result() will block indefinitely, + # unless an exception is encountered first. + streaming_pull_future.result(timeout=timeout) + except TimeoutError: + streaming_pull_future.cancel() + print("Streaming pull future canceled.") + streaming_pull_future.result() # Blocks until shutdown complete. + print("Done waiting for the stream shutdown.") + + # The "Done waiting..." message is only printed *after* the processing of all + # received messages has completed. + # [END pubsub_subscriber_blocking_shutdown] + + def synchronous_pull(project_id, subscription_id): """Pulling messages synchronously.""" # [START pubsub_subscriber_sync_pull] @@ -749,6 +793,15 @@ def callback(message): "timeout", default=None, type=float, nargs="?" ) + receive_with_blocking_shutdown_parser = subparsers.add_parser( + "receive-blocking-shutdown", + help=receive_messages_with_blocking_shutdown.__doc__, + ) + receive_with_blocking_shutdown_parser.add_argument("subscription_id") + receive_with_blocking_shutdown_parser.add_argument( + "timeout", default=None, type=float, nargs="?" + ) + synchronous_pull_parser = subparsers.add_parser( "receive-synchronously", help=synchronous_pull.__doc__ ) @@ -827,6 +880,10 @@ def callback(message): receive_messages_with_flow_control( args.project_id, args.subscription_id, args.timeout ) + elif args.command == "receive-blocking-shutdown": + receive_messages_with_blocking_shutdown( + args.project_id, args.subscription_id, args.timeout + ) elif args.command == "receive-synchronously": synchronous_pull(args.project_id, args.subscription_id) elif args.command == "receive-synchronously-with-lease": diff --git a/samples/snippets/subscriber_test.py b/samples/snippets/subscriber_test.py index de54598a5..8d034949d 100644 --- a/samples/snippets/subscriber_test.py +++ b/samples/snippets/subscriber_test.py @@ -13,6 +13,7 @@ # limitations under the License. import os +import re import sys import uuid @@ -395,6 +396,52 @@ def test_receive_with_flow_control(publisher_client, topic, subscription_async, assert "message" in out +def test_receive_with_blocking_shutdown( + publisher_client, topic, subscription_async, capsys +): + _publish_messages(publisher_client, topic, message_num=3) + + subscriber.receive_messages_with_blocking_shutdown( + PROJECT_ID, SUBSCRIPTION_ASYNC, timeout=5.0 + ) + + out, _ = capsys.readouterr() + out_lines = out.splitlines() + + msg_received_lines = [ + i for i, line in enumerate(out_lines) + if re.search(r".*received.*message.*", line, flags=re.IGNORECASE) + ] + msg_done_lines = [ + i for i, line in enumerate(out_lines) + if re.search(r".*done processing.*message.*", line, flags=re.IGNORECASE) + ] + stream_canceled_lines = [ + i for i, line in enumerate(out_lines) + if re.search(r".*streaming pull future canceled.*", line, flags=re.IGNORECASE) + ] + shutdown_done_waiting_lines = [ + i for i, line in enumerate(out_lines) + if re.search(r".*done waiting.*stream shutdown.*", line, flags=re.IGNORECASE) + ] + + assert "Listening" in out + assert subscription_async in out + + assert len(stream_canceled_lines) == 1 + assert len(shutdown_done_waiting_lines) == 1 + assert len(msg_received_lines) == 3 + assert len(msg_done_lines) == 3 + + # The stream should have been canceled *after* receiving messages, but before + # message processing was done. + assert msg_received_lines[-1] < stream_canceled_lines[0] < msg_done_lines[0] + + # Yet, waiting on the stream shutdown should have completed *after* the processing + # of received messages has ended. + assert msg_done_lines[-1] < shutdown_done_waiting_lines[0] + + def test_listen_for_errors(publisher_client, topic, subscription_async, capsys): _publish_messages(publisher_client, topic) diff --git a/tests/system.py b/tests/system.py index 512d75a5c..181632d79 100644 --- a/tests/system.py +++ b/tests/system.py @@ -14,6 +14,7 @@ from __future__ import absolute_import +import concurrent.futures import datetime import itertools import operator as op @@ -608,6 +609,82 @@ def test_streaming_pull_max_messages( finally: subscription_future.cancel() # trigger clean shutdown + def test_streaming_pull_blocking_shutdown( + self, publisher, topic_path, subscriber, subscription_path, cleanup + ): + # Make sure the topic and subscription get deleted. + cleanup.append((publisher.delete_topic, (), {"topic": topic_path})) + cleanup.append( + (subscriber.delete_subscription, (), {"subscription": subscription_path}) + ) + + # The ACK-s are only persisted if *all* messages published in the same batch + # are ACK-ed. We thus publish each message in its own batch so that the backend + # treats all messages' ACKs independently of each other. + publisher.create_topic(name=topic_path) + subscriber.create_subscription(name=subscription_path, topic=topic_path) + _publish_messages(publisher, topic_path, batch_sizes=[1] * 10) + + # Artificially delay message processing, gracefully shutdown the streaming pull + # in the meantime, then verify that those messages were nevertheless processed. + processed_messages = [] + + def callback(message): + time.sleep(15) + processed_messages.append(message.data) + message.ack() + + # Flow control limits should exceed the number of worker threads, so that some + # of the messages will be blocked on waiting for free scheduler threads. + flow_control = pubsub_v1.types.FlowControl(max_messages=5) + executor = concurrent.futures.ThreadPoolExecutor(max_workers=3) + scheduler = pubsub_v1.subscriber.scheduler.ThreadScheduler(executor=executor) + subscription_future = subscriber.subscribe( + subscription_path, + callback=callback, + flow_control=flow_control, + scheduler=scheduler, + await_callbacks_on_shutdown=True, + ) + + try: + subscription_future.result(timeout=10) # less than the sleep in callback + except exceptions.TimeoutError: + subscription_future.cancel() + subscription_future.result() # block until shutdown completes + + # Blocking om shutdown should have waited for the already executing + # callbacks to finish. + assert len(processed_messages) == 3 + + # The messages that were not processed should have been NACK-ed and we should + # receive them again quite soon. + all_done = threading.Barrier(7 + 1, timeout=5) # +1 because of the main thread + remaining = [] + + def callback2(message): + remaining.append(message.data) + message.ack() + all_done.wait() + + subscription_future = subscriber.subscribe( + subscription_path, callback=callback2, await_callbacks_on_shutdown=False + ) + + try: + all_done.wait() + except threading.BrokenBarrierError: # PRAGMA: no cover + pytest.fail("The remaining messages have not been re-delivered in time.") + finally: + subscription_future.cancel() + subscription_future.result() # block until shutdown completes + + # There should be 7 messages left that were not yet processed and none of them + # should be a message that should have already been sucessfully processed in the + # first streaming pull. + assert len(remaining) == 7 + assert not (set(processed_messages) & set(remaining)) # no re-delivery + @pytest.mark.skipif( "KOKORO_GFILE_DIR" not in os.environ, @@ -789,8 +866,8 @@ def _publish_messages(publisher, topic_path, batch_sizes): publish_futures = [] msg_counter = itertools.count(start=1) - for batch_size in batch_sizes: - msg_batch = _make_messages(count=batch_size) + for batch_num, batch_size in enumerate(batch_sizes, start=1): + msg_batch = _make_messages(count=batch_size, batch_num=batch_num) for msg in msg_batch: future = publisher.publish(topic_path, msg, seq_num=str(next(msg_counter))) publish_futures.append(future) @@ -801,9 +878,10 @@ def _publish_messages(publisher, topic_path, batch_sizes): future.result(timeout=30) -def _make_messages(count): +def _make_messages(count, batch_num): messages = [ - "message {}/{}".format(i, count).encode("utf-8") for i in range(1, count + 1) + f"message {i}/{count} of batch {batch_num}".encode("utf-8") + for i in range(1, count + 1) ] return messages diff --git a/tests/unit/pubsub_v1/subscriber/test_dispatcher.py b/tests/unit/pubsub_v1/subscriber/test_dispatcher.py index 097ff46af..84e04df1b 100644 --- a/tests/unit/pubsub_v1/subscriber/test_dispatcher.py +++ b/tests/unit/pubsub_v1/subscriber/test_dispatcher.py @@ -29,14 +29,14 @@ @pytest.mark.parametrize( "item,method_name", [ - (requests.AckRequest(0, 0, 0, ""), "ack"), - (requests.DropRequest(0, 0, ""), "drop"), - (requests.LeaseRequest(0, 0, ""), "lease"), - (requests.ModAckRequest(0, 0), "modify_ack_deadline"), - (requests.NackRequest(0, 0, ""), "nack"), + (requests.AckRequest("0", 0, 0, ""), "ack"), + (requests.DropRequest("0", 0, ""), "drop"), + (requests.LeaseRequest("0", 0, ""), "lease"), + (requests.ModAckRequest("0", 0), "modify_ack_deadline"), + (requests.NackRequest("0", 0, ""), "nack"), ], ) -def test_dispatch_callback(item, method_name): +def test_dispatch_callback_active_manager(item, method_name): manager = mock.create_autospec( streaming_pull_manager.StreamingPullManager, instance=True ) @@ -50,16 +50,29 @@ def test_dispatch_callback(item, method_name): method.assert_called_once_with([item]) -def test_dispatch_callback_inactive(): +@pytest.mark.parametrize( + "item,method_name", + [ + (requests.AckRequest("0", 0, 0, ""), "ack"), + (requests.DropRequest("0", 0, ""), "drop"), + (requests.LeaseRequest("0", 0, ""), "lease"), + (requests.ModAckRequest("0", 0), "modify_ack_deadline"), + (requests.NackRequest("0", 0, ""), "nack"), + ], +) +def test_dispatch_callback_inactive_manager(item, method_name): manager = mock.create_autospec( streaming_pull_manager.StreamingPullManager, instance=True ) manager.is_active = False dispatcher_ = dispatcher.Dispatcher(manager, mock.sentinel.queue) - dispatcher_.dispatch_callback([requests.AckRequest(0, 0, 0, "")]) + items = [item] - manager.send.assert_not_called() + with mock.patch.object(dispatcher_, method_name) as method: + dispatcher_.dispatch_callback(items) + + method.assert_called_once_with([item]) def test_ack(): diff --git a/tests/unit/pubsub_v1/subscriber/test_heartbeater.py b/tests/unit/pubsub_v1/subscriber/test_heartbeater.py index 8f5049691..1a52af231 100644 --- a/tests/unit/pubsub_v1/subscriber/test_heartbeater.py +++ b/tests/unit/pubsub_v1/subscriber/test_heartbeater.py @@ -22,22 +22,44 @@ import pytest -def test_heartbeat_inactive(caplog): - caplog.set_level(logging.INFO) +def test_heartbeat_inactive_manager_active_rpc(caplog): + caplog.set_level(logging.DEBUG) + + manager = mock.create_autospec( + streaming_pull_manager.StreamingPullManager, instance=True + ) + manager.is_active = False + manager.heartbeat.return_value = True # because of active rpc + + heartbeater_ = heartbeater.Heartbeater(manager) + make_sleep_mark_event_as_done(heartbeater_) + + heartbeater_.heartbeat() + + assert "Sent heartbeat" in caplog.text + assert "exiting" in caplog.text + + +def test_heartbeat_inactive_manager_inactive_rpc(caplog): + caplog.set_level(logging.DEBUG) + manager = mock.create_autospec( streaming_pull_manager.StreamingPullManager, instance=True ) manager.is_active = False + manager.heartbeat.return_value = False # because of inactive rpc heartbeater_ = heartbeater.Heartbeater(manager) + make_sleep_mark_event_as_done(heartbeater_) heartbeater_.heartbeat() + assert "Sent heartbeat" not in caplog.text assert "exiting" in caplog.text def test_heartbeat_stopped(caplog): - caplog.set_level(logging.INFO) + caplog.set_level(logging.DEBUG) manager = mock.create_autospec( streaming_pull_manager.StreamingPullManager, instance=True ) @@ -47,17 +69,18 @@ def test_heartbeat_stopped(caplog): heartbeater_.heartbeat() + assert "Sent heartbeat" not in caplog.text assert "exiting" in caplog.text -def make_sleep_mark_manager_as_inactive(heartbeater): - # Make sleep mark the manager as inactive so that heartbeat() +def make_sleep_mark_event_as_done(heartbeater): + # Make sleep actually trigger the done event so that heartbeat() # exits at the end of the first run. - def trigger_inactive(timeout): + def trigger_done(timeout): assert timeout - heartbeater._manager.is_active = False + heartbeater._stop_event.set() - heartbeater._stop_event.wait = trigger_inactive + heartbeater._stop_event.wait = trigger_done def test_heartbeat_once(): @@ -65,7 +88,7 @@ def test_heartbeat_once(): streaming_pull_manager.StreamingPullManager, instance=True ) heartbeater_ = heartbeater.Heartbeater(manager) - make_sleep_mark_manager_as_inactive(heartbeater_) + make_sleep_mark_event_as_done(heartbeater_) heartbeater_.heartbeat() diff --git a/tests/unit/pubsub_v1/subscriber/test_leaser.py b/tests/unit/pubsub_v1/subscriber/test_leaser.py index 17409cb3f..2ecc0b9f3 100644 --- a/tests/unit/pubsub_v1/subscriber/test_leaser.py +++ b/tests/unit/pubsub_v1/subscriber/test_leaser.py @@ -88,15 +88,21 @@ def create_manager(flow_control=types.FlowControl()): return manager -def test_maintain_leases_inactive(caplog): +def test_maintain_leases_inactive_manager(caplog): caplog.set_level(logging.INFO) manager = create_manager() manager.is_active = False leaser_ = leaser.Leaser(manager) + make_sleep_mark_event_as_done(leaser_) + leaser_.add( + [requests.LeaseRequest(ack_id="my_ack_ID", byte_size=42, ordering_key="")] + ) leaser_.maintain_leases() + # Leases should still be maintained even if the manager is inactive. + manager.dispatcher.modify_ack_deadline.assert_called() assert "exiting" in caplog.text @@ -112,20 +118,20 @@ def test_maintain_leases_stopped(caplog): assert "exiting" in caplog.text -def make_sleep_mark_manager_as_inactive(leaser): - # Make sleep mark the manager as inactive so that maintain_leases +def make_sleep_mark_event_as_done(leaser): + # Make sleep actually trigger the done event so that heartbeat() # exits at the end of the first run. - def trigger_inactive(timeout): + def trigger_done(timeout): assert 0 < timeout < 10 - leaser._manager.is_active = False + leaser._stop_event.set() - leaser._stop_event.wait = trigger_inactive + leaser._stop_event.wait = trigger_done def test_maintain_leases_ack_ids(): manager = create_manager() leaser_ = leaser.Leaser(manager) - make_sleep_mark_manager_as_inactive(leaser_) + make_sleep_mark_event_as_done(leaser_) leaser_.add( [requests.LeaseRequest(ack_id="my ack id", byte_size=50, ordering_key="")] ) @@ -140,7 +146,7 @@ def test_maintain_leases_ack_ids(): def test_maintain_leases_no_ack_ids(): manager = create_manager() leaser_ = leaser.Leaser(manager) - make_sleep_mark_manager_as_inactive(leaser_) + make_sleep_mark_event_as_done(leaser_) leaser_.maintain_leases() @@ -151,7 +157,7 @@ def test_maintain_leases_no_ack_ids(): def test_maintain_leases_outdated_items(time): manager = create_manager() leaser_ = leaser.Leaser(manager) - make_sleep_mark_manager_as_inactive(leaser_) + make_sleep_mark_event_as_done(leaser_) # Add and start expiry timer at the beginning of the timeline. time.return_value = 0 diff --git a/tests/unit/pubsub_v1/subscriber/test_scheduler.py b/tests/unit/pubsub_v1/subscriber/test_scheduler.py index 2ed1ea55a..82a6719d7 100644 --- a/tests/unit/pubsub_v1/subscriber/test_scheduler.py +++ b/tests/unit/pubsub_v1/subscriber/test_scheduler.py @@ -15,6 +15,7 @@ import concurrent.futures import queue import threading +import time import mock @@ -38,19 +39,89 @@ def test_constructor_options(): assert scheduler_._executor == mock.sentinel.executor -def test_schedule(): +def test_schedule_executes_submitted_items(): called_with = [] - called = threading.Event() + callback_done_twice = threading.Barrier(3) # 3 == 2x callback + 1x main thread def callback(*args, **kwargs): - called_with.append((args, kwargs)) - called.set() + called_with.append((args, kwargs)) # appends are thread-safe + callback_done_twice.wait() scheduler_ = scheduler.ThreadScheduler() scheduler_.schedule(callback, "arg1", kwarg1="meep") + scheduler_.schedule(callback, "arg2", kwarg2="boop") - called.wait() - scheduler_.shutdown() + callback_done_twice.wait(timeout=3.0) + result = scheduler_.shutdown() - assert called_with == [(("arg1",), {"kwarg1": "meep"})] + assert result == [] # no scheduled items dropped + + expected_calls = [(("arg1",), {"kwarg1": "meep"}), (("arg2",), {"kwarg2": "boop"})] + assert sorted(called_with) == expected_calls + + +def test_shutdown_nonblocking_by_default(): + called_with = [] + at_least_one_called = threading.Event() + at_least_one_completed = threading.Event() + + def callback(message): + called_with.append(message) # appends are thread-safe + at_least_one_called.set() + time.sleep(1.0) + at_least_one_completed.set() + + executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + scheduler_ = scheduler.ThreadScheduler(executor=executor) + + scheduler_.schedule(callback, "message_1") + scheduler_.schedule(callback, "message_2") + + at_least_one_called.wait() + dropped = scheduler_.shutdown() + + assert len(called_with) == 1 + assert called_with[0] in {"message_1", "message_2"} + + assert len(dropped) == 1 + assert dropped[0] in {"message_1", "message_2"} + assert dropped[0] != called_with[0] # the dropped message was not the processed one + + err_msg = ( + "Shutdown should not have waited " + "for the already running callbacks to complete." + ) + assert not at_least_one_completed.is_set(), err_msg + + +def test_shutdown_blocking_awaits_running_callbacks(): + called_with = [] + at_least_one_called = threading.Event() + at_least_one_completed = threading.Event() + + def callback(message): + called_with.append(message) # appends are thread-safe + at_least_one_called.set() + time.sleep(1.0) + at_least_one_completed.set() + + executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + scheduler_ = scheduler.ThreadScheduler(executor=executor) + + scheduler_.schedule(callback, "message_1") + scheduler_.schedule(callback, "message_2") + + at_least_one_called.wait() + dropped = scheduler_.shutdown(await_msg_callbacks=True) + + assert len(called_with) == 1 + assert called_with[0] in {"message_1", "message_2"} + + # The work items that have not been started yet should still be dropped. + assert len(dropped) == 1 + assert dropped[0] in {"message_1", "message_2"} + assert dropped[0] != called_with[0] # the dropped message was not the processed one + + err_msg = "Shutdown did not wait for the already running callbacks to complete." + assert at_least_one_completed.is_set(), err_msg diff --git a/tests/unit/pubsub_v1/subscriber/test_streaming_pull_manager.py b/tests/unit/pubsub_v1/subscriber/test_streaming_pull_manager.py index 242c0804a..9930e8f14 100644 --- a/tests/unit/pubsub_v1/subscriber/test_streaming_pull_manager.py +++ b/tests/unit/pubsub_v1/subscriber/test_streaming_pull_manager.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools import logging import threading import time @@ -372,7 +373,6 @@ def test__maybe_release_messages_negative_on_hold_bytes_warning(caplog): def test_send_unary(): manager = make_manager() - manager._UNARY_REQUESTS = True manager.send( gapic_types.StreamingPullRequest( @@ -405,7 +405,6 @@ def test_send_unary(): def test_send_unary_empty(): manager = make_manager() - manager._UNARY_REQUESTS = True manager.send(gapic_types.StreamingPullRequest()) @@ -417,7 +416,6 @@ def test_send_unary_api_call_error(caplog): caplog.set_level(logging.DEBUG) manager = make_manager() - manager._UNARY_REQUESTS = True error = exceptions.GoogleAPICallError("The front fell off") manager._client.acknowledge.side_effect = error @@ -431,7 +429,6 @@ def test_send_unary_retry_error(caplog): caplog.set_level(logging.DEBUG) manager, _, _, _, _, _ = make_running_manager() - manager._UNARY_REQUESTS = True error = exceptions.RetryError( "Too long a transient error", cause=Exception("Out of time!") @@ -445,24 +442,15 @@ def test_send_unary_retry_error(caplog): assert "signaled streaming pull manager shutdown" in caplog.text -def test_send_streaming(): - manager = make_manager() - manager._UNARY_REQUESTS = False - manager._rpc = mock.create_autospec(bidi.BidiRpc, instance=True) - - manager.send(mock.sentinel.request) - - manager._rpc.send.assert_called_once_with(mock.sentinel.request) - - def test_heartbeat(): manager = make_manager() manager._rpc = mock.create_autospec(bidi.BidiRpc, instance=True) manager._rpc.is_active = True - manager.heartbeat() + result = manager.heartbeat() manager._rpc.send.assert_called_once_with(gapic_types.StreamingPullRequest()) + assert result def test_heartbeat_inactive(): @@ -472,7 +460,8 @@ def test_heartbeat_inactive(): manager.heartbeat() - manager._rpc.send.assert_not_called() + result = manager._rpc.send.assert_not_called() + assert not result @mock.patch("google.api_core.bidi.ResumableBidiRpc", autospec=True) @@ -543,8 +532,8 @@ def test_open_has_been_closed(): manager.open(mock.sentinel.callback, mock.sentinel.on_callback_error) -def make_running_manager(): - manager = make_manager() +def make_running_manager(**kwargs): + manager = make_manager(**kwargs) manager._consumer = mock.create_autospec(bidi.BackgroundConsumer, instance=True) manager._consumer.is_active = True manager._dispatcher = mock.create_autospec(dispatcher.Dispatcher, instance=True) @@ -632,14 +621,14 @@ def _do_work(self): while not self._stop: try: self._manager.leaser.add([mock.Mock()]) - except Exception as exc: + except Exception as exc: # pragma: NO COVER self._error_callback(exc) time.sleep(0.1) # also try to interact with the leaser after the stop flag has been set try: self._manager.leaser.remove([mock.Mock()]) - except Exception as exc: + except Exception as exc: # pragma: NO COVER self._error_callback(exc) @@ -666,6 +655,45 @@ def test_close_callbacks(): callback.assert_called_once_with(manager, "meep") +def test_close_blocking_scheduler_shutdown(): + manager, _, _, _, _, _ = make_running_manager(await_callbacks_on_shutdown=True) + scheduler = manager._scheduler + + manager.close() + + scheduler.shutdown.assert_called_once_with(await_msg_callbacks=True) + + +def test_close_nonblocking_scheduler_shutdown(): + manager, _, _, _, _, _ = make_running_manager(await_callbacks_on_shutdown=False) + scheduler = manager._scheduler + + manager.close() + + scheduler.shutdown.assert_called_once_with(await_msg_callbacks=False) + + +def test_close_nacks_internally_queued_messages(): + nacked_messages = [] + + def fake_nack(self): + nacked_messages.append(self.data) + + MockMsg = functools.partial(mock.create_autospec, message.Message, instance=True) + messages = [MockMsg(data=b"msg1"), MockMsg(data=b"msg2"), MockMsg(data=b"msg3")] + for msg in messages: + msg.nack = stdlib_types.MethodType(fake_nack, msg) + + manager, _, _, _, _, _ = make_running_manager() + dropped_by_scheduler = messages[:2] + manager._scheduler.shutdown.return_value = dropped_by_scheduler + manager._messages_on_hold._messages_on_hold.append(messages[2]) + + manager.close() + + assert sorted(nacked_messages) == [b"msg1", b"msg2", b"msg3"] + + def test__get_initial_request(): manager = make_manager() manager._leaser = mock.create_autospec(leaser.Leaser, instance=True) @@ -960,7 +988,7 @@ def test__on_rpc_done(thread): manager._on_rpc_done(mock.sentinel.error) thread.assert_called_once_with( - name=mock.ANY, target=manager.close, kwargs={"reason": mock.ANY} + name=mock.ANY, target=manager._shutdown, kwargs={"reason": mock.ANY} ) _, kwargs = thread.call_args reason = kwargs["kwargs"]["reason"] @@ -979,3 +1007,15 @@ def test_activate_ordering_keys(): manager._messages_on_hold.activate_ordering_keys.assert_called_once_with( ["key1", "key2"], mock.ANY ) + + +def test_activate_ordering_keys_stopped_scheduler(): + manager = make_manager() + manager._messages_on_hold = mock.create_autospec( + messages_on_hold.MessagesOnHold, instance=True + ) + manager._scheduler = None + + manager.activate_ordering_keys(["key1", "key2"]) + + manager._messages_on_hold.activate_ordering_keys.assert_not_called() diff --git a/tests/unit/pubsub_v1/subscriber/test_subscriber_client.py b/tests/unit/pubsub_v1/subscriber/test_subscriber_client.py index 780c20de4..6dad4b12a 100644 --- a/tests/unit/pubsub_v1/subscriber/test_subscriber_client.py +++ b/tests/unit/pubsub_v1/subscriber/test_subscriber_client.py @@ -172,12 +172,14 @@ def test_subscribe_options(manager_open): callback=mock.sentinel.callback, flow_control=flow_control, scheduler=scheduler, + await_callbacks_on_shutdown=mock.sentinel.await_callbacks, ) assert isinstance(future, futures.StreamingPullFuture) assert future._manager._subscription == "sub_name_a" assert future._manager.flow_control == flow_control assert future._manager._scheduler == scheduler + assert future._manager._await_callbacks_on_shutdown is mock.sentinel.await_callbacks manager_open.assert_called_once_with( mock.ANY, callback=mock.sentinel.callback,