Skip to content

Commit

Permalink
feat: enable reduction server (#741)
Browse files Browse the repository at this point in the history
* feat: enable reduction server

* fix: remove optional for reduction_server_replica_count, add comment for _SPEC_ORDERS
  • Loading branch information
morgandu committed Oct 20, 2021
1 parent 3fd0ab7 commit 8ef0ded
Show file tree
Hide file tree
Showing 5 changed files with 574 additions and 69 deletions.
54 changes: 39 additions & 15 deletions google/cloud/aiplatform/jobs.py
Expand Up @@ -1061,6 +1061,9 @@ def from_local_script(
accelerator_count: int = 0,
boot_disk_type: str = "pd-ssd",
boot_disk_size_gb: int = 100,
reduction_server_replica_count: int = 0,
reduction_server_machine_type: Optional[str] = None,
reduction_server_container_uri: Optional[str] = None,
base_output_dir: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
Expand Down Expand Up @@ -1127,6 +1130,13 @@ def from_local_script(
boot_disk_size_gb (int):
Optional. Size in GB of the boot disk, default is 100GB.
boot disk size must be within the range of [100, 64000].
reduction_server_replica_count (int):
The number of reduction server replicas, default is 0.
reduction_server_machine_type (str):
Optional. The type of machine to use for reduction server.
reduction_server_container_uri (str):
Optional. The Uri of the reduction server container image.
See details: https://cloud.google.com/vertex-ai/docs/training/distributed-training#reduce_training_time_with_reduction_server
base_output_dir (str):
Optional. GCS output directory of job. If not provided a
timestamped directory in the staging directory will be used.
Expand Down Expand Up @@ -1181,6 +1191,8 @@ def from_local_script(
accelerator_type=accelerator_type,
boot_disk_type=boot_disk_type,
boot_disk_size_gb=boot_disk_size_gb,
reduction_server_replica_count=reduction_server_replica_count,
reduction_server_machine_type=reduction_server_machine_type,
).pool_specs

python_packager = source_utils._TrainingScriptPythonPackager(
Expand All @@ -1191,21 +1203,33 @@ def from_local_script(
gcs_staging_dir=staging_bucket, project=project, credentials=credentials,
)

for spec in worker_pool_specs:
spec["python_package_spec"] = {
"executor_image_uri": container_uri,
"python_module": python_packager.module_name,
"package_uris": [package_gcs_uri],
}

if args:
spec["python_package_spec"]["args"] = args

if environment_variables:
spec["python_package_spec"]["env"] = [
{"name": key, "value": value}
for key, value in environment_variables.items()
]
for spec_order, spec in enumerate(worker_pool_specs):

if not spec:
continue

if (
spec_order == worker_spec_utils._SPEC_ORDERS["server_spec"]
and reduction_server_replica_count > 0
):
spec["container_spec"] = {
"image_uri": reduction_server_container_uri,
}
else:
spec["python_package_spec"] = {
"executor_image_uri": container_uri,
"python_module": python_packager.module_name,
"package_uris": [package_gcs_uri],
}

if args:
spec["python_package_spec"]["args"] = args

if environment_variables:
spec["python_package_spec"]["env"] = [
{"name": key, "value": value}
for key, value in environment_variables.items()
]

return cls(
display_name=display_name,
Expand Down

0 comments on commit 8ef0ded

Please sign in to comment.