From 4fce8c42504c6c5b86025d728819f61284ac5eef Mon Sep 17 00:00:00 2001 From: Ivan Cheung Date: Wed, 5 May 2021 18:16:46 -0400 Subject: [PATCH] feat: Added default AutoMLTabularTrainingJob column transformations (#357) * Added default column_transformation code * Added docstrings * Added tests and moved code to tabular_dataset * Switched to using BigQuery.Table instead of custom SQL query * Fixed bigquery unit test * Added GCS test * Fixed issues with incorrect input config parameter * Added test for AutoMLTabularTrainingJob for no transformations * Added comment * Fixed test * Ran linter * Switched from classmethod to staticmethod where applicable and logged column names * Added extra dataset tests * Added logging suppression * Fixed lint errors * Switched logging filter method Co-authored-by: sasha-gitg <44654632+sasha-gitg@users.noreply.github.com> --- .../cloud/aiplatform/datasets/_datasources.py | 4 +- .../aiplatform/datasets/tabular_dataset.py | 159 +++++++++++++++++- google/cloud/aiplatform/initializer.py | 2 +- google/cloud/aiplatform/training_jobs.py | 24 ++- google/cloud/aiplatform/utils.py | 7 +- tests/system/aiplatform/test_dataset.py | 2 +- .../test_automl_tabular_training_jobs.py | 87 +++++++++- tests/unit/aiplatform/test_datasets.py | 151 +++++++++++++++-- 8 files changed, 413 insertions(+), 23 deletions(-) diff --git a/google/cloud/aiplatform/datasets/_datasources.py b/google/cloud/aiplatform/datasets/_datasources.py index 23a89cc157..a01e68c01f 100644 --- a/google/cloud/aiplatform/datasets/_datasources.py +++ b/google/cloud/aiplatform/datasets/_datasources.py @@ -86,9 +86,9 @@ def __init__( raise ValueError("One of gcs_source or bq_source must be set.") if gcs_source: - dataset_metadata = {"input_config": {"gcs_source": {"uri": gcs_source}}} + dataset_metadata = {"inputConfig": {"gcsSource": {"uri": gcs_source}}} elif bq_source: - dataset_metadata = {"input_config": {"bigquery_source": {"uri": bq_source}}} + dataset_metadata = {"inputConfig": {"bigquerySource": {"uri": bq_source}}} self._dataset_metadata = dataset_metadata diff --git a/google/cloud/aiplatform/datasets/tabular_dataset.py b/google/cloud/aiplatform/datasets/tabular_dataset.py index 06ba4a3394..b80266cf00 100644 --- a/google/cloud/aiplatform/datasets/tabular_dataset.py +++ b/google/cloud/aiplatform/datasets/tabular_dataset.py @@ -15,10 +15,16 @@ # limitations under the License. # -from typing import Optional, Sequence, Tuple, Union +import csv +import logging + +from typing import List, Optional, Sequence, Tuple, Union from google.auth import credentials as auth_credentials +from google.cloud import bigquery +from google.cloud import storage + from google.cloud.aiplatform import datasets from google.cloud.aiplatform.datasets import _datasources from google.cloud.aiplatform import initializer @@ -33,6 +39,157 @@ class TabularDataset(datasets._Dataset): schema.dataset.metadata.tabular, ) + @property + def column_names(self) -> List[str]: + """Retrieve the columns for the dataset by extracting it from the Google Cloud Storage or + Google BigQuery source. + + Returns: + List[str] + A list of columns names + + Raises: + RuntimeError: When no valid source is found. + """ + + metadata = self._gca_resource.metadata + + if metadata is None: + raise RuntimeError("No metadata found for dataset") + + input_config = metadata.get("inputConfig") + + if input_config is None: + raise RuntimeError("No inputConfig found for dataset") + + gcs_source = input_config.get("gcsSource") + bq_source = input_config.get("bigquerySource") + + if gcs_source: + gcs_source_uris = gcs_source.get("uri") + + if gcs_source_uris and len(gcs_source_uris) > 0: + # Lexicographically sort the files + gcs_source_uris.sort() + + # Get the first file in sorted list + return TabularDataset._retrieve_gcs_source_columns( + self.project, gcs_source_uris[0] + ) + elif bq_source: + bq_table_uri = bq_source.get("uri") + if bq_table_uri: + return TabularDataset._retrieve_bq_source_columns( + self.project, bq_table_uri + ) + + raise RuntimeError("No valid CSV or BigQuery datasource found.") + + @staticmethod + def _retrieve_gcs_source_columns(project: str, gcs_csv_file_path: str) -> List[str]: + """Retrieve the columns from a comma-delimited CSV file stored on Google Cloud Storage + + Example Usage: + + column_names = _retrieve_gcs_source_columns( + "project_id", + "gs://example-bucket/path/to/csv_file" + ) + + # column_names = ["column_1", "column_2"] + + Args: + project (str): + Required. Project to initiate the Google Cloud Storage client with. + gcs_csv_file_path (str): + Required. A full path to a CSV files stored on Google Cloud Storage. + Must include "gs://" prefix. + + Returns: + List[str] + A list of columns names in the CSV file. + + Raises: + RuntimeError: When the retrieved CSV file is invalid. + """ + + gcs_bucket, gcs_blob = utils.extract_bucket_and_prefix_from_gcs_path( + gcs_csv_file_path + ) + client = storage.Client(project=project) + bucket = client.bucket(gcs_bucket) + blob = bucket.blob(gcs_blob) + + # Incrementally download the CSV file until the header is retrieved + first_new_line_index = -1 + start_index = 0 + increment = 1000 + line = "" + + try: + logger = logging.getLogger("google.resumable_media._helpers") + logging_warning_filter = utils.LoggingFilter(logging.INFO) + logger.addFilter(logging_warning_filter) + + while first_new_line_index == -1: + line += blob.download_as_bytes( + start=start_index, end=start_index + increment + ).decode("utf-8") + first_new_line_index = line.find("\n") + start_index += increment + + header_line = line[:first_new_line_index] + + # Split to make it an iterable + header_line = header_line.split("\n")[:1] + + csv_reader = csv.reader(header_line, delimiter=",") + except (ValueError, RuntimeError) as err: + raise RuntimeError( + "There was a problem extracting the headers from the CSV file at '{}': {}".format( + gcs_csv_file_path, err + ) + ) + finally: + logger.removeFilter(logging_warning_filter) + + return next(csv_reader) + + @staticmethod + def _retrieve_bq_source_columns(project: str, bq_table_uri: str) -> List[str]: + """Retrieve the columns from a table on Google BigQuery + + Example Usage: + + column_names = _retrieve_bq_source_columns( + "project_id", + "bq://project_id.dataset.table" + ) + + # column_names = ["column_1", "column_2"] + + Args: + project (str): + Required. Project to initiate the BigQuery client with. + bq_table_uri (str): + Required. A URI to a BigQuery table. + Can include "bq://" prefix but not required. + + Returns: + List[str] + A list of columns names in the BigQuery table. + """ + + # Remove bq:// prefix + prefix = "bq://" + if bq_table_uri.startswith(prefix): + bq_table_uri = bq_table_uri[len(prefix) :] + + client = bigquery.Client(project=project) + table = client.get_table(bq_table_uri) + schema = table.schema + return [schema.name for schema in schema] + @classmethod def create( cls, diff --git a/google/cloud/aiplatform/initializer.py b/google/cloud/aiplatform/initializer.py index 41a3b06d7f..fa00fa3387 100644 --- a/google/cloud/aiplatform/initializer.py +++ b/google/cloud/aiplatform/initializer.py @@ -170,7 +170,7 @@ def credentials(self) -> Optional[auth_credentials.Credentials]: if self._credentials: return self._credentials logger = logging.getLogger("google.auth._default") - logging_warning_filter = utils.LoggingWarningFilter() + logging_warning_filter = utils.LoggingFilter(logging.WARNING) logger.addFilter(logging_warning_filter) credentials, _ = google.auth.default() logger.removeFilter(logging_warning_filter) diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index 441f91ca39..ccd0ca7be6 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -130,7 +130,6 @@ def __init__( super().__init__(project=project, location=location, credentials=credentials) self._display_name = display_name - self._project = project self._training_encryption_spec = initializer.global_config.get_encryption_spec( encryption_spec_key_name=training_encryption_spec_key_name ) @@ -2955,10 +2954,31 @@ def _run( training_task_definition = schema.training_job.definition.automl_tabular + if self._column_transformations is None: + _LOGGER.info( + "No column transformations provided, so now retrieving columns from dataset in order to set default column transformations." + ) + + column_names = [ + column_name + for column_name in dataset.column_names + if column_name != target_column + ] + column_transformations = [ + {"auto": {"column_name": column_name}} for column_name in column_names + ] + + _LOGGER.info( + "The column transformation of type 'auto' was set for the following columns: %s." + % column_names + ) + else: + column_transformations = self._column_transformations + training_task_inputs_dict = { # required inputs "targetColumn": target_column, - "transformations": self._column_transformations, + "transformations": column_transformations, "trainBudgetMilliNodeHours": budget_milli_node_hours, # optional inputs "weightColumnName": weight_column, diff --git a/google/cloud/aiplatform/utils.py b/google/cloud/aiplatform/utils.py index 22991290da..64f7a29671 100644 --- a/google/cloud/aiplatform/utils.py +++ b/google/cloud/aiplatform/utils.py @@ -468,6 +468,9 @@ class PredictionClientWithOverride(ClientWithOverride): ) -class LoggingWarningFilter(logging.Filter): +class LoggingFilter(logging.Filter): + def __init__(self, warning_level: int): + self._warning_level = warning_level + def filter(self, record): - return record.levelname == logging.WARNING + return record.levelname == self._warning_level diff --git a/tests/system/aiplatform/test_dataset.py b/tests/system/aiplatform/test_dataset.py index e18390a76a..e852933dc3 100644 --- a/tests/system/aiplatform/test_dataset.py +++ b/tests/system/aiplatform/test_dataset.py @@ -25,8 +25,8 @@ from google.api_core import exceptions from google.api_core import client_options -from google.cloud import storage from google.cloud import aiplatform +from google.cloud import storage from google.cloud.aiplatform import utils from google.cloud.aiplatform import initializer from google.cloud.aiplatform_v1beta1.types import dataset as gca_dataset diff --git a/tests/unit/aiplatform/test_automl_tabular_training_jobs.py b/tests/unit/aiplatform/test_automl_tabular_training_jobs.py index 62cab4b3c3..aea5b66d5f 100644 --- a/tests/unit/aiplatform/test_automl_tabular_training_jobs.py +++ b/tests/unit/aiplatform/test_automl_tabular_training_jobs.py @@ -34,10 +34,16 @@ _TEST_DATASET_DISPLAY_NAME = "test-dataset-display-name" _TEST_DATASET_NAME = "test-dataset-name" _TEST_DISPLAY_NAME = "test-display-name" -_TEST_TRAINING_CONTAINER_IMAGE = "gcr.io/test-training/container:image" _TEST_METADATA_SCHEMA_URI_TABULAR = schema.dataset.metadata.tabular _TEST_METADATA_SCHEMA_URI_NONTABULAR = schema.dataset.metadata.image +_TEST_TRAINING_COLUMN_NAMES = [ + "sepal_width", + "sepal_length", + "petal_length", + "petal_width", +] + _TEST_TRAINING_COLUMN_TRANSFORMATIONS = [ {"auto": {"column_name": "sepal_width"}}, {"auto": {"column_name": "sepal_length"}}, @@ -169,7 +175,9 @@ def mock_dataset_tabular(): name=_TEST_DATASET_NAME, metadata={}, ) - return ds + ds.column_names = _TEST_TRAINING_COLUMN_NAMES + + yield ds @pytest.fixture @@ -347,6 +355,81 @@ def test_run_call_pipeline_if_no_model_display_name( training_pipeline=true_training_pipeline, ) + @pytest.mark.parametrize("sync", [True, False]) + # This test checks that default transformations are used if no columns transformations are provided + def test_run_call_pipeline_service_create_if_no_column_transformations( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_dataset_tabular, + mock_model_service_get, + sync, + ): + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_BUCKET_NAME, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = training_jobs.AutoMLTabularTrainingJob( + display_name=_TEST_DISPLAY_NAME, + optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, + optimization_prediction_type=_TEST_TRAINING_OPTIMIZATION_PREDICTION_TYPE, + column_transformations=None, + optimization_objective_recall_value=None, + optimization_objective_precision_value=None, + ) + + model_from_job = job.run( + dataset=mock_dataset_tabular, + target_column=_TEST_TRAINING_TARGET_COLUMN, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, + weight_column=_TEST_TRAINING_WEIGHT_COLUMN, + budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, + disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction=_TEST_TEST_FRACTION_SPLIT, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + predefined_split=gca_training_pipeline.PredefinedSplit( + key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME + ), + dataset_id=mock_dataset_tabular.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_tabular, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + @pytest.mark.usefixtures( "mock_pipeline_service_create", "mock_pipeline_service_get", diff --git a/tests/unit/aiplatform/test_datasets.py b/tests/unit/aiplatform/test_datasets.py index 918f753dbf..6b67d67a20 100644 --- a/tests/unit/aiplatform/test_datasets.py +++ b/tests/unit/aiplatform/test_datasets.py @@ -28,6 +28,8 @@ from google.auth import credentials as auth_credentials from google.cloud import aiplatform +from google.cloud import bigquery +from google.cloud import storage from google.cloud.aiplatform import datasets from google.cloud.aiplatform import initializer @@ -86,7 +88,7 @@ "gs://my-bucket/index_file_2.jsonl", "gs://my-bucket/index_file_3.jsonl", ] -_TEST_SOURCE_URI_BQ = "bigquery://my-project/my-dataset" +_TEST_SOURCE_URI_BQ = "bq://my-project.my-dataset.table" _TEST_INVALID_SOURCE_URIS = ["gs://my-bucket/index_file_1.jsonl", 123] # request_metadata @@ -95,10 +97,10 @@ # dataset_metadata _TEST_NONTABULAR_DATASET_METADATA = None _TEST_METADATA_TABULAR_GCS = { - "input_config": {"gcs_source": {"uri": [_TEST_SOURCE_URI_GCS]}} + "inputConfig": {"gcsSource": {"uri": [_TEST_SOURCE_URI_GCS]}} } _TEST_METADATA_TABULAR_BQ = { - "input_config": {"bigquery_source": {"uri": _TEST_SOURCE_URI_BQ}} + "inputConfig": {"bigquerySource": {"uri": _TEST_SOURCE_URI_BQ}} } # CMEK encryption @@ -176,7 +178,22 @@ def get_dataset_image_mock(): @pytest.fixture -def get_dataset_tabular_mock(): +def get_dataset_tabular_gcs_mock(): + with patch.object( + dataset_service_client.DatasetServiceClient, "get_dataset" + ) as get_dataset_mock: + get_dataset_mock.return_value = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR, + metadata=_TEST_METADATA_TABULAR_GCS, + name=_TEST_NAME, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + yield get_dataset_mock + + +@pytest.fixture +def get_dataset_tabular_bq_mock(): with patch.object( dataset_service_client.DatasetServiceClient, "get_dataset" ) as get_dataset_mock: @@ -190,6 +207,51 @@ def get_dataset_tabular_mock(): yield get_dataset_mock +@pytest.fixture +def get_dataset_tabular_missing_metadata_mock(): + with patch.object( + dataset_service_client.DatasetServiceClient, "get_dataset" + ) as get_dataset_mock: + get_dataset_mock.return_value = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR, + metadata=None, + name=_TEST_NAME, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + yield get_dataset_mock + + +@pytest.fixture +def get_dataset_tabular_missing_input_config_mock(): + with patch.object( + dataset_service_client.DatasetServiceClient, "get_dataset" + ) as get_dataset_mock: + get_dataset_mock.return_value = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR, + metadata={}, + name=_TEST_NAME, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + yield get_dataset_mock + + +@pytest.fixture +def get_dataset_tabular_missing_datasource_mock(): + with patch.object( + dataset_service_client.DatasetServiceClient, "get_dataset" + ) as get_dataset_mock: + get_dataset_mock.return_value = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR, + metadata={"inputConfig": {}}, + name=_TEST_NAME, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + yield get_dataset_mock + + @pytest.fixture def get_dataset_text_mock(): with patch.object( @@ -276,6 +338,32 @@ def list_datasets_mock(): yield list_datasets_mock +@pytest.fixture +def gcs_client_download_as_bytes_mock(): + with patch.object(storage.Blob, "download_as_bytes") as bigquery_blob_mock: + bigquery_blob_mock.return_value = b'"column_1","column_2"\n0, 1' + yield bigquery_blob_mock + + +@pytest.fixture +def bigquery_client_mock(): + with patch.object(bigquery.Client, "get_table") as bigquery_client_mock: + bigquery_client_mock.return_value = bigquery.Table("project.dataset.table") + yield bigquery_client_mock + + +@pytest.fixture +def bigquery_table_schema_mock(): + with patch.object( + bigquery.Table, "schema", new_callable=mock.PropertyMock + ) as bigquery_table_schema_mock: + bigquery_table_schema_mock.return_value = [ + bigquery.SchemaField("column_1", "FLOAT", "NULLABLE", "", (), None), + bigquery.SchemaField("column_2", "FLOAT", "NULLABLE", "", (), None), + ] + yield bigquery_table_schema_mock + + # TODO(b/171333554): Move reusable test fixtures to conftest.py file class TestDataset: def setup_method(self): @@ -573,7 +661,7 @@ def test_create_then_import( expected_dataset.name = _TEST_NAME assert my_dataset._gca_resource == expected_dataset - @pytest.mark.usefixtures("get_dataset_tabular_mock") + @pytest.mark.usefixtures("get_dataset_tabular_bq_mock") @pytest.mark.parametrize("sync", [True, False]) def test_delete_dataset(self, delete_dataset_mock, sync): aiplatform.init(project=_TEST_PROJECT) @@ -600,7 +688,7 @@ def test_init_dataset_image(self, get_dataset_image_mock): datasets.ImageDataset(dataset_name=_TEST_NAME) get_dataset_image_mock.assert_called_once_with(name=_TEST_NAME) - @pytest.mark.usefixtures("get_dataset_tabular_mock") + @pytest.mark.usefixtures("get_dataset_tabular_bq_mock") def test_init_dataset_non_image(self): aiplatform.init(project=_TEST_PROJECT) with pytest.raises(ValueError): @@ -758,10 +846,10 @@ def setup_method(self): def teardown_method(self): initializer.global_pool.shutdown(wait=True) - def test_init_dataset_tabular(self, get_dataset_tabular_mock): + def test_init_dataset_tabular(self, get_dataset_tabular_bq_mock): datasets.TabularDataset(dataset_name=_TEST_NAME) - get_dataset_tabular_mock.assert_called_once_with(name=_TEST_NAME) + get_dataset_tabular_bq_mock.assert_called_once_with(name=_TEST_NAME) @pytest.mark.usefixtures("get_dataset_image_mock") def test_init_dataset_non_tabular(self): @@ -769,7 +857,7 @@ def test_init_dataset_non_tabular(self): with pytest.raises(ValueError): datasets.TabularDataset(dataset_name=_TEST_NAME) - @pytest.mark.usefixtures("get_dataset_tabular_mock") + @pytest.mark.usefixtures("get_dataset_tabular_bq_mock") @pytest.mark.parametrize("sync", [True, False]) def test_create_dataset_with_default_encryption_key( self, create_dataset_mock, sync @@ -798,7 +886,7 @@ def test_create_dataset_with_default_encryption_key( metadata=_TEST_REQUEST_METADATA, ) - @pytest.mark.usefixtures("get_dataset_tabular_mock") + @pytest.mark.usefixtures("get_dataset_tabular_bq_mock") @pytest.mark.parametrize("sync", [True, False]) def test_create_dataset(self, create_dataset_mock, sync): @@ -825,7 +913,7 @@ def test_create_dataset(self, create_dataset_mock, sync): metadata=_TEST_REQUEST_METADATA, ) - @pytest.mark.usefixtures("get_dataset_tabular_mock") + @pytest.mark.usefixtures("get_dataset_tabular_bq_mock") def test_no_import_data_method(self): my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME) @@ -863,6 +951,45 @@ def test_list_dataset_no_order_or_filter(self, list_datasets_mock): for ds in ds_list: assert type(ds) == aiplatform.TabularDataset + @pytest.mark.usefixtures("get_dataset_tabular_missing_metadata_mock") + def test_tabular_dataset_column_name_missing_metadata(self): + my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME) + + with pytest.raises(RuntimeError): + my_dataset.column_names + + @pytest.mark.usefixtures("get_dataset_tabular_missing_input_config_mock") + def test_tabular_dataset_column_name_missing_input_config(self): + my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME) + + with pytest.raises(RuntimeError): + my_dataset.column_names + + @pytest.mark.usefixtures("get_dataset_tabular_missing_datasource_mock") + def test_tabular_dataset_column_name_missing_datasource(self): + my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME) + + with pytest.raises(RuntimeError): + my_dataset.column_names + + @pytest.mark.usefixtures( + "get_dataset_tabular_gcs_mock", "gcs_client_download_as_bytes_mock" + ) + def test_tabular_dataset_column_name_gcs(self): + my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME) + + assert my_dataset.column_names == ["column_1", "column_2"] + + @pytest.mark.usefixtures( + "get_dataset_tabular_bq_mock", + "bigquery_client_mock", + "bigquery_table_schema_mock", + ) + def test_tabular_dataset_column_name_bigquery(self): + my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME) + + assert my_dataset.column_names == ["column_1", "column_2"] + class TestTextDataset: def setup_method(self): @@ -1041,7 +1168,7 @@ def test_init_dataset_video(self, get_dataset_video_mock): datasets.VideoDataset(dataset_name=_TEST_NAME) get_dataset_video_mock.assert_called_once_with(name=_TEST_NAME) - @pytest.mark.usefixtures("get_dataset_tabular_mock") + @pytest.mark.usefixtures("get_dataset_tabular_bq_mock") def test_init_dataset_non_video(self): aiplatform.init(project=_TEST_PROJECT) with pytest.raises(ValueError):