forked from googleapis/python-spanner-django
/
connection.py
144 lines (111 loc) · 4.16 KB
/
connection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# Copyright 2020 Google LLC
#
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd
from collections import namedtuple
from google.cloud import spanner_v1 as spanner
from .cursor import Cursor
from .exceptions import InterfaceError
ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"])
class Connection:
def __init__(self, db_handle):
self._dbhandle = db_handle
self._ddl_statements = []
self.is_closed = False
def cursor(self):
self._raise_if_closed()
return Cursor(self)
def _raise_if_closed(self):
"""Raise an exception if this connection is closed.
Helper to check the connection state before
running a SQL/DDL/DML query.
:raises: :class:`InterfaceError` if this connection is closed.
"""
if self.is_closed:
raise InterfaceError("connection is already closed")
def __handle_update_ddl(self, ddl_statements):
"""
Run the list of Data Definition Language (DDL) statements on the underlying
database. Each DDL statement MUST NOT contain a semicolon.
Args:
ddl_statements: a list of DDL statements, each without a semicolon.
Returns:
google.api_core.operation.Operation.result()
"""
self._raise_if_closed()
# Synchronously wait on the operation's completion.
return self._dbhandle.update_ddl(ddl_statements).result()
def read_snapshot(self):
self._raise_if_closed()
return self._dbhandle.snapshot()
def in_transaction(self, fn, *args, **kwargs):
self._raise_if_closed()
return self._dbhandle.run_in_transaction(fn, *args, **kwargs)
def append_ddl_statement(self, ddl_statement):
self._raise_if_closed()
self._ddl_statements.append(ddl_statement)
def run_prior_DDL_statements(self):
self._raise_if_closed()
if not self._ddl_statements:
return
ddl_statements = self._ddl_statements
self._ddl_statements = []
return self.__handle_update_ddl(ddl_statements)
def list_tables(self):
return self.run_sql_in_snapshot(
"""
SELECT
t.table_name
FROM
information_schema.tables AS t
WHERE
t.table_catalog = '' and t.table_schema = ''
"""
)
def run_sql_in_snapshot(self, sql, params=None, param_types=None):
# Some SQL e.g. for INFORMATION_SCHEMA cannot be run in read-write transactions
# hence this method exists to circumvent that limit.
self.run_prior_DDL_statements()
with self._dbhandle.snapshot() as snapshot:
res = snapshot.execute_sql(
sql, params=params, param_types=param_types
)
return list(res)
def get_table_column_schema(self, table_name):
rows = self.run_sql_in_snapshot(
"""SELECT
COLUMN_NAME, IS_NULLABLE, SPANNER_TYPE
FROM
INFORMATION_SCHEMA.COLUMNS
WHERE
TABLE_SCHEMA = ''
AND
TABLE_NAME = @table_name""",
params={"table_name": table_name},
param_types={"table_name": spanner.param_types.STRING},
)
column_details = {}
for column_name, is_nullable, spanner_type in rows:
column_details[column_name] = ColumnDetails(
null_ok=is_nullable == "YES", spanner_type=spanner_type
)
return column_details
def close(self):
"""Close this connection.
The connection will be unusable from this point forward.
"""
self.rollback()
self.__dbhandle = None
self.is_closed = True
def commit(self):
self._raise_if_closed()
self.run_prior_DDL_statements()
def rollback(self):
self._raise_if_closed()
# TODO: to be added.
def __enter__(self):
return self
def __exit__(self, etype, value, traceback):
self.commit()
self.close()