Skip to content
This repository has been archived by the owner on Sep 5, 2023. It is now read-only.

feat: add context manager support in client #34

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -536,6 +536,12 @@ async def update_iap_settings(
# Done; return the response.
return response

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc, tb):
await self.transport.close()


try:
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
Expand Down
Expand Up @@ -338,10 +338,7 @@ def __init__(
client_cert_source_for_mtls=client_cert_source_func,
quota_project_id=client_options.quota_project_id,
client_info=client_info,
always_use_jwt_access=(
Transport == type(self).get_transport_class("grpc")
or Transport == type(self).get_transport_class("grpc_asyncio")
),
always_use_jwt_access=True,
)

def set_iam_policy(
Expand Down Expand Up @@ -709,6 +706,19 @@ def update_iap_settings(
# Done; return the response.
return response

def __enter__(self):
return self

def __exit__(self, type, value, traceback):
"""Releases underlying transport's resources.

.. warning::
ONLY use as a context manager if the transport is NOT shared
with other clients! Exiting the with block will CLOSE the transport
and may cause errors in other clients!
"""
self.transport.close()


try:
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
Expand Down
Expand Up @@ -173,6 +173,15 @@ def _prep_wrapped_messages(self, client_info):
),
}

def close(self):
"""Closes resources associated with the transport.

.. warning::
Only call this method if the transport is NOT shared
with other clients - this may cause errors in other clients!
"""
raise NotImplementedError()

@property
def set_iam_policy(
self,
Expand Down
Expand Up @@ -373,5 +373,8 @@ def update_iap_settings(
)
return self._stubs["update_iap_settings"]

def close(self):
self.grpc_channel.close()


__all__ = ("IdentityAwareProxyAdminServiceGrpcTransport",)
Expand Up @@ -376,5 +376,8 @@ def update_iap_settings(
)
return self._stubs["update_iap_settings"]

def close(self):
return self.grpc_channel.close()


__all__ = ("IdentityAwareProxyAdminServiceGrpcAsyncIOTransport",)
Expand Up @@ -581,6 +581,12 @@ async def delete_identity_aware_proxy_client(
request, retry=retry, timeout=timeout, metadata=metadata,
)

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc, tb):
await self.transport.close()


try:
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
Expand Down
Expand Up @@ -340,10 +340,7 @@ def __init__(
client_cert_source_for_mtls=client_cert_source_func,
quota_project_id=client_options.quota_project_id,
client_info=client_info,
always_use_jwt_access=(
Transport == type(self).get_transport_class("grpc")
or Transport == type(self).get_transport_class("grpc_asyncio")
),
always_use_jwt_access=True,
)

def list_brands(
Expand Down Expand Up @@ -773,6 +770,19 @@ def delete_identity_aware_proxy_client(
request, retry=retry, timeout=timeout, metadata=metadata,
)

def __enter__(self):
return self

def __exit__(self, type, value, traceback):
"""Releases underlying transport's resources.

.. warning::
ONLY use as a context manager if the transport is NOT shared
with other clients! Exiting the with block will CLOSE the transport
and may cause errors in other clients!
"""
self.transport.close()


try:
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
Expand Down
Expand Up @@ -189,6 +189,15 @@ def _prep_wrapped_messages(self, client_info):
),
}

def close(self):
"""Closes resources associated with the transport.

.. warning::
Only call this method if the transport is NOT shared
with other clients - this may cause errors in other clients!
"""
raise NotImplementedError()

@property
def list_brands(
self,
Expand Down
Expand Up @@ -478,5 +478,8 @@ def delete_identity_aware_proxy_client(
)
return self._stubs["delete_identity_aware_proxy_client"]

def close(self):
self.grpc_channel.close()


__all__ = ("IdentityAwareProxyOAuthServiceGrpcTransport",)
Expand Up @@ -488,5 +488,8 @@ def delete_identity_aware_proxy_client(
)
return self._stubs["delete_identity_aware_proxy_client"]

def close(self):
return self.grpc_channel.close()


__all__ = ("IdentityAwareProxyOAuthServiceGrpcAsyncIOTransport",)
16 changes: 16 additions & 0 deletions google/cloud/iap_v1/types/service.py
Expand Up @@ -50,6 +50,7 @@

class GetIapSettingsRequest(proto.Message):
r"""The request sent to GetIapSettings.

Attributes:
name (str):
Required. The resource name for which to retrieve the
Expand All @@ -62,6 +63,7 @@ class GetIapSettingsRequest(proto.Message):

class UpdateIapSettingsRequest(proto.Message):
r"""The request sent to UpdateIapSettings.

Attributes:
iap_settings (google.cloud.iap_v1.types.IapSettings):
Required. The new values for the IAP settings to be updated.
Expand All @@ -83,6 +85,7 @@ class UpdateIapSettingsRequest(proto.Message):

class IapSettings(proto.Message):
r"""The IAP configurable settings.

Attributes:
name (str):
Required. The resource name of the IAP
Expand All @@ -104,6 +107,7 @@ class IapSettings(proto.Message):

class AccessSettings(proto.Message):
r"""Access related settings for IAP protected apps.

Attributes:
gcip_settings (google.cloud.iap_v1.types.GcipSettings):
GCIP claims and endpoint configurations for
Expand All @@ -122,6 +126,7 @@ class AccessSettings(proto.Message):

class GcipSettings(proto.Message):
r"""Allows customers to configure tenant_id for GCIP instance per-app.

Attributes:
tenant_ids (Sequence[str]):
GCIP tenant ids that are linked to the IAP resource.
Expand Down Expand Up @@ -182,6 +187,7 @@ class OAuthSettings(proto.Message):

class ApplicationSettings(proto.Message):
r"""Wrapper over application specific settings for IAP.

Attributes:
csm_settings (google.cloud.iap_v1.types.CsmSettings):
Settings to configure IAP's behavior for a
Expand Down Expand Up @@ -246,6 +252,7 @@ class AccessDeniedPageSettings(proto.Message):

class ListBrandsRequest(proto.Message):
r"""The request sent to ListBrands.

Attributes:
parent (str):
Required. GCP Project number/id. In the following format:
Expand All @@ -257,6 +264,7 @@ class ListBrandsRequest(proto.Message):

class ListBrandsResponse(proto.Message):
r"""Response message for ListBrands.

Attributes:
brands (Sequence[google.cloud.iap_v1.types.Brand]):
Brands existing in the project.
Expand All @@ -267,6 +275,7 @@ class ListBrandsResponse(proto.Message):

class CreateBrandRequest(proto.Message):
r"""The request sent to CreateBrand.

Attributes:
parent (str):
Required. GCP Project number/id under which the brand is to
Expand All @@ -282,6 +291,7 @@ class CreateBrandRequest(proto.Message):

class GetBrandRequest(proto.Message):
r"""The request sent to GetBrand.

Attributes:
name (str):
Required. Name of the brand to be fetched. In the following
Expand All @@ -293,6 +303,7 @@ class GetBrandRequest(proto.Message):

class ListIdentityAwareProxyClientsRequest(proto.Message):
r"""The request sent to ListIdentityAwareProxyClients.

Attributes:
parent (str):
Required. Full brand path. In the following format:
Expand Down Expand Up @@ -320,6 +331,7 @@ class ListIdentityAwareProxyClientsRequest(proto.Message):

class ListIdentityAwareProxyClientsResponse(proto.Message):
r"""Response message for ListIdentityAwareProxyClients.

Attributes:
identity_aware_proxy_clients (Sequence[google.cloud.iap_v1.types.IdentityAwareProxyClient]):
Clients existing in the brand.
Expand All @@ -341,6 +353,7 @@ def raw_page(self):

class CreateIdentityAwareProxyClientRequest(proto.Message):
r"""The request sent to CreateIdentityAwareProxyClient.

Attributes:
parent (str):
Required. Path to create the client in. In the following
Expand All @@ -359,6 +372,7 @@ class CreateIdentityAwareProxyClientRequest(proto.Message):

class GetIdentityAwareProxyClientRequest(proto.Message):
r"""The request sent to GetIdentityAwareProxyClient.

Attributes:
name (str):
Required. Name of the Identity Aware Proxy client to be
Expand All @@ -371,6 +385,7 @@ class GetIdentityAwareProxyClientRequest(proto.Message):

class ResetIdentityAwareProxyClientSecretRequest(proto.Message):
r"""The request sent to ResetIdentityAwareProxyClientSecret.

Attributes:
name (str):
Required. Name of the Identity Aware Proxy client to that
Expand All @@ -383,6 +398,7 @@ class ResetIdentityAwareProxyClientSecretRequest(proto.Message):

class DeleteIdentityAwareProxyClientRequest(proto.Message):
r"""The request sent to DeleteIdentityAwareProxyClient.

Attributes:
name (str):
Required. Name of the Identity Aware Proxy client to be
Expand Down
50 changes: 50 additions & 0 deletions tests/unit/gapic/iap_v1/test_identity_aware_proxy_admin_service.py
Expand Up @@ -29,6 +29,7 @@
from google.api_core import gapic_v1
from google.api_core import grpc_helpers
from google.api_core import grpc_helpers_async
from google.api_core import path_template
from google.auth import credentials as ga_credentials
from google.auth.exceptions import MutualTLSChannelError
from google.cloud.iap_v1.services.identity_aware_proxy_admin_service import (
Expand Down Expand Up @@ -1413,6 +1414,9 @@ def test_identity_aware_proxy_admin_service_base_transport():
with pytest.raises(NotImplementedError):
getattr(transport, method)(request=object())

with pytest.raises(NotImplementedError):
transport.close()


@requires_google_auth_gte_1_25_0
def test_identity_aware_proxy_admin_service_base_transport_with_credentials_file():
Expand Down Expand Up @@ -1886,3 +1890,49 @@ def test_client_withDEFAULT_CLIENT_INFO():
credentials=ga_credentials.AnonymousCredentials(), client_info=client_info,
)
prep.assert_called_once_with(client_info)


@pytest.mark.asyncio
async def test_transport_close_async():
client = IdentityAwareProxyAdminServiceAsyncClient(
credentials=ga_credentials.AnonymousCredentials(), transport="grpc_asyncio",
)
with mock.patch.object(
type(getattr(client.transport, "grpc_channel")), "close"
) as close:
async with client:
close.assert_not_called()
close.assert_called_once()


def test_transport_close():
transports = {
"grpc": "_grpc_channel",
}

for transport, close_name in transports.items():
client = IdentityAwareProxyAdminServiceClient(
credentials=ga_credentials.AnonymousCredentials(), transport=transport
)
with mock.patch.object(
type(getattr(client.transport, close_name)), "close"
) as close:
with client:
close.assert_not_called()
close.assert_called_once()


def test_client_ctx():
transports = [
"grpc",
]
for transport in transports:
client = IdentityAwareProxyAdminServiceClient(
credentials=ga_credentials.AnonymousCredentials(), transport=transport
)
# Test client calls underlying transport.
with mock.patch.object(type(client.transport), "close") as close:
close.assert_not_called()
with client:
pass
close.assert_called()