diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index c2e893a098..4b5a0d9652 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -85,6 +85,9 @@ def description(self): - ``precision`` - ``scale`` - ``null_ok`` + + :rtype: tuple + :returns: A tuple of columns' information. """ if not (self._result_set and self._result_set.metadata): return None @@ -107,7 +110,11 @@ def description(self): @property def rowcount(self): - """The number of rows produced by the last `.execute()`.""" + """The number of rows produced by the last `.execute()`. + + :rtype: int + :returns: The number of rows produced by the last .execute*(). + """ return self._row_count def _raise_if_closed(self): @@ -127,7 +134,14 @@ def callproc(self, procname, args=None): self._raise_if_closed() def close(self): - """Closes this Cursor, making it unusable from this point forward.""" + """Prepare and execute a Spanner database operation. + + :type sql: str + :param sql: A SQL query statement. + + :type args: list + :param args: Additional parameters to supplement the SQL query. + """ self._is_closed = True def _do_execute_update(self, transaction, sql, params, param_types=None): @@ -358,6 +372,11 @@ def __iter__(self): return self._itr def list_tables(self): + """List the tables of the linked Database. + + :rtype: list + :returns: The list of tables within the Database. + """ return self.run_sql_in_snapshot(_helpers.SQL_LIST_TABLES) def run_sql_in_snapshot(self, sql, params=None, param_types=None): diff --git a/google/cloud/spanner_dbapi/parse_utils.py b/google/cloud/spanner_dbapi/parse_utils.py index d3dd98dda6..abc36b397c 100644 --- a/google/cloud/spanner_dbapi/parse_utils.py +++ b/google/cloud/spanner_dbapi/parse_utils.py @@ -176,11 +176,11 @@ def classify_stmt(query): """Determine SQL query type. - :type query: :class:`str` - :param query: SQL query. + :type query: str + :param query: A SQL query. - :rtype: :class:`str` - :returns: Query type name. + :rtype: str + :returns: The query type name. """ if RE_DDL.match(query): return STMT_DDL @@ -253,6 +253,17 @@ def parse_insert(insert_sql, params): ('INSERT INTO T (f1, f2) VALUES (UPPER(%s), %s)', ('c', 'd',)) ], } + + :type insert_sql: str + :param insert_sql: A SQL insert request. + + :type params: list + :param params: A list of parameters. + + :rtype: dict + :returns: A dictionary that maps `sql_params_list` to the list of + parameters in cases a), b), d) or the dictionary with information + about the resulting table in case c). """ # noqa match = RE_INSERT.search(insert_sql) @@ -348,8 +359,16 @@ def rows_for_insert_or_update(columns, params, pyformat_args=None): We'll have to convert both params types into: Params: [(1, 2, 3,), (4, 5, 6,), (7, 8, 9,)] - """ # noqa + :type columns: list + :param columns: A list of the columns of the table. + + :type params: list + :param params: A list of parameters. + + :rtype: list + :returns: A properly restructured list of the parameters. + """ # noqa if not pyformat_args: # This is the case where we have for example: # SQL: 'INSERT INTO t (f1, f2, f3)' @@ -445,6 +464,16 @@ def sql_pyformat_args_to_spanner(sql, params): becomes: SQL: 'SELECT * from t where f1=@a0, f2=@a1, f3=@a2' Params: {'a0': 'a', 'a1': 23, 'a2': '888***'} + + :type sql: str + :param sql: A SQL request. + + :type params: list + :param params: A list of parameters. + + :rtype: tuple(str, dict) + :returns: A tuple of the sanitized SQL and a dictionary of the named + arguments. """ if not params: return sanitize_literals_for_upload(sql), params @@ -488,10 +517,10 @@ def cast_for_spanner(value): """Convert the param to its Cloud Spanner equivalent type. :type value: Any - :param value: Value to convert to a Cloud Spanner type. + :param value: The value to convert to a Cloud Spanner type. :rtype: Any - :returns: Value converted to a Cloud Spanner type. + :returns: The value converted to a Cloud Spanner type. """ if isinstance(value, decimal.Decimal): return str(value) @@ -501,10 +530,10 @@ def cast_for_spanner(value): def get_param_types(params): """Determine Cloud Spanner types for the given parameters. - :type params: :class:`dict` + :type params: dict :param params: Parameters requiring to find Cloud Spanner types. - :rtype: :class:`dict` + :rtype: dict :returns: The types index for the given parameters. """ if params is None: @@ -525,7 +554,7 @@ def ensure_where_clause(sql): Cloud Spanner requires a WHERE clause on UPDATE and DELETE statements. Add a dummy WHERE clause if non detected. - :type sql: `str` + :type sql: str :param sql: SQL code to check. """ if any(isinstance(token, sqlparse.sql.Where) for token in sqlparse.parse(sql)[0]): @@ -539,10 +568,10 @@ def escape_name(name): Apply backticks to the name that either contain '-' or ' ', or is a Cloud Spanner's reserved keyword. - :type name: :class:`str` + :type name: str :param name: Name to escape. - :rtype: :class:`str` + :rtype: str :returns: Name escaped if it has to be escaped. """ if "-" in name or " " in name or name.upper() in SPANNER_RESERVED_KEYWORDS: diff --git a/google/cloud/spanner_dbapi/parser.py b/google/cloud/spanner_dbapi/parser.py index 9271631b25..43e446c58e 100644 --- a/google/cloud/spanner_dbapi/parser.py +++ b/google/cloud/spanner_dbapi/parser.py @@ -68,14 +68,18 @@ def __len__(self): class terminal(str): - """ - terminal represents the unit symbol that can be part of a SQL values clause. - """ + """Represent the unit symbol that can be part of a SQL values clause.""" pass class a_args(object): + """Expression arguments. + + :type argv: list + :param argv: A List of expression arguments. + """ + def __init__(self, argv): self.argv = argv @@ -108,9 +112,11 @@ def __getitem__(self, index): return self.argv[index] def homogenous(self): - """ - Return True if all the arguments are pyformat - args and have the same number of arguments. + """Check arguments of the expression to be homogeneous. + + :rtype: bool + :return: True if all the arguments of the expression are in pyformat + and each has the same length, False otherwise. """ if not self._is_equal_length(): return False @@ -126,8 +132,10 @@ def homogenous(self): return True def _is_equal_length(self): - """ - Return False if all the arguments have the same length. + """Return False if all the arguments have the same length. + + :rtype: bool + :return: False if the sequences of the arguments have the same length. """ if len(self) == 0: return True @@ -141,6 +149,12 @@ def _is_equal_length(self): class values(a_args): + """A wrapper for values. + + :rtype: str + :returns: A string of the values expression in a tree view. + """ + def __str__(self): return "VALUES%s" % super().__str__() @@ -153,6 +167,21 @@ def parse_values(stmt): def expect(word, token): + """Parse the given expression recursively. + + :type word: str + :param word: A string expression. + + :type token: str + :param token: An expression token. + + :rtype: `Tuple(str, Any)` + :returns: A tuple containing the rest of the expression string and the + parse tree for the part of the expression that has already been + parsed. + + :raises :class:`ProgrammingError`: If there is a parsing error. + """ word = word.strip() if token == VALUES: if not word.startswith("VALUES"): @@ -242,5 +271,13 @@ def expect(word, token): def as_values(values_stmt): + """Return the parsed values. + + :type values_stmt: str + :param values_stmt: Raw values. + + :rtype: Any + :returns: A tree of the already parsed expression. + """ _, _values = parse_values(values_stmt) return _values diff --git a/google/cloud/spanner_dbapi/utils.py b/google/cloud/spanner_dbapi/utils.py index 7cafaaa609..bfb97346cf 100644 --- a/google/cloud/spanner_dbapi/utils.py +++ b/google/cloud/spanner_dbapi/utils.py @@ -19,11 +19,13 @@ class PeekIterator: """ - PeekIterator peeks at the first element out of an iterator - for the sake of operations like auto-population of fields on reading - the first element. - If next's result is an instance of list, it'll be converted into a tuple - to conform with DBAPI v2's sequence expectations. + Peek at the first element out of an iterator for the sake of operations + like auto-population of fields on reading the first element. + If next's result is an instance of list, it'll be converted into a tuple to + conform with DBAPI v2's sequence expectations. + + :type source: list + :param source: A list of source for the Iterator. """ def __init__(self, source): @@ -97,6 +99,15 @@ def __iter__(self): def backtick_unicode(sql): + """Check the SQL to be valid and split it by segments. + + :type sql: str + :param sql: A SQL request. + + :rtype: str + :returns: A SQL parsed by segments in unicode if initial SQL is valid, + initial string otherwise. + """ matches = list(re_UNICODE_POINTS.finditer(sql)) if not matches: return sql @@ -117,11 +128,20 @@ def backtick_unicode(sql): def sanitize_literals_for_upload(s): - """ - Convert literals in s, to be fit for consumption by Cloud Spanner. - 1. Convert %% (escaped percent literals) to %. Percent signs must be escaped when - values like %s are used as SQL parameter placeholders but Spanner's query language - uses placeholders like @a0 and doesn't expect percent signs to be escaped. - 2. Quote words containing non-ASCII, with backticks, for example föö to `föö`. + """Convert literals in s, to be fit for consumption by Cloud Spanner. + + * Convert %% (escaped percent literals) to %. Percent signs must be escaped + when values like %s are used as SQL parameter placeholders but Spanner's + query language uses placeholders like @a0 and doesn't expect percent + signs to be escaped. + * Quote words containing non-ASCII, with backticks, for example föö to + `föö`. + + :type s: str + :param s: A string with literals to escaped for consumption by Cloud + Spanner. + + :rtype: str + :returns: A sanitized string for uploading. """ return backtick_unicode(s.replace("%%", "%"))