Skip to content

Commit

Permalink
feat: add support for policy tags (#77)
Browse files Browse the repository at this point in the history
* feat: add support for policy tags in schema

* blacken

* add more unit coverage

* more test cleanup

* more tests

* formatting

* more testing of names setter

* address reviewer comments

* docstrings migrate from unions -> optional

* stashing changes

* revision to list-based representation, update tests

* changes to equality and testing, towards satisfying coverage

* cleanup

* return copy

* address api repr feedback

* make PolicyTagList fully immutable

* update docstring

* simplify to_api_repr

* remove stale doc comments

Co-authored-by: Peter Lamut <plamut@users.noreply.github.com>
  • Loading branch information
shollyman and plamut committed May 18, 2020
1 parent 23a173b commit 38a5c01
Show file tree
Hide file tree
Showing 3 changed files with 278 additions and 5 deletions.
118 changes: 116 additions & 2 deletions google/cloud/bigquery/schema.py
Expand Up @@ -62,14 +62,26 @@ class SchemaField(object):
fields (Tuple[google.cloud.bigquery.schema.SchemaField]):
subfields (requires ``field_type`` of 'RECORD').
policy_tags (Optional[PolicyTagList]): The policy tag list for the field.
"""

def __init__(self, name, field_type, mode="NULLABLE", description=None, fields=()):
def __init__(
self,
name,
field_type,
mode="NULLABLE",
description=None,
fields=(),
policy_tags=None,
):
self._name = name
self._field_type = field_type
self._mode = mode
self._description = description
self._fields = tuple(fields)
self._policy_tags = policy_tags

@classmethod
def from_api_repr(cls, api_repr):
Expand All @@ -87,12 +99,14 @@ def from_api_repr(cls, api_repr):
mode = api_repr.get("mode", "NULLABLE")
description = api_repr.get("description")
fields = api_repr.get("fields", ())

return cls(
field_type=api_repr["type"].upper(),
fields=[cls.from_api_repr(f) for f in fields],
mode=mode.upper(),
description=description,
name=api_repr["name"],
policy_tags=PolicyTagList.from_api_repr(api_repr.get("policyTags")),
)

@property
Expand Down Expand Up @@ -136,6 +150,13 @@ def fields(self):
"""
return self._fields

@property
def policy_tags(self):
"""Optional[google.cloud.bigquery.schema.PolicyTagList]: Policy tag list
definition for this field.
"""
return self._policy_tags

def to_api_repr(self):
"""Return a dictionary representing this schema field.
Expand All @@ -155,6 +176,10 @@ def to_api_repr(self):
if self.field_type.upper() in _STRUCT_TYPES:
answer["fields"] = [f.to_api_repr() for f in self.fields]

# If this contains a policy tag definition, include that as well:
if self.policy_tags is not None:
answer["policyTags"] = self.policy_tags.to_api_repr()

# Done; return the serialized dictionary.
return answer

Expand All @@ -172,6 +197,7 @@ def _key(self):
self._mode.upper(),
self._description,
self._fields,
self._policy_tags,
)

def to_standard_sql(self):
Expand Down Expand Up @@ -244,7 +270,10 @@ def _parse_schema_resource(info):
mode = r_field.get("mode", "NULLABLE")
description = r_field.get("description")
sub_fields = _parse_schema_resource(r_field)
schema.append(SchemaField(name, field_type, mode, description, sub_fields))
policy_tags = PolicyTagList.from_api_repr(r_field.get("policyTags"))
schema.append(
SchemaField(name, field_type, mode, description, sub_fields, policy_tags)
)
return schema


Expand Down Expand Up @@ -291,3 +320,88 @@ def _to_schema_fields(schema):
field if isinstance(field, SchemaField) else SchemaField.from_api_repr(field)
for field in schema
]


class PolicyTagList(object):
"""Define Policy Tags for a column.
Args:
names (
Optional[Tuple[str]]): list of policy tags to associate with
the column. Policy tag identifiers are of the form
`projects/*/locations/*/taxonomies/*/policyTags/*`.
"""

def __init__(self, names=()):
self._properties = {}
self._properties["names"] = tuple(names)

@property
def names(self):
"""Tuple[str]: Policy tags associated with this definition.
"""
return self._properties.get("names", ())

def _key(self):
"""A tuple key that uniquely describes this PolicyTagList.
Used to compute this instance's hashcode and evaluate equality.
Returns:
Tuple: The contents of this :class:`~google.cloud.bigquery.schema.PolicyTagList`.
"""
return tuple(sorted(self._properties.items()))

def __eq__(self, other):
if not isinstance(other, PolicyTagList):
return NotImplemented
return self._key() == other._key()

def __ne__(self, other):
return not self == other

def __hash__(self):
return hash(self._key())

def __repr__(self):
return "PolicyTagList{}".format(self._key())

@classmethod
def from_api_repr(cls, api_repr):
"""Return a :class:`PolicyTagList` object deserialized from a dict.
This method creates a new ``PolicyTagList`` instance that points to
the ``api_repr`` parameter as its internal properties dict. This means
that when a ``PolicyTagList`` instance is stored as a property of
another object, any changes made at the higher level will also appear
here.
Args:
api_repr (Mapping[str, str]):
The serialized representation of the PolicyTagList, such as
what is output by :meth:`to_api_repr`.
Returns:
Optional[google.cloud.bigquery.schema.PolicyTagList]:
The ``PolicyTagList`` object or None.
"""
if api_repr is None:
return None
names = api_repr.get("names", ())
return cls(names=names)

def to_api_repr(self):
"""Return a dictionary representing this object.
This method returns the properties dict of the ``PolicyTagList``
instance rather than making a copy. This means that when a
``PolicyTagList`` instance is stored as a property of another
object, any changes made at the higher level will also appear here.
Returns:
dict:
A dictionary representing the PolicyTagList object in
serialized form.
"""
answer = {"names": [name for name in self.names]}
return answer
51 changes: 51 additions & 0 deletions tests/system.py
Expand Up @@ -339,6 +339,57 @@ def test_create_table(self):
self.assertTrue(_table_exists(table))
self.assertEqual(table.table_id, table_id)

def test_create_table_with_policy(self):
from google.cloud.bigquery.schema import PolicyTagList

dataset = self.temp_dataset(_make_dataset_id("create_table_with_policy"))
table_id = "test_table"
policy_1 = PolicyTagList(
names=[
"projects/{}/locations/us/taxonomies/1/policyTags/2".format(
Config.CLIENT.project
),
]
)
policy_2 = PolicyTagList(
names=[
"projects/{}/locations/us/taxonomies/3/policyTags/4".format(
Config.CLIENT.project
),
]
)

schema = [
bigquery.SchemaField("full_name", "STRING", mode="REQUIRED"),
bigquery.SchemaField(
"secret_int", "INTEGER", mode="REQUIRED", policy_tags=policy_1
),
]
table_arg = Table(dataset.table(table_id), schema=schema)
self.assertFalse(_table_exists(table_arg))

table = retry_403(Config.CLIENT.create_table)(table_arg)
self.to_delete.insert(0, table)

self.assertTrue(_table_exists(table))
self.assertEqual(policy_1, table.schema[1].policy_tags)

# Amend the schema to replace the policy tags
new_schema = table.schema[:]
old_field = table.schema[1]
new_schema[1] = bigquery.SchemaField(
name=old_field.name,
field_type=old_field.field_type,
mode=old_field.mode,
description=old_field.description,
fields=old_field.fields,
policy_tags=policy_2,
)

table.schema = new_schema
table2 = Config.CLIENT.update_table(table, ["schema"])
self.assertEqual(policy_2, table2.schema[1].policy_tags)

def test_create_table_w_time_partitioning_w_clustering_fields(self):
from google.cloud.bigquery.table import TimePartitioning
from google.cloud.bigquery.table import TimePartitioningType
Expand Down
114 changes: 111 additions & 3 deletions tests/unit/test_schema.py
Expand Up @@ -63,11 +63,38 @@ def test_constructor_subfields(self):
self.assertIs(field._fields[0], sub_field1)
self.assertIs(field._fields[1], sub_field2)

def test_constructor_with_policy_tags(self):
from google.cloud.bigquery.schema import PolicyTagList

policy = PolicyTagList(names=("foo", "bar"))
field = self._make_one(
"test", "STRING", mode="REQUIRED", description="Testing", policy_tags=policy
)
self.assertEqual(field._name, "test")
self.assertEqual(field._field_type, "STRING")
self.assertEqual(field._mode, "REQUIRED")
self.assertEqual(field._description, "Testing")
self.assertEqual(field._fields, ())
self.assertEqual(field._policy_tags, policy)

def test_to_api_repr(self):
field = self._make_one("foo", "INTEGER", "NULLABLE")
from google.cloud.bigquery.schema import PolicyTagList

policy = PolicyTagList(names=("foo", "bar"))
self.assertEqual(
policy.to_api_repr(), {"names": ["foo", "bar"]},
)

field = self._make_one("foo", "INTEGER", "NULLABLE", policy_tags=policy)
self.assertEqual(
field.to_api_repr(),
{"mode": "NULLABLE", "name": "foo", "type": "INTEGER", "description": None},
{
"mode": "NULLABLE",
"name": "foo",
"type": "INTEGER",
"description": None,
"policyTags": {"names": ["foo", "bar"]},
},
)

def test_to_api_repr_with_subfield(self):
Expand Down Expand Up @@ -111,6 +138,23 @@ def test_from_api_repr(self):
self.assertEqual(field.fields[0].field_type, "INTEGER")
self.assertEqual(field.fields[0].mode, "NULLABLE")

def test_from_api_repr_policy(self):
field = self._get_target_class().from_api_repr(
{
"fields": [{"mode": "nullable", "name": "bar", "type": "integer"}],
"name": "foo",
"type": "record",
"policyTags": {"names": ["one", "two"]},
}
)
self.assertEqual(field.name, "foo")
self.assertEqual(field.field_type, "RECORD")
self.assertEqual(field.policy_tags.names, ("one", "two"))
self.assertEqual(len(field.fields), 1)
self.assertEqual(field.fields[0].name, "bar")
self.assertEqual(field.fields[0].field_type, "INTEGER")
self.assertEqual(field.fields[0].mode, "NULLABLE")

def test_from_api_repr_defaults(self):
field = self._get_target_class().from_api_repr(
{"name": "foo", "type": "record"}
Expand Down Expand Up @@ -408,7 +452,7 @@ def test___hash__not_equals(self):

def test___repr__(self):
field1 = self._make_one("field1", "STRING")
expected = "SchemaField('field1', 'STRING', 'NULLABLE', None, ())"
expected = "SchemaField('field1', 'STRING', 'NULLABLE', None, (), None)"
self.assertEqual(repr(field1), expected)


Expand Down Expand Up @@ -632,3 +676,67 @@ def test_valid_mapping_representation(self):

result = self._call_fut(schema)
self.assertEqual(result, expected_schema)


class TestPolicyTags(unittest.TestCase):
@staticmethod
def _get_target_class():
from google.cloud.bigquery.schema import PolicyTagList

return PolicyTagList

def _make_one(self, *args, **kw):
return self._get_target_class()(*args, **kw)

def test_constructor(self):
empty_policy_tags = self._make_one()
self.assertIsNotNone(empty_policy_tags.names)
self.assertEqual(len(empty_policy_tags.names), 0)
policy_tags = self._make_one(["foo", "bar"])
self.assertEqual(policy_tags.names, ("foo", "bar"))

def test_from_api_repr(self):
klass = self._get_target_class()
api_repr = {"names": ["foo"]}
policy_tags = klass.from_api_repr(api_repr)
self.assertEqual(policy_tags.to_api_repr(), api_repr)

# Ensure the None case correctly returns None, rather
# than an empty instance.
policy_tags2 = klass.from_api_repr(None)
self.assertIsNone(policy_tags2)

def test_to_api_repr(self):
taglist = self._make_one(names=["foo", "bar"])
self.assertEqual(
taglist.to_api_repr(), {"names": ["foo", "bar"]},
)
taglist2 = self._make_one(names=("foo", "bar"))
self.assertEqual(
taglist2.to_api_repr(), {"names": ["foo", "bar"]},
)

def test___eq___wrong_type(self):
policy = self._make_one(names=["foo"])
other = object()
self.assertNotEqual(policy, other)
self.assertEqual(policy, mock.ANY)

def test___eq___names_mismatch(self):
policy = self._make_one(names=["foo", "bar"])
other = self._make_one(names=["bar", "baz"])
self.assertNotEqual(policy, other)

def test___hash__set_equality(self):
policy1 = self._make_one(["foo", "bar"])
policy2 = self._make_one(["bar", "baz"])
set_one = {policy1, policy2}
set_two = {policy1, policy2}
self.assertEqual(set_one, set_two)

def test___hash__not_equals(self):
policy1 = self._make_one(["foo", "bar"])
policy2 = self._make_one(["bar", "baz"])
set_one = {policy1}
set_two = {policy2}
self.assertNotEqual(set_one, set_two)

0 comments on commit 38a5c01

Please sign in to comment.