From c24251fdd230e73c2aadb4369266b78979a31015 Mon Sep 17 00:00:00 2001 From: Morgan Du Date: Fri, 30 Jul 2021 11:18:19 -0700 Subject: [PATCH] fix: change default replica count to 1 for custom training job classes (#579) --- google/cloud/aiplatform/training_jobs.py | 8 ++++---- tests/unit/aiplatform/test_training_jobs.py | 13 ------------- 2 files changed, 4 insertions(+), 17 deletions(-) diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index 5f9d7c3445..05a9a3aeb3 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -1107,7 +1107,7 @@ def network(self) -> Optional[str]: def _prepare_and_validate_run( self, model_display_name: Optional[str] = None, - replica_count: int = 0, + replica_count: int = 1, machine_type: str = "n1-standard-4", accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", accelerator_count: int = 0, @@ -1521,7 +1521,7 @@ def run( bigquery_destination: Optional[str] = None, args: Optional[List[Union[str, float, int]]] = None, environment_variables: Optional[Dict[str, str]] = None, - replica_count: int = 0, + replica_count: int = 1, machine_type: str = "n1-standard-4", accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", accelerator_count: int = 0, @@ -2143,7 +2143,7 @@ def run( bigquery_destination: Optional[str] = None, args: Optional[List[Union[str, float, int]]] = None, environment_variables: Optional[Dict[str, str]] = None, - replica_count: int = 0, + replica_count: int = 1, machine_type: str = "n1-standard-4", accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", accelerator_count: int = 0, @@ -4095,7 +4095,7 @@ def run( bigquery_destination: Optional[str] = None, args: Optional[List[Union[str, float, int]]] = None, environment_variables: Optional[Dict[str, str]] = None, - replica_count: int = 0, + replica_count: int = 1, machine_type: str = "n1-standard-4", accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", accelerator_count: int = 0, diff --git a/tests/unit/aiplatform/test_training_jobs.py b/tests/unit/aiplatform/test_training_jobs.py index 72c17dedc5..c639c462cb 100644 --- a/tests/unit/aiplatform/test_training_jobs.py +++ b/tests/unit/aiplatform/test_training_jobs.py @@ -652,7 +652,6 @@ def test_run_call_pipeline_service_create_with_tabular_dataset( network=_TEST_NETWORK, args=_TEST_RUN_ARGS, environment_variables=_TEST_ENVIRONMENT_VARIABLES, - replica_count=1, machine_type=_TEST_MACHINE_TYPE, accelerator_type=_TEST_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, @@ -825,7 +824,6 @@ def test_run_call_pipeline_service_create_with_bigquery_destination( bigquery_destination=_TEST_BIGQUERY_DESTINATION, args=_TEST_RUN_ARGS, environment_variables=_TEST_ENVIRONMENT_VARIABLES, - replica_count=1, machine_type=_TEST_MACHINE_TYPE, accelerator_type=_TEST_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, @@ -1099,7 +1097,6 @@ def test_run_call_pipeline_service_create_with_no_dataset( base_output_dir=_TEST_BASE_OUTPUT_DIR, args=_TEST_RUN_ARGS, environment_variables=_TEST_ENVIRONMENT_VARIABLES, - replica_count=1, machine_type=_TEST_MACHINE_TYPE, accelerator_type=_TEST_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, @@ -1628,7 +1625,6 @@ def test_run_call_pipeline_service_create_with_nontabular_dataset( annotation_schema_uri=_TEST_ANNOTATION_SCHEMA_URI, base_output_dir=_TEST_BASE_OUTPUT_DIR, args=_TEST_RUN_ARGS, - replica_count=1, machine_type=_TEST_MACHINE_TYPE, accelerator_type=_TEST_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, @@ -1870,7 +1866,6 @@ def test_run_call_pipeline_service_create_with_tabular_dataset( base_output_dir=_TEST_BASE_OUTPUT_DIR, args=_TEST_RUN_ARGS, environment_variables=_TEST_ENVIRONMENT_VARIABLES, - replica_count=1, machine_type=_TEST_MACHINE_TYPE, accelerator_type=_TEST_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, @@ -2032,7 +2027,6 @@ def test_run_call_pipeline_service_create_with_bigquery_destination( base_output_dir=_TEST_BASE_OUTPUT_DIR, bigquery_destination=_TEST_BIGQUERY_DESTINATION, args=_TEST_RUN_ARGS, - replica_count=1, machine_type=_TEST_MACHINE_TYPE, accelerator_type=_TEST_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, @@ -2294,7 +2288,6 @@ def test_run_call_pipeline_service_create_with_no_dataset( model_from_job = job.run( base_output_dir=_TEST_BASE_OUTPUT_DIR, args=_TEST_RUN_ARGS, - replica_count=1, machine_type=_TEST_MACHINE_TYPE, accelerator_type=_TEST_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, @@ -2674,7 +2667,6 @@ def test_run_call_pipeline_service_create_with_nontabular_dataset( service_account=_TEST_SERVICE_ACCOUNT, network=_TEST_NETWORK, args=_TEST_RUN_ARGS, - replica_count=1, machine_type=_TEST_MACHINE_TYPE, accelerator_type=_TEST_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, @@ -3112,7 +3104,6 @@ def test_run_call_pipeline_service_create_with_tabular_dataset( network=_TEST_NETWORK, args=_TEST_RUN_ARGS, environment_variables=_TEST_ENVIRONMENT_VARIABLES, - replica_count=1, machine_type=_TEST_MACHINE_TYPE, accelerator_type=_TEST_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, @@ -3273,7 +3264,6 @@ def test_run_call_pipeline_service_create_with_tabular_dataset_without_model_dis # model_display_name=_TEST_MODEL_DISPLAY_NAME, base_output_dir=_TEST_BASE_OUTPUT_DIR, args=_TEST_RUN_ARGS, - replica_count=1, machine_type=_TEST_MACHINE_TYPE, accelerator_type=_TEST_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, @@ -3426,7 +3416,6 @@ def test_run_call_pipeline_service_create_with_bigquery_destination( base_output_dir=_TEST_BASE_OUTPUT_DIR, bigquery_destination=_TEST_BIGQUERY_DESTINATION, args=_TEST_RUN_ARGS, - replica_count=1, machine_type=_TEST_MACHINE_TYPE, accelerator_type=_TEST_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, @@ -3693,7 +3682,6 @@ def test_run_call_pipeline_service_create_with_no_dataset( model_display_name=_TEST_MODEL_DISPLAY_NAME, base_output_dir=_TEST_BASE_OUTPUT_DIR, args=_TEST_RUN_ARGS, - replica_count=1, machine_type=_TEST_MACHINE_TYPE, accelerator_type=_TEST_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, @@ -4080,7 +4068,6 @@ def test_run_call_pipeline_service_create_with_nontabular_dataset( annotation_schema_uri=_TEST_ANNOTATION_SCHEMA_URI, base_output_dir=_TEST_BASE_OUTPUT_DIR, args=_TEST_RUN_ARGS, - replica_count=1, machine_type=_TEST_MACHINE_TYPE, accelerator_type=_TEST_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT,