Skip to content

Commit

Permalink
fix: Handle nested fields from BigQuery source when getting default c…
Browse files Browse the repository at this point in the history
…olumn_names (#522)

* Handle nested fields from BigQuery source

* Added unit test for nested BigQuery fields and refactored column_names to return a Set instead of a List

* Added comment

* Fixed minor issues with tabular_dataset

* Switched TabularDataset.column_names back to returning a List as to not introduce a breaking change at this time
  • Loading branch information
ivanmkc committed Jul 8, 2021
1 parent 2508fe9 commit 3fc1d44
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 22 deletions.
91 changes: 71 additions & 20 deletions google/cloud/aiplatform/datasets/tabular_dataset.py
Expand Up @@ -18,7 +18,7 @@
import csv
import logging

from typing import List, Optional, Sequence, Tuple, Union
from typing import List, Optional, Sequence, Set, Tuple, Union

from google.auth import credentials as auth_credentials

Expand Down Expand Up @@ -73,18 +73,24 @@ def column_names(self) -> List[str]:
gcs_source_uris.sort()

# Get the first file in sorted list
return self._retrieve_gcs_source_columns(
project=self.project,
gcs_csv_file_path=gcs_source_uris[0],
credentials=self.credentials,
# TODO(b/193044977): Return as Set instead of List
return list(
self._retrieve_gcs_source_columns(
project=self.project,
gcs_csv_file_path=gcs_source_uris[0],
credentials=self.credentials,
)
)
elif bq_source:
bq_table_uri = bq_source.get("uri")
if bq_table_uri:
return self._retrieve_bq_source_columns(
project=self.project,
bq_table_uri=bq_table_uri,
credentials=self.credentials,
# TODO(b/193044977): Return as Set instead of List
return list(
self._retrieve_bq_source_columns(
project=self.project,
bq_table_uri=bq_table_uri,
credentials=self.credentials,
)
)

raise RuntimeError("No valid CSV or BigQuery datasource found.")
Expand All @@ -94,7 +100,7 @@ def _retrieve_gcs_source_columns(
project: str,
gcs_csv_file_path: str,
credentials: Optional[auth_credentials.Credentials] = None,
) -> List[str]:
) -> Set[str]:
"""Retrieve the columns from a comma-delimited CSV file stored on Google Cloud Storage
Example Usage:
Expand All @@ -104,7 +110,7 @@ def _retrieve_gcs_source_columns(
"gs://example-bucket/path/to/csv_file"
)
# column_names = ["column_1", "column_2"]
# column_names = {"column_1", "column_2"}
Args:
project (str):
Expand All @@ -115,8 +121,8 @@ def _retrieve_gcs_source_columns(
credentials (auth_credentials.Credentials):
Credentials to use to with GCS Client.
Returns:
List[str]
A list of columns names in the CSV file.
Set[str]
A set of columns names in the CSV file.
Raises:
RuntimeError: When the retrieved CSV file is invalid.
Expand Down Expand Up @@ -163,15 +169,53 @@ def _retrieve_gcs_source_columns(
finally:
logger.removeFilter(logging_warning_filter)

return next(csv_reader)
return set(next(csv_reader))

@staticmethod
def _get_bq_schema_field_names_recursively(
schema_field: bigquery.SchemaField,
) -> Set[str]:
"""Retrieve the name for a schema field along with ancestor fields.
Nested schema fields are flattened and concatenated with a ".".
Schema fields with child fields are not included, but the children are.
Args:
project (str):
Required. Project to initiate the BigQuery client with.
bq_table_uri (str):
Required. A URI to a BigQuery table.
Can include "bq://" prefix but not required.
credentials (auth_credentials.Credentials):
Credentials to use with BQ Client.
Returns:
Set[str]
A set of columns names in the BigQuery table.
"""

ancestor_names = {
nested_field_name
for field in schema_field.fields
for nested_field_name in TabularDataset._get_bq_schema_field_names_recursively(
field
)
}

# Only return "leaf nodes", basically any field that doesn't have children
if len(ancestor_names) == 0:
return {schema_field.name}
else:
return {f"{schema_field.name}.{name}" for name in ancestor_names}

@staticmethod
def _retrieve_bq_source_columns(
project: str,
bq_table_uri: str,
credentials: Optional[auth_credentials.Credentials] = None,
) -> List[str]:
"""Retrieve the columns from a table on Google BigQuery
) -> Set[str]:
"""Retrieve the column names from a table on Google BigQuery
Nested schema fields are flattened and concatenated with a ".".
Schema fields with child fields are not included, but the children are.
Example Usage:
Expand All @@ -180,7 +224,7 @@ def _retrieve_bq_source_columns(
"bq://project_id.dataset.table"
)
# column_names = ["column_1", "column_2"]
# column_names = {"column_1", "column_2", "column_3.nested_field"}
Args:
project (str):
Expand All @@ -192,8 +236,8 @@ def _retrieve_bq_source_columns(
Credentials to use with BQ Client.
Returns:
List[str]
A list of columns names in the BigQuery table.
Set[str]
A set of column names in the BigQuery table.
"""

# Remove bq:// prefix
Expand All @@ -204,7 +248,14 @@ def _retrieve_bq_source_columns(
client = bigquery.Client(project=project, credentials=credentials)
table = client.get_table(bq_table_uri)
schema = table.schema
return [schema.name for schema in schema]

return {
field_name
for field in schema
for field_name in TabularDataset._get_bq_schema_field_names_recursively(
field
)
}

@classmethod
def create(
Expand Down
66 changes: 64 additions & 2 deletions tests/unit/aiplatform/test_datasets.py
Expand Up @@ -375,6 +375,59 @@ def bigquery_table_schema_mock():
bigquery_table_schema_mock.return_value = [
bigquery.SchemaField("column_1", "FLOAT", "NULLABLE", "", (), None),
bigquery.SchemaField("column_2", "FLOAT", "NULLABLE", "", (), None),
bigquery.SchemaField(
"column_3",
"RECORD",
"NULLABLE",
"",
(
bigquery.SchemaField(
"nested_3_1",
"RECORD",
"NULLABLE",
"",
(
bigquery.SchemaField(
"nested_3_1_1", "FLOAT", "NULLABLE", "", (), None
),
bigquery.SchemaField(
"nested_3_1_2", "FLOAT", "NULLABLE", "", (), None
),
),
None,
),
bigquery.SchemaField(
"nested_3_2", "FLOAT", "NULLABLE", "", (), None
),
bigquery.SchemaField(
"nested_3_3",
"RECORD",
"NULLABLE",
"",
(
bigquery.SchemaField(
"nested_3_3_1",
"RECORD",
"NULLABLE",
"",
(
bigquery.SchemaField(
"nested_3_3_1_1",
"FLOAT",
"NULLABLE",
"",
(),
None,
),
),
None,
),
),
None,
),
),
None,
),
]
yield bigquery_table_schema_mock

Expand Down Expand Up @@ -1007,7 +1060,7 @@ def test_tabular_dataset_column_name_missing_datasource(self):
def test_tabular_dataset_column_name_gcs(self):
my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME)

assert my_dataset.column_names == ["column_1", "column_2"]
assert set(my_dataset.column_names) == {"column_1", "column_2"}

@pytest.mark.usefixtures("get_dataset_tabular_gcs_mock")
def test_tabular_dataset_column_name_gcs_with_creds(self, gcs_client_mock):
Expand Down Expand Up @@ -1045,7 +1098,16 @@ def test_tabular_dataset_column_name_bq_with_creds(self, bq_client_mock):
def test_tabular_dataset_column_name_bigquery(self):
my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME)

assert my_dataset.column_names == ["column_1", "column_2"]
assert set(my_dataset.column_names) == set(
[
"column_1",
"column_2",
"column_3.nested_3_1.nested_3_1_1",
"column_3.nested_3_1.nested_3_1_2",
"column_3.nested_3_2",
"column_3.nested_3_3.nested_3_3_1.nested_3_3_1_1",
]
)


class TestTextDataset:
Expand Down

0 comments on commit 3fc1d44

Please sign in to comment.