Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Handle nested fields from BigQuery source when getting default column_names #522

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
91 changes: 71 additions & 20 deletions google/cloud/aiplatform/datasets/tabular_dataset.py
Expand Up @@ -18,7 +18,7 @@
import csv
import logging

from typing import List, Optional, Sequence, Tuple, Union
from typing import List, Optional, Sequence, Set, Tuple, Union

from google.auth import credentials as auth_credentials

Expand Down Expand Up @@ -73,18 +73,24 @@ def column_names(self) -> List[str]:
gcs_source_uris.sort()

# Get the first file in sorted list
return self._retrieve_gcs_source_columns(
project=self.project,
gcs_csv_file_path=gcs_source_uris[0],
credentials=self.credentials,
# TODO(b/193044977): Return as Set instead of List
return list(
self._retrieve_gcs_source_columns(
project=self.project,
gcs_csv_file_path=gcs_source_uris[0],
credentials=self.credentials,
)
)
elif bq_source:
bq_table_uri = bq_source.get("uri")
if bq_table_uri:
return self._retrieve_bq_source_columns(
project=self.project,
bq_table_uri=bq_table_uri,
credentials=self.credentials,
# TODO(b/193044977): Return as Set instead of List
return list(
self._retrieve_bq_source_columns(
project=self.project,
bq_table_uri=bq_table_uri,
credentials=self.credentials,
)
)

raise RuntimeError("No valid CSV or BigQuery datasource found.")
Expand All @@ -94,7 +100,7 @@ def _retrieve_gcs_source_columns(
project: str,
gcs_csv_file_path: str,
credentials: Optional[auth_credentials.Credentials] = None,
) -> List[str]:
) -> Set[str]:
"""Retrieve the columns from a comma-delimited CSV file stored on Google Cloud Storage

Example Usage:
Expand All @@ -104,7 +110,7 @@ def _retrieve_gcs_source_columns(
"gs://example-bucket/path/to/csv_file"
)

# column_names = ["column_1", "column_2"]
# column_names = {"column_1", "column_2"}

Args:
project (str):
Expand All @@ -115,8 +121,8 @@ def _retrieve_gcs_source_columns(
credentials (auth_credentials.Credentials):
Credentials to use to with GCS Client.
Returns:
List[str]
A list of columns names in the CSV file.
Set[str]
A set of columns names in the CSV file.

Raises:
RuntimeError: When the retrieved CSV file is invalid.
Expand Down Expand Up @@ -163,15 +169,53 @@ def _retrieve_gcs_source_columns(
finally:
logger.removeFilter(logging_warning_filter)

return next(csv_reader)
ivanmkc marked this conversation as resolved.
Show resolved Hide resolved
return set(next(csv_reader))

@staticmethod
def _get_bq_schema_field_names_recursively(
schema_field: bigquery.SchemaField,
) -> Set[str]:
"""Retrieve the name for a schema field along with ancestor fields.
Nested schema fields are flattened and concatenated with a ".".
Schema fields with child fields are not included, but the children are.

Args:
project (str):
Required. Project to initiate the BigQuery client with.
bq_table_uri (str):
Required. A URI to a BigQuery table.
Can include "bq://" prefix but not required.
credentials (auth_credentials.Credentials):
Credentials to use with BQ Client.

Returns:
Set[str]
A set of columns names in the BigQuery table.
"""

ancestor_names = {
nested_field_name
for field in schema_field.fields
for nested_field_name in TabularDataset._get_bq_schema_field_names_recursively(
field
)
}

# Only return "leaf nodes", basically any field that doesn't have children
if len(ancestor_names) == 0:
return {schema_field.name}
else:
return {f"{schema_field.name}.{name}" for name in ancestor_names}
ivanmkc marked this conversation as resolved.
Show resolved Hide resolved

@staticmethod
def _retrieve_bq_source_columns(
project: str,
bq_table_uri: str,
credentials: Optional[auth_credentials.Credentials] = None,
) -> List[str]:
"""Retrieve the columns from a table on Google BigQuery
) -> Set[str]:
"""Retrieve the column names from a table on Google BigQuery
Nested schema fields are flattened and concatenated with a ".".
Schema fields with child fields are not included, but the children are.

Example Usage:

Expand All @@ -180,7 +224,7 @@ def _retrieve_bq_source_columns(
"bq://project_id.dataset.table"
)

# column_names = ["column_1", "column_2"]
# column_names = {"column_1", "column_2", "column_3.nested_field"}

Args:
project (str):
Expand All @@ -192,8 +236,8 @@ def _retrieve_bq_source_columns(
Credentials to use with BQ Client.

Returns:
List[str]
A list of columns names in the BigQuery table.
Set[str]
A set of column names in the BigQuery table.
"""

# Remove bq:// prefix
Expand All @@ -204,7 +248,14 @@ def _retrieve_bq_source_columns(
client = bigquery.Client(project=project, credentials=credentials)
table = client.get_table(bq_table_uri)
schema = table.schema
return [schema.name for schema in schema]

return {
field_name
for field in schema
for field_name in TabularDataset._get_bq_schema_field_names_recursively(
field
)
}

@classmethod
def create(
Expand Down
66 changes: 64 additions & 2 deletions tests/unit/aiplatform/test_datasets.py
Expand Up @@ -375,6 +375,59 @@ def bigquery_table_schema_mock():
bigquery_table_schema_mock.return_value = [
bigquery.SchemaField("column_1", "FLOAT", "NULLABLE", "", (), None),
bigquery.SchemaField("column_2", "FLOAT", "NULLABLE", "", (), None),
bigquery.SchemaField(
"column_3",
"RECORD",
"NULLABLE",
"",
(
bigquery.SchemaField(
"nested_3_1",
"RECORD",
"NULLABLE",
"",
(
bigquery.SchemaField(
"nested_3_1_1", "FLOAT", "NULLABLE", "", (), None
),
bigquery.SchemaField(
"nested_3_1_2", "FLOAT", "NULLABLE", "", (), None
),
),
None,
),
bigquery.SchemaField(
"nested_3_2", "FLOAT", "NULLABLE", "", (), None
),
bigquery.SchemaField(
"nested_3_3",
"RECORD",
"NULLABLE",
"",
(
bigquery.SchemaField(
"nested_3_3_1",
"RECORD",
"NULLABLE",
"",
(
bigquery.SchemaField(
"nested_3_3_1_1",
"FLOAT",
"NULLABLE",
"",
(),
None,
),
),
None,
),
),
None,
),
),
None,
),
]
yield bigquery_table_schema_mock

Expand Down Expand Up @@ -1007,7 +1060,7 @@ def test_tabular_dataset_column_name_missing_datasource(self):
def test_tabular_dataset_column_name_gcs(self):
my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME)

assert my_dataset.column_names == ["column_1", "column_2"]
assert set(my_dataset.column_names) == {"column_1", "column_2"}

@pytest.mark.usefixtures("get_dataset_tabular_gcs_mock")
def test_tabular_dataset_column_name_gcs_with_creds(self, gcs_client_mock):
Expand Down Expand Up @@ -1045,7 +1098,16 @@ def test_tabular_dataset_column_name_bq_with_creds(self, bq_client_mock):
def test_tabular_dataset_column_name_bigquery(self):
my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME)

assert my_dataset.column_names == ["column_1", "column_2"]
assert set(my_dataset.column_names) == set(
[
"column_1",
"column_2",
"column_3.nested_3_1.nested_3_1_1",
"column_3.nested_3_1.nested_3_1_2",
"column_3.nested_3_2",
"column_3.nested_3_3.nested_3_3_1.nested_3_3_1_1",
]
)


class TestTextDataset:
Expand Down