Skip to content

Commit

Permalink
cast web_access_uris to dict type
Browse files Browse the repository at this point in the history
  • Loading branch information
morgandu committed Oct 21, 2021
1 parent 5e44f24 commit 706423f
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
6 changes: 3 additions & 3 deletions google/cloud/aiplatform/jobs.py
Expand Up @@ -1108,7 +1108,7 @@ def _get_web_access_uris(self) -> Dict[str, str]:
(Dict[str, str]):
Web access uris of the custom job.
"""
return self._gca_resource.web_access_uris
return dict(self._gca_resource.web_access_uris)

def _log_web_access_uris(self):
"""Helper method to log the web access uris of the custom job"""
Expand Down Expand Up @@ -1678,9 +1678,9 @@ def _get_web_access_uris(self) -> Dict[str, Dict[str, str]]:
(Dict[str, Dict[str, str]]):
Web access uris of the hyperparameter job.
"""
web_access_uris = {}
web_access_uris = dict()
for trial in self.trials:
web_access_uris[trial.id] = web_access_uris.get(trial.id, {})
web_access_uris[trial.id] = web_access_uris.get(trial.id, dict())
for worker, uri in trial.web_access_uris.items():
web_access_uris[trial.id][worker] = uri
return web_access_uris
Expand Down
4 changes: 2 additions & 2 deletions google/cloud/aiplatform/training_jobs.py
Expand Up @@ -1457,7 +1457,7 @@ def web_access_uris(self) -> Dict[str, str]:
(Dict[str, str]):
Web access uris of the backing custom job.
"""
web_access_uris = {}
web_access_uris = dict()
if (
self._gca_resource.training_task_metadata
and self._gca_resource.training_task_metadata.get("backingCustomJob")
Expand All @@ -1467,7 +1467,7 @@ def web_access_uris(self) -> Dict[str, str]:
)
custom_job = jobs.CustomJob.get(resource_name=custom_job_resource_name)

web_access_uris = custom_job.web_access_uris
web_access_uris = dict(custom_job.web_access_uris)

return web_access_uris

Expand Down
3 changes: 3 additions & 0 deletions tests/unit/aiplatform/test_training_jobs.py
Expand Up @@ -524,11 +524,13 @@ def make_training_pipeline(state, add_training_task_metadata=True):
else None,
)


def make_training_pipeline_with_no_model_upload(state):
return gca_training_pipeline.TrainingPipeline(
name=_TEST_PIPELINE_RESOURCE_NAME, state=state,
)


def make_training_pipeline_with_enable_web_access(state):
training_pipeline = gca_training_pipeline.TrainingPipeline(
name=_TEST_PIPELINE_RESOURCE_NAME,
Expand All @@ -541,6 +543,7 @@ def make_training_pipeline_with_enable_web_access(state):
}
return training_pipeline


@pytest.fixture
def mock_pipeline_service_get():
with mock.patch.object(
Expand Down

0 comments on commit 706423f

Please sign in to comment.