Skip to content

Commit

Permalink
feat: enable unsetting policy tags on schema fields (#703)
Browse files Browse the repository at this point in the history
* feat: enable unsetting policy tags on schema fields

* Adjust API representation for STRUCT schema fields

* De-dup logic for converting None policy tags
  • Loading branch information
plamut committed Jun 21, 2021
1 parent 0b20015 commit 18bb443
Show file tree
Hide file tree
Showing 6 changed files with 279 additions and 32 deletions.
45 changes: 38 additions & 7 deletions google/cloud/bigquery/schema.py
Expand Up @@ -15,6 +15,7 @@
"""Schemas for BigQuery tables / queries."""

import collections
from typing import Optional

from google.cloud.bigquery_v2 import types

Expand Down Expand Up @@ -105,7 +106,26 @@ def __init__(
if max_length is not _DEFAULT_VALUE:
self._properties["maxLength"] = max_length
self._fields = tuple(fields)
self._policy_tags = policy_tags

self._policy_tags = self._determine_policy_tags(field_type, policy_tags)

@staticmethod
def _determine_policy_tags(
field_type: str, given_policy_tags: Optional["PolicyTagList"]
) -> Optional["PolicyTagList"]:
"""Return the given policy tags, or their suitable representation if `None`.
Args:
field_type: The type of the schema field.
given_policy_tags: The policy tags to maybe ajdust.
"""
if given_policy_tags is not None:
return given_policy_tags

if field_type is not None and field_type.upper() in _STRUCT_TYPES:
return None

return PolicyTagList()

@staticmethod
def __get_int(api_repr, name):
Expand All @@ -126,18 +146,24 @@ def from_api_repr(cls, api_repr: dict) -> "SchemaField":
Returns:
google.cloud.biquery.schema.SchemaField: The ``SchemaField`` object.
"""
field_type = api_repr["type"].upper()

# Handle optional properties with default values
mode = api_repr.get("mode", "NULLABLE")
description = api_repr.get("description", _DEFAULT_VALUE)
fields = api_repr.get("fields", ())

policy_tags = cls._determine_policy_tags(
field_type, PolicyTagList.from_api_repr(api_repr.get("policyTags"))
)

return cls(
field_type=api_repr["type"].upper(),
field_type=field_type,
fields=[cls.from_api_repr(f) for f in fields],
mode=mode.upper(),
description=description,
name=api_repr["name"],
policy_tags=PolicyTagList.from_api_repr(api_repr.get("policyTags")),
policy_tags=policy_tags,
precision=cls.__get_int(api_repr, "precision"),
scale=cls.__get_int(api_repr, "scale"),
max_length=cls.__get_int(api_repr, "maxLength"),
Expand Down Expand Up @@ -218,9 +244,9 @@ def to_api_repr(self) -> dict:
# add this to the serialized representation.
if self.field_type.upper() in _STRUCT_TYPES:
answer["fields"] = [f.to_api_repr() for f in self.fields]

# If this contains a policy tag definition, include that as well:
if self.policy_tags is not None:
else:
# Explicitly include policy tag definition (we must not do it for RECORD
# fields, because those are not leaf fields).
answer["policyTags"] = self.policy_tags.to_api_repr()

# Done; return the serialized dictionary.
Expand All @@ -244,14 +270,19 @@ def _key(self):
field_type = f"{field_type}({self.precision}, {self.scale})"
else:
field_type = f"{field_type}({self.precision})"

policy_tags = (
() if self._policy_tags is None else tuple(sorted(self._policy_tags.names))
)

return (
self.name,
field_type,
# Mode is always str, if not given it defaults to a str value
self.mode.upper(), # pytype: disable=attribute-error
self.description,
self._fields,
self._policy_tags,
policy_tags,
)

def to_standard_sql(self) -> types.StandardSqlField:
Expand Down
50 changes: 50 additions & 0 deletions tests/system/test_client.py
Expand Up @@ -653,6 +653,56 @@ def test_update_table_schema(self):
self.assertEqual(found.field_type, expected.field_type)
self.assertEqual(found.mode, expected.mode)

def test_unset_table_schema_attributes(self):
from google.cloud.bigquery.schema import PolicyTagList

dataset = self.temp_dataset(_make_dataset_id("unset_policy_tags"))
table_id = "test_table"
policy_tags = PolicyTagList(
names=[
"projects/{}/locations/us/taxonomies/1/policyTags/2".format(
Config.CLIENT.project
),
]
)

schema = [
bigquery.SchemaField("full_name", "STRING", mode="REQUIRED"),
bigquery.SchemaField(
"secret_int",
"INTEGER",
mode="REQUIRED",
description="This field is numeric",
policy_tags=policy_tags,
),
]
table_arg = Table(dataset.table(table_id), schema=schema)
self.assertFalse(_table_exists(table_arg))

table = helpers.retry_403(Config.CLIENT.create_table)(table_arg)
self.to_delete.insert(0, table)

self.assertTrue(_table_exists(table))
self.assertEqual(policy_tags, table.schema[1].policy_tags)

# Amend the schema to replace the policy tags
new_schema = table.schema[:]
old_field = table.schema[1]
new_schema[1] = bigquery.SchemaField(
name=old_field.name,
field_type=old_field.field_type,
mode=old_field.mode,
description=None,
fields=old_field.fields,
policy_tags=None,
)

table.schema = new_schema
updated_table = Config.CLIENT.update_table(table, ["schema"])

self.assertFalse(updated_table.schema[1].description) # Empty string or None.
self.assertEqual(updated_table.schema[1].policy_tags.names, ())

def test_update_table_clustering_configuration(self):
dataset = self.temp_dataset(_make_dataset_id("update_table"))

Expand Down
4 changes: 4 additions & 0 deletions tests/unit/job/test_load_config.py
Expand Up @@ -434,11 +434,13 @@ def test_schema_setter_fields(self):
"name": "full_name",
"type": "STRING",
"mode": "REQUIRED",
"policyTags": {"names": []},
}
age_repr = {
"name": "age",
"type": "INTEGER",
"mode": "REQUIRED",
"policyTags": {"names": []},
}
self.assertEqual(
config._properties["load"]["schema"], {"fields": [full_name_repr, age_repr]}
Expand All @@ -451,11 +453,13 @@ def test_schema_setter_valid_mappings_list(self):
"name": "full_name",
"type": "STRING",
"mode": "REQUIRED",
"policyTags": {"names": []},
}
age_repr = {
"name": "age",
"type": "INTEGER",
"mode": "REQUIRED",
"policyTags": {"names": []},
}
schema = [full_name_repr, age_repr]
config.schema = schema
Expand Down
49 changes: 43 additions & 6 deletions tests/unit/test_client.py
Expand Up @@ -1019,8 +1019,18 @@ def test_create_table_w_schema_and_query(self):
{
"schema": {
"fields": [
{"name": "full_name", "type": "STRING", "mode": "REQUIRED"},
{"name": "age", "type": "INTEGER", "mode": "REQUIRED"},
{
"name": "full_name",
"type": "STRING",
"mode": "REQUIRED",
"policyTags": {"names": []},
},
{
"name": "age",
"type": "INTEGER",
"mode": "REQUIRED",
"policyTags": {"names": []},
},
]
},
"view": {"query": query},
Expand Down Expand Up @@ -1054,8 +1064,18 @@ def test_create_table_w_schema_and_query(self):
},
"schema": {
"fields": [
{"name": "full_name", "type": "STRING", "mode": "REQUIRED"},
{"name": "age", "type": "INTEGER", "mode": "REQUIRED"},
{
"name": "full_name",
"type": "STRING",
"mode": "REQUIRED",
"policyTags": {"names": []},
},
{
"name": "age",
"type": "INTEGER",
"mode": "REQUIRED",
"policyTags": {"names": []},
},
]
},
"view": {"query": query, "useLegacySql": False},
Expand Down Expand Up @@ -2000,12 +2020,14 @@ def test_update_table(self):
"type": "STRING",
"mode": "REQUIRED",
"description": None,
"policyTags": {"names": []},
},
{
"name": "age",
"type": "INTEGER",
"mode": "REQUIRED",
"description": "New field description",
"policyTags": {"names": []},
},
]
},
Expand Down Expand Up @@ -2047,12 +2069,14 @@ def test_update_table(self):
"type": "STRING",
"mode": "REQUIRED",
"description": None,
"policyTags": {"names": []},
},
{
"name": "age",
"type": "INTEGER",
"mode": "REQUIRED",
"description": "New field description",
"policyTags": {"names": []},
},
]
},
Expand Down Expand Up @@ -2173,14 +2197,21 @@ def test_update_table_w_query(self):
"type": "STRING",
"mode": "REQUIRED",
"description": None,
"policyTags": {"names": []},
},
{
"name": "age",
"type": "INTEGER",
"mode": "REQUIRED",
"description": "this is a column",
"policyTags": {"names": []},
},
{
"name": "country",
"type": "STRING",
"mode": "NULLABLE",
"policyTags": {"names": []},
},
{"name": "country", "type": "STRING", "mode": "NULLABLE"},
]
}
schema = [
Expand Down Expand Up @@ -6516,10 +6547,10 @@ def test_load_table_from_dataframe(self):
assert field["type"] == table_field.field_type
assert field["mode"] == table_field.mode
assert len(field.get("fields", [])) == len(table_field.fields)
assert field["policyTags"]["names"] == []
# Omit unnecessary fields when they come from getting the table
# (not passed in via job_config)
assert "description" not in field
assert "policyTags" not in field

@unittest.skipIf(pandas is None, "Requires `pandas`")
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
Expand Down Expand Up @@ -7718,18 +7749,21 @@ def test_schema_to_json_with_file_path(self):
"description": "quarter",
"mode": "REQUIRED",
"name": "qtr",
"policyTags": {"names": []},
"type": "STRING",
},
{
"description": "sales representative",
"mode": "NULLABLE",
"name": "rep",
"policyTags": {"names": []},
"type": "STRING",
},
{
"description": "total sales",
"mode": "NULLABLE",
"name": "sales",
"policyTags": {"names": []},
"type": "FLOAT",
},
]
Expand Down Expand Up @@ -7762,18 +7796,21 @@ def test_schema_to_json_with_file_object(self):
"description": "quarter",
"mode": "REQUIRED",
"name": "qtr",
"policyTags": {"names": []},
"type": "STRING",
},
{
"description": "sales representative",
"mode": "NULLABLE",
"name": "rep",
"policyTags": {"names": []},
"type": "STRING",
},
{
"description": "total sales",
"mode": "NULLABLE",
"name": "sales",
"policyTags": {"names": []},
"type": "FLOAT",
},
]
Expand Down
9 changes: 8 additions & 1 deletion tests/unit/test_external_config.py
Expand Up @@ -78,7 +78,14 @@ def test_to_api_repr_base(self):
ec.schema = [schema.SchemaField("full_name", "STRING", mode="REQUIRED")]

exp_schema = {
"fields": [{"name": "full_name", "type": "STRING", "mode": "REQUIRED"}]
"fields": [
{
"name": "full_name",
"type": "STRING",
"mode": "REQUIRED",
"policyTags": {"names": []},
}
]
}
got_resource = ec.to_api_repr()
exp_resource = {
Expand Down

0 comments on commit 18bb443

Please sign in to comment.