Skip to content

Commit

Permalink
fix: default model_display_name to _CustomTrainingJob.display_name wh…
Browse files Browse the repository at this point in the history
…en model_serving_container_image_uri is provided (#324)
  • Loading branch information
morgandu committed Apr 20, 2021
1 parent 9bd02ad commit a5fa7a2
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 0 deletions.
11 changes: 11 additions & 0 deletions google/cloud/aiplatform/training_jobs.py
Expand Up @@ -1455,6 +1455,8 @@ def _prepare_and_validate_run(
If the script produces a managed AI Platform Model. The display name of
the Model. The name can be up to 128 characters long and can be consist
of any UTF-8 characters.
If not provided upon creation, the job's display_name is used.
replica_count (int):
The number of worker replicas. If replica count = 1 then one chief
replica will be provisioned. If replica_count > 1 the remainder will be
Expand Down Expand Up @@ -1491,6 +1493,9 @@ def _prepare_and_validate_run(
"""
)

if self._managed_model.container_spec.image_uri:
model_display_name = model_display_name or self._display_name + "-model"

# validates args and will raise
worker_pool_specs = _DistributedTrainingSpec.chief_worker_pool(
replica_count=replica_count,
Expand Down Expand Up @@ -1854,6 +1859,8 @@ def run(
If the script produces a managed AI Platform Model. The display name of
the Model. The name can be up to 128 characters long and can be consist
of any UTF-8 characters.
If not provided upon creation, the job's display_name is used.
base_output_dir (str):
GCS output directory of job. If not provided a
timestamped directory in the staging directory will be used.
Expand Down Expand Up @@ -2371,6 +2378,8 @@ def run(
If the script produces a managed AI Platform Model. The display name of
the Model. The name can be up to 128 characters long and can be consist
of any UTF-8 characters.
If not provided upon creation, the job's display_name is used.
base_output_dir (str):
GCS output directory of job. If not provided a
timestamped directory in the staging directory will be used.
Expand Down Expand Up @@ -3636,6 +3645,8 @@ def run(
If the script produces a managed AI Platform Model. The display name of
the Model. The name can be up to 128 characters long and can be consist
of any UTF-8 characters.
If not provided upon creation, the job's display_name is used.
base_output_dir (str):
GCS output directory of job. If not provided a
timestamped directory in the staging directory will be used.
Expand Down
150 changes: 150 additions & 0 deletions tests/unit/aiplatform/test_training_jobs.py
Expand Up @@ -3043,6 +3043,156 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(

assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED

@pytest.mark.parametrize("sync", [True, False])
def test_run_call_pipeline_service_create_with_tabular_dataset_without_model_display_name(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
mock_tabular_dataset,
mock_model_service_get,
sync,
):
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_BUCKET_NAME,
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
)

job = training_jobs.CustomPythonPackageTrainingJob(
display_name=_TEST_DISPLAY_NAME,
python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH,
python_module_name=_TEST_PYTHON_MODULE_NAME,
container_uri=_TEST_TRAINING_CONTAINER_IMAGE,
model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE,
model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE,
model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE,
model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND,
model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS,
model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES,
model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS,
model_description=_TEST_MODEL_DESCRIPTION,
model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI,
model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI,
model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI,
)

model_from_job = job.run(
dataset=mock_tabular_dataset,
# model_display_name=_TEST_MODEL_DISPLAY_NAME,
base_output_dir=_TEST_BASE_OUTPUT_DIR,
args=_TEST_RUN_ARGS,
replica_count=1,
machine_type=_TEST_MACHINE_TYPE,
accelerator_type=_TEST_ACCELERATOR_TYPE,
accelerator_count=_TEST_ACCELERATOR_COUNT,
training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT,
validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT,
test_fraction_split=_TEST_TEST_FRACTION_SPLIT,
predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME,
sync=sync,
)

if not sync:
model_from_job.wait()

true_args = _TEST_RUN_ARGS

true_worker_pool_spec = {
"replicaCount": _TEST_REPLICA_COUNT,
"machineSpec": {
"machineType": _TEST_MACHINE_TYPE,
"acceleratorType": _TEST_ACCELERATOR_TYPE,
"acceleratorCount": _TEST_ACCELERATOR_COUNT,
},
"pythonPackageSpec": {
"executorImageUri": _TEST_TRAINING_CONTAINER_IMAGE,
"pythonModule": _TEST_PYTHON_MODULE_NAME,
"packageUris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH],
"args": true_args,
},
}

true_fraction_split = gca_training_pipeline.FractionSplit(
training_fraction=_TEST_TRAINING_FRACTION_SPLIT,
validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT,
test_fraction=_TEST_TEST_FRACTION_SPLIT,
)

env = [
gca_env_var.EnvVar(name=str(key), value=str(value))
for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items()
]

ports = [
gca_model.Port(container_port=port)
for port in _TEST_MODEL_SERVING_CONTAINER_PORTS
]

true_container_spec = gca_model.ModelContainerSpec(
image_uri=_TEST_SERVING_CONTAINER_IMAGE,
predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE,
health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE,
command=_TEST_MODEL_SERVING_CONTAINER_COMMAND,
args=_TEST_MODEL_SERVING_CONTAINER_ARGS,
env=env,
ports=ports,
)

true_managed_model = gca_model.Model(
display_name=_TEST_DISPLAY_NAME + "-model",
description=_TEST_MODEL_DESCRIPTION,
container_spec=true_container_spec,
predict_schemata=gca_model.PredictSchemata(
instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI,
parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI,
prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI,
),
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
fraction_split=true_fraction_split,
predefined_split=gca_training_pipeline.PredefinedSplit(
key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME
),
dataset_id=mock_tabular_dataset.name,
gcs_destination=gca_io.GcsDestination(
output_uri_prefix=_TEST_BASE_OUTPUT_DIR
),
)

true_training_pipeline = gca_training_pipeline.TrainingPipeline(
display_name=_TEST_DISPLAY_NAME,
training_task_definition=schema.training_job.definition.custom_task,
training_task_inputs=json_format.ParseDict(
{
"workerPoolSpecs": [true_worker_pool_spec],
"baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR},
},
struct_pb2.Value(),
),
model_to_upload=true_managed_model,
input_data_config=true_input_data_config,
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
)

mock_pipeline_service_create.assert_called_once_with(
parent=initializer.global_config.common_location_path(),
training_pipeline=true_training_pipeline,
)

assert job._gca_resource is mock_pipeline_service_get.return_value

mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME)

assert model_from_job._gca_resource is mock_model_service_get.return_value

assert job.get_model()._gca_resource is mock_model_service_get.return_value

assert not job.has_failed

assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED

@pytest.mark.parametrize("sync", [True, False])
def test_run_call_pipeline_service_create_with_bigquery_destination(
self,
Expand Down

0 comments on commit a5fa7a2

Please sign in to comment.