Skip to content

Commit

Permalink
feat: Implementation for batch dml in dbapi (#1055)
Browse files Browse the repository at this point in the history
* feat: Implementation for batch dml in dbapi

* Few changes

* Incorporated comments
  • Loading branch information
ankiaga committed Dec 14, 2023
1 parent c70d7da commit 7a92315
Show file tree
Hide file tree
Showing 11 changed files with 574 additions and 122 deletions.
131 changes: 131 additions & 0 deletions google/cloud/spanner_dbapi/batch_dml_executor.py
@@ -0,0 +1,131 @@
# Copyright 2023 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from enum import Enum
from typing import TYPE_CHECKING, List
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
from google.cloud.spanner_dbapi.parsed_statement import (
ParsedStatement,
StatementType,
Statement,
)
from google.rpc.code_pb2 import ABORTED, OK
from google.api_core.exceptions import Aborted

from google.cloud.spanner_dbapi.utils import StreamedManyResultSets

if TYPE_CHECKING:
from google.cloud.spanner_dbapi.cursor import Cursor


class BatchDmlExecutor:
"""Executor that is used when a DML batch is started. These batches only
accept DML statements. All DML statements are buffered locally and sent to
Spanner when runBatch() is called.
:type "Cursor": :class:`~google.cloud.spanner_dbapi.cursor.Cursor`
:param cursor:
"""

def __init__(self, cursor: "Cursor"):
self._cursor = cursor
self._connection = cursor.connection
self._statements: List[Statement] = []

def execute_statement(self, parsed_statement: ParsedStatement):
"""Executes the statement when dml batch is active by buffering the
statement in-memory.
:type parsed_statement: ParsedStatement
:param parsed_statement: parsed statement containing sql query and query
params
"""
from google.cloud.spanner_dbapi import ProgrammingError

if (
parsed_statement.statement_type != StatementType.UPDATE
and parsed_statement.statement_type != StatementType.INSERT
):
raise ProgrammingError("Only DML statements are allowed in batch DML mode.")
self._statements.append(parsed_statement.statement)

def run_batch_dml(self):
"""Executes all the buffered statements on the active dml batch by
making a call to Spanner.
"""
return run_batch_dml(self._cursor, self._statements)


def run_batch_dml(cursor: "Cursor", statements: List[Statement]):
"""Executes all the dml statements by making a batch call to Spanner.
:type cursor: Cursor
:param cursor: Database Cursor object
:type statements: List[Statement]
:param statements: list of statements to execute in batch
"""
from google.cloud.spanner_dbapi import OperationalError

connection = cursor.connection
many_result_set = StreamedManyResultSets()
statements_tuple = []
for statement in statements:
statements_tuple.append(statement.get_tuple())
if not connection._client_transaction_started:
res = connection.database.run_in_transaction(_do_batch_update, statements_tuple)
many_result_set.add_iter(res)
cursor._row_count = sum([max(val, 0) for val in res])
else:
retried = False
while True:
try:
transaction = connection.transaction_checkout()
status, res = transaction.batch_update(statements_tuple)
many_result_set.add_iter(res)
res_checksum = ResultsChecksum()
res_checksum.consume_result(res)
res_checksum.consume_result(status.code)
if not retried:
connection._statements.append((statements, res_checksum))
cursor._row_count = sum([max(val, 0) for val in res])

if status.code == ABORTED:
connection._transaction = None
raise Aborted(status.message)
elif status.code != OK:
raise OperationalError(status.message)
return many_result_set
except Aborted:
connection.retry_transaction()
retried = True


def _do_batch_update(transaction, statements):
from google.cloud.spanner_dbapi import OperationalError

status, res = transaction.batch_update(statements)
if status.code == ABORTED:
raise Aborted(status.message)
elif status.code != OK:
raise OperationalError(status.message)
return res


class BatchMode(Enum):
DML = 1
DDL = 2
NONE = 3
16 changes: 12 additions & 4 deletions google/cloud/spanner_dbapi/client_side_statement_executor.py
Expand Up @@ -14,7 +14,7 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from google.cloud.spanner_dbapi import Connection
from google.cloud.spanner_dbapi.cursor import Cursor
from google.cloud.spanner_dbapi import ProgrammingError

from google.cloud.spanner_dbapi.parsed_statement import (
Expand All @@ -38,17 +38,18 @@
)


def execute(connection: "Connection", parsed_statement: ParsedStatement):
def execute(cursor: "Cursor", parsed_statement: ParsedStatement):
"""Executes the client side statements by calling the relevant method.
It is an internal method that can make backwards-incompatible changes.
:type connection: Connection
:param connection: Connection object of the dbApi
:type cursor: Cursor
:param cursor: Cursor object of the dbApi
:type parsed_statement: ParsedStatement
:param parsed_statement: parsed_statement based on the sql query
"""
connection = cursor.connection
if connection.is_closed:
raise ProgrammingError(CONNECTION_CLOSED_ERROR)
statement_type = parsed_statement.client_side_statement_type
Expand Down Expand Up @@ -81,6 +82,13 @@ def execute(connection: "Connection", parsed_statement: ParsedStatement):
TypeCode.TIMESTAMP,
read_timestamp,
)
if statement_type == ClientSideStatementType.START_BATCH_DML:
connection.start_batch_dml(cursor)
return None
if statement_type == ClientSideStatementType.RUN_BATCH:
return connection.run_batch()
if statement_type == ClientSideStatementType.ABORT_BATCH:
return connection.abort_batch()


def _get_streamed_result_set(column_name, type_code, column_value):
Expand Down
12 changes: 11 additions & 1 deletion google/cloud/spanner_dbapi/client_side_statement_parser.py
Expand Up @@ -18,6 +18,7 @@
ParsedStatement,
StatementType,
ClientSideStatementType,
Statement,
)

RE_BEGIN = re.compile(r"^\s*(BEGIN|START)(TRANSACTION)?", re.IGNORECASE)
Expand All @@ -29,6 +30,9 @@
RE_SHOW_READ_TIMESTAMP = re.compile(
r"^\s*(SHOW)\s+(VARIABLE)\s+(READ_TIMESTAMP)", re.IGNORECASE
)
RE_START_BATCH_DML = re.compile(r"^\s*(START)\s+(BATCH)\s+(DML)", re.IGNORECASE)
RE_RUN_BATCH = re.compile(r"^\s*(RUN)\s+(BATCH)", re.IGNORECASE)
RE_ABORT_BATCH = re.compile(r"^\s*(ABORT)\s+(BATCH)", re.IGNORECASE)


def parse_stmt(query):
Expand All @@ -54,8 +58,14 @@ def parse_stmt(query):
client_side_statement_type = ClientSideStatementType.SHOW_COMMIT_TIMESTAMP
if RE_SHOW_READ_TIMESTAMP.match(query):
client_side_statement_type = ClientSideStatementType.SHOW_READ_TIMESTAMP
if RE_START_BATCH_DML.match(query):
client_side_statement_type = ClientSideStatementType.START_BATCH_DML
if RE_RUN_BATCH.match(query):
client_side_statement_type = ClientSideStatementType.RUN_BATCH
if RE_ABORT_BATCH.match(query):
client_side_statement_type = ClientSideStatementType.ABORT_BATCH
if client_side_statement_type is not None:
return ParsedStatement(
StatementType.CLIENT_SIDE, query, client_side_statement_type
StatementType.CLIENT_SIDE, Statement(query), client_side_statement_type
)
return None
61 changes: 56 additions & 5 deletions google/cloud/spanner_dbapi/connection.py
Expand Up @@ -13,13 +13,14 @@
# limitations under the License.

"""DB-API Connection for the Google Cloud Spanner."""

import time
import warnings

from google.api_core.exceptions import Aborted
from google.api_core.gapic_v1.client_info import ClientInfo
from google.cloud import spanner_v1 as spanner
from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode, BatchDmlExecutor
from google.cloud.spanner_dbapi.parsed_statement import ParsedStatement, Statement
from google.cloud.spanner_v1 import RequestOptions
from google.cloud.spanner_v1.session import _get_retry_delay
from google.cloud.spanner_v1.snapshot import Snapshot
Expand All @@ -28,7 +29,11 @@
from google.cloud.spanner_dbapi.checksum import _compare_checksums
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
from google.cloud.spanner_dbapi.cursor import Cursor
from google.cloud.spanner_dbapi.exceptions import InterfaceError, OperationalError
from google.cloud.spanner_dbapi.exceptions import (
InterfaceError,
OperationalError,
ProgrammingError,
)
from google.cloud.spanner_dbapi.version import DEFAULT_USER_AGENT
from google.cloud.spanner_dbapi.version import PY_VERSION

Expand Down Expand Up @@ -111,6 +116,8 @@ def __init__(self, instance, database=None, read_only=False):
# whether transaction started at Spanner. This means that we had
# made atleast one call to Spanner.
self._spanner_transaction_started = False
self._batch_mode = BatchMode.NONE
self._batch_dml_executor: BatchDmlExecutor = None

@property
def autocommit(self):
Expand Down Expand Up @@ -310,7 +317,10 @@ def _rerun_previous_statements(self):
statements, checksum = statement

transaction = self.transaction_checkout()
status, res = transaction.batch_update(statements)
statements_tuple = []
for single_statement in statements:
statements_tuple.append(single_statement.get_tuple())
status, res = transaction.batch_update(statements_tuple)

if status.code == ABORTED:
raise Aborted(status.details)
Expand Down Expand Up @@ -476,14 +486,14 @@ def run_prior_DDL_statements(self):

return self.database.update_ddl(ddl_statements).result()

def run_statement(self, statement, retried=False):
def run_statement(self, statement: Statement, retried=False):
"""Run single SQL statement in begun transaction.
This method is never used in autocommit mode. In
!autocommit mode however it remembers every executed
SQL statement with its parameters.
:type statement: :class:`dict`
:type statement: :class:`Statement`
:param statement: SQL statement to execute.
:type retried: bool
Expand Down Expand Up @@ -534,6 +544,47 @@ def validate(self):
"Expected: [[1]]" % result
)

@check_not_closed
def start_batch_dml(self, cursor):
if self._batch_mode is not BatchMode.NONE:
raise ProgrammingError(
"Cannot start a DML batch when a batch is already active"
)
if self.read_only:
raise ProgrammingError(
"Cannot start a DML batch when the connection is in read-only mode"
)
self._batch_mode = BatchMode.DML
self._batch_dml_executor = BatchDmlExecutor(cursor)

@check_not_closed
def execute_batch_dml_statement(self, parsed_statement: ParsedStatement):
if self._batch_mode is not BatchMode.DML:
raise ProgrammingError(
"Cannot execute statement when the BatchMode is not DML"
)
self._batch_dml_executor.execute_statement(parsed_statement)

@check_not_closed
def run_batch(self):
if self._batch_mode is BatchMode.NONE:
raise ProgrammingError("Cannot run a batch when the BatchMode is not set")
try:
if self._batch_mode is BatchMode.DML:
many_result_set = self._batch_dml_executor.run_batch_dml()
finally:
self._batch_mode = BatchMode.NONE
self._batch_dml_executor = None
return many_result_set

@check_not_closed
def abort_batch(self):
if self._batch_mode is BatchMode.NONE:
raise ProgrammingError("Cannot abort a batch when the BatchMode is not set")
if self._batch_mode is BatchMode.DML:
self._batch_dml_executor = None
self._batch_mode = BatchMode.NONE

def __enter__(self):
return self

Expand Down

0 comments on commit 7a92315

Please sign in to comment.