Skip to content

Commit

Permalink
Fixed tests/unit/aiplatform/test_automl_video_training_jobs.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanmkc committed Aug 17, 2021
1 parent 72e20f0 commit 91d9c20
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 16 deletions.
2 changes: 1 addition & 1 deletion google/cloud/aiplatform/training_jobs.py
Expand Up @@ -5235,7 +5235,7 @@ def _run(
training_task_inputs=training_task_inputs_dict,
dataset=dataset,
training_fraction_split=training_fraction_split,
validation_fraction_split=None,
validation_fraction_split=0,
test_fraction_split=test_fraction_split,
training_filter_split=training_filter_split,
validation_filter_split="-",
Expand Down
18 changes: 3 additions & 15 deletions tests/unit/aiplatform/test_automl_video_training_jobs.py
Expand Up @@ -58,10 +58,6 @@
_TEST_FILTER_SPLIT_VALIDATION = "-"
_TEST_FILTER_SPLIT_TEST = "test"

_TEST_SPLIT_DEFAULT = gca_training_pipeline.FractionSplit(
training_fraction=0.8, validation_fraction=0.1, test_fraction=0.1,
)

_TEST_MODEL_NAME = (
f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/models/{_TEST_MODEL_ID}"
)
Expand Down Expand Up @@ -247,20 +243,14 @@ def test_init_aiplatform_with_encryption_key_name_and_create_training_job(
if not sync:
model_from_job.wait()

true_fraction_split = gca_training_pipeline.FractionSplit(
training_fraction=_TEST_FRACTION_SPLIT_TRAINING,
validation_fraction=_TEST_FRACTION_SPLIT_VALIDATION,
test_fraction=_TEST_FRACTION_SPLIT_TEST,
)

true_managed_model = gca_model.Model(
display_name=_TEST_MODEL_DISPLAY_NAME,
description=mock_model._gca_resource.description,
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
fraction_split=true_fraction_split, dataset_id=mock_dataset_video.name,
dataset_id=mock_dataset_video.name,
)

true_training_pipeline = gca_training_pipeline.TrainingPipeline(
Expand Down Expand Up @@ -453,16 +443,14 @@ def test_splits_default(
if not sync:
model_from_job.wait()

true_default_split = _TEST_SPLIT_DEFAULT

true_managed_model = gca_model.Model(
display_name=_TEST_MODEL_DISPLAY_NAME,
description=mock_model._gca_resource.description,
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
fraction_split=true_default_split, dataset_id=mock_dataset_video.name,
dataset_id=mock_dataset_video.name,
)

true_training_pipeline = gca_training_pipeline.TrainingPipeline(
Expand Down Expand Up @@ -646,7 +634,7 @@ def test_run_with_two_split_raises(
model_display_name=_TEST_MODEL_DISPLAY_NAME,
training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING,
test_fraction_split=_TEST_FRACTION_SPLIT_TEST,
training_filter_split=_TEST_FILTER_SPLIT_TRAINING,
training_filter_split=_TEST_FILTER_SPLIT_TEST,
test_filter_split=_TEST_FILTER_SPLIT_TEST,
sync=sync,
)
Expand Down

0 comments on commit 91d9c20

Please sign in to comment.