Skip to content

Commit

Permalink
feat: Support parameterized NUMERIC, BIGNUMERIC, STRING, and BYTES ty…
Browse files Browse the repository at this point in the history
…pes (#180)
  • Loading branch information
jimfulton committed May 24, 2021
1 parent fe0591a commit d118238
Show file tree
Hide file tree
Showing 10 changed files with 241 additions and 95 deletions.
8 changes: 5 additions & 3 deletions pybigquery/_helpers.py
Expand Up @@ -69,9 +69,6 @@ def substitute_re_method(r, flags=0, repl=None):

r = re.compile(r, flags)

if isinstance(repl, str):
return lambda self, s: r.sub(repl, s)

@functools.wraps(repl)
def sub(self, s, *args, **kw):
def repl_(m):
Expand All @@ -80,3 +77,8 @@ def repl_(m):
return r.sub(repl_, s)

return sub


def substitute_string_re_method(r, *, repl, flags=0):
r = re.compile(r, flags)
return lambda self, s: r.sub(repl, s)
125 changes: 90 additions & 35 deletions pybigquery/sqlalchemy_bigquery.py
Expand Up @@ -122,8 +122,8 @@ def format_label(self, label, name=None):
"BYTES": types.BINARY,
"TIME": types.TIME,
"RECORD": types.JSON,
"NUMERIC": types.DECIMAL,
"BIGNUMERIC": types.DECIMAL,
"NUMERIC": types.Numeric,
"BIGNUMERIC": types.Numeric,
}

STRING = _type_map["STRING"]
Expand Down Expand Up @@ -158,23 +158,33 @@ def get_insert_default(self, column): # pragma: NO COVER
elif isinstance(column.type, String):
return str(uuid.uuid4())

__remove_type_from_empty_in = _helpers.substitute_re_method(
r" IN UNNEST\(\[ ("
r"(?:NULL|\(NULL(?:, NULL)+\))\)"
r" (?:AND|OR) \(1 !?= 1"
r")"
r"(?:[:][A-Z0-9]+)?"
r" \]\)",
re.IGNORECASE,
r" IN(\1)",
__remove_type_from_empty_in = _helpers.substitute_string_re_method(
r"""
\sIN\sUNNEST\(\[\s # ' IN UNNEST([ '
(
(?:NULL|\(NULL(?:,\sNULL)+\))\) # '(NULL)' or '((NULL, NULL, ...))'
\s(?:AND|OR)\s\(1\s!?=\s1 # ' and 1 != 1' or ' or 1 = 1'
)
(?:[:][A-Z0-9]+)? # Maybe ':TYPE' (e.g. ':INT64')
\s\]\) # Close: ' ])'
""",
flags=re.IGNORECASE | re.VERBOSE,
repl=r" IN(\1)",
)

@_helpers.substitute_re_method(
r" IN UNNEST\(\[ "
r"(%\([^)]+_\d+\)s(?:, %\([^)]+_\d+\)s)*)?" # Placeholders. See below.
r":([A-Z0-9]+)" # Type
r" \]\)",
re.IGNORECASE,
r"""
\sIN\sUNNEST\(\[\s # ' IN UNNEST([ '
( # Placeholders. See below.
%\([^)]+_\d+\)s # Placeholder '%(foo_1)s'
(?:,\s # 0 or more placeholders
%\([^)]+_\d+\)s
)*
)?
:([A-Z0-9]+) # Type ':TYPE' (e.g. ':INT64')
\s\]\) # Close: ' ])'
""",
flags=re.IGNORECASE | re.VERBOSE,
)
def __distribute_types_to_expanded_placeholders(self, m):
# If we have an in parameter, it sometimes gets expaned to 0 or more
Expand Down Expand Up @@ -282,10 +292,20 @@ def group_by_clause(self, select, **kw):
"EXPANDING" if __sqlalchemy_version_info < (1, 4) else "POSTCOMPILE"
)

__in_expanding_bind = _helpers.substitute_re_method(
fr" IN \((\[" fr"{__expandng_text}" fr"_[^\]]+\](:[A-Z0-9]+)?)\)$",
re.IGNORECASE,
r" IN UNNEST([ \1 ])",
__in_expanding_bind = _helpers.substitute_string_re_method(
fr"""
\sIN\s\( # ' IN ('
(
\[ # Expanding placeholder
{__expandng_text} # e.g. [EXPANDING_foo_1]
_[^\]]+ #
\]
(:[A-Z0-9]+)? # type marker (e.g. ':INT64'
)
\)$ # close w ending )
""",
flags=re.IGNORECASE | re.VERBOSE,
repl=r" IN UNNEST([ \1 ])",
)

def visit_in_op_binary(self, binary, operator_, **kw):
Expand Down Expand Up @@ -360,6 +380,18 @@ def visit_notendswith_op_binary(self, binary, operator, **kw):

__expanded_param = re.compile(fr"\(\[" fr"{__expandng_text}" fr"_[^\]]+\]\)$").match

__remove_type_parameter = _helpers.substitute_string_re_method(
r"""
(STRING|BYTES|NUMERIC|BIGNUMERIC) # Base type
\( # Dimensions e.g. '(42)', '(4, 2)':
\s*\d+\s* # First dimension
(?:,\s*\d+\s*)* # Remaining dimensions
\)
""",
repl=r"\1",
flags=re.VERBOSE | re.IGNORECASE,
)

def visit_bindparam(
self,
bindparam,
Expand Down Expand Up @@ -397,6 +429,7 @@ def visit_bindparam(
if bq_type[-1] == ">" and bq_type.startswith("ARRAY<"):
# Values get arrayified at a lower level.
bq_type = bq_type[6:-1]
bq_type = self.__remove_type_parameter(bq_type)

assert_(param != "%s", f"Unexpected param: {param}")

Expand Down Expand Up @@ -429,6 +462,10 @@ def visit_FLOAT(self, type_, **kw):
visit_REAL = visit_FLOAT

def visit_STRING(self, type_, **kw):
if (type_.length is not None) and isinstance(
kw.get("type_expression"), Column
): # column def
return f"STRING({type_.length})"
return "STRING"

visit_CHAR = visit_NCHAR = visit_STRING
Expand All @@ -438,17 +475,29 @@ def visit_ARRAY(self, type_, **kw):
return "ARRAY<{}>".format(self.process(type_.item_type, **kw))

def visit_BINARY(self, type_, **kw):
if type_.length is not None:
return f"BYTES({type_.length})"
return "BYTES"

visit_VARBINARY = visit_BINARY

def visit_NUMERIC(self, type_, **kw):
if (type_.precision is not None and type_.precision > 38) or (
type_.scale is not None and type_.scale > 9
):
return "BIGNUMERIC"
if (type_.precision is not None) and isinstance(
kw.get("type_expression"), Column
): # column def
if type_.scale is not None:
suffix = f"({type_.precision}, {type_.scale})"
else:
suffix = f"({type_.precision})"
else:
return "NUMERIC"
suffix = ""

return (
"BIGNUMERIC"
if (type_.precision is not None and type_.precision > 38)
or (type_.scale is not None and type_.scale > 9)
else "NUMERIC"
) + suffix

visit_DECIMAL = visit_NUMERIC

Expand Down Expand Up @@ -800,18 +849,16 @@ def _get_columns_helper(self, columns, cur_columns):
"""
results = []
for col in columns:
results += [
SchemaField(
name=".".join(col.name for col in cur_columns + [col]),
field_type=col.field_type,
mode=col.mode,
description=col.description,
fields=col.fields,
)
]
results += [col]
if col.field_type == "RECORD":
cur_columns.append(col)
results += self._get_columns_helper(col.fields, cur_columns)
fields = [
SchemaField.from_api_repr(
dict(f.to_api_repr(), name=f"{col.name}.{f.name}")
)
for f in col.fields
]
results += self._get_columns_helper(fields, cur_columns)
cur_columns.pop()
return results

Expand All @@ -829,13 +876,21 @@ def get_columns(self, connection, table_name, schema=None, **kw):
)
coltype = types.NullType

if col.field_type.endswith("NUMERIC"):
coltype = coltype(precision=col.precision, scale=col.scale)
elif col.field_type == "STRING" or col.field_type == "BYTES":
coltype = coltype(col.max_length)

result.append(
{
"name": col.name,
"type": types.ARRAY(coltype) if col.mode == "REPEATED" else coltype,
"nullable": col.mode == "NULLABLE" or col.mode == "REPEATED",
"comment": col.description,
"default": None,
"precision": col.precision,
"scale": col.scale,
"max_length": col.max_length,
}
)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -67,7 +67,7 @@ def readme():
install_requires=[
"google-api-core>=1.23.0", # Work-around bug in cloud core deps.
"google-auth>=1.24.0,<2.0dev", # Work around pip wack.
"google-cloud-bigquery>=2.16.1",
"google-cloud-bigquery>=2.17.0",
"sqlalchemy>=1.2.0,<1.5.0dev",
"future",
],
Expand Down
2 changes: 1 addition & 1 deletion testing/constraints-3.6.txt
Expand Up @@ -6,5 +6,5 @@
# e.g., if setup.py has "foo >= 1.14.0, < 2.0.0dev",
sqlalchemy==1.2.0
google-auth==1.24.0
google-cloud-bigquery==2.16.1
google-cloud-bigquery==2.17.0
google-api-core==1.23.0
Expand Up @@ -71,7 +71,6 @@ def literal(value):

else:
from sqlalchemy.testing.suite import (
ComponentReflectionTestExtra as _ComponentReflectionTestExtra,
FetchLimitOffsetTest as _FetchLimitOffsetTest,
RowCountTest as _RowCountTest,
)
Expand Down Expand Up @@ -107,13 +106,6 @@ def test_limit_render_multiple_times(self, connection):
# over the backquotes that we add everywhere. XXX Why do we do that?
del PostCompileParamsTest

class ComponentReflectionTestExtra(_ComponentReflectionTestExtra):
@pytest.mark.skip("BQ types don't have parameters like precision and length")
def test_numeric_reflection(self):
pass

test_varchar_reflection = test_numeric_reflection

class TimestampMicrosecondsTest(_TimestampMicrosecondsTest):

data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396, tzinfo=pytz.UTC)
Expand Down
2 changes: 1 addition & 1 deletion tests/system/test_sqlalchemy_bigquery.py
Expand Up @@ -92,7 +92,7 @@
{"name": "timestamp", "type": types.TIMESTAMP(), "nullable": True, "default": None},
{"name": "string", "type": types.String(), "nullable": True, "default": None},
{"name": "float", "type": types.Float(), "nullable": True, "default": None},
{"name": "numeric", "type": types.DECIMAL(), "nullable": True, "default": None},
{"name": "numeric", "type": types.Numeric(), "nullable": True, "default": None},
{"name": "boolean", "type": types.Boolean(), "nullable": True, "default": None},
{"name": "date", "type": types.DATE(), "nullable": True, "default": None},
{"name": "datetime", "type": types.DATETIME(), "nullable": True, "default": None},
Expand Down

0 comments on commit d118238

Please sign in to comment.