Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add context manager support in client #328

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -467,6 +467,12 @@ async def split_read_stream(
# Done; return the response.
return response

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc, tb):
await self.transport.close()


try:
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
Expand Down
Expand Up @@ -381,10 +381,7 @@ def __init__(
client_cert_source_for_mtls=client_cert_source_func,
quota_project_id=client_options.quota_project_id,
client_info=client_info,
always_use_jwt_access=(
Transport == type(self).get_transport_class("grpc")
or Transport == type(self).get_transport_class("grpc_asyncio")
),
always_use_jwt_access=True,
)

def create_read_session(
Expand Down Expand Up @@ -660,6 +657,19 @@ def split_read_stream(
# Done; return the response.
return response

def __enter__(self):
return self

def __exit__(self, type, value, traceback):
"""Releases underlying transport's resources.

.. warning::
ONLY use as a context manager if the transport is NOT shared
with other clients! Exiting the with block will CLOSE the transport
and may cause errors in other clients!
"""
self.transport.close()


try:
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
Expand Down
Expand Up @@ -205,6 +205,15 @@ def _prep_wrapped_messages(self, client_info):
),
}

def close(self):
"""Closes resources associated with the transport.

.. warning::
Only call this method if the transport is NOT shared
with other clients - this may cause errors in other clients!
"""
raise NotImplementedError()

@property
def create_read_session(
self,
Expand Down
Expand Up @@ -346,5 +346,8 @@ def split_read_stream(
)
return self._stubs["split_read_stream"]

def close(self):
self.grpc_channel.close()


__all__ = ("BigQueryReadGrpcTransport",)
Expand Up @@ -351,5 +351,8 @@ def split_read_stream(
)
return self._stubs["split_read_stream"]

def close(self):
return self.grpc_channel.close()


__all__ = ("BigQueryReadGrpcAsyncIOTransport",)
Expand Up @@ -700,6 +700,12 @@ async def flush_rows(
# Done; return the response.
return response

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc, tb):
await self.transport.close()


try:
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
Expand Down
Expand Up @@ -365,10 +365,7 @@ def __init__(
client_cert_source_for_mtls=client_cert_source_func,
quota_project_id=client_options.quota_project_id,
client_info=client_info,
always_use_jwt_access=(
Transport == type(self).get_transport_class("grpc")
or Transport == type(self).get_transport_class("grpc_asyncio")
),
always_use_jwt_access=True,
)

def create_write_stream(
Expand Down Expand Up @@ -831,6 +828,19 @@ def flush_rows(
# Done; return the response.
return response

def __enter__(self):
return self

def __exit__(self, type, value, traceback):
"""Releases underlying transport's resources.

.. warning::
ONLY use as a context manager if the transport is NOT shared
with other clients! Exiting the with block will CLOSE the transport
and may cause errors in other clients!
"""
self.transport.close()


try:
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
Expand Down
Expand Up @@ -250,6 +250,15 @@ def _prep_wrapped_messages(self, client_info):
),
}

def close(self):
"""Closes resources associated with the transport.

.. warning::
Only call this method if the transport is NOT shared
with other clients - this may cause errors in other clients!
"""
raise NotImplementedError()

@property
def create_write_stream(
self,
Expand Down
Expand Up @@ -444,5 +444,8 @@ def flush_rows(
)
return self._stubs["flush_rows"]

def close(self):
self.grpc_channel.close()


__all__ = ("BigQueryWriteGrpcTransport",)
Expand Up @@ -448,5 +448,8 @@ def flush_rows(
)
return self._stubs["flush_rows"]

def close(self):
return self.grpc_channel.close()


__all__ = ("BigQueryWriteGrpcAsyncIOTransport",)
2 changes: 2 additions & 0 deletions google/cloud/bigquery_storage_v1/types/avro.py
Expand Up @@ -23,6 +23,7 @@

class AvroSchema(proto.Message):
r"""Avro schema.

Attributes:
schema (str):
Json serialized schema, as described at
Expand All @@ -34,6 +35,7 @@ class AvroSchema(proto.Message):

class AvroRows(proto.Message):
r"""Avro rows.

Attributes:
serialized_binary_rows (bytes):
Binary serialized rows in a block.
Expand Down
1 change: 1 addition & 0 deletions google/cloud/bigquery_storage_v1/types/protobuf.py
Expand Up @@ -49,6 +49,7 @@ class ProtoSchema(proto.Message):

class ProtoRows(proto.Message):
r"""

Attributes:
serialized_rows (Sequence[bytes]):
A sequence of rows serialized as a Protocol
Expand Down
17 changes: 17 additions & 0 deletions google/cloud/bigquery_storage_v1/types/storage.py
Expand Up @@ -52,6 +52,7 @@

class CreateReadSessionRequest(proto.Message):
r"""Request message for ``CreateReadSession``.

Attributes:
parent (str):
Required. The request project that owns the session, in the
Expand Down Expand Up @@ -79,6 +80,7 @@ class CreateReadSessionRequest(proto.Message):

class ReadRowsRequest(proto.Message):
r"""Request message for ``ReadRows``.

Attributes:
read_stream (str):
Required. Stream to read rows from.
Expand All @@ -95,6 +97,7 @@ class ReadRowsRequest(proto.Message):

class ThrottleState(proto.Message):
r"""Information on if the current connection is being throttled.

Attributes:
throttle_percent (int):
How much this connection is being throttled.
Expand All @@ -107,6 +110,7 @@ class ThrottleState(proto.Message):

class StreamStats(proto.Message):
r"""Estimated stream statistics for a given read Stream.

Attributes:
progress (google.cloud.bigquery_storage_v1.types.StreamStats.Progress):
Represents the progress of the current
Expand All @@ -115,6 +119,7 @@ class StreamStats(proto.Message):

class Progress(proto.Message):
r"""

Attributes:
at_response_start (float):
The fraction of rows assigned to the stream that have been
Expand Down Expand Up @@ -183,6 +188,7 @@ class ReadRowsResponse(proto.Message):

class SplitReadStreamRequest(proto.Message):
r"""Request message for ``SplitReadStream``.

Attributes:
name (str):
Required. Name of the stream to split.
Expand All @@ -207,6 +213,7 @@ class SplitReadStreamRequest(proto.Message):

class SplitReadStreamResponse(proto.Message):
r"""Response message for ``SplitReadStream``.

Attributes:
primary_stream (google.cloud.bigquery_storage_v1.types.ReadStream):
Primary stream, which contains the beginning portion of
Expand All @@ -224,6 +231,7 @@ class SplitReadStreamResponse(proto.Message):

class CreateWriteStreamRequest(proto.Message):
r"""Request message for ``CreateWriteStream``.

Attributes:
parent (str):
Required. Reference to the table to which the stream
Expand Down Expand Up @@ -303,6 +311,7 @@ class ProtoData(proto.Message):

class AppendRowsResponse(proto.Message):
r"""Response message for ``AppendRows``.

Attributes:
append_result (google.cloud.bigquery_storage_v1.types.AppendRowsResponse.AppendResult):
Result if the append is successful.
Expand Down Expand Up @@ -339,6 +348,7 @@ class AppendRowsResponse(proto.Message):

class AppendResult(proto.Message):
r"""AppendResult is returned for successful append requests.

Attributes:
offset (google.protobuf.wrappers_pb2.Int64Value):
The row offset at which the last append
Expand All @@ -359,6 +369,7 @@ class AppendResult(proto.Message):

class GetWriteStreamRequest(proto.Message):
r"""Request message for ``GetWriteStreamRequest``.

Attributes:
name (str):
Required. Name of the stream to get, in the form of
Expand All @@ -370,6 +381,7 @@ class GetWriteStreamRequest(proto.Message):

class BatchCommitWriteStreamsRequest(proto.Message):
r"""Request message for ``BatchCommitWriteStreams``.

Attributes:
parent (str):
Required. Parent table that all the streams should belong
Expand All @@ -386,6 +398,7 @@ class BatchCommitWriteStreamsRequest(proto.Message):

class BatchCommitWriteStreamsResponse(proto.Message):
r"""Response message for ``BatchCommitWriteStreams``.

Attributes:
commit_time (google.protobuf.timestamp_pb2.Timestamp):
The time at which streams were committed in microseconds
Expand All @@ -409,6 +422,7 @@ class BatchCommitWriteStreamsResponse(proto.Message):

class FinalizeWriteStreamRequest(proto.Message):
r"""Request message for invoking ``FinalizeWriteStream``.

Attributes:
name (str):
Required. Name of the stream to finalize, in the form of
Expand All @@ -420,6 +434,7 @@ class FinalizeWriteStreamRequest(proto.Message):

class FinalizeWriteStreamResponse(proto.Message):
r"""Response message for ``FinalizeWriteStream``.

Attributes:
row_count (int):
Number of rows in the finalized stream.
Expand All @@ -430,6 +445,7 @@ class FinalizeWriteStreamResponse(proto.Message):

class FlushRowsRequest(proto.Message):
r"""Request message for ``FlushRows``.

Attributes:
write_stream (str):
Required. The stream that is the target of
Expand All @@ -446,6 +462,7 @@ class FlushRowsRequest(proto.Message):

class FlushRowsResponse(proto.Message):
r"""Respond message for ``FlushRows``.

Attributes:
offset (int):
The rows before this offset (including this
Expand Down
3 changes: 3 additions & 0 deletions google/cloud/bigquery_storage_v1/types/stream.py
Expand Up @@ -36,6 +36,7 @@ class DataFormat(proto.Enum):

class ReadSession(proto.Message):
r"""Information about the ReadSession.

Attributes:
name (str):
Output only. Unique identifier for the session, in the form
Expand Down Expand Up @@ -79,6 +80,7 @@ class ReadSession(proto.Message):

class TableModifiers(proto.Message):
r"""Additional attributes when reading a table.

Attributes:
snapshot_time (google.protobuf.timestamp_pb2.Timestamp):
The snapshot time of the table. If not set,
Expand All @@ -91,6 +93,7 @@ class TableModifiers(proto.Message):

class TableReadOptions(proto.Message):
r"""Options dictating how we read a table.

Attributes:
selected_fields (Sequence[str]):
Names of the fields in the table that should be read. If
Expand Down
1 change: 1 addition & 0 deletions google/cloud/bigquery_storage_v1/types/table.py
Expand Up @@ -24,6 +24,7 @@

class TableSchema(proto.Message):
r"""Schema of a table.

Attributes:
fields (Sequence[google.cloud.bigquery_storage_v1.types.TableFieldSchema]):
Describes the fields in a table.
Expand Down
Expand Up @@ -469,6 +469,12 @@ async def split_read_stream(
# Done; return the response.
return response

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc, tb):
await self.transport.close()


try:
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
Expand Down
Expand Up @@ -383,10 +383,7 @@ def __init__(
client_cert_source_for_mtls=client_cert_source_func,
quota_project_id=client_options.quota_project_id,
client_info=client_info,
always_use_jwt_access=(
Transport == type(self).get_transport_class("grpc")
or Transport == type(self).get_transport_class("grpc_asyncio")
),
always_use_jwt_access=True,
)

def create_read_session(
Expand Down Expand Up @@ -662,6 +659,19 @@ def split_read_stream(
# Done; return the response.
return response

def __enter__(self):
return self

def __exit__(self, type, value, traceback):
"""Releases underlying transport's resources.

.. warning::
ONLY use as a context manager if the transport is NOT shared
with other clients! Exiting the with block will CLOSE the transport
and may cause errors in other clients!
"""
self.transport.close()


try:
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
Expand Down