Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add wait for creation and more informative exception when properties are not available #566

Merged
merged 13 commits into from Jul 29, 2021
66 changes: 66 additions & 0 deletions google/cloud/aiplatform/base.py
Expand Up @@ -23,6 +23,7 @@
import logging
import sys
import threading
import time
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -540,21 +541,25 @@ def _sync_gca_resource(self):
@property
def name(self) -> str:
"""Name of this resource."""
self._assert_gca_resource_is_available()
return self._gca_resource.name.split("/")[-1]

@property
def resource_name(self) -> str:
"""Full qualified resource name."""
self._assert_gca_resource_is_available()
return self._gca_resource.name

@property
def display_name(self) -> str:
"""Display name of this resource."""
self._assert_gca_resource_is_available()
return self._gca_resource.display_name

@property
def create_time(self) -> datetime.datetime:
"""Time this resource was created."""
self._assert_gca_resource_is_available()
return self._gca_resource.create_time

@property
Expand All @@ -570,6 +575,7 @@ def encryption_spec(self) -> Optional[gca_encryption_spec.EncryptionSpec]:
If this is set, then all resources created by this Vertex AI resource will
be encrypted with the provided encryption key.
"""
self._assert_gca_resource_is_available()
return getattr(self._gca_resource, "encryption_spec")

@property
Expand All @@ -578,13 +584,26 @@ def labels(self) -> Dict[str, str]:

Read more about labels at https://goo.gl/xmQnxf
"""
self._assert_gca_resource_is_available()
return self._gca_resource.labels

@property
def gca_resource(self) -> proto.Message:
"""The underlying resource proto representation."""
self._assert_gca_resource_is_available()
return self._gca_resource

def _assert_gca_resource_is_available(self) -> None:
"""Helper method to raise when property is not accessible.

Raises:
RuntimeError if _gca_resource is has not been created.
"""
if self._gca_resource is None:
raise RuntimeError(
f"{self.__class__.__name__} resource has not been created"
)

def __repr__(self) -> str:
return f"{object.__repr__(self)} \nresource name: {self.resource_name}"

Expand Down Expand Up @@ -1061,6 +1080,53 @@ def __repr__(self) -> str:

return FutureManager.__repr__(self)

def _wait_for_resource_creation(self) -> None:
"""Wait until underlying resource is created.

Currently this should only be used on subclasses that implement the construct then
`run` pattern because the underlying sync=False implementation will not update
downstream resource noun object's _gca_resource until the entire invoked method is complete.

Ex:
job = CustomTrainingJob()
job.run(sync=False, ...)
job._wait_for_resource_creation()
Raises:
RuntimeError if the resource has not been scheduled to be created.
"""

# If the user calls this but didn't actually invoke an API to create
if self._are_futures_done() and not getattr(self._gca_resource, "name", None):
self._raise_future_exception()
raise RuntimeError(
f"{self.__class__.__name__} resource is not scheduled to be created."
)

while not getattr(self._gca_resource, "name", None):
# breaks out of loop if creation has failed async
if self._are_futures_done() and not getattr(
self._gca_resource, "name", None
):
self._raise_future_exception()

time.sleep(1)
Comment on lines +1105 to +1112
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this frequency of polling alright? Since it doesn't call the service I'm presuming so.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, it doesn't call the service. Also, when in this loop there is no _gca_resource available so the resource either hasn't been created yet or our creation request hasn't returned yet.


def _assert_gca_resource_is_available(self) -> None:
"""Helper method to raise when accessing properties that do not exist.

Raises:
RuntimeError when resource has not been created.
"""
sasha-gitg marked this conversation as resolved.
Show resolved Hide resolved
if not getattr(self._gca_resource, "name", None):
raise RuntimeError(
f"{self.__class__.__name__} resource has not been created."
+ (
f" Resource failed with: {self._exception}"
if self._exception
else ""
)
)


def get_annotation_class(annotation: type) -> type:
"""Helper method to retrieve type annotation.
Expand Down
1 change: 1 addition & 0 deletions google/cloud/aiplatform/datasets/dataset.py
Expand Up @@ -84,6 +84,7 @@ def __init__(
@property
def metadata_schema_uri(self) -> str:
"""The metadata schema uri of this dataset resource."""
self._assert_gca_resource_is_available()
return self._gca_resource.metadata_schema_uri

def _validate_metadata_schema_uri(self) -> None:
Expand Down
2 changes: 2 additions & 0 deletions google/cloud/aiplatform/datasets/tabular_dataset.py
Expand Up @@ -52,6 +52,8 @@ def column_names(self) -> List[str]:
RuntimeError: When no valid source is found.
"""

self._assert_gca_resource_is_available()

metadata = self._gca_resource.metadata

if metadata is None:
Expand Down
31 changes: 13 additions & 18 deletions google/cloud/aiplatform/jobs.py
Expand Up @@ -330,18 +330,21 @@ def output_info(self,) -> Optional[aiplatform.gapic.BatchPredictionJob.OutputInf

This is only available for batch predicition jobs that have run successfully.
"""
self._assert_gca_resource_is_available()
return self._gca_resource.output_info

@property
def partial_failures(self) -> Optional[Sequence[status_pb2.Status]]:
"""Partial failures encountered. For example, single files that can't be read.
This field never exceeds 20 entries. Status details fields contain standard
GCP error details."""
self._assert_gca_resource_is_available()
return getattr(self._gca_resource, "partial_failures")

@property
def completion_stats(self) -> Optional[gca_completion_stats.CompletionStats]:
"""Statistics on completed and failed prediction instances."""
self._assert_gca_resource_is_available()
return getattr(self._gca_resource, "completion_stats")

@classmethod
Expand Down Expand Up @@ -772,6 +775,8 @@ def iter_outputs(
GCS or BQ output provided.
"""

self._assert_gca_resource_is_available()

if self.state != gca_job_state.JobState.JOB_STATE_SUCCEEDED:
raise RuntimeError(
f"Cannot read outputs until BatchPredictionJob has succeeded, "
Expand Down Expand Up @@ -859,23 +864,6 @@ def __init__(
def run(self) -> None:
pass

@property
def _has_run(self) -> bool:
"""Property returns true if this class has a resource name."""
return bool(self._gca_resource.name)

@property
def state(self) -> gca_job_state.JobState:
"""Current state of job.

Raises:
RuntimeError if job run has not been called.
"""
if not self._has_run:
raise RuntimeError("Job has not run. No state available.")

return super().state

@classmethod
def get(
cls,
Expand Down Expand Up @@ -913,6 +901,10 @@ def get(

return self

def wait_for_resource_creation(self) -> None:
"""Waits until resource has been created."""
self._wait_for_resource_creation()


class DataLabelingJob(_Job):
_resource_noun = "dataLabelingJobs"
Expand Down Expand Up @@ -1041,7 +1033,8 @@ def network(self) -> Optional[str]:
Private services access must already be configured for the network. If left
unspecified, the CustomJob is not peered with any network.
"""
return getattr(self._gca_resource, "network")
self._assert_gca_resource_is_available()
return self._gca_resource.job_spec.network

@classmethod
def from_local_script(
Expand Down Expand Up @@ -1512,6 +1505,7 @@ def network(self) -> Optional[str]:
Private services access must already be configured for the network. If left
unspecified, the HyperparameterTuningJob is not peered with any network.
"""
self._assert_gca_resource_is_available()
return getattr(self._gca_resource.trial_job_spec, "network")

@base.optional_sync()
Expand Down Expand Up @@ -1612,4 +1606,5 @@ def run(

@property
def trials(self) -> List[gca_study_compat.Trial]:
self._assert_gca_resource_is_available()
return list(self._gca_resource.trials)
12 changes: 11 additions & 1 deletion google/cloud/aiplatform/models.py
Expand Up @@ -146,7 +146,8 @@ def network(self) -> Optional[str]:
Private services access must already be configured for the network. If left
unspecified, the Endpoint is not peered with any network.
"""
return getattr(self._gca_resource, "network")
self._assert_gca_resource_is_available()
return getattr(self._gca_resource, "network", None)

@classmethod
def create(
Expand Down Expand Up @@ -1283,11 +1284,13 @@ class Model(base.VertexAiResourceNounWithFutureManager):
def uri(self) -> Optional[str]:
"""Path to the directory containing the Model artifact and any of its
supporting files. Not present for AutoML Models."""
self._assert_gca_resource_is_available()
return self._gca_resource.artifact_uri or None

@property
def description(self) -> str:
"""Description of the model."""
self._assert_gca_resource_is_available()
return self._gca_resource.description

@property
Expand All @@ -1302,6 +1305,7 @@ def supported_export_formats(

{'tf-saved-model': [<ExportableContent.ARTIFACT: 1>]}
"""
self._assert_gca_resource_is_available()
return {
export_format.id: [
gca_model_compat.Model.ExportFormat.ExportableContent(content)
Expand All @@ -1328,6 +1332,7 @@ def supported_deployment_resources_types(
predictions by using a `BatchPredictionJob`, if it has at least one entry
each in `Model.supported_input_storage_formats` and
`Model.supported_output_storage_formats`."""
self._assert_gca_resource_is_available()
return list(self._gca_resource.supported_deployment_resources_types)

@property
Expand All @@ -1343,6 +1348,7 @@ def supported_input_storage_formats(self) -> List[str]:
`supported_deployment_resources_types`, it could serve online predictions
by using `Endpoint.predict()` or `Endpoint.explain()`.
"""
self._assert_gca_resource_is_available()
return list(self._gca_resource.supported_input_storage_formats)

@property
Expand All @@ -1363,12 +1369,14 @@ def supported_output_storage_formats(self) -> List[str]:
`supported_deployment_resources_types`, it could serve online predictions
by using `Endpoint.predict()` or `Endpoint.explain()`.
"""
self._assert_gca_resource_is_available()
return list(self._gca_resource.supported_output_storage_formats)

@property
def predict_schemata(self) -> Optional[aiplatform.gapic.PredictSchemata]:
"""The schemata that describe formats of the Model's predictions and
explanations, if available."""
self._assert_gca_resource_is_available()
return getattr(self._gca_resource, "predict_schemata")

@property
Expand All @@ -1379,6 +1387,7 @@ def training_job(self) -> Optional["aiplatform.training_jobs._TrainingJob"]:
api_core.exceptions.NotFound: If the Model's training job resource
cannot be found on the Vertex service.
"""
self._assert_gca_resource_is_available()
job_name = getattr(self._gca_resource, "training_pipeline")

if not job_name:
Expand All @@ -1400,6 +1409,7 @@ def training_job(self) -> Optional["aiplatform.training_jobs._TrainingJob"]:
def container_spec(self) -> Optional[aiplatform.gapic.ModelContainerSpec]:
"""The specification of the container that is to be used when deploying
this Model. Not present for AutoML Models."""
self._assert_gca_resource_is_available()
return getattr(self._gca_resource, "container_spec")

def __init__(
Expand Down
37 changes: 17 additions & 20 deletions google/cloud/aiplatform/pipeline_jobs.py
Expand Up @@ -220,6 +220,18 @@ def __init__(
),
)

def _assert_gca_resource_is_available(self) -> None:
# TODO(b/193800063) Change this to name after this fix
if not getattr(self._gca_resource, "create_time", None):
raise RuntimeError(
f"{self.__class__.__name__} resource has not been created."
+ (
f" Resource failed with: {self._exception}"
if self._exception
else ""
)
)

@base.optional_sync()
def run(
self,
Expand All @@ -236,6 +248,7 @@ def run(
network (str):
Optional. The full name of the Compute Engine network to which the job
should be peered. For example, projects/12345/global/networks/myVPC.

Private services access must already be configured for the network.
If left unspecified, the job is not peered with any network.
sync (bool):
Expand Down Expand Up @@ -268,17 +281,9 @@ def pipeline_spec(self):
@property
def state(self) -> Optional[gca_pipeline_state_v1beta1.PipelineState]:
"""Current pipeline state."""
if not self._has_run:
raise RuntimeError("Job has not run. No state available.")

self._sync_gca_resource()
return self._gca_resource.state

@property
def _has_run(self) -> bool:
"""Helper property to check if this pipeline job has been run."""
return bool(self._gca_resource.create_time)

@property
def has_failed(self) -> bool:
"""Returns True if pipeline has failed.
Expand All @@ -296,10 +301,6 @@ def _dashboard_uri(self) -> str:
url = f"https://console.cloud.google.com/vertex-ai/locations/{fields.location}/pipelines/runs/{fields.id}?project={fields.project}"
return url

def _sync_gca_resource(self):
"""Helper method to sync the local gca_source against the service."""
self._gca_resource = self.api_client.get_pipeline_job(name=self.resource_name)

def _block_until_complete(self):
"""Helper method to block and check on job until complete."""
# Used these numbers so failures surface fast
Expand Down Expand Up @@ -373,13 +374,9 @@ def cancel(self) -> None:
makes a best effort to cancel the job, but success is not guaranteed.
On successful cancellation, the PipelineJob is not deleted; instead it
becomes a job with state set to `CANCELLED`.

Raises:
RuntimeError: If this PipelineJob has not started running.
"""
if not self._has_run:
raise RuntimeError(
"This PipelineJob has not been launched, use the `run()` method "
"to start. `cancel()` can only be called on a job that is running."
)
self.api_client.cancel_pipeline_job(name=self.resource_name)

def wait_for_resource_creation(self) -> None:
"""Waits until resource has been created."""
self._wait_for_resource_creation()