From 45421e73bfcddb244822e6a5cd43be6bd1ca2256 Mon Sep 17 00:00:00 2001 From: Jim Fulton Date: Fri, 21 May 2021 10:50:55 -0600 Subject: [PATCH] feat: Support parameterized NUMERIC, BIGNUMERIC, STRING, and BYTES types (#673) * parse parameterized schema info * Fixed SchemaField repr/key * Fix code duplication between _parse_schema_resource and from_api_repr Move new parameterized-type code from _parse_schema_resource to from_api_repr and implement _parse_schema_resource in terms of from_api_repr. * empty schemas are lists now, just like non-empty schemas. * changed new parameterized-type tests to use from_api_repr Because that's more direct and it uncovered duplicate code. * paramaterized the from_api_repr tests and added to_api_repr tests * Test BYTES and _key (repr) too. * Added a round-trip parameterized types schema tests * handle BYTES in _key/repr * blacken * Move _get_int close to use * Updated documentation. * Oops, forgot BIGNUMERIC * Improve argument doc and better argument name to __get_int * doom tables before creating them. * Use max_length in the Python for the REST api maxLength --- google/cloud/bigquery/schema.py | 71 +++++++++++++----- tests/system/test_client.py | 29 ++++++++ tests/unit/test_query.py | 4 +- tests/unit/test_schema.py | 123 ++++++++++++++++++++++++++++++++ 4 files changed, 209 insertions(+), 18 deletions(-) diff --git a/google/cloud/bigquery/schema.py b/google/cloud/bigquery/schema.py index cb221d6de..919d78b23 100644 --- a/google/cloud/bigquery/schema.py +++ b/google/cloud/bigquery/schema.py @@ -67,6 +67,15 @@ class SchemaField(object): policy_tags (Optional[PolicyTagList]): The policy tag list for the field. + precision (Optional[int]): + Precison (number of digits) of fields with NUMERIC or BIGNUMERIC type. + + scale (Optional[int]): + Scale (digits after decimal) of fields with NUMERIC or BIGNUMERIC type. + + max_length (Optional[int]): + Maximim length of fields with STRING or BYTES type. + """ def __init__( @@ -77,6 +86,9 @@ def __init__( description=_DEFAULT_VALUE, fields=(), policy_tags=None, + precision=_DEFAULT_VALUE, + scale=_DEFAULT_VALUE, + max_length=_DEFAULT_VALUE, ): self._properties = { "name": name, @@ -86,9 +98,22 @@ def __init__( self._properties["mode"] = mode.upper() if description is not _DEFAULT_VALUE: self._properties["description"] = description + if precision is not _DEFAULT_VALUE: + self._properties["precision"] = precision + if scale is not _DEFAULT_VALUE: + self._properties["scale"] = scale + if max_length is not _DEFAULT_VALUE: + self._properties["maxLength"] = max_length self._fields = tuple(fields) self._policy_tags = policy_tags + @staticmethod + def __get_int(api_repr, name): + v = api_repr.get(name, _DEFAULT_VALUE) + if v is not _DEFAULT_VALUE: + v = int(v) + return v + @classmethod def from_api_repr(cls, api_repr: dict) -> "SchemaField": """Return a ``SchemaField`` object deserialized from a dictionary. @@ -113,6 +138,9 @@ def from_api_repr(cls, api_repr: dict) -> "SchemaField": description=description, name=api_repr["name"], policy_tags=PolicyTagList.from_api_repr(api_repr.get("policyTags")), + precision=cls.__get_int(api_repr, "precision"), + scale=cls.__get_int(api_repr, "scale"), + max_length=cls.__get_int(api_repr, "maxLength"), ) @property @@ -148,6 +176,21 @@ def description(self): """Optional[str]: description for the field.""" return self._properties.get("description") + @property + def precision(self): + """Optional[int]: Precision (number of digits) for the NUMERIC field.""" + return self._properties.get("precision") + + @property + def scale(self): + """Optional[int]: Scale (digits after decimal) for the NUMERIC field.""" + return self._properties.get("scale") + + @property + def max_length(self): + """Optional[int]: Maximum length for the STRING or BYTES field.""" + return self._properties.get("maxLength") + @property def fields(self): """Optional[tuple]: Subfields contained in this field. @@ -191,9 +234,19 @@ def _key(self): Returns: Tuple: The contents of this :class:`~google.cloud.bigquery.schema.SchemaField`. """ + field_type = self.field_type.upper() + if field_type == "STRING" or field_type == "BYTES": + if self.max_length is not None: + field_type = f"{field_type}({self.max_length})" + elif field_type.endswith("NUMERIC"): + if self.precision is not None: + if self.scale is not None: + field_type = f"{field_type}({self.precision}, {self.scale})" + else: + field_type = f"{field_type}({self.precision})" return ( self.name, - self.field_type.upper(), + field_type, # Mode is always str, if not given it defaults to a str value self.mode.upper(), # pytype: disable=attribute-error self.description, @@ -269,21 +322,7 @@ def _parse_schema_resource(info): Optional[Sequence[google.cloud.bigquery.schema.SchemaField`]: A list of parsed fields, or ``None`` if no "fields" key found. """ - if "fields" not in info: - return () - - schema = [] - for r_field in info["fields"]: - name = r_field["name"] - field_type = r_field["type"] - mode = r_field.get("mode", "NULLABLE") - description = r_field.get("description") - sub_fields = _parse_schema_resource(r_field) - policy_tags = PolicyTagList.from_api_repr(r_field.get("policyTags")) - schema.append( - SchemaField(name, field_type, mode, description, sub_fields, policy_tags) - ) - return schema + return [SchemaField.from_api_repr(f) for f in info.get("fields", ())] def _build_schema_resource(fields): diff --git a/tests/system/test_client.py b/tests/system/test_client.py index 7c8ef50fa..b4b0c053d 100644 --- a/tests/system/test_client.py +++ b/tests/system/test_client.py @@ -2173,6 +2173,35 @@ def test_list_rows_page_size(self): page = next(pages) self.assertEqual(page.num_items, num_last_page) + def test_parameterized_types_round_trip(self): + client = Config.CLIENT + table_id = f"{Config.DATASET}.test_parameterized_types_round_trip" + fields = ( + ("n", "NUMERIC"), + ("n9", "NUMERIC(9)"), + ("n92", "NUMERIC(9, 2)"), + ("bn", "BIGNUMERIC"), + ("bn9", "BIGNUMERIC(38)"), + ("bn92", "BIGNUMERIC(38, 22)"), + ("s", "STRING"), + ("s9", "STRING(9)"), + ("b", "BYTES"), + ("b9", "BYTES(9)"), + ) + self.to_delete.insert(0, Table(f"{client.project}.{table_id}")) + client.query( + "create table {} ({})".format( + table_id, ", ".join(" ".join(f) for f in fields) + ) + ).result() + table = client.get_table(table_id) + table_id2 = table_id + "2" + self.to_delete.insert(0, Table(f"{client.project}.{table_id2}")) + client.create_table(Table(f"{client.project}.{table_id2}", table.schema)) + table2 = client.get_table(table_id2) + + self.assertEqual(tuple(s._key()[:2] for s in table2.schema), fields) + def temp_dataset(self, dataset_id, location=None): project = Config.CLIENT.project dataset_ref = bigquery.DatasetReference(project, dataset_id) diff --git a/tests/unit/test_query.py b/tests/unit/test_query.py index 90fc30b20..9483fe8dd 100644 --- a/tests/unit/test_query.py +++ b/tests/unit/test_query.py @@ -1302,7 +1302,7 @@ def _verifySchema(self, query, resource): self.assertEqual(found.description, expected.get("description")) self.assertEqual(found.fields, expected.get("fields", ())) else: - self.assertEqual(query.schema, ()) + self.assertEqual(query.schema, []) def test_ctor_defaults(self): query = self._make_one(self._make_resource()) @@ -1312,7 +1312,7 @@ def test_ctor_defaults(self): self.assertIsNone(query.page_token) self.assertEqual(query.project, self.PROJECT) self.assertEqual(query.rows, []) - self.assertEqual(query.schema, ()) + self.assertEqual(query.schema, []) self.assertIsNone(query.total_rows) self.assertIsNone(query.total_bytes_processed) diff --git a/tests/unit/test_schema.py b/tests/unit/test_schema.py index 87baaf379..29c3bace5 100644 --- a/tests/unit/test_schema.py +++ b/tests/unit/test_schema.py @@ -15,6 +15,7 @@ import unittest import mock +import pytest class TestSchemaField(unittest.TestCase): @@ -715,3 +716,125 @@ def test___hash__not_equals(self): set_one = {policy1} set_two = {policy2} self.assertNotEqual(set_one, set_two) + + +@pytest.mark.parametrize( + "api,expect,key2", + [ + ( + dict(name="n", type="NUMERIC"), + ("n", "NUMERIC", None, None, None), + ("n", "NUMERIC"), + ), + ( + dict(name="n", type="NUMERIC", precision=9), + ("n", "NUMERIC", 9, None, None), + ("n", "NUMERIC(9)"), + ), + ( + dict(name="n", type="NUMERIC", precision=9, scale=2), + ("n", "NUMERIC", 9, 2, None), + ("n", "NUMERIC(9, 2)"), + ), + ( + dict(name="n", type="BIGNUMERIC"), + ("n", "BIGNUMERIC", None, None, None), + ("n", "BIGNUMERIC"), + ), + ( + dict(name="n", type="BIGNUMERIC", precision=40), + ("n", "BIGNUMERIC", 40, None, None), + ("n", "BIGNUMERIC(40)"), + ), + ( + dict(name="n", type="BIGNUMERIC", precision=40, scale=2), + ("n", "BIGNUMERIC", 40, 2, None), + ("n", "BIGNUMERIC(40, 2)"), + ), + ( + dict(name="n", type="STRING"), + ("n", "STRING", None, None, None), + ("n", "STRING"), + ), + ( + dict(name="n", type="STRING", maxLength=9), + ("n", "STRING", None, None, 9), + ("n", "STRING(9)"), + ), + ( + dict(name="n", type="BYTES"), + ("n", "BYTES", None, None, None), + ("n", "BYTES"), + ), + ( + dict(name="n", type="BYTES", maxLength=9), + ("n", "BYTES", None, None, 9), + ("n", "BYTES(9)"), + ), + ], +) +def test_from_api_repr_parameterized(api, expect, key2): + from google.cloud.bigquery.schema import SchemaField + + field = SchemaField.from_api_repr(api) + + assert ( + field.name, + field.field_type, + field.precision, + field.scale, + field.max_length, + ) == expect + + assert field._key()[:2] == key2 + + +@pytest.mark.parametrize( + "field,api", + [ + ( + dict(name="n", field_type="NUMERIC"), + dict(name="n", type="NUMERIC", mode="NULLABLE"), + ), + ( + dict(name="n", field_type="NUMERIC", precision=9), + dict(name="n", type="NUMERIC", mode="NULLABLE", precision=9), + ), + ( + dict(name="n", field_type="NUMERIC", precision=9, scale=2), + dict(name="n", type="NUMERIC", mode="NULLABLE", precision=9, scale=2), + ), + ( + dict(name="n", field_type="BIGNUMERIC"), + dict(name="n", type="BIGNUMERIC", mode="NULLABLE"), + ), + ( + dict(name="n", field_type="BIGNUMERIC", precision=40), + dict(name="n", type="BIGNUMERIC", mode="NULLABLE", precision=40), + ), + ( + dict(name="n", field_type="BIGNUMERIC", precision=40, scale=2), + dict(name="n", type="BIGNUMERIC", mode="NULLABLE", precision=40, scale=2), + ), + ( + dict(name="n", field_type="STRING"), + dict(name="n", type="STRING", mode="NULLABLE"), + ), + ( + dict(name="n", field_type="STRING", max_length=9), + dict(name="n", type="STRING", mode="NULLABLE", maxLength=9), + ), + ( + dict(name="n", field_type="BYTES"), + dict(name="n", type="BYTES", mode="NULLABLE"), + ), + ( + dict(name="n", field_type="BYTES", max_length=9), + dict(name="n", type="BYTES", mode="NULLABLE", maxLength=9), + ), + ], +) +def test_to_api_repr_parameterized(field, api): + from google.cloud.bigquery.schema import SchemaField + + assert SchemaField(**field).to_api_repr() == api