diff --git a/google/cloud/pubsub_v1/subscriber/client.py b/google/cloud/pubsub_v1/subscriber/client.py index b255fe476..00d97231e 100644 --- a/google/cloud/pubsub_v1/subscriber/client.py +++ b/google/cloud/pubsub_v1/subscriber/client.py @@ -228,3 +228,19 @@ def callback(message): manager.open(callback=callback, on_callback_error=future.set_exception) return future + + def close(self): + """Close the underlying channel to release socket resources. + + After a channel has been closed, the client instance cannot be used + anymore. + + This method is idempotent. + """ + self.api.transport.channel.close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() diff --git a/noxfile.py b/noxfile.py index 99d73295c..03c570f57 100644 --- a/noxfile.py +++ b/noxfile.py @@ -110,7 +110,7 @@ def system(session): # Install all test dependencies, then install this package into the # virtualenv's dist-packages. - session.install("mock", "pytest") + session.install("mock", "pytest", "psutil") session.install("-e", "test_utils") session.install("-e", ".") diff --git a/tests/system.py b/tests/system.py index 65baaf016..37a39766a 100644 --- a/tests/system.py +++ b/tests/system.py @@ -18,6 +18,7 @@ import itertools import operator as op import os +import psutil import threading import time @@ -46,7 +47,7 @@ def publisher(): yield pubsub_v1.PublisherClient() -@pytest.fixture(scope=u"module") +@pytest.fixture(scope="module") def subscriber(): yield pubsub_v1.SubscriberClient() @@ -383,6 +384,54 @@ def test_managing_subscription_iam_policy( assert bindings[1].members == ["group:cloud-logs@google.com"] +def test_subscriber_not_leaking_open_sockets( + publisher, topic_path, subscription_path, cleanup +): + # Make sure the topic and the supscription get deleted. + # NOTE: Since subscriber client will be closed in the test, we should not + # use the shared `subscriber` fixture, but instead construct a new client + # in this test. + # Also, since the client will get closed, we need another subscriber client + # to clean up the subscription. We also need to make sure that auxiliary + # subscriber releases the sockets, too. + subscriber = pubsub_v1.SubscriberClient() + subscriber_2 = pubsub_v1.SubscriberClient() + cleanup.append((subscriber_2.delete_subscription, subscription_path)) + + def one_arg_close(subscriber): # the cleanup helper expects exactly one argument + subscriber.close() + + cleanup.append((one_arg_close, subscriber_2)) + cleanup.append((publisher.delete_topic, topic_path)) + + # Create topic before starting to track connection count (any sockets opened + # by the publisher client are not counted by this test). + publisher.create_topic(topic_path) + + current_process = psutil.Process() + conn_count_start = len(current_process.connections()) + + # Publish a few messages, then synchronously pull them and check that + # no sockets are leaked. + with subscriber: + subscriber.create_subscription(name=subscription_path, topic=topic_path) + + # Publish a few messages, wait for the publish to succeed. + publish_futures = [ + publisher.publish(topic_path, u"message {}".format(i).encode()) + for i in range(1, 4) + ] + for future in publish_futures: + future.result() + + # Synchronously pull messages. + response = subscriber.pull(subscription_path, max_messages=3) + assert len(response.received_messages) == 3 + + conn_count_end = len(current_process.connections()) + assert conn_count_end == conn_count_start + + class TestStreamingPull(object): def test_streaming_pull_callback_error_propagation( self, publisher, topic_path, subscriber, subscription_path, cleanup diff --git a/tests/unit/pubsub_v1/subscriber/test_subscriber_client.py b/tests/unit/pubsub_v1/subscriber/test_subscriber_client.py index b367733aa..19ec194ce 100644 --- a/tests/unit/pubsub_v1/subscriber/test_subscriber_client.py +++ b/tests/unit/pubsub_v1/subscriber/test_subscriber_client.py @@ -106,3 +106,22 @@ def test_subscribe_options(manager_open): callback=mock.sentinel.callback, on_callback_error=future.set_exception, ) + + +def test_close(): + mock_transport = mock.NonCallableMock() + client = subscriber.Client(transport=mock_transport) + + client.close() + + mock_transport.channel.close.assert_called() + + +def test_closes_channel_as_context_manager(): + mock_transport = mock.NonCallableMock() + client = subscriber.Client(transport=mock_transport) + + with client: + pass + + mock_transport.channel.close.assert_called()