Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for policy tags #77

Merged
merged 22 commits into from May 18, 2020
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
111 changes: 109 additions & 2 deletions google/cloud/bigquery/schema.py
Expand Up @@ -14,6 +14,8 @@

"""Schemas for BigQuery tables / queries."""

import copy

from six.moves import collections_abc

from google.cloud.bigquery_v2 import types
Expand Down Expand Up @@ -62,14 +64,26 @@ class SchemaField(object):

fields (Tuple[google.cloud.bigquery.schema.SchemaField]):
subfields (requires ``field_type`` of 'RECORD').

policy_tags (Optional[PolicyTagList]): The policy tag list for the field.

"""

def __init__(self, name, field_type, mode="NULLABLE", description=None, fields=()):
def __init__(
self,
name,
field_type,
mode="NULLABLE",
description=None,
fields=(),
policy_tags=None,
):
self._name = name
self._field_type = field_type
self._mode = mode
self._description = description
self._fields = tuple(fields)
self._policy_tags = policy_tags

@classmethod
def from_api_repr(cls, api_repr):
Expand All @@ -87,12 +101,16 @@ def from_api_repr(cls, api_repr):
mode = api_repr.get("mode", "NULLABLE")
description = api_repr.get("description")
fields = api_repr.get("fields", ())
policy_tags = api_repr.get("policyTags")
if policy_tags is not None:
policy_tags = PolicyTagList.from_api_repr(policy_tags)
return cls(
field_type=api_repr["type"].upper(),
fields=[cls.from_api_repr(f) for f in fields],
mode=mode.upper(),
description=description,
name=api_repr["name"],
policy_tags=policy_tags,
)

@property
Expand Down Expand Up @@ -136,6 +154,18 @@ def fields(self):
"""
return self._fields

@property
def policy_tags(self):
"""Optional[google.cloud.bigquery.schema.PolicyTagList]: Policy tag list
definition for this field.

Raises:
ValueError:
If the value is not :class:`~google.cloud.bigquery.schema.PolicyTagList`
or :data:`None`.
shollyman marked this conversation as resolved.
Show resolved Hide resolved
"""
return self._policy_tags

def to_api_repr(self):
"""Return a dictionary representing this schema field.

Expand All @@ -155,6 +185,10 @@ def to_api_repr(self):
if self.field_type.upper() in _STRUCT_TYPES:
answer["fields"] = [f.to_api_repr() for f in self.fields]

# If this contains a policy tag definition, include that as well:
if self.policy_tags is not None:
answer["policyTags"] = self.policy_tags.to_api_repr()

# Done; return the serialized dictionary.
return answer

Expand All @@ -172,6 +206,7 @@ def _key(self):
self._mode.upper(),
self._description,
self._fields,
self._policy_tags,
)

def to_standard_sql(self):
Expand Down Expand Up @@ -244,7 +279,10 @@ def _parse_schema_resource(info):
mode = r_field.get("mode", "NULLABLE")
description = r_field.get("description")
sub_fields = _parse_schema_resource(r_field)
schema.append(SchemaField(name, field_type, mode, description, sub_fields))
policy_tags = r_field.get("policyTags")
schema.append(
SchemaField(name, field_type, mode, description, sub_fields, policy_tags)
)
return schema


Expand Down Expand Up @@ -291,3 +329,72 @@ def _to_schema_fields(schema):
field if isinstance(field, SchemaField) else SchemaField.from_api_repr(field)
for field in schema
]


class PolicyTagList(object):
"""Define Policy Tags for a column.

Args:
names (Optional[Tuple[str]]): list of policy tags to associate with
the column.
"""

def __init__(self, names=None):
self._properties = {}
self.names = names

@property
def names(self):
"""Tuple[str]: Policy tags associated with this definition.
"""
return self._properties.get("names", ())
plamut marked this conversation as resolved.
Show resolved Hide resolved

@names.setter
def names(self, value):
"""Optional[Tuple[str]]: Policy tags associated with this definition.

(Defaults to :data:`None`).
"""
if value is not None:
self._properties["names"] = value
else:
if "names" in self._properties:
del self._properties["names"]

@classmethod
def from_api_repr(cls, api_repr):
"""Return a :class:`PolicyTagList` object deserialized from a dict.

This method creates a new ``PolicyTagList`` instance that points to
the ``api_repr`` parameter as its internal properties dict. This means
that when a ``PolicyTagList`` instance is stored as a property of
another object, any changes made at the higher level will also appear
here.

Args:
api_repr (Mapping[str, str]):
The serialized representation of the PolicyTagList, such as
what is output by :meth:`to_api_repr`.

Returns:
google.cloud.bigquery.schema.PolicyTagList:
shollyman marked this conversation as resolved.
Show resolved Hide resolved
The ``PolicyTagList`` object.
"""
instance = cls()
instance._properties = copy.deepcopy(api_repr)
return instance

def to_api_repr(self):
"""Return a dictionary representing this object.

This method returns the properties dict of the ``PolicyTagList``
instance rather than making a copy. This means that when a
``PolicyTagList`` instance is stored as a property of another
object, any changes made at the higher level will also appear here.

Returns:
dict:
A dictionary representing the PolicyTagList object in
serialized form.
"""
return self._properties
plamut marked this conversation as resolved.
Show resolved Hide resolved
90 changes: 87 additions & 3 deletions tests/unit/test_schema.py
Expand Up @@ -63,11 +63,38 @@ def test_constructor_subfields(self):
self.assertIs(field._fields[0], sub_field1)
self.assertIs(field._fields[1], sub_field2)

def test_constructor_with_policy_tags(self):
from google.cloud.bigquery.schema import PolicyTagList

policy = PolicyTagList(names=("foo", "bar"))
field = self._make_one(
"test", "STRING", mode="REQUIRED", description="Testing", policy_tags=policy
)
self.assertEqual(field._name, "test")
self.assertEqual(field._field_type, "STRING")
self.assertEqual(field._mode, "REQUIRED")
self.assertEqual(field._description, "Testing")
self.assertEqual(field._fields, ())
self.assertEqual(field._policy_tags, policy)

def test_to_api_repr(self):
field = self._make_one("foo", "INTEGER", "NULLABLE")
from google.cloud.bigquery.schema import PolicyTagList

policy = PolicyTagList(names=("foo", "bar"))
self.assertEqual(
policy.to_api_repr(), {"names": ("foo", "bar")},
)

field = self._make_one("foo", "INTEGER", "NULLABLE", policy_tags=policy)
self.assertEqual(
field.to_api_repr(),
{"mode": "NULLABLE", "name": "foo", "type": "INTEGER", "description": None},
{
"mode": "NULLABLE",
"name": "foo",
"type": "INTEGER",
"description": None,
"policyTags": {"names": ("foo", "bar")},
},
)

def test_to_api_repr_with_subfield(self):
Expand Down Expand Up @@ -111,6 +138,23 @@ def test_from_api_repr(self):
self.assertEqual(field.fields[0].field_type, "INTEGER")
self.assertEqual(field.fields[0].mode, "NULLABLE")

def test_from_api_repr_policy(self):
field = self._get_target_class().from_api_repr(
{
"fields": [{"mode": "nullable", "name": "bar", "type": "integer"}],
"name": "foo",
"type": "record",
"policyTags": {"names": ("one", "two")},
}
)
self.assertEqual(field.name, "foo")
self.assertEqual(field.field_type, "RECORD")
self.assertEqual(field.policy_tags.names, ("one", "two"))
self.assertEqual(len(field.fields), 1)
self.assertEqual(field.fields[0].name, "bar")
self.assertEqual(field.fields[0].field_type, "INTEGER")
self.assertEqual(field.fields[0].mode, "NULLABLE")

def test_from_api_repr_defaults(self):
field = self._get_target_class().from_api_repr(
{"name": "foo", "type": "record"}
Expand Down Expand Up @@ -408,7 +452,7 @@ def test___hash__not_equals(self):

def test___repr__(self):
field1 = self._make_one("field1", "STRING")
expected = "SchemaField('field1', 'STRING', 'NULLABLE', None, ())"
expected = "SchemaField('field1', 'STRING', 'NULLABLE', None, (), None)"
self.assertEqual(repr(field1), expected)


Expand Down Expand Up @@ -632,3 +676,43 @@ def test_valid_mapping_representation(self):

result = self._call_fut(schema)
self.assertEqual(result, expected_schema)


class TestPolicyTags(unittest.TestCase):
@staticmethod
def _get_target_class():
from google.cloud.bigquery.schema import PolicyTagList

return PolicyTagList

def _make_one(self, *args, **kw):
return self._get_target_class()(*args, **kw)

def test_constructor(self):
empty_policy_tags = self._make_one()
self.assertIsNotNone(empty_policy_tags.names)
self.assertEqual(len(empty_policy_tags.names), 0)
policy_tags = self._make_one(("foo", "bar"))
self.assertEqual(policy_tags.names, ("foo", "bar"))

def test_from_api_repr(self):
klass = self._get_target_class()
api_repr = {"names": ("foo")}
policy_tags = klass.from_api_repr(api_repr)
self.assertEqual(policy_tags.to_api_repr(), api_repr)

def test_to_api_repr(self):
taglist = self._make_one(names=("foo", "bar"))
self.assertEqual(
taglist.to_api_repr(), {"names": ("foo", "bar")},
)

def test_setter(self):
policy_tags = self._make_one()
self.assertEqual(policy_tags.names, ())
policy_tags.names = None
self.assertEqual(policy_tags.names, ())
policy_tags.names = ("foo", "bar")
self.assertEqual(policy_tags.names, ("foo", "bar"))
policy_tags.names = None
self.assertEqual(policy_tags.names, ())