diff --git a/google/cloud/pubsub_v1/subscriber/client.py b/google/cloud/pubsub_v1/subscriber/client.py index 85b88006d..c4b229a17 100644 --- a/google/cloud/pubsub_v1/subscriber/client.py +++ b/google/cloud/pubsub_v1/subscriber/client.py @@ -85,6 +85,7 @@ def __init__(self, **kwargs): # Instantiate the underlying GAPIC client. self._api = subscriber_client.SubscriberClient(**kwargs) self._target = self._api._transport._host + self._closed = False @classmethod def from_service_account_file(cls, filename, **kwargs): @@ -120,6 +121,14 @@ def api(self): """The underlying gapic API client.""" return self._api + @property + def closed(self) -> bool: + """Return whether the client has been closed and cannot be used anymore. + + .. versionadded:: 2.8.0 + """ + return self._closed + def subscribe( self, subscription, @@ -252,8 +261,11 @@ def close(self): This method is idempotent. """ self.api._transport.grpc_channel.close() + self._closed = True def __enter__(self): + if self._closed: + raise RuntimeError("Closed subscriber cannot be used as context manager.") return self def __exit__(self, exc_type, exc_val, exc_tb): diff --git a/tests/unit/pubsub_v1/subscriber/test_subscriber_client.py b/tests/unit/pubsub_v1/subscriber/test_subscriber_client.py index 7624c9212..601b40bcc 100644 --- a/tests/unit/pubsub_v1/subscriber/test_subscriber_client.py +++ b/tests/unit/pubsub_v1/subscriber/test_subscriber_client.py @@ -50,6 +50,11 @@ def test_init_default_client_info(creds): assert expected_client_info in user_agent +def test_init_default_closed_state(creds): + client = subscriber.Client(credentials=creds) + assert not client.closed + + def test_init_w_custom_transport(creds): transport = SubscriberGrpcTransport(credentials=creds) client = subscriber.Client(transport=transport) @@ -185,6 +190,7 @@ def test_close(creds): client.close() patched_close.assert_called() + assert client.closed def test_closes_channel_as_context_manager(creds): @@ -198,6 +204,18 @@ def test_closes_channel_as_context_manager(creds): patched_close.assert_called() +def test_context_manager_raises_if_closed(creds): + client = subscriber.Client(credentials=creds) + + with mock.patch.object(client.api._transport.grpc_channel, "close"): + client.close() + + expetect_msg = r"(?i).*closed.*cannot.*context manager.*" + with pytest.raises(RuntimeError, match=expetect_msg): + with client: + pass + + def test_streaming_pull_gapic_monkeypatch(creds): client = subscriber.Client(credentials=creds)