From a5fa7a224570901988e5e7579c46cc2b823caa9b Mon Sep 17 00:00:00 2001 From: Morgan Du Date: Tue, 20 Apr 2021 14:22:43 -0700 Subject: [PATCH] fix: default model_display_name to _CustomTrainingJob.display_name when model_serving_container_image_uri is provided (#324) --- google/cloud/aiplatform/training_jobs.py | 11 ++ tests/unit/aiplatform/test_training_jobs.py | 150 ++++++++++++++++++++ 2 files changed, 161 insertions(+) diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index 220a34637e..5a12c3286b 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -1455,6 +1455,8 @@ def _prepare_and_validate_run( If the script produces a managed AI Platform Model. The display name of the Model. The name can be up to 128 characters long and can be consist of any UTF-8 characters. + + If not provided upon creation, the job's display_name is used. replica_count (int): The number of worker replicas. If replica count = 1 then one chief replica will be provisioned. If replica_count > 1 the remainder will be @@ -1491,6 +1493,9 @@ def _prepare_and_validate_run( """ ) + if self._managed_model.container_spec.image_uri: + model_display_name = model_display_name or self._display_name + "-model" + # validates args and will raise worker_pool_specs = _DistributedTrainingSpec.chief_worker_pool( replica_count=replica_count, @@ -1854,6 +1859,8 @@ def run( If the script produces a managed AI Platform Model. The display name of the Model. The name can be up to 128 characters long and can be consist of any UTF-8 characters. + + If not provided upon creation, the job's display_name is used. base_output_dir (str): GCS output directory of job. If not provided a timestamped directory in the staging directory will be used. @@ -2371,6 +2378,8 @@ def run( If the script produces a managed AI Platform Model. The display name of the Model. The name can be up to 128 characters long and can be consist of any UTF-8 characters. + + If not provided upon creation, the job's display_name is used. base_output_dir (str): GCS output directory of job. If not provided a timestamped directory in the staging directory will be used. @@ -3636,6 +3645,8 @@ def run( If the script produces a managed AI Platform Model. The display name of the Model. The name can be up to 128 characters long and can be consist of any UTF-8 characters. + + If not provided upon creation, the job's display_name is used. base_output_dir (str): GCS output directory of job. If not provided a timestamped directory in the staging directory will be used. diff --git a/tests/unit/aiplatform/test_training_jobs.py b/tests/unit/aiplatform/test_training_jobs.py index b5520a5f4c..d63b028445 100644 --- a/tests/unit/aiplatform/test_training_jobs.py +++ b/tests/unit/aiplatform/test_training_jobs.py @@ -3043,6 +3043,156 @@ def test_run_call_pipeline_service_create_with_tabular_dataset( assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_with_tabular_dataset_without_model_display_name( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_tabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_BUCKET_NAME, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = training_jobs.CustomPythonPackageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH, + python_module_name=_TEST_PYTHON_MODULE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS, + model_description=_TEST_MODEL_DESCRIPTION, + model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ) + + model_from_job = job.run( + dataset=mock_tabular_dataset, + # model_display_name=_TEST_MODEL_DISPLAY_NAME, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_args = _TEST_RUN_ARGS + + true_worker_pool_spec = { + "replicaCount": _TEST_REPLICA_COUNT, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "pythonPackageSpec": { + "executorImageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "pythonModule": _TEST_PYTHON_MODULE_NAME, + "packageUris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH], + "args": true_args, + }, + } + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction=_TEST_TEST_FRACTION_SPLIT, + ) + + env = [ + gca_env_var.EnvVar(name=str(key), value=str(value)) + for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items() + ] + + ports = [ + gca_model.Port(container_port=port) + for port in _TEST_MODEL_SERVING_CONTAINER_PORTS + ] + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + env=env, + ports=ports, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_DISPLAY_NAME + "-model", + description=_TEST_MODEL_DESCRIPTION, + container_spec=true_container_spec, + predict_schemata=gca_model.PredictSchemata( + instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ), + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + predefined_split=gca_training_pipeline.PredefinedSplit( + key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME + ), + dataset_id=mock_tabular_dataset.name, + gcs_destination=gca_io.GcsDestination( + output_uri_prefix=_TEST_BASE_OUTPUT_DIR + ), + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "workerPoolSpecs": [true_worker_pool_spec], + "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + assert job.get_model()._gca_resource is mock_model_service_get.return_value + + assert not job.has_failed + + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + @pytest.mark.parametrize("sync", [True, False]) def test_run_call_pipeline_service_create_with_bigquery_destination( self,