Skip to content

Commit

Permalink
Added test for AutoMLTabularTrainingJob for no transformations
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanmkc committed Apr 29, 2021
1 parent ea5ef12 commit 3300faa
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 16 deletions.
20 changes: 17 additions & 3 deletions google/cloud/aiplatform/datasets/tabular_dataset.py
Expand Up @@ -41,7 +41,23 @@ class TabularDataset(datasets._Dataset):

@property
def column_names(self) -> List[str]:
input_config = self._gca_resource.metadata.get("inputConfig")
"""Retrieve the columns for the dataset by extracting it from the Google Cloud Storage or
Google BigQuery source.
Returns:
List[str]
A list of columns names
Raises:
RuntimeError: When no valid source is found.
"""

metadata = self._gca_resource.metadata

if metadata is None:
raise RuntimeError("No metadata found for dataset")

input_config = metadata.get("inputConfig")

if input_config is None:
raise RuntimeError("No inputConfig found for dataset")
Expand Down Expand Up @@ -92,8 +108,6 @@ def _retrieve_gcs_source_columns(
Must include "gs://" prefix.
Returns:
str
List[str]
A list of columns names in the CSV file.
Expand Down
4 changes: 2 additions & 2 deletions google/cloud/aiplatform/training_jobs.py
Expand Up @@ -2919,8 +2919,8 @@ def _run(

if self._column_transformations is None:
column_transformations = [
{"AUTO": {"column_name": column_name}}
for column_name in dataset.column_names
{"auto": {"column_name": column_name}}
for column_name in dataset.column_names()
]
else:
column_transformations = self._column_transformations
Expand Down
95 changes: 93 additions & 2 deletions tests/unit/aiplatform/test_automl_tabular_training_jobs.py
Expand Up @@ -34,10 +34,16 @@
_TEST_DATASET_DISPLAY_NAME = "test-dataset-display-name"
_TEST_DATASET_NAME = "test-dataset-name"
_TEST_DISPLAY_NAME = "test-display-name"
_TEST_TRAINING_CONTAINER_IMAGE = "gcr.io/test-training/container:image"
_TEST_METADATA_SCHEMA_URI_TABULAR = schema.dataset.metadata.tabular
_TEST_METADATA_SCHEMA_URI_NONTABULAR = schema.dataset.metadata.image

_TEST_TRAINING_COLUMN_NAMES = [
"sepal_width",
"sepal_length",
"petal_length",
"petal_width",
]

_TEST_TRAINING_COLUMN_TRANSFORMATIONS = [
{"auto": {"column_name": "sepal_width"}},
{"auto": {"column_name": "sepal_length"}},
Expand Down Expand Up @@ -169,7 +175,17 @@ def mock_dataset_tabular():
name=_TEST_DATASET_NAME,
metadata={},
)
return ds

yield ds


@pytest.fixture
def mock_dataset_tabular_column_names(mock_dataset_tabular):
with mock.patch.object(
mock_dataset_tabular, "column_names", new_callable=mock.PropertyMock
) as mock_dataset_tabular_column_names:
mock_dataset_tabular_column_names.return_value = _TEST_TRAINING_COLUMN_NAMES
yield mock_dataset_tabular_column_names


@pytest.fixture
Expand Down Expand Up @@ -347,6 +363,81 @@ def test_run_call_pipeline_if_no_model_display_name(
training_pipeline=true_training_pipeline,
)

@pytest.mark.parametrize("sync", [True, False])
def test_run_call_pipeline_service_create_if_no_column_transformations(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
mock_dataset_tabular,
mock_dataset_tabular_column_names,
mock_model_service_get,
sync,
):
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_BUCKET_NAME,
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
)

job = training_jobs.AutoMLTabularTrainingJob(
display_name=_TEST_DISPLAY_NAME,
optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME,
optimization_prediction_type=_TEST_TRAINING_OPTIMIZATION_PREDICTION_TYPE,
column_transformations=None,
optimization_objective_recall_value=None,
optimization_objective_precision_value=None,
)

model_from_job = job.run(
dataset=mock_dataset_tabular,
target_column=_TEST_TRAINING_TARGET_COLUMN,
model_display_name=_TEST_MODEL_DISPLAY_NAME,
training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT,
validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT,
test_fraction_split=_TEST_TEST_FRACTION_SPLIT,
predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME,
weight_column=_TEST_TRAINING_WEIGHT_COLUMN,
budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS,
disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING,
sync=sync,
)

if not sync:
model_from_job.wait()

true_fraction_split = gca_training_pipeline.FractionSplit(
training_fraction=_TEST_TRAINING_FRACTION_SPLIT,
validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT,
test_fraction=_TEST_TEST_FRACTION_SPLIT,
)

true_managed_model = gca_model.Model(
display_name=_TEST_MODEL_DISPLAY_NAME,
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
fraction_split=true_fraction_split,
predefined_split=gca_training_pipeline.PredefinedSplit(
key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME
),
dataset_id=mock_dataset_tabular.name,
)

true_training_pipeline = gca_training_pipeline.TrainingPipeline(
display_name=_TEST_DISPLAY_NAME,
training_task_definition=schema.training_job.definition.automl_tabular,
training_task_inputs=_TEST_TRAINING_TASK_INPUTS,
model_to_upload=true_managed_model,
input_data_config=true_input_data_config,
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
)

mock_pipeline_service_create.assert_called_once_with(
parent=initializer.global_config.common_location_path(),
training_pipeline=true_training_pipeline,
)

@pytest.mark.usefixtures(
"mock_pipeline_service_create",
"mock_pipeline_service_get",
Expand Down
43 changes: 34 additions & 9 deletions tests/unit/aiplatform/test_datasets.py
Expand Up @@ -193,7 +193,7 @@ def get_dataset_tabular_gcs_mock():


@pytest.fixture
def get_dataset_tabular_mock():
def get_dataset_tabular_bq_mock():
with patch.object(
dataset_service_client.DatasetServiceClient, "get_dataset"
) as get_dataset_mock:
Expand All @@ -207,6 +207,21 @@ def get_dataset_tabular_mock():
yield get_dataset_mock


@pytest.fixture
def get_dataset_tabular_missing_metadata_mock():
with patch.object(
dataset_service_client.DatasetServiceClient, "get_dataset"
) as get_dataset_mock:
get_dataset_mock.return_value = gca_dataset.Dataset(
display_name=_TEST_DISPLAY_NAME,
metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR,
metadata=None,
name=_TEST_NAME,
encryption_spec=_TEST_ENCRYPTION_SPEC,
)
yield get_dataset_mock


@pytest.fixture
def get_dataset_text_mock():
with patch.object(
Expand Down Expand Up @@ -616,7 +631,7 @@ def test_create_then_import(
expected_dataset.name = _TEST_NAME
assert my_dataset._gca_resource == expected_dataset

@pytest.mark.usefixtures("get_dataset_tabular_mock")
@pytest.mark.usefixtures("get_dataset_tabular_bq_mock")
@pytest.mark.parametrize("sync", [True, False])
def test_delete_dataset(self, delete_dataset_mock, sync):
aiplatform.init(project=_TEST_PROJECT)
Expand All @@ -643,7 +658,7 @@ def test_init_dataset_image(self, get_dataset_image_mock):
datasets.ImageDataset(dataset_name=_TEST_NAME)
get_dataset_image_mock.assert_called_once_with(name=_TEST_NAME)

@pytest.mark.usefixtures("get_dataset_tabular_mock")
@pytest.mark.usefixtures("get_dataset_tabular_bq_mock")
def test_init_dataset_non_image(self):
aiplatform.init(project=_TEST_PROJECT)
with pytest.raises(ValueError):
Expand Down Expand Up @@ -801,18 +816,18 @@ def setup_method(self):
def teardown_method(self):
initializer.global_pool.shutdown(wait=True)

def test_init_dataset_tabular(self, get_dataset_tabular_mock):
def test_init_dataset_tabular(self, get_dataset_tabular_bq_mock):

datasets.TabularDataset(dataset_name=_TEST_NAME)
get_dataset_tabular_mock.assert_called_once_with(name=_TEST_NAME)
get_dataset_tabular_bq_mock.assert_called_once_with(name=_TEST_NAME)

@pytest.mark.usefixtures("get_dataset_image_mock")
def test_init_dataset_non_tabular(self):

with pytest.raises(ValueError):
datasets.TabularDataset(dataset_name=_TEST_NAME)

@pytest.mark.usefixtures("get_dataset_tabular_mock")
@pytest.mark.usefixtures("get_dataset_tabular_bq_mock")
@pytest.mark.parametrize("sync", [True, False])
def test_create_dataset_with_default_encryption_key(
self, create_dataset_mock, sync
Expand Down Expand Up @@ -841,7 +856,7 @@ def test_create_dataset_with_default_encryption_key(
metadata=_TEST_REQUEST_METADATA,
)

@pytest.mark.usefixtures("get_dataset_tabular_mock")
@pytest.mark.usefixtures("get_dataset_tabular_bq_mock")
@pytest.mark.parametrize("sync", [True, False])
def test_create_dataset(self, create_dataset_mock, sync):

Expand All @@ -868,7 +883,7 @@ def test_create_dataset(self, create_dataset_mock, sync):
metadata=_TEST_REQUEST_METADATA,
)

@pytest.mark.usefixtures("get_dataset_tabular_mock")
@pytest.mark.usefixtures("get_dataset_tabular_bq_mock")
def test_no_import_data_method(self):

my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME)
Expand Down Expand Up @@ -906,6 +921,13 @@ def test_list_dataset_no_order_or_filter(self, list_datasets_mock):
for ds in ds_list:
assert type(ds) == aiplatform.TabularDataset

@pytest.mark.usefixtures("get_dataset_tabular_missing_metadata_mock")
def test_tabular_dataset_column_name_missing_metadata(self):
my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME)

with pytest.raises(RuntimeError):
my_dataset.column_names

@pytest.mark.usefixtures(
"get_dataset_tabular_gcs_mock", "gcs_client_download_as_bytes_mock"
)
Expand All @@ -915,7 +937,9 @@ def test_tabular_dataset_column_name_gcs(self):
assert my_dataset.column_names == ["column_1", "column_2"]

@pytest.mark.usefixtures(
"get_dataset_tabular_mock", "bigquery_client_mock", "bigquery_table_schema_mock"
"get_dataset_tabular_bq_mock",
"bigquery_client_mock",
"bigquery_table_schema_mock",
)
def test_tabular_dataset_column_name_bigquery(self):
my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME)
Expand Down Expand Up @@ -1100,6 +1124,7 @@ def test_init_dataset_video(self, get_dataset_video_mock):
datasets.VideoDataset(dataset_name=_TEST_NAME)
get_dataset_video_mock.assert_called_once_with(name=_TEST_NAME)

@pytest.mark.usefixtures("get_dataset_tabular_bq_mock")
def test_init_dataset_non_video(self):
aiplatform.init(project=_TEST_PROJECT)
with pytest.raises(ValueError):
Expand Down

0 comments on commit 3300faa

Please sign in to comment.