Skip to content

Commit

Permalink
feat: Add VPC Peering support to CustomTrainingJob classes (#378)
Browse files Browse the repository at this point in the history
* Add 'network' for VPC Peering in custom training

* Blacken code
  • Loading branch information
vinnysenthil committed May 12, 2021
1 parent 7eaedb6 commit 56273f7
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 3 deletions.
62 changes: 59 additions & 3 deletions google/cloud/aiplatform/training_jobs.py
Expand Up @@ -1526,6 +1526,7 @@ def _prepare_training_task_inputs_and_output_dir(
worker_pool_specs: _DistributedTrainingSpec,
base_output_dir: Optional[str] = None,
service_account: Optional[str] = None,
network: Optional[str] = None,
) -> Tuple[Dict, str]:
"""Prepares training task inputs and output directory for custom job.
Expand All @@ -1538,6 +1539,11 @@ def _prepare_training_task_inputs_and_output_dir(
service_account (str):
Specifies the service account for workload run-as account.
Users submitting jobs must have act-as permission on this run-as account.
network (str):
The full name of the Compute Engine network to which the job
should be peered. For example, projects/12345/global/networks/myVPC.
Private services access must already be configured for the network.
If left unspecified, the job is not peered with any network.
Returns:
Training task inputs and Output directory for custom job.
"""
Expand All @@ -1556,6 +1562,8 @@ def _prepare_training_task_inputs_and_output_dir(

if service_account:
training_task_inputs["serviceAccount"] = service_account
if network:
training_task_inputs["network"] = network

return training_task_inputs, base_output_dir

Expand Down Expand Up @@ -1803,6 +1811,7 @@ def run(
model_display_name: Optional[str] = None,
base_output_dir: Optional[str] = None,
service_account: Optional[str] = None,
network: Optional[str] = None,
bigquery_destination: Optional[str] = None,
args: Optional[List[Union[str, float, int]]] = None,
environment_variables: Optional[Dict[str, str]] = None,
Expand Down Expand Up @@ -1891,6 +1900,11 @@ def run(
service_account (str):
Specifies the service account for workload run-as account.
Users submitting jobs must have act-as permission on this run-as account.
network (str):
The full name of the Compute Engine network to which the job
should be peered. For example, projects/12345/global/networks/myVPC.
Private services access must already be configured for the network.
If left unspecified, the job is not peered with any network.
bigquery_destination (str):
Provide this field if `dataset` is a BiqQuery dataset.
The BigQuery project location where the training data is to
Expand Down Expand Up @@ -1981,6 +1995,7 @@ def run(
environment_variables=environment_variables,
base_output_dir=base_output_dir,
service_account=service_account,
network=network,
bigquery_destination=bigquery_destination,
training_fraction_split=training_fraction_split,
validation_fraction_split=validation_fraction_split,
Expand Down Expand Up @@ -2008,6 +2023,7 @@ def _run(
environment_variables: Optional[Dict[str, str]] = None,
base_output_dir: Optional[str] = None,
service_account: Optional[str] = None,
network: Optional[str] = None,
bigquery_destination: Optional[str] = None,
training_fraction_split: float = 0.8,
validation_fraction_split: float = 0.1,
Expand Down Expand Up @@ -2061,6 +2077,11 @@ def _run(
service_account (str):
Specifies the service account for workload run-as account.
Users submitting jobs must have act-as permission on this run-as account.
network (str):
The full name of the Compute Engine network to which the job
should be peered. For example, projects/12345/global/networks/myVPC.
Private services access must already be configured for the network.
If left unspecified, the job is not peered with any network.
bigquery_destination (str):
Provide this field if `dataset` is a BiqQuery dataset.
The BigQuery project location where the training data is to
Expand Down Expand Up @@ -2130,7 +2151,10 @@ def _run(
training_task_inputs,
base_output_dir,
) = self._prepare_training_task_inputs_and_output_dir(
worker_pool_specs, base_output_dir, service_account
worker_pool_specs=worker_pool_specs,
base_output_dir=base_output_dir,
service_account=service_account,
network=network,
)

model = self._run_job(
Expand Down Expand Up @@ -2375,6 +2399,7 @@ def run(
model_display_name: Optional[str] = None,
base_output_dir: Optional[str] = None,
service_account: Optional[str] = None,
network: Optional[str] = None,
bigquery_destination: Optional[str] = None,
args: Optional[List[Union[str, float, int]]] = None,
environment_variables: Optional[Dict[str, str]] = None,
Expand Down Expand Up @@ -2456,6 +2481,11 @@ def run(
service_account (str):
Specifies the service account for workload run-as account.
Users submitting jobs must have act-as permission on this run-as account.
network (str):
The full name of the Compute Engine network to which the job
should be peered. For example, projects/12345/global/networks/myVPC.
Private services access must already be configured for the network.
If left unspecified, the job is not peered with any network.
bigquery_destination (str):
Provide this field if `dataset` is a BiqQuery dataset.
The BigQuery project location where the training data is to
Expand Down Expand Up @@ -2545,6 +2575,7 @@ def run(
environment_variables=environment_variables,
base_output_dir=base_output_dir,
service_account=service_account,
network=network,
bigquery_destination=bigquery_destination,
training_fraction_split=training_fraction_split,
validation_fraction_split=validation_fraction_split,
Expand All @@ -2571,6 +2602,7 @@ def _run(
environment_variables: Optional[Dict[str, str]] = None,
base_output_dir: Optional[str] = None,
service_account: Optional[str] = None,
network: Optional[str] = None,
bigquery_destination: Optional[str] = None,
training_fraction_split: float = 0.8,
validation_fraction_split: float = 0.1,
Expand Down Expand Up @@ -2621,6 +2653,11 @@ def _run(
service_account (str):
Specifies the service account for workload run-as account.
Users submitting jobs must have act-as permission on this run-as account.
network (str):
The full name of the Compute Engine network to which the job
should be peered. For example, projects/12345/global/networks/myVPC.
Private services access must already be configured for the network.
If left unspecified, the job is not peered with any network.
bigquery_destination (str):
The BigQuery project location where the training data is to
be written to. In the given project a new dataset is created
Expand Down Expand Up @@ -2683,7 +2720,10 @@ def _run(
training_task_inputs,
base_output_dir,
) = self._prepare_training_task_inputs_and_output_dir(
worker_pool_specs, base_output_dir, service_account
worker_pool_specs=worker_pool_specs,
base_output_dir=base_output_dir,
service_account=service_account,
network=network,
)

model = self._run_job(
Expand Down Expand Up @@ -3709,6 +3749,7 @@ def run(
model_display_name: Optional[str] = None,
base_output_dir: Optional[str] = None,
service_account: Optional[str] = None,
network: Optional[str] = None,
bigquery_destination: Optional[str] = None,
args: Optional[List[Union[str, float, int]]] = None,
environment_variables: Optional[Dict[str, str]] = None,
Expand Down Expand Up @@ -3790,6 +3831,11 @@ def run(
service_account (str):
Specifies the service account for workload run-as account.
Users submitting jobs must have act-as permission on this run-as account.
network (str):
The full name of the Compute Engine network to which the job
should be peered. For example, projects/12345/global/networks/myVPC.
Private services access must already be configured for the network.
If left unspecified, the job is not peered with any network.
bigquery_destination (str):
Provide this field if `dataset` is a BiqQuery dataset.
The BigQuery project location where the training data is to
Expand Down Expand Up @@ -3874,6 +3920,7 @@ def run(
environment_variables=environment_variables,
base_output_dir=base_output_dir,
service_account=service_account,
network=network,
training_fraction_split=training_fraction_split,
validation_fraction_split=validation_fraction_split,
test_fraction_split=test_fraction_split,
Expand All @@ -3900,6 +3947,7 @@ def _run(
environment_variables: Optional[Dict[str, str]] = None,
base_output_dir: Optional[str] = None,
service_account: Optional[str] = None,
network: Optional[str] = None,
training_fraction_split: float = 0.8,
validation_fraction_split: float = 0.1,
test_fraction_split: float = 0.1,
Expand Down Expand Up @@ -3951,6 +3999,11 @@ def _run(
service_account (str):
Specifies the service account for workload run-as account.
Users submitting jobs must have act-as permission on this run-as account.
network (str):
The full name of the Compute Engine network to which the job
should be peered. For example, projects/12345/global/networks/myVPC.
Private services access must already be configured for the network.
If left unspecified, the job is not peered with any network.
training_fraction_split (float):
The fraction of the input data that is to be
used to train the Model.
Expand Down Expand Up @@ -3999,7 +4052,10 @@ def _run(
training_task_inputs,
base_output_dir,
) = self._prepare_training_task_inputs_and_output_dir(
worker_pool_specs, base_output_dir, service_account
worker_pool_specs=worker_pool_specs,
base_output_dir=base_output_dir,
service_account=service_account,
network=network,
)

model = self._run_job(
Expand Down
7 changes: 7 additions & 0 deletions tests/unit/aiplatform/test_training_jobs.py
Expand Up @@ -109,6 +109,7 @@
)
_TEST_ALT_PROJECT = "test-project-alt"
_TEST_ALT_LOCATION = "europe-west4"
_TEST_NETWORK = f"projects/{_TEST_PROJECT}/global/networks/{_TEST_ID}"

_TEST_MODEL_INSTANCE_SCHEMA_URI = "instance_schema_uri.yaml"
_TEST_MODEL_PARAMETERS_SCHEMA_URI = "parameters_schema_uri.yaml"
Expand Down Expand Up @@ -598,6 +599,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
dataset=mock_tabular_dataset,
base_output_dir=_TEST_BASE_OUTPUT_DIR,
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
args=_TEST_RUN_ARGS,
environment_variables=_TEST_ENVIRONMENT_VARIABLES,
replica_count=1,
Expand Down Expand Up @@ -700,6 +702,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
"workerPoolSpecs": [true_worker_pool_spec],
"baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR},
"serviceAccount": _TEST_SERVICE_ACCOUNT,
"network": _TEST_NETWORK,
},
struct_pb2.Value(),
),
Expand Down Expand Up @@ -2539,6 +2542,7 @@ def test_run_call_pipeline_service_create_with_nontabular_dataset(
annotation_schema_uri=_TEST_ANNOTATION_SCHEMA_URI,
base_output_dir=_TEST_BASE_OUTPUT_DIR,
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
args=_TEST_RUN_ARGS,
replica_count=1,
machine_type=_TEST_MACHINE_TYPE,
Expand Down Expand Up @@ -2621,6 +2625,7 @@ def test_run_call_pipeline_service_create_with_nontabular_dataset(
"workerPoolSpecs": [true_worker_pool_spec],
"baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR},
"serviceAccount": _TEST_SERVICE_ACCOUNT,
"network": _TEST_NETWORK,
},
struct_pb2.Value(),
),
Expand Down Expand Up @@ -2970,6 +2975,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
model_display_name=_TEST_MODEL_DISPLAY_NAME,
base_output_dir=_TEST_BASE_OUTPUT_DIR,
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
args=_TEST_RUN_ARGS,
environment_variables=_TEST_ENVIRONMENT_VARIABLES,
replica_count=1,
Expand Down Expand Up @@ -3065,6 +3071,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
"workerPoolSpecs": [true_worker_pool_spec],
"baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR},
"serviceAccount": _TEST_SERVICE_ACCOUNT,
"network": _TEST_NETWORK,
},
struct_pb2.Value(),
),
Expand Down

0 comments on commit 56273f7

Please sign in to comment.