Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implementation for batch dml in dbapi #1055

Merged
merged 4 commits into from Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
119 changes: 119 additions & 0 deletions google/cloud/spanner_dbapi/batch_dml_executor.py
@@ -0,0 +1,119 @@
from __future__ import annotations
olavloite marked this conversation as resolved.
Show resolved Hide resolved

from enum import Enum
from typing import TYPE_CHECKING, List
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
from google.cloud.spanner_dbapi.parsed_statement import (
ParsedStatement,
StatementType,
Statement,
)
from google.rpc.code_pb2 import ABORTED, OK
from google.api_core.exceptions import Aborted

from google.cloud.spanner_dbapi.utils import StreamedManyResultSets

if TYPE_CHECKING:
from google.cloud.spanner_dbapi.cursor import Cursor


class BatchDmlExecutor:
"""Executor that is used when a DML batch is started. These batches only
accept DML statements. All DML statements are buffered locally and sent to
Spanner when runBatch() is called.

:type "Cursor": :class:`~google.cloud.spanner_dbapi.cursor.Cursor`
:param cursor:
"""

def __init__(self, cursor: "Cursor"):
self._cursor = cursor
self._connection = cursor.connection
self._statements: List[Statement] = []

def execute_statement(self, parsed_statement: ParsedStatement):
"""Executes the statement when dml batch is active by buffering the
statement in-memory.

:type parsed_statement: ParsedStatement
:param parsed_statement: parsed statement containing sql query and query
params
"""
from google.cloud.spanner_dbapi import ProgrammingError

if (
parsed_statement.statement_type != StatementType.UPDATE
and parsed_statement.statement_type != StatementType.INSERT
):
raise ProgrammingError(
"Only DML statements are allowed in batch " "DML mode."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can this be just one string instead of two concatenated strings?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

)
self._statements.append(parsed_statement.statement)

def run_batch_dml(self):
"""Executes all the buffered statements on the active dml batch by
making a call to Spanner.
"""
return run_batch_dml(self._cursor, self._statements)


def run_batch_dml(cursor: "Cursor", statements: List[Statement]):
"""Executes all the dml statements by making a batch call to Spanner.

:type cursor: Cursor
:param cursor: Database Cursor object

:type statements: List[Statement]
:param statements: list of statements to execute in batch
"""
from google.cloud.spanner_dbapi import OperationalError

connection = cursor.connection
olavloite marked this conversation as resolved.
Show resolved Hide resolved
many_result_set = StreamedManyResultSets()
statements_tuple = []
for statement in statements:
statements_tuple.append(statement.get_tuple())
if not connection._client_transaction_started:
res = connection.database.run_in_transaction(_do_batch_update, statements_tuple)
many_result_set.add_iter(res)
cursor._row_count = sum([max(val, 0) for val in res])
else:
retried = False
while True:
try:
transaction = connection.transaction_checkout()
status, res = transaction.batch_update(statements_tuple)
many_result_set.add_iter(res)
res_checksum = ResultsChecksum()
res_checksum.consume_result(res)
res_checksum.consume_result(status.code)
if not retried:
connection._statements.append((statements, res_checksum))
cursor._row_count = sum([max(val, 0) for val in res])

if status.code == ABORTED:
connection._transaction = None
raise Aborted(status.message)
elif status.code != OK:
raise OperationalError(status.message)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should (could) this also include the status code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will take it in a follow up PR

return many_result_set
except Aborted:
connection.retry_transaction()
retried = True


def _do_batch_update(transaction, statements):
from google.cloud.spanner_dbapi import OperationalError

status, res = transaction.batch_update(statements)
if status.code == ABORTED:
raise Aborted(status.message)
elif status.code != OK:
raise OperationalError(status.message)
return res


class BatchMode(Enum):
DML = 1
DDL = 2
NONE = 3
14 changes: 10 additions & 4 deletions google/cloud/spanner_dbapi/client_side_statement_executor.py
Expand Up @@ -14,7 +14,7 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from google.cloud.spanner_dbapi import Connection
from google.cloud.spanner_dbapi.cursor import Cursor
from google.cloud.spanner_dbapi import ProgrammingError

from google.cloud.spanner_dbapi.parsed_statement import (
Expand All @@ -38,17 +38,18 @@
)


def execute(connection: "Connection", parsed_statement: ParsedStatement):
def execute(cursor: "Cursor", parsed_statement: ParsedStatement):
"""Executes the client side statements by calling the relevant method.

It is an internal method that can make backwards-incompatible changes.

:type connection: Connection
:param connection: Connection object of the dbApi
:type cursor: Cursor
:param cursor: Cursor object of the dbApi

:type parsed_statement: ParsedStatement
:param parsed_statement: parsed_statement based on the sql query
"""
connection = cursor.connection
if connection.is_closed:
raise ProgrammingError(CONNECTION_CLOSED_ERROR)
statement_type = parsed_statement.client_side_statement_type
Expand Down Expand Up @@ -81,6 +82,11 @@ def execute(connection: "Connection", parsed_statement: ParsedStatement):
TypeCode.TIMESTAMP,
read_timestamp,
)
if statement_type == ClientSideStatementType.START_BATCH_DML:
connection.start_batch_dml(cursor)
return None
if statement_type == ClientSideStatementType.RUN_BATCH:
return connection.run_batch()


def _get_streamed_result_set(column_name, type_code, column_value):
Expand Down
Expand Up @@ -18,6 +18,7 @@
ParsedStatement,
StatementType,
ClientSideStatementType,
Statement,
)

RE_BEGIN = re.compile(r"^\s*(BEGIN|START)(TRANSACTION)?", re.IGNORECASE)
Expand All @@ -29,6 +30,8 @@
RE_SHOW_READ_TIMESTAMP = re.compile(
r"^\s*(SHOW)\s+(VARIABLE)\s+(READ_TIMESTAMP)", re.IGNORECASE
)
RE_START_BATCH_DML = re.compile(r"^\s*(START)\s+(BATCH)\s+(DML)", re.IGNORECASE)
RE_RUN_BATCH = re.compile(r"^\s*(RUN)\s+(BATCH)", re.IGNORECASE)
olavloite marked this conversation as resolved.
Show resolved Hide resolved


def parse_stmt(query):
Expand All @@ -54,8 +57,12 @@ def parse_stmt(query):
client_side_statement_type = ClientSideStatementType.SHOW_COMMIT_TIMESTAMP
if RE_SHOW_READ_TIMESTAMP.match(query):
client_side_statement_type = ClientSideStatementType.SHOW_READ_TIMESTAMP
if RE_START_BATCH_DML.match(query):
client_side_statement_type = ClientSideStatementType.START_BATCH_DML
if RE_RUN_BATCH.match(query):
client_side_statement_type = ClientSideStatementType.RUN_BATCH
if client_side_statement_type is not None:
return ParsedStatement(
StatementType.CLIENT_SIDE, query, client_side_statement_type
StatementType.CLIENT_SIDE, Statement(query), client_side_statement_type
)
return None
71 changes: 66 additions & 5 deletions google/cloud/spanner_dbapi/connection.py
Expand Up @@ -13,13 +13,14 @@
# limitations under the License.

"""DB-API Connection for the Google Cloud Spanner."""

import time
import warnings

from google.api_core.exceptions import Aborted
from google.api_core.gapic_v1.client_info import ClientInfo
from google.cloud import spanner_v1 as spanner
from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode, BatchDmlExecutor
from google.cloud.spanner_dbapi.parsed_statement import ParsedStatement, Statement
from google.cloud.spanner_v1 import RequestOptions
from google.cloud.spanner_v1.session import _get_retry_delay
from google.cloud.spanner_v1.snapshot import Snapshot
Expand All @@ -28,7 +29,11 @@
from google.cloud.spanner_dbapi.checksum import _compare_checksums
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
from google.cloud.spanner_dbapi.cursor import Cursor
from google.cloud.spanner_dbapi.exceptions import InterfaceError, OperationalError
from google.cloud.spanner_dbapi.exceptions import (
InterfaceError,
OperationalError,
ProgrammingError,
)
from google.cloud.spanner_dbapi.version import DEFAULT_USER_AGENT
from google.cloud.spanner_dbapi.version import PY_VERSION

Expand Down Expand Up @@ -111,6 +116,8 @@ def __init__(self, instance, database=None, read_only=False):
# whether transaction started at Spanner. This means that we had
# made atleast one call to Spanner.
self._spanner_transaction_started = False
self._batch_mode = BatchMode.NONE
self._batch_dml_executor: BatchDmlExecutor = None

@property
def autocommit(self):
Expand Down Expand Up @@ -196,6 +203,24 @@ def read_only(self, value):
)
self._read_only = value

@property
def batch_mode(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand it correctly, giving these names that do not start with an underscore will make them part of the public API. In that case, we should document them and also add validations to verify that they are only called with valid arguments. (But probably we should make them internal instead)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the property

"""_batch_mode flag for this connection.

:rtype: bool
:returns: _batch_mode flag value.
"""
return self._batch_mode

@batch_mode.setter
def batch_mode(self, value):
"""`batch_mode` flag setter.

Args:
value (BatchMode)
"""
self._batch_mode = value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if an external uses calls this function when the connection is already in the middle of a different type of batch (e.g. it is now in a DML batch, and it is called to set it to DDL batch)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed this property


@property
def request_options(self):
"""Options for the next SQL operations.
Expand Down Expand Up @@ -310,7 +335,10 @@ def _rerun_previous_statements(self):
statements, checksum = statement

transaction = self.transaction_checkout()
status, res = transaction.batch_update(statements)
statements_tuple = []
for single_statement in statements:
statements_tuple.append(single_statement.get_tuple())
status, res = transaction.batch_update(statements_tuple)

if status.code == ABORTED:
raise Aborted(status.details)
Expand Down Expand Up @@ -476,14 +504,14 @@ def run_prior_DDL_statements(self):

return self.database.update_ddl(ddl_statements).result()

def run_statement(self, statement, retried=False):
def run_statement(self, statement: Statement, retried=False):
"""Run single SQL statement in begun transaction.

This method is never used in autocommit mode. In
!autocommit mode however it remembers every executed
SQL statement with its parameters.

:type statement: :class:`dict`
:type statement: :class:`Statement`
:param statement: SQL statement to execute.

:type retried: bool
Expand Down Expand Up @@ -534,6 +562,39 @@ def validate(self):
"Expected: [[1]]" % result
)

@check_not_closed
def start_batch_dml(self, cursor):
if self.batch_mode is not BatchMode.NONE:
raise ProgrammingError(
"Cannot start a DML batch when a batch is already active"
)
if self.read_only:
raise ProgrammingError(
"Cannot start a DML batch when the connection is in read-only mode"
)
self.batch_mode = BatchMode.DML
self._batch_dml_executor = BatchDmlExecutor(cursor)

@check_not_closed
def execute_batch_dml_statement(self, parsed_statement: ParsedStatement):
if self.batch_mode is not BatchMode.DML:
raise ProgrammingError(
"Cannot execute statement when the BatchMode is not DML"
)
self._batch_dml_executor.execute_statement(parsed_statement)

@check_not_closed
def run_batch(self):
if self.batch_mode is BatchMode.NONE:
raise ProgrammingError("Cannot run a batch when the BatchMode is not set")
try:
if self.batch_mode is BatchMode.DML:
many_result_set = self._batch_dml_executor.run_batch_dml()
finally:
self.batch_mode = BatchMode.NONE
self._batch_dml_executor = None
return many_result_set

def __enter__(self):
return self

Expand Down