Skip to content

Commit

Permalink
fix: updated internal structure for retrieving and tracking web acces…
Browse files Browse the repository at this point in the history
…s uris
  • Loading branch information
morgandu committed Oct 20, 2021
1 parent 3c7640b commit 5e44f24
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 83 deletions.
97 changes: 46 additions & 51 deletions google/cloud/aiplatform/jobs.py
Expand Up @@ -843,23 +843,20 @@ def __init__(
project=project, location=location
)

self._web_access_uris = None
self._logged_web_access_uris = []
self._logged_web_access_uris = set()

@property
def web_access_uris(self) -> Dict[str, str]:
def web_access_uris(self) -> Dict[str, Union[str, Dict[str, str]]]:
"""Fetch the runnable job again and return the latest web access uris.
Returns:
(Dict[str, str]):
(Dict[str, Union[str, Dict[str, str]]]):
Web access uris of the runnable job.
"""

# Fetch the Job again for most up-to-date web access uris
self._sync_gca_resource()
self._get_web_access_uris()

return self._web_access_uris
return self._get_web_access_uris()

@abc.abstractmethod
def _get_web_access_uris(self):
Expand Down Expand Up @@ -1104,28 +1101,25 @@ def network(self) -> Optional[str]:
self._assert_gca_resource_is_available()
return self._gca_resource.job_spec.network

def _get_web_access_uris(self):
"""Helper method to get the web access uris of the custom job"""
self._web_access_uris = self._gca_resource.web_access_uris
def _get_web_access_uris(self) -> Dict[str, str]:
"""Helper method to get the web access uris of the custom job
Returns:
(Dict[str, str]):
Web access uris of the custom job.
"""
return self._gca_resource.web_access_uris

def _log_web_access_uris(self):
"""Helper method to log the web access uris of the custom job"""

self._get_web_access_uris()

if self._web_access_uris:
for worker, uri in self._web_access_uris.items():
if uri not in self._logged_web_access_uris:
_LOGGER.info(
"%s %s access the interactive shell terminals for the custom job:\n%s:\n%s"
% (
self.__class__.__name__,
self._gca_resource.name,
worker,
uri,
),
)
self._logged_web_access_uris.append(uri)
for worker, uri in self.web_access_uris.items():
if uri not in self._logged_web_access_uris:
_LOGGER.info(
"%s %s access the interactive shell terminals for the custom job:\n%s:\n%s"
% (self.__class__.__name__, self._gca_resource.name, worker, uri,),
)
self._logged_web_access_uris.add(uri)

@classmethod
def from_local_script(
Expand Down Expand Up @@ -1677,36 +1671,37 @@ def network(self) -> Optional[str]:
self._assert_gca_resource_is_available()
return getattr(self._gca_resource.trial_job_spec, "network")

def _get_web_access_uris(self):
"""Helper method to get the web access uris of the hyperparameter job"""
if self.trials:
self._web_access_uris = [
(trial.id, trial.web_access_uris)
for trial in self.trials
if trial.web_access_uris
]
def _get_web_access_uris(self) -> Dict[str, Dict[str, str]]:
"""Helper method to get the web access uris of the hyperparameter job
Returns:
(Dict[str, Dict[str, str]]):
Web access uris of the hyperparameter job.
"""
web_access_uris = {}
for trial in self.trials:
web_access_uris[trial.id] = web_access_uris.get(trial.id, {})
for worker, uri in trial.web_access_uris.items():
web_access_uris[trial.id][worker] = uri
return web_access_uris

def _log_web_access_uris(self):
"""Helper method to log the web access uris of the hyperparameter job"""

self._get_web_access_uris()

if self._web_access_uris:
for (trial_id, tria_web_access_uris) in self._web_access_uris:
if tria_web_access_uris:
for worker, uri in tria_web_access_uris.items():
if uri not in self._logged_web_access_uris:
_LOGGER.info(
"%s %s access the interactive shell terminals for trial - %s:\n%s:\n%s"
% (
self.__class__.__name__,
self._gca_resource.name,
trial_id,
worker,
uri,
),
)
self._logged_web_access_uris.append(uri)
for (trial_id, trial_web_access_uris) in self.web_access_uris.items():
for worker, uri in trial_web_access_uris.items():
if uri not in self._logged_web_access_uris:
_LOGGER.info(
"%s %s access the interactive shell terminals for trial - %s:\n%s:\n%s"
% (
self.__class__.__name__,
self._gca_resource.name,
trial_id,
worker,
uri,
),
)
self._logged_web_access_uris.add(uri)

@base.optional_sync()
def run(
Expand Down
55 changes: 23 additions & 32 deletions google/cloud/aiplatform/training_jobs.py
Expand Up @@ -1252,9 +1252,7 @@ def __init__(
# once Custom Job is known we log the console uri and the tensorboard uri
# this flags keeps that state so we don't log it multiple times
self._has_logged_custom_job = False
self._custom_job = None
self._web_access_uris = None
self._logged_web_access_uris = []
self._logged_web_access_uris = set()

@property
def network(self) -> Optional[str]:
Expand Down Expand Up @@ -1452,33 +1450,36 @@ def _prepare_training_task_inputs_and_output_dir(
return training_task_inputs, base_output_dir

@property
def web_access_uris(self) -> Optional[Dict[str, str]]:
def web_access_uris(self) -> Dict[str, str]:
"""Get the web access uris of the backing custom job.
Returns:
(Dict[str, str]):
Web access uris of the backing custom job.
"""
if self._custom_job:
self._web_access_uris = self._custom_job.web_access_uris
return self._web_access_uris
web_access_uris = {}
if (
self._gca_resource.training_task_metadata
and self._gca_resource.training_task_metadata.get("backingCustomJob")
):
custom_job_resource_name = self._gca_resource.training_task_metadata.get(
"backingCustomJob"
)
custom_job = jobs.CustomJob.get(resource_name=custom_job_resource_name)

web_access_uris = custom_job.web_access_uris

return web_access_uris

def _log_web_access_uris(self):
"""Helper method to log the web access uris of the backing custom job"""

if self._web_access_uris:
for worker, uri in self._web_access_uris.items():
if uri not in self._logged_web_access_uris:
_LOGGER.info(
"%s %s access the interactive shell terminals for the backing custom job:\n%s:\n%s"
% (
self.__class__.__name__,
self._gca_resource.name,
worker,
uri,
),
)
self._logged_web_access_uris.append(uri)
for worker, uri in self.web_access_uris.items():
if uri not in self._logged_web_access_uris:
_LOGGER.info(
"%s %s access the interactive shell terminals for the backing custom job:\n%s:\n%s"
% (self.__class__.__name__, self._gca_resource.name, worker, uri,),
)
self._logged_web_access_uris.add(uri)

def _wait_callback(self):
if (
Expand All @@ -1493,17 +1494,7 @@ def _wait_callback(self):

self._has_logged_custom_job = True

if self._has_logged_custom_job and self._gca_resource.training_task_inputs.get(
"enable_web_access"
):
if not self._custom_job:
custom_job_resource_name = self._gca_resource.training_task_metadata.get(
"backingCustomJob"
)
self._custom_job = jobs.CustomJob.get(
resource_name=custom_job_resource_name
)
self._web_access_uris = self._custom_job.web_access_uris
if self._gca_resource.training_task_inputs.get("enable_web_access"):
self._log_web_access_uris()

def _custom_job_console_uri(self) -> str:
Expand Down

0 comments on commit 5e44f24

Please sign in to comment.