Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

feat: add context manager support in client #63

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -1595,6 +1595,12 @@ async def disable_service(
# Done; return the response.
return response

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc, tb):
await self.transport.close()


try:
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
Expand Down
Expand Up @@ -358,10 +358,7 @@ def __init__(
client_cert_source_for_mtls=client_cert_source_func,
quota_project_id=client_options.quota_project_id,
client_info=client_info,
always_use_jwt_access=(
Transport == type(self).get_transport_class("grpc")
or Transport == type(self).get_transport_class("grpc_asyncio")
),
always_use_jwt_access=True,
)

def list_services(
Expand Down Expand Up @@ -1769,6 +1766,19 @@ def disable_service(
# Done; return the response.
return response

def __enter__(self):
return self

def __exit__(self, type, value, traceback):
"""Releases underlying transport's resources.

.. warning::
ONLY use as a context manager if the transport is NOT shared
with other clients! Exiting the with block will CLOSE the transport
and may cause errors in other clients!
"""
self.transport.close()


try:
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
Expand Down
Expand Up @@ -222,6 +222,15 @@ def _prep_wrapped_messages(self, client_info):
),
}

def close(self):
"""Closes resources associated with the transport.

.. warning::
Only call this method if the transport is NOT shared
with other clients - this may cause errors in other clients!
"""
raise NotImplementedError()

@property
def operations_client(self) -> operations_v1.OperationsClient:
"""Return the client designed to process long-running operations."""
Expand Down
Expand Up @@ -737,5 +737,8 @@ def disable_service(
)
return self._stubs["disable_service"]

def close(self):
self.grpc_channel.close()


__all__ = ("ServiceManagerGrpcTransport",)
Expand Up @@ -764,5 +764,8 @@ def disable_service(
)
return self._stubs["disable_service"]

def close(self):
return self.grpc_channel.close()


__all__ = ("ServiceManagerGrpcAsyncIOTransport",)
7 changes: 6 additions & 1 deletion google/cloud/servicemanagement_v1/types/resources.py
Expand Up @@ -83,6 +83,7 @@ class Status(proto.Enum):

class Step(proto.Message):
r"""Represents the status of one operation step.

Attributes:
description (str):
The short description of the step.
Expand All @@ -101,6 +102,7 @@ class Step(proto.Message):

class Diagnostic(proto.Message):
r"""Represents a diagnostic message (error or warning)

Attributes:
location (str):
File name and line number of the error or
Expand Down Expand Up @@ -142,6 +144,7 @@ class ConfigSource(proto.Message):

class ConfigFile(proto.Message):
r"""Generic specification of a source configuration file

Attributes:
file_path (str):
The file name of the configuration file (full
Expand Down Expand Up @@ -169,6 +172,7 @@ class FileType(proto.Enum):

class ConfigRef(proto.Message):
r"""Represents a service configuration with its name and id.

Attributes:
name (str):
Resource name of a service config. It must
Expand Down Expand Up @@ -296,7 +300,8 @@ class DeleteServiceStrategy(proto.Message):
r"""Strategy used to delete a service. This strategy is a
placeholder only used by the system generated rollout to delete
a service.
"""

"""

rollout_id = proto.Field(proto.STRING, number=1,)
create_time = proto.Field(proto.MESSAGE, number=2, message=timestamp_pb2.Timestamp,)
Expand Down
27 changes: 25 additions & 2 deletions google/cloud/servicemanagement_v1/types/servicemanager.py
Expand Up @@ -52,6 +52,7 @@

class ListServicesRequest(proto.Message):
r"""Request message for ``ListServices`` method.

Attributes:
producer_project_id (str):
Include services produced by the specified
Expand Down Expand Up @@ -80,6 +81,7 @@ class ListServicesRequest(proto.Message):

class ListServicesResponse(proto.Message):
r"""Response message for ``ListServices`` method.

Attributes:
services (Sequence[google.cloud.servicemanagement_v1.types.ManagedService]):
The returned services will only have the name
Expand All @@ -101,6 +103,7 @@ def raw_page(self):

class GetServiceRequest(proto.Message):
r"""Request message for ``GetService`` method.

Attributes:
service_name (str):
Required. The name of the service. See the
Expand All @@ -113,6 +116,7 @@ class GetServiceRequest(proto.Message):

class CreateServiceRequest(proto.Message):
r"""Request message for CreateService method.

Attributes:
service (google.cloud.servicemanagement_v1.types.ManagedService):
Required. Initial values for the service
Expand All @@ -124,6 +128,7 @@ class CreateServiceRequest(proto.Message):

class DeleteServiceRequest(proto.Message):
r"""Request message for DeleteService method.

Attributes:
service_name (str):
Required. The name of the service. See the
Expand All @@ -137,6 +142,7 @@ class DeleteServiceRequest(proto.Message):

class UndeleteServiceRequest(proto.Message):
r"""Request message for UndeleteService method.

Attributes:
service_name (str):
Required. The name of the service. See the
Expand All @@ -150,6 +156,7 @@ class UndeleteServiceRequest(proto.Message):

class UndeleteServiceResponse(proto.Message):
r"""Response message for UndeleteService method.

Attributes:
service (google.cloud.servicemanagement_v1.types.ManagedService):
Revived service resource.
Expand All @@ -160,6 +167,7 @@ class UndeleteServiceResponse(proto.Message):

class GetServiceConfigRequest(proto.Message):
r"""Request message for GetServiceConfig method.

Attributes:
service_name (str):
Required. The name of the service. See the
Expand Down Expand Up @@ -188,6 +196,7 @@ class ConfigView(proto.Enum):

class ListServiceConfigsRequest(proto.Message):
r"""Request message for ListServiceConfigs method.

Attributes:
service_name (str):
Required. The name of the service. See the
Expand All @@ -209,6 +218,7 @@ class ListServiceConfigsRequest(proto.Message):

class ListServiceConfigsResponse(proto.Message):
r"""Response message for ListServiceConfigs method.

Attributes:
service_configs (Sequence[google.api.service_pb2.Service]):
The list of service configuration resources.
Expand All @@ -228,6 +238,7 @@ def raw_page(self):

class CreateServiceConfigRequest(proto.Message):
r"""Request message for CreateServiceConfig method.

Attributes:
service_name (str):
Required. The name of the service. See the
Expand All @@ -244,6 +255,7 @@ class CreateServiceConfigRequest(proto.Message):

class SubmitConfigSourceRequest(proto.Message):
r"""Request message for SubmitConfigSource method.

Attributes:
service_name (str):
Required. The name of the service. See the
Expand All @@ -269,6 +281,7 @@ class SubmitConfigSourceRequest(proto.Message):

class SubmitConfigSourceResponse(proto.Message):
r"""Response message for SubmitConfigSource method.

Attributes:
service_config (google.api.service_pb2.Service):
The generated service configuration.
Expand All @@ -279,6 +292,7 @@ class SubmitConfigSourceResponse(proto.Message):

class CreateServiceRolloutRequest(proto.Message):
r"""Request message for 'CreateServiceRollout'

Attributes:
service_name (str):
Required. The name of the service. See the
Expand All @@ -296,6 +310,7 @@ class CreateServiceRolloutRequest(proto.Message):

class ListServiceRolloutsRequest(proto.Message):
r"""Request message for 'ListServiceRollouts'

Attributes:
service_name (str):
Required. The name of the service. See the
Expand Down Expand Up @@ -327,6 +342,7 @@ class ListServiceRolloutsRequest(proto.Message):

class ListServiceRolloutsResponse(proto.Message):
r"""Response message for ListServiceRollouts method.

Attributes:
rollouts (Sequence[google.cloud.servicemanagement_v1.types.Rollout]):
The list of rollout resources.
Expand All @@ -344,6 +360,7 @@ def raw_page(self):

class GetServiceRolloutRequest(proto.Message):
r"""Request message for GetServiceRollout method.

Attributes:
service_name (str):
Required. The name of the service. See the
Expand All @@ -360,6 +377,7 @@ class GetServiceRolloutRequest(proto.Message):

class EnableServiceRequest(proto.Message):
r"""Request message for EnableService method.

Attributes:
service_name (str):
Required. Name of the service to enable.
Expand All @@ -383,11 +401,13 @@ class EnableServiceRequest(proto.Message):


class EnableServiceResponse(proto.Message):
r"""Operation payload for EnableService method. """
r"""Operation payload for EnableService method.
"""


class DisableServiceRequest(proto.Message):
r"""Request message for DisableService method.

Attributes:
service_name (str):
Required. Name of the service to disable.
Expand All @@ -411,11 +431,13 @@ class DisableServiceRequest(proto.Message):


class DisableServiceResponse(proto.Message):
r"""Operation payload for DisableService method. """
r"""Operation payload for DisableService method.
"""


class GenerateConfigReportRequest(proto.Message):
r"""Request message for GenerateConfigReport method.

Attributes:
new_config (google.protobuf.any_pb2.Any):
Required. Service configuration for which we want to
Expand All @@ -439,6 +461,7 @@ class GenerateConfigReportRequest(proto.Message):

class GenerateConfigReportResponse(proto.Message):
r"""Response message for GenerateConfigReport method.

Attributes:
service_name (str):
Name of the service this report belongs to.
Expand Down
50 changes: 50 additions & 0 deletions tests/unit/gapic/servicemanagement_v1/test_service_manager.py
Expand Up @@ -52,6 +52,7 @@
from google.api_core import grpc_helpers_async
from google.api_core import operation_async # type: ignore
from google.api_core import operations_v1
from google.api_core import path_template
from google.auth import credentials as ga_credentials
from google.auth.exceptions import MutualTLSChannelError
from google.cloud.servicemanagement_v1.services.service_manager import (
Expand Down Expand Up @@ -3599,6 +3600,9 @@ def test_service_manager_base_transport():
with pytest.raises(NotImplementedError):
getattr(transport, method)(request=object())

with pytest.raises(NotImplementedError):
transport.close()

# Additionally, the LRO client (a property) should
# also raise NotImplementedError
with pytest.raises(NotImplementedError):
Expand Down Expand Up @@ -4123,3 +4127,49 @@ def test_client_withDEFAULT_CLIENT_INFO():
credentials=ga_credentials.AnonymousCredentials(), client_info=client_info,
)
prep.assert_called_once_with(client_info)


@pytest.mark.asyncio
async def test_transport_close_async():
client = ServiceManagerAsyncClient(
credentials=ga_credentials.AnonymousCredentials(), transport="grpc_asyncio",
)
with mock.patch.object(
type(getattr(client.transport, "grpc_channel")), "close"
) as close:
async with client:
close.assert_not_called()
close.assert_called_once()


def test_transport_close():
transports = {
"grpc": "_grpc_channel",
}

for transport, close_name in transports.items():
client = ServiceManagerClient(
credentials=ga_credentials.AnonymousCredentials(), transport=transport
)
with mock.patch.object(
type(getattr(client.transport, close_name)), "close"
) as close:
with client:
close.assert_not_called()
close.assert_called_once()


def test_client_ctx():
transports = [
"grpc",
]
for transport in transports:
client = ServiceManagerClient(
credentials=ga_credentials.AnonymousCredentials(), transport=transport
)
# Test client calls underlying transport.
with mock.patch.object(type(client.transport), "close") as close:
close.assert_not_called()
with client:
pass
close.assert_called()