Skip to content
This repository has been archived by the owner on Jul 6, 2023. It is now read-only.

feat: add context manager support in client #70

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -2322,6 +2322,12 @@ async def update_settings(
# Done; return the response.
return response

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc, tb):
await self.transport.close()


try:
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
Expand Down
Expand Up @@ -457,10 +457,7 @@ def __init__(
client_cert_source_for_mtls=client_cert_source_func,
quota_project_id=client_options.quota_project_id,
client_info=client_info,
always_use_jwt_access=(
Transport == type(self).get_transport_class("grpc")
or Transport == type(self).get_transport_class("grpc_asyncio")
),
always_use_jwt_access=True,
)

def create_conversation(
Expand Down Expand Up @@ -2596,6 +2593,19 @@ def update_settings(
# Done; return the response.
return response

def __enter__(self):
return self

def __exit__(self, type, value, traceback):
"""Releases underlying transport's resources.

.. warning::
ONLY use as a context manager if the transport is NOT shared
with other clients! Exiting the with block will CLOSE the transport
and may cause errors in other clients!
"""
self.transport.close()


try:
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
Expand Down
Expand Up @@ -256,6 +256,15 @@ def _prep_wrapped_messages(self, client_info):
),
}

def close(self):
"""Closes resources associated with the transport.

.. warning::
Only call this method if the transport is NOT shared
with other clients - this may cause errors in other clients!
"""
raise NotImplementedError()

@property
def operations_client(self) -> operations_v1.OperationsClient:
"""Return the client designed to process long-running operations."""
Expand Down
Expand Up @@ -1026,5 +1026,8 @@ def update_settings(
)
return self._stubs["update_settings"]

def close(self):
self.grpc_channel.close()


__all__ = ("ContactCenterInsightsGrpcTransport",)
Expand Up @@ -1059,5 +1059,8 @@ def update_settings(
)
return self._stubs["update_settings"]

def close(self):
return self.grpc_channel.close()


__all__ = ("ContactCenterInsightsGrpcAsyncIOTransport",)
Expand Up @@ -132,6 +132,7 @@ class CalculateStatsResponse(proto.Message):

class TimeSeries(proto.Message):
r"""A time series representing conversations over time.

Attributes:
interval_duration (google.protobuf.duration_pb2.Duration):
The duration of each interval.
Expand All @@ -144,6 +145,7 @@ class TimeSeries(proto.Message):

class Interval(proto.Message):
r"""A single interval in a time series.

Attributes:
start_time (google.protobuf.timestamp_pb2.Timestamp):
The start time of this interval.
Expand Down Expand Up @@ -441,6 +443,7 @@ class ExportInsightsDataRequest(proto.Message):

class BigQueryDestination(proto.Message):
r"""A BigQuery Table Reference.

Attributes:
project_id (str):
A project ID or number. If specified, then
Expand Down Expand Up @@ -496,7 +499,8 @@ class ExportInsightsDataMetadata(proto.Message):


class ExportInsightsDataResponse(proto.Message):
r"""Response for an export insights operation. """
r"""Response for an export insights operation.
"""


class CreateIssueModelRequest(proto.Message):
Expand Down Expand Up @@ -628,7 +632,8 @@ class DeployIssueModelRequest(proto.Message):


class DeployIssueModelResponse(proto.Message):
r"""The response to deploy an issue model. """
r"""The response to deploy an issue model.
"""


class DeployIssueModelMetadata(proto.Message):
Expand Down Expand Up @@ -662,7 +667,8 @@ class UndeployIssueModelRequest(proto.Message):


class UndeployIssueModelResponse(proto.Message):
r"""The response to undeploy an issue model. """
r"""The response to undeploy an issue model.
"""


class UndeployIssueModelMetadata(proto.Message):
Expand Down
17 changes: 14 additions & 3 deletions google/cloud/contact_center_insights_v1/types/resources.py
Expand Up @@ -137,6 +137,7 @@ class Medium(proto.Enum):

class CallMetadata(proto.Message):
r"""Call-specific metadata.

Attributes:
customer_channel (int):
The audio channel that contains the customer.
Expand All @@ -149,6 +150,7 @@ class CallMetadata(proto.Message):

class Transcript(proto.Message):
r"""A message representing the transcript of a conversation.

Attributes:
transcript_segments (Sequence[google.cloud.contact_center_insights_v1.types.Conversation.Transcript.TranscriptSegment]):
A list of sequential transcript segments that
Expand All @@ -157,6 +159,7 @@ class Transcript(proto.Message):

class TranscriptSegment(proto.Message):
r"""A segment of a full transcript.

Attributes:
message_time (google.protobuf.timestamp_pb2.Timestamp):
The time that the message occurred, if
Expand Down Expand Up @@ -192,6 +195,7 @@ class TranscriptSegment(proto.Message):

class WordInfo(proto.Message):
r"""Word-level info for words in a transcript.

Attributes:
start_offset (google.protobuf.duration_pb2.Duration):
Time offset of the start of this word
Expand Down Expand Up @@ -387,6 +391,7 @@ class AnalysisResult(proto.Message):

class CallAnalysisMetadata(proto.Message):
r"""Call-specific metadata created during analysis.

Attributes:
annotations (Sequence[google.cloud.contact_center_insights_v1.types.CallAnnotation]):
A list of call annotations that apply to this
Expand Down Expand Up @@ -659,15 +664,18 @@ class DialogflowIntent(proto.Message):


class InterruptionData(proto.Message):
r"""The data for an interruption annotation. """
r"""The data for an interruption annotation.
"""


class SilenceData(proto.Message):
r"""The data for a silence annotation. """
r"""The data for a silence annotation.
"""


class HoldData(proto.Message):
r"""The data for a hold annotation. """
r"""The data for a hold annotation.
"""


class EntityMentionData(proto.Message):
Expand Down Expand Up @@ -766,6 +774,7 @@ class State(proto.Enum):

class InputDataConfig(proto.Message):
r"""Configs for the input data used to create the issue model.

Attributes:
medium (google.cloud.contact_center_insights_v1.types.Conversation.Medium):
Medium of conversations used in training data. This field is
Expand Down Expand Up @@ -835,6 +844,7 @@ class IssueModelLabelStats(proto.Message):

class IssueStats(proto.Message):
r"""Aggregated statistics about an issue.

Attributes:
issue (str):
Issue resource. Format:
Expand Down Expand Up @@ -1055,6 +1065,7 @@ class Settings(proto.Message):

class AnalysisConfig(proto.Message):
r"""Default configuration when creating Analyses in Insights.

Attributes:
runtime_integration_analysis_percentage (float):
Percentage of conversations created using Dialogflow runtime
Expand Down
Expand Up @@ -32,6 +32,7 @@
from google.api_core import grpc_helpers_async
from google.api_core import operation_async # type: ignore
from google.api_core import operations_v1
from google.api_core import path_template
from google.auth import credentials as ga_credentials
from google.auth.exceptions import MutualTLSChannelError
from google.cloud.contact_center_insights_v1.services.contact_center_insights import (
Expand Down Expand Up @@ -7496,6 +7497,9 @@ def test_contact_center_insights_base_transport():
with pytest.raises(NotImplementedError):
getattr(transport, method)(request=object())

with pytest.raises(NotImplementedError):
transport.close()

# Additionally, the LRO client (a property) should
# also raise NotImplementedError
with pytest.raises(NotImplementedError):
Expand Down Expand Up @@ -8174,3 +8178,49 @@ def test_client_withDEFAULT_CLIENT_INFO():
credentials=ga_credentials.AnonymousCredentials(), client_info=client_info,
)
prep.assert_called_once_with(client_info)


@pytest.mark.asyncio
async def test_transport_close_async():
client = ContactCenterInsightsAsyncClient(
credentials=ga_credentials.AnonymousCredentials(), transport="grpc_asyncio",
)
with mock.patch.object(
type(getattr(client.transport, "grpc_channel")), "close"
) as close:
async with client:
close.assert_not_called()
close.assert_called_once()


def test_transport_close():
transports = {
"grpc": "_grpc_channel",
}

for transport, close_name in transports.items():
client = ContactCenterInsightsClient(
credentials=ga_credentials.AnonymousCredentials(), transport=transport
)
with mock.patch.object(
type(getattr(client.transport, close_name)), "close"
) as close:
with client:
close.assert_not_called()
close.assert_called_once()


def test_client_ctx():
transports = [
"grpc",
]
for transport in transports:
client = ContactCenterInsightsClient(
credentials=ga_credentials.AnonymousCredentials(), transport=transport
)
# Test client calls underlying transport.
with mock.patch.object(type(client.transport), "close") as close:
close.assert_not_called()
with client:
pass
close.assert_called()