Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Handle nested fields from BigQuery source when getting default column_names #522

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
69 changes: 57 additions & 12 deletions google/cloud/aiplatform/datasets/tabular_dataset.py
Expand Up @@ -18,7 +18,9 @@
import csv
import logging

from typing import List, Optional, Sequence, Tuple, Union
from typing import List, Optional, Sequence, Set, Tuple, Union

from google.cloud.bigquery.schema import SchemaField
ivanmkc marked this conversation as resolved.
Show resolved Hide resolved

from google.auth import credentials as auth_credentials

Expand All @@ -40,7 +42,7 @@ class TabularDataset(datasets._Dataset):
)

@property
def column_names(self) -> List[str]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing the return type introduces a breaking change. Perhaps convert back to list after after deduping with set. What is the scenario where the same column name is populated more than once? Assuming that's the motivation using set. Trying to understand if this is worth introducing a breaking change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The set comparison is trivial because col_names are unique and order is not a factor.

List implies an order which is not relevant for column names and makes the unit tests a tiny (very tiny) bit more complicated to write.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point about a breaking change. What do you recommend?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with the motivation to change the return type to a Set. Let's do the following:

  1. Leave the return type as List to avoid the breaking change.
  2. Proceed with this PR mainly as is.
  3. Open a ticket to track the return type change to Set.
  4. Tentatively plan to implement the return type change when we get closer to a larger breaking change and major version rev.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed on all points.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tracked in b/193044977

Copy link
Contributor Author

@ivanmkc ivanmkc Jul 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept the private method return types as Set as I assume that it's acceptable to make breaking changes to private methods.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sasha-gitg made the changes.

def column_names(self) -> Set[str]:
"""Retrieve the columns for the dataset by extracting it from the Google Cloud Storage or
Google BigQuery source.

Expand Down Expand Up @@ -94,7 +96,7 @@ def _retrieve_gcs_source_columns(
project: str,
gcs_csv_file_path: str,
credentials: Optional[auth_credentials.Credentials] = None,
) -> List[str]:
) -> Set[str]:
"""Retrieve the columns from a comma-delimited CSV file stored on Google Cloud Storage

Example Usage:
Expand All @@ -104,7 +106,7 @@ def _retrieve_gcs_source_columns(
"gs://example-bucket/path/to/csv_file"
)

# column_names = ["column_1", "column_2"]
# column_names = {"column_1", "column_2"}

Args:
project (str):
Expand All @@ -115,8 +117,8 @@ def _retrieve_gcs_source_columns(
credentials (auth_credentials.Credentials):
Credentials to use to with GCS Client.
Returns:
List[str]
A list of columns names in the CSV file.
Set[str]
A set of columns names in the CSV file.

Raises:
RuntimeError: When the retrieved CSV file is invalid.
Expand Down Expand Up @@ -163,15 +165,51 @@ def _retrieve_gcs_source_columns(
finally:
logger.removeFilter(logging_warning_filter)

return next(csv_reader)
ivanmkc marked this conversation as resolved.
Show resolved Hide resolved
return set(next(csv_reader))

@staticmethod
def _get_bq_schema_field_names_recursively(schema_field: SchemaField) -> Set[str]:
"""Retrieve the name for a schema field along with ancestor fields.
Nested schema fields are flattened and concatenated with a ".".
Schema fields with child fields are not included, but the children are.

Args:
project (str):
Required. Project to initiate the BigQuery client with.
bq_table_uri (str):
Required. A URI to a BigQuery table.
Can include "bq://" prefix but not required.
credentials (auth_credentials.Credentials):
Credentials to use with BQ Client.

Returns:
Set[str]
A set of columns names in the BigQuery table.
"""

ancestor_names = {
nested_field_name
for field in schema_field.fields
for nested_field_name in TabularDataset._get_bq_schema_field_names_recursively(
field
)
}

# Only return "leaf nodes", basically any field that doesn't have children
if len(ancestor_names) == 0:
return {schema_field.name}
else:
return {f"{schema_field.name}.{name}" for name in ancestor_names}
ivanmkc marked this conversation as resolved.
Show resolved Hide resolved

@staticmethod
def _retrieve_bq_source_columns(
project: str,
bq_table_uri: str,
credentials: Optional[auth_credentials.Credentials] = None,
) -> List[str]:
"""Retrieve the columns from a table on Google BigQuery
) -> Set[str]:
"""Retrieve the column names from a table on Google BigQuery
Nested schema fields are flattened and concatenated with a ".".
Schema fields with child fields are not included, but the children are.

Example Usage:

Expand All @@ -180,7 +218,7 @@ def _retrieve_bq_source_columns(
"bq://project_id.dataset.table"
)

# column_names = ["column_1", "column_2"]
# column_names = {"column_1", "column_2", "column_3.nested_field"}

Args:
project (str):
Expand All @@ -192,7 +230,7 @@ def _retrieve_bq_source_columns(
Credentials to use with BQ Client.

Returns:
List[str]
Set[str]
A list of columns names in the BigQuery table.
"""

Expand All @@ -204,7 +242,14 @@ def _retrieve_bq_source_columns(
client = bigquery.Client(project=project, credentials=credentials)
table = client.get_table(bq_table_uri)
schema = table.schema
return [schema.name for schema in schema]

return {
field_name
for field in schema
for field_name in TabularDataset._get_bq_schema_field_names_recursively(
field
)
}

@classmethod
def create(
Expand Down
64 changes: 63 additions & 1 deletion tests/unit/aiplatform/test_datasets.py
Expand Up @@ -375,6 +375,59 @@ def bigquery_table_schema_mock():
bigquery_table_schema_mock.return_value = [
bigquery.SchemaField("column_1", "FLOAT", "NULLABLE", "", (), None),
bigquery.SchemaField("column_2", "FLOAT", "NULLABLE", "", (), None),
bigquery.SchemaField(
"column_3",
"RECORD",
"NULLABLE",
"",
(
bigquery.SchemaField(
"nested_3_1",
"RECORD",
"NULLABLE",
"",
(
bigquery.SchemaField(
"nested_3_1_1", "FLOAT", "NULLABLE", "", (), None
),
bigquery.SchemaField(
"nested_3_1_2", "FLOAT", "NULLABLE", "", (), None
),
),
None,
),
bigquery.SchemaField(
"nested_3_2", "FLOAT", "NULLABLE", "", (), None
),
bigquery.SchemaField(
"nested_3_3",
"RECORD",
"NULLABLE",
"",
(
bigquery.SchemaField(
"nested_3_3_1",
"RECORD",
"NULLABLE",
"",
(
bigquery.SchemaField(
"nested_3_3_1_1",
"FLOAT",
"NULLABLE",
"",
(),
None,
),
),
None,
),
),
None,
),
),
None,
),
]
yield bigquery_table_schema_mock

Expand Down Expand Up @@ -1045,7 +1098,16 @@ def test_tabular_dataset_column_name_bq_with_creds(self, bq_client_mock):
def test_tabular_dataset_column_name_bigquery(self):
my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME)

assert my_dataset.column_names == ["column_1", "column_2"]
assert my_dataset.column_names == set(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sasha-gitg Using a set means that when I write a unit-test, I don't have to know about the implementation details on how the list is ordered.

There are workarounds, but using a set seems cleanest.

[
"column_1",
"column_2",
"column_3.nested_3_1.nested_3_1_1",
"column_3.nested_3_1.nested_3_1_2",
"column_3.nested_3_2",
"column_3.nested_3_3.nested_3_3_1.nested_3_3_1_1",
]
)


class TestTextDataset:
Expand Down