From 98d60e9e18b1b6301cbb98ffb6b0b7639e6e6fb9 Mon Sep 17 00:00:00 2001 From: "gcf-owl-bot[bot]" <78513119+gcf-owl-bot[bot]@users.noreply.github.com> Date: Thu, 7 Oct 2021 00:34:53 +0000 Subject: [PATCH] feat: add context manager support in client (#73) - [ ] Regenerate this pull request now. chore: fix docstring for first attribute of protos committer: @busunkim96 PiperOrigin-RevId: 401271153 Source-Link: https://github.com/googleapis/googleapis/commit/787f8c9a731f44e74a90b9847d48659ca9462d10 Source-Link: https://github.com/googleapis/googleapis-gen/commit/81decffe9fc72396a8153e756d1d67a6eecfd620 Copy-Tag: eyJwIjoiLmdpdGh1Yi8uT3dsQm90LnlhbWwiLCJoIjoiODFkZWNmZmU5ZmM3MjM5NmE4MTUzZTc1NmQxZDY3YTZlZWNmZDYyMCJ9 --- .../services/quota_controller/async_client.py | 6 +++ .../services/quota_controller/client.py | 18 +++++-- .../quota_controller/transports/base.py | 9 ++++ .../quota_controller/transports/grpc.py | 3 ++ .../transports/grpc_asyncio.py | 3 ++ .../service_controller/async_client.py | 6 +++ .../services/service_controller/client.py | 18 +++++-- .../service_controller/transports/base.py | 9 ++++ .../service_controller/transports/grpc.py | 3 ++ .../transports/grpc_asyncio.py | 3 ++ .../servicecontrol_v1/types/distribution.py | 3 ++ .../servicecontrol_v1/types/log_entry.py | 1 + .../servicecontrol_v1/types/metric_value.py | 1 + .../servicecontrol_v1/types/operation.py | 1 + .../types/quota_controller.py | 3 ++ .../types/service_controller.py | 6 +++ .../test_quota_controller.py | 50 +++++++++++++++++++ .../test_service_controller.py | 50 +++++++++++++++++++ 18 files changed, 185 insertions(+), 8 deletions(-) diff --git a/google/cloud/servicecontrol_v1/services/quota_controller/async_client.py b/google/cloud/servicecontrol_v1/services/quota_controller/async_client.py index 72f7bba..49a268d 100644 --- a/google/cloud/servicecontrol_v1/services/quota_controller/async_client.py +++ b/google/cloud/servicecontrol_v1/services/quota_controller/async_client.py @@ -217,6 +217,12 @@ async def allocate_quota( # Done; return the response. return response + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.transport.close() + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( diff --git a/google/cloud/servicecontrol_v1/services/quota_controller/client.py b/google/cloud/servicecontrol_v1/services/quota_controller/client.py index cf8cc78..b2a4e64 100644 --- a/google/cloud/servicecontrol_v1/services/quota_controller/client.py +++ b/google/cloud/servicecontrol_v1/services/quota_controller/client.py @@ -332,10 +332,7 @@ def __init__( client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, - always_use_jwt_access=( - Transport == type(self).get_transport_class("grpc") - or Transport == type(self).get_transport_class("grpc_asyncio") - ), + always_use_jwt_access=True, ) def allocate_quota( @@ -393,6 +390,19 @@ def allocate_quota( # Done; return the response. return response + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + """Releases underlying transport's resources. + + .. warning:: + ONLY use as a context manager if the transport is NOT shared + with other clients! Exiting the with block will CLOSE the transport + and may cause errors in other clients! + """ + self.transport.close() + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( diff --git a/google/cloud/servicecontrol_v1/services/quota_controller/transports/base.py b/google/cloud/servicecontrol_v1/services/quota_controller/transports/base.py index 5a7c560..3257a28 100644 --- a/google/cloud/servicecontrol_v1/services/quota_controller/transports/base.py +++ b/google/cloud/servicecontrol_v1/services/quota_controller/transports/base.py @@ -162,6 +162,15 @@ def _prep_wrapped_messages(self, client_info): ), } + def close(self): + """Closes resources associated with the transport. + + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! + """ + raise NotImplementedError() + @property def allocate_quota( self, diff --git a/google/cloud/servicecontrol_v1/services/quota_controller/transports/grpc.py b/google/cloud/servicecontrol_v1/services/quota_controller/transports/grpc.py index c823260..86e4b17 100644 --- a/google/cloud/servicecontrol_v1/services/quota_controller/transports/grpc.py +++ b/google/cloud/servicecontrol_v1/services/quota_controller/transports/grpc.py @@ -267,5 +267,8 @@ def allocate_quota( ) return self._stubs["allocate_quota"] + def close(self): + self.grpc_channel.close() + __all__ = ("QuotaControllerGrpcTransport",) diff --git a/google/cloud/servicecontrol_v1/services/quota_controller/transports/grpc_asyncio.py b/google/cloud/servicecontrol_v1/services/quota_controller/transports/grpc_asyncio.py index fc5a5eb..a513a09 100644 --- a/google/cloud/servicecontrol_v1/services/quota_controller/transports/grpc_asyncio.py +++ b/google/cloud/servicecontrol_v1/services/quota_controller/transports/grpc_asyncio.py @@ -271,5 +271,8 @@ def allocate_quota( ) return self._stubs["allocate_quota"] + def close(self): + return self.grpc_channel.close() + __all__ = ("QuotaControllerGrpcAsyncIOTransport",) diff --git a/google/cloud/servicecontrol_v1/services/service_controller/async_client.py b/google/cloud/servicecontrol_v1/services/service_controller/async_client.py index 90935a7..fbcc3ee 100644 --- a/google/cloud/servicecontrol_v1/services/service_controller/async_client.py +++ b/google/cloud/servicecontrol_v1/services/service_controller/async_client.py @@ -293,6 +293,12 @@ async def report( # Done; return the response. return response + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.transport.close() + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( diff --git a/google/cloud/servicecontrol_v1/services/service_controller/client.py b/google/cloud/servicecontrol_v1/services/service_controller/client.py index 8f95eeb..0ab18a6 100644 --- a/google/cloud/servicecontrol_v1/services/service_controller/client.py +++ b/google/cloud/servicecontrol_v1/services/service_controller/client.py @@ -334,10 +334,7 @@ def __init__( client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, - always_use_jwt_access=( - Transport == type(self).get_transport_class("grpc") - or Transport == type(self).get_transport_class("grpc_asyncio") - ), + always_use_jwt_access=True, ) def check( @@ -463,6 +460,19 @@ def report( # Done; return the response. return response + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + """Releases underlying transport's resources. + + .. warning:: + ONLY use as a context manager if the transport is NOT shared + with other clients! Exiting the with block will CLOSE the transport + and may cause errors in other clients! + """ + self.transport.close() + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( diff --git a/google/cloud/servicecontrol_v1/services/service_controller/transports/base.py b/google/cloud/servicecontrol_v1/services/service_controller/transports/base.py index c6e632e..c6559eb 100644 --- a/google/cloud/servicecontrol_v1/services/service_controller/transports/base.py +++ b/google/cloud/servicecontrol_v1/services/service_controller/transports/base.py @@ -176,6 +176,15 @@ def _prep_wrapped_messages(self, client_info): ), } + def close(self): + """Closes resources associated with the transport. + + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! + """ + raise NotImplementedError() + @property def check( self, diff --git a/google/cloud/servicecontrol_v1/services/service_controller/transports/grpc.py b/google/cloud/servicecontrol_v1/services/service_controller/transports/grpc.py index 4ef1311..7100c53 100644 --- a/google/cloud/servicecontrol_v1/services/service_controller/transports/grpc.py +++ b/google/cloud/servicecontrol_v1/services/service_controller/transports/grpc.py @@ -317,5 +317,8 @@ def report( ) return self._stubs["report"] + def close(self): + self.grpc_channel.close() + __all__ = ("ServiceControllerGrpcTransport",) diff --git a/google/cloud/servicecontrol_v1/services/service_controller/transports/grpc_asyncio.py b/google/cloud/servicecontrol_v1/services/service_controller/transports/grpc_asyncio.py index c99b439..5dbd054 100644 --- a/google/cloud/servicecontrol_v1/services/service_controller/transports/grpc_asyncio.py +++ b/google/cloud/servicecontrol_v1/services/service_controller/transports/grpc_asyncio.py @@ -322,5 +322,8 @@ def report( ) return self._stubs["report"] + def close(self): + return self.grpc_channel.close() + __all__ = ("ServiceControllerGrpcAsyncIOTransport",) diff --git a/google/cloud/servicecontrol_v1/types/distribution.py b/google/cloud/servicecontrol_v1/types/distribution.py index 5c5ffa0..d66536e 100644 --- a/google/cloud/servicecontrol_v1/types/distribution.py +++ b/google/cloud/servicecontrol_v1/types/distribution.py @@ -79,6 +79,7 @@ class Distribution(proto.Message): class LinearBuckets(proto.Message): r"""Describing buckets with constant width. + Attributes: num_finite_buckets (int): The number of finite buckets. With the underflow and @@ -101,6 +102,7 @@ class LinearBuckets(proto.Message): class ExponentialBuckets(proto.Message): r"""Describing buckets with exponentially growing width. + Attributes: num_finite_buckets (int): The number of finite buckets. With the underflow and @@ -124,6 +126,7 @@ class ExponentialBuckets(proto.Message): class ExplicitBuckets(proto.Message): r"""Describing buckets with arbitrary user-provided width. + Attributes: bounds (Sequence[float]): 'bound' is a list of strictly increasing boundaries between diff --git a/google/cloud/servicecontrol_v1/types/log_entry.py b/google/cloud/servicecontrol_v1/types/log_entry.py index 80ca9c6..0ac9e95 100644 --- a/google/cloud/servicecontrol_v1/types/log_entry.py +++ b/google/cloud/servicecontrol_v1/types/log_entry.py @@ -30,6 +30,7 @@ class LogEntry(proto.Message): r"""An individual log entry. + Attributes: name (str): Required. The log to which this log entry belongs. Examples: diff --git a/google/cloud/servicecontrol_v1/types/metric_value.py b/google/cloud/servicecontrol_v1/types/metric_value.py index fa81d37..9f1e628 100644 --- a/google/cloud/servicecontrol_v1/types/metric_value.py +++ b/google/cloud/servicecontrol_v1/types/metric_value.py @@ -26,6 +26,7 @@ class MetricValue(proto.Message): r"""Represents a single metric value. + Attributes: labels (Sequence[google.cloud.servicecontrol_v1.types.MetricValue.LabelsEntry]): The labels describing the metric value. See comments on diff --git a/google/cloud/servicecontrol_v1/types/operation.py b/google/cloud/servicecontrol_v1/types/operation.py index 69d55c3..debcea7 100644 --- a/google/cloud/servicecontrol_v1/types/operation.py +++ b/google/cloud/servicecontrol_v1/types/operation.py @@ -28,6 +28,7 @@ class Operation(proto.Message): r"""Represents information regarding an operation. + Attributes: operation_id (str): Identity of the operation. This must be diff --git a/google/cloud/servicecontrol_v1/types/quota_controller.py b/google/cloud/servicecontrol_v1/types/quota_controller.py index 0b74958..3c54e1e 100644 --- a/google/cloud/servicecontrol_v1/types/quota_controller.py +++ b/google/cloud/servicecontrol_v1/types/quota_controller.py @@ -32,6 +32,7 @@ class AllocateQuotaRequest(proto.Message): r"""Request message for the AllocateQuota method. + Attributes: service_name (str): Name of the service as specified in the service @@ -56,6 +57,7 @@ class AllocateQuotaRequest(proto.Message): class QuotaOperation(proto.Message): r"""Represents information regarding a quota operation. + Attributes: operation_id (str): Identity of the operation. This is expected to be unique @@ -128,6 +130,7 @@ class QuotaMode(proto.Enum): class AllocateQuotaResponse(proto.Message): r"""Response message for the AllocateQuota method. + Attributes: operation_id (str): The same operation_id value used in the diff --git a/google/cloud/servicecontrol_v1/types/service_controller.py b/google/cloud/servicecontrol_v1/types/service_controller.py index 0b2b93f..aa4b73b 100644 --- a/google/cloud/servicecontrol_v1/types/service_controller.py +++ b/google/cloud/servicecontrol_v1/types/service_controller.py @@ -28,6 +28,7 @@ class CheckRequest(proto.Message): r"""Request message for the Check method. + Attributes: service_name (str): The service name as specified in its service configuration. @@ -53,6 +54,7 @@ class CheckRequest(proto.Message): class CheckResponse(proto.Message): r"""Response message for the Check method. + Attributes: operation_id (str): The same operation_id value used in the @@ -77,6 +79,7 @@ class CheckResponse(proto.Message): class CheckInfo(proto.Message): r"""Contains additional information about the check operation. + Attributes: unused_arguments (Sequence[str]): A list of fields and label keys that are @@ -94,6 +97,7 @@ class CheckInfo(proto.Message): class ConsumerInfo(proto.Message): r"""``ConsumerInfo`` provides information about the consumer. + Attributes: project_number (int): The Google cloud project number, e.g. @@ -140,6 +144,7 @@ class ConsumerType(proto.Enum): class ReportRequest(proto.Message): r"""Request message for the Report method. + Attributes: service_name (str): The service name as specified in its service configuration. @@ -178,6 +183,7 @@ class ReportRequest(proto.Message): class ReportResponse(proto.Message): r"""Response message for the Report method. + Attributes: report_errors (Sequence[google.cloud.servicecontrol_v1.types.ReportResponse.ReportError]): Partial failures, one for each ``Operation`` in the request diff --git a/tests/unit/gapic/servicecontrol_v1/test_quota_controller.py b/tests/unit/gapic/servicecontrol_v1/test_quota_controller.py index b7b497f..9f4ae58 100644 --- a/tests/unit/gapic/servicecontrol_v1/test_quota_controller.py +++ b/tests/unit/gapic/servicecontrol_v1/test_quota_controller.py @@ -30,6 +30,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import path_template from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.servicecontrol_v1.services.quota_controller import ( @@ -702,6 +703,9 @@ def test_quota_controller_base_transport(): with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) + with pytest.raises(NotImplementedError): + transport.close() + @requires_google_auth_gte_1_25_0 def test_quota_controller_base_transport_with_credentials_file(): @@ -1181,3 +1185,49 @@ def test_client_withDEFAULT_CLIENT_INFO(): credentials=ga_credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) + + +@pytest.mark.asyncio +async def test_transport_close_async(): + client = QuotaControllerAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc_asyncio", + ) + with mock.patch.object( + type(getattr(client.transport, "grpc_channel")), "close" + ) as close: + async with client: + close.assert_not_called() + close.assert_called_once() + + +def test_transport_close(): + transports = { + "grpc": "_grpc_channel", + } + + for transport, close_name in transports.items(): + client = QuotaControllerClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + with mock.patch.object( + type(getattr(client.transport, close_name)), "close" + ) as close: + with client: + close.assert_not_called() + close.assert_called_once() + + +def test_client_ctx(): + transports = [ + "grpc", + ] + for transport in transports: + client = QuotaControllerClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + # Test client calls underlying transport. + with mock.patch.object(type(client.transport), "close") as close: + close.assert_not_called() + with client: + pass + close.assert_called() diff --git a/tests/unit/gapic/servicecontrol_v1/test_service_controller.py b/tests/unit/gapic/servicecontrol_v1/test_service_controller.py index 3fe646f..3e2934d 100644 --- a/tests/unit/gapic/servicecontrol_v1/test_service_controller.py +++ b/tests/unit/gapic/servicecontrol_v1/test_service_controller.py @@ -30,6 +30,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import path_template from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.servicecontrol_v1.services.service_controller import ( @@ -802,6 +803,9 @@ def test_service_controller_base_transport(): with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) + with pytest.raises(NotImplementedError): + transport.close() + @requires_google_auth_gte_1_25_0 def test_service_controller_base_transport_with_credentials_file(): @@ -1281,3 +1285,49 @@ def test_client_withDEFAULT_CLIENT_INFO(): credentials=ga_credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) + + +@pytest.mark.asyncio +async def test_transport_close_async(): + client = ServiceControllerAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc_asyncio", + ) + with mock.patch.object( + type(getattr(client.transport, "grpc_channel")), "close" + ) as close: + async with client: + close.assert_not_called() + close.assert_called_once() + + +def test_transport_close(): + transports = { + "grpc": "_grpc_channel", + } + + for transport, close_name in transports.items(): + client = ServiceControllerClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + with mock.patch.object( + type(getattr(client.transport, close_name)), "close" + ) as close: + with client: + close.assert_not_called() + close.assert_called_once() + + +def test_client_ctx(): + transports = [ + "grpc", + ] + for transport in transports: + client = ServiceControllerClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + # Test client calls underlying transport. + with mock.patch.object(type(client.transport), "close") as close: + close.assert_not_called() + with client: + pass + close.assert_called()