Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support parameterized NUMERIC, BIGNUMERIC, STRING, and BYTES types #180

Merged
merged 8 commits into from May 24, 2021
7 changes: 6 additions & 1 deletion pybigquery/_helpers.py
Expand Up @@ -65,7 +65,12 @@ def create_bigquery_client(

def substitute_re_method(r, flags=0, repl=None):
if repl is None:
return lambda f: substitute_re_method(r, flags, f)
if isinstance(flags, int):
return lambda f: substitute_re_method(r, flags, f)
else:
# someone passed 2 args, a pattern and a replacement
repl = flags
flags = 0
tseaver marked this conversation as resolved.
Show resolved Hide resolved

r = re.compile(r, flags)

Expand Down
61 changes: 44 additions & 17 deletions pybigquery/sqlalchemy_bigquery.py
Expand Up @@ -122,8 +122,8 @@ def format_label(self, label, name=None):
"BYTES": types.BINARY,
"TIME": types.TIME,
"RECORD": types.JSON,
"NUMERIC": types.DECIMAL,
"BIGNUMERIC": types.DECIMAL,
"NUMERIC": types.Numeric,
"BIGNUMERIC": types.Numeric,
}

STRING = _type_map["STRING"]
Expand Down Expand Up @@ -360,6 +360,10 @@ def visit_notendswith_op_binary(self, binary, operator, **kw):

__expanded_param = re.compile(fr"\(\[" fr"{__expandng_text}" fr"_[^\]]+\]\)$").match

__remove_type_parameter = _helpers.substitute_re_method(
r"(\w+)\(\s*\d+\s*(,\s*\d+\s*)*\)", r"\1"
)

tseaver marked this conversation as resolved.
Show resolved Hide resolved
def visit_bindparam(
self,
bindparam,
Expand Down Expand Up @@ -397,6 +401,7 @@ def visit_bindparam(
if bq_type[-1] == ">" and bq_type.startswith("ARRAY<"):
# Values get arrayified at a lower level.
bq_type = bq_type[6:-1]
bq_type = self.__remove_type_parameter(bq_type)

assert_(param != "%s", f"Unexpected param: {param}")

Expand Down Expand Up @@ -429,6 +434,10 @@ def visit_FLOAT(self, type_, **kw):
visit_REAL = visit_FLOAT

def visit_STRING(self, type_, **kw):
if (type_.length is not None) and isinstance(
kw.get("type_expression"), Column
): # column def
return f"STRING({type_.length})"
return "STRING"

visit_CHAR = visit_NCHAR = visit_STRING
Expand All @@ -438,17 +447,29 @@ def visit_ARRAY(self, type_, **kw):
return "ARRAY<{}>".format(self.process(type_.item_type, **kw))

def visit_BINARY(self, type_, **kw):
if type_.length is not None:
return f"BYTES({type_.length})"
return "BYTES"

visit_VARBINARY = visit_BINARY

def visit_NUMERIC(self, type_, **kw):
if (type_.precision is not None and type_.precision > 38) or (
type_.scale is not None and type_.scale > 9
):
return "BIGNUMERIC"
if (type_.precision is not None) and isinstance(
kw.get("type_expression"), Column
): # column def
if type_.scale is not None:
suffix = f"({type_.precision}, {type_.scale})"
else:
suffix = f"({type_.precision})"
else:
return "NUMERIC"
suffix = ""

return (
"BIGNUMERIC"
if (type_.precision is not None and type_.precision > 38)
or (type_.scale is not None and type_.scale > 9)
else "NUMERIC"
) + suffix

visit_DECIMAL = visit_NUMERIC

Expand Down Expand Up @@ -800,18 +821,16 @@ def _get_columns_helper(self, columns, cur_columns):
"""
results = []
for col in columns:
results += [
SchemaField(
name=".".join(col.name for col in cur_columns + [col]),
field_type=col.field_type,
mode=col.mode,
description=col.description,
fields=col.fields,
)
]
results += [col]
if col.field_type == "RECORD":
cur_columns.append(col)
results += self._get_columns_helper(col.fields, cur_columns)
fields = [
SchemaField.from_api_repr(
dict(f.to_api_repr(), name=f"{col.name}.{f.name}")
)
for f in col.fields
]
results += self._get_columns_helper(fields, cur_columns)
cur_columns.pop()
return results

Expand All @@ -829,13 +848,21 @@ def get_columns(self, connection, table_name, schema=None, **kw):
)
coltype = types.NullType

if col.field_type.endswith("NUMERIC"):
coltype = coltype(precision=col.precision, scale=col.scale)
elif col.field_type == "STRING" or col.field_type == "BYTES":
coltype = coltype(col.max_length)

result.append(
{
"name": col.name,
"type": types.ARRAY(coltype) if col.mode == "REPEATED" else coltype,
"nullable": col.mode == "NULLABLE" or col.mode == "REPEATED",
"comment": col.description,
"default": None,
"precision": col.precision,
"scale": col.scale,
"max_length": col.max_length,
}
)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -67,7 +67,7 @@ def readme():
install_requires=[
"google-api-core>=1.23.0", # Work-around bug in cloud core deps.
"google-auth>=1.24.0,<2.0dev", # Work around pip wack.
"google-cloud-bigquery>=2.16.1",
"google-cloud-bigquery>=2.17.0",
"sqlalchemy>=1.2.0,<1.5.0dev",
"future",
],
Expand Down
2 changes: 1 addition & 1 deletion testing/constraints-3.6.txt
Expand Up @@ -6,5 +6,5 @@
# e.g., if setup.py has "foo >= 1.14.0, < 2.0.0dev",
sqlalchemy==1.2.0
google-auth==1.24.0
google-cloud-bigquery==2.16.1
google-cloud-bigquery==2.17.0
google-api-core==1.23.0
Expand Up @@ -71,7 +71,6 @@ def literal(value):

else:
from sqlalchemy.testing.suite import (
ComponentReflectionTestExtra as _ComponentReflectionTestExtra,
FetchLimitOffsetTest as _FetchLimitOffsetTest,
RowCountTest as _RowCountTest,
)
Expand Down Expand Up @@ -107,13 +106,6 @@ def test_limit_render_multiple_times(self, connection):
# over the backquotes that we add everywhere. XXX Why do we do that?
del PostCompileParamsTest

class ComponentReflectionTestExtra(_ComponentReflectionTestExtra):
@pytest.mark.skip("BQ types don't have parameters like precision and length")
def test_numeric_reflection(self):
pass

test_varchar_reflection = test_numeric_reflection

class TimestampMicrosecondsTest(_TimestampMicrosecondsTest):

data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396, tzinfo=pytz.UTC)
Expand Down
2 changes: 1 addition & 1 deletion tests/system/test_sqlalchemy_bigquery.py
Expand Up @@ -92,7 +92,7 @@
{"name": "timestamp", "type": types.TIMESTAMP(), "nullable": True, "default": None},
{"name": "string", "type": types.String(), "nullable": True, "default": None},
{"name": "float", "type": types.Float(), "nullable": True, "default": None},
{"name": "numeric", "type": types.DECIMAL(), "nullable": True, "default": None},
{"name": "numeric", "type": types.Numeric(), "nullable": True, "default": None},
{"name": "boolean", "type": types.Boolean(), "nullable": True, "default": None},
{"name": "date", "type": types.DATE(), "nullable": True, "default": None},
{"name": "datetime", "type": types.DATETIME(), "nullable": True, "default": None},
Expand Down
20 changes: 19 additions & 1 deletion tests/unit/fauxdbi.py
Expand Up @@ -329,6 +329,20 @@ def _row_dict(row, cursor):
result = {d[0]: value for d, value in zip(cursor.description, row)}
return result

__string_types = "STRING", "BYTES"

@substitute_re_method(r"(\w+)\s*\(\s*(\d+)\s*(?:,\s*(\d+)\s*)?\)")
def __parse_type_parameters(self, m, parameters):
name, precision, scale = m.groups()
if scale is not None:
parameters.update(precision=precision, scale=scale)
elif name in self.__string_types:
parameters.update(max_length=precision)
else:
parameters.update(precision=precision)

return name

def _get_field(
self,
type,
Expand All @@ -348,12 +362,16 @@ def _get_field(
if not mode:
mode = "REQUIRED" if notnull else "NULLABLE"

parameters = {}
type_ = self.__parse_type_parameters(type, parameters)

field = google.cloud.bigquery.schema.SchemaField(
name=name,
field_type=type,
field_type=type_,
mode=mode,
description=description,
fields=tuple(self._get_field(**f) for f in fields),
**parameters,
)

return field
Expand Down
74 changes: 55 additions & 19 deletions tests/unit/test_catalog_functions.py
Expand Up @@ -159,35 +159,51 @@ def test_get_table_comment(faux_conn):


@pytest.mark.parametrize(
"btype,atype",
"btype,atype,extra",
[
("STRING", sqlalchemy.types.String),
("BYTES", sqlalchemy.types.BINARY),
("INT64", sqlalchemy.types.Integer),
("FLOAT64", sqlalchemy.types.Float),
("NUMERIC", sqlalchemy.types.DECIMAL),
("BIGNUMERIC", sqlalchemy.types.DECIMAL),
("BOOL", sqlalchemy.types.Boolean),
("TIMESTAMP", sqlalchemy.types.TIMESTAMP),
("DATE", sqlalchemy.types.DATE),
("TIME", sqlalchemy.types.TIME),
("DATETIME", sqlalchemy.types.DATETIME),
("THURSDAY", sqlalchemy.types.NullType),
("STRING", sqlalchemy.types.String(), ()),
("STRING(42)", sqlalchemy.types.String(42), dict(max_length=42)),
("BYTES", sqlalchemy.types.BINARY(), ()),
("BYTES(42)", sqlalchemy.types.BINARY(42), dict(max_length=42)),
("INT64", sqlalchemy.types.Integer, ()),
("FLOAT64", sqlalchemy.types.Float, ()),
("NUMERIC", sqlalchemy.types.NUMERIC(), ()),
("NUMERIC(4)", sqlalchemy.types.NUMERIC(4), dict(precision=4)),
("NUMERIC(4, 2)", sqlalchemy.types.NUMERIC(4, 2), dict(precision=4, scale=2)),
("BIGNUMERIC", sqlalchemy.types.NUMERIC(), ()),
("BIGNUMERIC(42)", sqlalchemy.types.NUMERIC(42), dict(precision=42)),
(
"BIGNUMERIC(42, 2)",
sqlalchemy.types.NUMERIC(42, 2),
dict(precision=42, scale=2),
),
("BOOL", sqlalchemy.types.Boolean, ()),
("TIMESTAMP", sqlalchemy.types.TIMESTAMP, ()),
("DATE", sqlalchemy.types.DATE, ()),
("TIME", sqlalchemy.types.TIME, ()),
("DATETIME", sqlalchemy.types.DATETIME, ()),
("THURSDAY", sqlalchemy.types.NullType, ()),
],
)
def test_get_table_columns(faux_conn, btype, atype):
def test_get_table_columns(faux_conn, btype, atype, extra):
cursor = faux_conn.connection.cursor()
cursor.execute(f"create table foo (x {btype})")

assert faux_conn.dialect.get_columns(faux_conn, "foo") == [
[col] = faux_conn.dialect.get_columns(faux_conn, "foo")
col["type"] = str(col["type"])
assert col == dict(
{
"comment": None,
"default": None,
"max_length": None,
"name": "x",
"nullable": True,
"type": atype,
}
]
"type": str(atype),
"precision": None,
"scale": None,
},
**(extra or {}),
)


def test_get_table_columns_special_cases(faux_conn):
Expand All @@ -206,34 +222,54 @@ def test_get_table_columns_special_cases(faux_conn):
assert isinstance(stype, sqlalchemy.types.ARRAY)
assert isinstance(stype.item_type, sqlalchemy.types.String)
assert actual == [
{"comment": "a fine column", "default": None, "name": "s", "nullable": True},
{
"comment": "a fine column",
"default": None,
"name": "s",
"nullable": True,
"max_length": None,
"precision": None,
"scale": None,
},
{
"comment": None,
"default": None,
"name": "n",
"nullable": False,
"type": sqlalchemy.types.Integer,
"max_length": None,
"precision": None,
"scale": None,
},
{
"comment": None,
"default": None,
"name": "r",
"nullable": True,
"type": sqlalchemy.types.JSON,
"max_length": None,
"precision": None,
"scale": None,
},
{
"comment": None,
"default": None,
"name": "r.i",
"nullable": True,
"type": sqlalchemy.types.Integer,
"max_length": None,
"precision": None,
"scale": None,
},
{
"comment": None,
"default": None,
"name": "r.f",
"nullable": True,
"type": sqlalchemy.types.Float,
"max_length": None,
"precision": None,
"scale": None,
},
]

Expand Down
7 changes: 7 additions & 0 deletions tests/unit/test_helpers.py
Expand Up @@ -192,3 +192,10 @@ def foo_to_bar(self, m):
Replacer("hah").foo_to_bar("some foo and FOO is good")
== "some hah and FOO is good"
)


def test_substitute_re_string_repl_wo_flags(module_under_test):
foo_to_baz = module_under_test.substitute_re_method("foo", "baz")
assert (
foo_to_baz(object(), "some foo and FOO is good") == "some baz and FOO is good"
)