From 3e68d78e6f0c5d2e65f148935446baa92b5dd8ef Mon Sep 17 00:00:00 2001 From: "gcf-owl-bot[bot]" <78513119+gcf-owl-bot[bot]@users.noreply.github.com> Date: Thu, 7 Oct 2021 00:36:35 +0000 Subject: [PATCH] feat: add context manager support in client (#101) - [ ] Regenerate this pull request now. chore: fix docstring for first attribute of protos committer: @busunkim96 PiperOrigin-RevId: 401271153 Source-Link: https://github.com/googleapis/googleapis/commit/787f8c9a731f44e74a90b9847d48659ca9462d10 Source-Link: https://github.com/googleapis/googleapis-gen/commit/81decffe9fc72396a8153e756d1d67a6eecfd620 Copy-Tag: eyJwIjoiLmdpdGh1Yi8uT3dsQm90LnlhbWwiLCJoIjoiODFkZWNmZmU5ZmM3MjM5NmE4MTUzZTc1NmQxZDY3YTZlZWNmZDYyMCJ9 --- .../services/catalog_service/async_client.py | 6 +++ .../services/catalog_service/client.py | 18 +++++-- .../catalog_service/transports/base.py | 9 ++++ .../catalog_service/transports/grpc.py | 3 ++ .../transports/grpc_asyncio.py | 3 ++ .../completion_service/async_client.py | 6 +++ .../services/completion_service/client.py | 18 +++++-- .../completion_service/transports/base.py | 9 ++++ .../completion_service/transports/grpc.py | 3 ++ .../transports/grpc_asyncio.py | 3 ++ .../prediction_service/async_client.py | 6 +++ .../services/prediction_service/client.py | 18 +++++-- .../prediction_service/transports/base.py | 9 ++++ .../prediction_service/transports/grpc.py | 3 ++ .../transports/grpc_asyncio.py | 3 ++ .../services/product_service/async_client.py | 6 +++ .../services/product_service/client.py | 18 +++++-- .../product_service/transports/base.py | 9 ++++ .../product_service/transports/grpc.py | 3 ++ .../transports/grpc_asyncio.py | 3 ++ .../services/search_service/async_client.py | 6 +++ .../services/search_service/client.py | 18 +++++-- .../search_service/transports/base.py | 9 ++++ .../search_service/transports/grpc.py | 3 ++ .../search_service/transports/grpc_asyncio.py | 3 ++ .../user_event_service/async_client.py | 6 +++ .../services/user_event_service/client.py | 18 +++++-- .../user_event_service/transports/base.py | 9 ++++ .../user_event_service/transports/grpc.py | 3 ++ .../transports/grpc_asyncio.py | 3 ++ google/cloud/retail_v2/types/catalog.py | 1 + .../cloud/retail_v2/types/catalog_service.py | 1 + google/cloud/retail_v2/types/common.py | 5 ++ .../retail_v2/types/completion_service.py | 4 ++ google/cloud/retail_v2/types/import_config.py | 8 +++ .../retail_v2/types/prediction_service.py | 2 + .../cloud/retail_v2/types/product_service.py | 25 +++++++--- google/cloud/retail_v2/types/purge_config.py | 4 +- .../cloud/retail_v2/types/search_service.py | 8 +++ google/cloud/retail_v2/types/user_event.py | 2 + .../retail_v2/types/user_event_service.py | 7 ++- .../gapic/retail_v2/test_catalog_service.py | 50 +++++++++++++++++++ .../retail_v2/test_completion_service.py | 50 +++++++++++++++++++ .../retail_v2/test_prediction_service.py | 50 +++++++++++++++++++ .../gapic/retail_v2/test_product_service.py | 50 +++++++++++++++++++ .../gapic/retail_v2/test_search_service.py | 50 +++++++++++++++++++ .../retail_v2/test_user_event_service.py | 50 +++++++++++++++++++ 47 files changed, 569 insertions(+), 32 deletions(-) diff --git a/google/cloud/retail_v2/services/catalog_service/async_client.py b/google/cloud/retail_v2/services/catalog_service/async_client.py index 38888836..de45252c 100644 --- a/google/cloud/retail_v2/services/catalog_service/async_client.py +++ b/google/cloud/retail_v2/services/catalog_service/async_client.py @@ -529,6 +529,12 @@ async def get_default_branch( # Done; return the response. return response + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.transport.close() + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( diff --git a/google/cloud/retail_v2/services/catalog_service/client.py b/google/cloud/retail_v2/services/catalog_service/client.py index 7a0adad3..fdbb65dc 100644 --- a/google/cloud/retail_v2/services/catalog_service/client.py +++ b/google/cloud/retail_v2/services/catalog_service/client.py @@ -364,10 +364,7 @@ def __init__( client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, - always_use_jwt_access=( - Transport == type(self).get_transport_class("grpc") - or Transport == type(self).get_transport_class("grpc_asyncio") - ), + always_use_jwt_access=True, ) def list_catalogs( @@ -732,6 +729,19 @@ def get_default_branch( # Done; return the response. return response + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + """Releases underlying transport's resources. + + .. warning:: + ONLY use as a context manager if the transport is NOT shared + with other clients! Exiting the with block will CLOSE the transport + and may cause errors in other clients! + """ + self.transport.close() + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( diff --git a/google/cloud/retail_v2/services/catalog_service/transports/base.py b/google/cloud/retail_v2/services/catalog_service/transports/base.py index 27a8d130..9d692726 100644 --- a/google/cloud/retail_v2/services/catalog_service/transports/base.py +++ b/google/cloud/retail_v2/services/catalog_service/transports/base.py @@ -168,6 +168,15 @@ def _prep_wrapped_messages(self, client_info): ), } + def close(self): + """Closes resources associated with the transport. + + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! + """ + raise NotImplementedError() + @property def list_catalogs( self, diff --git a/google/cloud/retail_v2/services/catalog_service/transports/grpc.py b/google/cloud/retail_v2/services/catalog_service/transports/grpc.py index 5f9185d1..98a36126 100644 --- a/google/cloud/retail_v2/services/catalog_service/transports/grpc.py +++ b/google/cloud/retail_v2/services/catalog_service/transports/grpc.py @@ -383,5 +383,8 @@ def get_default_branch( ) return self._stubs["get_default_branch"] + def close(self): + self.grpc_channel.close() + __all__ = ("CatalogServiceGrpcTransport",) diff --git a/google/cloud/retail_v2/services/catalog_service/transports/grpc_asyncio.py b/google/cloud/retail_v2/services/catalog_service/transports/grpc_asyncio.py index d967c1fd..3ec7e33d 100644 --- a/google/cloud/retail_v2/services/catalog_service/transports/grpc_asyncio.py +++ b/google/cloud/retail_v2/services/catalog_service/transports/grpc_asyncio.py @@ -391,5 +391,8 @@ def get_default_branch( ) return self._stubs["get_default_branch"] + def close(self): + return self.grpc_channel.close() + __all__ = ("CatalogServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/retail_v2/services/completion_service/async_client.py b/google/cloud/retail_v2/services/completion_service/async_client.py index afef37ef..428509e2 100644 --- a/google/cloud/retail_v2/services/completion_service/async_client.py +++ b/google/cloud/retail_v2/services/completion_service/async_client.py @@ -291,6 +291,12 @@ async def import_completion_data( # Done; return the response. return response + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.transport.close() + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( diff --git a/google/cloud/retail_v2/services/completion_service/client.py b/google/cloud/retail_v2/services/completion_service/client.py index 4708f498..e425ee0b 100644 --- a/google/cloud/retail_v2/services/completion_service/client.py +++ b/google/cloud/retail_v2/services/completion_service/client.py @@ -354,10 +354,7 @@ def __init__( client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, - always_use_jwt_access=( - Transport == type(self).get_transport_class("grpc") - or Transport == type(self).get_transport_class("grpc_asyncio") - ), + always_use_jwt_access=True, ) def complete_query( @@ -484,6 +481,19 @@ def import_completion_data( # Done; return the response. return response + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + """Releases underlying transport's resources. + + .. warning:: + ONLY use as a context manager if the transport is NOT shared + with other clients! Exiting the with block will CLOSE the transport + and may cause errors in other clients! + """ + self.transport.close() + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( diff --git a/google/cloud/retail_v2/services/completion_service/transports/base.py b/google/cloud/retail_v2/services/completion_service/transports/base.py index 3cfb7c1d..bf793811 100644 --- a/google/cloud/retail_v2/services/completion_service/transports/base.py +++ b/google/cloud/retail_v2/services/completion_service/transports/base.py @@ -165,6 +165,15 @@ def _prep_wrapped_messages(self, client_info): ), } + def close(self): + """Closes resources associated with the transport. + + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! + """ + raise NotImplementedError() + @property def operations_client(self) -> operations_v1.OperationsClient: """Return the client designed to process long-running operations.""" diff --git a/google/cloud/retail_v2/services/completion_service/transports/grpc.py b/google/cloud/retail_v2/services/completion_service/transports/grpc.py index 2bc86b79..5ab1e783 100644 --- a/google/cloud/retail_v2/services/completion_service/transports/grpc.py +++ b/google/cloud/retail_v2/services/completion_service/transports/grpc.py @@ -318,5 +318,8 @@ def import_completion_data( ) return self._stubs["import_completion_data"] + def close(self): + self.grpc_channel.close() + __all__ = ("CompletionServiceGrpcTransport",) diff --git a/google/cloud/retail_v2/services/completion_service/transports/grpc_asyncio.py b/google/cloud/retail_v2/services/completion_service/transports/grpc_asyncio.py index bf777c8c..a7832c7e 100644 --- a/google/cloud/retail_v2/services/completion_service/transports/grpc_asyncio.py +++ b/google/cloud/retail_v2/services/completion_service/transports/grpc_asyncio.py @@ -323,5 +323,8 @@ def import_completion_data( ) return self._stubs["import_completion_data"] + def close(self): + return self.grpc_channel.close() + __all__ = ("CompletionServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/retail_v2/services/prediction_service/async_client.py b/google/cloud/retail_v2/services/prediction_service/async_client.py index 6bf0811c..4ceffa17 100644 --- a/google/cloud/retail_v2/services/prediction_service/async_client.py +++ b/google/cloud/retail_v2/services/prediction_service/async_client.py @@ -208,6 +208,12 @@ async def predict( # Done; return the response. return response + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.transport.close() + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( diff --git a/google/cloud/retail_v2/services/prediction_service/client.py b/google/cloud/retail_v2/services/prediction_service/client.py index 39183b9a..fff246f6 100644 --- a/google/cloud/retail_v2/services/prediction_service/client.py +++ b/google/cloud/retail_v2/services/prediction_service/client.py @@ -351,10 +351,7 @@ def __init__( client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, - always_use_jwt_access=( - Transport == type(self).get_transport_class("grpc") - or Transport == type(self).get_transport_class("grpc_asyncio") - ), + always_use_jwt_access=True, ) def predict( @@ -406,6 +403,19 @@ def predict( # Done; return the response. return response + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + """Releases underlying transport's resources. + + .. warning:: + ONLY use as a context manager if the transport is NOT shared + with other clients! Exiting the with block will CLOSE the transport + and may cause errors in other clients! + """ + self.transport.close() + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( diff --git a/google/cloud/retail_v2/services/prediction_service/transports/base.py b/google/cloud/retail_v2/services/prediction_service/transports/base.py index eb850e7b..a348426b 100644 --- a/google/cloud/retail_v2/services/prediction_service/transports/base.py +++ b/google/cloud/retail_v2/services/prediction_service/transports/base.py @@ -157,6 +157,15 @@ def _prep_wrapped_messages(self, client_info): ), } + def close(self): + """Closes resources associated with the transport. + + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! + """ + raise NotImplementedError() + @property def predict( self, diff --git a/google/cloud/retail_v2/services/prediction_service/transports/grpc.py b/google/cloud/retail_v2/services/prediction_service/transports/grpc.py index ca2e3089..89e7e769 100644 --- a/google/cloud/retail_v2/services/prediction_service/transports/grpc.py +++ b/google/cloud/retail_v2/services/prediction_service/transports/grpc.py @@ -253,5 +253,8 @@ def predict( ) return self._stubs["predict"] + def close(self): + self.grpc_channel.close() + __all__ = ("PredictionServiceGrpcTransport",) diff --git a/google/cloud/retail_v2/services/prediction_service/transports/grpc_asyncio.py b/google/cloud/retail_v2/services/prediction_service/transports/grpc_asyncio.py index d7bf8ff3..f5994fd9 100644 --- a/google/cloud/retail_v2/services/prediction_service/transports/grpc_asyncio.py +++ b/google/cloud/retail_v2/services/prediction_service/transports/grpc_asyncio.py @@ -257,5 +257,8 @@ def predict( ) return self._stubs["predict"] + def close(self): + return self.grpc_channel.close() + __all__ = ("PredictionServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/retail_v2/services/product_service/async_client.py b/google/cloud/retail_v2/services/product_service/async_client.py index 288a418c..5e95bbd6 100644 --- a/google/cloud/retail_v2/services/product_service/async_client.py +++ b/google/cloud/retail_v2/services/product_service/async_client.py @@ -1115,6 +1115,12 @@ async def remove_fulfillment_places( # Done; return the response. return response + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.transport.close() + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( diff --git a/google/cloud/retail_v2/services/product_service/client.py b/google/cloud/retail_v2/services/product_service/client.py index a9afa377..77b76cd6 100644 --- a/google/cloud/retail_v2/services/product_service/client.py +++ b/google/cloud/retail_v2/services/product_service/client.py @@ -378,10 +378,7 @@ def __init__( client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, - always_use_jwt_access=( - Transport == type(self).get_transport_class("grpc") - or Transport == type(self).get_transport_class("grpc_asyncio") - ), + always_use_jwt_access=True, ) def create_product( @@ -1317,6 +1314,19 @@ def remove_fulfillment_places( # Done; return the response. return response + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + """Releases underlying transport's resources. + + .. warning:: + ONLY use as a context manager if the transport is NOT shared + with other clients! Exiting the with block will CLOSE the transport + and may cause errors in other clients! + """ + self.transport.close() + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( diff --git a/google/cloud/retail_v2/services/product_service/transports/base.py b/google/cloud/retail_v2/services/product_service/transports/base.py index 3444c918..e062d131 100644 --- a/google/cloud/retail_v2/services/product_service/transports/base.py +++ b/google/cloud/retail_v2/services/product_service/transports/base.py @@ -203,6 +203,15 @@ def _prep_wrapped_messages(self, client_info): ), } + def close(self): + """Closes resources associated with the transport. + + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! + """ + raise NotImplementedError() + @property def operations_client(self) -> operations_v1.OperationsClient: """Return the client designed to process long-running operations.""" diff --git a/google/cloud/retail_v2/services/product_service/transports/grpc.py b/google/cloud/retail_v2/services/product_service/transports/grpc.py index bcfef051..9f0fe118 100644 --- a/google/cloud/retail_v2/services/product_service/transports/grpc.py +++ b/google/cloud/retail_v2/services/product_service/transports/grpc.py @@ -575,5 +575,8 @@ def remove_fulfillment_places( ) return self._stubs["remove_fulfillment_places"] + def close(self): + self.grpc_channel.close() + __all__ = ("ProductServiceGrpcTransport",) diff --git a/google/cloud/retail_v2/services/product_service/transports/grpc_asyncio.py b/google/cloud/retail_v2/services/product_service/transports/grpc_asyncio.py index baf40438..6c838484 100644 --- a/google/cloud/retail_v2/services/product_service/transports/grpc_asyncio.py +++ b/google/cloud/retail_v2/services/product_service/transports/grpc_asyncio.py @@ -591,5 +591,8 @@ def remove_fulfillment_places( ) return self._stubs["remove_fulfillment_places"] + def close(self): + return self.grpc_channel.close() + __all__ = ("ProductServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/retail_v2/services/search_service/async_client.py b/google/cloud/retail_v2/services/search_service/async_client.py index 093b856b..089818d2 100644 --- a/google/cloud/retail_v2/services/search_service/async_client.py +++ b/google/cloud/retail_v2/services/search_service/async_client.py @@ -236,6 +236,12 @@ async def search( # Done; return the response. return response + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.transport.close() + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( diff --git a/google/cloud/retail_v2/services/search_service/client.py b/google/cloud/retail_v2/services/search_service/client.py index d2359507..88923d14 100644 --- a/google/cloud/retail_v2/services/search_service/client.py +++ b/google/cloud/retail_v2/services/search_service/client.py @@ -370,10 +370,7 @@ def __init__( client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, - always_use_jwt_access=( - Transport == type(self).get_transport_class("grpc") - or Transport == type(self).get_transport_class("grpc_asyncio") - ), + always_use_jwt_access=True, ) def search( @@ -444,6 +441,19 @@ def search( # Done; return the response. return response + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + """Releases underlying transport's resources. + + .. warning:: + ONLY use as a context manager if the transport is NOT shared + with other clients! Exiting the with block will CLOSE the transport + and may cause errors in other clients! + """ + self.transport.close() + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( diff --git a/google/cloud/retail_v2/services/search_service/transports/base.py b/google/cloud/retail_v2/services/search_service/transports/base.py index fb44bac5..349a4d00 100644 --- a/google/cloud/retail_v2/services/search_service/transports/base.py +++ b/google/cloud/retail_v2/services/search_service/transports/base.py @@ -157,6 +157,15 @@ def _prep_wrapped_messages(self, client_info): ), } + def close(self): + """Closes resources associated with the transport. + + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! + """ + raise NotImplementedError() + @property def search( self, diff --git a/google/cloud/retail_v2/services/search_service/transports/grpc.py b/google/cloud/retail_v2/services/search_service/transports/grpc.py index d232fee8..1c15621e 100644 --- a/google/cloud/retail_v2/services/search_service/transports/grpc.py +++ b/google/cloud/retail_v2/services/search_service/transports/grpc.py @@ -261,5 +261,8 @@ def search( ) return self._stubs["search"] + def close(self): + self.grpc_channel.close() + __all__ = ("SearchServiceGrpcTransport",) diff --git a/google/cloud/retail_v2/services/search_service/transports/grpc_asyncio.py b/google/cloud/retail_v2/services/search_service/transports/grpc_asyncio.py index 681dfb40..f36d1c2a 100644 --- a/google/cloud/retail_v2/services/search_service/transports/grpc_asyncio.py +++ b/google/cloud/retail_v2/services/search_service/transports/grpc_asyncio.py @@ -266,5 +266,8 @@ def search( ) return self._stubs["search"] + def close(self): + return self.grpc_channel.close() + __all__ = ("SearchServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/retail_v2/services/user_event_service/async_client.py b/google/cloud/retail_v2/services/user_event_service/async_client.py index beccce65..6652e673 100644 --- a/google/cloud/retail_v2/services/user_event_service/async_client.py +++ b/google/cloud/retail_v2/services/user_event_service/async_client.py @@ -542,6 +542,12 @@ async def rejoin_user_events( # Done; return the response. return response + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.transport.close() + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( diff --git a/google/cloud/retail_v2/services/user_event_service/client.py b/google/cloud/retail_v2/services/user_event_service/client.py index f8b4ee18..0a73e7e9 100644 --- a/google/cloud/retail_v2/services/user_event_service/client.py +++ b/google/cloud/retail_v2/services/user_event_service/client.py @@ -376,10 +376,7 @@ def __init__( client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, - always_use_jwt_access=( - Transport == type(self).get_transport_class("grpc") - or Transport == type(self).get_transport_class("grpc_asyncio") - ), + always_use_jwt_access=True, ) def write_user_event( @@ -736,6 +733,19 @@ def rejoin_user_events( # Done; return the response. return response + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + """Releases underlying transport's resources. + + .. warning:: + ONLY use as a context manager if the transport is NOT shared + with other clients! Exiting the with block will CLOSE the transport + and may cause errors in other clients! + """ + self.transport.close() + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( diff --git a/google/cloud/retail_v2/services/user_event_service/transports/base.py b/google/cloud/retail_v2/services/user_event_service/transports/base.py index 2ba27a08..e44c6216 100644 --- a/google/cloud/retail_v2/services/user_event_service/transports/base.py +++ b/google/cloud/retail_v2/services/user_event_service/transports/base.py @@ -199,6 +199,15 @@ def _prep_wrapped_messages(self, client_info): ), } + def close(self): + """Closes resources associated with the transport. + + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! + """ + raise NotImplementedError() + @property def operations_client(self) -> operations_v1.OperationsClient: """Return the client designed to process long-running operations.""" diff --git a/google/cloud/retail_v2/services/user_event_service/transports/grpc.py b/google/cloud/retail_v2/services/user_event_service/transports/grpc.py index 21e77785..415d3db5 100644 --- a/google/cloud/retail_v2/services/user_event_service/transports/grpc.py +++ b/google/cloud/retail_v2/services/user_event_service/transports/grpc.py @@ -403,5 +403,8 @@ def rejoin_user_events( ) return self._stubs["rejoin_user_events"] + def close(self): + self.grpc_channel.close() + __all__ = ("UserEventServiceGrpcTransport",) diff --git a/google/cloud/retail_v2/services/user_event_service/transports/grpc_asyncio.py b/google/cloud/retail_v2/services/user_event_service/transports/grpc_asyncio.py index ac42c6bf..4dc009d8 100644 --- a/google/cloud/retail_v2/services/user_event_service/transports/grpc_asyncio.py +++ b/google/cloud/retail_v2/services/user_event_service/transports/grpc_asyncio.py @@ -417,5 +417,8 @@ def rejoin_user_events( ) return self._stubs["rejoin_user_events"] + def close(self): + return self.grpc_channel.close() + __all__ = ("UserEventServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/retail_v2/types/catalog.py b/google/cloud/retail_v2/types/catalog.py index a0d42016..f7a76f3e 100644 --- a/google/cloud/retail_v2/types/catalog.py +++ b/google/cloud/retail_v2/types/catalog.py @@ -86,6 +86,7 @@ class ProductLevelConfig(proto.Message): class Catalog(proto.Message): r"""The catalog configuration. + Attributes: name (str): Required. Immutable. The fully qualified diff --git a/google/cloud/retail_v2/types/catalog_service.py b/google/cloud/retail_v2/types/catalog_service.py index 4611d996..a51a5b59 100644 --- a/google/cloud/retail_v2/types/catalog_service.py +++ b/google/cloud/retail_v2/types/catalog_service.py @@ -131,6 +131,7 @@ class UpdateCatalogRequest(proto.Message): class SetDefaultBranchRequest(proto.Message): r"""Request message to set a specified branch as new default_branch. + Attributes: catalog (str): Full resource name of the catalog, such as diff --git a/google/cloud/retail_v2/types/common.py b/google/cloud/retail_v2/types/common.py index b0584707..676bf19c 100644 --- a/google/cloud/retail_v2/types/common.py +++ b/google/cloud/retail_v2/types/common.py @@ -225,6 +225,7 @@ class FulfillmentInfo(proto.Message): class Image(proto.Message): r"""[Product][google.cloud.retail.v2.Product] thumbnail/detail image. + Attributes: uri (str): Required. URI of the image. @@ -256,6 +257,7 @@ class Image(proto.Message): class Interval(proto.Message): r"""A floating point interval. + Attributes: minimum (float): Inclusive lower bound. @@ -405,6 +407,7 @@ class PriceRange(proto.Message): class Rating(proto.Message): r"""The rating of a [Product][google.cloud.retail.v2.Product]. + Attributes: rating_count (int): The total number of ratings. This value is independent of @@ -437,6 +440,7 @@ class Rating(proto.Message): class UserInfo(proto.Message): r"""Information of an end user. + Attributes: user_id (str): Highly recommended for logged-in users. Unique identifier @@ -498,6 +502,7 @@ class UserInfo(proto.Message): class Promotion(proto.Message): r"""Promotion information. + Attributes: promotion_id (str): ID of the promotion. For example, "free gift". diff --git a/google/cloud/retail_v2/types/completion_service.py b/google/cloud/retail_v2/types/completion_service.py index 3ea97969..43788966 100644 --- a/google/cloud/retail_v2/types/completion_service.py +++ b/google/cloud/retail_v2/types/completion_service.py @@ -26,6 +26,7 @@ class CompleteQueryRequest(proto.Message): r"""Auto-complete parameters. + Attributes: catalog (str): Required. Catalog for which the completion is performed. @@ -105,6 +106,7 @@ class CompleteQueryRequest(proto.Message): class CompleteQueryResponse(proto.Message): r"""Response of the auto-complete query. + Attributes: completion_results (Sequence[google.cloud.retail_v2.types.CompleteQueryResponse.CompletionResult]): Results of the matching suggestions. The @@ -140,6 +142,7 @@ class CompleteQueryResponse(proto.Message): class CompletionResult(proto.Message): r"""Resource that represents completion results. + Attributes: suggestion (str): The suggestion for the query. @@ -155,6 +158,7 @@ class CompletionResult(proto.Message): class RecentSearchResult(proto.Message): r"""Recent search of this user. + Attributes: recent_search (str): The recent search query. diff --git a/google/cloud/retail_v2/types/import_config.py b/google/cloud/retail_v2/types/import_config.py index 0e4e64f0..f9b712b8 100644 --- a/google/cloud/retail_v2/types/import_config.py +++ b/google/cloud/retail_v2/types/import_config.py @@ -88,6 +88,7 @@ class GcsSource(proto.Message): class BigQuerySource(proto.Message): r"""BigQuery source import data from. + Attributes: partition_date (google.type.date_pb2.Date): BigQuery time partitioned table's \_PARTITIONDATE in @@ -177,6 +178,7 @@ class UserEventInlineSource(proto.Message): class ImportErrorsConfig(proto.Message): r"""Configuration of destination for Import related errors. + Attributes: gcs_prefix (str): Google Cloud Storage path for import errors. This must be an @@ -190,6 +192,7 @@ class ImportErrorsConfig(proto.Message): class ImportProductsRequest(proto.Message): r"""Request message for Import methods. + Attributes: parent (str): Required. @@ -257,6 +260,7 @@ class ReconciliationMode(proto.Enum): class ImportUserEventsRequest(proto.Message): r"""Request message for the ImportUserEvents request. + Attributes: parent (str): Required. @@ -277,6 +281,7 @@ class ImportUserEventsRequest(proto.Message): class ImportCompletionDataRequest(proto.Message): r"""Request message for ImportCompletionData methods. + Attributes: parent (str): Required. The catalog which the suggestions dataset belongs @@ -305,6 +310,7 @@ class ImportCompletionDataRequest(proto.Message): class ProductInputConfig(proto.Message): r"""The input config source for products. + Attributes: product_inline_source (google.cloud.retail_v2.types.ProductInlineSource): The Inline source for the input content for @@ -329,6 +335,7 @@ class ProductInputConfig(proto.Message): class UserEventInputConfig(proto.Message): r"""The input config source for user events. + Attributes: user_event_inline_source (google.cloud.retail_v2.types.UserEventInlineSource): Required. The Inline source for the input @@ -353,6 +360,7 @@ class UserEventInputConfig(proto.Message): class CompletionDataInputConfig(proto.Message): r"""The input config source for completion data. + Attributes: big_query_source (google.cloud.retail_v2.types.BigQuerySource): Required. BigQuery input source. diff --git a/google/cloud/retail_v2/types/prediction_service.py b/google/cloud/retail_v2/types/prediction_service.py index f6880776..d36bda57 100644 --- a/google/cloud/retail_v2/types/prediction_service.py +++ b/google/cloud/retail_v2/types/prediction_service.py @@ -26,6 +26,7 @@ class PredictRequest(proto.Message): r"""Request message for Predict method. + Attributes: placement (str): Required. Full resource name of the format: @@ -157,6 +158,7 @@ class PredictRequest(proto.Message): class PredictResponse(proto.Message): r"""Response message for predict method. + Attributes: results (Sequence[google.cloud.retail_v2.types.PredictResponse.PredictionResult]): A list of recommended products. The order diff --git a/google/cloud/retail_v2/types/product_service.py b/google/cloud/retail_v2/types/product_service.py index 61a9d6ae..45c9db6e 100644 --- a/google/cloud/retail_v2/types/product_service.py +++ b/google/cloud/retail_v2/types/product_service.py @@ -44,6 +44,7 @@ class CreateProductRequest(proto.Message): r"""Request message for [CreateProduct][] method. + Attributes: parent (str): Required. The parent catalog resource name, such as @@ -79,6 +80,7 @@ class CreateProductRequest(proto.Message): class GetProductRequest(proto.Message): r"""Request message for [GetProduct][] method. + Attributes: name (str): Required. Full resource name of @@ -99,6 +101,7 @@ class GetProductRequest(proto.Message): class UpdateProductRequest(proto.Message): r"""Request message for [UpdateProduct][] method. + Attributes: product (google.cloud.retail_v2.types.Product): Required. The product to update/create. @@ -137,6 +140,7 @@ class UpdateProductRequest(proto.Message): class DeleteProductRequest(proto.Message): r"""Request message for [DeleteProduct][] method. + Attributes: name (str): Required. Full resource name of @@ -288,6 +292,7 @@ def raw_page(self): class SetInventoryRequest(proto.Message): r"""Request message for [SetInventory][] method. + Attributes: inventory (google.cloud.retail_v2.types.Product): Required. The inventory information to update. The allowable @@ -366,18 +371,21 @@ class SetInventoryMetadata(proto.Message): r"""Metadata related to the progress of the SetInventory operation. Currently empty because there is no meaningful metadata populated from the [SetInventory][] method. - """ + + """ class SetInventoryResponse(proto.Message): r"""Response of the SetInventoryRequest. Currently empty because there is no meaningful response populated from the [SetInventory][] method. - """ + + """ class AddFulfillmentPlacesRequest(proto.Message): r"""Request message for [AddFulfillmentPlaces][] method. + Attributes: product (str): Required. Full resource name of @@ -453,18 +461,21 @@ class AddFulfillmentPlacesMetadata(proto.Message): r"""Metadata related to the progress of the AddFulfillmentPlaces operation. Currently empty because there is no meaningful metadata populated from the [AddFulfillmentPlaces][] method. - """ + + """ class AddFulfillmentPlacesResponse(proto.Message): r"""Response of the RemoveFulfillmentPlacesRequest. Currently empty because there is no meaningful response populated from the [AddFulfillmentPlaces][] method. - """ + + """ class RemoveFulfillmentPlacesRequest(proto.Message): r"""Request message for [RemoveFulfillmentPlaces][] method. + Attributes: product (str): Required. Full resource name of @@ -535,14 +546,16 @@ class RemoveFulfillmentPlacesMetadata(proto.Message): r"""Metadata related to the progress of the RemoveFulfillmentPlaces operation. Currently empty because there is no meaningful metadata populated from the [RemoveFulfillmentPlaces][] method. - """ + + """ class RemoveFulfillmentPlacesResponse(proto.Message): r"""Response of the RemoveFulfillmentPlacesRequest. Currently empty because there is no meaningful response populated from the [RemoveFulfillmentPlaces][] method. - """ + + """ __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/retail_v2/types/purge_config.py b/google/cloud/retail_v2/types/purge_config.py index 9a69c35f..748cfac8 100644 --- a/google/cloud/retail_v2/types/purge_config.py +++ b/google/cloud/retail_v2/types/purge_config.py @@ -26,11 +26,13 @@ class PurgeMetadata(proto.Message): r"""Metadata related to the progress of the Purge operation. This will be returned by the google.longrunning.Operation.metadata field. - """ + + """ class PurgeUserEventsRequest(proto.Message): r"""Request message for PurgeUserEvents method. + Attributes: parent (str): Required. The resource name of the catalog under which the diff --git a/google/cloud/retail_v2/types/search_service.py b/google/cloud/retail_v2/types/search_service.py index f617d497..c1fe08ba 100644 --- a/google/cloud/retail_v2/types/search_service.py +++ b/google/cloud/retail_v2/types/search_service.py @@ -234,6 +234,7 @@ class SearchRequest(proto.Message): class FacetSpec(proto.Message): r"""A facet specification to perform faceted search. + Attributes: facet_key (google.cloud.retail_v2.types.SearchRequest.FacetSpec.FacetKey): Required. The facet key specification. @@ -303,6 +304,7 @@ class FacetSpec(proto.Message): class FacetKey(proto.Message): r"""Specifies how a facet is computed. + Attributes: key (str): Required. Supported textual and numerical facet keys in @@ -455,6 +457,7 @@ class FacetKey(proto.Message): class DynamicFacetSpec(proto.Message): r"""The specifications of dynamically generated facets. + Attributes: mode (google.cloud.retail_v2.types.SearchRequest.DynamicFacetSpec.Mode): Mode of the DynamicFacet feature. Defaults to @@ -474,6 +477,7 @@ class Mode(proto.Enum): class BoostSpec(proto.Message): r"""Boost specification to boost certain items. + Attributes: condition_boost_specs (Sequence[google.cloud.retail_v2.types.SearchRequest.BoostSpec.ConditionBoostSpec]): Condition boost specifications. If a product @@ -486,6 +490,7 @@ class BoostSpec(proto.Message): class ConditionBoostSpec(proto.Message): r"""Boost applies to products which match a condition. + Attributes: condition (str): An expression which specifies a boost condition. The syntax @@ -628,6 +633,7 @@ class SearchResponse(proto.Message): class SearchResult(proto.Message): r"""Represents the search results. + Attributes: id (str): [Product.id][google.cloud.retail.v2.Product.id] of the @@ -718,6 +724,7 @@ class SearchResult(proto.Message): class Facet(proto.Message): r"""A facet result. + Attributes: key (str): The key for this facet. E.g., "colorFamilies" @@ -730,6 +737,7 @@ class Facet(proto.Message): class FacetValue(proto.Message): r"""A facet value which contains value names and their count. + Attributes: value (str): Text value of a facet, such as "Black" for diff --git a/google/cloud/retail_v2/types/user_event.py b/google/cloud/retail_v2/types/user_event.py index 1c934780..790656d1 100644 --- a/google/cloud/retail_v2/types/user_event.py +++ b/google/cloud/retail_v2/types/user_event.py @@ -300,6 +300,7 @@ class UserEvent(proto.Message): class ProductDetail(proto.Message): r"""Detailed product information associated with a user event. + Attributes: product (google.cloud.retail_v2.types.Product): Required. [Product][google.cloud.retail.v2.Product] @@ -357,6 +358,7 @@ class CompletionDetail(proto.Message): class PurchaseTransaction(proto.Message): r"""A transaction represents the entire purchase transaction. + Attributes: id (str): The transaction ID with a length limit of 128 diff --git a/google/cloud/retail_v2/types/user_event_service.py b/google/cloud/retail_v2/types/user_event_service.py index eed85538..c87a208d 100644 --- a/google/cloud/retail_v2/types/user_event_service.py +++ b/google/cloud/retail_v2/types/user_event_service.py @@ -32,6 +32,7 @@ class WriteUserEventRequest(proto.Message): r"""Request message for WriteUserEvent method. + Attributes: parent (str): Required. The parent catalog resource name, such as @@ -46,6 +47,7 @@ class WriteUserEventRequest(proto.Message): class CollectUserEventRequest(proto.Message): r"""Request message for CollectUserEvent method. + Attributes: parent (str): Required. The parent catalog name, such as @@ -74,6 +76,7 @@ class CollectUserEventRequest(proto.Message): class RejoinUserEventsRequest(proto.Message): r"""Request message for RejoinUserEvents method. + Attributes: parent (str): Required. The parent catalog resource name, such as @@ -107,6 +110,7 @@ class UserEventRejoinScope(proto.Enum): class RejoinUserEventsResponse(proto.Message): r"""Response message for RejoinUserEvents method. + Attributes: rejoined_user_events_count (int): Number of user events that were joined with @@ -117,7 +121,8 @@ class RejoinUserEventsResponse(proto.Message): class RejoinUserEventsMetadata(proto.Message): - r"""Metadata for RejoinUserEvents method. """ + r"""Metadata for RejoinUserEvents method. + """ __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/tests/unit/gapic/retail_v2/test_catalog_service.py b/tests/unit/gapic/retail_v2/test_catalog_service.py index 0118a31e..c68fe805 100644 --- a/tests/unit/gapic/retail_v2/test_catalog_service.py +++ b/tests/unit/gapic/retail_v2/test_catalog_service.py @@ -29,6 +29,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import path_template from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.retail_v2.services.catalog_service import CatalogServiceAsyncClient @@ -1582,6 +1583,9 @@ def test_catalog_service_base_transport(): with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) + with pytest.raises(NotImplementedError): + transport.close() + @requires_google_auth_gte_1_25_0 def test_catalog_service_base_transport_with_credentials_file(): @@ -2090,3 +2094,49 @@ def test_client_withDEFAULT_CLIENT_INFO(): credentials=ga_credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) + + +@pytest.mark.asyncio +async def test_transport_close_async(): + client = CatalogServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc_asyncio", + ) + with mock.patch.object( + type(getattr(client.transport, "grpc_channel")), "close" + ) as close: + async with client: + close.assert_not_called() + close.assert_called_once() + + +def test_transport_close(): + transports = { + "grpc": "_grpc_channel", + } + + for transport, close_name in transports.items(): + client = CatalogServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + with mock.patch.object( + type(getattr(client.transport, close_name)), "close" + ) as close: + with client: + close.assert_not_called() + close.assert_called_once() + + +def test_client_ctx(): + transports = [ + "grpc", + ] + for transport in transports: + client = CatalogServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + # Test client calls underlying transport. + with mock.patch.object(type(client.transport), "close") as close: + close.assert_not_called() + with client: + pass + close.assert_called() diff --git a/tests/unit/gapic/retail_v2/test_completion_service.py b/tests/unit/gapic/retail_v2/test_completion_service.py index ec0cc266..48df9d7e 100644 --- a/tests/unit/gapic/retail_v2/test_completion_service.py +++ b/tests/unit/gapic/retail_v2/test_completion_service.py @@ -32,6 +32,7 @@ from google.api_core import grpc_helpers_async from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 +from google.api_core import path_template from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.retail_v2.services.completion_service import ( @@ -898,6 +899,9 @@ def test_completion_service_base_transport(): with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) + with pytest.raises(NotImplementedError): + transport.close() + # Additionally, the LRO client (a property) should # also raise NotImplementedError with pytest.raises(NotImplementedError): @@ -1411,3 +1415,49 @@ def test_client_withDEFAULT_CLIENT_INFO(): credentials=ga_credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) + + +@pytest.mark.asyncio +async def test_transport_close_async(): + client = CompletionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc_asyncio", + ) + with mock.patch.object( + type(getattr(client.transport, "grpc_channel")), "close" + ) as close: + async with client: + close.assert_not_called() + close.assert_called_once() + + +def test_transport_close(): + transports = { + "grpc": "_grpc_channel", + } + + for transport, close_name in transports.items(): + client = CompletionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + with mock.patch.object( + type(getattr(client.transport, close_name)), "close" + ) as close: + with client: + close.assert_not_called() + close.assert_called_once() + + +def test_client_ctx(): + transports = [ + "grpc", + ] + for transport in transports: + client = CompletionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + # Test client calls underlying transport. + with mock.patch.object(type(client.transport), "close") as close: + close.assert_not_called() + with client: + pass + close.assert_called() diff --git a/tests/unit/gapic/retail_v2/test_prediction_service.py b/tests/unit/gapic/retail_v2/test_prediction_service.py index b5cbf7be..fca07646 100644 --- a/tests/unit/gapic/retail_v2/test_prediction_service.py +++ b/tests/unit/gapic/retail_v2/test_prediction_service.py @@ -29,6 +29,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import path_template from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.retail_v2.services.prediction_service import ( @@ -761,6 +762,9 @@ def test_prediction_service_base_transport(): with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) + with pytest.raises(NotImplementedError): + transport.close() + @requires_google_auth_gte_1_25_0 def test_prediction_service_base_transport_with_credentials_file(): @@ -1253,3 +1257,49 @@ def test_client_withDEFAULT_CLIENT_INFO(): credentials=ga_credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) + + +@pytest.mark.asyncio +async def test_transport_close_async(): + client = PredictionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc_asyncio", + ) + with mock.patch.object( + type(getattr(client.transport, "grpc_channel")), "close" + ) as close: + async with client: + close.assert_not_called() + close.assert_called_once() + + +def test_transport_close(): + transports = { + "grpc": "_grpc_channel", + } + + for transport, close_name in transports.items(): + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + with mock.patch.object( + type(getattr(client.transport, close_name)), "close" + ) as close: + with client: + close.assert_not_called() + close.assert_called_once() + + +def test_client_ctx(): + transports = [ + "grpc", + ] + for transport in transports: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + # Test client calls underlying transport. + with mock.patch.object(type(client.transport), "close") as close: + close.assert_not_called() + with client: + pass + close.assert_called() diff --git a/tests/unit/gapic/retail_v2/test_product_service.py b/tests/unit/gapic/retail_v2/test_product_service.py index 67f22750..9e095889 100644 --- a/tests/unit/gapic/retail_v2/test_product_service.py +++ b/tests/unit/gapic/retail_v2/test_product_service.py @@ -32,6 +32,7 @@ from google.api_core import grpc_helpers_async from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 +from google.api_core import path_template from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.retail_v2.services.product_service import ProductServiceAsyncClient @@ -2793,6 +2794,9 @@ def test_product_service_base_transport(): with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) + with pytest.raises(NotImplementedError): + transport.close() + # Additionally, the LRO client (a property) should # also raise NotImplementedError with pytest.raises(NotImplementedError): @@ -3342,3 +3346,49 @@ def test_client_withDEFAULT_CLIENT_INFO(): credentials=ga_credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) + + +@pytest.mark.asyncio +async def test_transport_close_async(): + client = ProductServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc_asyncio", + ) + with mock.patch.object( + type(getattr(client.transport, "grpc_channel")), "close" + ) as close: + async with client: + close.assert_not_called() + close.assert_called_once() + + +def test_transport_close(): + transports = { + "grpc": "_grpc_channel", + } + + for transport, close_name in transports.items(): + client = ProductServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + with mock.patch.object( + type(getattr(client.transport, close_name)), "close" + ) as close: + with client: + close.assert_not_called() + close.assert_called_once() + + +def test_client_ctx(): + transports = [ + "grpc", + ] + for transport in transports: + client = ProductServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + # Test client calls underlying transport. + with mock.patch.object(type(client.transport), "close") as close: + close.assert_not_called() + with client: + pass + close.assert_called() diff --git a/tests/unit/gapic/retail_v2/test_search_service.py b/tests/unit/gapic/retail_v2/test_search_service.py index 932bce09..5ecb1839 100644 --- a/tests/unit/gapic/retail_v2/test_search_service.py +++ b/tests/unit/gapic/retail_v2/test_search_service.py @@ -29,6 +29,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import path_template from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.retail_v2.services.search_service import SearchServiceAsyncClient @@ -906,6 +907,9 @@ def test_search_service_base_transport(): with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) + with pytest.raises(NotImplementedError): + transport.close() + @requires_google_auth_gte_1_25_0 def test_search_service_base_transport_with_credentials_file(): @@ -1422,3 +1426,49 @@ def test_client_withDEFAULT_CLIENT_INFO(): credentials=ga_credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) + + +@pytest.mark.asyncio +async def test_transport_close_async(): + client = SearchServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc_asyncio", + ) + with mock.patch.object( + type(getattr(client.transport, "grpc_channel")), "close" + ) as close: + async with client: + close.assert_not_called() + close.assert_called_once() + + +def test_transport_close(): + transports = { + "grpc": "_grpc_channel", + } + + for transport, close_name in transports.items(): + client = SearchServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + with mock.patch.object( + type(getattr(client.transport, close_name)), "close" + ) as close: + with client: + close.assert_not_called() + close.assert_called_once() + + +def test_client_ctx(): + transports = [ + "grpc", + ] + for transport in transports: + client = SearchServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + # Test client calls underlying transport. + with mock.patch.object(type(client.transport), "close") as close: + close.assert_not_called() + with client: + pass + close.assert_called() diff --git a/tests/unit/gapic/retail_v2/test_user_event_service.py b/tests/unit/gapic/retail_v2/test_user_event_service.py index 8216cf15..eb2a27f4 100644 --- a/tests/unit/gapic/retail_v2/test_user_event_service.py +++ b/tests/unit/gapic/retail_v2/test_user_event_service.py @@ -33,6 +33,7 @@ from google.api_core import grpc_helpers_async from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 +from google.api_core import path_template from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.retail_v2.services.user_event_service import ( @@ -1396,6 +1397,9 @@ def test_user_event_service_base_transport(): with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) + with pytest.raises(NotImplementedError): + transport.close() + # Additionally, the LRO client (a property) should # also raise NotImplementedError with pytest.raises(NotImplementedError): @@ -1943,3 +1947,49 @@ def test_client_withDEFAULT_CLIENT_INFO(): credentials=ga_credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) + + +@pytest.mark.asyncio +async def test_transport_close_async(): + client = UserEventServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc_asyncio", + ) + with mock.patch.object( + type(getattr(client.transport, "grpc_channel")), "close" + ) as close: + async with client: + close.assert_not_called() + close.assert_called_once() + + +def test_transport_close(): + transports = { + "grpc": "_grpc_channel", + } + + for transport, close_name in transports.items(): + client = UserEventServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + with mock.patch.object( + type(getattr(client.transport, close_name)), "close" + ) as close: + with client: + close.assert_not_called() + close.assert_called_once() + + +def test_client_ctx(): + transports = [ + "grpc", + ] + for transport in transports: + client = UserEventServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + # Test client calls underlying transport. + with mock.patch.object(type(client.transport), "close") as close: + close.assert_not_called() + with client: + pass + close.assert_called()