Skip to content

Commit

Permalink
fix: change default replica count to 1 for custom training job classes (
Browse files Browse the repository at this point in the history
  • Loading branch information
morgandu committed Jul 30, 2021
1 parent 6a99b12 commit c24251f
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 17 deletions.
8 changes: 4 additions & 4 deletions google/cloud/aiplatform/training_jobs.py
Expand Up @@ -1107,7 +1107,7 @@ def network(self) -> Optional[str]:
def _prepare_and_validate_run(
self,
model_display_name: Optional[str] = None,
replica_count: int = 0,
replica_count: int = 1,
machine_type: str = "n1-standard-4",
accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
accelerator_count: int = 0,
Expand Down Expand Up @@ -1521,7 +1521,7 @@ def run(
bigquery_destination: Optional[str] = None,
args: Optional[List[Union[str, float, int]]] = None,
environment_variables: Optional[Dict[str, str]] = None,
replica_count: int = 0,
replica_count: int = 1,
machine_type: str = "n1-standard-4",
accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
accelerator_count: int = 0,
Expand Down Expand Up @@ -2143,7 +2143,7 @@ def run(
bigquery_destination: Optional[str] = None,
args: Optional[List[Union[str, float, int]]] = None,
environment_variables: Optional[Dict[str, str]] = None,
replica_count: int = 0,
replica_count: int = 1,
machine_type: str = "n1-standard-4",
accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
accelerator_count: int = 0,
Expand Down Expand Up @@ -4095,7 +4095,7 @@ def run(
bigquery_destination: Optional[str] = None,
args: Optional[List[Union[str, float, int]]] = None,
environment_variables: Optional[Dict[str, str]] = None,
replica_count: int = 0,
replica_count: int = 1,
machine_type: str = "n1-standard-4",
accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
accelerator_count: int = 0,
Expand Down
13 changes: 0 additions & 13 deletions tests/unit/aiplatform/test_training_jobs.py
Expand Up @@ -652,7 +652,6 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
network=_TEST_NETWORK,
args=_TEST_RUN_ARGS,
environment_variables=_TEST_ENVIRONMENT_VARIABLES,
replica_count=1,
machine_type=_TEST_MACHINE_TYPE,
accelerator_type=_TEST_ACCELERATOR_TYPE,
accelerator_count=_TEST_ACCELERATOR_COUNT,
Expand Down Expand Up @@ -825,7 +824,6 @@ def test_run_call_pipeline_service_create_with_bigquery_destination(
bigquery_destination=_TEST_BIGQUERY_DESTINATION,
args=_TEST_RUN_ARGS,
environment_variables=_TEST_ENVIRONMENT_VARIABLES,
replica_count=1,
machine_type=_TEST_MACHINE_TYPE,
accelerator_type=_TEST_ACCELERATOR_TYPE,
accelerator_count=_TEST_ACCELERATOR_COUNT,
Expand Down Expand Up @@ -1099,7 +1097,6 @@ def test_run_call_pipeline_service_create_with_no_dataset(
base_output_dir=_TEST_BASE_OUTPUT_DIR,
args=_TEST_RUN_ARGS,
environment_variables=_TEST_ENVIRONMENT_VARIABLES,
replica_count=1,
machine_type=_TEST_MACHINE_TYPE,
accelerator_type=_TEST_ACCELERATOR_TYPE,
accelerator_count=_TEST_ACCELERATOR_COUNT,
Expand Down Expand Up @@ -1628,7 +1625,6 @@ def test_run_call_pipeline_service_create_with_nontabular_dataset(
annotation_schema_uri=_TEST_ANNOTATION_SCHEMA_URI,
base_output_dir=_TEST_BASE_OUTPUT_DIR,
args=_TEST_RUN_ARGS,
replica_count=1,
machine_type=_TEST_MACHINE_TYPE,
accelerator_type=_TEST_ACCELERATOR_TYPE,
accelerator_count=_TEST_ACCELERATOR_COUNT,
Expand Down Expand Up @@ -1870,7 +1866,6 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
base_output_dir=_TEST_BASE_OUTPUT_DIR,
args=_TEST_RUN_ARGS,
environment_variables=_TEST_ENVIRONMENT_VARIABLES,
replica_count=1,
machine_type=_TEST_MACHINE_TYPE,
accelerator_type=_TEST_ACCELERATOR_TYPE,
accelerator_count=_TEST_ACCELERATOR_COUNT,
Expand Down Expand Up @@ -2032,7 +2027,6 @@ def test_run_call_pipeline_service_create_with_bigquery_destination(
base_output_dir=_TEST_BASE_OUTPUT_DIR,
bigquery_destination=_TEST_BIGQUERY_DESTINATION,
args=_TEST_RUN_ARGS,
replica_count=1,
machine_type=_TEST_MACHINE_TYPE,
accelerator_type=_TEST_ACCELERATOR_TYPE,
accelerator_count=_TEST_ACCELERATOR_COUNT,
Expand Down Expand Up @@ -2294,7 +2288,6 @@ def test_run_call_pipeline_service_create_with_no_dataset(
model_from_job = job.run(
base_output_dir=_TEST_BASE_OUTPUT_DIR,
args=_TEST_RUN_ARGS,
replica_count=1,
machine_type=_TEST_MACHINE_TYPE,
accelerator_type=_TEST_ACCELERATOR_TYPE,
accelerator_count=_TEST_ACCELERATOR_COUNT,
Expand Down Expand Up @@ -2674,7 +2667,6 @@ def test_run_call_pipeline_service_create_with_nontabular_dataset(
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
args=_TEST_RUN_ARGS,
replica_count=1,
machine_type=_TEST_MACHINE_TYPE,
accelerator_type=_TEST_ACCELERATOR_TYPE,
accelerator_count=_TEST_ACCELERATOR_COUNT,
Expand Down Expand Up @@ -3112,7 +3104,6 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
network=_TEST_NETWORK,
args=_TEST_RUN_ARGS,
environment_variables=_TEST_ENVIRONMENT_VARIABLES,
replica_count=1,
machine_type=_TEST_MACHINE_TYPE,
accelerator_type=_TEST_ACCELERATOR_TYPE,
accelerator_count=_TEST_ACCELERATOR_COUNT,
Expand Down Expand Up @@ -3273,7 +3264,6 @@ def test_run_call_pipeline_service_create_with_tabular_dataset_without_model_dis
# model_display_name=_TEST_MODEL_DISPLAY_NAME,
base_output_dir=_TEST_BASE_OUTPUT_DIR,
args=_TEST_RUN_ARGS,
replica_count=1,
machine_type=_TEST_MACHINE_TYPE,
accelerator_type=_TEST_ACCELERATOR_TYPE,
accelerator_count=_TEST_ACCELERATOR_COUNT,
Expand Down Expand Up @@ -3426,7 +3416,6 @@ def test_run_call_pipeline_service_create_with_bigquery_destination(
base_output_dir=_TEST_BASE_OUTPUT_DIR,
bigquery_destination=_TEST_BIGQUERY_DESTINATION,
args=_TEST_RUN_ARGS,
replica_count=1,
machine_type=_TEST_MACHINE_TYPE,
accelerator_type=_TEST_ACCELERATOR_TYPE,
accelerator_count=_TEST_ACCELERATOR_COUNT,
Expand Down Expand Up @@ -3693,7 +3682,6 @@ def test_run_call_pipeline_service_create_with_no_dataset(
model_display_name=_TEST_MODEL_DISPLAY_NAME,
base_output_dir=_TEST_BASE_OUTPUT_DIR,
args=_TEST_RUN_ARGS,
replica_count=1,
machine_type=_TEST_MACHINE_TYPE,
accelerator_type=_TEST_ACCELERATOR_TYPE,
accelerator_count=_TEST_ACCELERATOR_COUNT,
Expand Down Expand Up @@ -4080,7 +4068,6 @@ def test_run_call_pipeline_service_create_with_nontabular_dataset(
annotation_schema_uri=_TEST_ANNOTATION_SCHEMA_URI,
base_output_dir=_TEST_BASE_OUTPUT_DIR,
args=_TEST_RUN_ARGS,
replica_count=1,
machine_type=_TEST_MACHINE_TYPE,
accelerator_type=_TEST_ACCELERATOR_TYPE,
accelerator_count=_TEST_ACCELERATOR_COUNT,
Expand Down

0 comments on commit c24251f

Please sign in to comment.