diff --git a/google/cloud/bigquery/schema.py b/google/cloud/bigquery/schema.py index 3878a80a9..0eaf1201b 100644 --- a/google/cloud/bigquery/schema.py +++ b/google/cloud/bigquery/schema.py @@ -62,14 +62,26 @@ class SchemaField(object): fields (Tuple[google.cloud.bigquery.schema.SchemaField]): subfields (requires ``field_type`` of 'RECORD'). + + policy_tags (Optional[PolicyTagList]): The policy tag list for the field. + """ - def __init__(self, name, field_type, mode="NULLABLE", description=None, fields=()): + def __init__( + self, + name, + field_type, + mode="NULLABLE", + description=None, + fields=(), + policy_tags=None, + ): self._name = name self._field_type = field_type self._mode = mode self._description = description self._fields = tuple(fields) + self._policy_tags = policy_tags @classmethod def from_api_repr(cls, api_repr): @@ -87,12 +99,14 @@ def from_api_repr(cls, api_repr): mode = api_repr.get("mode", "NULLABLE") description = api_repr.get("description") fields = api_repr.get("fields", ()) + return cls( field_type=api_repr["type"].upper(), fields=[cls.from_api_repr(f) for f in fields], mode=mode.upper(), description=description, name=api_repr["name"], + policy_tags=PolicyTagList.from_api_repr(api_repr.get("policyTags")), ) @property @@ -136,6 +150,13 @@ def fields(self): """ return self._fields + @property + def policy_tags(self): + """Optional[google.cloud.bigquery.schema.PolicyTagList]: Policy tag list + definition for this field. + """ + return self._policy_tags + def to_api_repr(self): """Return a dictionary representing this schema field. @@ -155,6 +176,10 @@ def to_api_repr(self): if self.field_type.upper() in _STRUCT_TYPES: answer["fields"] = [f.to_api_repr() for f in self.fields] + # If this contains a policy tag definition, include that as well: + if self.policy_tags is not None: + answer["policyTags"] = self.policy_tags.to_api_repr() + # Done; return the serialized dictionary. return answer @@ -172,6 +197,7 @@ def _key(self): self._mode.upper(), self._description, self._fields, + self._policy_tags, ) def to_standard_sql(self): @@ -244,7 +270,10 @@ def _parse_schema_resource(info): mode = r_field.get("mode", "NULLABLE") description = r_field.get("description") sub_fields = _parse_schema_resource(r_field) - schema.append(SchemaField(name, field_type, mode, description, sub_fields)) + policy_tags = PolicyTagList.from_api_repr(r_field.get("policyTags")) + schema.append( + SchemaField(name, field_type, mode, description, sub_fields, policy_tags) + ) return schema @@ -291,3 +320,88 @@ def _to_schema_fields(schema): field if isinstance(field, SchemaField) else SchemaField.from_api_repr(field) for field in schema ] + + +class PolicyTagList(object): + """Define Policy Tags for a column. + + Args: + names ( + Optional[Tuple[str]]): list of policy tags to associate with + the column. Policy tag identifiers are of the form + `projects/*/locations/*/taxonomies/*/policyTags/*`. + """ + + def __init__(self, names=()): + self._properties = {} + self._properties["names"] = tuple(names) + + @property + def names(self): + """Tuple[str]: Policy tags associated with this definition. + """ + return self._properties.get("names", ()) + + def _key(self): + """A tuple key that uniquely describes this PolicyTagList. + + Used to compute this instance's hashcode and evaluate equality. + + Returns: + Tuple: The contents of this :class:`~google.cloud.bigquery.schema.PolicyTagList`. + """ + return tuple(sorted(self._properties.items())) + + def __eq__(self, other): + if not isinstance(other, PolicyTagList): + return NotImplemented + return self._key() == other._key() + + def __ne__(self, other): + return not self == other + + def __hash__(self): + return hash(self._key()) + + def __repr__(self): + return "PolicyTagList{}".format(self._key()) + + @classmethod + def from_api_repr(cls, api_repr): + """Return a :class:`PolicyTagList` object deserialized from a dict. + + This method creates a new ``PolicyTagList`` instance that points to + the ``api_repr`` parameter as its internal properties dict. This means + that when a ``PolicyTagList`` instance is stored as a property of + another object, any changes made at the higher level will also appear + here. + + Args: + api_repr (Mapping[str, str]): + The serialized representation of the PolicyTagList, such as + what is output by :meth:`to_api_repr`. + + Returns: + Optional[google.cloud.bigquery.schema.PolicyTagList]: + The ``PolicyTagList`` object or None. + """ + if api_repr is None: + return None + names = api_repr.get("names", ()) + return cls(names=names) + + def to_api_repr(self): + """Return a dictionary representing this object. + + This method returns the properties dict of the ``PolicyTagList`` + instance rather than making a copy. This means that when a + ``PolicyTagList`` instance is stored as a property of another + object, any changes made at the higher level will also appear here. + + Returns: + dict: + A dictionary representing the PolicyTagList object in + serialized form. + """ + answer = {"names": [name for name in self.names]} + return answer diff --git a/tests/system.py b/tests/system.py index b86684675..49e45c772 100644 --- a/tests/system.py +++ b/tests/system.py @@ -339,6 +339,57 @@ def test_create_table(self): self.assertTrue(_table_exists(table)) self.assertEqual(table.table_id, table_id) + def test_create_table_with_policy(self): + from google.cloud.bigquery.schema import PolicyTagList + + dataset = self.temp_dataset(_make_dataset_id("create_table_with_policy")) + table_id = "test_table" + policy_1 = PolicyTagList( + names=[ + "projects/{}/locations/us/taxonomies/1/policyTags/2".format( + Config.CLIENT.project + ), + ] + ) + policy_2 = PolicyTagList( + names=[ + "projects/{}/locations/us/taxonomies/3/policyTags/4".format( + Config.CLIENT.project + ), + ] + ) + + schema = [ + bigquery.SchemaField("full_name", "STRING", mode="REQUIRED"), + bigquery.SchemaField( + "secret_int", "INTEGER", mode="REQUIRED", policy_tags=policy_1 + ), + ] + table_arg = Table(dataset.table(table_id), schema=schema) + self.assertFalse(_table_exists(table_arg)) + + table = retry_403(Config.CLIENT.create_table)(table_arg) + self.to_delete.insert(0, table) + + self.assertTrue(_table_exists(table)) + self.assertEqual(policy_1, table.schema[1].policy_tags) + + # Amend the schema to replace the policy tags + new_schema = table.schema[:] + old_field = table.schema[1] + new_schema[1] = bigquery.SchemaField( + name=old_field.name, + field_type=old_field.field_type, + mode=old_field.mode, + description=old_field.description, + fields=old_field.fields, + policy_tags=policy_2, + ) + + table.schema = new_schema + table2 = Config.CLIENT.update_table(table, ["schema"]) + self.assertEqual(policy_2, table2.schema[1].policy_tags) + def test_create_table_w_time_partitioning_w_clustering_fields(self): from google.cloud.bigquery.table import TimePartitioning from google.cloud.bigquery.table import TimePartitioningType diff --git a/tests/unit/test_schema.py b/tests/unit/test_schema.py index e1bdd7b2f..9f7ee7bb3 100644 --- a/tests/unit/test_schema.py +++ b/tests/unit/test_schema.py @@ -63,11 +63,38 @@ def test_constructor_subfields(self): self.assertIs(field._fields[0], sub_field1) self.assertIs(field._fields[1], sub_field2) + def test_constructor_with_policy_tags(self): + from google.cloud.bigquery.schema import PolicyTagList + + policy = PolicyTagList(names=("foo", "bar")) + field = self._make_one( + "test", "STRING", mode="REQUIRED", description="Testing", policy_tags=policy + ) + self.assertEqual(field._name, "test") + self.assertEqual(field._field_type, "STRING") + self.assertEqual(field._mode, "REQUIRED") + self.assertEqual(field._description, "Testing") + self.assertEqual(field._fields, ()) + self.assertEqual(field._policy_tags, policy) + def test_to_api_repr(self): - field = self._make_one("foo", "INTEGER", "NULLABLE") + from google.cloud.bigquery.schema import PolicyTagList + + policy = PolicyTagList(names=("foo", "bar")) + self.assertEqual( + policy.to_api_repr(), {"names": ["foo", "bar"]}, + ) + + field = self._make_one("foo", "INTEGER", "NULLABLE", policy_tags=policy) self.assertEqual( field.to_api_repr(), - {"mode": "NULLABLE", "name": "foo", "type": "INTEGER", "description": None}, + { + "mode": "NULLABLE", + "name": "foo", + "type": "INTEGER", + "description": None, + "policyTags": {"names": ["foo", "bar"]}, + }, ) def test_to_api_repr_with_subfield(self): @@ -111,6 +138,23 @@ def test_from_api_repr(self): self.assertEqual(field.fields[0].field_type, "INTEGER") self.assertEqual(field.fields[0].mode, "NULLABLE") + def test_from_api_repr_policy(self): + field = self._get_target_class().from_api_repr( + { + "fields": [{"mode": "nullable", "name": "bar", "type": "integer"}], + "name": "foo", + "type": "record", + "policyTags": {"names": ["one", "two"]}, + } + ) + self.assertEqual(field.name, "foo") + self.assertEqual(field.field_type, "RECORD") + self.assertEqual(field.policy_tags.names, ("one", "two")) + self.assertEqual(len(field.fields), 1) + self.assertEqual(field.fields[0].name, "bar") + self.assertEqual(field.fields[0].field_type, "INTEGER") + self.assertEqual(field.fields[0].mode, "NULLABLE") + def test_from_api_repr_defaults(self): field = self._get_target_class().from_api_repr( {"name": "foo", "type": "record"} @@ -408,7 +452,7 @@ def test___hash__not_equals(self): def test___repr__(self): field1 = self._make_one("field1", "STRING") - expected = "SchemaField('field1', 'STRING', 'NULLABLE', None, ())" + expected = "SchemaField('field1', 'STRING', 'NULLABLE', None, (), None)" self.assertEqual(repr(field1), expected) @@ -632,3 +676,67 @@ def test_valid_mapping_representation(self): result = self._call_fut(schema) self.assertEqual(result, expected_schema) + + +class TestPolicyTags(unittest.TestCase): + @staticmethod + def _get_target_class(): + from google.cloud.bigquery.schema import PolicyTagList + + return PolicyTagList + + def _make_one(self, *args, **kw): + return self._get_target_class()(*args, **kw) + + def test_constructor(self): + empty_policy_tags = self._make_one() + self.assertIsNotNone(empty_policy_tags.names) + self.assertEqual(len(empty_policy_tags.names), 0) + policy_tags = self._make_one(["foo", "bar"]) + self.assertEqual(policy_tags.names, ("foo", "bar")) + + def test_from_api_repr(self): + klass = self._get_target_class() + api_repr = {"names": ["foo"]} + policy_tags = klass.from_api_repr(api_repr) + self.assertEqual(policy_tags.to_api_repr(), api_repr) + + # Ensure the None case correctly returns None, rather + # than an empty instance. + policy_tags2 = klass.from_api_repr(None) + self.assertIsNone(policy_tags2) + + def test_to_api_repr(self): + taglist = self._make_one(names=["foo", "bar"]) + self.assertEqual( + taglist.to_api_repr(), {"names": ["foo", "bar"]}, + ) + taglist2 = self._make_one(names=("foo", "bar")) + self.assertEqual( + taglist2.to_api_repr(), {"names": ["foo", "bar"]}, + ) + + def test___eq___wrong_type(self): + policy = self._make_one(names=["foo"]) + other = object() + self.assertNotEqual(policy, other) + self.assertEqual(policy, mock.ANY) + + def test___eq___names_mismatch(self): + policy = self._make_one(names=["foo", "bar"]) + other = self._make_one(names=["bar", "baz"]) + self.assertNotEqual(policy, other) + + def test___hash__set_equality(self): + policy1 = self._make_one(["foo", "bar"]) + policy2 = self._make_one(["bar", "baz"]) + set_one = {policy1, policy2} + set_two = {policy1, policy2} + self.assertEqual(set_one, set_two) + + def test___hash__not_equals(self): + policy1 = self._make_one(["foo", "bar"]) + policy2 = self._make_one(["bar", "baz"]) + set_one = {policy1} + set_two = {policy2} + self.assertNotEqual(set_one, set_two)