Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(db_api): use sqlparse to split DDL statements #372

Merged
merged 2 commits into from Jun 22, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 3 additions & 2 deletions google/cloud/spanner_dbapi/cursor.py
Expand Up @@ -14,6 +14,8 @@

"""Database cursor for Google Cloud Spanner DB-API."""

import sqlparse

from google.api_core.exceptions import Aborted
from google.api_core.exceptions import AlreadyExists
from google.api_core.exceptions import FailedPrecondition
Expand Down Expand Up @@ -174,8 +176,7 @@ def execute(self, sql, args=None):
try:
classification = parse_utils.classify_stmt(sql)
if classification == parse_utils.STMT_DDL:
for ddl in sql.split(";"):
ddl = ddl.strip()
for ddl in sqlparse.split(sql):
if ddl:
self.connection._ddl_statements.append(ddl)
if self.connection.autocommit:
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/spanner_dbapi/parse_utils.py
Expand Up @@ -199,7 +199,7 @@ def classify_stmt(query):

def parse_insert(insert_sql, params):
"""
Parse an INSERT statement an generate a list of tuples of the form:
Parse an INSERT statement and generate a list of tuples of the form:
[
(SQL, params_per_row1),
(SQL, params_per_row2),
Expand Down
20 changes: 16 additions & 4 deletions tests/unit/spanner_dbapi/test_cursor.py
Expand Up @@ -939,9 +939,16 @@ def test_ddls_with_semicolon(self):
from google.cloud.spanner_dbapi.connection import connect

EXP_DDLS = [
"CREATE TABLE table_name (row_id INT64) PRIMARY KEY ()",
"DROP INDEX index_name",
"DROP TABLE table_name",
"CREATE TABLE table_name (row_id INT64) PRIMARY KEY ();",
"DROP INDEX index_name;",
(
"CREATE TABLE papers ("
"\n id INT64,"
"\n authors ARRAY<STRING(100)>,"
'\n author_list STRING(MAX) AS (ARRAY_TO_STRING(authors, ";")) stored'
") PRIMARY KEY (id);"
),
"DROP TABLE table_name;",
]

with mock.patch(
Expand All @@ -956,7 +963,12 @@ def test_ddls_with_semicolon(self):
cursor.execute(
"CREATE TABLE table_name (row_id INT64) PRIMARY KEY ();"
"DROP INDEX index_name;\n"
"DROP TABLE table_name;"
"CREATE TABLE papers ("
"\n id INT64,"
"\n authors ARRAY<STRING(100)>,"
'\n author_list STRING(MAX) AS (ARRAY_TO_STRING(authors, ";")) stored'
") PRIMARY KEY (id);"
"DROP TABLE table_name;",
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved
)

self.assertEqual(connection._ddl_statements, EXP_DDLS)