diff --git a/google/cloud/aiplatform/datasets/tabular_dataset.py b/google/cloud/aiplatform/datasets/tabular_dataset.py index 71c9d4f7d7..0e812892e4 100644 --- a/google/cloud/aiplatform/datasets/tabular_dataset.py +++ b/google/cloud/aiplatform/datasets/tabular_dataset.py @@ -18,7 +18,7 @@ import csv import logging -from typing import List, Optional, Sequence, Tuple, Union +from typing import List, Optional, Sequence, Set, Tuple, Union from google.auth import credentials as auth_credentials @@ -73,18 +73,24 @@ def column_names(self) -> List[str]: gcs_source_uris.sort() # Get the first file in sorted list - return self._retrieve_gcs_source_columns( - project=self.project, - gcs_csv_file_path=gcs_source_uris[0], - credentials=self.credentials, + # TODO(b/193044977): Return as Set instead of List + return list( + self._retrieve_gcs_source_columns( + project=self.project, + gcs_csv_file_path=gcs_source_uris[0], + credentials=self.credentials, + ) ) elif bq_source: bq_table_uri = bq_source.get("uri") if bq_table_uri: - return self._retrieve_bq_source_columns( - project=self.project, - bq_table_uri=bq_table_uri, - credentials=self.credentials, + # TODO(b/193044977): Return as Set instead of List + return list( + self._retrieve_bq_source_columns( + project=self.project, + bq_table_uri=bq_table_uri, + credentials=self.credentials, + ) ) raise RuntimeError("No valid CSV or BigQuery datasource found.") @@ -94,7 +100,7 @@ def _retrieve_gcs_source_columns( project: str, gcs_csv_file_path: str, credentials: Optional[auth_credentials.Credentials] = None, - ) -> List[str]: + ) -> Set[str]: """Retrieve the columns from a comma-delimited CSV file stored on Google Cloud Storage Example Usage: @@ -104,7 +110,7 @@ def _retrieve_gcs_source_columns( "gs://example-bucket/path/to/csv_file" ) - # column_names = ["column_1", "column_2"] + # column_names = {"column_1", "column_2"} Args: project (str): @@ -115,8 +121,8 @@ def _retrieve_gcs_source_columns( credentials (auth_credentials.Credentials): Credentials to use to with GCS Client. Returns: - List[str] - A list of columns names in the CSV file. + Set[str] + A set of columns names in the CSV file. Raises: RuntimeError: When the retrieved CSV file is invalid. @@ -163,15 +169,53 @@ def _retrieve_gcs_source_columns( finally: logger.removeFilter(logging_warning_filter) - return next(csv_reader) + return set(next(csv_reader)) + + @staticmethod + def _get_bq_schema_field_names_recursively( + schema_field: bigquery.SchemaField, + ) -> Set[str]: + """Retrieve the name for a schema field along with ancestor fields. + Nested schema fields are flattened and concatenated with a ".". + Schema fields with child fields are not included, but the children are. + + Args: + project (str): + Required. Project to initiate the BigQuery client with. + bq_table_uri (str): + Required. A URI to a BigQuery table. + Can include "bq://" prefix but not required. + credentials (auth_credentials.Credentials): + Credentials to use with BQ Client. + + Returns: + Set[str] + A set of columns names in the BigQuery table. + """ + + ancestor_names = { + nested_field_name + for field in schema_field.fields + for nested_field_name in TabularDataset._get_bq_schema_field_names_recursively( + field + ) + } + + # Only return "leaf nodes", basically any field that doesn't have children + if len(ancestor_names) == 0: + return {schema_field.name} + else: + return {f"{schema_field.name}.{name}" for name in ancestor_names} @staticmethod def _retrieve_bq_source_columns( project: str, bq_table_uri: str, credentials: Optional[auth_credentials.Credentials] = None, - ) -> List[str]: - """Retrieve the columns from a table on Google BigQuery + ) -> Set[str]: + """Retrieve the column names from a table on Google BigQuery + Nested schema fields are flattened and concatenated with a ".". + Schema fields with child fields are not included, but the children are. Example Usage: @@ -180,7 +224,7 @@ def _retrieve_bq_source_columns( "bq://project_id.dataset.table" ) - # column_names = ["column_1", "column_2"] + # column_names = {"column_1", "column_2", "column_3.nested_field"} Args: project (str): @@ -192,8 +236,8 @@ def _retrieve_bq_source_columns( Credentials to use with BQ Client. Returns: - List[str] - A list of columns names in the BigQuery table. + Set[str] + A set of column names in the BigQuery table. """ # Remove bq:// prefix @@ -204,7 +248,14 @@ def _retrieve_bq_source_columns( client = bigquery.Client(project=project, credentials=credentials) table = client.get_table(bq_table_uri) schema = table.schema - return [schema.name for schema in schema] + + return { + field_name + for field in schema + for field_name in TabularDataset._get_bq_schema_field_names_recursively( + field + ) + } @classmethod def create( diff --git a/tests/unit/aiplatform/test_datasets.py b/tests/unit/aiplatform/test_datasets.py index 5da47bea59..25b8f27b63 100644 --- a/tests/unit/aiplatform/test_datasets.py +++ b/tests/unit/aiplatform/test_datasets.py @@ -375,6 +375,59 @@ def bigquery_table_schema_mock(): bigquery_table_schema_mock.return_value = [ bigquery.SchemaField("column_1", "FLOAT", "NULLABLE", "", (), None), bigquery.SchemaField("column_2", "FLOAT", "NULLABLE", "", (), None), + bigquery.SchemaField( + "column_3", + "RECORD", + "NULLABLE", + "", + ( + bigquery.SchemaField( + "nested_3_1", + "RECORD", + "NULLABLE", + "", + ( + bigquery.SchemaField( + "nested_3_1_1", "FLOAT", "NULLABLE", "", (), None + ), + bigquery.SchemaField( + "nested_3_1_2", "FLOAT", "NULLABLE", "", (), None + ), + ), + None, + ), + bigquery.SchemaField( + "nested_3_2", "FLOAT", "NULLABLE", "", (), None + ), + bigquery.SchemaField( + "nested_3_3", + "RECORD", + "NULLABLE", + "", + ( + bigquery.SchemaField( + "nested_3_3_1", + "RECORD", + "NULLABLE", + "", + ( + bigquery.SchemaField( + "nested_3_3_1_1", + "FLOAT", + "NULLABLE", + "", + (), + None, + ), + ), + None, + ), + ), + None, + ), + ), + None, + ), ] yield bigquery_table_schema_mock @@ -1007,7 +1060,7 @@ def test_tabular_dataset_column_name_missing_datasource(self): def test_tabular_dataset_column_name_gcs(self): my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME) - assert my_dataset.column_names == ["column_1", "column_2"] + assert set(my_dataset.column_names) == {"column_1", "column_2"} @pytest.mark.usefixtures("get_dataset_tabular_gcs_mock") def test_tabular_dataset_column_name_gcs_with_creds(self, gcs_client_mock): @@ -1045,7 +1098,16 @@ def test_tabular_dataset_column_name_bq_with_creds(self, bq_client_mock): def test_tabular_dataset_column_name_bigquery(self): my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME) - assert my_dataset.column_names == ["column_1", "column_2"] + assert set(my_dataset.column_names) == set( + [ + "column_1", + "column_2", + "column_3.nested_3_1.nested_3_1_1", + "column_3.nested_3_1.nested_3_1_2", + "column_3.nested_3_2", + "column_3.nested_3_3.nested_3_3_1.nested_3_3_1_1", + ] + ) class TestTextDataset: