diff --git a/google/cloud/bigquery/schema.py b/google/cloud/bigquery/schema.py index 919d78b23..157db7ce6 100644 --- a/google/cloud/bigquery/schema.py +++ b/google/cloud/bigquery/schema.py @@ -15,6 +15,7 @@ """Schemas for BigQuery tables / queries.""" import collections +from typing import Optional from google.cloud.bigquery_v2 import types @@ -105,7 +106,26 @@ def __init__( if max_length is not _DEFAULT_VALUE: self._properties["maxLength"] = max_length self._fields = tuple(fields) - self._policy_tags = policy_tags + + self._policy_tags = self._determine_policy_tags(field_type, policy_tags) + + @staticmethod + def _determine_policy_tags( + field_type: str, given_policy_tags: Optional["PolicyTagList"] + ) -> Optional["PolicyTagList"]: + """Return the given policy tags, or their suitable representation if `None`. + + Args: + field_type: The type of the schema field. + given_policy_tags: The policy tags to maybe ajdust. + """ + if given_policy_tags is not None: + return given_policy_tags + + if field_type is not None and field_type.upper() in _STRUCT_TYPES: + return None + + return PolicyTagList() @staticmethod def __get_int(api_repr, name): @@ -126,18 +146,24 @@ def from_api_repr(cls, api_repr: dict) -> "SchemaField": Returns: google.cloud.biquery.schema.SchemaField: The ``SchemaField`` object. """ + field_type = api_repr["type"].upper() + # Handle optional properties with default values mode = api_repr.get("mode", "NULLABLE") description = api_repr.get("description", _DEFAULT_VALUE) fields = api_repr.get("fields", ()) + policy_tags = cls._determine_policy_tags( + field_type, PolicyTagList.from_api_repr(api_repr.get("policyTags")) + ) + return cls( - field_type=api_repr["type"].upper(), + field_type=field_type, fields=[cls.from_api_repr(f) for f in fields], mode=mode.upper(), description=description, name=api_repr["name"], - policy_tags=PolicyTagList.from_api_repr(api_repr.get("policyTags")), + policy_tags=policy_tags, precision=cls.__get_int(api_repr, "precision"), scale=cls.__get_int(api_repr, "scale"), max_length=cls.__get_int(api_repr, "maxLength"), @@ -218,9 +244,9 @@ def to_api_repr(self) -> dict: # add this to the serialized representation. if self.field_type.upper() in _STRUCT_TYPES: answer["fields"] = [f.to_api_repr() for f in self.fields] - - # If this contains a policy tag definition, include that as well: - if self.policy_tags is not None: + else: + # Explicitly include policy tag definition (we must not do it for RECORD + # fields, because those are not leaf fields). answer["policyTags"] = self.policy_tags.to_api_repr() # Done; return the serialized dictionary. @@ -244,6 +270,11 @@ def _key(self): field_type = f"{field_type}({self.precision}, {self.scale})" else: field_type = f"{field_type}({self.precision})" + + policy_tags = ( + () if self._policy_tags is None else tuple(sorted(self._policy_tags.names)) + ) + return ( self.name, field_type, @@ -251,7 +282,7 @@ def _key(self): self.mode.upper(), # pytype: disable=attribute-error self.description, self._fields, - self._policy_tags, + policy_tags, ) def to_standard_sql(self) -> types.StandardSqlField: diff --git a/tests/system/test_client.py b/tests/system/test_client.py index c4caadbe9..ce3021399 100644 --- a/tests/system/test_client.py +++ b/tests/system/test_client.py @@ -653,6 +653,56 @@ def test_update_table_schema(self): self.assertEqual(found.field_type, expected.field_type) self.assertEqual(found.mode, expected.mode) + def test_unset_table_schema_attributes(self): + from google.cloud.bigquery.schema import PolicyTagList + + dataset = self.temp_dataset(_make_dataset_id("unset_policy_tags")) + table_id = "test_table" + policy_tags = PolicyTagList( + names=[ + "projects/{}/locations/us/taxonomies/1/policyTags/2".format( + Config.CLIENT.project + ), + ] + ) + + schema = [ + bigquery.SchemaField("full_name", "STRING", mode="REQUIRED"), + bigquery.SchemaField( + "secret_int", + "INTEGER", + mode="REQUIRED", + description="This field is numeric", + policy_tags=policy_tags, + ), + ] + table_arg = Table(dataset.table(table_id), schema=schema) + self.assertFalse(_table_exists(table_arg)) + + table = helpers.retry_403(Config.CLIENT.create_table)(table_arg) + self.to_delete.insert(0, table) + + self.assertTrue(_table_exists(table)) + self.assertEqual(policy_tags, table.schema[1].policy_tags) + + # Amend the schema to replace the policy tags + new_schema = table.schema[:] + old_field = table.schema[1] + new_schema[1] = bigquery.SchemaField( + name=old_field.name, + field_type=old_field.field_type, + mode=old_field.mode, + description=None, + fields=old_field.fields, + policy_tags=None, + ) + + table.schema = new_schema + updated_table = Config.CLIENT.update_table(table, ["schema"]) + + self.assertFalse(updated_table.schema[1].description) # Empty string or None. + self.assertEqual(updated_table.schema[1].policy_tags.names, ()) + def test_update_table_clustering_configuration(self): dataset = self.temp_dataset(_make_dataset_id("update_table")) diff --git a/tests/unit/job/test_load_config.py b/tests/unit/job/test_load_config.py index b0729e428..eafe7e046 100644 --- a/tests/unit/job/test_load_config.py +++ b/tests/unit/job/test_load_config.py @@ -434,11 +434,13 @@ def test_schema_setter_fields(self): "name": "full_name", "type": "STRING", "mode": "REQUIRED", + "policyTags": {"names": []}, } age_repr = { "name": "age", "type": "INTEGER", "mode": "REQUIRED", + "policyTags": {"names": []}, } self.assertEqual( config._properties["load"]["schema"], {"fields": [full_name_repr, age_repr]} @@ -451,11 +453,13 @@ def test_schema_setter_valid_mappings_list(self): "name": "full_name", "type": "STRING", "mode": "REQUIRED", + "policyTags": {"names": []}, } age_repr = { "name": "age", "type": "INTEGER", "mode": "REQUIRED", + "policyTags": {"names": []}, } schema = [full_name_repr, age_repr] config.schema = schema diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 7a28ef248..f6811e207 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -1019,8 +1019,18 @@ def test_create_table_w_schema_and_query(self): { "schema": { "fields": [ - {"name": "full_name", "type": "STRING", "mode": "REQUIRED"}, - {"name": "age", "type": "INTEGER", "mode": "REQUIRED"}, + { + "name": "full_name", + "type": "STRING", + "mode": "REQUIRED", + "policyTags": {"names": []}, + }, + { + "name": "age", + "type": "INTEGER", + "mode": "REQUIRED", + "policyTags": {"names": []}, + }, ] }, "view": {"query": query}, @@ -1054,8 +1064,18 @@ def test_create_table_w_schema_and_query(self): }, "schema": { "fields": [ - {"name": "full_name", "type": "STRING", "mode": "REQUIRED"}, - {"name": "age", "type": "INTEGER", "mode": "REQUIRED"}, + { + "name": "full_name", + "type": "STRING", + "mode": "REQUIRED", + "policyTags": {"names": []}, + }, + { + "name": "age", + "type": "INTEGER", + "mode": "REQUIRED", + "policyTags": {"names": []}, + }, ] }, "view": {"query": query, "useLegacySql": False}, @@ -2000,12 +2020,14 @@ def test_update_table(self): "type": "STRING", "mode": "REQUIRED", "description": None, + "policyTags": {"names": []}, }, { "name": "age", "type": "INTEGER", "mode": "REQUIRED", "description": "New field description", + "policyTags": {"names": []}, }, ] }, @@ -2047,12 +2069,14 @@ def test_update_table(self): "type": "STRING", "mode": "REQUIRED", "description": None, + "policyTags": {"names": []}, }, { "name": "age", "type": "INTEGER", "mode": "REQUIRED", "description": "New field description", + "policyTags": {"names": []}, }, ] }, @@ -2173,14 +2197,21 @@ def test_update_table_w_query(self): "type": "STRING", "mode": "REQUIRED", "description": None, + "policyTags": {"names": []}, }, { "name": "age", "type": "INTEGER", "mode": "REQUIRED", "description": "this is a column", + "policyTags": {"names": []}, + }, + { + "name": "country", + "type": "STRING", + "mode": "NULLABLE", + "policyTags": {"names": []}, }, - {"name": "country", "type": "STRING", "mode": "NULLABLE"}, ] } schema = [ @@ -6516,10 +6547,10 @@ def test_load_table_from_dataframe(self): assert field["type"] == table_field.field_type assert field["mode"] == table_field.mode assert len(field.get("fields", [])) == len(table_field.fields) + assert field["policyTags"]["names"] == [] # Omit unnecessary fields when they come from getting the table # (not passed in via job_config) assert "description" not in field - assert "policyTags" not in field @unittest.skipIf(pandas is None, "Requires `pandas`") @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") @@ -7718,18 +7749,21 @@ def test_schema_to_json_with_file_path(self): "description": "quarter", "mode": "REQUIRED", "name": "qtr", + "policyTags": {"names": []}, "type": "STRING", }, { "description": "sales representative", "mode": "NULLABLE", "name": "rep", + "policyTags": {"names": []}, "type": "STRING", }, { "description": "total sales", "mode": "NULLABLE", "name": "sales", + "policyTags": {"names": []}, "type": "FLOAT", }, ] @@ -7762,18 +7796,21 @@ def test_schema_to_json_with_file_object(self): "description": "quarter", "mode": "REQUIRED", "name": "qtr", + "policyTags": {"names": []}, "type": "STRING", }, { "description": "sales representative", "mode": "NULLABLE", "name": "rep", + "policyTags": {"names": []}, "type": "STRING", }, { "description": "total sales", "mode": "NULLABLE", "name": "sales", + "policyTags": {"names": []}, "type": "FLOAT", }, ] diff --git a/tests/unit/test_external_config.py b/tests/unit/test_external_config.py index 7178367ea..393df931e 100644 --- a/tests/unit/test_external_config.py +++ b/tests/unit/test_external_config.py @@ -78,7 +78,14 @@ def test_to_api_repr_base(self): ec.schema = [schema.SchemaField("full_name", "STRING", mode="REQUIRED")] exp_schema = { - "fields": [{"name": "full_name", "type": "STRING", "mode": "REQUIRED"}] + "fields": [ + { + "name": "full_name", + "type": "STRING", + "mode": "REQUIRED", + "policyTags": {"names": []}, + } + ] } got_resource = ec.to_api_repr() exp_resource = { diff --git a/tests/unit/test_schema.py b/tests/unit/test_schema.py index 29c3bace5..d0b5ca54c 100644 --- a/tests/unit/test_schema.py +++ b/tests/unit/test_schema.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from google.cloud.bigquery.schema import PolicyTagList import unittest import mock @@ -41,6 +42,7 @@ def test_constructor_defaults(self): self.assertEqual(field.mode, "NULLABLE") self.assertIsNone(field.description) self.assertEqual(field.fields, ()) + self.assertEqual(field.policy_tags, PolicyTagList()) def test_constructor_explicit(self): field = self._make_one("test", "STRING", mode="REQUIRED", description="Testing") @@ -104,7 +106,14 @@ def test_to_api_repr_with_subfield(self): self.assertEqual( field.to_api_repr(), { - "fields": [{"mode": "NULLABLE", "name": "bar", "type": "INTEGER"}], + "fields": [ + { + "mode": "NULLABLE", + "name": "bar", + "type": "INTEGER", + "policyTags": {"names": []}, + } + ], "mode": "REQUIRED", "name": "foo", "type": record_type, @@ -404,6 +413,23 @@ def test___eq___hit_w_fields(self): other = self._make_one("test", "RECORD", fields=[sub1, sub2]) self.assertEqual(field, other) + def test___eq___hit_w_policy_tags(self): + field = self._make_one( + "test", + "STRING", + mode="REQUIRED", + description="Testing", + policy_tags=PolicyTagList(names=["foo", "bar"]), + ) + other = self._make_one( + "test", + "STRING", + mode="REQUIRED", + description="Testing", + policy_tags=PolicyTagList(names=["bar", "foo"]), + ) + self.assertEqual(field, other) # Policy tags order does not matter. + def test___ne___wrong_type(self): field = self._make_one("toast", "INTEGER") other = object() @@ -426,6 +452,23 @@ def test___ne___different_values(self): ) self.assertNotEqual(field1, field2) + def test___ne___different_policy_tags(self): + field = self._make_one( + "test", + "STRING", + mode="REQUIRED", + description="Testing", + policy_tags=PolicyTagList(names=["foo", "bar"]), + ) + other = self._make_one( + "test", + "STRING", + mode="REQUIRED", + description="Testing", + policy_tags=PolicyTagList(names=["foo", "baz"]), + ) + self.assertNotEqual(field, other) + def test___hash__set_equality(self): sub1 = self._make_one("sub1", "STRING") sub2 = self._make_one("sub2", "STRING") @@ -446,7 +489,7 @@ def test___hash__not_equals(self): def test___repr__(self): field1 = self._make_one("field1", "STRING") - expected = "SchemaField('field1', 'STRING', 'NULLABLE', None, (), None)" + expected = "SchemaField('field1', 'STRING', 'NULLABLE', None, (), ())" self.assertEqual(repr(field1), expected) @@ -524,10 +567,22 @@ def test_defaults(self): resource = self._call_fut([full_name, age]) self.assertEqual(len(resource), 2) self.assertEqual( - resource[0], {"name": "full_name", "type": "STRING", "mode": "REQUIRED"}, + resource[0], + { + "name": "full_name", + "type": "STRING", + "mode": "REQUIRED", + "policyTags": {"names": []}, + }, ) self.assertEqual( - resource[1], {"name": "age", "type": "INTEGER", "mode": "REQUIRED"} + resource[1], + { + "name": "age", + "type": "INTEGER", + "mode": "REQUIRED", + "policyTags": {"names": []}, + }, ) def test_w_description(self): @@ -553,11 +608,18 @@ def test_w_description(self): "type": "STRING", "mode": "REQUIRED", "description": DESCRIPTION, + "policyTags": {"names": []}, }, ) self.assertEqual( resource[1], - {"name": "age", "type": "INTEGER", "mode": "REQUIRED", "description": None}, + { + "name": "age", + "type": "INTEGER", + "mode": "REQUIRED", + "description": None, + "policyTags": {"names": []}, + }, ) def test_w_subfields(self): @@ -572,7 +634,13 @@ def test_w_subfields(self): resource = self._call_fut([full_name, phone]) self.assertEqual(len(resource), 2) self.assertEqual( - resource[0], {"name": "full_name", "type": "STRING", "mode": "REQUIRED"}, + resource[0], + { + "name": "full_name", + "type": "STRING", + "mode": "REQUIRED", + "policyTags": {"names": []}, + }, ) self.assertEqual( resource[1], @@ -581,8 +649,18 @@ def test_w_subfields(self): "type": "RECORD", "mode": "REPEATED", "fields": [ - {"name": "type", "type": "STRING", "mode": "REQUIRED"}, - {"name": "number", "type": "STRING", "mode": "REQUIRED"}, + { + "name": "type", + "type": "STRING", + "mode": "REQUIRED", + "policyTags": {"names": []}, + }, + { + "name": "number", + "type": "STRING", + "mode": "REQUIRED", + "policyTags": {"names": []}, + }, ], }, ) @@ -794,43 +872,83 @@ def test_from_api_repr_parameterized(api, expect, key2): [ ( dict(name="n", field_type="NUMERIC"), - dict(name="n", type="NUMERIC", mode="NULLABLE"), + dict(name="n", type="NUMERIC", mode="NULLABLE", policyTags={"names": []}), ), ( dict(name="n", field_type="NUMERIC", precision=9), - dict(name="n", type="NUMERIC", mode="NULLABLE", precision=9), + dict( + name="n", + type="NUMERIC", + mode="NULLABLE", + precision=9, + policyTags={"names": []}, + ), ), ( dict(name="n", field_type="NUMERIC", precision=9, scale=2), - dict(name="n", type="NUMERIC", mode="NULLABLE", precision=9, scale=2), + dict( + name="n", + type="NUMERIC", + mode="NULLABLE", + precision=9, + scale=2, + policyTags={"names": []}, + ), ), ( dict(name="n", field_type="BIGNUMERIC"), - dict(name="n", type="BIGNUMERIC", mode="NULLABLE"), + dict( + name="n", type="BIGNUMERIC", mode="NULLABLE", policyTags={"names": []} + ), ), ( dict(name="n", field_type="BIGNUMERIC", precision=40), - dict(name="n", type="BIGNUMERIC", mode="NULLABLE", precision=40), + dict( + name="n", + type="BIGNUMERIC", + mode="NULLABLE", + precision=40, + policyTags={"names": []}, + ), ), ( dict(name="n", field_type="BIGNUMERIC", precision=40, scale=2), - dict(name="n", type="BIGNUMERIC", mode="NULLABLE", precision=40, scale=2), + dict( + name="n", + type="BIGNUMERIC", + mode="NULLABLE", + precision=40, + scale=2, + policyTags={"names": []}, + ), ), ( dict(name="n", field_type="STRING"), - dict(name="n", type="STRING", mode="NULLABLE"), + dict(name="n", type="STRING", mode="NULLABLE", policyTags={"names": []}), ), ( dict(name="n", field_type="STRING", max_length=9), - dict(name="n", type="STRING", mode="NULLABLE", maxLength=9), + dict( + name="n", + type="STRING", + mode="NULLABLE", + maxLength=9, + policyTags={"names": []}, + ), ), ( dict(name="n", field_type="BYTES"), - dict(name="n", type="BYTES", mode="NULLABLE"), + dict(name="n", type="BYTES", mode="NULLABLE", policyTags={"names": []}), ), ( dict(name="n", field_type="BYTES", max_length=9), - dict(name="n", type="BYTES", mode="NULLABLE", maxLength=9), + dict( + name="n", + type="BYTES", + mode="NULLABLE", + maxLength=9, + policyTags={"names": []}, + ), ), ], )