Skip to content

Commit

Permalink
feat: add filter and timestamp splits (#627)
Browse files Browse the repository at this point in the history
* Fixed splits

* Fixed docstrings

* Fix test bug

* Ran linter

* Fixed FractionSplit and AutoMLVideo FilterSplit issues

* Added warning for incomplete filter splits

* Fixed AutoMLVideo tests

* Fixed type

* Moved annotation_schema_uri

* Tweaked docstrings

Co-authored-by: sasha-gitg <44654632+sasha-gitg@users.noreply.github.com>
  • Loading branch information
ivanmkc and sasha-gitg committed Aug 18, 2021
1 parent 74f81e6 commit 1a13577
Show file tree
Hide file tree
Showing 7 changed files with 2,115 additions and 556 deletions.
1,181 changes: 943 additions & 238 deletions google/cloud/aiplatform/training_jobs.py

Large diffs are not rendered by default.

28 changes: 2 additions & 26 deletions tests/unit/aiplatform/test_automl_forecasting_training_jobs.py
Expand Up @@ -103,14 +103,11 @@
_TEST_DATASET_NAME = "test-dataset-name"

_TEST_MODEL_DISPLAY_NAME = "model-display-name"

_TEST_LABELS = {"key": "value"}
_TEST_MODEL_LABELS = {"model_key": "model_value"}
_TEST_TRAINING_FRACTION_SPLIT = 0.8
_TEST_VALIDATION_FRACTION_SPLIT = 0.1
_TEST_TEST_FRACTION_SPLIT = 0.1
_TEST_PREDEFINED_SPLIT_COLUMN_NAME = "split"

_TEST_OUTPUT_PYTHON_PACKAGE_PATH = "gs://test/ouput/python/trainer.tar.gz"
_TEST_PREDEFINED_SPLIT_COLUMN_NAME = "split"

_TEST_MODEL_NAME = "projects/my-project/locations/us-central1/models/12345"

Expand Down Expand Up @@ -261,18 +258,11 @@ def test_run_call_pipeline_service_create(
if not sync:
model_from_job.wait()

true_fraction_split = gca_training_pipeline.FractionSplit(
training_fraction=_TEST_TRAINING_FRACTION_SPLIT,
validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT,
test_fraction=_TEST_TEST_FRACTION_SPLIT,
)

true_managed_model = gca_model.Model(
display_name=_TEST_MODEL_DISPLAY_NAME, labels=_TEST_MODEL_LABELS
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
fraction_split=true_fraction_split,
predefined_split=gca_training_pipeline.PredefinedSplit(
key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME
),
Expand Down Expand Up @@ -348,19 +338,12 @@ def test_run_call_pipeline_if_no_model_display_name_nor_model_labels(
if not sync:
model_from_job.wait()

true_fraction_split = gca_training_pipeline.FractionSplit(
training_fraction=_TEST_TRAINING_FRACTION_SPLIT,
validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT,
test_fraction=_TEST_TEST_FRACTION_SPLIT,
)

# Test that if defaults to the job display name
true_managed_model = gca_model.Model(
display_name=_TEST_DISPLAY_NAME, labels=_TEST_LABELS,
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
fraction_split=true_fraction_split,
dataset_id=mock_dataset_time_series.name,
)

Expand Down Expand Up @@ -422,17 +405,10 @@ def test_run_call_pipeline_if_set_additional_experiments(
if not sync:
model_from_job.wait()

true_fraction_split = gca_training_pipeline.FractionSplit(
training_fraction=_TEST_TRAINING_FRACTION_SPLIT,
validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT,
test_fraction=_TEST_TEST_FRACTION_SPLIT,
)

# Test that if defaults to the job display name
true_managed_model = gca_model.Model(display_name=_TEST_DISPLAY_NAME)

true_input_data_config = gca_training_pipeline.InputDataConfig(
fraction_split=true_fraction_split,
dataset_id=mock_dataset_time_series.name,
)

Expand Down

0 comments on commit 1a13577

Please sign in to comment.