diff --git a/google/area120/tables_v1alpha1/services/tables_service/async_client.py b/google/area120/tables_v1alpha1/services/tables_service/async_client.py index 84d5c2e..547f6eb 100644 --- a/google/area120/tables_v1alpha1/services/tables_service/async_client.py +++ b/google/area120/tables_v1alpha1/services/tables_service/async_client.py @@ -934,6 +934,12 @@ async def batch_delete_rows( request, retry=retry, timeout=timeout, metadata=metadata, ) + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.transport.close() + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( diff --git a/google/area120/tables_v1alpha1/services/tables_service/client.py b/google/area120/tables_v1alpha1/services/tables_service/client.py index 8878cc1..d7630ee 100644 --- a/google/area120/tables_v1alpha1/services/tables_service/client.py +++ b/google/area120/tables_v1alpha1/services/tables_service/client.py @@ -374,10 +374,7 @@ def __init__( client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, - always_use_jwt_access=( - Transport == type(self).get_transport_class("grpc") - or Transport == type(self).get_transport_class("grpc_asyncio") - ), + always_use_jwt_access=True, ) def get_table( @@ -1139,6 +1136,19 @@ def batch_delete_rows( request, retry=retry, timeout=timeout, metadata=metadata, ) + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + """Releases underlying transport's resources. + + .. warning:: + ONLY use as a context manager if the transport is NOT shared + with other clients! Exiting the with block will CLOSE the transport + and may cause errors in other clients! + """ + self.transport.close() + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( diff --git a/google/area120/tables_v1alpha1/services/tables_service/transports/base.py b/google/area120/tables_v1alpha1/services/tables_service/transports/base.py index 348f156..389f523 100644 --- a/google/area120/tables_v1alpha1/services/tables_service/transports/base.py +++ b/google/area120/tables_v1alpha1/services/tables_service/transports/base.py @@ -198,6 +198,15 @@ def _prep_wrapped_messages(self, client_info): ), } + def close(self): + """Closes resources associated with the transport. + + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! + """ + raise NotImplementedError() + @property def get_table( self, diff --git a/google/area120/tables_v1alpha1/services/tables_service/transports/grpc.py b/google/area120/tables_v1alpha1/services/tables_service/transports/grpc.py index 6eeff66..2de12ae 100644 --- a/google/area120/tables_v1alpha1/services/tables_service/transports/grpc.py +++ b/google/area120/tables_v1alpha1/services/tables_service/transports/grpc.py @@ -540,5 +540,8 @@ def batch_delete_rows( ) return self._stubs["batch_delete_rows"] + def close(self): + self.grpc_channel.close() + __all__ = ("TablesServiceGrpcTransport",) diff --git a/google/area120/tables_v1alpha1/services/tables_service/transports/grpc_asyncio.py b/google/area120/tables_v1alpha1/services/tables_service/transports/grpc_asyncio.py index c5a1076..dcb67a9 100644 --- a/google/area120/tables_v1alpha1/services/tables_service/transports/grpc_asyncio.py +++ b/google/area120/tables_v1alpha1/services/tables_service/transports/grpc_asyncio.py @@ -555,5 +555,8 @@ def batch_delete_rows( ) return self._stubs["batch_delete_rows"] + def close(self): + return self.grpc_channel.close() + __all__ = ("TablesServiceGrpcAsyncIOTransport",) diff --git a/google/area120/tables_v1alpha1/types/tables.py b/google/area120/tables_v1alpha1/types/tables.py index 359bcd0..25387eb 100644 --- a/google/area120/tables_v1alpha1/types/tables.py +++ b/google/area120/tables_v1alpha1/types/tables.py @@ -59,6 +59,7 @@ class View(proto.Enum): class GetTableRequest(proto.Message): r"""Request message for TablesService.GetTable. + Attributes: name (str): Required. The name of the table to retrieve. @@ -70,6 +71,7 @@ class GetTableRequest(proto.Message): class ListTablesRequest(proto.Message): r"""Request message for TablesService.ListTables. + Attributes: page_size (int): The maximum number of tables to return. The @@ -93,6 +95,7 @@ class ListTablesRequest(proto.Message): class ListTablesResponse(proto.Message): r"""Response message for TablesService.ListTables. + Attributes: tables (Sequence[google.area120.tables_v1alpha1.types.Table]): The list of tables. @@ -112,6 +115,7 @@ def raw_page(self): class GetWorkspaceRequest(proto.Message): r"""Request message for TablesService.GetWorkspace. + Attributes: name (str): Required. The name of the workspace to @@ -123,6 +127,7 @@ class GetWorkspaceRequest(proto.Message): class ListWorkspacesRequest(proto.Message): r"""Request message for TablesService.ListWorkspaces. + Attributes: page_size (int): The maximum number of workspaces to return. @@ -145,6 +150,7 @@ class ListWorkspacesRequest(proto.Message): class ListWorkspacesResponse(proto.Message): r"""Response message for TablesService.ListWorkspaces. + Attributes: workspaces (Sequence[google.area120.tables_v1alpha1.types.Workspace]): The list of workspaces. @@ -164,6 +170,7 @@ def raw_page(self): class GetRowRequest(proto.Message): r"""Request message for TablesService.GetRow. + Attributes: name (str): Required. The name of the row to retrieve. @@ -179,6 +186,7 @@ class GetRowRequest(proto.Message): class ListRowsRequest(proto.Message): r"""Request message for TablesService.ListRows. + Attributes: parent (str): Required. The parent table. @@ -216,6 +224,7 @@ class ListRowsRequest(proto.Message): class ListRowsResponse(proto.Message): r"""Response message for TablesService.ListRows. + Attributes: rows (Sequence[google.area120.tables_v1alpha1.types.Row]): The rows from the specified table. @@ -235,6 +244,7 @@ def raw_page(self): class CreateRowRequest(proto.Message): r"""Request message for TablesService.CreateRow. + Attributes: parent (str): Required. The parent table where this row @@ -253,6 +263,7 @@ class CreateRowRequest(proto.Message): class BatchCreateRowsRequest(proto.Message): r"""Request message for TablesService.BatchCreateRows. + Attributes: parent (str): Required. The parent table where the rows @@ -270,6 +281,7 @@ class BatchCreateRowsRequest(proto.Message): class BatchCreateRowsResponse(proto.Message): r"""Response message for TablesService.BatchCreateRows. + Attributes: rows (Sequence[google.area120.tables_v1alpha1.types.Row]): The created rows. @@ -280,6 +292,7 @@ class BatchCreateRowsResponse(proto.Message): class UpdateRowRequest(proto.Message): r"""Request message for TablesService.UpdateRow. + Attributes: row (google.area120.tables_v1alpha1.types.Row): Required. The row to update. @@ -299,6 +312,7 @@ class UpdateRowRequest(proto.Message): class BatchUpdateRowsRequest(proto.Message): r"""Request message for TablesService.BatchUpdateRows. + Attributes: parent (str): Required. The parent table shared by all rows @@ -316,6 +330,7 @@ class BatchUpdateRowsRequest(proto.Message): class BatchUpdateRowsResponse(proto.Message): r"""Response message for TablesService.BatchUpdateRows. + Attributes: rows (Sequence[google.area120.tables_v1alpha1.types.Row]): The updated rows. @@ -326,6 +341,7 @@ class BatchUpdateRowsResponse(proto.Message): class DeleteRowRequest(proto.Message): r"""Request message for TablesService.DeleteRow + Attributes: name (str): Required. The name of the row to delete. @@ -337,6 +353,7 @@ class DeleteRowRequest(proto.Message): class BatchDeleteRowsRequest(proto.Message): r"""Request message for TablesService.BatchDeleteRows + Attributes: parent (str): Required. The parent table shared by all rows @@ -355,6 +372,7 @@ class BatchDeleteRowsRequest(proto.Message): class Table(proto.Message): r"""A single table. + Attributes: name (str): The resource name of the table. Table names have the form @@ -373,6 +391,7 @@ class Table(proto.Message): class ColumnDescription(proto.Message): r"""Details on a column in the table. + Attributes: name (str): column name @@ -416,6 +435,7 @@ class ColumnDescription(proto.Message): class LabeledItem(proto.Message): r"""A single item in a labeled column. + Attributes: name (str): Display string as entered by user. @@ -429,6 +449,7 @@ class LabeledItem(proto.Message): class RelationshipDetails(proto.Message): r"""Details about a relationship column. + Attributes: linked_table (str): The name of the table this relationship is @@ -456,6 +477,7 @@ class LookupDetails(proto.Message): class Row(proto.Message): r"""A single row in a table. + Attributes: name (str): The resource name of the row. Row names have the form @@ -476,6 +498,7 @@ class Row(proto.Message): class Workspace(proto.Message): r"""A single workspace. + Attributes: name (str): The resource name of the workspace. Workspace names have the diff --git a/tests/unit/gapic/tables_v1alpha1/test_tables_service.py b/tests/unit/gapic/tables_v1alpha1/test_tables_service.py index bd4b9f4..134f495 100644 --- a/tests/unit/gapic/tables_v1alpha1/test_tables_service.py +++ b/tests/unit/gapic/tables_v1alpha1/test_tables_service.py @@ -29,6 +29,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import path_template from google.area120.tables_v1alpha1.services.tables_service import ( TablesServiceAsyncClient, ) @@ -2944,6 +2945,9 @@ def test_tables_service_base_transport(): with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) + with pytest.raises(NotImplementedError): + transport.close() + @requires_google_auth_gte_1_25_0 def test_tables_service_base_transport_with_credentials_file(): @@ -3505,3 +3509,49 @@ def test_client_withDEFAULT_CLIENT_INFO(): credentials=ga_credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) + + +@pytest.mark.asyncio +async def test_transport_close_async(): + client = TablesServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc_asyncio", + ) + with mock.patch.object( + type(getattr(client.transport, "grpc_channel")), "close" + ) as close: + async with client: + close.assert_not_called() + close.assert_called_once() + + +def test_transport_close(): + transports = { + "grpc": "_grpc_channel", + } + + for transport, close_name in transports.items(): + client = TablesServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + with mock.patch.object( + type(getattr(client.transport, close_name)), "close" + ) as close: + with client: + close.assert_not_called() + close.assert_called_once() + + +def test_client_ctx(): + transports = [ + "grpc", + ] + for transport in transports: + client = TablesServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + # Test client calls underlying transport. + with mock.patch.object(type(client.transport), "close") as close: + close.assert_not_called() + with client: + pass + close.assert_called()