Skip to content

Commit

Permalink
feat: add wait for creation and more informative exception when prope…
Browse files Browse the repository at this point in the history
…rties are not available (#566)
  • Loading branch information
sasha-gitg committed Jul 29, 2021
1 parent c6614cd commit e346117
Show file tree
Hide file tree
Showing 20 changed files with 685 additions and 65 deletions.
69 changes: 69 additions & 0 deletions google/cloud/aiplatform/base.py
Expand Up @@ -23,6 +23,7 @@
import logging
import sys
import threading
import time
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -540,21 +541,25 @@ def _sync_gca_resource(self):
@property
def name(self) -> str:
"""Name of this resource."""
self._assert_gca_resource_is_available()
return self._gca_resource.name.split("/")[-1]

@property
def resource_name(self) -> str:
"""Full qualified resource name."""
self._assert_gca_resource_is_available()
return self._gca_resource.name

@property
def display_name(self) -> str:
"""Display name of this resource."""
self._assert_gca_resource_is_available()
return self._gca_resource.display_name

@property
def create_time(self) -> datetime.datetime:
"""Time this resource was created."""
self._assert_gca_resource_is_available()
return self._gca_resource.create_time

@property
Expand All @@ -570,6 +575,7 @@ def encryption_spec(self) -> Optional[gca_encryption_spec.EncryptionSpec]:
If this is set, then all resources created by this Vertex AI resource will
be encrypted with the provided encryption key.
"""
self._assert_gca_resource_is_available()
return getattr(self._gca_resource, "encryption_spec")

@property
Expand All @@ -578,13 +584,26 @@ def labels(self) -> Dict[str, str]:
Read more about labels at https://goo.gl/xmQnxf
"""
self._assert_gca_resource_is_available()
return self._gca_resource.labels

@property
def gca_resource(self) -> proto.Message:
"""The underlying resource proto representation."""
self._assert_gca_resource_is_available()
return self._gca_resource

def _assert_gca_resource_is_available(self) -> None:
"""Helper method to raise when property is not accessible.
Raises:
RuntimeError if _gca_resource is has not been created.
"""
if self._gca_resource is None:
raise RuntimeError(
f"{self.__class__.__name__} resource has not been created"
)

def __repr__(self) -> str:
return f"{object.__repr__(self)} \nresource name: {self.resource_name}"

Expand Down Expand Up @@ -1061,6 +1080,56 @@ def __repr__(self) -> str:

return FutureManager.__repr__(self)

def _wait_for_resource_creation(self) -> None:
"""Wait until underlying resource is created.
Currently this should only be used on subclasses that implement the construct then
`run` pattern because the underlying sync=False implementation will not update
downstream resource noun object's _gca_resource until the entire invoked method is complete.
Ex:
job = CustomTrainingJob()
job.run(sync=False, ...)
job._wait_for_resource_creation()
Raises:
RuntimeError if the resource has not been scheduled to be created.
"""

# If the user calls this but didn't actually invoke an API to create
if self._are_futures_done() and not getattr(self._gca_resource, "name", None):
self._raise_future_exception()
raise RuntimeError(
f"{self.__class__.__name__} resource is not scheduled to be created."
)

while not getattr(self._gca_resource, "name", None):
# breaks out of loop if creation has failed async
if self._are_futures_done() and not getattr(
self._gca_resource, "name", None
):
self._raise_future_exception()

time.sleep(1)

def _assert_gca_resource_is_available(self) -> None:
"""Helper method to raise when accessing properties that do not exist.
Overrides VertexAiResourceNoun to provide a more informative exception if
resource creation has failed asynchronously.
Raises:
RuntimeError when resource has not been created.
"""
if not getattr(self._gca_resource, "name", None):
raise RuntimeError(
f"{self.__class__.__name__} resource has not been created."
+ (
f" Resource failed with: {self._exception}"
if self._exception
else ""
)
)


def get_annotation_class(annotation: type) -> type:
"""Helper method to retrieve type annotation.
Expand Down
1 change: 1 addition & 0 deletions google/cloud/aiplatform/datasets/dataset.py
Expand Up @@ -84,6 +84,7 @@ def __init__(
@property
def metadata_schema_uri(self) -> str:
"""The metadata schema uri of this dataset resource."""
self._assert_gca_resource_is_available()
return self._gca_resource.metadata_schema_uri

def _validate_metadata_schema_uri(self) -> None:
Expand Down
2 changes: 2 additions & 0 deletions google/cloud/aiplatform/datasets/tabular_dataset.py
Expand Up @@ -52,6 +52,8 @@ def column_names(self) -> List[str]:
RuntimeError: When no valid source is found.
"""

self._assert_gca_resource_is_available()

metadata = self._gca_resource.metadata

if metadata is None:
Expand Down
31 changes: 13 additions & 18 deletions google/cloud/aiplatform/jobs.py
Expand Up @@ -330,18 +330,21 @@ def output_info(self,) -> Optional[aiplatform.gapic.BatchPredictionJob.OutputInf
This is only available for batch predicition jobs that have run successfully.
"""
self._assert_gca_resource_is_available()
return self._gca_resource.output_info

@property
def partial_failures(self) -> Optional[Sequence[status_pb2.Status]]:
"""Partial failures encountered. For example, single files that can't be read.
This field never exceeds 20 entries. Status details fields contain standard
GCP error details."""
self._assert_gca_resource_is_available()
return getattr(self._gca_resource, "partial_failures")

@property
def completion_stats(self) -> Optional[gca_completion_stats.CompletionStats]:
"""Statistics on completed and failed prediction instances."""
self._assert_gca_resource_is_available()
return getattr(self._gca_resource, "completion_stats")

@classmethod
Expand Down Expand Up @@ -772,6 +775,8 @@ def iter_outputs(
GCS or BQ output provided.
"""

self._assert_gca_resource_is_available()

if self.state != gca_job_state.JobState.JOB_STATE_SUCCEEDED:
raise RuntimeError(
f"Cannot read outputs until BatchPredictionJob has succeeded, "
Expand Down Expand Up @@ -859,23 +864,6 @@ def __init__(
def run(self) -> None:
pass

@property
def _has_run(self) -> bool:
"""Property returns true if this class has a resource name."""
return bool(self._gca_resource.name)

@property
def state(self) -> gca_job_state.JobState:
"""Current state of job.
Raises:
RuntimeError if job run has not been called.
"""
if not self._has_run:
raise RuntimeError("Job has not run. No state available.")

return super().state

@classmethod
def get(
cls,
Expand Down Expand Up @@ -913,6 +901,10 @@ def get(

return self

def wait_for_resource_creation(self) -> None:
"""Waits until resource has been created."""
self._wait_for_resource_creation()


class DataLabelingJob(_Job):
_resource_noun = "dataLabelingJobs"
Expand Down Expand Up @@ -1041,7 +1033,8 @@ def network(self) -> Optional[str]:
Private services access must already be configured for the network. If left
unspecified, the CustomJob is not peered with any network.
"""
return getattr(self._gca_resource, "network")
self._assert_gca_resource_is_available()
return self._gca_resource.job_spec.network

@classmethod
def from_local_script(
Expand Down Expand Up @@ -1512,6 +1505,7 @@ def network(self) -> Optional[str]:
Private services access must already be configured for the network. If left
unspecified, the HyperparameterTuningJob is not peered with any network.
"""
self._assert_gca_resource_is_available()
return getattr(self._gca_resource.trial_job_spec, "network")

@base.optional_sync()
Expand Down Expand Up @@ -1612,4 +1606,5 @@ def run(

@property
def trials(self) -> List[gca_study_compat.Trial]:
self._assert_gca_resource_is_available()
return list(self._gca_resource.trials)
12 changes: 11 additions & 1 deletion google/cloud/aiplatform/models.py
Expand Up @@ -146,7 +146,8 @@ def network(self) -> Optional[str]:
Private services access must already be configured for the network. If left
unspecified, the Endpoint is not peered with any network.
"""
return getattr(self._gca_resource, "network")
self._assert_gca_resource_is_available()
return getattr(self._gca_resource, "network", None)

@classmethod
def create(
Expand Down Expand Up @@ -1283,11 +1284,13 @@ class Model(base.VertexAiResourceNounWithFutureManager):
def uri(self) -> Optional[str]:
"""Path to the directory containing the Model artifact and any of its
supporting files. Not present for AutoML Models."""
self._assert_gca_resource_is_available()
return self._gca_resource.artifact_uri or None

@property
def description(self) -> str:
"""Description of the model."""
self._assert_gca_resource_is_available()
return self._gca_resource.description

@property
Expand All @@ -1302,6 +1305,7 @@ def supported_export_formats(
{'tf-saved-model': [<ExportableContent.ARTIFACT: 1>]}
"""
self._assert_gca_resource_is_available()
return {
export_format.id: [
gca_model_compat.Model.ExportFormat.ExportableContent(content)
Expand All @@ -1328,6 +1332,7 @@ def supported_deployment_resources_types(
predictions by using a `BatchPredictionJob`, if it has at least one entry
each in `Model.supported_input_storage_formats` and
`Model.supported_output_storage_formats`."""
self._assert_gca_resource_is_available()
return list(self._gca_resource.supported_deployment_resources_types)

@property
Expand All @@ -1343,6 +1348,7 @@ def supported_input_storage_formats(self) -> List[str]:
`supported_deployment_resources_types`, it could serve online predictions
by using `Endpoint.predict()` or `Endpoint.explain()`.
"""
self._assert_gca_resource_is_available()
return list(self._gca_resource.supported_input_storage_formats)

@property
Expand All @@ -1363,12 +1369,14 @@ def supported_output_storage_formats(self) -> List[str]:
`supported_deployment_resources_types`, it could serve online predictions
by using `Endpoint.predict()` or `Endpoint.explain()`.
"""
self._assert_gca_resource_is_available()
return list(self._gca_resource.supported_output_storage_formats)

@property
def predict_schemata(self) -> Optional[aiplatform.gapic.PredictSchemata]:
"""The schemata that describe formats of the Model's predictions and
explanations, if available."""
self._assert_gca_resource_is_available()
return getattr(self._gca_resource, "predict_schemata")

@property
Expand All @@ -1379,6 +1387,7 @@ def training_job(self) -> Optional["aiplatform.training_jobs._TrainingJob"]:
api_core.exceptions.NotFound: If the Model's training job resource
cannot be found on the Vertex service.
"""
self._assert_gca_resource_is_available()
job_name = getattr(self._gca_resource, "training_pipeline")

if not job_name:
Expand All @@ -1400,6 +1409,7 @@ def training_job(self) -> Optional["aiplatform.training_jobs._TrainingJob"]:
def container_spec(self) -> Optional[aiplatform.gapic.ModelContainerSpec]:
"""The specification of the container that is to be used when deploying
this Model. Not present for AutoML Models."""
self._assert_gca_resource_is_available()
return getattr(self._gca_resource, "container_spec")

def __init__(
Expand Down
37 changes: 17 additions & 20 deletions google/cloud/aiplatform/pipeline_jobs.py
Expand Up @@ -220,6 +220,18 @@ def __init__(
),
)

def _assert_gca_resource_is_available(self) -> None:
# TODO(b/193800063) Change this to name after this fix
if not getattr(self._gca_resource, "create_time", None):
raise RuntimeError(
f"{self.__class__.__name__} resource has not been created."
+ (
f" Resource failed with: {self._exception}"
if self._exception
else ""
)
)

@base.optional_sync()
def run(
self,
Expand All @@ -236,6 +248,7 @@ def run(
network (str):
Optional. The full name of the Compute Engine network to which the job
should be peered. For example, projects/12345/global/networks/myVPC.
Private services access must already be configured for the network.
If left unspecified, the job is not peered with any network.
sync (bool):
Expand Down Expand Up @@ -272,17 +285,9 @@ def pipeline_spec(self):
@property
def state(self) -> Optional[gca_pipeline_state_v1beta1.PipelineState]:
"""Current pipeline state."""
if not self._has_run:
raise RuntimeError("Job has not run. No state available.")

self._sync_gca_resource()
return self._gca_resource.state

@property
def _has_run(self) -> bool:
"""Helper property to check if this pipeline job has been run."""
return bool(self._gca_resource.create_time)

@property
def has_failed(self) -> bool:
"""Returns True if pipeline has failed.
Expand All @@ -300,10 +305,6 @@ def _dashboard_uri(self) -> str:
url = f"https://console.cloud.google.com/vertex-ai/locations/{fields.location}/pipelines/runs/{fields.id}?project={fields.project}"
return url

def _sync_gca_resource(self):
"""Helper method to sync the local gca_source against the service."""
self._gca_resource = self.api_client.get_pipeline_job(name=self.resource_name)

def _block_until_complete(self):
"""Helper method to block and check on job until complete."""
# Used these numbers so failures surface fast
Expand Down Expand Up @@ -377,13 +378,9 @@ def cancel(self) -> None:
makes a best effort to cancel the job, but success is not guaranteed.
On successful cancellation, the PipelineJob is not deleted; instead it
becomes a job with state set to `CANCELLED`.
Raises:
RuntimeError: If this PipelineJob has not started running.
"""
if not self._has_run:
raise RuntimeError(
"This PipelineJob has not been launched, use the `run()` method "
"to start. `cancel()` can only be called on a job that is running."
)
self.api_client.cancel_pipeline_job(name=self.resource_name)

def wait_for_resource_creation(self) -> None:
"""Waits until resource has been created."""
self._wait_for_resource_creation()

0 comments on commit e346117

Please sign in to comment.