Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add filter and timestamp splits #627

Merged
merged 11 commits into from Aug 18, 2021
1,181 changes: 943 additions & 238 deletions google/cloud/aiplatform/training_jobs.py

Large diffs are not rendered by default.

28 changes: 2 additions & 26 deletions tests/unit/aiplatform/test_automl_forecasting_training_jobs.py
Expand Up @@ -103,14 +103,11 @@
_TEST_DATASET_NAME = "test-dataset-name"

_TEST_MODEL_DISPLAY_NAME = "model-display-name"

_TEST_LABELS = {"key": "value"}
_TEST_MODEL_LABELS = {"model_key": "model_value"}
_TEST_TRAINING_FRACTION_SPLIT = 0.8
_TEST_VALIDATION_FRACTION_SPLIT = 0.1
_TEST_TEST_FRACTION_SPLIT = 0.1
_TEST_PREDEFINED_SPLIT_COLUMN_NAME = "split"

_TEST_OUTPUT_PYTHON_PACKAGE_PATH = "gs://test/ouput/python/trainer.tar.gz"
_TEST_PREDEFINED_SPLIT_COLUMN_NAME = "split"

_TEST_MODEL_NAME = "projects/my-project/locations/us-central1/models/12345"

Expand Down Expand Up @@ -261,18 +258,11 @@ def test_run_call_pipeline_service_create(
if not sync:
model_from_job.wait()

true_fraction_split = gca_training_pipeline.FractionSplit(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed since we test splits separately later.

training_fraction=_TEST_TRAINING_FRACTION_SPLIT,
validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT,
test_fraction=_TEST_TEST_FRACTION_SPLIT,
)

true_managed_model = gca_model.Model(
display_name=_TEST_MODEL_DISPLAY_NAME, labels=_TEST_MODEL_LABELS
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
fraction_split=true_fraction_split,
predefined_split=gca_training_pipeline.PredefinedSplit(
key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME
),
Expand Down Expand Up @@ -348,19 +338,12 @@ def test_run_call_pipeline_if_no_model_display_name_nor_model_labels(
if not sync:
model_from_job.wait()

true_fraction_split = gca_training_pipeline.FractionSplit(
training_fraction=_TEST_TRAINING_FRACTION_SPLIT,
validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT,
test_fraction=_TEST_TEST_FRACTION_SPLIT,
)

# Test that if defaults to the job display name
true_managed_model = gca_model.Model(
display_name=_TEST_DISPLAY_NAME, labels=_TEST_LABELS,
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
fraction_split=true_fraction_split,
dataset_id=mock_dataset_time_series.name,
)

Expand Down Expand Up @@ -422,17 +405,10 @@ def test_run_call_pipeline_if_set_additional_experiments(
if not sync:
model_from_job.wait()

true_fraction_split = gca_training_pipeline.FractionSplit(
training_fraction=_TEST_TRAINING_FRACTION_SPLIT,
validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT,
test_fraction=_TEST_TEST_FRACTION_SPLIT,
)

# Test that if defaults to the job display name
true_managed_model = gca_model.Model(display_name=_TEST_DISPLAY_NAME)

true_input_data_config = gca_training_pipeline.InputDataConfig(
fraction_split=true_fraction_split,
dataset_id=mock_dataset_time_series.name,
)

Expand Down