From 3300faa380dd282b2dd2ea45c77478795744ab86 Mon Sep 17 00:00:00 2001 From: ivanmkc Date: Thu, 29 Apr 2021 17:59:44 -0400 Subject: [PATCH] Added test for AutoMLTabularTrainingJob for no transformations --- .../aiplatform/datasets/tabular_dataset.py | 20 +++- google/cloud/aiplatform/training_jobs.py | 4 +- .../test_automl_tabular_training_jobs.py | 95 ++++++++++++++++++- tests/unit/aiplatform/test_datasets.py | 43 +++++++-- 4 files changed, 146 insertions(+), 16 deletions(-) diff --git a/google/cloud/aiplatform/datasets/tabular_dataset.py b/google/cloud/aiplatform/datasets/tabular_dataset.py index b293c75c02..687f3a51ab 100644 --- a/google/cloud/aiplatform/datasets/tabular_dataset.py +++ b/google/cloud/aiplatform/datasets/tabular_dataset.py @@ -41,7 +41,23 @@ class TabularDataset(datasets._Dataset): @property def column_names(self) -> List[str]: - input_config = self._gca_resource.metadata.get("inputConfig") + """Retrieve the columns for the dataset by extracting it from the Google Cloud Storage or + Google BigQuery source. + + Returns: + List[str] + A list of columns names + + Raises: + RuntimeError: When no valid source is found. + """ + + metadata = self._gca_resource.metadata + + if metadata is None: + raise RuntimeError("No metadata found for dataset") + + input_config = metadata.get("inputConfig") if input_config is None: raise RuntimeError("No inputConfig found for dataset") @@ -92,8 +108,6 @@ def _retrieve_gcs_source_columns( Must include "gs://" prefix. Returns: - str - List[str] A list of columns names in the CSV file. diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index a7488f10e9..6488a8a6f6 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -2919,8 +2919,8 @@ def _run( if self._column_transformations is None: column_transformations = [ - {"AUTO": {"column_name": column_name}} - for column_name in dataset.column_names + {"auto": {"column_name": column_name}} + for column_name in dataset.column_names() ] else: column_transformations = self._column_transformations diff --git a/tests/unit/aiplatform/test_automl_tabular_training_jobs.py b/tests/unit/aiplatform/test_automl_tabular_training_jobs.py index 62cab4b3c3..2d4d4253d7 100644 --- a/tests/unit/aiplatform/test_automl_tabular_training_jobs.py +++ b/tests/unit/aiplatform/test_automl_tabular_training_jobs.py @@ -34,10 +34,16 @@ _TEST_DATASET_DISPLAY_NAME = "test-dataset-display-name" _TEST_DATASET_NAME = "test-dataset-name" _TEST_DISPLAY_NAME = "test-display-name" -_TEST_TRAINING_CONTAINER_IMAGE = "gcr.io/test-training/container:image" _TEST_METADATA_SCHEMA_URI_TABULAR = schema.dataset.metadata.tabular _TEST_METADATA_SCHEMA_URI_NONTABULAR = schema.dataset.metadata.image +_TEST_TRAINING_COLUMN_NAMES = [ + "sepal_width", + "sepal_length", + "petal_length", + "petal_width", +] + _TEST_TRAINING_COLUMN_TRANSFORMATIONS = [ {"auto": {"column_name": "sepal_width"}}, {"auto": {"column_name": "sepal_length"}}, @@ -169,7 +175,17 @@ def mock_dataset_tabular(): name=_TEST_DATASET_NAME, metadata={}, ) - return ds + + yield ds + + +@pytest.fixture +def mock_dataset_tabular_column_names(mock_dataset_tabular): + with mock.patch.object( + mock_dataset_tabular, "column_names", new_callable=mock.PropertyMock + ) as mock_dataset_tabular_column_names: + mock_dataset_tabular_column_names.return_value = _TEST_TRAINING_COLUMN_NAMES + yield mock_dataset_tabular_column_names @pytest.fixture @@ -347,6 +363,81 @@ def test_run_call_pipeline_if_no_model_display_name( training_pipeline=true_training_pipeline, ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_if_no_column_transformations( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_dataset_tabular, + mock_dataset_tabular_column_names, + mock_model_service_get, + sync, + ): + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_BUCKET_NAME, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = training_jobs.AutoMLTabularTrainingJob( + display_name=_TEST_DISPLAY_NAME, + optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, + optimization_prediction_type=_TEST_TRAINING_OPTIMIZATION_PREDICTION_TYPE, + column_transformations=None, + optimization_objective_recall_value=None, + optimization_objective_precision_value=None, + ) + + model_from_job = job.run( + dataset=mock_dataset_tabular, + target_column=_TEST_TRAINING_TARGET_COLUMN, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, + weight_column=_TEST_TRAINING_WEIGHT_COLUMN, + budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, + disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction=_TEST_TEST_FRACTION_SPLIT, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + predefined_split=gca_training_pipeline.PredefinedSplit( + key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME + ), + dataset_id=mock_dataset_tabular.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_tabular, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + @pytest.mark.usefixtures( "mock_pipeline_service_create", "mock_pipeline_service_get", diff --git a/tests/unit/aiplatform/test_datasets.py b/tests/unit/aiplatform/test_datasets.py index db4e19d073..cbe1e046ee 100644 --- a/tests/unit/aiplatform/test_datasets.py +++ b/tests/unit/aiplatform/test_datasets.py @@ -193,7 +193,7 @@ def get_dataset_tabular_gcs_mock(): @pytest.fixture -def get_dataset_tabular_mock(): +def get_dataset_tabular_bq_mock(): with patch.object( dataset_service_client.DatasetServiceClient, "get_dataset" ) as get_dataset_mock: @@ -207,6 +207,21 @@ def get_dataset_tabular_mock(): yield get_dataset_mock +@pytest.fixture +def get_dataset_tabular_missing_metadata_mock(): + with patch.object( + dataset_service_client.DatasetServiceClient, "get_dataset" + ) as get_dataset_mock: + get_dataset_mock.return_value = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR, + metadata=None, + name=_TEST_NAME, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + yield get_dataset_mock + + @pytest.fixture def get_dataset_text_mock(): with patch.object( @@ -616,7 +631,7 @@ def test_create_then_import( expected_dataset.name = _TEST_NAME assert my_dataset._gca_resource == expected_dataset - @pytest.mark.usefixtures("get_dataset_tabular_mock") + @pytest.mark.usefixtures("get_dataset_tabular_bq_mock") @pytest.mark.parametrize("sync", [True, False]) def test_delete_dataset(self, delete_dataset_mock, sync): aiplatform.init(project=_TEST_PROJECT) @@ -643,7 +658,7 @@ def test_init_dataset_image(self, get_dataset_image_mock): datasets.ImageDataset(dataset_name=_TEST_NAME) get_dataset_image_mock.assert_called_once_with(name=_TEST_NAME) - @pytest.mark.usefixtures("get_dataset_tabular_mock") + @pytest.mark.usefixtures("get_dataset_tabular_bq_mock") def test_init_dataset_non_image(self): aiplatform.init(project=_TEST_PROJECT) with pytest.raises(ValueError): @@ -801,10 +816,10 @@ def setup_method(self): def teardown_method(self): initializer.global_pool.shutdown(wait=True) - def test_init_dataset_tabular(self, get_dataset_tabular_mock): + def test_init_dataset_tabular(self, get_dataset_tabular_bq_mock): datasets.TabularDataset(dataset_name=_TEST_NAME) - get_dataset_tabular_mock.assert_called_once_with(name=_TEST_NAME) + get_dataset_tabular_bq_mock.assert_called_once_with(name=_TEST_NAME) @pytest.mark.usefixtures("get_dataset_image_mock") def test_init_dataset_non_tabular(self): @@ -812,7 +827,7 @@ def test_init_dataset_non_tabular(self): with pytest.raises(ValueError): datasets.TabularDataset(dataset_name=_TEST_NAME) - @pytest.mark.usefixtures("get_dataset_tabular_mock") + @pytest.mark.usefixtures("get_dataset_tabular_bq_mock") @pytest.mark.parametrize("sync", [True, False]) def test_create_dataset_with_default_encryption_key( self, create_dataset_mock, sync @@ -841,7 +856,7 @@ def test_create_dataset_with_default_encryption_key( metadata=_TEST_REQUEST_METADATA, ) - @pytest.mark.usefixtures("get_dataset_tabular_mock") + @pytest.mark.usefixtures("get_dataset_tabular_bq_mock") @pytest.mark.parametrize("sync", [True, False]) def test_create_dataset(self, create_dataset_mock, sync): @@ -868,7 +883,7 @@ def test_create_dataset(self, create_dataset_mock, sync): metadata=_TEST_REQUEST_METADATA, ) - @pytest.mark.usefixtures("get_dataset_tabular_mock") + @pytest.mark.usefixtures("get_dataset_tabular_bq_mock") def test_no_import_data_method(self): my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME) @@ -906,6 +921,13 @@ def test_list_dataset_no_order_or_filter(self, list_datasets_mock): for ds in ds_list: assert type(ds) == aiplatform.TabularDataset + @pytest.mark.usefixtures("get_dataset_tabular_missing_metadata_mock") + def test_tabular_dataset_column_name_missing_metadata(self): + my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME) + + with pytest.raises(RuntimeError): + my_dataset.column_names + @pytest.mark.usefixtures( "get_dataset_tabular_gcs_mock", "gcs_client_download_as_bytes_mock" ) @@ -915,7 +937,9 @@ def test_tabular_dataset_column_name_gcs(self): assert my_dataset.column_names == ["column_1", "column_2"] @pytest.mark.usefixtures( - "get_dataset_tabular_mock", "bigquery_client_mock", "bigquery_table_schema_mock" + "get_dataset_tabular_bq_mock", + "bigquery_client_mock", + "bigquery_table_schema_mock", ) def test_tabular_dataset_column_name_bigquery(self): my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME) @@ -1100,6 +1124,7 @@ def test_init_dataset_video(self, get_dataset_video_mock): datasets.VideoDataset(dataset_name=_TEST_NAME) get_dataset_video_mock.assert_called_once_with(name=_TEST_NAME) + @pytest.mark.usefixtures("get_dataset_tabular_bq_mock") def test_init_dataset_non_video(self): aiplatform.init(project=_TEST_PROJECT) with pytest.raises(ValueError):