diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index 8ef054fc97..f3f447deb6 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -1526,6 +1526,7 @@ def _prepare_training_task_inputs_and_output_dir( worker_pool_specs: _DistributedTrainingSpec, base_output_dir: Optional[str] = None, service_account: Optional[str] = None, + network: Optional[str] = None, ) -> Tuple[Dict, str]: """Prepares training task inputs and output directory for custom job. @@ -1538,6 +1539,11 @@ def _prepare_training_task_inputs_and_output_dir( service_account (str): Specifies the service account for workload run-as account. Users submitting jobs must have act-as permission on this run-as account. + network (str): + The full name of the Compute Engine network to which the job + should be peered. For example, projects/12345/global/networks/myVPC. + Private services access must already be configured for the network. + If left unspecified, the job is not peered with any network. Returns: Training task inputs and Output directory for custom job. """ @@ -1556,6 +1562,8 @@ def _prepare_training_task_inputs_and_output_dir( if service_account: training_task_inputs["serviceAccount"] = service_account + if network: + training_task_inputs["network"] = network return training_task_inputs, base_output_dir @@ -1803,6 +1811,7 @@ def run( model_display_name: Optional[str] = None, base_output_dir: Optional[str] = None, service_account: Optional[str] = None, + network: Optional[str] = None, bigquery_destination: Optional[str] = None, args: Optional[List[Union[str, float, int]]] = None, environment_variables: Optional[Dict[str, str]] = None, @@ -1891,6 +1900,11 @@ def run( service_account (str): Specifies the service account for workload run-as account. Users submitting jobs must have act-as permission on this run-as account. + network (str): + The full name of the Compute Engine network to which the job + should be peered. For example, projects/12345/global/networks/myVPC. + Private services access must already be configured for the network. + If left unspecified, the job is not peered with any network. bigquery_destination (str): Provide this field if `dataset` is a BiqQuery dataset. The BigQuery project location where the training data is to @@ -1981,6 +1995,7 @@ def run( environment_variables=environment_variables, base_output_dir=base_output_dir, service_account=service_account, + network=network, bigquery_destination=bigquery_destination, training_fraction_split=training_fraction_split, validation_fraction_split=validation_fraction_split, @@ -2008,6 +2023,7 @@ def _run( environment_variables: Optional[Dict[str, str]] = None, base_output_dir: Optional[str] = None, service_account: Optional[str] = None, + network: Optional[str] = None, bigquery_destination: Optional[str] = None, training_fraction_split: float = 0.8, validation_fraction_split: float = 0.1, @@ -2061,6 +2077,11 @@ def _run( service_account (str): Specifies the service account for workload run-as account. Users submitting jobs must have act-as permission on this run-as account. + network (str): + The full name of the Compute Engine network to which the job + should be peered. For example, projects/12345/global/networks/myVPC. + Private services access must already be configured for the network. + If left unspecified, the job is not peered with any network. bigquery_destination (str): Provide this field if `dataset` is a BiqQuery dataset. The BigQuery project location where the training data is to @@ -2130,7 +2151,10 @@ def _run( training_task_inputs, base_output_dir, ) = self._prepare_training_task_inputs_and_output_dir( - worker_pool_specs, base_output_dir, service_account + worker_pool_specs=worker_pool_specs, + base_output_dir=base_output_dir, + service_account=service_account, + network=network, ) model = self._run_job( @@ -2375,6 +2399,7 @@ def run( model_display_name: Optional[str] = None, base_output_dir: Optional[str] = None, service_account: Optional[str] = None, + network: Optional[str] = None, bigquery_destination: Optional[str] = None, args: Optional[List[Union[str, float, int]]] = None, environment_variables: Optional[Dict[str, str]] = None, @@ -2456,6 +2481,11 @@ def run( service_account (str): Specifies the service account for workload run-as account. Users submitting jobs must have act-as permission on this run-as account. + network (str): + The full name of the Compute Engine network to which the job + should be peered. For example, projects/12345/global/networks/myVPC. + Private services access must already be configured for the network. + If left unspecified, the job is not peered with any network. bigquery_destination (str): Provide this field if `dataset` is a BiqQuery dataset. The BigQuery project location where the training data is to @@ -2545,6 +2575,7 @@ def run( environment_variables=environment_variables, base_output_dir=base_output_dir, service_account=service_account, + network=network, bigquery_destination=bigquery_destination, training_fraction_split=training_fraction_split, validation_fraction_split=validation_fraction_split, @@ -2571,6 +2602,7 @@ def _run( environment_variables: Optional[Dict[str, str]] = None, base_output_dir: Optional[str] = None, service_account: Optional[str] = None, + network: Optional[str] = None, bigquery_destination: Optional[str] = None, training_fraction_split: float = 0.8, validation_fraction_split: float = 0.1, @@ -2621,6 +2653,11 @@ def _run( service_account (str): Specifies the service account for workload run-as account. Users submitting jobs must have act-as permission on this run-as account. + network (str): + The full name of the Compute Engine network to which the job + should be peered. For example, projects/12345/global/networks/myVPC. + Private services access must already be configured for the network. + If left unspecified, the job is not peered with any network. bigquery_destination (str): The BigQuery project location where the training data is to be written to. In the given project a new dataset is created @@ -2683,7 +2720,10 @@ def _run( training_task_inputs, base_output_dir, ) = self._prepare_training_task_inputs_and_output_dir( - worker_pool_specs, base_output_dir, service_account + worker_pool_specs=worker_pool_specs, + base_output_dir=base_output_dir, + service_account=service_account, + network=network, ) model = self._run_job( @@ -3709,6 +3749,7 @@ def run( model_display_name: Optional[str] = None, base_output_dir: Optional[str] = None, service_account: Optional[str] = None, + network: Optional[str] = None, bigquery_destination: Optional[str] = None, args: Optional[List[Union[str, float, int]]] = None, environment_variables: Optional[Dict[str, str]] = None, @@ -3790,6 +3831,11 @@ def run( service_account (str): Specifies the service account for workload run-as account. Users submitting jobs must have act-as permission on this run-as account. + network (str): + The full name of the Compute Engine network to which the job + should be peered. For example, projects/12345/global/networks/myVPC. + Private services access must already be configured for the network. + If left unspecified, the job is not peered with any network. bigquery_destination (str): Provide this field if `dataset` is a BiqQuery dataset. The BigQuery project location where the training data is to @@ -3874,6 +3920,7 @@ def run( environment_variables=environment_variables, base_output_dir=base_output_dir, service_account=service_account, + network=network, training_fraction_split=training_fraction_split, validation_fraction_split=validation_fraction_split, test_fraction_split=test_fraction_split, @@ -3900,6 +3947,7 @@ def _run( environment_variables: Optional[Dict[str, str]] = None, base_output_dir: Optional[str] = None, service_account: Optional[str] = None, + network: Optional[str] = None, training_fraction_split: float = 0.8, validation_fraction_split: float = 0.1, test_fraction_split: float = 0.1, @@ -3951,6 +3999,11 @@ def _run( service_account (str): Specifies the service account for workload run-as account. Users submitting jobs must have act-as permission on this run-as account. + network (str): + The full name of the Compute Engine network to which the job + should be peered. For example, projects/12345/global/networks/myVPC. + Private services access must already be configured for the network. + If left unspecified, the job is not peered with any network. training_fraction_split (float): The fraction of the input data that is to be used to train the Model. @@ -3999,7 +4052,10 @@ def _run( training_task_inputs, base_output_dir, ) = self._prepare_training_task_inputs_and_output_dir( - worker_pool_specs, base_output_dir, service_account + worker_pool_specs=worker_pool_specs, + base_output_dir=base_output_dir, + service_account=service_account, + network=network, ) model = self._run_job( diff --git a/tests/unit/aiplatform/test_training_jobs.py b/tests/unit/aiplatform/test_training_jobs.py index c3c0e33863..8fd82c7727 100644 --- a/tests/unit/aiplatform/test_training_jobs.py +++ b/tests/unit/aiplatform/test_training_jobs.py @@ -109,6 +109,7 @@ ) _TEST_ALT_PROJECT = "test-project-alt" _TEST_ALT_LOCATION = "europe-west4" +_TEST_NETWORK = f"projects/{_TEST_PROJECT}/global/networks/{_TEST_ID}" _TEST_MODEL_INSTANCE_SCHEMA_URI = "instance_schema_uri.yaml" _TEST_MODEL_PARAMETERS_SCHEMA_URI = "parameters_schema_uri.yaml" @@ -598,6 +599,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset( dataset=mock_tabular_dataset, base_output_dir=_TEST_BASE_OUTPUT_DIR, service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, args=_TEST_RUN_ARGS, environment_variables=_TEST_ENVIRONMENT_VARIABLES, replica_count=1, @@ -700,6 +702,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset( "workerPoolSpecs": [true_worker_pool_spec], "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, "serviceAccount": _TEST_SERVICE_ACCOUNT, + "network": _TEST_NETWORK, }, struct_pb2.Value(), ), @@ -2539,6 +2542,7 @@ def test_run_call_pipeline_service_create_with_nontabular_dataset( annotation_schema_uri=_TEST_ANNOTATION_SCHEMA_URI, base_output_dir=_TEST_BASE_OUTPUT_DIR, service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, args=_TEST_RUN_ARGS, replica_count=1, machine_type=_TEST_MACHINE_TYPE, @@ -2621,6 +2625,7 @@ def test_run_call_pipeline_service_create_with_nontabular_dataset( "workerPoolSpecs": [true_worker_pool_spec], "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, "serviceAccount": _TEST_SERVICE_ACCOUNT, + "network": _TEST_NETWORK, }, struct_pb2.Value(), ), @@ -2970,6 +2975,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset( model_display_name=_TEST_MODEL_DISPLAY_NAME, base_output_dir=_TEST_BASE_OUTPUT_DIR, service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, args=_TEST_RUN_ARGS, environment_variables=_TEST_ENVIRONMENT_VARIABLES, replica_count=1, @@ -3065,6 +3071,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset( "workerPoolSpecs": [true_worker_pool_spec], "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, "serviceAccount": _TEST_SERVICE_ACCOUNT, + "network": _TEST_NETWORK, }, struct_pb2.Value(), ),