Skip to content

Commit

Permalink
Added unit test for nested BigQuery fields and refactored column_name…
Browse files Browse the repository at this point in the history
…s to return a Set instead of a List
  • Loading branch information
ivanmkc committed Jul 1, 2021
1 parent d1a1c4d commit e581093
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 20 deletions.
40 changes: 21 additions & 19 deletions google/cloud/aiplatform/datasets/tabular_dataset.py
Expand Up @@ -18,7 +18,7 @@
import csv
import logging

from typing import List, Optional, Sequence, Tuple, Union
from typing import List, Optional, Sequence, Set, Tuple, Union

from google.cloud.bigquery.schema import SchemaField

Expand All @@ -42,7 +42,7 @@ class TabularDataset(datasets._Dataset):
)

@property
def column_names(self) -> List[str]:
def column_names(self) -> Set[str]:
"""Retrieve the columns for the dataset by extracting it from the Google Cloud Storage or
Google BigQuery source.
Expand Down Expand Up @@ -96,7 +96,7 @@ def _retrieve_gcs_source_columns(
project: str,
gcs_csv_file_path: str,
credentials: Optional[auth_credentials.Credentials] = None,
) -> List[str]:
) -> Set[str]:
"""Retrieve the columns from a comma-delimited CSV file stored on Google Cloud Storage
Example Usage:
Expand All @@ -106,7 +106,7 @@ def _retrieve_gcs_source_columns(
"gs://example-bucket/path/to/csv_file"
)
# column_names = ["column_1", "column_2"]
# column_names = {"column_1", "column_2"}
Args:
project (str):
Expand All @@ -117,8 +117,8 @@ def _retrieve_gcs_source_columns(
credentials (auth_credentials.Credentials):
Credentials to use to with GCS Client.
Returns:
List[str]
A list of columns names in the CSV file.
Set[str]
A set of columns names in the CSV file.
Raises:
RuntimeError: When the retrieved CSV file is invalid.
Expand Down Expand Up @@ -165,10 +165,10 @@ def _retrieve_gcs_source_columns(
finally:
logger.removeFilter(logging_warning_filter)

return next(csv_reader)
return Set(next(csv_reader))

@staticmethod
def _get_schema_field_names_recursively(schema_field: SchemaField) -> List[str]:
def _get_bq_schema_field_names_recursively(schema_field: SchemaField) -> Set[str]:
"""Retrieve the name for a schema field along with ancestor fields.
Nested schema fields are flattened and concatenated with a ".".
Schema fields with child fields are not included, but the children are.
Expand All @@ -183,29 +183,29 @@ def _get_schema_field_names_recursively(schema_field: SchemaField) -> List[str]:
Credentials to use with BQ Client.
Returns:
List[str]
A list of columns names in the BigQuery table.
Set[str]
A set of columns names in the BigQuery table.
"""

ancestor_names = [
nested_field_name
for field in schema_field.fields
for nested_field_name in TabularDataset._get_schema_field_names_recursively(
for nested_field_name in TabularDataset._get_bq_schema_field_names_recursively(
field
)
]

if len(ancestor_names) == 0:
return [schema_field.name]
return {schema_field.name}
else:
return [f"{schema_field.name}.{name}" for name in ancestor_names]
return {f"{schema_field.name}.{name}" for name in ancestor_names}

@staticmethod
def _retrieve_bq_source_columns(
project: str,
bq_table_uri: str,
credentials: Optional[auth_credentials.Credentials] = None,
) -> List[str]:
) -> Set[str]:
"""Retrieve the column names from a table on Google BigQuery
Nested schema fields are flattened and concatenated with a ".".
Schema fields with child fields are not included, but the children are.
Expand All @@ -217,7 +217,7 @@ def _retrieve_bq_source_columns(
"bq://project_id.dataset.table"
)
# column_names = ["column_1", "column_2", "column_3.nested_field"]
# column_names = {"column_1", "column_2", "column_3.nested_field"}
Args:
project (str):
Expand All @@ -229,7 +229,7 @@ def _retrieve_bq_source_columns(
Credentials to use with BQ Client.
Returns:
List[str]
Set[str]
A list of columns names in the BigQuery table.
"""

Expand All @@ -242,11 +242,13 @@ def _retrieve_bq_source_columns(
table = client.get_table(bq_table_uri)
schema = table.schema

return [
return {
field_name
for field in schema
for field_name in TabularDataset._get_schema_field_names_recursively(field)
]
for field_name in TabularDataset._get_bq_schema_field_names_recursively(
field
)
}

@classmethod
def create(
Expand Down
64 changes: 63 additions & 1 deletion tests/unit/aiplatform/test_datasets.py
Expand Up @@ -375,6 +375,59 @@ def bigquery_table_schema_mock():
bigquery_table_schema_mock.return_value = [
bigquery.SchemaField("column_1", "FLOAT", "NULLABLE", "", (), None),
bigquery.SchemaField("column_2", "FLOAT", "NULLABLE", "", (), None),
bigquery.SchemaField(
"column_3",
"FLOAT",
"NULLABLE",
"",
(
bigquery.SchemaField(
"nested_3_1",
"FLOAT",
"NULLABLE",
"",
(
bigquery.SchemaField(
"nested_3_1_1", "FLOAT", "NULLABLE", "", (), None
),
bigquery.SchemaField(
"nested_3_1_2", "FLOAT", "NULLABLE", "", (), None
),
),
None,
),
bigquery.SchemaField(
"nested_3_2", "FLOAT", "NULLABLE", "", (), None
),
bigquery.SchemaField(
"nested_3_3",
"FLOAT",
"NULLABLE",
"",
(
bigquery.SchemaField(
"nested_3_3_1",
"FLOAT",
"NULLABLE",
"",
(
bigquery.SchemaField(
"nested_3_3_1_1",
"FLOAT",
"NULLABLE",
"",
(),
None,
),
),
None,
),
),
None,
),
),
None,
),
]
yield bigquery_table_schema_mock

Expand Down Expand Up @@ -1045,7 +1098,16 @@ def test_tabular_dataset_column_name_bq_with_creds(self, bq_client_mock):
def test_tabular_dataset_column_name_bigquery(self):
my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME)

assert my_dataset.column_names == ["column_1", "column_2"]
assert my_dataset.column_names == set(
[
"column_1",
"column_2",
"column_3.nested_3_1.nested_3_1_1",
"column_3.nested_3_1.nested_3_1_2",
"column_3.nested_3_2",
"column_3.nested_3_3.nested_3_3_1.nested_3_3_1_1",
]
)


class TestTextDataset:
Expand Down

0 comments on commit e581093

Please sign in to comment.