Skip to content

Commit

Permalink
Fixed splits
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanmkc committed Aug 17, 2021
1 parent 9dcf6fb commit 07b8843
Show file tree
Hide file tree
Showing 7 changed files with 1,926 additions and 495 deletions.
966 changes: 790 additions & 176 deletions google/cloud/aiplatform/training_jobs.py

Large diffs are not rendered by default.

28 changes: 2 additions & 26 deletions tests/unit/aiplatform/test_automl_forecasting_training_jobs.py
Expand Up @@ -103,14 +103,11 @@
_TEST_DATASET_NAME = "test-dataset-name"

_TEST_MODEL_DISPLAY_NAME = "model-display-name"

_TEST_LABELS = {"key": "value"}
_TEST_MODEL_LABELS = {"model_key": "model_value"}
_TEST_TRAINING_FRACTION_SPLIT = 0.8
_TEST_VALIDATION_FRACTION_SPLIT = 0.1
_TEST_TEST_FRACTION_SPLIT = 0.1
_TEST_PREDEFINED_SPLIT_COLUMN_NAME = "split"

_TEST_OUTPUT_PYTHON_PACKAGE_PATH = "gs://test/ouput/python/trainer.tar.gz"
_TEST_PREDEFINED_SPLIT_COLUMN_NAME = "split"

_TEST_MODEL_NAME = "projects/my-project/locations/us-central1/models/12345"

Expand Down Expand Up @@ -261,18 +258,11 @@ def test_run_call_pipeline_service_create(
if not sync:
model_from_job.wait()

true_fraction_split = gca_training_pipeline.FractionSplit(
training_fraction=_TEST_TRAINING_FRACTION_SPLIT,
validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT,
test_fraction=_TEST_TEST_FRACTION_SPLIT,
)

true_managed_model = gca_model.Model(
display_name=_TEST_MODEL_DISPLAY_NAME, labels=_TEST_MODEL_LABELS
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
fraction_split=true_fraction_split,
predefined_split=gca_training_pipeline.PredefinedSplit(
key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME
),
Expand Down Expand Up @@ -348,19 +338,12 @@ def test_run_call_pipeline_if_no_model_display_name_nor_model_labels(
if not sync:
model_from_job.wait()

true_fraction_split = gca_training_pipeline.FractionSplit(
training_fraction=_TEST_TRAINING_FRACTION_SPLIT,
validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT,
test_fraction=_TEST_TEST_FRACTION_SPLIT,
)

# Test that if defaults to the job display name
true_managed_model = gca_model.Model(
display_name=_TEST_DISPLAY_NAME, labels=_TEST_LABELS,
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
fraction_split=true_fraction_split,
dataset_id=mock_dataset_time_series.name,
)

Expand Down Expand Up @@ -422,17 +405,10 @@ def test_run_call_pipeline_if_set_additional_experiments(
if not sync:
model_from_job.wait()

true_fraction_split = gca_training_pipeline.FractionSplit(
training_fraction=_TEST_TRAINING_FRACTION_SPLIT,
validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT,
test_fraction=_TEST_TEST_FRACTION_SPLIT,
)

# Test that if defaults to the job display name
true_managed_model = gca_model.Model(display_name=_TEST_DISPLAY_NAME)

true_input_data_config = gca_training_pipeline.InputDataConfig(
fraction_split=true_fraction_split,
dataset_id=mock_dataset_time_series.name,
)

Expand Down

0 comments on commit 07b8843

Please sign in to comment.