Skip to content

Commit

Permalink
Merge branch 'master' into cmek-support
Browse files Browse the repository at this point in the history
  • Loading branch information
larkee committed Jan 5, 2021
2 parents 514e675 + 87789c9 commit 2e85bc6
Show file tree
Hide file tree
Showing 25 changed files with 350 additions and 79 deletions.
13 changes: 13 additions & 0 deletions .github/sync-repo-settings.yaml
@@ -0,0 +1,13 @@
# https://github.com/googleapis/repo-automation-bots/tree/master/packages/sync-repo-settings
# Rules for master branch protection
branchProtectionRules:
# Identifies the protection rule pattern. Name of the branch to be protected.
# Defaults to `master`
- pattern: master
requiredStatusCheckContexts:
- 'Kokoro'
- 'cla/google'
- 'Samples - Lint'
- 'Samples - Python 3.6'
- 'Samples - Python 3.7'
- 'Samples - Python 3.8'
Expand Up @@ -28,7 +28,6 @@
_transport_registry["grpc"] = DatabaseAdminGrpcTransport
_transport_registry["grpc_asyncio"] = DatabaseAdminGrpcAsyncIOTransport


__all__ = (
"DatabaseAdminTransport",
"DatabaseAdminGrpcTransport",
Expand Down
Expand Up @@ -158,6 +158,10 @@ def __init__(
ssl_credentials=ssl_credentials,
scopes=scopes or self.AUTH_SCOPES,
quota_project_id=quota_project_id,
options=[
("grpc.max_send_message_length", -1),
("grpc.max_receive_message_length", -1),
],
)
self._ssl_channel_credentials = ssl_credentials
else:
Expand All @@ -176,9 +180,14 @@ def __init__(
ssl_credentials=ssl_channel_credentials,
scopes=scopes or self.AUTH_SCOPES,
quota_project_id=quota_project_id,
options=[
("grpc.max_send_message_length", -1),
("grpc.max_receive_message_length", -1),
],
)

self._stubs = {} # type: Dict[str, Callable]
self._operations_client = None

# Run the base constructor.
super().__init__(
Expand All @@ -202,7 +211,7 @@ def create_channel(
) -> grpc.Channel:
"""Create and return a gRPC channel object.
Args:
address (Optionsl[str]): The host for the channel to use.
address (Optional[str]): The host for the channel to use.
credentials (Optional[~.Credentials]): The
authorization credentials to attach to requests. These
credentials identify this application to the service. If
Expand Down Expand Up @@ -249,13 +258,11 @@ def operations_client(self) -> operations_v1.OperationsClient:
client.
"""
# Sanity check: Only create a new client if we do not already have one.
if "operations_client" not in self.__dict__:
self.__dict__["operations_client"] = operations_v1.OperationsClient(
self.grpc_channel
)
if self._operations_client is None:
self._operations_client = operations_v1.OperationsClient(self.grpc_channel)

# Return the client from cache.
return self.__dict__["operations_client"]
return self._operations_client

@property
def list_databases(
Expand Down
Expand Up @@ -203,6 +203,10 @@ def __init__(
ssl_credentials=ssl_credentials,
scopes=scopes or self.AUTH_SCOPES,
quota_project_id=quota_project_id,
options=[
("grpc.max_send_message_length", -1),
("grpc.max_receive_message_length", -1),
],
)
self._ssl_channel_credentials = ssl_credentials
else:
Expand All @@ -221,6 +225,10 @@ def __init__(
ssl_credentials=ssl_channel_credentials,
scopes=scopes or self.AUTH_SCOPES,
quota_project_id=quota_project_id,
options=[
("grpc.max_send_message_length", -1),
("grpc.max_receive_message_length", -1),
],
)

# Run the base constructor.
Expand All @@ -234,6 +242,7 @@ def __init__(
)

self._stubs = {}
self._operations_client = None

@property
def grpc_channel(self) -> aio.Channel:
Expand All @@ -253,13 +262,13 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient:
client.
"""
# Sanity check: Only create a new client if we do not already have one.
if "operations_client" not in self.__dict__:
self.__dict__["operations_client"] = operations_v1.OperationsAsyncClient(
if self._operations_client is None:
self._operations_client = operations_v1.OperationsAsyncClient(
self.grpc_channel
)

# Return the client from cache.
return self.__dict__["operations_client"]
return self._operations_client

@property
def list_databases(
Expand Down
3 changes: 2 additions & 1 deletion google/cloud/spanner_admin_database_v1/types/__init__.py
Expand Up @@ -53,9 +53,9 @@
RestoreDatabaseEncryptionConfig,
RestoreDatabaseMetadata,
OptimizeRestoredDatabaseMetadata,
RestoreSourceType,
)


__all__ = (
"OperationProgress",
"EncryptionConfig",
Expand Down Expand Up @@ -90,4 +90,5 @@
"RestoreDatabaseEncryptionConfig",
"RestoreDatabaseMetadata",
"OptimizeRestoredDatabaseMetadata",
"RestoreSourceType",
)
Expand Up @@ -28,7 +28,6 @@
_transport_registry["grpc"] = InstanceAdminGrpcTransport
_transport_registry["grpc_asyncio"] = InstanceAdminGrpcAsyncIOTransport


__all__ = (
"InstanceAdminTransport",
"InstanceAdminGrpcTransport",
Expand Down
Expand Up @@ -171,6 +171,10 @@ def __init__(
ssl_credentials=ssl_credentials,
scopes=scopes or self.AUTH_SCOPES,
quota_project_id=quota_project_id,
options=[
("grpc.max_send_message_length", -1),
("grpc.max_receive_message_length", -1),
],
)
self._ssl_channel_credentials = ssl_credentials
else:
Expand All @@ -189,9 +193,14 @@ def __init__(
ssl_credentials=ssl_channel_credentials,
scopes=scopes or self.AUTH_SCOPES,
quota_project_id=quota_project_id,
options=[
("grpc.max_send_message_length", -1),
("grpc.max_receive_message_length", -1),
],
)

self._stubs = {} # type: Dict[str, Callable]
self._operations_client = None

# Run the base constructor.
super().__init__(
Expand All @@ -215,7 +224,7 @@ def create_channel(
) -> grpc.Channel:
"""Create and return a gRPC channel object.
Args:
address (Optionsl[str]): The host for the channel to use.
address (Optional[str]): The host for the channel to use.
credentials (Optional[~.Credentials]): The
authorization credentials to attach to requests. These
credentials identify this application to the service. If
Expand Down Expand Up @@ -262,13 +271,11 @@ def operations_client(self) -> operations_v1.OperationsClient:
client.
"""
# Sanity check: Only create a new client if we do not already have one.
if "operations_client" not in self.__dict__:
self.__dict__["operations_client"] = operations_v1.OperationsClient(
self.grpc_channel
)
if self._operations_client is None:
self._operations_client = operations_v1.OperationsClient(self.grpc_channel)

# Return the client from cache.
return self.__dict__["operations_client"]
return self._operations_client

@property
def list_instance_configs(
Expand Down
Expand Up @@ -216,6 +216,10 @@ def __init__(
ssl_credentials=ssl_credentials,
scopes=scopes or self.AUTH_SCOPES,
quota_project_id=quota_project_id,
options=[
("grpc.max_send_message_length", -1),
("grpc.max_receive_message_length", -1),
],
)
self._ssl_channel_credentials = ssl_credentials
else:
Expand All @@ -234,6 +238,10 @@ def __init__(
ssl_credentials=ssl_channel_credentials,
scopes=scopes or self.AUTH_SCOPES,
quota_project_id=quota_project_id,
options=[
("grpc.max_send_message_length", -1),
("grpc.max_receive_message_length", -1),
],
)

# Run the base constructor.
Expand All @@ -247,6 +255,7 @@ def __init__(
)

self._stubs = {}
self._operations_client = None

@property
def grpc_channel(self) -> aio.Channel:
Expand All @@ -266,13 +275,13 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient:
client.
"""
# Sanity check: Only create a new client if we do not already have one.
if "operations_client" not in self.__dict__:
self.__dict__["operations_client"] = operations_v1.OperationsAsyncClient(
if self._operations_client is None:
self._operations_client = operations_v1.OperationsAsyncClient(
self.grpc_channel
)

# Return the client from cache.
return self.__dict__["operations_client"]
return self._operations_client

@property
def list_instance_configs(
Expand Down
1 change: 0 additions & 1 deletion google/cloud/spanner_admin_instance_v1/types/__init__.py
Expand Up @@ -32,7 +32,6 @@
UpdateInstanceMetadata,
)


__all__ = (
"ReplicaInfo",
"InstanceConfig",
Expand Down
50 changes: 38 additions & 12 deletions google/cloud/spanner_dbapi/connection.py
Expand Up @@ -22,6 +22,9 @@
from google.cloud import spanner_v1 as spanner
from google.cloud.spanner_v1.session import _get_retry_delay

from google.cloud.spanner_dbapi._helpers import _execute_insert_heterogenous
from google.cloud.spanner_dbapi._helpers import _execute_insert_homogenous
from google.cloud.spanner_dbapi._helpers import parse_insert
from google.cloud.spanner_dbapi.checksum import _compare_checksums
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
from google.cloud.spanner_dbapi.cursor import Cursor
Expand Down Expand Up @@ -82,7 +85,7 @@ def autocommit(self, value):
:type value: bool
:param value: New autocommit mode state.
"""
if value and not self._autocommit:
if value and not self._autocommit and self.inside_transaction:
self.commit()

self._autocommit = value
Expand All @@ -96,6 +99,19 @@ def database(self):
"""
return self._database

@property
def inside_transaction(self):
"""Flag: transaction is started.
Returns:
bool: True if transaction begun, False otherwise.
"""
return (
self._transaction
and not self._transaction.committed
and not self._transaction.rolled_back
)

@property
def instance(self):
"""Instance to which this connection relates.
Expand Down Expand Up @@ -191,11 +207,7 @@ def transaction_checkout(self):
:returns: A Cloud Spanner transaction object, ready to use.
"""
if not self.autocommit:
if (
not self._transaction
or self._transaction.committed
or self._transaction.rolled_back
):
if not self.inside_transaction:
self._transaction = self._session_checkout().transaction()
self._transaction.begin()

Expand All @@ -216,11 +228,7 @@ def close(self):
The connection will be unusable from this point forward. If the
connection has an active transaction, it will be rolled back.
"""
if (
self._transaction
and not self._transaction.committed
and not self._transaction.rolled_back
):
if self.inside_transaction:
self._transaction.rollback()

if self._own_pool:
Expand All @@ -235,7 +243,7 @@ def commit(self):
"""
if self._autocommit:
warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2)
elif self._transaction:
elif self.inside_transaction:
try:
self._transaction.commit()
self._release_session()
Expand Down Expand Up @@ -291,6 +299,24 @@ def run_statement(self, statement, retried=False):
if not retried:
self._statements.append(statement)

if statement.is_insert:
parts = parse_insert(statement.sql, statement.params)

if parts.get("homogenous"):
_execute_insert_homogenous(transaction, parts)
return (
iter(()),
ResultsChecksum() if retried else statement.checksum,
)
else:
_execute_insert_heterogenous(
transaction, parts.get("sql_params_list"),
)
return (
iter(()),
ResultsChecksum() if retried else statement.checksum,
)

return (
transaction.execute_sql(
statement.sql, statement.params, param_types=statement.param_types,
Expand Down

0 comments on commit 2e85bc6

Please sign in to comment.