diff --git a/google/cloud/trace_v1/services/trace_service/async_client.py b/google/cloud/trace_v1/services/trace_service/async_client.py index 7004e379..3861f643 100644 --- a/google/cloud/trace_v1/services/trace_service/async_client.py +++ b/google/cloud/trace_v1/services/trace_service/async_client.py @@ -414,6 +414,12 @@ async def patch_traces( request, retry=retry, timeout=timeout, metadata=metadata, ) + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.transport.close() + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( diff --git a/google/cloud/trace_v1/services/trace_service/client.py b/google/cloud/trace_v1/services/trace_service/client.py index 6d19ff43..39104cd5 100644 --- a/google/cloud/trace_v1/services/trace_service/client.py +++ b/google/cloud/trace_v1/services/trace_service/client.py @@ -332,10 +332,7 @@ def __init__( client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, - always_use_jwt_access=( - Transport == type(self).get_transport_class("grpc") - or Transport == type(self).get_transport_class("grpc_asyncio") - ), + always_use_jwt_access=True, ) def list_traces( @@ -561,6 +558,19 @@ def patch_traces( request, retry=retry, timeout=timeout, metadata=metadata, ) + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + """Releases underlying transport's resources. + + .. warning:: + ONLY use as a context manager if the transport is NOT shared + with other clients! Exiting the with block will CLOSE the transport + and may cause errors in other clients! + """ + self.transport.close() + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( diff --git a/google/cloud/trace_v1/services/trace_service/transports/base.py b/google/cloud/trace_v1/services/trace_service/transports/base.py index f1fd8503..899a0a8c 100644 --- a/google/cloud/trace_v1/services/trace_service/transports/base.py +++ b/google/cloud/trace_v1/services/trace_service/transports/base.py @@ -204,6 +204,15 @@ def _prep_wrapped_messages(self, client_info): ), } + def close(self): + """Closes resources associated with the transport. + + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! + """ + raise NotImplementedError() + @property def list_traces( self, diff --git a/google/cloud/trace_v1/services/trace_service/transports/grpc.py b/google/cloud/trace_v1/services/trace_service/transports/grpc.py index e27a7e4d..39bc4cee 100644 --- a/google/cloud/trace_v1/services/trace_service/transports/grpc.py +++ b/google/cloud/trace_v1/services/trace_service/transports/grpc.py @@ -312,5 +312,8 @@ def patch_traces(self) -> Callable[[trace.PatchTracesRequest], empty_pb2.Empty]: ) return self._stubs["patch_traces"] + def close(self): + self.grpc_channel.close() + __all__ = ("TraceServiceGrpcTransport",) diff --git a/google/cloud/trace_v1/services/trace_service/transports/grpc_asyncio.py b/google/cloud/trace_v1/services/trace_service/transports/grpc_asyncio.py index 13f9afb7..13f6b4c4 100644 --- a/google/cloud/trace_v1/services/trace_service/transports/grpc_asyncio.py +++ b/google/cloud/trace_v1/services/trace_service/transports/grpc_asyncio.py @@ -317,5 +317,8 @@ def patch_traces( ) return self._stubs["patch_traces"] + def close(self): + return self.grpc_channel.close() + __all__ = ("TraceServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/trace_v1/types/trace.py b/google/cloud/trace_v1/types/trace.py index a9e232eb..46778209 100644 --- a/google/cloud/trace_v1/types/trace.py +++ b/google/cloud/trace_v1/types/trace.py @@ -56,6 +56,7 @@ class Trace(proto.Message): class Traces(proto.Message): r"""List of new or updated traces. + Attributes: traces (Sequence[google.cloud.trace_v1.types.Trace]): List of traces. @@ -254,6 +255,7 @@ class ViewType(proto.Enum): class ListTracesResponse(proto.Message): r"""The response message for the ``ListTraces`` method. + Attributes: traces (Sequence[google.cloud.trace_v1.types.Trace]): List of trace records as specified by the @@ -275,6 +277,7 @@ def raw_page(self): class GetTraceRequest(proto.Message): r"""The request message for the ``GetTrace`` method. + Attributes: project_id (str): Required. ID of the Cloud project where the @@ -289,6 +292,7 @@ class GetTraceRequest(proto.Message): class PatchTracesRequest(proto.Message): r"""The request message for the ``PatchTraces`` method. + Attributes: project_id (str): Required. ID of the Cloud project where the diff --git a/google/cloud/trace_v2/services/trace_service/async_client.py b/google/cloud/trace_v2/services/trace_service/async_client.py index b9b377c6..cc719539 100644 --- a/google/cloud/trace_v2/services/trace_service/async_client.py +++ b/google/cloud/trace_v2/services/trace_service/async_client.py @@ -315,6 +315,12 @@ async def create_span( # Done; return the response. return response + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.transport.close() + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( diff --git a/google/cloud/trace_v2/services/trace_service/client.py b/google/cloud/trace_v2/services/trace_service/client.py index 7c3ccc40..fdbdecc6 100644 --- a/google/cloud/trace_v2/services/trace_service/client.py +++ b/google/cloud/trace_v2/services/trace_service/client.py @@ -351,10 +351,7 @@ def __init__( client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, - always_use_jwt_access=( - Transport == type(self).get_transport_class("grpc") - or Transport == type(self).get_transport_class("grpc_asyncio") - ), + always_use_jwt_access=True, ) def batch_write_spans( @@ -497,6 +494,19 @@ def create_span( # Done; return the response. return response + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + """Releases underlying transport's resources. + + .. warning:: + ONLY use as a context manager if the transport is NOT shared + with other clients! Exiting the with block will CLOSE the transport + and may cause errors in other clients! + """ + self.transport.close() + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( diff --git a/google/cloud/trace_v2/services/trace_service/transports/base.py b/google/cloud/trace_v2/services/trace_service/transports/base.py index b4adadf1..a964c714 100644 --- a/google/cloud/trace_v2/services/trace_service/transports/base.py +++ b/google/cloud/trace_v2/services/trace_service/transports/base.py @@ -177,6 +177,15 @@ def _prep_wrapped_messages(self, client_info): ), } + def close(self): + """Closes resources associated with the transport. + + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! + """ + raise NotImplementedError() + @property def batch_write_spans( self, diff --git a/google/cloud/trace_v2/services/trace_service/transports/grpc.py b/google/cloud/trace_v2/services/trace_service/transports/grpc.py index 304de9ae..7483d31c 100644 --- a/google/cloud/trace_v2/services/trace_service/transports/grpc.py +++ b/google/cloud/trace_v2/services/trace_service/transports/grpc.py @@ -283,5 +283,8 @@ def create_span(self) -> Callable[[trace.Span], trace.Span]: ) return self._stubs["create_span"] + def close(self): + self.grpc_channel.close() + __all__ = ("TraceServiceGrpcTransport",) diff --git a/google/cloud/trace_v2/services/trace_service/transports/grpc_asyncio.py b/google/cloud/trace_v2/services/trace_service/transports/grpc_asyncio.py index 5c3dc8a7..d52fbc88 100644 --- a/google/cloud/trace_v2/services/trace_service/transports/grpc_asyncio.py +++ b/google/cloud/trace_v2/services/trace_service/transports/grpc_asyncio.py @@ -286,5 +286,8 @@ def create_span(self) -> Callable[[trace.Span], Awaitable[trace.Span]]: ) return self._stubs["create_span"] + def close(self): + return self.grpc_channel.close() + __all__ = ("TraceServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/trace_v2/types/trace.py b/google/cloud/trace_v2/types/trace.py index db598507..220f47f0 100644 --- a/google/cloud/trace_v2/types/trace.py +++ b/google/cloud/trace_v2/types/trace.py @@ -125,6 +125,7 @@ class SpanKind(proto.Enum): class Attributes(proto.Message): r"""A set of attributes, each in the format ``[KEY]:[VALUE]``. + Attributes: attribute_map (Sequence[google.cloud.trace_v2.types.Span.Attributes.AttributeMapEntry]): The set of attributes. Each attribute's key can be up to 128 @@ -152,6 +153,7 @@ class Attributes(proto.Message): class TimeEvent(proto.Message): r"""A time-stamped annotation or message event in the Span. + Attributes: time (google.protobuf.timestamp_pb2.Timestamp): The timestamp indicating the time the event @@ -165,6 +167,7 @@ class TimeEvent(proto.Message): class Annotation(proto.Message): r"""Text annotation with a set of attributes. + Attributes: description (google.cloud.trace_v2.types.TruncatableString): A user-supplied message describing the event. @@ -184,6 +187,7 @@ class Annotation(proto.Message): class MessageEvent(proto.Message): r"""An event describing a message sent/received between Spans. + Attributes: type (google.cloud.trace_v2.types.Span.TimeEvent.MessageEvent.Type): Type of MessageEvent. Indicates whether the @@ -321,6 +325,7 @@ class Links(proto.Message): class AttributeValue(proto.Message): r"""The allowed types for [VALUE] in a ``[KEY]:[VALUE]`` attribute. + Attributes: string_value (google.cloud.trace_v2.types.TruncatableString): A string up to 256 bytes long. @@ -339,6 +344,7 @@ class AttributeValue(proto.Message): class StackTrace(proto.Message): r"""A call stack appearing in a trace. + Attributes: stack_frames (google.cloud.trace_v2.types.StackTrace.StackFrames): Stack frames in this stack trace. A maximum @@ -357,6 +363,7 @@ class StackTrace(proto.Message): class StackFrame(proto.Message): r"""Represents a single stack frame in a stack trace. + Attributes: function_name (google.cloud.trace_v2.types.TruncatableString): The fully-qualified name that uniquely @@ -400,6 +407,7 @@ class StackFrame(proto.Message): class StackFrames(proto.Message): r"""A collection of stack frames, which can be truncated. + Attributes: frame (Sequence[google.cloud.trace_v2.types.StackTrace.StackFrame]): Stack frames in this call stack. @@ -421,6 +429,7 @@ class StackFrames(proto.Message): class Module(proto.Message): r"""Binary module. + Attributes: module (google.cloud.trace_v2.types.TruncatableString): For example: main binary, kernel modules, and diff --git a/google/cloud/trace_v2/types/tracing.py b/google/cloud/trace_v2/types/tracing.py index bead9f56..98fbff3f 100644 --- a/google/cloud/trace_v2/types/tracing.py +++ b/google/cloud/trace_v2/types/tracing.py @@ -25,6 +25,7 @@ class BatchWriteSpansRequest(proto.Message): r"""The request message for the ``BatchWriteSpans`` method. + Attributes: name (str): Required. The name of the project where the spans belong. diff --git a/tests/unit/gapic/trace_v1/test_trace_service.py b/tests/unit/gapic/trace_v1/test_trace_service.py index acd8724e..07546651 100644 --- a/tests/unit/gapic/trace_v1/test_trace_service.py +++ b/tests/unit/gapic/trace_v1/test_trace_service.py @@ -29,6 +29,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import path_template from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.trace_v1.services.trace_service import TraceServiceAsyncClient @@ -1143,6 +1144,9 @@ def test_trace_service_base_transport(): with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) + with pytest.raises(NotImplementedError): + transport.close() + @requires_google_auth_gte_1_25_0 def test_trace_service_base_transport_with_credentials_file(): @@ -1618,3 +1622,49 @@ def test_client_withDEFAULT_CLIENT_INFO(): credentials=ga_credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) + + +@pytest.mark.asyncio +async def test_transport_close_async(): + client = TraceServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc_asyncio", + ) + with mock.patch.object( + type(getattr(client.transport, "grpc_channel")), "close" + ) as close: + async with client: + close.assert_not_called() + close.assert_called_once() + + +def test_transport_close(): + transports = { + "grpc": "_grpc_channel", + } + + for transport, close_name in transports.items(): + client = TraceServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + with mock.patch.object( + type(getattr(client.transport, close_name)), "close" + ) as close: + with client: + close.assert_not_called() + close.assert_called_once() + + +def test_client_ctx(): + transports = [ + "grpc", + ] + for transport in transports: + client = TraceServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + # Test client calls underlying transport. + with mock.patch.object(type(client.transport), "close") as close: + close.assert_not_called() + with client: + pass + close.assert_called() diff --git a/tests/unit/gapic/trace_v2/test_trace_service.py b/tests/unit/gapic/trace_v2/test_trace_service.py index 6cb35598..c5c0023a 100644 --- a/tests/unit/gapic/trace_v2/test_trace_service.py +++ b/tests/unit/gapic/trace_v2/test_trace_service.py @@ -29,6 +29,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import path_template from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.trace_v2.services.trace_service import TraceServiceAsyncClient @@ -945,6 +946,9 @@ def test_trace_service_base_transport(): with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) + with pytest.raises(NotImplementedError): + transport.close() + @requires_google_auth_gte_1_25_0 def test_trace_service_base_transport_with_credentials_file(): @@ -1437,3 +1441,49 @@ def test_client_withDEFAULT_CLIENT_INFO(): credentials=ga_credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) + + +@pytest.mark.asyncio +async def test_transport_close_async(): + client = TraceServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc_asyncio", + ) + with mock.patch.object( + type(getattr(client.transport, "grpc_channel")), "close" + ) as close: + async with client: + close.assert_not_called() + close.assert_called_once() + + +def test_transport_close(): + transports = { + "grpc": "_grpc_channel", + } + + for transport, close_name in transports.items(): + client = TraceServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + with mock.patch.object( + type(getattr(client.transport, close_name)), "close" + ) as close: + with client: + close.assert_not_called() + close.assert_called_once() + + +def test_client_ctx(): + transports = [ + "grpc", + ] + for transport in transports: + client = TraceServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + # Test client calls underlying transport. + with mock.patch.object(type(client.transport), "close") as close: + close.assert_not_called() + with client: + pass + close.assert_called()