Skip to content

Commit

Permalink
Fixed test_automl_tabular_training_jobs.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanmkc committed Aug 17, 2021
1 parent 78872ab commit b7eac85
Showing 1 changed file with 8 additions and 28 deletions.
36 changes: 8 additions & 28 deletions tests/unit/aiplatform/test_automl_tabular_training_jobs.py
Expand Up @@ -125,12 +125,8 @@
_TEST_FRACTION_SPLIT_VALIDATION = 0.2
_TEST_FRACTION_SPLIT_TEST = 0.2

_TEST_SPLIT_DEFAULT = gca_training_pipeline.FractionSplit(
training_fraction=0.8, validation_fraction=0.1, test_fraction=0.1,
)

_TEST_SPLIT_PREDEFINED_COLUMN_NAME = "split"
_TEST_SPLIT_PREDEFINED_COLUMN_NAME = "timestamp"
_TEST_SPLIT_TIMESTAMP_COLUMN_NAME = "timestamp"

_TEST_OUTPUT_PYTHON_PACKAGE_PATH = "gs://test/ouput/python/trainer.tar.gz"

Expand Down Expand Up @@ -307,7 +303,7 @@ def test_run_call_pipeline_service_create(
training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING,
validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION,
test_fraction_split=_TEST_FRACTION_SPLIT_TEST,
timestamp_split_column_name=_TEST_SPLIT_PREDEFINED_COLUMN_NAME,
timestamp_split_column_name=_TEST_SPLIT_TIMESTAMP_COLUMN_NAME,
weight_column=_TEST_TRAINING_WEIGHT_COLUMN,
budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS,
disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING,
Expand All @@ -321,7 +317,7 @@ def test_run_call_pipeline_service_create(
training_fraction=_TEST_FRACTION_SPLIT_TRAINING,
validation_fraction=_TEST_FRACTION_SPLIT_VALIDATION,
test_fraction=_TEST_FRACTION_SPLIT_TEST,
key=_TEST_SPLIT_PREDEFINED_COLUMN_NAME,
key=_TEST_SPLIT_TIMESTAMP_COLUMN_NAME,
)

true_managed_model = gca_model.Model(
Expand Down Expand Up @@ -392,15 +388,13 @@ def test_run_call_pipeline_if_no_model_display_name(
if not sync:
model_from_job.wait()

true_fraction_split = _TEST_SPLIT_DEFAULT

# Test that if defaults to the job display name
true_managed_model = gca_model.Model(
display_name=_TEST_DISPLAY_NAME, encryption_spec=_TEST_MODEL_ENCRYPTION_SPEC
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
fraction_split=true_fraction_split, dataset_id=mock_dataset_tabular.name,
dataset_id=mock_dataset_tabular.name,
)

true_training_pipeline = gca_training_pipeline.TrainingPipeline(
Expand Down Expand Up @@ -527,19 +521,13 @@ def test_run_call_pipeline_service_create_if_set_additional_experiments(
if not sync:
model_from_job.wait()

true_fraction_split = gca_training_pipeline.FractionSplit(
training_fraction=_TEST_FRACTION_SPLIT_TRAINING,
validation_fraction=_TEST_FRACTION_SPLIT_VALIDATION,
test_fraction=_TEST_FRACTION_SPLIT_TEST,
)

true_managed_model = gca_model.Model(
display_name=_TEST_MODEL_DISPLAY_NAME,
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
fraction_split=true_fraction_split, dataset_id=mock_dataset_tabular.name,
dataset_id=mock_dataset_tabular.name,
)

true_training_pipeline = gca_training_pipeline.TrainingPipeline(
Expand Down Expand Up @@ -860,7 +848,6 @@ def test_splits_fraction(
mock_pipeline_service_get,
mock_dataset_tabular,
mock_model_service_get,
mock_model,
sync,
):
"""
Expand Down Expand Up @@ -902,7 +889,6 @@ def test_splits_fraction(

true_managed_model = gca_model.Model(
display_name=_TEST_MODEL_DISPLAY_NAME,
description=mock_model._gca_resource.description,
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
)

Expand Down Expand Up @@ -931,7 +917,6 @@ def test_splits_timestamp(
mock_pipeline_service_get,
mock_dataset_tabular,
mock_model_service_get,
mock_model,
sync,
):
"""
Expand Down Expand Up @@ -975,7 +960,6 @@ def test_splits_timestamp(

true_managed_model = gca_model.Model(
display_name=_TEST_MODEL_DISPLAY_NAME,
description=mock_model._gca_resource.description,
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
)

Expand Down Expand Up @@ -1004,7 +988,6 @@ def test_splits_predefined(
mock_pipeline_service_get,
mock_dataset_tabular,
mock_model_service_get,
mock_model,
sync,
):
"""
Expand Down Expand Up @@ -1043,7 +1026,6 @@ def test_splits_predefined(

true_managed_model = gca_model.Model(
display_name=_TEST_MODEL_DISPLAY_NAME,
description=mock_model._gca_resource.description,
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
)

Expand Down Expand Up @@ -1072,7 +1054,6 @@ def test_splits_default(
mock_pipeline_service_get,
mock_dataset_tabular,
mock_model_service_get,
mock_model,
sync,
):
"""
Expand All @@ -1096,6 +1077,8 @@ def test_splits_default(

model_from_job = job.run(
dataset=mock_dataset_tabular,
target_column=_TEST_TRAINING_TARGET_COLUMN,
weight_column=_TEST_TRAINING_WEIGHT_COLUMN,
model_display_name=_TEST_MODEL_DISPLAY_NAME,
disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING,
sync=sync,
Expand All @@ -1104,16 +1087,13 @@ def test_splits_default(
if not sync:
model_from_job.wait()

true_default_split = _TEST_SPLIT_DEFAULT

true_managed_model = gca_model.Model(
display_name=_TEST_MODEL_DISPLAY_NAME,
description=mock_model._gca_resource.description,
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
fraction_split=true_default_split, dataset_id=mock_dataset_tabular.name,
dataset_id=mock_dataset_tabular.name,
)

true_training_pipeline = gca_training_pipeline.TrainingPipeline(
Expand Down

0 comments on commit b7eac85

Please sign in to comment.