Skip to content

Commit

Permalink
Feat: Add debugging terminal support for CustomJob, HyperparameterTun…
Browse files Browse the repository at this point in the history
…ingJob, and Custom(*)TrainingJob
  • Loading branch information
morgandu committed Oct 15, 2021
1 parent 293809e commit 41eca99
Show file tree
Hide file tree
Showing 6 changed files with 787 additions and 21 deletions.
154 changes: 138 additions & 16 deletions google/cloud/aiplatform/jobs.py
Expand Up @@ -173,6 +173,17 @@ def _dashboard_uri(self) -> Optional[str]:
url = f"https://console.cloud.google.com/ai/platform/locations/{fields.location}/{self._job_type}/{fields.id}?project={fields.project}"
return url

def _log_job_state(self):
"""Helper method to log job state."""
_LOGGER.info(
"%s %s current state:\n%s"
% (
self.__class__.__name__,
self._gca_resource.name,
self._gca_resource.state,
)
)

def _block_until_complete(self):
"""Helper method to block and check on job until complete.
Expand All @@ -190,26 +201,13 @@ def _block_until_complete(self):
while self.state not in _JOB_COMPLETE_STATES:
current_time = time.time()
if current_time - previous_time >= log_wait:
_LOGGER.info(
"%s %s current state:\n%s"
% (
self.__class__.__name__,
self._gca_resource.name,
self._gca_resource.state,
)
)
self._log_job_state()
log_wait = min(log_wait * multiplier, max_wait)
previous_time = current_time
time.sleep(wait)

_LOGGER.info(
"%s %s current state:\n%s"
% (
self.__class__.__name__,
self._gca_resource.name,
self._gca_resource.state,
)
)
self._log_job_state()

# Error is only populated when the job state is
# JOB_STATE_FAILED or JOB_STATE_CANCELLED.
if self._gca_resource.state in _JOB_ERROR_STATES:
Expand Down Expand Up @@ -845,6 +843,56 @@ def __init__(
project=project, location=location
)

@abc.abstractmethod
def _get_web_access_uris(self) -> Dict[str, str]:
"""Helper method to get web access uris for the runnable job"""
pass

@property
def web_access_uris(self) -> Dict[str, str]:
return self._get_web_access_uris()

@abc.abstractmethod
def _log_web_access_uris(self):
"""Helper method to log web access uris for runnable job"""
pass

def _block_until_complete(self):
"""Helper method to block and check on runnable job until complete.
Raises:
RuntimeError: If job failed or cancelled.
"""

# Used these numbers so failures surface fast
wait = 5 # start at five seconds
log_wait = 5
max_wait = 60 * 5 # 5 minute wait
multiplier = 2 # scale wait by 2 every iteration

logged_web_access_uri = {}
previous_time = time.time()
while self.state not in _JOB_COMPLETE_STATES:
current_time = time.time()
if current_time - previous_time >= log_wait:
self._log_job_state()
log_wait = min(log_wait * multiplier, max_wait)
previous_time = current_time
web_access_uri = self._get_web_access_uris()
if web_access_uri and web_access_uri != logged_web_access_uri:
logged_web_access_uri = web_access_uri
self._log_web_access_uris()
time.sleep(wait)

self._log_job_state()

# Error is only populated when the job state is
# JOB_STATE_FAILED or JOB_STATE_CANCELLED.
if self._gca_resource.state in _JOB_ERROR_STATES:
raise RuntimeError("Job failed with:\n%s" % self._gca_resource.error)
else:
_LOGGER.log_action_completed_against_resource("run", "completed", self)

@abc.abstractmethod
def run(self) -> None:
pass
Expand Down Expand Up @@ -1219,13 +1267,39 @@ def from_local_script(
staging_bucket=staging_bucket,
)

def _get_web_access_uris(self) -> Dict[str, str]:
"""Helper method to get web access uris of the custom job
Returns:
(Dict[str, str]) - web access uris of the custom job
"""
self._sync_gca_resource()
return self._gca_resource.web_access_uris

def _log_web_access_uris(self):
"""Helper method to log web access uris of the custom job"""
_LOGGER.info(
"%s %s access the interactive shell terminals for this job at the following links:\n%s"
% (
self.__class__.__name__,
self._gca_resource.name,
"\n".join(
[
"%s:\n%s" % (worker, uri)
for worker, uri in self._gca_resource.web_access_uris.items()
]
),
)
)

@base.optional_sync()
def run(
self,
service_account: Optional[str] = None,
network: Optional[str] = None,
timeout: Optional[int] = None,
restart_job_on_worker_restart: bool = False,
enable_web_access: Optional[bool] = False,
tensorboard: Optional[str] = None,
sync: bool = True,
) -> None:
Expand All @@ -1247,6 +1321,10 @@ def run(
gets restarted. This feature can be used by
distributed training jobs that are not resilient
to workers leaving and joining a job.
enable_web_access (bool):
Optional. Whether you want Vertex AI to enable interactive shell access
to training containers.
https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell
tensorboard (str):
Optional. The name of a Vertex AI
[Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard]
Expand Down Expand Up @@ -1280,6 +1358,9 @@ def run(
restart_job_on_worker_restart=restart_job_on_worker_restart,
)

if enable_web_access:
self._gca_resource.job_spec.enable_web_access = enable_web_access

if tensorboard:
v1beta1_gca_resource = gca_custom_job_v1beta1.CustomJob()
v1beta1_gca_resource._pb.MergeFromString(
Expand Down Expand Up @@ -1564,13 +1645,47 @@ def network(self) -> Optional[str]:
self._assert_gca_resource_is_available()
return getattr(self._gca_resource.trial_job_spec, "network")

def _get_web_access_uris(self) -> Dict[str, str]:
"""Helper method to get web access uris of current trial of the hp job
Returns:
(Dict[str, str]) - web access uris of current trial of the hp job
"""

self._sync_gca_resource()

if self._gca_resource.trials:
return self._gca_resource.trials[-1].web_access_uris

return {}

def _log_web_access_uris(self):
"""Helper method to log web access uris of current trial of the hp job"""
_LOGGER.info(
"%s %s access the interactive shell terminals for trial %s at the following links:\n%s"
% (
self.__class__.__name__,
self._gca_resource.name,
self._gca_resource.trials[-1].id,
"\n".join(
[
"%s:\n%s" % (worker, uri)
for worker, uri in self._gca_resource.trials[
-1
].web_access_uris.items()
]
),
)
)

@base.optional_sync()
def run(
self,
service_account: Optional[str] = None,
network: Optional[str] = None,
timeout: Optional[int] = None, # seconds
restart_job_on_worker_restart: bool = False,
enable_web_access: Optional[bool] = False,
tensorboard: Optional[str] = None,
sync: bool = True,
) -> None:
Expand All @@ -1592,6 +1707,10 @@ def run(
gets restarted. This feature can be used by
distributed training jobs that are not resilient
to workers leaving and joining a job.
enable_web_access (bool):
Optional. Whether you want Vertex AI to enable interactive shell access
to training containers.
https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell
tensorboard (str):
Optional. The name of a Vertex AI
[Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard]
Expand Down Expand Up @@ -1625,6 +1744,9 @@ def run(
restart_job_on_worker_restart=restart_job_on_worker_restart,
)

if enable_web_access:
self._gca_resource.trial_job_spec.enable_web_access = enable_web_access

if tensorboard:
v1beta1_gca_resource = (
gca_hyperparameter_tuning_job_v1beta1.HyperparameterTuningJob()
Expand Down

0 comments on commit 41eca99

Please sign in to comment.