Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for policy tags #77

Merged
merged 22 commits into from May 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
118 changes: 116 additions & 2 deletions google/cloud/bigquery/schema.py
Expand Up @@ -62,14 +62,26 @@ class SchemaField(object):

fields (Tuple[google.cloud.bigquery.schema.SchemaField]):
subfields (requires ``field_type`` of 'RECORD').

policy_tags (Optional[PolicyTagList]): The policy tag list for the field.

"""

def __init__(self, name, field_type, mode="NULLABLE", description=None, fields=()):
def __init__(
self,
name,
field_type,
mode="NULLABLE",
description=None,
fields=(),
policy_tags=None,
):
self._name = name
self._field_type = field_type
self._mode = mode
self._description = description
self._fields = tuple(fields)
self._policy_tags = policy_tags

@classmethod
def from_api_repr(cls, api_repr):
Expand All @@ -87,12 +99,14 @@ def from_api_repr(cls, api_repr):
mode = api_repr.get("mode", "NULLABLE")
description = api_repr.get("description")
fields = api_repr.get("fields", ())

return cls(
field_type=api_repr["type"].upper(),
fields=[cls.from_api_repr(f) for f in fields],
mode=mode.upper(),
description=description,
name=api_repr["name"],
policy_tags=PolicyTagList.from_api_repr(api_repr.get("policyTags")),
)

@property
Expand Down Expand Up @@ -136,6 +150,13 @@ def fields(self):
"""
return self._fields

@property
def policy_tags(self):
"""Optional[google.cloud.bigquery.schema.PolicyTagList]: Policy tag list
definition for this field.
"""
return self._policy_tags

def to_api_repr(self):
"""Return a dictionary representing this schema field.

Expand All @@ -155,6 +176,10 @@ def to_api_repr(self):
if self.field_type.upper() in _STRUCT_TYPES:
answer["fields"] = [f.to_api_repr() for f in self.fields]

# If this contains a policy tag definition, include that as well:
if self.policy_tags is not None:
answer["policyTags"] = self.policy_tags.to_api_repr()

# Done; return the serialized dictionary.
return answer

Expand All @@ -172,6 +197,7 @@ def _key(self):
self._mode.upper(),
self._description,
self._fields,
self._policy_tags,
)

def to_standard_sql(self):
Expand Down Expand Up @@ -244,7 +270,10 @@ def _parse_schema_resource(info):
mode = r_field.get("mode", "NULLABLE")
description = r_field.get("description")
sub_fields = _parse_schema_resource(r_field)
schema.append(SchemaField(name, field_type, mode, description, sub_fields))
policy_tags = PolicyTagList.from_api_repr(r_field.get("policyTags"))
schema.append(
SchemaField(name, field_type, mode, description, sub_fields, policy_tags)
)
return schema


Expand Down Expand Up @@ -291,3 +320,88 @@ def _to_schema_fields(schema):
field if isinstance(field, SchemaField) else SchemaField.from_api_repr(field)
for field in schema
]


class PolicyTagList(object):
"""Define Policy Tags for a column.

Args:
names (
Optional[Tuple[str]]): list of policy tags to associate with
the column. Policy tag identifiers are of the form
`projects/*/locations/*/taxonomies/*/policyTags/*`.
"""

def __init__(self, names=()):
self._properties = {}
self._properties["names"] = tuple(names)

@property
def names(self):
"""Tuple[str]: Policy tags associated with this definition.
"""
return self._properties.get("names", ())
plamut marked this conversation as resolved.
Show resolved Hide resolved

def _key(self):
"""A tuple key that uniquely describes this PolicyTagList.

Used to compute this instance's hashcode and evaluate equality.

Returns:
Tuple: The contents of this :class:`~google.cloud.bigquery.schema.PolicyTagList`.
"""
return tuple(sorted(self._properties.items()))

def __eq__(self, other):
if not isinstance(other, PolicyTagList):
return NotImplemented
return self._key() == other._key()

def __ne__(self, other):
return not self == other

def __hash__(self):
return hash(self._key())
plamut marked this conversation as resolved.
Show resolved Hide resolved

def __repr__(self):
return "PolicyTagList{}".format(self._key())

@classmethod
def from_api_repr(cls, api_repr):
"""Return a :class:`PolicyTagList` object deserialized from a dict.

This method creates a new ``PolicyTagList`` instance that points to
the ``api_repr`` parameter as its internal properties dict. This means
that when a ``PolicyTagList`` instance is stored as a property of
another object, any changes made at the higher level will also appear
here.

Args:
api_repr (Mapping[str, str]):
The serialized representation of the PolicyTagList, such as
what is output by :meth:`to_api_repr`.

Returns:
Optional[google.cloud.bigquery.schema.PolicyTagList]:
The ``PolicyTagList`` object or None.
"""
if api_repr is None:
return None
names = api_repr.get("names", ())
return cls(names=names)

def to_api_repr(self):
"""Return a dictionary representing this object.

This method returns the properties dict of the ``PolicyTagList``
instance rather than making a copy. This means that when a
``PolicyTagList`` instance is stored as a property of another
object, any changes made at the higher level will also appear here.

Returns:
dict:
A dictionary representing the PolicyTagList object in
serialized form.
"""
answer = {"names": [name for name in self.names]}
return answer
51 changes: 51 additions & 0 deletions tests/system.py
Expand Up @@ -339,6 +339,57 @@ def test_create_table(self):
self.assertTrue(_table_exists(table))
self.assertEqual(table.table_id, table_id)

def test_create_table_with_policy(self):
from google.cloud.bigquery.schema import PolicyTagList

dataset = self.temp_dataset(_make_dataset_id("create_table_with_policy"))
table_id = "test_table"
policy_1 = PolicyTagList(
names=[
"projects/{}/locations/us/taxonomies/1/policyTags/2".format(
Config.CLIENT.project
),
]
)
policy_2 = PolicyTagList(
names=[
"projects/{}/locations/us/taxonomies/3/policyTags/4".format(
Config.CLIENT.project
),
]
)

schema = [
bigquery.SchemaField("full_name", "STRING", mode="REQUIRED"),
bigquery.SchemaField(
"secret_int", "INTEGER", mode="REQUIRED", policy_tags=policy_1
),
]
table_arg = Table(dataset.table(table_id), schema=schema)
self.assertFalse(_table_exists(table_arg))

table = retry_403(Config.CLIENT.create_table)(table_arg)
self.to_delete.insert(0, table)

self.assertTrue(_table_exists(table))
self.assertEqual(policy_1, table.schema[1].policy_tags)

# Amend the schema to replace the policy tags
new_schema = table.schema[:]
old_field = table.schema[1]
new_schema[1] = bigquery.SchemaField(
name=old_field.name,
field_type=old_field.field_type,
mode=old_field.mode,
description=old_field.description,
fields=old_field.fields,
policy_tags=policy_2,
)

table.schema = new_schema
table2 = Config.CLIENT.update_table(table, ["schema"])
self.assertEqual(policy_2, table2.schema[1].policy_tags)

def test_create_table_w_time_partitioning_w_clustering_fields(self):
from google.cloud.bigquery.table import TimePartitioning
from google.cloud.bigquery.table import TimePartitioningType
Expand Down
114 changes: 111 additions & 3 deletions tests/unit/test_schema.py
Expand Up @@ -63,11 +63,38 @@ def test_constructor_subfields(self):
self.assertIs(field._fields[0], sub_field1)
self.assertIs(field._fields[1], sub_field2)

def test_constructor_with_policy_tags(self):
from google.cloud.bigquery.schema import PolicyTagList

policy = PolicyTagList(names=("foo", "bar"))
field = self._make_one(
"test", "STRING", mode="REQUIRED", description="Testing", policy_tags=policy
)
self.assertEqual(field._name, "test")
self.assertEqual(field._field_type, "STRING")
self.assertEqual(field._mode, "REQUIRED")
self.assertEqual(field._description, "Testing")
self.assertEqual(field._fields, ())
self.assertEqual(field._policy_tags, policy)

def test_to_api_repr(self):
field = self._make_one("foo", "INTEGER", "NULLABLE")
from google.cloud.bigquery.schema import PolicyTagList

policy = PolicyTagList(names=("foo", "bar"))
self.assertEqual(
policy.to_api_repr(), {"names": ["foo", "bar"]},
)

field = self._make_one("foo", "INTEGER", "NULLABLE", policy_tags=policy)
self.assertEqual(
field.to_api_repr(),
{"mode": "NULLABLE", "name": "foo", "type": "INTEGER", "description": None},
{
"mode": "NULLABLE",
"name": "foo",
"type": "INTEGER",
"description": None,
"policyTags": {"names": ["foo", "bar"]},
},
)

def test_to_api_repr_with_subfield(self):
Expand Down Expand Up @@ -111,6 +138,23 @@ def test_from_api_repr(self):
self.assertEqual(field.fields[0].field_type, "INTEGER")
self.assertEqual(field.fields[0].mode, "NULLABLE")

def test_from_api_repr_policy(self):
field = self._get_target_class().from_api_repr(
{
"fields": [{"mode": "nullable", "name": "bar", "type": "integer"}],
"name": "foo",
"type": "record",
"policyTags": {"names": ["one", "two"]},
}
)
self.assertEqual(field.name, "foo")
self.assertEqual(field.field_type, "RECORD")
self.assertEqual(field.policy_tags.names, ("one", "two"))
self.assertEqual(len(field.fields), 1)
self.assertEqual(field.fields[0].name, "bar")
self.assertEqual(field.fields[0].field_type, "INTEGER")
self.assertEqual(field.fields[0].mode, "NULLABLE")

def test_from_api_repr_defaults(self):
field = self._get_target_class().from_api_repr(
{"name": "foo", "type": "record"}
Expand Down Expand Up @@ -408,7 +452,7 @@ def test___hash__not_equals(self):

def test___repr__(self):
field1 = self._make_one("field1", "STRING")
expected = "SchemaField('field1', 'STRING', 'NULLABLE', None, ())"
expected = "SchemaField('field1', 'STRING', 'NULLABLE', None, (), None)"
self.assertEqual(repr(field1), expected)


Expand Down Expand Up @@ -632,3 +676,67 @@ def test_valid_mapping_representation(self):

result = self._call_fut(schema)
self.assertEqual(result, expected_schema)


class TestPolicyTags(unittest.TestCase):
@staticmethod
def _get_target_class():
from google.cloud.bigquery.schema import PolicyTagList

return PolicyTagList

def _make_one(self, *args, **kw):
return self._get_target_class()(*args, **kw)

def test_constructor(self):
empty_policy_tags = self._make_one()
self.assertIsNotNone(empty_policy_tags.names)
self.assertEqual(len(empty_policy_tags.names), 0)
policy_tags = self._make_one(["foo", "bar"])
self.assertEqual(policy_tags.names, ("foo", "bar"))

def test_from_api_repr(self):
klass = self._get_target_class()
api_repr = {"names": ["foo"]}
policy_tags = klass.from_api_repr(api_repr)
self.assertEqual(policy_tags.to_api_repr(), api_repr)
shollyman marked this conversation as resolved.
Show resolved Hide resolved

# Ensure the None case correctly returns None, rather
# than an empty instance.
policy_tags2 = klass.from_api_repr(None)
self.assertIsNone(policy_tags2)

def test_to_api_repr(self):
taglist = self._make_one(names=["foo", "bar"])
self.assertEqual(
taglist.to_api_repr(), {"names": ["foo", "bar"]},
)
taglist2 = self._make_one(names=("foo", "bar"))
self.assertEqual(
taglist2.to_api_repr(), {"names": ["foo", "bar"]},
)

def test___eq___wrong_type(self):
policy = self._make_one(names=["foo"])
other = object()
self.assertNotEqual(policy, other)
self.assertEqual(policy, mock.ANY)

def test___eq___names_mismatch(self):
policy = self._make_one(names=["foo", "bar"])
other = self._make_one(names=["bar", "baz"])
self.assertNotEqual(policy, other)

def test___hash__set_equality(self):
policy1 = self._make_one(["foo", "bar"])
policy2 = self._make_one(["bar", "baz"])
set_one = {policy1, policy2}
set_two = {policy1, policy2}
self.assertEqual(set_one, set_two)

def test___hash__not_equals(self):
policy1 = self._make_one(["foo", "bar"])
policy2 = self._make_one(["bar", "baz"])
set_one = {policy1}
set_two = {policy2}
self.assertNotEqual(set_one, set_two)