Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add VPC Peering support to CustomTrainingJob classes #378

Merged
merged 2 commits into from May 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
62 changes: 59 additions & 3 deletions google/cloud/aiplatform/training_jobs.py
Expand Up @@ -1526,6 +1526,7 @@ def _prepare_training_task_inputs_and_output_dir(
worker_pool_specs: _DistributedTrainingSpec,
base_output_dir: Optional[str] = None,
service_account: Optional[str] = None,
network: Optional[str] = None,
) -> Tuple[Dict, str]:
"""Prepares training task inputs and output directory for custom job.

Expand All @@ -1538,6 +1539,11 @@ def _prepare_training_task_inputs_and_output_dir(
service_account (str):
Specifies the service account for workload run-as account.
Users submitting jobs must have act-as permission on this run-as account.
network (str):
The full name of the Compute Engine network to which the job
should be peered. For example, projects/12345/global/networks/myVPC.
Private services access must already be configured for the network.
If left unspecified, the job is not peered with any network.
Returns:
Training task inputs and Output directory for custom job.
"""
Expand All @@ -1556,6 +1562,8 @@ def _prepare_training_task_inputs_and_output_dir(

if service_account:
training_task_inputs["serviceAccount"] = service_account
if network:
training_task_inputs["network"] = network

return training_task_inputs, base_output_dir

Expand Down Expand Up @@ -1803,6 +1811,7 @@ def run(
model_display_name: Optional[str] = None,
base_output_dir: Optional[str] = None,
service_account: Optional[str] = None,
network: Optional[str] = None,
bigquery_destination: Optional[str] = None,
args: Optional[List[Union[str, float, int]]] = None,
environment_variables: Optional[Dict[str, str]] = None,
Expand Down Expand Up @@ -1891,6 +1900,11 @@ def run(
service_account (str):
Specifies the service account for workload run-as account.
Users submitting jobs must have act-as permission on this run-as account.
network (str):
The full name of the Compute Engine network to which the job
should be peered. For example, projects/12345/global/networks/myVPC.
Private services access must already be configured for the network.
If left unspecified, the job is not peered with any network.
bigquery_destination (str):
Provide this field if `dataset` is a BiqQuery dataset.
The BigQuery project location where the training data is to
Expand Down Expand Up @@ -1981,6 +1995,7 @@ def run(
environment_variables=environment_variables,
base_output_dir=base_output_dir,
service_account=service_account,
network=network,
bigquery_destination=bigquery_destination,
training_fraction_split=training_fraction_split,
validation_fraction_split=validation_fraction_split,
Expand Down Expand Up @@ -2008,6 +2023,7 @@ def _run(
environment_variables: Optional[Dict[str, str]] = None,
base_output_dir: Optional[str] = None,
service_account: Optional[str] = None,
network: Optional[str] = None,
bigquery_destination: Optional[str] = None,
training_fraction_split: float = 0.8,
validation_fraction_split: float = 0.1,
Expand Down Expand Up @@ -2061,6 +2077,11 @@ def _run(
service_account (str):
Specifies the service account for workload run-as account.
Users submitting jobs must have act-as permission on this run-as account.
network (str):
The full name of the Compute Engine network to which the job
should be peered. For example, projects/12345/global/networks/myVPC.
Private services access must already be configured for the network.
If left unspecified, the job is not peered with any network.
bigquery_destination (str):
Provide this field if `dataset` is a BiqQuery dataset.
The BigQuery project location where the training data is to
Expand Down Expand Up @@ -2127,7 +2148,10 @@ def _run(
training_task_inputs,
base_output_dir,
) = self._prepare_training_task_inputs_and_output_dir(
worker_pool_specs, base_output_dir, service_account
worker_pool_specs=worker_pool_specs,
base_output_dir=base_output_dir,
service_account=service_account,
network=network,
)

model = self._run_job(
Expand Down Expand Up @@ -2372,6 +2396,7 @@ def run(
model_display_name: Optional[str] = None,
base_output_dir: Optional[str] = None,
service_account: Optional[str] = None,
network: Optional[str] = None,
bigquery_destination: Optional[str] = None,
args: Optional[List[Union[str, float, int]]] = None,
environment_variables: Optional[Dict[str, str]] = None,
Expand Down Expand Up @@ -2453,6 +2478,11 @@ def run(
service_account (str):
Specifies the service account for workload run-as account.
Users submitting jobs must have act-as permission on this run-as account.
network (str):
The full name of the Compute Engine network to which the job
should be peered. For example, projects/12345/global/networks/myVPC.
Private services access must already be configured for the network.
If left unspecified, the job is not peered with any network.
bigquery_destination (str):
Provide this field if `dataset` is a BiqQuery dataset.
The BigQuery project location where the training data is to
Expand Down Expand Up @@ -2542,6 +2572,7 @@ def run(
environment_variables=environment_variables,
base_output_dir=base_output_dir,
service_account=service_account,
network=network,
bigquery_destination=bigquery_destination,
training_fraction_split=training_fraction_split,
validation_fraction_split=validation_fraction_split,
Expand All @@ -2568,6 +2599,7 @@ def _run(
environment_variables: Optional[Dict[str, str]] = None,
base_output_dir: Optional[str] = None,
service_account: Optional[str] = None,
network: Optional[str] = None,
bigquery_destination: Optional[str] = None,
training_fraction_split: float = 0.8,
validation_fraction_split: float = 0.1,
Expand Down Expand Up @@ -2618,6 +2650,11 @@ def _run(
service_account (str):
Specifies the service account for workload run-as account.
Users submitting jobs must have act-as permission on this run-as account.
network (str):
The full name of the Compute Engine network to which the job
should be peered. For example, projects/12345/global/networks/myVPC.
Private services access must already be configured for the network.
If left unspecified, the job is not peered with any network.
bigquery_destination (str):
The BigQuery project location where the training data is to
be written to. In the given project a new dataset is created
Expand Down Expand Up @@ -2677,7 +2714,10 @@ def _run(
training_task_inputs,
base_output_dir,
) = self._prepare_training_task_inputs_and_output_dir(
worker_pool_specs, base_output_dir, service_account
worker_pool_specs=worker_pool_specs,
base_output_dir=base_output_dir,
service_account=service_account,
network=network,
)

model = self._run_job(
Expand Down Expand Up @@ -3703,6 +3743,7 @@ def run(
model_display_name: Optional[str] = None,
base_output_dir: Optional[str] = None,
service_account: Optional[str] = None,
network: Optional[str] = None,
bigquery_destination: Optional[str] = None,
args: Optional[List[Union[str, float, int]]] = None,
environment_variables: Optional[Dict[str, str]] = None,
Expand Down Expand Up @@ -3784,6 +3825,11 @@ def run(
service_account (str):
Specifies the service account for workload run-as account.
Users submitting jobs must have act-as permission on this run-as account.
network (str):
The full name of the Compute Engine network to which the job
should be peered. For example, projects/12345/global/networks/myVPC.
Private services access must already be configured for the network.
If left unspecified, the job is not peered with any network.
bigquery_destination (str):
Provide this field if `dataset` is a BiqQuery dataset.
The BigQuery project location where the training data is to
Expand Down Expand Up @@ -3868,6 +3914,7 @@ def run(
environment_variables=environment_variables,
base_output_dir=base_output_dir,
service_account=service_account,
network=network,
training_fraction_split=training_fraction_split,
validation_fraction_split=validation_fraction_split,
test_fraction_split=test_fraction_split,
Expand All @@ -3894,6 +3941,7 @@ def _run(
environment_variables: Optional[Dict[str, str]] = None,
base_output_dir: Optional[str] = None,
service_account: Optional[str] = None,
network: Optional[str] = None,
training_fraction_split: float = 0.8,
validation_fraction_split: float = 0.1,
test_fraction_split: float = 0.1,
Expand Down Expand Up @@ -3945,6 +3993,11 @@ def _run(
service_account (str):
Specifies the service account for workload run-as account.
Users submitting jobs must have act-as permission on this run-as account.
network (str):
The full name of the Compute Engine network to which the job
should be peered. For example, projects/12345/global/networks/myVPC.
Private services access must already be configured for the network.
If left unspecified, the job is not peered with any network.
training_fraction_split (float):
The fraction of the input data that is to be
used to train the Model.
Expand Down Expand Up @@ -3990,7 +4043,10 @@ def _run(
training_task_inputs,
base_output_dir,
) = self._prepare_training_task_inputs_and_output_dir(
worker_pool_specs, base_output_dir, service_account
worker_pool_specs=worker_pool_specs,
base_output_dir=base_output_dir,
service_account=service_account,
network=network,
)

model = self._run_job(
Expand Down
7 changes: 7 additions & 0 deletions tests/unit/aiplatform/test_training_jobs.py
Expand Up @@ -109,6 +109,7 @@
)
_TEST_ALT_PROJECT = "test-project-alt"
_TEST_ALT_LOCATION = "europe-west4"
_TEST_NETWORK = f"projects/{_TEST_PROJECT}/global/networks/{_TEST_ID}"

_TEST_MODEL_INSTANCE_SCHEMA_URI = "instance_schema_uri.yaml"
_TEST_MODEL_PARAMETERS_SCHEMA_URI = "parameters_schema_uri.yaml"
Expand Down Expand Up @@ -598,6 +599,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
dataset=mock_tabular_dataset,
base_output_dir=_TEST_BASE_OUTPUT_DIR,
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
args=_TEST_RUN_ARGS,
environment_variables=_TEST_ENVIRONMENT_VARIABLES,
replica_count=1,
Expand Down Expand Up @@ -697,6 +699,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
"workerPoolSpecs": [true_worker_pool_spec],
"baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR},
"serviceAccount": _TEST_SERVICE_ACCOUNT,
"network": _TEST_NETWORK,
},
struct_pb2.Value(),
),
Expand Down Expand Up @@ -2524,6 +2527,7 @@ def test_run_call_pipeline_service_create_with_nontabular_dataset(
annotation_schema_uri=_TEST_ANNOTATION_SCHEMA_URI,
base_output_dir=_TEST_BASE_OUTPUT_DIR,
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
args=_TEST_RUN_ARGS,
replica_count=1,
machine_type=_TEST_MACHINE_TYPE,
Expand Down Expand Up @@ -2606,6 +2610,7 @@ def test_run_call_pipeline_service_create_with_nontabular_dataset(
"workerPoolSpecs": [true_worker_pool_spec],
"baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR},
"serviceAccount": _TEST_SERVICE_ACCOUNT,
"network": _TEST_NETWORK,
},
struct_pb2.Value(),
),
Expand Down Expand Up @@ -2955,6 +2960,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
model_display_name=_TEST_MODEL_DISPLAY_NAME,
base_output_dir=_TEST_BASE_OUTPUT_DIR,
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
args=_TEST_RUN_ARGS,
environment_variables=_TEST_ENVIRONMENT_VARIABLES,
replica_count=1,
Expand Down Expand Up @@ -3047,6 +3053,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
"workerPoolSpecs": [true_worker_pool_spec],
"baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR},
"serviceAccount": _TEST_SERVICE_ACCOUNT,
"network": _TEST_NETWORK,
},
struct_pb2.Value(),
),
Expand Down